MVP / massspecgym /data /data_module.py
yzhouchen001's picture
partial push
94aa6f9
import typing as T
import pandas as pd
import numpy as np
import pytorch_lightning as pl
import massspecgym.utils as utils
from pathlib import Path
from typing import Optional
from torch.utils.data.dataset import Subset
from torch.utils.data.dataloader import DataLoader
from massspecgym.data.datasets import MassSpecDataset
class MassSpecDataModule(pl.LightningDataModule):
"""
Data module containing a mass spectrometry dataset. This class is responsible for loading, splitting, and wrapping
the dataset into data loaders according to pre-defined train, validation, test folds.
"""
def __init__(
self,
dataset: MassSpecDataset,
batch_size: int,
num_workers: int = 0,
persistent_workers: bool = True,
split_pth: Optional[Path] = None,
**kwargs
):
"""
Args:
split_pth (Optional[Path], optional): Path to a .tsv file with columns "identifier" and "fold",
corresponding to dataset item IDs, and "fold", containg "train", "val", "test"
values. Default is None, in which case the split from the `dataset` is used.
"""
super().__init__(**kwargs)
self.dataset = dataset
self.split_pth = split_pth
self.batch_size = batch_size
self.num_workers = num_workers
self.persistent_workers = persistent_workers if num_workers > 0 else False
def prepare_data(self):
if self.split_pth is None:
self.split = self.dataset.metadata[["identifier", "fold"]]
else:
# NOTE: custom split is not tested
self.split = pd.read_csv(self.split_pth, sep="\t")
if set(self.split.columns) != {"identifier", "fold"}:
raise ValueError('Split file must contain "id" and "fold" columns.')
self.split["identifier"] = self.split["identifier"].astype(str)
if set(self.dataset.metadata["identifier"]) != set(self.split["identifier"]):
raise ValueError(
"Dataset item IDs must match the IDs in the split file."
)
self.split = self.split.set_index("identifier")["fold"]
if not set(self.split) <= {"train", "val", "test"}:
raise ValueError(
'"Folds" column must contain only "train", "val", or "test" values.'
)
def setup(self, stage=None):
split_mask = self.split.loc[self.dataset.metadata["identifier"]].values
if stage == "fit" or stage is None:
self.train_dataset = Subset(
self.dataset, np.where(split_mask == "train")[0]
)
self.val_dataset = Subset(self.dataset, np.where(split_mask == "val")[0])
if stage == "test":
self.test_dataset = Subset(self.dataset, np.where(split_mask == "test")[0])
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
persistent_workers=self.persistent_workers,
drop_last=False,
collate_fn=self.dataset.collate_fn,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
persistent_workers=self.persistent_workers,
drop_last=False,
collate_fn=self.dataset.collate_fn,
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
persistent_workers=self.persistent_workers,
drop_last=False,
collate_fn=self.dataset.collate_fn,
)