|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import io |
|
|
import os |
|
|
import json |
|
|
import zipfile |
|
|
|
|
|
import lmdb |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import torch |
|
|
from torchvision.datasets import ImageFolder, VisionDataset |
|
|
|
|
|
|
|
|
|
|
|
def center_crop_arr(pil_image, image_size): |
|
|
""" |
|
|
Center cropping implementation from ADM. |
|
|
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 |
|
|
""" |
|
|
while min(*pil_image.size) >= 2 * image_size: |
|
|
pil_image = pil_image.resize( |
|
|
tuple(x // 2 for x in pil_image.size), resample=Image.BOX |
|
|
) |
|
|
|
|
|
scale = image_size / min(*pil_image.size) |
|
|
pil_image = pil_image.resize( |
|
|
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC |
|
|
) |
|
|
|
|
|
arr = np.array(pil_image) |
|
|
crop_y = (arr.shape[0] - image_size) // 2 |
|
|
crop_x = (arr.shape[1] - image_size) // 2 |
|
|
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def lmdb_loader(path, lmdb_data, resolution): |
|
|
|
|
|
with lmdb_data.begin(write=False, buffers=True) as txn: |
|
|
bytedata = txn.get(path.encode('ascii')) |
|
|
img = Image.open(io.BytesIO(bytedata)).convert('RGB') |
|
|
arr = center_crop_arr(img, resolution) |
|
|
|
|
|
|
|
|
return arr |
|
|
|
|
|
|
|
|
def imagenet_lmdb_dataset( |
|
|
root, |
|
|
transform=None, target_transform=None, |
|
|
resolution=256): |
|
|
""" |
|
|
You can create this dataloader using: |
|
|
train_data = imagenet_lmdb_dataset(traindir, transform=train_transform) |
|
|
valid_data = imagenet_lmdb_dataset(validdir, transform=val_transform) |
|
|
""" |
|
|
|
|
|
if root.endswith('/'): |
|
|
root = root[:-1] |
|
|
pt_path = os.path.join( |
|
|
root + '_faster_imagefolder.lmdb.pt') |
|
|
lmdb_path = os.path.join( |
|
|
root + '_faster_imagefolder.lmdb') |
|
|
if os.path.isfile(pt_path) and os.path.isdir(lmdb_path): |
|
|
print('Loading pt {} and lmdb {}'.format(pt_path, lmdb_path)) |
|
|
data_set = torch.load(pt_path) |
|
|
else: |
|
|
data_set = ImageFolder( |
|
|
root, None, None, None) |
|
|
torch.save(data_set, pt_path, pickle_protocol=4) |
|
|
print('Saving pt to {}'.format(pt_path)) |
|
|
print('Building lmdb to {}'.format(lmdb_path)) |
|
|
env = lmdb.open(lmdb_path, map_size=1e12) |
|
|
with env.begin(write=True) as txn: |
|
|
for path, class_index in data_set.imgs: |
|
|
with open(path, 'rb') as f: |
|
|
data = f.read() |
|
|
txn.put(path.encode('ascii'), data) |
|
|
|
|
|
lmdb_dataset = ImageLMDB(lmdb_path, transform, target_transform, resolution, data_set.imgs, data_set.class_to_idx, data_set.classes) |
|
|
return lmdb_dataset |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImageLMDB(VisionDataset): |
|
|
""" |
|
|
A data loader for ImageNet LMDB dataset, which is faster than the original ImageFolder. |
|
|
""" |
|
|
def __init__(self, root, transform=None, target_transform=None, |
|
|
resolution=256, samples=None, class_to_idx=None, classes=None): |
|
|
super().__init__(root, transform=transform, |
|
|
target_transform=target_transform) |
|
|
self.root = root |
|
|
self.resolution = resolution |
|
|
self.samples = samples |
|
|
self.class_to_idx = class_to_idx |
|
|
self.classes = classes |
|
|
|
|
|
def __getitem__(self, index: int): |
|
|
path, target = self.samples[index] |
|
|
|
|
|
|
|
|
if not hasattr(self, 'txn'): |
|
|
self.open_db() |
|
|
bytedata = self.txn.get(path.encode('ascii')) |
|
|
img = Image.open(io.BytesIO(bytedata)).convert('RGB') |
|
|
arr = center_crop_arr(img, self.resolution) |
|
|
if self.transform is not None: |
|
|
arr = self.transform(arr) |
|
|
if self.target_transform is not None: |
|
|
target = self.target_transform(target) |
|
|
return arr, target |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.samples) |
|
|
|
|
|
def open_db(self): |
|
|
self.env = lmdb.open(self.root, readonly=True, max_readers=256, lock=False, readahead=False, meminit=False) |
|
|
self.txn = self.env.begin(write=False, buffers=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Dataset(torch.utils.data.Dataset): |
|
|
def __init__(self, |
|
|
name, |
|
|
raw_shape, |
|
|
max_size=None, |
|
|
label_dim=1000, |
|
|
xflip=False, |
|
|
random_seed=0, |
|
|
): |
|
|
self._name = name |
|
|
self._raw_shape = list(raw_shape) |
|
|
self._label_dim = label_dim |
|
|
self._label_shape = None |
|
|
|
|
|
|
|
|
self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) |
|
|
if (max_size is not None) and (self._raw_idx.size > max_size): |
|
|
np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx) |
|
|
self._raw_idx = np.sort(self._raw_idx[:max_size]) |
|
|
|
|
|
|
|
|
if xflip: |
|
|
self._raw_idx = np.concatenate([self._raw_idx, self._raw_idx + self._raw_shape[0]]) |
|
|
|
|
|
def close(self): |
|
|
pass |
|
|
|
|
|
def _load_raw_data(self, raw_idx): |
|
|
raise NotImplementedError |
|
|
|
|
|
def __getstate__(self): |
|
|
return dict(self.__dict__, _raw_labels=None) |
|
|
|
|
|
def __del__(self): |
|
|
try: |
|
|
self.close() |
|
|
except: |
|
|
pass |
|
|
|
|
|
def __len__(self): |
|
|
return self._raw_idx.size |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
raw_idx = self._raw_idx[idx] |
|
|
image, cond = self._load_raw_data(raw_idx) |
|
|
assert isinstance(image, np.ndarray) |
|
|
if isinstance(cond, list): |
|
|
cond[0] = self._get_onehot(cond[0]) |
|
|
else: |
|
|
cond = self._get_onehot(cond) |
|
|
return image.copy(), cond |
|
|
|
|
|
def _get_onehot(self, label): |
|
|
if isinstance(label, int) or label.dtype == np.int64: |
|
|
onehot = np.zeros(self.label_shape, dtype=np.float32) |
|
|
onehot[label] = 1 |
|
|
label = onehot |
|
|
assert isinstance(label, np.ndarray) |
|
|
return label.copy() |
|
|
|
|
|
@property |
|
|
def name(self): |
|
|
return self._name |
|
|
|
|
|
@property |
|
|
def image_shape(self): |
|
|
return list(self._raw_shape[1:]) |
|
|
|
|
|
@property |
|
|
def num_channels(self): |
|
|
assert len(self.image_shape) == 3 |
|
|
return self.image_shape[0] |
|
|
|
|
|
@property |
|
|
def resolution(self): |
|
|
assert len(self.image_shape) == 3 |
|
|
assert self.image_shape[1] == self.image_shape[2] |
|
|
return self.image_shape[1] |
|
|
|
|
|
@property |
|
|
def label_shape(self): |
|
|
if self._label_shape is None: |
|
|
self._label_shape = [self._label_dim] |
|
|
return list(self._label_shape) |
|
|
|
|
|
@property |
|
|
def label_dim(self): |
|
|
assert len(self.label_shape) == 1 |
|
|
return self.label_shape[0] |
|
|
|
|
|
@property |
|
|
def has_labels(self): |
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImageNetLatentDataset(Dataset): |
|
|
def __init__(self, |
|
|
path, |
|
|
resolution=32, |
|
|
num_channels=4, |
|
|
split='train', |
|
|
feat_path=None, |
|
|
feat_dim=0, |
|
|
**super_kwargs, |
|
|
): |
|
|
self._path = os.path.join(path, split) |
|
|
self.feat_dim = feat_dim |
|
|
if not hasattr(self, 'txn'): |
|
|
self.open_lmdb() |
|
|
self.feat_txn = None |
|
|
if feat_path is not None and os.path.isdir(feat_path): |
|
|
assert self.feat_dim > 0 |
|
|
self._feat_path = os.path.join(feat_path, split) |
|
|
self.open_feat_lmdb() |
|
|
|
|
|
length = int(self.txn.get('length'.encode('utf-8')).decode('utf-8')) |
|
|
name = os.path.basename(path) |
|
|
raw_shape = [length, num_channels, resolution, resolution] |
|
|
if raw_shape[2] != resolution or raw_shape[3] != resolution: |
|
|
raise IOError('Image files do not match the specified resolution') |
|
|
|
|
|
super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) |
|
|
|
|
|
def open_lmdb(self): |
|
|
self.env = lmdb.open(self._path, readonly=True, lock=False, create=False) |
|
|
self.txn = self.env.begin(write=False) |
|
|
|
|
|
def open_feat_lmdb(self): |
|
|
self.feat_env = lmdb.open(self._feat_path, readonly=True, lock=False, create=False) |
|
|
self.feat_txn = self.feat_env.begin(write=False) |
|
|
|
|
|
def _load_raw_data(self, idx): |
|
|
if not hasattr(self, 'txn'): |
|
|
self.open_lmdb() |
|
|
|
|
|
z_bytes = self.txn.get(f'z-{str(idx)}'.encode('utf-8')) |
|
|
y_bytes = self.txn.get(f'y-{str(idx)}'.encode('utf-8')) |
|
|
z = np.frombuffer(z_bytes, dtype=np.float32).reshape([-1, self.resolution, self.resolution]).copy() |
|
|
y = int(y_bytes.decode('utf-8')) |
|
|
|
|
|
cond = y |
|
|
if self.feat_txn is not None: |
|
|
feat_bytes = self.feat_txn.get(f'feat-{str(idx)}'.encode('utf-8')) |
|
|
feat_y_bytes = self.feat_txn.get(f'y-{str(idx)}'.encode('utf-8')) |
|
|
feat = np.frombuffer(feat_bytes, dtype=np.float32).reshape([self.feat_dim]).copy() |
|
|
feat_y = int(feat_y_bytes.decode('utf-8')) |
|
|
assert y == feat_y, 'Ordering mismatch between txn and feat_txn!' |
|
|
cond = [y, feat] |
|
|
|
|
|
return z, cond |
|
|
|
|
|
def close(self): |
|
|
try: |
|
|
if self.env is not None: |
|
|
self.env.close() |
|
|
if self.feat_env is not None: |
|
|
self.feat_env.close() |
|
|
finally: |
|
|
self.env = None |
|
|
self.feat_env = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImageFolderDataset(Dataset): |
|
|
def __init__(self, |
|
|
path, |
|
|
resolution=None, |
|
|
use_labels=False, |
|
|
**super_kwargs, |
|
|
): |
|
|
self._path = path |
|
|
self._zipfile = None |
|
|
self._raw_labels = None |
|
|
self._use_labels = use_labels |
|
|
|
|
|
if os.path.isdir(self._path): |
|
|
self._type = 'dir' |
|
|
self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in |
|
|
os.walk(self._path) for fname in files} |
|
|
elif self._file_ext(self._path) == '.zip': |
|
|
self._type = 'zip' |
|
|
self._all_fnames = set(self._get_zipfile().namelist()) |
|
|
else: |
|
|
raise IOError('Path must point to a directory or zip') |
|
|
|
|
|
Image.init() |
|
|
self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in Image.EXTENSION) |
|
|
if len(self._image_fnames) == 0: |
|
|
raise IOError('No image files found in the specified path') |
|
|
|
|
|
name = os.path.splitext(os.path.basename(self._path))[0] |
|
|
raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) |
|
|
if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): |
|
|
raise IOError('Image files do not match the specified resolution') |
|
|
super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) |
|
|
|
|
|
@staticmethod |
|
|
def _file_ext(fname): |
|
|
return os.path.splitext(fname)[1].lower() |
|
|
|
|
|
def _get_zipfile(self): |
|
|
assert self._type == 'zip' |
|
|
if self._zipfile is None: |
|
|
self._zipfile = zipfile.ZipFile(self._path) |
|
|
return self._zipfile |
|
|
|
|
|
def _open_file(self, fname): |
|
|
if self._type == 'dir': |
|
|
return open(os.path.join(self._path, fname), 'rb') |
|
|
if self._type == 'zip': |
|
|
return self._get_zipfile().open(fname, 'r') |
|
|
return None |
|
|
|
|
|
def close(self): |
|
|
try: |
|
|
if self._zipfile is not None: |
|
|
self._zipfile.close() |
|
|
finally: |
|
|
self._zipfile = None |
|
|
|
|
|
def __getstate__(self): |
|
|
return dict(super().__getstate__(), _zipfile=None) |
|
|
|
|
|
def _load_raw_data(self, raw_idx): |
|
|
image = self._load_raw_image(raw_idx) |
|
|
assert image.dtype == np.uint8 |
|
|
label = self._get_raw_labels()[raw_idx] |
|
|
return image, label |
|
|
|
|
|
def _load_raw_image(self, raw_idx): |
|
|
fname = self._image_fnames[raw_idx] |
|
|
with self._open_file(fname) as f: |
|
|
image = np.array(Image.open(f)) |
|
|
if image.ndim == 2: |
|
|
image = image[:, :, np.newaxis] |
|
|
image = image.transpose(2, 0, 1) |
|
|
return image |
|
|
|
|
|
def _get_raw_labels(self): |
|
|
if self._raw_labels is None: |
|
|
self._raw_labels = self._load_raw_labels() if self._use_labels else None |
|
|
if self._raw_labels is None: |
|
|
self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) |
|
|
assert isinstance(self._raw_labels, np.ndarray) |
|
|
assert self._raw_labels.shape[0] == self._raw_shape[0] |
|
|
assert self._raw_labels.dtype in [np.float32, np.int64] |
|
|
if self._raw_labels.dtype == np.int64: |
|
|
assert self._raw_labels.ndim == 1 |
|
|
assert np.all(self._raw_labels >= 0) |
|
|
return self._raw_labels |
|
|
|
|
|
def _load_raw_labels(self): |
|
|
fname = 'dataset.json' |
|
|
if fname not in self._all_fnames: |
|
|
return None |
|
|
with self._open_file(fname) as f: |
|
|
labels = json.load(f)['labels'] |
|
|
if labels is None: |
|
|
return None |
|
|
labels = dict(labels) |
|
|
labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] |
|
|
labels = np.array(labels) |
|
|
labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) |
|
|
return labels |
|
|
|
|
|
|
|
|
|