| import cv2 | |
| import os | |
| from os.path import join as osp | |
| import numpy | |
| import torch.utils.data | |
| class Dataset(torch.utils.data.Dataset): | |
| def __init__(self, file_root='data/', mode='train', transform=None): | |
| self.file_list = os.listdir(osp(file_root, mode, 'A')) | |
| self.pre_images = [osp(file_root, mode, 'A', x) for x in self.file_list] | |
| self.post_images = [osp(file_root, mode, 'B', x) for x in self.file_list] | |
| self.gts = [osp(file_root, mode, 'label', x) for x in self.file_list] | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.pre_images) | |
| def __getitem__(self, idx): | |
| pre_image_name = self.pre_images[idx] | |
| label_name = self.gts[idx] | |
| post_image_name = self.post_images[idx] | |
| pre_image = cv2.imread(pre_image_name) | |
| label = cv2.imread(label_name, 0) | |
| post_image = cv2.imread(post_image_name) | |
| img = numpy.concatenate((pre_image, post_image), axis=2) | |
| if self.transform: | |
| [img, label] = self.transform(img, label) | |
| return img, label | |
| def get_img_info(self, idx): | |
| img = cv2.imread(self.pre_images[idx]) | |
| return {"height": img.shape[0], "width": img.shape[1]} | |