| from typing import Optional | |
| from omegaconf import DictConfig | |
| import pytorch_lightning as L | |
| import torch.utils.data as torchdata | |
| from .torch import collate, worker_init_fn | |
| def get_dataset(name): | |
| if name == "mapillary": | |
| from .mapillary.data_module import MapillaryDataModule | |
| return MapillaryDataModule | |
| elif name == "nuscenes": | |
| from .nuscenes.data_module import NuScenesData | |
| return NuScenesData | |
| elif name == "kitti": | |
| from .kitti.data_module import BEVKitti360Data | |
| return BEVKitti360Data | |
| else: | |
| raise NotImplementedError(f"Dataset {name} not implemented.") | |
| class GenericDataModule(L.LightningDataModule): | |
| def __init__(self, cfg: DictConfig): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.data_module = get_dataset(cfg.name)(cfg) | |
| def prepare_data(self) -> None: | |
| self.data_module.prepare_data() | |
| def setup(self, stage: Optional[str] = None): | |
| self.data_module.setup(stage) | |
| def dataloader( | |
| self, | |
| stage: str, | |
| shuffle: bool = False, | |
| num_workers: int = None, | |
| sampler: Optional[torchdata.Sampler] = None, | |
| ): | |
| dataset = self.data_module.dataset(stage) | |
| cfg = self.cfg["loading"][stage] | |
| num_workers = cfg["num_workers"] if num_workers is None else num_workers | |
| loader = torchdata.DataLoader( | |
| dataset, | |
| batch_size=cfg["batch_size"], | |
| num_workers=num_workers, | |
| shuffle=shuffle or (stage == "train"), | |
| pin_memory=True, | |
| persistent_workers=num_workers > 0, | |
| worker_init_fn=worker_init_fn, | |
| collate_fn=collate, | |
| sampler=sampler, | |
| ) | |
| return loader | |
| def train_dataloader(self, **kwargs): | |
| return self.dataloader("train", **kwargs) | |
| def val_dataloader(self, **kwargs): | |
| return self.dataloader("val", **kwargs) | |
| def test_dataloader(self, **kwargs): | |
| return self.dataloader("test", **kwargs) |