Spaces:
Sleeping
Sleeping
File size: 5,959 Bytes
94aa6f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import typing as T
import collections
from enum import Enum
from abc import ABC, abstractmethod
from pathlib import Path
import torch
import pytorch_lightning as pl
from torchmetrics import Metric, SumMetric
from massspecgym.utils import ReturnScalarBootStrapper
class Stage(Enum):
TRAIN = 'train'
VAL = 'val'
TEST = 'test'
NONE = 'none'
def to_pref(self) -> str:
return f"{self.value}_" if self != Stage.NONE else ""
class MassSpecGymModel(pl.LightningModule, ABC):
def __init__(
self,
lr: float = 1e-4,
weight_decay: float = 0.0,
log_only_loss_at_stages: T.Sequence[Stage | str] = (),
bootstrap_metrics: bool = True,
df_test_path: T.Optional[str | Path] = None,
*args,
**kwargs
):
super().__init__()
self.save_hyperparameters()
# Setup metring logging
self.log_only_loss_at_stages = [
Stage(s) if isinstance(s, str) else s for s in log_only_loss_at_stages
]
self.bootstrap_metrics = bootstrap_metrics
# Init dictionary to store dataframe columns where rows correspond to samples
# (for constructing test dataframe with predictions and metrics for each sample)
self.df_test_path = Path(df_test_path) if df_test_path is not None else None
self.df_test = collections.defaultdict(list)
@abstractmethod
def step(
self, batch: dict, stage: Stage = Stage.NONE
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError(
"Method `step` must be implemented in the model-specific child class."
)
def training_step(
self, batch: dict, batch_idx: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
return self.step(batch, stage=Stage.TRAIN)
def validation_step(
self, batch: dict, batch_idx: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
return self.step(batch, stage=Stage.VAL)
def test_step(
self, batch: dict, batch_idx: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
return self.step(batch, stage=Stage.TEST)
@abstractmethod
def on_batch_end(
self, outputs: T.Any, batch: dict, batch_idx: int, stage: Stage
) -> None:
"""
Method to be called at the end of each batch. This method should be implemented by a child,
task-dedicated class and contain the evaluation necessary for the task.
"""
raise NotImplementedError(
"Method `on_batch_end` must be implemented in the task-specific child class."
)
def on_train_batch_end(self, *args, **kwargs):
return self.on_batch_end(*args, **kwargs, stage=Stage.TRAIN)
def on_validation_batch_end(self, *args, **kwargs):
return self.on_batch_end(*args, **kwargs, stage=Stage.VAL)
def on_test_batch_end(self, *args, **kwargs):
return self.on_batch_end(*args, **kwargs, stage=Stage.TEST)
def configure_optimizers(self):
return torch.optim.Adam(
self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
)
def get_checkpoint_monitors(self) -> list[dict]:
monitors = [
{"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": True}
]
return monitors
def _update_metric(
self,
name: str,
metric_class: type[Metric],
update_args: T.Any,
batch_size: T.Optional[int] = None,
prog_bar: bool = False,
metric_kwargs: T.Optional[dict] = None,
log: bool = True,
log_n_samples: bool = False,
bootstrap: bool = False,
num_bootstraps: int = 100
) -> None:
"""
This method enables updating and logging metrics without instantiating them in advance in
the __init__ method. The metrics are aggreated over batches and logged at the end of the
epoch. If the metric does not exist yet, it is instantiated and added as an attribute to the
model.
"""
# Process arguments
bootstrap = bootstrap and self.bootstrap_metrics
# Log total number of samples (useful for debugging)
if log_n_samples:
self._update_metric(
name=name + "_n_samples",
metric_class=SumMetric,
update_args=(len(update_args[0]),),
batch_size=1,
)
# Init metric if does not exits yet
if hasattr(self, name):
metric = getattr(self, name)
else:
if metric_kwargs is None:
metric_kwargs = dict()
metric = metric_class(**metric_kwargs)
metric = metric.to(self.device)
setattr(self, name, metric)
# Update
metric(*update_args)
# Log
if log:
self.log(
name,
metric,
prog_bar=prog_bar,
batch_size=batch_size,
on_step=False,
on_epoch=True,
add_dataloader_idx=False,
metric_attribute=name # Suggested by a torchmetrics error
)
# Bootstrap
if bootstrap:
def _bootsrapped_metric_class(**metric_kwargs):
metric = metric_class(**metric_kwargs)
return ReturnScalarBootStrapper(metric, std=True, num_bootstraps=num_bootstraps)
self._update_metric(
name=name + "_std",
metric_class=_bootsrapped_metric_class,
update_args=update_args,
batch_size=batch_size,
metric_kwargs=metric_kwargs,
)
def _update_df_test(self, dct: dict) -> None:
for col, vals in dct.items():
if isinstance(vals, torch.Tensor):
vals = vals.tolist()
self.df_test[col].extend(vals)
|