File size: 5,959 Bytes
94aa6f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import typing as T
import collections
from enum import Enum
from abc import ABC, abstractmethod
from pathlib import Path

import torch
import pytorch_lightning as pl
from torchmetrics import Metric, SumMetric
from massspecgym.utils import ReturnScalarBootStrapper


class Stage(Enum):

    TRAIN = 'train'
    VAL = 'val'
    TEST = 'test'
    NONE = 'none'

    def to_pref(self) -> str:
        return f"{self.value}_" if self != Stage.NONE else ""


class MassSpecGymModel(pl.LightningModule, ABC):

    def __init__(
        self,
        lr: float = 1e-4,
        weight_decay: float = 0.0,
        log_only_loss_at_stages: T.Sequence[Stage | str] = (),
        bootstrap_metrics: bool = True,
        df_test_path: T.Optional[str | Path] = None,
        *args,
        **kwargs
    ):
        super().__init__()
        self.save_hyperparameters()

        # Setup metring logging
        self.log_only_loss_at_stages = [
            Stage(s) if isinstance(s, str) else s for s in log_only_loss_at_stages
        ]
        self.bootstrap_metrics = bootstrap_metrics

        # Init dictionary to store dataframe columns where rows correspond to samples
        # (for constructing test dataframe with predictions and metrics for each sample)
        self.df_test_path = Path(df_test_path) if df_test_path is not None else None
        self.df_test = collections.defaultdict(list)

    @abstractmethod
    def step(
        self, batch: dict, stage: Stage = Stage.NONE
    ) -> tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError(
            "Method `step` must be implemented in the model-specific child class."
        )

    def training_step(
        self, batch: dict, batch_idx: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        return self.step(batch, stage=Stage.TRAIN)

    def validation_step(
        self, batch: dict, batch_idx: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        return self.step(batch, stage=Stage.VAL)

    def test_step(
        self, batch: dict, batch_idx: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        return self.step(batch, stage=Stage.TEST)

    @abstractmethod
    def on_batch_end(
        self, outputs: T.Any, batch: dict, batch_idx: int, stage: Stage
    ) -> None:
        """
        Method to be called at the end of each batch. This method should be implemented by a child,
        task-dedicated class and contain the evaluation necessary for the task.
        """
        raise NotImplementedError(
            "Method `on_batch_end` must be implemented in the task-specific child class."
        )

    def on_train_batch_end(self, *args, **kwargs):
        return self.on_batch_end(*args, **kwargs, stage=Stage.TRAIN)

    def on_validation_batch_end(self, *args, **kwargs):
        return self.on_batch_end(*args, **kwargs, stage=Stage.VAL)

    def on_test_batch_end(self, *args, **kwargs):
        return self.on_batch_end(*args, **kwargs, stage=Stage.TEST)

    def configure_optimizers(self):
        return torch.optim.Adam(
            self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
        )

    def get_checkpoint_monitors(self) -> list[dict]:
        monitors = [
            {"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": True}
        ]
        return monitors

    def _update_metric(
        self,
        name: str,
        metric_class: type[Metric],
        update_args: T.Any,
        batch_size: T.Optional[int] = None,
        prog_bar: bool = False,
        metric_kwargs: T.Optional[dict] = None,
        log: bool = True,
        log_n_samples: bool = False,
        bootstrap: bool = False,
        num_bootstraps: int = 100
    ) -> None:
        """
        This method enables updating and logging metrics without instantiating them in advance in
        the __init__ method. The metrics are aggreated over batches and logged at the end of the
        epoch. If the metric does not exist yet, it is instantiated and added as an attribute to the
        model.
        """
        # Process arguments
        bootstrap = bootstrap and self.bootstrap_metrics

        # Log total number of samples (useful for debugging)
        if log_n_samples:
            self._update_metric(
                name=name + "_n_samples",
                metric_class=SumMetric,
                update_args=(len(update_args[0]),),
                batch_size=1,
            )

        # Init metric if does not exits yet
        if hasattr(self, name):
            metric = getattr(self, name)
        else:
            if metric_kwargs is None:
                metric_kwargs = dict()
            metric = metric_class(**metric_kwargs)
            metric = metric.to(self.device)
            setattr(self, name, metric)

        # Update
        metric(*update_args)

        # Log
        if log:
            self.log(
                name,
                metric,
                prog_bar=prog_bar,
                batch_size=batch_size,
                on_step=False,
                on_epoch=True,
                add_dataloader_idx=False,
                metric_attribute=name  # Suggested by a torchmetrics error
            )

        # Bootstrap
        if bootstrap:
            def _bootsrapped_metric_class(**metric_kwargs):
                metric = metric_class(**metric_kwargs)
                return ReturnScalarBootStrapper(metric, std=True, num_bootstraps=num_bootstraps)

            self._update_metric(
                name=name + "_std",
                metric_class=_bootsrapped_metric_class,
                update_args=update_args,
                batch_size=batch_size,
                metric_kwargs=metric_kwargs,
            )

    def _update_df_test(self, dct: dict) -> None:
        for col, vals in dct.items():
            if isinstance(vals, torch.Tensor):
                vals = vals.tolist()
            self.df_test[col].extend(vals)