yzhouchen001 commited on
Commit
94aa6f9
·
1 Parent(s): c65d76d

partial push

Browse files
.gitignore ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
163
+
164
+ # data
165
+ data/*
166
+ experiments/main_result/*
167
+ experiments/old/*
168
+ experiments/test_dir/*
169
+ my_notebooks/*
170
+ other/*
171
+ .cache/
172
+
173
+
174
+ !data/.gitkeep
175
+ !experiments/.gitkeep
176
+ !data/sample/
massspecgym/__init__.py ADDED
File without changes
massspecgym/data/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .datasets import MassSpecDataset, RetrievalDataset
2
+ from .data_module import MassSpecDataModule
3
+
4
+ __all__ = [
5
+ "MassSpecDataset",
6
+ "RetrievalDataset",
7
+ "MassSpecDataModule"
8
+ ]
massspecgym/data/data_module.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as T
2
+ import pandas as pd
3
+ import numpy as np
4
+ import pytorch_lightning as pl
5
+ import massspecgym.utils as utils
6
+ from pathlib import Path
7
+ from typing import Optional
8
+ from torch.utils.data.dataset import Subset
9
+ from torch.utils.data.dataloader import DataLoader
10
+ from massspecgym.data.datasets import MassSpecDataset
11
+
12
+
13
+ class MassSpecDataModule(pl.LightningDataModule):
14
+ """
15
+ Data module containing a mass spectrometry dataset. This class is responsible for loading, splitting, and wrapping
16
+ the dataset into data loaders according to pre-defined train, validation, test folds.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ dataset: MassSpecDataset,
22
+ batch_size: int,
23
+ num_workers: int = 0,
24
+ persistent_workers: bool = True,
25
+ split_pth: Optional[Path] = None,
26
+ **kwargs
27
+ ):
28
+ """
29
+ Args:
30
+ split_pth (Optional[Path], optional): Path to a .tsv file with columns "identifier" and "fold",
31
+ corresponding to dataset item IDs, and "fold", containg "train", "val", "test"
32
+ values. Default is None, in which case the split from the `dataset` is used.
33
+ """
34
+ super().__init__(**kwargs)
35
+ self.dataset = dataset
36
+ self.split_pth = split_pth
37
+ self.batch_size = batch_size
38
+ self.num_workers = num_workers
39
+ self.persistent_workers = persistent_workers if num_workers > 0 else False
40
+
41
+ def prepare_data(self):
42
+ if self.split_pth is None:
43
+ self.split = self.dataset.metadata[["identifier", "fold"]]
44
+ else:
45
+ # NOTE: custom split is not tested
46
+ self.split = pd.read_csv(self.split_pth, sep="\t")
47
+ if set(self.split.columns) != {"identifier", "fold"}:
48
+ raise ValueError('Split file must contain "id" and "fold" columns.')
49
+ self.split["identifier"] = self.split["identifier"].astype(str)
50
+ if set(self.dataset.metadata["identifier"]) != set(self.split["identifier"]):
51
+ raise ValueError(
52
+ "Dataset item IDs must match the IDs in the split file."
53
+ )
54
+
55
+ self.split = self.split.set_index("identifier")["fold"]
56
+ if not set(self.split) <= {"train", "val", "test"}:
57
+ raise ValueError(
58
+ '"Folds" column must contain only "train", "val", or "test" values.'
59
+ )
60
+
61
+ def setup(self, stage=None):
62
+ split_mask = self.split.loc[self.dataset.metadata["identifier"]].values
63
+ if stage == "fit" or stage is None:
64
+ self.train_dataset = Subset(
65
+ self.dataset, np.where(split_mask == "train")[0]
66
+ )
67
+ self.val_dataset = Subset(self.dataset, np.where(split_mask == "val")[0])
68
+ if stage == "test":
69
+ self.test_dataset = Subset(self.dataset, np.where(split_mask == "test")[0])
70
+
71
+ def train_dataloader(self):
72
+ return DataLoader(
73
+ self.train_dataset,
74
+ batch_size=self.batch_size,
75
+ shuffle=True,
76
+ num_workers=self.num_workers,
77
+ persistent_workers=self.persistent_workers,
78
+ drop_last=False,
79
+ collate_fn=self.dataset.collate_fn,
80
+ )
81
+
82
+ def val_dataloader(self):
83
+ return DataLoader(
84
+ self.val_dataset,
85
+ batch_size=self.batch_size,
86
+ shuffle=False,
87
+ num_workers=self.num_workers,
88
+ persistent_workers=self.persistent_workers,
89
+ drop_last=False,
90
+ collate_fn=self.dataset.collate_fn,
91
+ )
92
+
93
+ def test_dataloader(self):
94
+ return DataLoader(
95
+ self.test_dataset,
96
+ batch_size=self.batch_size,
97
+ shuffle=False,
98
+ num_workers=self.num_workers,
99
+ persistent_workers=self.persistent_workers,
100
+ drop_last=False,
101
+ collate_fn=self.dataset.collate_fn,
102
+ )
massspecgym/data/datasets.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import json
3
+ import typing as T
4
+ import numpy as np
5
+ import torch
6
+ import matchms
7
+ import massspecgym.utils as utils
8
+ from pathlib import Path
9
+ from rdkit import Chem
10
+ from torch.utils.data.dataset import Dataset
11
+ from torch.utils.data.dataloader import default_collate
12
+ from matchms.importing import load_from_mgf
13
+ from massspecgym.data.transforms import SpecTransform, MolTransform, MolToInChIKey
14
+
15
+
16
+ class MassSpecDataset(Dataset):
17
+ """
18
+ Dataset containing mass spectra and their corresponding molecular structures. This class is
19
+ responsible for loading the data from disk and applying transformation steps to the spectra and
20
+ molecules.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ spec_transform: T.Optional[T.Union[SpecTransform, T.Dict[str, SpecTransform]]] = None,
26
+ mol_transform: T.Optional[T.Union[MolTransform, T.Dict[str, MolTransform]]] = None,
27
+ pth: T.Optional[Path] = None,
28
+ return_mol_freq: bool = True,
29
+ return_identifier: bool = True,
30
+ dtype: T.Type = torch.float32
31
+ ):
32
+ """
33
+ Args:
34
+ pth (Optional[Path], optional): Path to the .tsv or .mgf file containing the mass spectra.
35
+ Default is None, in which case the MassSpecGym dataset is downloaded from HuggingFace Hub.
36
+ """
37
+ self.pth = pth
38
+ self.spec_transform = spec_transform
39
+ self.mol_transform = mol_transform
40
+ self.return_mol_freq = return_mol_freq
41
+
42
+ if self.pth is None:
43
+ self.pth = utils.hugging_face_download("MassSpecGym.tsv")
44
+
45
+ if isinstance(self.pth, str):
46
+ self.pth = Path(self.pth)
47
+
48
+ if self.pth.suffix == ".tsv":
49
+ self.metadata = pd.read_csv(self.pth, sep="\t")
50
+ self.spectra = self.metadata.apply(
51
+ lambda row: matchms.Spectrum(
52
+ mz=np.array([float(m) for m in row["mzs"].split(",")]),
53
+ intensities=np.array(
54
+ [float(i) for i in row["intensities"].split(",")]
55
+ ),
56
+ metadata={"precursor_mz": row["precursor_mz"]},
57
+ ),
58
+ axis=1,
59
+ )
60
+ self.metadata = self.metadata.drop(columns=["mzs", "intensities"])
61
+ elif self.pth.suffix == ".mgf":
62
+ self.spectra = list(load_from_mgf(str(self.pth)))
63
+ self.metadata = pd.DataFrame([s.metadata for s in self.spectra])
64
+ else:
65
+ raise ValueError(f"{self.pth.suffix} file format not supported.")
66
+
67
+ if self.return_mol_freq:
68
+ if "inchikey" not in self.metadata.columns:
69
+ self.metadata["inchikey"] = self.metadata["smiles"].apply(utils.smiles_to_inchi_key)
70
+ self.metadata["mol_freq"] = self.metadata.groupby("inchikey")["inchikey"].transform("count")
71
+
72
+ self.return_identifier = return_identifier
73
+ self.dtype = dtype
74
+
75
+ def __len__(self) -> int:
76
+ return len(self.spectra)
77
+
78
+ def __getitem__(
79
+ self, i: int, transform_spec: bool = True, transform_mol: bool = True
80
+ ) -> dict:
81
+ spec = self.spectra[i]
82
+ metadata = self.metadata.iloc[i]
83
+ mol = metadata["smiles"]
84
+
85
+ # Apply all transformations to the spectrum
86
+ item = {}
87
+ if transform_spec and self.spec_transform:
88
+ if isinstance(self.spec_transform, dict):
89
+ for key, transform in self.spec_transform.items():
90
+ item[key] = transform(spec) if transform is not None else spec
91
+ else:
92
+ item["spec"] = self.spec_transform(spec)
93
+ else:
94
+ item["spec"] = spec
95
+
96
+ # Apply all transformations to the molecule
97
+ if transform_mol and self.mol_transform:
98
+ if isinstance(self.mol_transform, dict):
99
+ for key, transform in self.mol_transform.items():
100
+ item[key] = transform(mol) if transform is not None else mol
101
+ else:
102
+ item["mol"] = self.mol_transform(mol)
103
+ else:
104
+ item["mol"] = mol
105
+
106
+ # Add other metadata to the item
107
+ # item.update({
108
+ # k: metadata[k] for k in ["precursor_mz", "adduct"]
109
+ # })
110
+
111
+ if self.return_mol_freq:
112
+ item["mol_freq"] = metadata["mol_freq"]
113
+
114
+ if self.return_identifier:
115
+ item["identifier"] = metadata["identifier"]
116
+
117
+ # TODO: this should be refactored
118
+ for k, v in item.items():
119
+ if not isinstance(v, str):
120
+ try:
121
+ item[k] = torch.as_tensor(v, dtype=self.dtype)
122
+ except:
123
+ continue
124
+
125
+ return item
126
+
127
+ @staticmethod
128
+ def collate_fn(batch: T.Iterable[dict]) -> dict:
129
+ """
130
+ Custom collate function to handle the outputs of __getitem__.
131
+ """
132
+ return default_collate(batch)
133
+
134
+
135
+ class RetrievalDataset(MassSpecDataset):
136
+ """
137
+ Dataset containing mass spectra and their corresponding molecular structures, with additional
138
+ candidates of molecules for retrieval based on spectral similarity.
139
+ """
140
+
141
+ def __init__(
142
+ self,
143
+ mol_label_transform: MolTransform = MolToInChIKey(),
144
+ candidates_pth: T.Optional[T.Union[Path, str]] = None,
145
+ **kwargs,
146
+ ):
147
+ super().__init__(**kwargs)
148
+
149
+ self.candidates_pth = candidates_pth
150
+ self.mol_label_transform = mol_label_transform
151
+
152
+ # Download candidates from HuggigFace Hub if not a path to exisiting file is passed
153
+ if self.candidates_pth is None:
154
+ self.candidates_pth = utils.hugging_face_download(
155
+ "molecules/MassSpecGym_retrieval_candidates_mass.json"
156
+ )
157
+ elif isinstance(self.candidates_pth, str):
158
+ if Path(self.candidates_pth).is_file():
159
+ self.candidates_pth = Path(self.candidates_pth)
160
+ else:
161
+ self.candidates_pth = utils.hugging_face_download(candidates_pth)
162
+
163
+ # Read candidates_pth from json to dict: SMILES -> respective candidate SMILES
164
+ with open(self.candidates_pth, "r") as file:
165
+ self.candidates = json.load(file)
166
+
167
+ def __getitem__(self, i) -> dict:
168
+ item = super().__getitem__(i, transform_mol=False)
169
+
170
+ # Save the original SMILES representation of the query molecule (for evaluation)
171
+ item["smiles"] = item["mol"]
172
+
173
+ # Get candidates
174
+ if item["mol"] not in self.candidates:
175
+ raise ValueError(f'No candidates for the query molecule {item["mol"]}.')
176
+ item["candidates"] = self.candidates[item["mol"]]
177
+
178
+ # Save the original SMILES representations of the canidates (for evaluation)
179
+ item["candidates_smiles"] = item["candidates"]
180
+
181
+ # Create neg/pos label mask by matching the query molecule with the candidates
182
+ item_label = self.mol_label_transform(item["mol"])
183
+ item["labels"] = [
184
+ self.mol_label_transform(c) == item_label for c in item["candidates"]
185
+ ]
186
+
187
+ if not any(item["labels"]):
188
+ raise ValueError(
189
+ f'Query molecule {item["mol"]} not found in the candidates list.'
190
+ )
191
+
192
+ # Transform the query and candidate molecules
193
+ item["mol"] = self.mol_transform(item["mol"])
194
+ item["candidates"] = [self.mol_transform(c) for c in item["candidates"]]
195
+ if isinstance(item["mol"], np.ndarray):
196
+ item["mol"] = torch.as_tensor(item["mol"], dtype=self.dtype)
197
+ # item["candidates"] = [torch.as_tensor(c, dtype=self.dtype) for c in item["candidates"]]
198
+
199
+ return item
200
+
201
+ @staticmethod
202
+ def collate_fn(batch: T.Iterable[dict]) -> dict:
203
+ # Standard collate for everything except candidates and their labels (which may have different length per sample)
204
+ collated_batch = {}
205
+ for k in batch[0].keys():
206
+ if k not in ["candidates", "labels", "candidates_smiles"]:
207
+ collated_batch[k] = default_collate([item[k] for item in batch])
208
+
209
+ # Collate candidates and labels by concatenating and storing sizes of each list
210
+ collated_batch["candidates"] = torch.as_tensor(
211
+ np.concatenate([item["candidates"] for item in batch])
212
+ )
213
+ collated_batch["labels"] = torch.as_tensor(
214
+ sum([item["labels"] for item in batch], start=[])
215
+ )
216
+ collated_batch["batch_ptr"] = torch.as_tensor(
217
+ [len(item["candidates"]) for item in batch]
218
+ )
219
+ collated_batch["candidates_smiles"] = \
220
+ sum([item["candidates_smiles"] for item in batch], start=[])
221
+
222
+ return collated_batch
223
+
224
+
225
+ # TODO: Datasets for unlabeled data.
massspecgym/data/transforms.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import matchms
4
+ import matchms.filtering as ms_filters
5
+ from rdkit.Chem import AllChem as Chem
6
+ from typing import Optional
7
+ from abc import ABC, abstractmethod
8
+ import massspecgym.utils as utils
9
+ from massspecgym.definitions import CHEM_ELEMS
10
+
11
+
12
+ class SpecTransform(ABC):
13
+ """
14
+ Base class for spectrum transformations. Custom transformatios should inherit from this class.
15
+ The transformation consists of two consecutive steps:
16
+ 1. Apply a series of matchms filters to the input spectrum (method `matchms_transforms`).
17
+ 2. Convert the matchms spectrum to a torch tensor (method `matchms_to_torch`).
18
+ """
19
+
20
+ @abstractmethod
21
+ def matchms_transforms(self, spec: matchms.Spectrum) -> matchms.Spectrum:
22
+ """
23
+ Apply a series of matchms filters to the input spectrum. Abstract method.
24
+ """
25
+
26
+ @abstractmethod
27
+ def matchms_to_torch(self, spec: matchms.Spectrum) -> torch.Tensor:
28
+ """
29
+ Convert a matchms spectrum to a torch tensor. Abstract method.
30
+ """
31
+
32
+ def __call__(self, spec: matchms.Spectrum) -> torch.Tensor:
33
+ """
34
+ Compose the matchms filters and the torch conversion.
35
+ """
36
+ return self.matchms_to_torch(self.matchms_transforms(spec))
37
+
38
+
39
+ def default_matchms_transforms(
40
+ spec: matchms.Spectrum,
41
+ n_max_peaks: int = 60,
42
+ mz_from: float = 10,
43
+ mz_to: float = 1000,
44
+ ) -> matchms.Spectrum:
45
+ spec = ms_filters.select_by_mz(spec, mz_from=mz_from, mz_to=mz_to)
46
+ if n_max_peaks is not None:
47
+ spec = ms_filters.reduce_to_number_of_peaks(spec, n_max=n_max_peaks)
48
+ spec = ms_filters.normalize_intensities(spec)
49
+ return spec
50
+
51
+
52
+ class SpecTokenizer(SpecTransform):
53
+ def __init__(
54
+ self,
55
+ n_peaks: Optional[int] = 60,
56
+ prec_mz_intensity: Optional[float] = 1.1,
57
+ matchms_kwargs: Optional[dict] = None
58
+ ) -> None:
59
+ self.n_peaks = n_peaks
60
+ self.prec_mz_intensity = prec_mz_intensity
61
+ self.matchms_kwargs = matchms_kwargs if matchms_kwargs is not None else {}
62
+
63
+ def matchms_transforms(self, spec: matchms.Spectrum) -> matchms.Spectrum:
64
+ return default_matchms_transforms(spec, n_max_peaks=self.n_peaks, **self.matchms_kwargs)
65
+
66
+ def matchms_to_torch(self, spec: matchms.Spectrum) -> torch.Tensor:
67
+ """
68
+ Stack arrays of mz and intensities into a matrix of shape (num_peaks, 2).
69
+ If the number of peaks is less than `n_peaks`, pad the matrix with zeros.
70
+ """
71
+ spec_t = np.vstack([spec.peaks.mz, spec.peaks.intensities]).T
72
+ if self.prec_mz_intensity is not None:
73
+ spec_t = np.vstack([[spec.metadata["precursor_mz"], self.prec_mz_intensity], spec_t])
74
+ if self.n_peaks is not None:
75
+ spec_t = utils.pad_spectrum(
76
+ spec_t,
77
+ self.n_peaks + 1 if self.prec_mz_intensity is not None else self.n_peaks
78
+ )
79
+ return torch.from_numpy(spec_t)
80
+
81
+
82
+ class SpecBinner(SpecTransform):
83
+ def __init__(
84
+ self,
85
+ max_mz: float = 1005,
86
+ bin_width: float = 1,
87
+ to_rel_intensities: bool = True,
88
+ ) -> None:
89
+ self.max_mz = max_mz
90
+ self.bin_width = bin_width
91
+ self.to_rel_intensities = to_rel_intensities
92
+ if not (max_mz / bin_width).is_integer():
93
+ raise ValueError("`max_mz` must be divisible by `bin_width`.")
94
+
95
+ def matchms_transforms(self, spec: matchms.Spectrum) -> matchms.Spectrum:
96
+ return default_matchms_transforms(spec, mz_to=self.max_mz, n_max_peaks=None)
97
+
98
+ def matchms_to_torch(self, spec: matchms.Spectrum) -> torch.Tensor:
99
+ """
100
+ Bin the spectrum into a fixed number of bins.
101
+ """
102
+ binned_spec = self._bin_mass_spectrum(
103
+ mzs=spec.peaks.mz,
104
+ intensities=spec.peaks.intensities,
105
+ max_mz=self.max_mz,
106
+ bin_width=self.bin_width,
107
+ to_rel_intensities=self.to_rel_intensities,
108
+ )
109
+ return torch.from_numpy(binned_spec)
110
+
111
+ def _bin_mass_spectrum(
112
+ self, mzs, intensities, max_mz, bin_width, to_rel_intensities=True
113
+ ):
114
+ # Calculate the number of bins
115
+ num_bins = int(np.ceil(max_mz / bin_width))
116
+
117
+ # Calculate the bin indices for each mass
118
+ bin_indices = np.floor(mzs / bin_width).astype(int)
119
+
120
+ # Filter out mzs that exceed max_mz
121
+ valid_indices = bin_indices[mzs <= max_mz]
122
+ valid_intensities = intensities[mzs <= max_mz]
123
+
124
+ # Clip bin indices to ensure they are within the valid range
125
+ valid_indices = np.clip(valid_indices, 0, num_bins - 1)
126
+
127
+ # Initialize an array to store the binned intensities
128
+ binned_intensities = np.zeros(num_bins)
129
+
130
+ # Use np.add.at to sum intensities in the appropriate bins
131
+ np.add.at(binned_intensities, valid_indices, valid_intensities)
132
+
133
+ # Generate the bin edges for reference
134
+ # bin_edges = np.arange(0, max_mz + bin_width, bin_width)
135
+
136
+ # Normalize the intensities to relative intensities
137
+ if to_rel_intensities:
138
+ binned_intensities /= np.max(binned_intensities)
139
+
140
+ return binned_intensities # , bin_edges
141
+
142
+
143
+ class MolTransform(ABC):
144
+ @abstractmethod
145
+ def from_smiles(self, mol: str):
146
+ """
147
+ Convert a SMILES string to a tensor-like representation. Abstract method.
148
+ """
149
+
150
+ def __call__(self, mol: str):
151
+ return self.from_smiles(mol)
152
+
153
+
154
+ class MolFingerprinter(MolTransform):
155
+ def __init__(self, type: str = "morgan", fp_size: int = 2048, radius: int = 2):
156
+ if type != "morgan":
157
+ raise NotImplementedError(
158
+ "Only Morgan fingerprints are implemented at the moment."
159
+ )
160
+ self.type = type
161
+ self.fp_size = fp_size
162
+ self.radius = radius
163
+
164
+ def from_smiles(self, mol: str):
165
+ mol = Chem.MolFromSmiles(mol)
166
+ return utils.morgan_fp(
167
+ mol, fp_size=self.fp_size, radius=self.radius, to_np=True
168
+ )
169
+
170
+
171
+ class MolToInChIKey(MolTransform):
172
+ def __init__(self, twod: bool = True) -> None:
173
+ self.twod = twod
174
+
175
+ def from_smiles(self, mol: str) -> str:
176
+ mol = Chem.MolFromSmiles(mol)
177
+ return utils.mol_to_inchi_key(mol, twod=self.twod)
178
+
179
+
180
+ class MolToFormulaVector(MolTransform):
181
+ def __init__(self):
182
+ self.element_index = {element: i for i, element in enumerate(CHEM_ELEMS)}
183
+
184
+ def from_smiles(self, smiles: str):
185
+ mol = Chem.MolFromSmiles(smiles)
186
+ if mol is None:
187
+ raise ValueError(f"Invalid SMILES string: {smiles}")
188
+
189
+ # Add explicit hydrogens to the molecule
190
+ mol = Chem.AddHs(mol)
191
+
192
+ # Initialize a vector of zeros for the 118 elements
193
+ formula_vector = np.zeros(118, dtype=np.int32)
194
+
195
+ # Iterate over atoms in the molecule and count occurrences of each element
196
+ for atom in mol.GetAtoms():
197
+ symbol = atom.GetSymbol()
198
+ if symbol in self.element_index:
199
+ index = self.element_index[symbol]
200
+ formula_vector[index] += 1
201
+ else:
202
+ raise ValueError(f"Element '{symbol}' not found in the list of 118 elements.")
203
+
204
+ return formula_vector
205
+
206
+ @staticmethod
207
+ def num_elements():
208
+ return len(CHEM_ELEMS)
massspecgym/definitions.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Global variables used across the package."""
2
+ import pathlib
3
+
4
+ # Dirs
5
+ MASSSPECGYM_ROOT_DIR = pathlib.Path(__file__).parent.absolute()
6
+ MASSSPECGYM_REPO_DIR = MASSSPECGYM_ROOT_DIR.parent
7
+ MASSSPECGYM_DATA_DIR = MASSSPECGYM_REPO_DIR / 'data'
8
+ MASSSPECGYM_TEST_RESULTS_DIR = MASSSPECGYM_DATA_DIR / 'test_results'
9
+ MASSSPECGYM_ASSETS_DIR = MASSSPECGYM_REPO_DIR / 'assets'
10
+
11
+ # Special tokens
12
+ PAD_TOKEN = "<pad>"
13
+ SOS_TOKEN = "<s>"
14
+ EOS_TOKEN = "</s>"
15
+ UNK_TOKEN = "<unk>"
16
+
17
+ # Chemistry
18
+ # List of all 118 elements (indexed by atomic number)
19
+ CHEM_ELEMS = [
20
+ "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar",
21
+ "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr",
22
+ "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe",
23
+ "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu",
24
+ "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac",
25
+ "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh",
26
+ "Hs", "Mt", "Ds", "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og"
27
+ ]
massspecgym/models/__init__.py ADDED
File without changes
massspecgym/models/base.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as T
2
+ import collections
3
+ from enum import Enum
4
+ from abc import ABC, abstractmethod
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ import pytorch_lightning as pl
9
+ from torchmetrics import Metric, SumMetric
10
+ from massspecgym.utils import ReturnScalarBootStrapper
11
+
12
+
13
+ class Stage(Enum):
14
+
15
+ TRAIN = 'train'
16
+ VAL = 'val'
17
+ TEST = 'test'
18
+ NONE = 'none'
19
+
20
+ def to_pref(self) -> str:
21
+ return f"{self.value}_" if self != Stage.NONE else ""
22
+
23
+
24
+ class MassSpecGymModel(pl.LightningModule, ABC):
25
+
26
+ def __init__(
27
+ self,
28
+ lr: float = 1e-4,
29
+ weight_decay: float = 0.0,
30
+ log_only_loss_at_stages: T.Sequence[Stage | str] = (),
31
+ bootstrap_metrics: bool = True,
32
+ df_test_path: T.Optional[str | Path] = None,
33
+ *args,
34
+ **kwargs
35
+ ):
36
+ super().__init__()
37
+ self.save_hyperparameters()
38
+
39
+ # Setup metring logging
40
+ self.log_only_loss_at_stages = [
41
+ Stage(s) if isinstance(s, str) else s for s in log_only_loss_at_stages
42
+ ]
43
+ self.bootstrap_metrics = bootstrap_metrics
44
+
45
+ # Init dictionary to store dataframe columns where rows correspond to samples
46
+ # (for constructing test dataframe with predictions and metrics for each sample)
47
+ self.df_test_path = Path(df_test_path) if df_test_path is not None else None
48
+ self.df_test = collections.defaultdict(list)
49
+
50
+ @abstractmethod
51
+ def step(
52
+ self, batch: dict, stage: Stage = Stage.NONE
53
+ ) -> tuple[torch.Tensor, torch.Tensor]:
54
+ raise NotImplementedError(
55
+ "Method `step` must be implemented in the model-specific child class."
56
+ )
57
+
58
+ def training_step(
59
+ self, batch: dict, batch_idx: torch.Tensor
60
+ ) -> tuple[torch.Tensor, torch.Tensor]:
61
+ return self.step(batch, stage=Stage.TRAIN)
62
+
63
+ def validation_step(
64
+ self, batch: dict, batch_idx: torch.Tensor
65
+ ) -> tuple[torch.Tensor, torch.Tensor]:
66
+ return self.step(batch, stage=Stage.VAL)
67
+
68
+ def test_step(
69
+ self, batch: dict, batch_idx: torch.Tensor
70
+ ) -> tuple[torch.Tensor, torch.Tensor]:
71
+ return self.step(batch, stage=Stage.TEST)
72
+
73
+ @abstractmethod
74
+ def on_batch_end(
75
+ self, outputs: T.Any, batch: dict, batch_idx: int, stage: Stage
76
+ ) -> None:
77
+ """
78
+ Method to be called at the end of each batch. This method should be implemented by a child,
79
+ task-dedicated class and contain the evaluation necessary for the task.
80
+ """
81
+ raise NotImplementedError(
82
+ "Method `on_batch_end` must be implemented in the task-specific child class."
83
+ )
84
+
85
+ def on_train_batch_end(self, *args, **kwargs):
86
+ return self.on_batch_end(*args, **kwargs, stage=Stage.TRAIN)
87
+
88
+ def on_validation_batch_end(self, *args, **kwargs):
89
+ return self.on_batch_end(*args, **kwargs, stage=Stage.VAL)
90
+
91
+ def on_test_batch_end(self, *args, **kwargs):
92
+ return self.on_batch_end(*args, **kwargs, stage=Stage.TEST)
93
+
94
+ def configure_optimizers(self):
95
+ return torch.optim.Adam(
96
+ self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
97
+ )
98
+
99
+ def get_checkpoint_monitors(self) -> list[dict]:
100
+ monitors = [
101
+ {"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": True}
102
+ ]
103
+ return monitors
104
+
105
+ def _update_metric(
106
+ self,
107
+ name: str,
108
+ metric_class: type[Metric],
109
+ update_args: T.Any,
110
+ batch_size: T.Optional[int] = None,
111
+ prog_bar: bool = False,
112
+ metric_kwargs: T.Optional[dict] = None,
113
+ log: bool = True,
114
+ log_n_samples: bool = False,
115
+ bootstrap: bool = False,
116
+ num_bootstraps: int = 100
117
+ ) -> None:
118
+ """
119
+ This method enables updating and logging metrics without instantiating them in advance in
120
+ the __init__ method. The metrics are aggreated over batches and logged at the end of the
121
+ epoch. If the metric does not exist yet, it is instantiated and added as an attribute to the
122
+ model.
123
+ """
124
+ # Process arguments
125
+ bootstrap = bootstrap and self.bootstrap_metrics
126
+
127
+ # Log total number of samples (useful for debugging)
128
+ if log_n_samples:
129
+ self._update_metric(
130
+ name=name + "_n_samples",
131
+ metric_class=SumMetric,
132
+ update_args=(len(update_args[0]),),
133
+ batch_size=1,
134
+ )
135
+
136
+ # Init metric if does not exits yet
137
+ if hasattr(self, name):
138
+ metric = getattr(self, name)
139
+ else:
140
+ if metric_kwargs is None:
141
+ metric_kwargs = dict()
142
+ metric = metric_class(**metric_kwargs)
143
+ metric = metric.to(self.device)
144
+ setattr(self, name, metric)
145
+
146
+ # Update
147
+ metric(*update_args)
148
+
149
+ # Log
150
+ if log:
151
+ self.log(
152
+ name,
153
+ metric,
154
+ prog_bar=prog_bar,
155
+ batch_size=batch_size,
156
+ on_step=False,
157
+ on_epoch=True,
158
+ add_dataloader_idx=False,
159
+ metric_attribute=name # Suggested by a torchmetrics error
160
+ )
161
+
162
+ # Bootstrap
163
+ if bootstrap:
164
+ def _bootsrapped_metric_class(**metric_kwargs):
165
+ metric = metric_class(**metric_kwargs)
166
+ return ReturnScalarBootStrapper(metric, std=True, num_bootstraps=num_bootstraps)
167
+
168
+ self._update_metric(
169
+ name=name + "_std",
170
+ metric_class=_bootsrapped_metric_class,
171
+ update_args=update_args,
172
+ batch_size=batch_size,
173
+ metric_kwargs=metric_kwargs,
174
+ )
175
+
176
+ def _update_df_test(self, dct: dict) -> None:
177
+ for col, vals in dct.items():
178
+ if isinstance(vals, torch.Tensor):
179
+ vals = vals.tolist()
180
+ self.df_test[col].extend(vals)
massspecgym/models/de_novo/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .base import DeNovoMassSpecGymModel
2
+ from .random import RandomDeNovo
3
+ from .dummy import DummyDeNovo
4
+ from .smiles_tranformer import SmilesTransformer
5
+
6
+ __all__ = ["DeNovoMassSpecGymModel", "RandomDeNovo", "DummyDeNovo", "SmilesTransformer"]
massspecgym/models/de_novo/base.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as T
2
+ from abc import ABC
3
+
4
+ import torch
5
+ import pandas as pd
6
+ from rdkit import Chem
7
+ from rdkit.DataStructs import TanimotoSimilarity
8
+ from torchmetrics.aggregation import MeanMetric
9
+
10
+ from massspecgym.models.base import MassSpecGymModel, Stage
11
+ from massspecgym.utils import morgan_fp, mol_to_inchi_key, MyopicMCES
12
+
13
+
14
+ class DeNovoMassSpecGymModel(MassSpecGymModel, ABC):
15
+
16
+ def __init__(
17
+ self,
18
+ top_ks: T.Iterable[int] = (1, 10),
19
+ myopic_mces_kwargs: T.Optional[T.Mapping] = None,
20
+ *args,
21
+ **kwargs
22
+ ):
23
+ super().__init__(*args, **kwargs)
24
+
25
+ self.top_ks = top_ks
26
+ self.myopic_mces = MyopicMCES(**(myopic_mces_kwargs or {}))
27
+ self.mol_pred_kind: T.Literal["smiles", "rdkit"] = "smiles"
28
+ # caches of already computed results to avoid expensive re-computations
29
+ self.mces_cache = dict()
30
+ self.mol_2_morgan_fp = dict()
31
+
32
+ def on_batch_end(
33
+ self,
34
+ outputs: T.Any,
35
+ batch: dict,
36
+ batch_idx: int,
37
+ stage: Stage
38
+ ) -> None:
39
+ self.log(
40
+ f"{stage.to_pref()}loss",
41
+ outputs['loss'],
42
+ batch_size=batch['spec'].size(0),
43
+ sync_dist=True,
44
+ prog_bar=True,
45
+ )
46
+
47
+ if stage in self.log_only_loss_at_stages:
48
+ return
49
+
50
+ metric_vals = self.evaluate_de_novo_step(
51
+ outputs["mols_pred"], # (bs, k) list of generated rdkit molecules or SMILES strings
52
+ batch["mol"], # (bs) list of ground truth SMILES strings
53
+ stage=stage
54
+ )
55
+
56
+ if stage == Stage.TEST and self.df_test_path is not None:
57
+ self._update_df_test(metric_vals)
58
+
59
+ def evaluate_de_novo_step(
60
+ self,
61
+ mols_pred: list[list[T.Optional[Chem.Mol | str]]],
62
+ mol_true: list[str],
63
+ stage: Stage,
64
+ ) -> dict[str, torch.Tensor]:
65
+ """
66
+ # TODO: refactor to compute only for max(k) and then use the result to obtain the rest by
67
+ subsetting.
68
+
69
+ Main evaluation method for the models for de novo molecule generation from mass spectra.
70
+
71
+ Args:
72
+ mols_pred (list[list[Mol | str]]): (bs, k) list of generated rdkit molecules or SMILES
73
+ strings with possible Nones if no molecule was generated
74
+ mol_true (list[str]): (bs) list of ground-truth SMILES strings
75
+ """
76
+ # Initialize return dictionary to store metric values per sample
77
+ metric_vals = {}
78
+
79
+ # Get SMILES and RDKit molecule objects for all predictions
80
+ if self.mol_pred_kind == "smiles":
81
+ smiles_pred_valid, mols_pred_valid = [], []
82
+ for mols_pred_sample in mols_pred:
83
+ smiles_pred_valid_sample, mols_pred_valid_sample = [], []
84
+ for s in mols_pred_sample:
85
+ m = Chem.MolFromSmiles(s) if s is not None else None
86
+ # If SMILES cannot be converted to RDKit molecule, the molecule is set to None
87
+ smiles_pred_valid_sample.append(s if m is not None else None)
88
+ mols_pred_valid_sample.append(m)
89
+ smiles_pred_valid.append(smiles_pred_valid_sample)
90
+ mols_pred_valid.append(mols_pred_valid_sample)
91
+ smiles_pred, mols_pred = smiles_pred_valid, mols_pred_valid
92
+ elif self.mol_pred_kind == "rdkit":
93
+ smiles_pred = [
94
+ [Chem.MolToSmiles(m) if m is not None else None for m in ms]
95
+ for ms in mols_pred
96
+ ]
97
+ else:
98
+ raise ValueError(f"Invalid mol_pred_kind: {self.mol_pred_kind}")
99
+
100
+ # Auxiliary metric: number of valid molecules
101
+ self._update_metric(
102
+ stage.to_pref() + f"num_valid_mols",
103
+ MeanMetric,
104
+ ([sum([m is not None for m in ms]) for ms in mols_pred],),
105
+ batch_size=len(mols_pred),
106
+ )
107
+
108
+ # Get RDKit molecule objects for ground truth
109
+ smile_true = mol_true
110
+ mol_true = [Chem.MolFromSmiles(sm) for sm in mol_true]
111
+
112
+ def _get_morgan_fp_with_cache(mol):
113
+ """
114
+ A helper function to retrieve either cached Morgan Fingerprint value, or to compute and cache it
115
+ @param mol: RDKit molecule object
116
+ @return:
117
+ """
118
+ if mol not in self.mol_2_morgan_fp:
119
+ self.mol_2_morgan_fp[mol] = morgan_fp(mol, to_np=False)
120
+ return self.mol_2_morgan_fp[mol]
121
+
122
+ # Evaluate top-k metrics
123
+ for top_k in self.top_ks:
124
+ # Get top-k predicted molecules for each ground-truth sample
125
+ smiles_pred_top_k = [smiles_pred_sample[:top_k] for smiles_pred_sample in smiles_pred]
126
+ mols_pred_top_k = [mols_pred_sample[:top_k] for mols_pred_sample in mols_pred]
127
+
128
+ # 1. Evaluate minimum common edge subgraph:
129
+ # Calculate MCES distance between top-k predicted molecules and ground truth and
130
+ # report the minimum distance. The minimum distances for each sample in the batch are
131
+ # averaged across the epoch.
132
+ min_mces_dists = []
133
+ mces_thld = 100
134
+ # Iterate over batch
135
+ for preds, true in zip(smiles_pred_top_k, smile_true):
136
+ # Iterate over top-k predicted molecule samples
137
+ dists = []
138
+ for pred in preds:
139
+ if pred is None:
140
+ dists.append(mces_thld)
141
+ else:
142
+ if (true, pred) not in self.mces_cache:
143
+ mce_val = self.myopic_mces(true, pred)
144
+ self.mces_cache[(true, pred)] = mce_val
145
+ dists.append(self.mces_cache[(true, pred)])
146
+ min_mces_dists.append(min(min(dists), mces_thld))
147
+ min_mces_dists = torch.tensor(min_mces_dists, device=self.device)
148
+
149
+ # Log
150
+ metric_name = stage.to_pref() + f"top_{top_k}_mces_dist"
151
+ self._update_metric(
152
+ metric_name,
153
+ MeanMetric,
154
+ (min_mces_dists,),
155
+ batch_size=len(min_mces_dists),
156
+ bootstrap=stage == Stage.TEST
157
+ )
158
+ metric_vals[metric_name] = min_mces_dists
159
+
160
+ # 2. Evaluate Tanimoto similarity:
161
+ # Calculate Tanimoto similarity between top-k predicted molecules and ground truth and
162
+ # report the maximum similarity. The maximum similarities for each sample in the batch
163
+ # are averaged across the epoch.
164
+ fps_pred_top_k = [
165
+ [_get_morgan_fp_with_cache(m) if m is not None else None for m in ms]
166
+ for ms in mols_pred_top_k
167
+ ]
168
+ fp_true = [_get_morgan_fp_with_cache(m) for m in mol_true]
169
+
170
+ max_tanimoto_sims = []
171
+ # Iterate over batch
172
+ for preds, true in zip(fps_pred_top_k, fp_true):
173
+ # Iterate over top-k predicted molecule samples
174
+ sims = [
175
+ TanimotoSimilarity(true, pred)
176
+ if pred is not None else 0
177
+ for pred in preds
178
+ ]
179
+ max_tanimoto_sims.append(max(sims))
180
+ max_tanimoto_sims = torch.tensor(max_tanimoto_sims, device=self.device)
181
+
182
+ # Log
183
+ metric_name = stage.to_pref() + f"top_{top_k}_max_tanimoto_sim"
184
+ self._update_metric(
185
+ metric_name,
186
+ MeanMetric,
187
+ (max_tanimoto_sims,),
188
+ batch_size=len(max_tanimoto_sims),
189
+ bootstrap=stage == Stage.TEST
190
+ )
191
+ metric_vals[metric_name] = max_tanimoto_sims
192
+
193
+ # 3. Evaluate exact match (accuracy):
194
+ # Calculate if the ground truth molecule is in the top-k predicted molecules and report
195
+ # the average across the epoch.
196
+ in_top_k = [
197
+ mol_to_inchi_key(true) in [
198
+ mol_to_inchi_key(pred)
199
+ if pred is not None else None
200
+ for pred in preds
201
+ ]
202
+ for true, preds in zip(mol_true, mols_pred_top_k)
203
+ ]
204
+ in_top_k = torch.tensor(in_top_k, device=self.device)
205
+
206
+ # Log
207
+ metric_name = stage.to_pref() + f"top_{top_k}_accuracy"
208
+ self._update_metric(
209
+ metric_name,
210
+ MeanMetric,
211
+ (in_top_k,),
212
+ batch_size=len(in_top_k),
213
+ bootstrap=stage == Stage.TEST
214
+ )
215
+ metric_vals[metric_name] = in_top_k
216
+
217
+ return metric_vals
218
+
219
+
220
+ def test_step(
221
+ self,
222
+ batch: dict,
223
+ batch_idx: torch.Tensor
224
+ ) -> tuple[torch.Tensor, torch.Tensor]:
225
+ outputs = super().test_step(batch, batch_idx)
226
+
227
+ # Get generated (i.e., predicted) SMILES
228
+ if self.df_test_path is not None:
229
+ self._update_df_test({
230
+ 'identifier': batch['identifier'],
231
+ 'mols_pred': outputs['mols_pred']
232
+ })
233
+
234
+ return outputs
235
+
236
+ def on_test_epoch_end(self):
237
+ # Save test data frame to disk
238
+ if self.df_test_path is not None:
239
+ df_test = pd.DataFrame(self.df_test)
240
+ self.df_test_path.parent.mkdir(parents=True, exist_ok=True)
241
+ df_test.to_pickle(self.df_test_path)
massspecgym/models/de_novo/dummy.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+
5
+ from massspecgym.models.base import Stage
6
+ from massspecgym.models.de_novo.base import DeNovoMassSpecGymModel
7
+
8
+
9
+ class DummyDeNovo(DeNovoMassSpecGymModel):
10
+
11
+ def __init__(self, n_samples: int = 10, *args, **kwargs):
12
+ super().__init__(*args, **kwargs)
13
+ self.n_samples = n_samples
14
+
15
+ self.dummy_smiles = [
16
+ "O", # Water (H₂O)
17
+ "C", # Methane (CH₄)
18
+ "CCO", # Ethanol (C₂H₆O)
19
+ "C(C1C(C(C(C(O1)O)O)O)O)O", # Glucose (C₆H₁₂O₆)
20
+ "CC(=O)C", # Acetone (C₃H₆O)
21
+ "CC(=O)Oc1ccccc1C(=O)O", # Aspirin (C₉H₈O₄)
22
+ "CN1C=NC2=C1C(=O)N(C(=O)N2C)C", # Caffeine (C₈H₁₀N₄O₂)
23
+ "c1ccccc1", # Benzene (C₆H₆)
24
+ "CC(=O)O", # Acetic Acid (C₂H₄O₂)
25
+ "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", # Ibuprofen (C₁₃H₁₈O₂)
26
+ None
27
+ ]
28
+ self.mol_pred_kind = "smiles"
29
+
30
+ def step(
31
+ self, batch: dict, stage: Stage = Stage.NONE
32
+ ) -> tuple[torch.Tensor, torch.Tensor]:
33
+ bs = batch['spec'].shape[0]
34
+
35
+ # Sample dummy molecules from the pre-defined list
36
+ mols_pred = [[random.choice(self.dummy_smiles) for _ in range(self.n_samples)] for _ in range(bs)]
37
+
38
+ # Random baseline, so we return a dummy loss
39
+ loss = torch.tensor(0.0, requires_grad=True)
40
+
41
+ # Return molecules in the dict
42
+ return dict(loss=loss, mols_pred=mols_pred)
43
+
44
+ def configure_optimizers(self):
45
+ # No optimizer needed for a random baseline
46
+ return None
massspecgym/models/de_novo/random.py ADDED
@@ -0,0 +1,1750 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque, defaultdict
2
+ from collections.abc import Generator
3
+ from dataclasses import dataclass
4
+ from random import choice, shuffle
5
+
6
+ import chemparse
7
+ import numpy as np
8
+ import torch
9
+ from massspecgym.models.base import Stage
10
+ from massspecgym.models.de_novo.base import DeNovoMassSpecGymModel
11
+ from rdkit import Chem
12
+ from rdkit.Chem.MolStandardize import rdMolStandardize
13
+ from rdkit.Chem.rdMolDescriptors import CalcMolFormula
14
+ from rdkit.Chem.Descriptors import ExactMolWt
15
+ from rdkit.Chem.rdchem import Mol, BondType
16
+ from copy import deepcopy
17
+ from collections import Counter
18
+ import bisect
19
+ from itertools import combinations
20
+
21
+ # type aliases for code readability
22
+ chem_element = str
23
+ number_of_atoms = int
24
+
25
+
26
+ @dataclass(frozen=True, order=True)
27
+ class ValenceAndCharge:
28
+ """
29
+ A data class to store valence value with the corresponding charge
30
+ """
31
+
32
+ valence: int
33
+ charge: int
34
+
35
+
36
+ @dataclass(frozen=True, order=True)
37
+ class AtomWithValence:
38
+ """
39
+ A data class to store atom info including the computed valence
40
+ """
41
+
42
+ atom_type: chem_element
43
+ atom_valence_and_charge: ValenceAndCharge
44
+
45
+
46
+ @dataclass(frozen=True, order=True)
47
+ class BondToNeighbouringAtom:
48
+ """
49
+ A data class to store info about the adjacent atom
50
+ """
51
+
52
+ adjacent_atom: AtomWithValence
53
+ bond_type: int
54
+
55
+
56
+ @dataclass
57
+ class AtomNodeForRandomTraversal:
58
+ """
59
+ A data class to store atom info including the computed valence
60
+ """
61
+
62
+ atom_with_valence: AtomWithValence
63
+ _remaining_node_degree: int = None
64
+ _remaining_node_charge: int = None
65
+
66
+ def __post_init__(self):
67
+ """Setting up remaining node degree and charge for random traversal"""
68
+ self._remaining_node_degree = (
69
+ self.atom_with_valence.atom_valence_and_charge.valence
70
+ )
71
+ self._remaining_node_charge = (
72
+ self.atom_with_valence.atom_valence_and_charge.charge
73
+ )
74
+
75
+ @property
76
+ def remaining_node_degree(self):
77
+ """remaining_node_degree variable getter"""
78
+ return self._remaining_node_degree
79
+
80
+ @remaining_node_degree.setter
81
+ def remaining_node_degree(self, value: int):
82
+ """remaining_node_degree variable setter"""
83
+ self._remaining_node_degree = value
84
+
85
+ @property
86
+ def remaining_node_charge(self):
87
+ """remaining_node_charge variable getter"""
88
+ return self._remaining_node_charge
89
+
90
+ @remaining_node_charge.setter
91
+ def remaining_node_charge(self, value: int):
92
+ """remaining_node_charge variable setter"""
93
+ self._remaining_node_charge = value
94
+
95
+
96
+ def create_rdkit_molecule_from_edge_list(
97
+ edge_list: list[tuple[int, int]], all_graph_nodes: list[AtomNodeForRandomTraversal]
98
+ ) -> Mol:
99
+ """
100
+ A helper function converting a randomly generated edge list into rdkit.Chem.rdchem.Mol object
101
+ @param edge_list: a list of edges, where each edge is specified by the index of its nodes
102
+ @param all_graph_nodes: a list of all atomic nodes in the molecular graph
103
+ """
104
+ # first we traverse all randomly generated edges and compute bond types between each pair of atoms
105
+ edge_2_bondtype = defaultdict(int)
106
+ for edge_node_i, edge_node_j in edge_list:
107
+ edge_2_bondtype[
108
+ (min(edge_node_i, edge_node_j), max(edge_node_i, edge_node_j))
109
+ ] += 1
110
+
111
+ # helper routine to get the rdking enum bondtype
112
+ def _get_rdkit_bondtype(bondtype: int) -> BondType:
113
+ int_bondtype_2_enum = {
114
+ 1: BondType.SINGLE,
115
+ 2: BondType.DOUBLE,
116
+ 3: BondType.TRIPLE,
117
+ 4: BondType.QUADRUPLE,
118
+ 5: BondType.QUINTUPLE,
119
+ 6: BondType.HEXTUPLE,
120
+ }
121
+ try:
122
+ return int_bondtype_2_enum[bondtype]
123
+ except KeyError:
124
+ raise NotImplementedError(f"Bond type {bondtype} is not supported")
125
+
126
+ edge_list_rdkit = [
127
+ (node_i, node_j, _get_rdkit_bondtype(bondtype))
128
+ for (node_i, node_j), bondtype in edge_2_bondtype.items()
129
+ ]
130
+ # creating an empty editable molecule
131
+ mol = Chem.RWMol()
132
+ # adding the atoms to the molecule object
133
+
134
+ # as some all_graph nodes can represent charges, we have to remember mapping of molecular atom index to
135
+ # the corresponding atom index in all_graph_nodes
136
+ all_graph_atom_idx_2_mol_atom_idx = {}
137
+ for all_graph_atom_idx, atom in enumerate(all_graph_nodes):
138
+ # ignoring charge-related graph nodes
139
+ if atom.atom_with_valence.atom_type not in {"+", "-"}:
140
+ all_graph_atom_idx_2_mol_atom_idx[all_graph_atom_idx] = mol.GetNumAtoms()
141
+ next_atom = Chem.Atom(atom.atom_with_valence.atom_type)
142
+ next_atom.SetFormalCharge(
143
+ atom.atom_with_valence.atom_valence_and_charge.charge
144
+ )
145
+ mol.AddAtom(next_atom)
146
+
147
+ # adding bonds
148
+ for (edge_node_i, edge_node_j, bond_type) in edge_list_rdkit:
149
+ # checking if the edge represents a charge of connected atom
150
+ the_edge_represents_charge = len(
151
+ {
152
+ all_graph_nodes[node_i].atom_with_valence.atom_type
153
+ for node_i in [edge_node_i, edge_node_j]
154
+ }.intersection({"+", "-"})
155
+ )
156
+ if the_edge_represents_charge:
157
+ # setting a charge to the corresponding atom
158
+ for node_i in [edge_node_i, edge_node_j]:
159
+ if all_graph_nodes[node_i].atom_with_valence.atom_type in {"+", "-"}:
160
+ charge_value = (
161
+ 1
162
+ if all_graph_nodes[node_i].atom_with_valence.atom_type == "+"
163
+ else -1
164
+ )
165
+ else:
166
+ atom_node_i = node_i
167
+ mol.GetAtomWithIdx(
168
+ all_graph_atom_idx_2_mol_atom_idx[atom_node_i]
169
+ ).SetFormalCharge(charge_value)
170
+ else:
171
+ mol.AddBond(
172
+ all_graph_atom_idx_2_mol_atom_idx[edge_node_i],
173
+ all_graph_atom_idx_2_mol_atom_idx[edge_node_j],
174
+ bond_type,
175
+ )
176
+ # returning the rdkit.Chem.rdchem.Mol object
177
+ return mol.GetMol()
178
+
179
+
180
+ class RandomDeNovo(DeNovoMassSpecGymModel):
181
+ def __init__(
182
+ self,
183
+ formula_known: bool = True,
184
+ count_of_valid_valence_assignments: int = 10,
185
+ estimate_chem_element_stats: bool = False,
186
+ max_top_k: int = 10,
187
+ enforce_connectivity: bool = True,
188
+ cache_results: bool = True,
189
+ **kwargs
190
+ ):
191
+ """
192
+
193
+ @param formula_known: a boolean flag about the information available prior to generation
194
+ If formula_known is True, we should generate molecules with the specified formula
195
+ If formula_known is False, we should generate any molecule with the specified mass
196
+ @param count_of_valid_valence_assignments: an integer controlling process of selecting valence assignment
197
+ to each atom in the generated molecule.
198
+ `count_of_valid_valence_assignments` of assignment corresponding to
199
+ the formula are generated, then one assignment is is picked at random.
200
+ The default is set to 3 for the computational speed purposes.
201
+ When setting to 1, the first feasible valence assignment will be used.
202
+ @param estimate_chem_element_stats: a boolean flag controlling if prior information about elements' valences
203
+ and bond type distributions is estimated from training data
204
+ @param max_top_k: a maximum number of candidates to generate. If the count of valid valence assignments do
205
+ not allow generation of max_top_k, then less candidates are returned
206
+ @param enforce_connectivity: a boolean flag controlling connectivity of randomly generated molecules.
207
+ When it is set to True, first a random spanning tree is sampled
208
+ @param cache_results: a boolean flag controlling caching of already generated structures.
209
+ When set to True, for each unique formula the set of random molecules is cached to avoid
210
+ recomputation.
211
+ """
212
+ super(RandomDeNovo, self).__init__(**kwargs)
213
+ self.formula_known = formula_known
214
+ self.count_of_valid_valence_assignments = count_of_valid_valence_assignments
215
+ self.estimate_chem_element_stats = estimate_chem_element_stats
216
+ self.max_top_k = min(max(self.top_ks), max_top_k)
217
+ self.enforce_connectivity = enforce_connectivity
218
+ # prior chemical knownledge about element valences
219
+ self.element_2_valences = ELEMENT_VALENCES
220
+ # a dictionary structure to record molecular weights with corresponding formulas from training data
221
+ # during training steps, for each molecular weight we record all encountered formulas
222
+ # then on training end we compute proportions of the formulas and record it as a mapping
223
+ # mol_weight -> [[formula_1, formula_2], [proportion_of_formula_1, proportion_of_formula_2]]
224
+ self.mol_weight_2_formulas = defaultdict(list)
225
+ # a helper array to store sorted list of train molecular weights.
226
+ # It will be used for the O(logn) lookup of the closest mol weight
227
+ self.mol_weight_trn_values: list[float] = None
228
+ # a dictionary structure for statistics about bond type distributions
229
+ # the dictionary has the following mapping:
230
+ # chem_element ->
231
+ # ValenceAndCharge ->
232
+ # number of already bonded atoms ->
233
+ # [already created BondToNeighbouringAtom] ->
234
+ # AtomWithValence ->
235
+ # list of (bond_type, count) + total_count
236
+ self.element_2_bond_stats = None
237
+ # a cache with already precomputed sets of randomly generated molecules for the given formula
238
+ self.formula_2_random_smiles = {}
239
+ self.cache_results = cache_results
240
+
241
+ def generator_for_splits_of_chem_element_atoms_by_possible_valences(
242
+ self,
243
+ atom_type: chem_element,
244
+ possible_valences: list[ValenceAndCharge],
245
+ atom_count: int,
246
+ already_assigned_groups_of_atoms: dict[AtomWithValence, number_of_atoms],
247
+ ) -> Generator[dict[AtomWithValence, number_of_atoms]]:
248
+ """
249
+ A recursive generator function to iterate over all possible partitions of element atoms
250
+ into groups with different valid valences.
251
+ Each allowed valence value can have any number from atoms, from zero up to total `atom_count`
252
+ @param atom_type: chemical element
253
+ @param possible_valences: a list of allowed valences
254
+ @param atom_count: a total number of element atoms to split into valence groups
255
+ @param already_assigned_groups_of_atoms: partial results to pass into the subsequent recursive calls
256
+
257
+ @return A generator for lazy enumeration over all possible splits of `atom_count` atoms into subgroups
258
+ of valid valences specified in `possible valences` parameters.
259
+ Each return value is a dictionary, mapping atom with fixed valence to a total count of such instances
260
+ in the molecule.
261
+
262
+ @note In the future the method can be made into a function in a separate utils module,
263
+ for the simplicity of codebase organization and testing purposes it's kept as the method for now
264
+ """
265
+ # the check for a base case of the recursion
266
+ if atom_count == 0:
267
+ yield already_assigned_groups_of_atoms
268
+ elif len(possible_valences):
269
+ # taking the first valence value from the possible ones
270
+ next_valence = possible_valences[0]
271
+ # iterating over possible sizes for a group of atoms with `next_valence` value of the valence
272
+ for size_of_group in range(atom_count, -1, -1):
273
+ # recording the assigned size of the group
274
+ already_assigned_groups_of_atoms_next = (
275
+ already_assigned_groups_of_atoms.copy()
276
+ )
277
+ atom_with_valence = AtomWithValence(
278
+ atom_type=atom_type, atom_valence_and_charge=next_valence
279
+ )
280
+ already_assigned_groups_of_atoms_next[atom_with_valence] = size_of_group
281
+ yield from self.generator_for_splits_of_chem_element_atoms_by_possible_valences(
282
+ atom_type=atom_type,
283
+ possible_valences=possible_valences[1:],
284
+ atom_count=atom_count - size_of_group,
285
+ already_assigned_groups_of_atoms=already_assigned_groups_of_atoms_next,
286
+ )
287
+
288
+ def assigner_of_valences_to_all_atoms(
289
+ self,
290
+ unassigned_molecule_elements_with_counts: dict[chem_element, number_of_atoms],
291
+ already_assigned_atoms_with_valences: dict[AtomWithValence, number_of_atoms],
292
+ common_valences_only: bool = True,
293
+ ) -> Generator[dict[AtomWithValence, number_of_atoms]]:
294
+ """
295
+ A recursive function to iterate over all possible valid assignments of valences for each atom in the molecule
296
+ @param unassigned_molecule_elements_with_counts: a dictionary representation of a molecule,
297
+ mapping each present element to a corresponding number of atoms.
298
+ The function is recursive, in the subsequence calls
299
+ the dictionary represents an yet-unprocessed submolecule
300
+ @param already_assigned_atoms_with_valences: partial results to pass into the subsequent recursive calls,
301
+ stored as a dictionary, mapping atom with fixed valence
302
+ to a total count of such atoms in the molecule
303
+ @param common_valences_only: a flag for using the common valence values for each element
304
+
305
+ @return A generator for lazy enumeration over all possible assignments of all molecule atoms into subgroups
306
+ defined by valences. Valence values are the valid ones for the corresponding chemical element.
307
+ Each return value is a dictionary, mapping atom of specified chemical element with a fixed valence
308
+ to a total count of such atoms in the molecule.
309
+
310
+ @note In the future the method can be made into a function in a separate utils module,
311
+ for the simplicity of codebase organization and testing purposes it's kept as the method for now
312
+ """
313
+ # the check for a base case of the recursion
314
+ if len(unassigned_molecule_elements_with_counts) == 0:
315
+ yield already_assigned_atoms_with_valences
316
+ else:
317
+ # processing the next chemical element in the molecule
318
+ chem_element_type, atom_count = list(
319
+ unassigned_molecule_elements_with_counts.items()
320
+ )[0]
321
+ # for the subsequence recursive calls the picked atom will be removed from the yet-to-be-processed
322
+ remaining_unassigned_atoms_with_counts = (
323
+ unassigned_molecule_elements_with_counts.copy()
324
+ )
325
+ del remaining_unassigned_atoms_with_counts[chem_element_type]
326
+ # generating splits of the element count into groups with possible valences
327
+ valences_common, valences_others = self.element_2_valences[
328
+ chem_element_type.capitalize()
329
+ ]
330
+ possible_element_valences = (
331
+ valences_common
332
+ if common_valences_only
333
+ else valences_common + valences_others
334
+ )
335
+ # we ignore "the direction" of ionic bonds, therefore we work with absolute values of valences
336
+ possible_element_valences = map(
337
+ lambda x: ValenceAndCharge(valence=np.abs(x.valence), charge=x.charge),
338
+ possible_element_valences,
339
+ )
340
+ # we require a connected molecule graph, so we ignore possible 0 values of valences
341
+ possible_element_valences = list(
342
+ set(filter(lambda x: x.valence > 0, possible_element_valences))
343
+ )
344
+ # creating a generator for lazy enumeration over all possible splits of element atoms
345
+ # into subgroups of possible valid valences
346
+ valence_split_generator = (
347
+ self.generator_for_splits_of_chem_element_atoms_by_possible_valences(
348
+ atom_type=chem_element_type,
349
+ possible_valences=possible_element_valences,
350
+ atom_count=atom_count,
351
+ already_assigned_groups_of_atoms=dict(),
352
+ )
353
+ )
354
+ # iterating over splits of the element count into groups with possible valences
355
+ for element_atoms_with_valence_2_count in valence_split_generator:
356
+ already_assigned_atoms_with_valences_new = (
357
+ already_assigned_atoms_with_valences.copy()
358
+ )
359
+ already_assigned_atoms_with_valences_new.update(
360
+ element_atoms_with_valence_2_count
361
+ )
362
+ yield from self.assigner_of_valences_to_all_atoms(
363
+ unassigned_molecule_elements_with_counts=remaining_unassigned_atoms_with_counts,
364
+ already_assigned_atoms_with_valences=already_assigned_atoms_with_valences_new,
365
+ common_valences_only=common_valences_only,
366
+ )
367
+
368
+ def is_valence_assignment_feasible(
369
+ self, valence_assignment: dict[AtomWithValence, number_of_atoms]
370
+ ) -> bool:
371
+ """
372
+ A function for checking if the valence assignment to all molecule atoms can be feasible
373
+
374
+ @param valence_assignment: an assignment of all molecule atoms into subgroups of plausible valences
375
+
376
+ @note In the future the method can be made into a function in a separate utils module,
377
+ for the simplicity of codebase organization and testing purposes it's kept as the method for now
378
+ """
379
+ # considering a molecule as a graph with atom being nodes and chemical bonds being edges
380
+ # computing sum of all node degrees
381
+ sum_of_all_node_degrees = sum(
382
+ [
383
+ atom.atom_valence_and_charge.valence * count_of_atoms
384
+ for atom, count_of_atoms in valence_assignment.items()
385
+ ]
386
+ )
387
+ if sum_of_all_node_degrees % 2 == 1:
388
+ # the valence assignment is infeasible as in the graph the number of edges is half of the total degrees sum
389
+ # therefore the sum_of_all_node_degrees must be an even number
390
+ return False
391
+ total_number_of_bonds = sum_of_all_node_degrees / 2
392
+ # the total number of all atoms in the whole molecule
393
+ total_number_of_atoms_in_molecule = sum(valence_assignment.values())
394
+ if total_number_of_bonds < total_number_of_atoms_in_molecule - 1:
395
+ # the valence assignment is infeasible as the molecule graph cannot be connected
396
+ return False
397
+ # check that charges add up to zero
398
+ total_charge = 0
399
+ for atom, count_of_atoms in valence_assignment.items():
400
+ # we do not take virtual nodes for the charged molecules, we force the remaining submolecule to be neutral
401
+ if atom.atom_type not in {"+", "-"}:
402
+ total_charge += atom.atom_valence_and_charge.charge * count_of_atoms
403
+ if total_charge != 0:
404
+ return False
405
+ return True
406
+
407
+ def get_feasible_atom_valence_assignments(
408
+ self, chemical_formula: str
409
+ ) -> list[dict[AtomWithValence, number_of_atoms]]:
410
+ """
411
+ A function generating candidate assignments of valences to individual atoms in the molecule.
412
+ Candidates are returned in a random order.
413
+ @param chemical_formula: a string containing the chemical formula of the molecule
414
+
415
+ @note In the future the method can be made into a function in a separate utils module,
416
+ for the simplicity of codebase organization and testing purposes it's kept as the method for now
417
+ """
418
+ # parsing chemical formula into a dictionary of elements with corresponding counts
419
+ element_2_count = {
420
+ element: int(count)
421
+ for element, count in chemparse.parse_formula(chemical_formula).items()
422
+ }
423
+ # checking that all input elements are valid
424
+ for element in element_2_count.keys():
425
+ if element.capitalize() not in self.element_2_valences:
426
+ raise ValueError(
427
+ f"Found an unknown element {element.capitalize()} in the formula {chemical_formula}"
428
+ )
429
+
430
+ # estimate the total number of all atoms in the whole molecule
431
+ # it will be used to check validity of the valence assignments
432
+ total_number_of_atoms_in_molecule = sum(element_2_count.values())
433
+ generated_candidate_valence_assignments = []
434
+ valence_assignment_generator = self.assigner_of_valences_to_all_atoms(
435
+ unassigned_molecule_elements_with_counts=element_2_count,
436
+ already_assigned_atoms_with_valences=dict(),
437
+ common_valences_only=True,
438
+ )
439
+ termination_assignment_value = {AtomWithValence("No more assignments", -1): -1}
440
+ next_valence_assignment = next(
441
+ valence_assignment_generator, termination_assignment_value
442
+ )
443
+ while (
444
+ len(generated_candidate_valence_assignments)
445
+ < self.count_of_valid_valence_assignments
446
+ and next_valence_assignment != termination_assignment_value
447
+ ):
448
+ if self.is_valence_assignment_feasible(next_valence_assignment):
449
+ generated_candidate_valence_assignments.append(next_valence_assignment)
450
+ next_valence_assignment = next(
451
+ valence_assignment_generator, termination_assignment_value
452
+ )
453
+ # if no valence assignment was found with common valences,
454
+ # then try generating assignments including not-common valences
455
+ if len(generated_candidate_valence_assignments) == 0:
456
+ valence_assignment_generator = self.assigner_of_valences_to_all_atoms(
457
+ unassigned_molecule_elements_with_counts=element_2_count,
458
+ already_assigned_atoms_with_valences=dict(),
459
+ common_valences_only=False,
460
+ )
461
+ next_valence_assignment = next(
462
+ valence_assignment_generator, termination_assignment_value
463
+ )
464
+ while (
465
+ len(generated_candidate_valence_assignments)
466
+ < self.count_of_valid_valence_assignments
467
+ and next_valence_assignment != termination_assignment_value
468
+ ):
469
+ if self.is_valence_assignment_feasible(next_valence_assignment):
470
+ generated_candidate_valence_assignments.append(
471
+ next_valence_assignment
472
+ )
473
+ next_valence_assignment = next(
474
+ valence_assignment_generator, termination_assignment_value
475
+ )
476
+
477
+ if len(generated_candidate_valence_assignments) == 0:
478
+ raise ValueError(
479
+ f"No valence assignments can be generated for the formula {chemical_formula}"
480
+ )
481
+ shuffle(generated_candidate_valence_assignments)
482
+ return generated_candidate_valence_assignments
483
+
484
+ def sample_second_edgenode_at_random(
485
+ self,
486
+ edge_start_node_i: int,
487
+ all_graph_nodes: list[AtomNodeForRandomTraversal],
488
+ open_nodes_for_sampling: dict[str, set[int]],
489
+ possible_candidates_type: str,
490
+ closed_set: set[int],
491
+ use_chem_element_stats: bool = False,
492
+ already_connected_neighbours: list[BondToNeighbouringAtom] = None,
493
+ ):
494
+ """
495
+ A function randomly sampling the second node for an edge
496
+ @param edge_start_node_i: index of the first edge node
497
+ @param all_graph_nodes: a list of all nodes in the molecule graph
498
+ @param open_nodes_for_sampling: dictionary with sets of node indices which
499
+ can be considered for closing the edge.
500
+ Each set is specified by the dictionary key:
501
+ "coordinate_bond_negatively_charged_targets",
502
+ "coordinate_bond_positively_charged_targets",
503
+ "covalent_bond_targets"
504
+ @param possible_candidates_type: the `open_nodes_for_sampling` dictionary key
505
+ @param closed_set: closed set for traversal
506
+ @param use_chem_element_stats: a boolean flag setting up usage of per chem. elements statistics about its bonds
507
+ @param already_connected_neighbours: an adjacency list of already sampled neighbours
508
+ """
509
+ if not use_chem_element_stats:
510
+ edge_end_node_j = choice(
511
+ [
512
+ candidate_node_j
513
+ for candidate_node_j in open_nodes_for_sampling[
514
+ possible_candidates_type
515
+ ]
516
+ if candidate_node_j not in closed_set
517
+ ]
518
+ )
519
+ bond_degree = 1
520
+ else:
521
+ # checking the current state of the atom and gathering the corresponding stats
522
+ number_of_already_sampled_neighbours = len(already_connected_neighbours)
523
+ # note that the graph is undirected, start-end node refers to the random traversal only
524
+ start_atom = all_graph_nodes[edge_start_node_i]
525
+ if self.element_2_bond_stats is None:
526
+ raise RuntimeError(
527
+ "To use chem. element stats, the model has to be trained first,"
528
+ "to record training molecular weights with corresponding formulas."
529
+ )
530
+ # the structure of `self.element_2_bond_stats` is
531
+ # chem_element ->
532
+ # ValenceAndCharge ->
533
+ # number of already bonded atoms ->
534
+ # [already created BondToNeighbouringAtom] ->
535
+ # AtomWithValence ->
536
+ # list of (bond_type, count)+ total_count
537
+
538
+ # if we don't have stats -> fall back to sampling from all candidates
539
+ try:
540
+ element_stats = self.element_2_bond_stats[
541
+ start_atom.atom_with_valence.atom_type
542
+ ][
543
+ ValenceAndCharge(
544
+ start_atom.atom_with_valence.atom_valence_and_charge.valence,
545
+ start_atom.atom_with_valence.atom_valence_and_charge.charge,
546
+ )
547
+ ][
548
+ number_of_already_sampled_neighbours
549
+ ][
550
+ tuple(sorted(already_connected_neighbours))
551
+ ]
552
+
553
+ full_candidates_list = []
554
+ neighb_with_stats_candidates_list = []
555
+ neighb_with_stats_bondcounts = []
556
+ neighb_with_stats_bondlists = []
557
+ # iterating over open nodes of the corresponding bond type
558
+ for candidate_node_j in open_nodes_for_sampling[
559
+ possible_candidates_type
560
+ ]:
561
+ if candidate_node_j not in closed_set:
562
+ # remembering all candidates in case no statistic-based option is there
563
+ full_candidates_list.append(candidate_node_j)
564
+ # checking if the candidate is present in element-specific bond stats
565
+ candidate_neighb_atom = all_graph_nodes[
566
+ candidate_node_j
567
+ ].atom_with_valence
568
+ if candidate_neighb_atom in element_stats:
569
+ neighb_with_stats_candidates_list.append(candidate_node_j)
570
+ bondslist, total_bond_count = element_stats[
571
+ candidate_neighb_atom
572
+ ]
573
+ neighb_with_stats_bondcounts.append(total_bond_count)
574
+ neighb_with_stats_bondlists.append(bondslist)
575
+ # when no stats-based neighbour remain (e.g. hydrogens are not recorded in the stats)
576
+ if len(neighb_with_stats_candidates_list) == 0:
577
+ edge_end_node_j = choice(full_candidates_list)
578
+ bond_degree = 1
579
+ else:
580
+ # sampling based on frequences in bond stats
581
+ total_bondcount_sum = sum(neighb_with_stats_bondcounts)
582
+ proportions = [
583
+ val / total_bondcount_sum
584
+ for val in neighb_with_stats_bondcounts
585
+ ]
586
+ edge_end_node_j = np.random.choice(
587
+ neighb_with_stats_candidates_list, p=proportions
588
+ )
589
+ # getting i of the sampled neighbour to access its bond-stats
590
+ neighb_i = neighb_with_stats_candidates_list.index(edge_end_node_j)
591
+ # for the sampled end node, we sample the type of the bond based on the stats
592
+ bondtypes_possible = []
593
+ counts_of_possible_bondtypes = []
594
+ total_possible_bondtype_count = 0
595
+ # we leave only the bonds which current state of random generation allows
596
+ # i.e., we cannot sample a bond violating the current remaining degree of `edge_start_node_i`
597
+ start_node_remaining_degree = all_graph_nodes[
598
+ edge_start_node_i
599
+ ].remaining_node_degree
600
+ for bondtype, count in neighb_with_stats_bondlists[neighb_i]:
601
+ if bondtype <= start_node_remaining_degree:
602
+ bondtypes_possible.append(bondtype)
603
+ counts_of_possible_bondtypes.append(count)
604
+ total_possible_bondtype_count += count
605
+ # if no bonds can be closed for the sampled element, fall back to sampling from full candidates list
606
+ if len(bondtypes_possible) == 0:
607
+ edge_end_node_j = choice(full_candidates_list)
608
+ bond_degree = 1
609
+ else:
610
+ bond_degree_proportions = [
611
+ num / total_possible_bondtype_count
612
+ for num in counts_of_possible_bondtypes
613
+ ]
614
+ bond_degree = np.random.choice(
615
+ bondtypes_possible, p=bond_degree_proportions
616
+ )
617
+ already_connected_neighbours.append(
618
+ BondToNeighbouringAtom(
619
+ adjacent_atom=all_graph_nodes[
620
+ edge_end_node_j
621
+ ].atom_with_valence,
622
+ bond_type=bond_degree,
623
+ )
624
+ )
625
+ except:
626
+ edge_end_node_j = choice(
627
+ [
628
+ candidate_node_j
629
+ for candidate_node_j in open_nodes_for_sampling[
630
+ possible_candidates_type
631
+ ]
632
+ if candidate_node_j not in closed_set
633
+ ]
634
+ )
635
+ bond_degree = 1
636
+ return edge_end_node_j, bond_degree, already_connected_neighbours
637
+
638
+ def sample_edge_at_random(
639
+ self,
640
+ all_graph_nodes: list[AtomNodeForRandomTraversal],
641
+ open_nodes_for_sampling: dict[str, set[int]],
642
+ edge_start_node_i: int = None,
643
+ closed_set: set[int] = None,
644
+ use_chem_element_stats: bool = False,
645
+ atom_2_already_connected_neighbours: list[list[BondToNeighbouringAtom]] = None,
646
+ ) -> tuple[tuple[int, int], list[AtomNodeForRandomTraversal], set[int]]:
647
+ """
648
+ Helper function to filter atoms suitable for generation of a random bond with `edge_start_node_i`
649
+ and sampling a random edge
650
+ @param all_graph_nodes: a list of all nodes in the molecule graph
651
+ @param edge_start_node_i: index of the first edge node
652
+ @param open_nodes_for_sampling: dictionary with sets of node indices which
653
+ can be considered for closing the edge.
654
+ Each set is specified by the dictionary key:
655
+ "coordinate_bond_negatively_charged_targets",
656
+ "coordinate_bond_positively_charged_targets",
657
+ "covalent_bond_targets"
658
+ @param use_chem_element_stats: a boolean flag setting up usage of per chem. elements statistics about its bonds
659
+ @param closed_set: closed set for traversal
660
+ @param atom_2_already_connected_neighbours: a mapping from atom to its adjacency list of already sampled neighbours
661
+ @return: a sampled edge and updated structures `all_graph_nodes`, `open_nodes_for_sampling`
662
+ """
663
+ # sample the start node for the edge if it's not specified
664
+ if edge_start_node_i is None:
665
+ edge_start_node_i = choice(
666
+ sum(map(list, open_nodes_for_sampling.values()), [])
667
+ )
668
+ if closed_set is None:
669
+ closed_set = {edge_start_node_i}
670
+ # check if the start edge atom has the charge and therefore can form coordinate bond
671
+ can_form_coordinate_bond = (
672
+ all_graph_nodes[edge_start_node_i].remaining_node_charge != 0
673
+ )
674
+ # if possible, create coordinate bond at random
675
+ is_bond_coordinate = can_form_coordinate_bond and np.random.rand() < 0.5
676
+ if is_bond_coordinate:
677
+ start_node_charge_sign = np.sign(
678
+ all_graph_nodes[edge_start_node_i].remaining_node_charge
679
+ )
680
+ # if for the coordinate bond one atom is positively charged, then another must be charged negatively
681
+ if start_node_charge_sign > 0:
682
+ possible_candidates_type = "coordinate_bond_neg_charged_targets"
683
+ else:
684
+ possible_candidates_type = "coordinate_bond_pos_charged_targets"
685
+ else:
686
+ possible_candidates_type = "covalent_bond_targets"
687
+
688
+ (
689
+ edge_end_node_j,
690
+ node_degree_reduction,
691
+ atom_2_already_connected_neighbours[edge_start_node_i],
692
+ ) = self.sample_second_edgenode_at_random(
693
+ edge_start_node_i,
694
+ all_graph_nodes,
695
+ open_nodes_for_sampling,
696
+ possible_candidates_type,
697
+ closed_set,
698
+ use_chem_element_stats,
699
+ atom_2_already_connected_neighbours[edge_start_node_i],
700
+ )
701
+
702
+ # decrease the node degrees correspondingly
703
+ for node_of_a_new_edge_i in [edge_start_node_i, edge_end_node_j]:
704
+ all_graph_nodes[
705
+ node_of_a_new_edge_i
706
+ ].remaining_node_degree -= node_degree_reduction
707
+ # if all bonds are created for the particular atom, it is no more open for traversal
708
+ if all_graph_nodes[node_of_a_new_edge_i].remaining_node_degree <= 0:
709
+ for candidates_type in open_nodes_for_sampling.keys():
710
+ if node_of_a_new_edge_i in open_nodes_for_sampling[candidates_type]:
711
+ open_nodes_for_sampling[candidates_type].remove(
712
+ node_of_a_new_edge_i
713
+ )
714
+ # if the added bond was coordinate, modify the remaining charges correspondingly
715
+ elif is_bond_coordinate:
716
+ new_charge_abs_value = (
717
+ np.abs(all_graph_nodes[node_of_a_new_edge_i].remaining_node_charge)
718
+ - 1
719
+ )
720
+ # check if the node still can form coordinate bonds
721
+ if new_charge_abs_value == 0:
722
+ for candidates_type in [
723
+ "coordinate_bond_neg_charged_targets",
724
+ "coordinate_bond_pos_charged_targets",
725
+ ]:
726
+ if (
727
+ node_of_a_new_edge_i
728
+ in open_nodes_for_sampling[candidates_type]
729
+ ):
730
+ open_nodes_for_sampling[candidates_type].remove(
731
+ node_of_a_new_edge_i
732
+ )
733
+ else:
734
+ charge_sign = np.sign(
735
+ all_graph_nodes[node_of_a_new_edge_i].remaining_node_charge
736
+ )
737
+ all_graph_nodes[node_of_a_new_edge_i].remaining_node_charge = (
738
+ charge_sign * new_charge_abs_value
739
+ )
740
+ return (
741
+ (edge_start_node_i, edge_end_node_j),
742
+ all_graph_nodes,
743
+ open_nodes_for_sampling,
744
+ )
745
+
746
+ def generate_random_molecule_graphs_via_traversal(
747
+ self,
748
+ chemical_formula: str,
749
+ max_number_of_retries_per_valence_assignment: int = 100,
750
+ ) -> list[Mol]:
751
+ """
752
+ A function generating random molecule graph(s).
753
+ The generation process ensures that each graph is connected.
754
+ If any of the `self.count_of_valid_valence_assignments` enables it,
755
+ the function returns graph(s) without self-loops.
756
+
757
+ @param chemical_formula: a string containing the chemical formula of the molecule
758
+ @param max_number_of_retries_per_valence_assignment: a max count of attempts to generate a random spanning tree
759
+ for a given potentially feasible valence assignment
760
+
761
+ @note In the future the method can be made into a function in a separate utils module,
762
+ for the simplicity of codebase organization and testing purposes it's kept as the method for now
763
+ """
764
+ # check if for the input formula the random structures have been already generated
765
+ if self.cache_results and chemical_formula in self.formula_2_random_smiles:
766
+ return self.formula_2_random_smiles[chemical_formula]
767
+
768
+ # get candidate partitions of all molecule atoms into valences
769
+ candidate_valence_assignments = self.get_feasible_atom_valence_assignments(
770
+ chemical_formula
771
+ )
772
+ # iterate over each valence assignment to all atoms, the order is random
773
+ assert (
774
+ len(candidate_valence_assignments) > 0
775
+ ), f"No potentially feasible atom valence assignment for {chemical_formula}"
776
+ # number of iteration over feasible valence assignments
777
+ num_of_iterations_over_splits_into_valences = int(
778
+ np.ceil(self.max_top_k / len(candidate_valence_assignments))
779
+ )
780
+ generated_molecules = []
781
+
782
+ # we request to generate self.max_top_k molecule(s)
783
+ while len(generated_molecules) < self.max_top_k:
784
+ for _ in range(num_of_iterations_over_splits_into_valences):
785
+ for valence_assignment in candidate_valence_assignments:
786
+ # first randomly create a spanning tree of the molecule graph, to ensure the connectivity of molecule.
787
+ # The feasibility check `self.is_valence_assignment_feasible` inside the
788
+ # `self.get_feasible_atom_valence_assignments` function should ensure the possibility to create the tree.
789
+ spanning_tree_was_generated = False
790
+ spanning_tree_generation_attempts = 0
791
+ while (
792
+ not spanning_tree_was_generated
793
+ and spanning_tree_generation_attempts
794
+ < max_number_of_retries_per_valence_assignment
795
+ ):
796
+ spanning_tree_generation_attempts += 1
797
+ # we optimistically set the value of `spanning_tree_was_generated` to True,
798
+ # If the current traversal do not lead to a spanning tree,
799
+ # then `spanning_tree_was_generated` is set to False in the code below
800
+ spanning_tree_was_generated = True
801
+
802
+ # prepare node list for a random edges generation
803
+ all_graph_nodes = []
804
+ for (
805
+ atom_with_valence,
806
+ num_of_atoms_in_molecule,
807
+ ) in valence_assignment.items():
808
+ for _ in range(num_of_atoms_in_molecule):
809
+ all_graph_nodes.append(
810
+ AtomNodeForRandomTraversal(
811
+ atom_with_valence=atom_with_valence
812
+ )
813
+ )
814
+
815
+ # a helper structure to record already sampled bonds
816
+ # it is used only if we use estimated chem elements stats in the generation process
817
+ atom_2_already_connected_neighbours = [
818
+ [] for _ in range(len(all_graph_nodes))
819
+ ]
820
+
821
+ # recording sets of nodes available for random sampling of covalent and coordinate bonds
822
+ coordinate_bond_neg_charged_targets = {
823
+ node_i
824
+ for node_i, node in enumerate(all_graph_nodes)
825
+ if np.sign(node.remaining_node_charge) == -1
826
+ }
827
+ coordinate_bond_pos_charged_targets = {
828
+ node_i
829
+ for node_i, node in enumerate(all_graph_nodes)
830
+ if np.sign(node.remaining_node_charge) == 1
831
+ }
832
+ covalent_bond_targets = {
833
+ node_i
834
+ for node_i, node in enumerate(all_graph_nodes)
835
+ if node.remaining_node_charge == 0
836
+ or node.remaining_node_degree
837
+ > np.abs(node.remaining_node_charge)
838
+ }
839
+
840
+ open_nodes_for_sampling = {
841
+ "coordinate_bond_neg_charged_targets": coordinate_bond_neg_charged_targets,
842
+ "coordinate_bond_pos_charged_targets": coordinate_bond_pos_charged_targets,
843
+ "covalent_bond_targets": covalent_bond_targets,
844
+ }
845
+
846
+ # the final edge list will be stored into the variable below.
847
+ # An edge is defined by a pair of position indices in the `all_graph_nodes` list
848
+ edge_list = []
849
+
850
+ # the nodes already included into the spanning tree
851
+ # the set is used for quick blacklisting, while the list is used for possible backtracking when
852
+ (
853
+ spanning_tree_visited_nodes_set,
854
+ spanning_tree_traversal_list,
855
+ ) = (
856
+ set(),
857
+ deque(),
858
+ )
859
+ # sample a random start of spanning tree generation
860
+ edge_start_node_i = choice(list(range(len(all_graph_nodes))))
861
+ spanning_tree_visited_nodes_set.add(edge_start_node_i)
862
+ spanning_tree_traversal_list.append(edge_start_node_i)
863
+ while self.enforce_connectivity and len(
864
+ spanning_tree_visited_nodes_set
865
+ ) < len(all_graph_nodes):
866
+ # check if the start edge atom has the charge and therefore can form coordinate bond
867
+ try:
868
+ (
869
+ (edge_start_node_i, edge_end_node_i),
870
+ all_graph_nodes,
871
+ open_nodes_for_sampling,
872
+ ) = self.sample_edge_at_random(
873
+ all_graph_nodes,
874
+ open_nodes_for_sampling,
875
+ edge_start_node_i=edge_start_node_i,
876
+ closed_set=spanning_tree_visited_nodes_set,
877
+ use_chem_element_stats=self.estimate_chem_element_stats,
878
+ atom_2_already_connected_neighbours=atom_2_already_connected_neighbours,
879
+ )
880
+ except IndexError:
881
+ spanning_tree_was_generated = False
882
+ break
883
+ # note that the graph is undirected, start-end node refers to the random traversal only
884
+ edge_list.append((edge_start_node_i, edge_end_node_i))
885
+ # recording the node added to the random spanning tree
886
+ spanning_tree_visited_nodes_set.add(edge_end_node_i)
887
+ spanning_tree_traversal_list.append(edge_end_node_i)
888
+
889
+ # finding a start node for the next sampled edge.
890
+ # We have to ensure that such a node still has some degree not covered by sampling nodes.
891
+ # For that, we might need to backtrack.
892
+ candidate_for_start_node_i = edge_end_node_i
893
+ try:
894
+ while (
895
+ all_graph_nodes[
896
+ candidate_for_start_node_i
897
+ ].remaining_node_degree
898
+ == 0
899
+ ):
900
+ spanning_tree_traversal_list.pop()
901
+ candidate_for_start_node_i = (
902
+ spanning_tree_traversal_list[-1]
903
+ )
904
+ except IndexError:
905
+ spanning_tree_was_generated = False
906
+ break
907
+ edge_start_node_i = candidate_for_start_node_i
908
+
909
+ # after the spanning tree edges were sampled,
910
+ # now we randomly connect nodes with remaining degrees yet uncovered by sampled bonds
911
+ while sum(map(len, open_nodes_for_sampling.values())) >= 2:
912
+ try:
913
+ (
914
+ (edge_start_node_i, edge_end_node_i),
915
+ all_graph_nodes,
916
+ open_nodes_for_sampling,
917
+ ) = self.sample_edge_at_random(
918
+ all_graph_nodes,
919
+ open_nodes_for_sampling,
920
+ use_chem_element_stats=self.estimate_chem_element_stats,
921
+ atom_2_already_connected_neighbours=atom_2_already_connected_neighbours,
922
+ )
923
+ except IndexError:
924
+ break
925
+ edge_list.append((edge_start_node_i, edge_end_node_i))
926
+
927
+ # if all nodes were covered by edges without self-loops, then we remember the generated molecule
928
+ if sum(map(len, open_nodes_for_sampling.values())) == 0:
929
+ generated_molecules.append(
930
+ create_rdkit_molecule_from_edge_list(
931
+ edge_list, all_graph_nodes
932
+ )
933
+ )
934
+ if len(generated_molecules) == self.max_top_k:
935
+ if self.cache_results:
936
+ self.formula_2_random_smiles[
937
+ chemical_formula
938
+ ] = generated_molecules
939
+ return generated_molecules
940
+ if self.cache_results:
941
+ self.formula_2_random_smiles[chemical_formula] = generated_molecules
942
+ return generated_molecules
943
+
944
+ def training_step(
945
+ self, batch: dict, batch_idx: torch.Tensor
946
+ ) -> tuple[torch.Tensor, torch.Tensor]:
947
+ # recording statistics about chemical elements
948
+ if self.estimate_chem_element_stats:
949
+ if self.element_2_bond_stats is None:
950
+ self.element_2_bond_stats = defaultdict(dict)
951
+ for mol_smiles in batch["mol"]:
952
+ molecule = Chem.MolFromSmiles(mol_smiles)
953
+ # in order to work with double and single bonds instead of aromatic
954
+ Chem.Kekulize(molecule, clearAromaticFlags=True)
955
+ # we add hydrogen atoms (which are ommited by default)
956
+ molecule = Chem.AddHs(molecule)
957
+ formula = CalcMolFormula(molecule)
958
+ for atom in molecule.GetAtoms():
959
+ valence = atom.GetTotalValence()
960
+ charge = atom.GetFormalCharge()
961
+ valence_charge = ValenceAndCharge(valence, charge)
962
+ chem_element_type = atom.GetSymbol()
963
+ atom_bonds = atom.GetBonds()
964
+ if (
965
+ valence_charge
966
+ not in self.element_2_bond_stats[chem_element_type]
967
+ ):
968
+ # for each value of atom's valence, number of neighbours we will count types of neighbouring atoms
969
+ self.element_2_bond_stats[chem_element_type][
970
+ valence_charge
971
+ ] = dict()
972
+
973
+ all_atom_neighbours = set()
974
+ for bond in atom_bonds:
975
+ start_atom_idx = atom.GetIdx()
976
+ end_atom = [
977
+ _atom
978
+ for _atom in [bond.GetBeginAtom(), bond.GetEndAtom()]
979
+ if _atom.GetIdx() != start_atom_idx
980
+ ][0]
981
+ all_atom_neighbours.add(
982
+ (
983
+ end_atom.GetSymbol(),
984
+ end_atom.GetTotalValence(),
985
+ end_atom.GetFormalCharge(),
986
+ bond.GetBondTypeAsDouble(),
987
+ )
988
+ )
989
+
990
+ all_neighbour_subsets = [
991
+ [set(subset) for subset in combinations(all_atom_neighbours, r)]
992
+ for r in range(len(all_atom_neighbours))
993
+ ]
994
+ for neighbour_subsets_of_fixed_size in all_neighbour_subsets:
995
+ subset_size = len(neighbour_subsets_of_fixed_size[0])
996
+ # for each number of already connected atoms we record the neighbours and then possible bonds yet to be closed
997
+ if (
998
+ subset_size
999
+ not in self.element_2_bond_stats[chem_element_type][
1000
+ valence_charge
1001
+ ]
1002
+ ):
1003
+ self.element_2_bond_stats[chem_element_type][
1004
+ valence_charge
1005
+ ][subset_size] = dict()
1006
+ for neighbours in neighbour_subsets_of_fixed_size:
1007
+ neighbours_tuple = tuple(sorted(neighbours))
1008
+ if (
1009
+ neighbours_tuple
1010
+ not in self.element_2_bond_stats[chem_element_type][
1011
+ valence_charge
1012
+ ][subset_size]
1013
+ ):
1014
+ self.element_2_bond_stats[chem_element_type][
1015
+ valence_charge
1016
+ ][subset_size][neighbours_tuple] = defaultdict(int)
1017
+ remaining_bonds = all_atom_neighbours.difference(neighbours)
1018
+ for remaining_bonded_neighbour in remaining_bonds:
1019
+ self.element_2_bond_stats[chem_element_type][
1020
+ valence_charge
1021
+ ][subset_size][neighbours_tuple][
1022
+ remaining_bonded_neighbour
1023
+ ] += 1
1024
+
1025
+ # recording molecular weight
1026
+ for mol_smiles in batch["mol"]:
1027
+ molecule = Chem.MolFromSmiles(mol_smiles)
1028
+ formula = CalcMolFormula(molecule)
1029
+ weight = ExactMolWt(molecule)
1030
+ self.mol_weight_2_formulas[weight].append(formula)
1031
+ # Random baseline, so we return a dummy loss
1032
+ loss = torch.tensor(0.0, requires_grad=True)
1033
+ return dict(loss=loss, mols_pred=["C"])
1034
+
1035
+ def on_train_end(self) -> None:
1036
+ # for each molecular weight we compute proportions of recorded molecular formulas
1037
+ molecular_weight_2_formula_counts = {
1038
+ weight: Counter(formulas)
1039
+ for weight, formulas in self.mol_weight_2_formulas.items()
1040
+ }
1041
+ weight_2_formula_proportions = {}
1042
+ for weight, formula_2_count in molecular_weight_2_formula_counts.items():
1043
+ total_count = sum(formula_2_count.values())
1044
+ weight_2_formula_proportions[weight] = {
1045
+ formula: count / total_count
1046
+ for formula, count in formula_2_count.items()
1047
+ }
1048
+ # for consequent sampling using numpy.random.choice function, we store the results in the format
1049
+ # weight -> [[formula_1, formula_2], [proportion_of_formula_1, proportion_of_formula_2]]
1050
+ self.mol_weight_2_formulas = {
1051
+ weight: [
1052
+ list(formula_2_proportions.keys()),
1053
+ list(formula_2_proportions.values()),
1054
+ ]
1055
+ for weight, formula_2_proportions in weight_2_formula_proportions.items()
1056
+ }
1057
+ # storing weights in the sorted list for the logarithmic time look-up of the closest weight value
1058
+ self.mol_weight_trn_values = sorted(self.mol_weight_2_formulas.keys())
1059
+
1060
+ # if chem element stats are used, then the corresponding data structure is reformated in accordance with the
1061
+ # description from docstring to the class __init__:
1062
+ # chem_element ->
1063
+ # ValenceAndCharge ->
1064
+ # number of already bonded atoms ->
1065
+ # [already created BondToNeighbouringAtom] ->
1066
+ # AtomWithValence ->
1067
+ # list of (bond_type, count) + total_count
1068
+ if self.estimate_chem_element_stats:
1069
+ element_2_bond_stats = defaultdict(dict)
1070
+ for (
1071
+ chem_element,
1072
+ valence_charge_2_stats,
1073
+ ) in self.element_2_bond_stats.items():
1074
+ for valence_charge, num_bonds_2_stats in valence_charge_2_stats.items():
1075
+ element_2_bond_stats[chem_element][valence_charge] = dict()
1076
+ for num_bonds, bonds_2_stats in num_bonds_2_stats.items():
1077
+ element_2_bond_stats[chem_element][valence_charge][
1078
+ num_bonds
1079
+ ] = dict()
1080
+ for (
1081
+ bonds,
1082
+ neighb_atom_with_valence_2_stats,
1083
+ ) in bonds_2_stats.items():
1084
+ present_bonds_sorted = tuple(
1085
+ sorted(
1086
+ [
1087
+ BondToNeighbouringAtom(
1088
+ adjacent_atom=AtomWithValence(
1089
+ atom_type=bond[0],
1090
+ atom_valence_and_charge=ValenceAndCharge(
1091
+ valence=bond[1], charge=bond[2]
1092
+ ),
1093
+ ),
1094
+ bond_type=bond[3],
1095
+ )
1096
+ for bond in bonds
1097
+ ]
1098
+ )
1099
+ )
1100
+ element_2_bond_stats[chem_element][valence_charge][
1101
+ num_bonds
1102
+ ][present_bonds_sorted] = defaultdict(list)
1103
+ for (
1104
+ neighb_atom_type,
1105
+ neighb_atom_valence,
1106
+ neighb_atom_charge,
1107
+ bondtype,
1108
+ ), count in neighb_atom_with_valence_2_stats.items():
1109
+ neighbouring_atom = AtomWithValence(
1110
+ atom_type=neighb_atom_type,
1111
+ atom_valence_and_charge=ValenceAndCharge(
1112
+ valence=neighb_atom_valence,
1113
+ charge=neighb_atom_charge,
1114
+ ),
1115
+ )
1116
+ element_2_bond_stats[chem_element][valence_charge][
1117
+ num_bonds
1118
+ ][present_bonds_sorted][neighbouring_atom].append(
1119
+ (bondtype, count)
1120
+ )
1121
+ # computing total count of all bound per neighbouring atom
1122
+ for (
1123
+ neighbouring_atom,
1124
+ list_of_bondtype_counts,
1125
+ ) in element_2_bond_stats[chem_element][valence_charge][
1126
+ num_bonds
1127
+ ][
1128
+ present_bonds_sorted
1129
+ ].items():
1130
+ total_count_of_bonds = sum(
1131
+ map(lambda x: x[1], list_of_bondtype_counts)
1132
+ )
1133
+ element_2_bond_stats[chem_element][valence_charge][
1134
+ num_bonds
1135
+ ][present_bonds_sorted][neighbouring_atom] = (
1136
+ list_of_bondtype_counts,
1137
+ total_count_of_bonds,
1138
+ )
1139
+ self.element_2_bond_stats = element_2_bond_stats
1140
+
1141
+ def sample_formula_with_the_closest_molecular_weight(
1142
+ self, molecular_weight: float
1143
+ ) -> str:
1144
+ """
1145
+ A method sampling chemical formula observed in training data with the closest weight to `molecular_weight`
1146
+ @param molecular_weight: Molecular weight of a structure to be generated
1147
+ """
1148
+ if self.mol_weight_trn_values is None:
1149
+ raise RuntimeError(
1150
+ "For random denovo generation without known formula, the model has to be trained first,"
1151
+ "to record training molecular weights with corresponding formulas."
1152
+ )
1153
+ # finding a place in the sorted array for insertion of the `molecular_weight`, while preserving sorted order
1154
+ idx_of_closest_larger = bisect.bisect_left(
1155
+ self.mol_weight_trn_values, molecular_weight
1156
+ )
1157
+ # check if the exact same molecular weight was observed in training data, otherwise select the closest weight
1158
+ if molecular_weight == self.mol_weight_trn_values[idx_of_closest_larger]:
1159
+ idx_of_closest = idx_of_closest_larger
1160
+ elif idx_of_closest_larger > 0:
1161
+ # determining the closest molecular weight out of both neighbours
1162
+ idx_of_closest_smaller = idx_of_closest_larger - 1
1163
+ weight_difference_with_smaller_neighbour = (
1164
+ molecular_weight - self.mol_weight_trn_values[idx_of_closest_smaller]
1165
+ )
1166
+ weight_difference_with_larger_neighbour = (
1167
+ self.mol_weight_trn_values[idx_of_closest_larger] - molecular_weight
1168
+ )
1169
+ if (
1170
+ weight_difference_with_larger_neighbour
1171
+ < weight_difference_with_smaller_neighbour
1172
+ ):
1173
+ idx_of_closest = idx_of_closest_larger
1174
+ else:
1175
+ idx_of_closest = idx_of_closest_smaller
1176
+ else:
1177
+ idx_of_closest = 0
1178
+ # the value of the molecular weight observed in training labels, which is the closest to `molecular_weight`
1179
+ closest_observed_molecular_weight = self.mol_weight_trn_values[idx_of_closest]
1180
+ # getting chemical formulas observed for this molecular weight
1181
+ # self.mol_weight_2_formulas is a dictionary containing the following mapping
1182
+ # weight -> [[formula_1, formula_2], [proportion_of_formula_1, proportion_of_formula_2]]
1183
+ feasible_formulas, formula_proportions = self.mol_weight_2_formulas[
1184
+ closest_observed_molecular_weight
1185
+ ]
1186
+ # if just one formula is known, it is returned directly
1187
+ if len(feasible_formulas) == 1:
1188
+ return feasible_formulas[0]
1189
+ # otherwise we randomly sample in accordance with proportions
1190
+ return np.random.choice(feasible_formulas, p=formula_proportions)
1191
+
1192
+ def step(
1193
+ self, batch: dict, stage: Stage = Stage.NONE
1194
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1195
+ mols = batch["mol"] # List of SMILES of length batch_size
1196
+
1197
+ # If formula_known is True, we should generate molecules with the same formula as label (`mols` above)
1198
+ # If formula_known is False, we should generate any molecule with the same mass as label
1199
+
1200
+ # obtaining molecule objects from SMILES
1201
+ molecules = [Chem.MolFromSmiles(smiles) for smiles in mols]
1202
+ # getting the formulas
1203
+ if self.formula_known:
1204
+ formulas = [CalcMolFormula(molecule) for molecule in molecules]
1205
+ else:
1206
+ molecular_weights = [ExactMolWt(molecule) for molecule in molecules]
1207
+ formulas = [
1208
+ self.sample_formula_with_the_closest_molecular_weight(mol_weight)
1209
+ for mol_weight in molecular_weights
1210
+ ]
1211
+ # (bs, k) list of rdkit molecules
1212
+ mols_pred = [
1213
+ self.generate_random_molecule_graphs_via_traversal(formula)
1214
+ for formula in formulas
1215
+ ]
1216
+
1217
+ for predicted_mol_group in mols_pred:
1218
+ for mol in predicted_mol_group:
1219
+ Chem.RemoveHs(mol)
1220
+
1221
+ # list of predicted smiles
1222
+ smiles_pred = [
1223
+ [
1224
+ Chem.MolToSmiles(mol_candidate)
1225
+ for mol_candidate in candidates_per_input_mol
1226
+ ]
1227
+ for candidates_per_input_mol in mols_pred
1228
+ ]
1229
+
1230
+ # Random baseline, so we return a dummy loss
1231
+ loss = torch.tensor(0.0, requires_grad=True)
1232
+ return dict(loss=loss, mols_pred=smiles_pred)
1233
+
1234
+ def configure_optimizers(self):
1235
+ # No optimizer needed for a random baseline
1236
+ return None
1237
+
1238
+
1239
+ # element valences taken from sources like https://sciencenotes.org/element-valency-pdf
1240
+ # the first list contains the typical valences, each tuple is a valence value with the corresponding charge
1241
+ ELEMENT_VALENCES = {
1242
+ "H": (
1243
+ [ValenceAndCharge(valence=1, charge=0)],
1244
+ [ValenceAndCharge(valence=0, charge=0), ValenceAndCharge(valence=1, charge=-1)],
1245
+ ),
1246
+ "He": ([ValenceAndCharge(valence=0, charge=0)], []),
1247
+ "Li": (
1248
+ [ValenceAndCharge(valence=1, charge=0)],
1249
+ [ValenceAndCharge(valence=1, charge=-1)],
1250
+ ),
1251
+ "Be": ([ValenceAndCharge(valence=2, charge=0)], []),
1252
+ "B": (
1253
+ [ValenceAndCharge(valence=3, charge=0), ValenceAndCharge(valence=4, charge=-1)],
1254
+ [ValenceAndCharge(valence=2, charge=0), ValenceAndCharge(valence=1, charge=0)],
1255
+ ),
1256
+ "C": (
1257
+ [ValenceAndCharge(valence=4, charge=0)],
1258
+ [
1259
+ ValenceAndCharge(valence=3, charge=-1),
1260
+ ValenceAndCharge(valence=2, charge=0),
1261
+ ValenceAndCharge(valence=2, charge=-1),
1262
+ ValenceAndCharge(valence=1, charge=0),
1263
+ ValenceAndCharge(valence=1, charge=-1),
1264
+ ],
1265
+ ),
1266
+ "N": (
1267
+ [ValenceAndCharge(valence=3, charge=0), ValenceAndCharge(valence=4, charge=1)],
1268
+ [
1269
+ ValenceAndCharge(valence=2, charge=-1),
1270
+ ValenceAndCharge(valence=5, charge=0),
1271
+ ValenceAndCharge(valence=1, charge=0),
1272
+ ValenceAndCharge(valence=0, charge=0),
1273
+ ValenceAndCharge(valence=1, charge=-1),
1274
+ ],
1275
+ ),
1276
+ "O": (
1277
+ [ValenceAndCharge(valence=2, charge=0), ValenceAndCharge(valence=1, charge=-1)],
1278
+ [ValenceAndCharge(valence=3, charge=1)],
1279
+ ),
1280
+ "F": ([ValenceAndCharge(valence=1, charge=0)], []),
1281
+ "Ne": ([ValenceAndCharge(valence=0, charge=0)], []),
1282
+ "Na": (
1283
+ [ValenceAndCharge(valence=1, charge=0)],
1284
+ [],
1285
+ ),
1286
+ "Mg": ([ValenceAndCharge(valence=2, charge=0)], []),
1287
+ "Al": (
1288
+ [ValenceAndCharge(valence=3, charge=0)],
1289
+ [ValenceAndCharge(valence=1, charge=0)],
1290
+ ),
1291
+ "Si": (
1292
+ [ValenceAndCharge(valence=4, charge=0)],
1293
+ [],
1294
+ ),
1295
+ "P": (
1296
+ [ValenceAndCharge(valence=5, charge=0)],
1297
+ [
1298
+ ValenceAndCharge(valence=4, charge=1),
1299
+ ValenceAndCharge(valence=3, charge=0),
1300
+ ValenceAndCharge(valence=2, charge=0),
1301
+ ValenceAndCharge(valence=1, charge=0),
1302
+ ],
1303
+ ),
1304
+ "S": (
1305
+ [ValenceAndCharge(valence=2, charge=0), ValenceAndCharge(valence=6, charge=0)],
1306
+ [
1307
+ ValenceAndCharge(valence=4, charge=0),
1308
+ ValenceAndCharge(valence=1, charge=-1),
1309
+ ValenceAndCharge(valence=3, charge=1),
1310
+ ],
1311
+ ),
1312
+ "Cl": (
1313
+ [ValenceAndCharge(valence=1, charge=0)],
1314
+ [],
1315
+ ),
1316
+ "Ar": ([ValenceAndCharge(valence=0, charge=0)], []),
1317
+ "K": (
1318
+ [ValenceAndCharge(valence=1, charge=0)],
1319
+ [],
1320
+ ),
1321
+ "Ca": ([ValenceAndCharge(valence=2, charge=0)], []),
1322
+ "Sc": (
1323
+ [ValenceAndCharge(valence=3, charge=0)],
1324
+ [ValenceAndCharge(valence=2, charge=0), ValenceAndCharge(valence=1, charge=0)],
1325
+ ),
1326
+ "Ti": (
1327
+ [ValenceAndCharge(valence=4, charge=0)],
1328
+ [
1329
+ ValenceAndCharge(valence=3, charge=0),
1330
+ ValenceAndCharge(valence=2, charge=0),
1331
+ ValenceAndCharge(valence=0, charge=0),
1332
+ ],
1333
+ ),
1334
+ "V": (
1335
+ [
1336
+ ValenceAndCharge(valence=5, charge=0),
1337
+ ValenceAndCharge(valence=4, charge=0),
1338
+ ValenceAndCharge(valence=3, charge=0),
1339
+ ],
1340
+ [
1341
+ ValenceAndCharge(valence=2, charge=0),
1342
+ ValenceAndCharge(valence=1, charge=0),
1343
+ ValenceAndCharge(valence=0, charge=0),
1344
+ ],
1345
+ ),
1346
+ "Cr": (
1347
+ [
1348
+ ValenceAndCharge(valence=6, charge=0),
1349
+ ValenceAndCharge(valence=3, charge=0),
1350
+ ValenceAndCharge(valence=2, charge=0),
1351
+ ],
1352
+ [
1353
+ ValenceAndCharge(valence=5, charge=0),
1354
+ ValenceAndCharge(valence=4, charge=0),
1355
+ ValenceAndCharge(valence=1, charge=0),
1356
+ ValenceAndCharge(valence=0, charge=0),
1357
+ ],
1358
+ ),
1359
+ "Mn": (
1360
+ [
1361
+ ValenceAndCharge(valence=7, charge=0),
1362
+ ValenceAndCharge(valence=4, charge=0),
1363
+ ValenceAndCharge(valence=2, charge=0),
1364
+ ],
1365
+ [
1366
+ ValenceAndCharge(valence=6, charge=0),
1367
+ ValenceAndCharge(valence=5, charge=0),
1368
+ ValenceAndCharge(valence=3, charge=0),
1369
+ ValenceAndCharge(valence=1, charge=0),
1370
+ ValenceAndCharge(valence=0, charge=0),
1371
+ ],
1372
+ ),
1373
+ "Fe": (
1374
+ [ValenceAndCharge(valence=2, charge=0), ValenceAndCharge(valence=3, charge=0)],
1375
+ [
1376
+ ValenceAndCharge(valence=6, charge=0),
1377
+ ValenceAndCharge(valence=5, charge=0),
1378
+ ValenceAndCharge(valence=4, charge=0),
1379
+ ValenceAndCharge(valence=1, charge=0),
1380
+ ValenceAndCharge(valence=0, charge=0),
1381
+ ],
1382
+ ),
1383
+ "Co": (
1384
+ [ValenceAndCharge(valence=2, charge=0), ValenceAndCharge(valence=3, charge=0)],
1385
+ [
1386
+ ValenceAndCharge(valence=5, charge=0),
1387
+ ValenceAndCharge(valence=4, charge=0),
1388
+ ValenceAndCharge(valence=1, charge=0),
1389
+ ValenceAndCharge(valence=0, charge=0),
1390
+ ],
1391
+ ),
1392
+ "Ni": (
1393
+ [ValenceAndCharge(valence=2, charge=0)],
1394
+ [
1395
+ ValenceAndCharge(valence=6, charge=0),
1396
+ ValenceAndCharge(valence=4, charge=0),
1397
+ ValenceAndCharge(valence=3, charge=0),
1398
+ ValenceAndCharge(valence=1, charge=0),
1399
+ ValenceAndCharge(valence=0, charge=0),
1400
+ ],
1401
+ ),
1402
+ "Cu": (
1403
+ [ValenceAndCharge(valence=2, charge=0), ValenceAndCharge(valence=1, charge=0)],
1404
+ [
1405
+ ValenceAndCharge(valence=4, charge=0),
1406
+ ValenceAndCharge(valence=3, charge=0),
1407
+ ValenceAndCharge(valence=0, charge=0),
1408
+ ],
1409
+ ),
1410
+ "Zn": (
1411
+ [ValenceAndCharge(valence=2, charge=0)],
1412
+ [ValenceAndCharge(valence=1, charge=0), ValenceAndCharge(valence=0, charge=0)],
1413
+ ),
1414
+ "Ga": (
1415
+ [ValenceAndCharge(valence=3, charge=0)],
1416
+ [ValenceAndCharge(valence=2, charge=0), ValenceAndCharge(valence=1, charge=0)],
1417
+ ),
1418
+ "Ge": (
1419
+ [ValenceAndCharge(valence=4, charge=0)],
1420
+ [
1421
+ ValenceAndCharge(valence=3, charge=0),
1422
+ ValenceAndCharge(valence=2, charge=0),
1423
+ ValenceAndCharge(valence=1, charge=0),
1424
+ ],
1425
+ ),
1426
+ "As": (
1427
+ [ValenceAndCharge(valence=5, charge=0), ValenceAndCharge(valence=4, charge=1)],
1428
+ [],
1429
+ ),
1430
+ "Se": (
1431
+ [ValenceAndCharge(valence=2, charge=0)],
1432
+ [],
1433
+ ),
1434
+ "Br": (
1435
+ [ValenceAndCharge(valence=1, charge=0)],
1436
+ [],
1437
+ ),
1438
+ "Kr": (
1439
+ [ValenceAndCharge(valence=0, charge=0)],
1440
+ [ValenceAndCharge(valence=2, charge=0)],
1441
+ ),
1442
+ "Rb": (
1443
+ [ValenceAndCharge(valence=1, charge=0)],
1444
+ [],
1445
+ ),
1446
+ "Sr": ([ValenceAndCharge(valence=2, charge=0)], []),
1447
+ "Y": (
1448
+ [ValenceAndCharge(valence=3, charge=0)],
1449
+ [ValenceAndCharge(valence=2, charge=0)],
1450
+ ),
1451
+ "Zr": (
1452
+ [ValenceAndCharge(valence=4, charge=0)],
1453
+ [
1454
+ ValenceAndCharge(valence=3, charge=0),
1455
+ ValenceAndCharge(valence=2, charge=0),
1456
+ ValenceAndCharge(valence=1, charge=0),
1457
+ ValenceAndCharge(valence=0, charge=0),
1458
+ ],
1459
+ ),
1460
+ "Nb": (
1461
+ [ValenceAndCharge(valence=5, charge=0)],
1462
+ [
1463
+ ValenceAndCharge(valence=4, charge=0),
1464
+ ValenceAndCharge(valence=3, charge=0),
1465
+ ValenceAndCharge(valence=2, charge=0),
1466
+ ValenceAndCharge(valence=1, charge=0),
1467
+ ValenceAndCharge(valence=0, charge=0),
1468
+ ],
1469
+ ),
1470
+ "Mo": (
1471
+ [ValenceAndCharge(valence=6, charge=0), ValenceAndCharge(valence=4, charge=0)],
1472
+ [
1473
+ ValenceAndCharge(valence=5, charge=0),
1474
+ ValenceAndCharge(valence=3, charge=0),
1475
+ ValenceAndCharge(valence=2, charge=0),
1476
+ ValenceAndCharge(valence=1, charge=0),
1477
+ ValenceAndCharge(valence=0, charge=0),
1478
+ ],
1479
+ ),
1480
+ "Tc": (
1481
+ [ValenceAndCharge(valence=7, charge=0), ValenceAndCharge(valence=4, charge=0)],
1482
+ [
1483
+ ValenceAndCharge(valence=6, charge=0),
1484
+ ValenceAndCharge(valence=5, charge=0),
1485
+ ValenceAndCharge(valence=3, charge=0),
1486
+ ValenceAndCharge(valence=2, charge=0),
1487
+ ValenceAndCharge(valence=1, charge=0),
1488
+ ValenceAndCharge(valence=0, charge=0),
1489
+ ],
1490
+ ),
1491
+ "Ru": (
1492
+ [ValenceAndCharge(valence=4, charge=0), ValenceAndCharge(valence=3, charge=0)],
1493
+ [
1494
+ ValenceAndCharge(valence=8, charge=0),
1495
+ ValenceAndCharge(valence=7, charge=0),
1496
+ ValenceAndCharge(valence=6, charge=0),
1497
+ ValenceAndCharge(valence=5, charge=0),
1498
+ ValenceAndCharge(valence=2, charge=0),
1499
+ ValenceAndCharge(valence=1, charge=0),
1500
+ ValenceAndCharge(valence=0, charge=0),
1501
+ ],
1502
+ ),
1503
+ "Rh": (
1504
+ [ValenceAndCharge(valence=3, charge=0)],
1505
+ [
1506
+ ValenceAndCharge(valence=6, charge=0),
1507
+ ValenceAndCharge(valence=5, charge=0),
1508
+ ValenceAndCharge(valence=4, charge=0),
1509
+ ValenceAndCharge(valence=2, charge=0),
1510
+ ValenceAndCharge(valence=1, charge=0),
1511
+ ValenceAndCharge(valence=0, charge=0),
1512
+ ],
1513
+ ),
1514
+ "Pd": (
1515
+ [ValenceAndCharge(valence=4, charge=0), ValenceAndCharge(valence=2, charge=0)],
1516
+ [ValenceAndCharge(valence=0, charge=0)],
1517
+ ),
1518
+ "Ag": (
1519
+ [ValenceAndCharge(valence=1, charge=0)],
1520
+ [
1521
+ ValenceAndCharge(valence=3, charge=0),
1522
+ ValenceAndCharge(valence=2, charge=0),
1523
+ ValenceAndCharge(valence=0, charge=0),
1524
+ ],
1525
+ ),
1526
+ "Cd": (
1527
+ [ValenceAndCharge(valence=2, charge=0)],
1528
+ [ValenceAndCharge(valence=1, charge=0)],
1529
+ ),
1530
+ "In": (
1531
+ [ValenceAndCharge(valence=3, charge=0)],
1532
+ [ValenceAndCharge(valence=2, charge=0), ValenceAndCharge(valence=1, charge=0)],
1533
+ ),
1534
+ "Sn": (
1535
+ [ValenceAndCharge(valence=2, charge=0)],
1536
+ [ValenceAndCharge(valence=4, charge=0)],
1537
+ ),
1538
+ "Sb": (
1539
+ [ValenceAndCharge(valence=3, charge=0)],
1540
+ [ValenceAndCharge(valence=5, charge=0), ValenceAndCharge(valence=3, charge=-1)],
1541
+ ),
1542
+ "Te": (
1543
+ [ValenceAndCharge(valence=4, charge=0)],
1544
+ [ValenceAndCharge(valence=2, charge=0), ValenceAndCharge(valence=6, charge=0)],
1545
+ ),
1546
+ "I": (
1547
+ [ValenceAndCharge(valence=1, charge=0)],
1548
+ [
1549
+ ValenceAndCharge(valence=1, charge=0),
1550
+ ValenceAndCharge(valence=3, charge=0),
1551
+ ValenceAndCharge(valence=7, charge=0),
1552
+ ValenceAndCharge(valence=0, charge=0),
1553
+ ],
1554
+ ),
1555
+ "Xe": (
1556
+ [ValenceAndCharge(valence=0, charge=0)],
1557
+ [
1558
+ ValenceAndCharge(valence=2, charge=0),
1559
+ ValenceAndCharge(valence=4, charge=0),
1560
+ ValenceAndCharge(valence=6, charge=0),
1561
+ ValenceAndCharge(valence=8, charge=0),
1562
+ ],
1563
+ ),
1564
+ "Cs": (
1565
+ [ValenceAndCharge(valence=1, charge=0)],
1566
+ [],
1567
+ ),
1568
+ "Ba": ([ValenceAndCharge(valence=2, charge=0)], []),
1569
+ "La": ([ValenceAndCharge(valence=3, charge=0)], []),
1570
+ "Ce": (
1571
+ [ValenceAndCharge(valence=3, charge=0)],
1572
+ [ValenceAndCharge(valence=4, charge=0)],
1573
+ ),
1574
+ "Pr": (
1575
+ [ValenceAndCharge(valence=3, charge=0)],
1576
+ [ValenceAndCharge(valence=4, charge=0)],
1577
+ ),
1578
+ "Nd": ([ValenceAndCharge(valence=3, charge=0)], []),
1579
+ "Pm": ([ValenceAndCharge(valence=3, charge=0)], []),
1580
+ "Sm": ([ValenceAndCharge(valence=3, charge=0)], []),
1581
+ "Eu": (
1582
+ [ValenceAndCharge(valence=3, charge=0)],
1583
+ [ValenceAndCharge(valence=2, charge=0)],
1584
+ ),
1585
+ "Gd": ([ValenceAndCharge(valence=3, charge=0)], []),
1586
+ "Tb": (
1587
+ [ValenceAndCharge(valence=3, charge=0)],
1588
+ [ValenceAndCharge(valence=4, charge=0)],
1589
+ ),
1590
+ "Dy": ([ValenceAndCharge(valence=3, charge=0)], []),
1591
+ "Ho": ([ValenceAndCharge(valence=3, charge=0)], []),
1592
+ "Er": ([ValenceAndCharge(valence=3, charge=0)], []),
1593
+ "Tm": ([ValenceAndCharge(valence=3, charge=0)], []),
1594
+ "Yb": (
1595
+ [ValenceAndCharge(valence=3, charge=0)],
1596
+ [ValenceAndCharge(valence=2, charge=0)],
1597
+ ),
1598
+ "Lu": ([ValenceAndCharge(valence=3, charge=0)], []),
1599
+ "Hf": ([ValenceAndCharge(valence=4, charge=0)], []),
1600
+ "Ta": ([ValenceAndCharge(valence=5, charge=0)], []),
1601
+ "W": (
1602
+ [ValenceAndCharge(valence=6, charge=0), ValenceAndCharge(valence=4, charge=0)],
1603
+ [
1604
+ ValenceAndCharge(valence=5, charge=0),
1605
+ ValenceAndCharge(valence=3, charge=0),
1606
+ ValenceAndCharge(valence=2, charge=0),
1607
+ ],
1608
+ ),
1609
+ "Re": (
1610
+ [
1611
+ ValenceAndCharge(valence=5, charge=0),
1612
+ ValenceAndCharge(valence=4, charge=0),
1613
+ ValenceAndCharge(valence=3, charge=0),
1614
+ ],
1615
+ [
1616
+ ValenceAndCharge(valence=7, charge=0),
1617
+ ValenceAndCharge(valence=6, charge=0),
1618
+ ValenceAndCharge(valence=2, charge=0),
1619
+ ValenceAndCharge(valence=1, charge=0),
1620
+ ValenceAndCharge(valence=0, charge=0),
1621
+ ],
1622
+ ),
1623
+ "Os": (
1624
+ [ValenceAndCharge(valence=4, charge=0)],
1625
+ [
1626
+ ValenceAndCharge(valence=8, charge=0),
1627
+ ValenceAndCharge(valence=6, charge=0),
1628
+ ValenceAndCharge(valence=2, charge=0),
1629
+ ],
1630
+ ),
1631
+ "Ir": (
1632
+ [ValenceAndCharge(valence=4, charge=0), ValenceAndCharge(valence=3, charge=0)],
1633
+ [ValenceAndCharge(valence=6, charge=0), ValenceAndCharge(valence=4, charge=0)],
1634
+ ),
1635
+ "Pt": (
1636
+ [ValenceAndCharge(valence=2, charge=0)],
1637
+ [ValenceAndCharge(valence=4, charge=0)],
1638
+ ),
1639
+ "Au": (
1640
+ [ValenceAndCharge(valence=3, charge=0)],
1641
+ [ValenceAndCharge(valence=1, charge=0)],
1642
+ ),
1643
+ "Hg": (
1644
+ [ValenceAndCharge(valence=2, charge=0)],
1645
+ [ValenceAndCharge(valence=1, charge=0)],
1646
+ ),
1647
+ "Tl": (
1648
+ [ValenceAndCharge(valence=3, charge=0)],
1649
+ [ValenceAndCharge(valence=1, charge=0)],
1650
+ ),
1651
+ "Pb": (
1652
+ [ValenceAndCharge(valence=4, charge=0)],
1653
+ [ValenceAndCharge(valence=2, charge=0)],
1654
+ ),
1655
+ "Bi": (
1656
+ [ValenceAndCharge(valence=3, charge=0), ValenceAndCharge(valence=1, charge=0)],
1657
+ [ValenceAndCharge(valence=5, charge=0)],
1658
+ ),
1659
+ "Po": (
1660
+ [ValenceAndCharge(valence=4, charge=0)],
1661
+ [ValenceAndCharge(valence=2, charge=0)],
1662
+ ),
1663
+ "At": (
1664
+ [ValenceAndCharge(valence=1, charge=0)],
1665
+ [
1666
+ ValenceAndCharge(valence=5, charge=0),
1667
+ ValenceAndCharge(valence=3, charge=0),
1668
+ ValenceAndCharge(valence=7, charge=0),
1669
+ ],
1670
+ ),
1671
+ "Rn": (
1672
+ [ValenceAndCharge(valence=0, charge=0)],
1673
+ [ValenceAndCharge(valence=2, charge=0)],
1674
+ ),
1675
+ "Fr": ([ValenceAndCharge(valence=1, charge=0)], []),
1676
+ "Ra": ([ValenceAndCharge(valence=2, charge=0)], []),
1677
+ "Ac": ([ValenceAndCharge(valence=3, charge=0)], []),
1678
+ "Th": ([ValenceAndCharge(valence=4, charge=0)], []),
1679
+ "Pa": (
1680
+ [ValenceAndCharge(valence=5, charge=0)],
1681
+ [ValenceAndCharge(valence=4, charge=0)],
1682
+ ),
1683
+ "U": (
1684
+ [ValenceAndCharge(valence=6, charge=0)],
1685
+ [
1686
+ ValenceAndCharge(valence=5, charge=0),
1687
+ ValenceAndCharge(valence=4, charge=0),
1688
+ ValenceAndCharge(valence=3, charge=0),
1689
+ ],
1690
+ ),
1691
+ "Np": (
1692
+ [ValenceAndCharge(valence=7, charge=0)],
1693
+ [
1694
+ ValenceAndCharge(valence=6, charge=0),
1695
+ ValenceAndCharge(valence=5, charge=0),
1696
+ ValenceAndCharge(valence=4, charge=0),
1697
+ ValenceAndCharge(valence=3, charge=0),
1698
+ ],
1699
+ ),
1700
+ "Pu": (
1701
+ [ValenceAndCharge(valence=7, charge=0), ValenceAndCharge(valence=4, charge=0)],
1702
+ [
1703
+ ValenceAndCharge(valence=6, charge=0),
1704
+ ValenceAndCharge(valence=5, charge=0),
1705
+ ValenceAndCharge(valence=3, charge=0),
1706
+ ],
1707
+ ),
1708
+ "Am": (
1709
+ [ValenceAndCharge(valence=3, charge=0)],
1710
+ [ValenceAndCharge(valence=5, charge=0), ValenceAndCharge(valence=4, charge=0)],
1711
+ ),
1712
+ "Cm": (
1713
+ [
1714
+ ValenceAndCharge(valence=6, charge=0),
1715
+ ValenceAndCharge(valence=5, charge=0),
1716
+ ValenceAndCharge(valence=3, charge=0),
1717
+ ],
1718
+ [],
1719
+ ),
1720
+ "Bk": (
1721
+ [ValenceAndCharge(valence=3, charge=0)],
1722
+ [ValenceAndCharge(valence=4, charge=0)],
1723
+ ),
1724
+ "Cf": ([ValenceAndCharge(valence=3, charge=0)], []),
1725
+ "Es": ([ValenceAndCharge(valence=3, charge=0)], []),
1726
+ "Fm": ([ValenceAndCharge(valence=3, charge=0)], []),
1727
+ "Md": ([ValenceAndCharge(valence=3, charge=0)], []),
1728
+ "No": (
1729
+ [ValenceAndCharge(valence=3, charge=0)],
1730
+ [ValenceAndCharge(valence=2, charge=0)],
1731
+ ),
1732
+ "Lr": ([ValenceAndCharge(valence=3, charge=0)], []),
1733
+ "Rf": ([ValenceAndCharge(valence=4, charge=0)], []),
1734
+ "Db": ([ValenceAndCharge(valence=5, charge=0)], []),
1735
+ "Sg": ([ValenceAndCharge(valence=6, charge=0)], []),
1736
+ "Bh": ([ValenceAndCharge(valence=7, charge=0)], []),
1737
+ "Hs": ([ValenceAndCharge(valence=8, charge=0)], []),
1738
+ "Mt": ([ValenceAndCharge(valence=8, charge=0)], []),
1739
+ "Ds": ([ValenceAndCharge(valence=8, charge=0)], []),
1740
+ "Rg": ([ValenceAndCharge(valence=8, charge=0)], []),
1741
+ "Cn": ([ValenceAndCharge(valence=2, charge=0)], []),
1742
+ "Nh": ([ValenceAndCharge(valence=3, charge=0)], []),
1743
+ "Fl": ([ValenceAndCharge(valence=4, charge=0)], []),
1744
+ "Mc": ([ValenceAndCharge(valence=3, charge=0)], []),
1745
+ "Lv": ([ValenceAndCharge(valence=4, charge=0)], []),
1746
+ "Ts": ([ValenceAndCharge(valence=7, charge=0)], []),
1747
+ "Og": ([ValenceAndCharge(valence=0, charge=0)], []),
1748
+ "+": ([ValenceAndCharge(valence=1, charge=1)], []),
1749
+ "-": ([ValenceAndCharge(valence=1, charge=-1)], []),
1750
+ }
massspecgym/models/de_novo/smiles_tranformer.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import typing as T
5
+ from torch_geometric.nn import MLP
6
+ from massspecgym.models.tokenizers import SpecialTokensBaseTokenizer
7
+ from massspecgym.data.transforms import MolToFormulaVector
8
+ from massspecgym.models.base import Stage
9
+ from massspecgym.models.de_novo.base import DeNovoMassSpecGymModel
10
+ from massspecgym.definitions import PAD_TOKEN, SOS_TOKEN, EOS_TOKEN
11
+
12
+
13
+ class SmilesTransformer(DeNovoMassSpecGymModel):
14
+ def __init__(
15
+ self,
16
+ input_dim: int,
17
+ d_model: int,
18
+ nhead: int,
19
+ num_encoder_layers: int,
20
+ num_decoder_layers: int,
21
+ smiles_tokenizer: SpecialTokensBaseTokenizer,
22
+ start_token: str = SOS_TOKEN,
23
+ end_token: str = EOS_TOKEN,
24
+ pad_token: str = PAD_TOKEN,
25
+ dropout: float = 0.1,
26
+ max_smiles_len: int = 200,
27
+ k_predictions: int = 1,
28
+ temperature: T.Optional[float] = 1.0,
29
+ pre_norm: bool = False,
30
+ chemical_formula: bool = False,
31
+ *args,
32
+ **kwargs
33
+ ):
34
+ super().__init__(*args, **kwargs)
35
+ self.smiles_tokenizer = smiles_tokenizer
36
+ self.vocab_size = smiles_tokenizer.get_vocab_size()
37
+ for token in [start_token, end_token, pad_token]:
38
+ assert token in smiles_tokenizer.get_vocab(), f"Token {token} not found in tokenizer vocabulary."
39
+ self.start_token_id = smiles_tokenizer.token_to_id(start_token)
40
+ self.end_token_id = smiles_tokenizer.token_to_id(end_token)
41
+ self.pad_token_id = smiles_tokenizer.token_to_id(pad_token)
42
+
43
+ self.d_model = d_model
44
+ self.max_smiles_len = max_smiles_len
45
+ self.k_predictions = k_predictions
46
+ self.temperature = temperature
47
+ if self.k_predictions == 1: # TODO: this logic should be changed because sampling with k = 1 also makes sense
48
+ self.temperature = None
49
+
50
+ self.src_encoder = nn.Linear(input_dim, d_model)
51
+ self.tgt_embedding = nn.Embedding(self.vocab_size, d_model)
52
+ self.transformer = nn.Transformer(
53
+ d_model=d_model,
54
+ nhead=nhead,
55
+ num_encoder_layers=num_encoder_layers,
56
+ num_decoder_layers=num_decoder_layers,
57
+ dim_feedforward=4 * d_model,
58
+ dropout=dropout,
59
+ norm_first=pre_norm
60
+ )
61
+ self.tgt_decoder = nn.Linear(d_model, self.vocab_size)
62
+
63
+ self.chemical_formula = chemical_formula
64
+ if self.chemical_formula:
65
+ self.formula_mlp = MLP(
66
+ in_channels=MolToFormulaVector.num_elements(),
67
+ hidden_channels=MolToFormulaVector.num_elements(),
68
+ out_channels=d_model,
69
+ num_layers=1,
70
+ dropout=dropout,
71
+ norm=None
72
+ )
73
+
74
+ self.criterion = nn.CrossEntropyLoss()
75
+
76
+ def forward(self, batch):
77
+
78
+ spec = batch["spec"] # (batch_size, seq_len, in_dim)
79
+ smiles = batch["mol"] # List of SMILES of length batch_size
80
+
81
+ smiles = self.smiles_tokenizer.encode_batch(smiles)
82
+ smiles = [s.ids for s in smiles]
83
+ smiles = torch.tensor(smiles, device=spec.device) # (batch_size, seq_len)
84
+
85
+ # Generating padding masks for variable-length sequences
86
+ src_key_padding_mask = self.generate_src_padding_mask(spec)
87
+ tgt_key_padding_mask = self.generate_tgt_padding_mask(smiles)
88
+
89
+ # Create target mask (causal mask)
90
+ tgt_seq_len = smiles.size(1)
91
+ tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_seq_len).to(smiles.device)
92
+
93
+ # Preapre inputs for transformer teacher forcing
94
+ src = spec.permute(1, 0, 2) # (seq_len, batch_size, in_dim)
95
+ smiles = smiles.permute(1, 0) # (seq_len, batch_size)
96
+ tgt = smiles[:-1, :]
97
+ tgt_mask = tgt_mask[:-1, :-1]
98
+ src_key_padding_mask = src_key_padding_mask
99
+ tgt_key_padding_mask = tgt_key_padding_mask[:, :-1]
100
+
101
+ # Input and output embeddings
102
+ src = self.src_encoder(src) # (seq_len, batch_size, d_model)
103
+ if self.chemical_formula:
104
+ formula_emb = self.formula_mlp(batch["formula"]) # (batch_size, d_model)
105
+ src = src + formula_emb.unsqueeze(0) # (seq_len, batch_size, d_model) + (1, batch_size, d_model)
106
+ src = src * (self.d_model**0.5)
107
+ tgt = self.tgt_embedding(tgt) * (self.d_model**0.5) # (seq_len, batch_size, d_model)
108
+
109
+ # Transformer forward pass
110
+ memory = self.transformer.encoder(src, src_key_padding_mask=src_key_padding_mask)
111
+ output = self.transformer.decoder(tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask)
112
+
113
+ # Logits to vocabulary
114
+ output = self.tgt_decoder(output) # (seq_len, batch_size, vocab_size)
115
+
116
+ # Reshape before returning
117
+ smiles_pred = output.view(-1, self.vocab_size)
118
+ smiles = smiles[1:, :].contiguous().view(-1)
119
+ return smiles_pred, smiles
120
+
121
+ def step(self, batch: dict, stage: Stage = Stage.NONE) -> dict:
122
+
123
+ # Forward pass
124
+ smiles_pred, smiles = self.forward(batch)
125
+
126
+ # Compute loss
127
+ loss = self.criterion(smiles_pred, smiles)
128
+
129
+ # Generate SMILES strings
130
+ if stage in self.log_only_loss_at_stages:
131
+ mols_pred = None
132
+ else:
133
+ mols_pred = self.decode_smiles(batch)
134
+
135
+ return dict(loss=loss, mols_pred=mols_pred)
136
+
137
+ def generate_src_padding_mask(self, spec):
138
+ return spec.sum(-1) == 0
139
+
140
+ def generate_tgt_padding_mask(self, smiles):
141
+ return smiles == self.pad_token_id
142
+
143
+ def decode_smiles(self, batch):
144
+
145
+ decoded_smiles_str = []
146
+ for _ in range(self.k_predictions):
147
+ decoded_smiles = self.greedy_decode(
148
+ batch,
149
+ max_len=self.max_smiles_len,
150
+ temperature=self.temperature,
151
+ )
152
+
153
+ decoded_smiles = [seq.tolist() for seq in decoded_smiles]
154
+ decoded_smiles_str.append(self.smiles_tokenizer.decode_batch(decoded_smiles))
155
+
156
+ # Transpose from (k, batch_size) to (batch_size, k)
157
+ decoded_smiles_str = list(map(list, zip(*decoded_smiles_str)))
158
+
159
+ return decoded_smiles_str
160
+
161
+ def greedy_decode(self, batch, max_len, temperature):
162
+
163
+ with torch.inference_mode():
164
+
165
+ spec = batch["spec"] # (batch_size, seq_len, in_dim)
166
+ src_key_padding_mask = self.generate_src_padding_mask(spec)
167
+
168
+ spec = spec.permute(1, 0, 2) # (seq_len, batch_size, in_dim)
169
+ src = self.src_encoder(spec) # (seq_len, batch_size, d_model)
170
+ if self.chemical_formula:
171
+ formula_emb = self.formula_mlp(batch["formula"]) # (batch_size, d_model)
172
+ src = src + formula_emb.unsqueeze(0) # (seq_len, batch_size, d_model) + (1, batch_size, d_model)
173
+ src = src * (self.d_model**0.5)
174
+ memory = self.transformer.encoder(src, src_key_padding_mask=src_key_padding_mask)
175
+
176
+ batch_size = src.size(1)
177
+ out_tokens = torch.ones(1, batch_size).fill_(self.start_token_id).type(torch.long).to(spec.device)
178
+
179
+ for _ in range(max_len - 1):
180
+ tgt = self.tgt_embedding(out_tokens) * (self.d_model**0.5)
181
+ tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.size(0)).to(src.device)
182
+ out = self.transformer.decoder(tgt, memory, tgt_mask=tgt_mask)
183
+ out = self.tgt_decoder(out[-1, :]) # (batch_size, vocab_size)
184
+
185
+ # Select next token
186
+ if self.temperature is None:
187
+ probs = F.softmax(out, dim=-1)
188
+ next_token = torch.argmax(probs, dim=-1) # (batch_size,)
189
+ else:
190
+ probs = F.softmax(out / temperature, dim=-1)
191
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1) # (batch_size,)
192
+
193
+ next_token = next_token.unsqueeze(0) # (1, batch_size)
194
+
195
+ out_tokens = torch.cat([out_tokens, next_token], dim=0)
196
+ if torch.all(next_token == self.end_token_id):
197
+ break
198
+
199
+ out_tokens = out_tokens.permute(1, 0) # (batch_size, seq_len)
200
+ return out_tokens
massspecgym/models/layers.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reproduced from https://github.com/pluskal-lab/DreaMS/blob/main/dreams/models/layers/fourier_features.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from math import ceil
6
+
7
+
8
+ class FourierFeatures(nn.Module):
9
+ """
10
+ A module for generating Fourier features for input data. This module maps input data
11
+ to a higher-dimensional space using sinusoidal functions, enhancing the representation
12
+ capabilities for various tasks.
13
+
14
+ Args:
15
+ strategy (str): Strategy for generating frequency components. Available options are
16
+ 'random', 'voronov_et_al', and 'dreams'. Each option corresponds to a certain paper:
17
+ - 'random': https://doi.org/10.48550/arXiv.2006.10739.
18
+ - 'voronov_et_al': https://doi.org/10.48550/arXiv.2207.02980.
19
+ - 'dreams': https://doi.org/10.26434/chemrxiv-2023-kss3r-v2.
20
+ x_min (float, optional): The minimum value for generating frequencies. Defaults to 1e-4.
21
+ x_max (float, optional): The maximum value for generating frequencies. Defaults to 1000.
22
+ trainable (bool, optional): If True, the frequencies are treated as trainable parameters.
23
+ Defaults to False.
24
+ funcs (str, optional): Specifies the trigonometric functions to use. Options are 'both',
25
+ 'sin', and 'cos'. Defaults to 'both'.
26
+ sigma (float, optional): Standard deviation used for random frequency initialization
27
+ when strategy is 'random'. Defaults to 10.
28
+ num_freqs (int, optional): Number of frequency components to generate. Defaults to 512.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ strategy='dreams',
34
+ x_min=1e-4,
35
+ x_max=1000,
36
+ trainable=False,
37
+ funcs="both",
38
+ sigma=10,
39
+ num_freqs=512,
40
+ ):
41
+ assert funcs in {"both", "sin", "cos"}, "funcs must be 'both', 'sin', or 'cos'"
42
+ assert 0 < x_min < 1, "x_min must be a positive fraction"
43
+
44
+ super().__init__()
45
+ self.funcs = funcs
46
+ self.strategy = strategy
47
+ self.trainable = trainable
48
+ self.num_freqs = num_freqs
49
+
50
+ if strategy == "random":
51
+ self.b = torch.randn(num_freqs) * sigma
52
+ elif self.strategy == "voronov_et_al":
53
+ self.b = torch.tensor(
54
+ [
55
+ 1 / (x_min * (x_max / x_min) ** (2 * i / (num_freqs - 2)))
56
+ for i in range(1, num_freqs)
57
+ ],
58
+ )
59
+ elif self.strategy == "dreams":
60
+ self.b = torch.tensor(
61
+ [1 / (x_min * i) for i in range(2, ceil(1 / x_min), 2)]
62
+ + [1 / (1 * i) for i in range(2, ceil(x_max), 1)],
63
+ )
64
+ else:
65
+ raise ValueError(f"Unknown strategy: {strategy}")
66
+
67
+ self.b = self.b.unsqueeze(0)
68
+ self.b = nn.Parameter(self.b, requires_grad=self.trainable)
69
+ self.register_parameter("Fourier frequencies", self.b)
70
+
71
+ @property
72
+ def num_features(self):
73
+ """
74
+ Returns the number of features generated by the FourierFeatures module.
75
+ If both sine and cosine functions are used, the number of features is doubled.
76
+
77
+ Returns:
78
+ int: The number of features.
79
+ """
80
+ return self.b.shape[1] if self.funcs != "both" else 2 * self.b.shape[1]
81
+
82
+ def forward(self, x):
83
+ """
84
+ Applies the Fourier transformation to the input data.
85
+
86
+ Args:
87
+ x (torch.Tensor): Input tensor of shape (batch_size, input_dim) to transform.
88
+
89
+ Returns:
90
+ torch.Tensor: Fourier features.
91
+ """
92
+ x = 2 * torch.pi * x @ self.b
93
+
94
+ if self.funcs == "both":
95
+ x = torch.cat((torch.cos(x), torch.sin(x)), dim=-1)
96
+ elif self.funcs == "cos":
97
+ x = torch.cos(x)
98
+ elif self.funcs == "sin":
99
+ x = torch.sin(x)
100
+
101
+ return x
massspecgym/models/retrieval/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import RetrievalMassSpecGymModel
2
+ from .random import RandomRetrieval
3
+ from .deepsets import DeepSetsRetrieval
4
+ from .fingerprint_ffn import FingerprintFFNRetrieval
5
+ from .from_dict import FromDictRetrieval
6
+
7
+ __all__ = [
8
+ "RetrievalMassSpecGymModel",
9
+ "RandomRetrieval",
10
+ "DeepSetsRetrieval",
11
+ "FingerprintFFNRetrieval",
12
+ "FromDictRetrieval"
13
+ ]
massspecgym/models/retrieval/base.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as T
2
+ from abc import ABC
3
+
4
+ import pandas as pd
5
+ import torch
6
+ from torchmetrics import CosineSimilarity, MeanMetric
7
+ from torchmetrics.functional.retrieval import retrieval_hit_rate
8
+ from torch_geometric.utils import unbatch
9
+
10
+ from massspecgym.models.base import MassSpecGymModel, Stage
11
+ import massspecgym.utils as utils
12
+
13
+
14
+ class RetrievalMassSpecGymModel(MassSpecGymModel, ABC):
15
+
16
+ def __init__(
17
+ self,
18
+ at_ks: T.Iterable[int] = (1, 5, 20),
19
+ myopic_mces_kwargs: T.Optional[T.Mapping] = None,
20
+ *args,
21
+ **kwargs
22
+ ):
23
+ super().__init__(*args, **kwargs)
24
+ self.at_ks = at_ks
25
+ self.myopic_mces = utils.MyopicMCES(**(myopic_mces_kwargs or {}))
26
+
27
+ def on_batch_end(
28
+ self, outputs: T.Any, batch: dict, batch_idx: int, stage: Stage
29
+ ) -> None:
30
+ """
31
+ Compute evaluation metrics for the retrieval model based on the batch and corresponding
32
+ predictions.
33
+ """
34
+ self.log(
35
+ f"{stage.to_pref()}loss",
36
+ outputs['loss'],
37
+ batch_size=batch['spec'].size(0),
38
+ sync_dist=True,
39
+ prog_bar=True,
40
+ )
41
+ if stage in self.log_only_loss_at_stages:
42
+ return
43
+
44
+ metric_vals = {}
45
+ metric_vals |= self.evaluate_retrieval_step(
46
+ outputs["scores"],
47
+ batch["labels"],
48
+ batch["batch_ptr"],
49
+ stage=stage,
50
+ )
51
+ metric_vals |= self.evaluate_mces_at_1(
52
+ outputs["scores"],
53
+ batch["labels"],
54
+ batch["smiles"],
55
+ batch["candidates_smiles"],
56
+ batch["batch_ptr"],
57
+ stage=stage,
58
+ )
59
+ if stage == Stage.TEST and self.df_test_path is not None:
60
+ self._update_df_test(metric_vals)
61
+
62
+ def evaluate_retrieval_step(
63
+ self,
64
+ scores: torch.Tensor,
65
+ labels: torch.Tensor,
66
+ batch_ptr: torch.Tensor,
67
+ stage: Stage,
68
+ ) -> dict[str, torch.Tensor]:
69
+ """
70
+ Main evaluation method for the retrieval models. The retrieval step is evaluated by
71
+ computing the hit rate at different top-k values.
72
+
73
+ Args:
74
+ scores (torch.Tensor): Concatenated scores for all candidates for all samples in the
75
+ batch
76
+ labels (torch.Tensor): Concatenated True/False labels for all candidates for all samples
77
+ in the batch
78
+ batch_ptr (torch.Tensor): Number of each sample's candidates in the concatenated tensors
79
+ """
80
+ # Initialize return dictionary to store metric values per sample
81
+ metric_vals = {}
82
+
83
+ # Evaluate hitrate at different top-k values
84
+ indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
85
+ scores = unbatch(scores, indexes)
86
+ labels = unbatch(labels, indexes)
87
+
88
+ for at_k in self.at_ks:
89
+ hit_rates = []
90
+ for scores_sample, labels_sample in zip(scores, labels):
91
+ hit_rates.append(retrieval_hit_rate(scores_sample, labels_sample, top_k=at_k))
92
+ hit_rates = torch.tensor(hit_rates, device=batch_ptr.device)
93
+
94
+ metric_name = f"{stage.to_pref()}hit_rate@{at_k}"
95
+ self._update_metric(
96
+ metric_name,
97
+ MeanMetric,
98
+ (hit_rates,),
99
+ batch_size=batch_ptr.size(0),
100
+ bootstrap=stage == Stage.TEST
101
+ )
102
+ metric_vals[metric_name] = hit_rates
103
+
104
+ return metric_vals
105
+
106
+ def evaluate_mces_at_1(
107
+ self,
108
+ scores: torch.Tensor,
109
+ labels: torch.Tensor,
110
+ smiles: list[str],
111
+ candidates_smiles: list[str],
112
+ batch_ptr: torch.Tensor,
113
+ stage: Stage,
114
+ ) -> dict[str, torch.Tensor]:
115
+ """
116
+ TODO
117
+ """
118
+ if labels.sum() != len(smiles):
119
+ raise ValueError("MCES@1 evaluation currently supports exactly 1 positive candidate per sample.")
120
+
121
+ # Initialize return dictionary to store metric values per sample
122
+ metric_vals = {}
123
+
124
+ # Get top-1 predicted molecules for each ground-truth sample
125
+ smiles_pred_top_1 = []
126
+ batch_ptr = torch.cumsum(batch_ptr, dim=0)
127
+ for i, j in zip(torch.cat([torch.tensor([0], device=batch_ptr.device), batch_ptr]), batch_ptr):
128
+ scores_sample = scores[i:j]
129
+ top_1_idx = i + torch.argmax(scores_sample)
130
+ smiles_pred_top_1.append(candidates_smiles[top_1_idx])
131
+
132
+ # Calculate MCES distance between top-1 predicted molecules and ground truth
133
+ mces_dists = [
134
+ self.myopic_mces(sm, sm_pred)
135
+ for sm, sm_pred in zip(smiles, smiles_pred_top_1)
136
+ ]
137
+ mces_dists = torch.tensor(mces_dists, device=scores.device)
138
+
139
+ # Log
140
+ metric_name = f"{stage.to_pref()}mces@1"
141
+ self._update_metric(
142
+ metric_name,
143
+ MeanMetric,
144
+ (mces_dists,),
145
+ batch_size=len(mces_dists),
146
+ bootstrap=stage == Stage.TEST
147
+ )
148
+ metric_vals[metric_name] = mces_dists
149
+
150
+ return metric_vals
151
+
152
+ def evaluate_fingerprint_step(
153
+ self,
154
+ y_true: torch.Tensor,
155
+ y_pred: torch.Tensor,
156
+ stage: Stage,
157
+ ) -> None:
158
+ """
159
+ Utility evaluation method to assess the quality of predicted fingerprints. This method is
160
+ not a part of the necessary evaluation logic (not called in the `on_batch_end` method)
161
+ since retrieval models are not bound to predict fingerprints.
162
+
163
+ Args:
164
+ y_true (torch.Tensor): [batch_size, fingerprint_size] tensor of true fingerprints
165
+ y_pred (torch.Tensor): [batch_size, fingerprint_size] tensor of predicted fingerprints
166
+ """
167
+ # Cosine similarity between predicted and true fingerprints
168
+ self._update_metric(
169
+ f"{stage.to_pref()}fingerprint_cos_sim",
170
+ CosineSimilarity,
171
+ (y_pred, y_true),
172
+ batch_size=y_true.size(0),
173
+ metric_kwargs=dict(reduction="mean")
174
+ )
175
+
176
+ def test_step(
177
+ self,
178
+ batch: dict,
179
+ batch_idx: torch.Tensor
180
+ ) -> tuple[torch.Tensor, torch.Tensor]:
181
+ outputs = super().test_step(batch, batch_idx)
182
+
183
+ # Get sorted candidate SMILES based on the predicted scores for each sample
184
+ if self.df_test_path is not None:
185
+ indexes = utils.batch_ptr_to_batch_idx(batch['batch_ptr'])
186
+ scores = unbatch(outputs['scores'], indexes)
187
+ candidates_smiles = utils.unbatch_list(batch['candidates_smiles'], indexes)
188
+ sorted_candidate_smiles = []
189
+ for scores_sample, candidates_smiles_sample in zip(scores, candidates_smiles):
190
+ candidates_smiles_sample = [
191
+ x for _, x in sorted(zip(scores_sample, candidates_smiles_sample), reverse=True)
192
+ ]
193
+ sorted_candidate_smiles.append(candidates_smiles_sample)
194
+ self._update_df_test({
195
+ 'identifier': batch['identifier'],
196
+ 'sorted_candidate_smiles': sorted_candidate_smiles
197
+ })
198
+
199
+ return outputs
200
+
201
+ def on_test_epoch_end(self):
202
+ # Save test data frame to disk
203
+ if self.df_test_path is not None:
204
+ df_test = pd.DataFrame(self.df_test)
205
+ self.df_test_path.parent.mkdir(parents=True, exist_ok=True)
206
+ df_test.to_pickle(self.df_test_path)
massspecgym/models/retrieval/deepsets.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as T
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch_geometric.nn import MLP
7
+
8
+ from massspecgym.models.base import Stage
9
+ from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel
10
+ from massspecgym.models.layers import FourierFeatures
11
+ from massspecgym.utils import CosSimLoss
12
+
13
+
14
+ class DeepSetsRetrieval(RetrievalMassSpecGymModel):
15
+ def __init__(
16
+ self,
17
+ in_channels: int = 2, # m/z and intensity of a peak
18
+ hidden_channels: int = 512, # hidden layer size
19
+ out_channels: int = 4096, # fingerprint size
20
+ num_layers_per_mlp: int = 2,
21
+ dropout: float = 0.0,
22
+ norm: T.Optional[str] = None,
23
+ fourier_features: bool = True,
24
+ fourier_features_mz_channels: T.Optional[int] = None,
25
+ fourier_features_kwargs: T.Optional[dict] = None,
26
+ **kwargs
27
+ ):
28
+ super().__init__(**kwargs)
29
+
30
+ self.fourier_features = fourier_features
31
+ if fourier_features:
32
+ if fourier_features_kwargs is None:
33
+ fourier_features_kwargs = {}
34
+ self.ff = FourierFeatures(**fourier_features_kwargs)
35
+
36
+ if fourier_features_mz_channels is None:
37
+ fourier_features_mz_channels = int(0.8 * hidden_channels)
38
+ else:
39
+ assert fourier_features_mz_channels < hidden_channels
40
+ self.ff_proj_mz = nn.Linear(self.ff.num_features, fourier_features_mz_channels)
41
+ self.ff_proj_i = nn.Linear(1, hidden_channels - fourier_features_mz_channels)
42
+ in_channels = hidden_channels
43
+
44
+ self.phi = MLP(
45
+ in_channels=in_channels,
46
+ hidden_channels=hidden_channels,
47
+ out_channels=hidden_channels,
48
+ num_layers=num_layers_per_mlp,
49
+ dropout=dropout,
50
+ norm=norm
51
+ )
52
+
53
+ self.rho = MLP(
54
+ in_channels=hidden_channels,
55
+ hidden_channels=hidden_channels,
56
+ out_channels=out_channels,
57
+ num_layers=num_layers_per_mlp,
58
+ dropout=dropout,
59
+ norm=norm
60
+ )
61
+
62
+ self.loss_fn = CosSimLoss()
63
+
64
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
65
+ if self.fourier_features:
66
+ x_mz = x[:, :, 0].unsqueeze(-1)
67
+ x_mz = self.ff(x_mz)
68
+ x_mz = self.ff_proj_mz(x_mz)
69
+ x_i = x[:, :, 1].unsqueeze(-1)
70
+ x_i = self.ff_proj_i(x_i)
71
+ x = torch.cat((x_mz, x_i), dim=-1)
72
+ x = self.phi(x)
73
+ x = x.sum(dim=-2) # sum over peaks
74
+ x = self.rho(x)
75
+ x = F.sigmoid(x) # predict proper fingerprint
76
+ return x
77
+
78
+ def step(
79
+ self, batch: dict, stage: Stage = Stage.NONE
80
+ ) -> tuple[torch.Tensor, torch.Tensor]:
81
+ # Unpack inputs
82
+ x = batch["spec"]
83
+ fp_true = batch["mol"]
84
+ cands = batch["candidates"]
85
+ batch_ptr = batch["batch_ptr"]
86
+
87
+ # Predict fingerprint
88
+ fp_pred = self.forward(x)
89
+
90
+ # Calculate loss
91
+ loss = self.loss_fn(fp_true, fp_pred)
92
+
93
+ # Evaluation performance on fingerprint prediction (optional)
94
+ self.evaluate_fingerprint_step(fp_true, fp_pred, stage=stage)
95
+
96
+ # Calculate final similarity scores between predicted fingerprints and corresponding
97
+ # candidate fingerprints for retrieval
98
+ fp_pred_repeated = fp_pred.repeat_interleave(batch_ptr, dim=0)
99
+ scores = nn.functional.cosine_similarity(fp_pred_repeated, cands)
100
+
101
+ return dict(loss=loss, scores=scores)
massspecgym/models/retrieval/fingerprint_ffn.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as T
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch_geometric.nn import MLP
7
+
8
+ from massspecgym.models.base import Stage
9
+ from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel
10
+ from massspecgym.utils import CosSimLoss
11
+
12
+
13
+ class FingerprintFFNRetrieval(RetrievalMassSpecGymModel):
14
+ def __init__(
15
+ self,
16
+ in_channels: int = 1000, # number of bins
17
+ hidden_channels: int = 512, # hidden layer size
18
+ out_channels: int = 4096, # fingerprint size
19
+ num_layers: int = 2,
20
+ dropout: float = 0.0,
21
+ norm: T.Optional[str] = None,
22
+ **kwargs
23
+ ):
24
+ super().__init__(**kwargs)
25
+
26
+ self.ffn = MLP(
27
+ in_channels=in_channels,
28
+ hidden_channels=hidden_channels,
29
+ out_channels=out_channels,
30
+ num_layers=num_layers,
31
+ dropout=dropout,
32
+ norm=norm
33
+ )
34
+
35
+ self.loss_fn = CosSimLoss()
36
+
37
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
38
+ x = self.ffn(x)
39
+ x = F.sigmoid(x) # predict proper fingerprint
40
+ return x
41
+
42
+ def step(
43
+ self, batch: dict, stage: Stage = Stage.NONE
44
+ ) -> tuple[torch.Tensor, torch.Tensor]:
45
+ # Unpack inputs
46
+ x = batch["spec"]
47
+ fp_true = batch["mol"]
48
+ cands = batch["candidates"]
49
+ batch_ptr = batch["batch_ptr"]
50
+
51
+ # Predict fingerprint
52
+ fp_pred = self.forward(x)
53
+
54
+ # Calculate loss
55
+ loss = self.loss_fn(fp_true, fp_pred)
56
+
57
+ # Evaluation performance on fingerprint prediction (optional)
58
+ self.evaluate_fingerprint_step(fp_true, fp_pred, stage=stage)
59
+
60
+ # Calculate final similarity scores between predicted fingerprints and corresponding
61
+ # candidate fingerprints for retrieval
62
+ fp_pred_repeated = fp_pred.repeat_interleave(batch_ptr, dim=0)
63
+ scores = nn.functional.cosine_similarity(fp_pred_repeated, cands)
64
+
65
+ return dict(loss=loss, scores=scores)
massspecgym/models/retrieval/from_dict.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import typing as T
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from massspecgym.models.base import Stage
9
+ from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel
10
+
11
+
12
+ class FromDictRetrieval(RetrievalMassSpecGymModel):
13
+ """
14
+ Read predictions from dictionary with MassSpecGym ids as keys. Currently, the class
15
+ only implements reading fingerprints from the dictionary.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ dct: T.Optional[dict[str, T.Any]] = None,
21
+ dct_path: T.Optional[T.Union[str, Path]] = None, # pickled dict path
22
+ *args,
23
+ **kwargs
24
+ ):
25
+ super().__init__(*args, **kwargs)
26
+
27
+ if dct is None and dct_path is None:
28
+ raise ValueError("Either dct or dct_path must be provided.")
29
+
30
+ if dct is not None and dct_path is not None:
31
+ raise ValueError("Only one of dct or dct_path must be provided.")
32
+
33
+ if dct_path is not None:
34
+ with open(dct_path, "rb") as file:
35
+ dct = pickle.load(file)
36
+
37
+ dct = {k: torch.tensor(v) for k, v in dct.items()}
38
+ self.dct = dct
39
+
40
+ def step(
41
+ self, batch: dict, stage: Stage = Stage.NONE
42
+ ) -> tuple[torch.Tensor, torch.Tensor]:
43
+ # Unpack inputs
44
+ ids = batch["identifier"]
45
+ fp_true = batch["mol"]
46
+ cands = batch["candidates"]
47
+ batch_ptr = batch["batch_ptr"]
48
+
49
+ # Read predicted fingerprints from dictionary
50
+ fp_pred = torch.stack([self.dct[id] for id in ids]).to(fp_true.device)
51
+
52
+ # Evaluation performance on fingerprint prediction (optional)
53
+ self.evaluate_fingerprint_step(fp_true, fp_pred, stage=stage)
54
+
55
+ # Calculate final similarity scores between predicted fingerprints and corresponding
56
+ # candidate fingerprints for retrieval
57
+ fp_pred_repeated = fp_pred.repeat_interleave(batch_ptr, dim=0)
58
+ scores = nn.functional.cosine_similarity(fp_pred_repeated, cands).to(fp_true.device)
59
+
60
+ # Random baseline, so we return a dummy loss
61
+ loss = torch.tensor(0.0, requires_grad=True, device=fp_true.device)
62
+
63
+ return dict(loss=loss, scores=scores)
64
+
65
+ def configure_optimizers(self):
66
+ # No training, so no optimizers
67
+ return None
massspecgym/models/retrieval/random.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from massspecgym.models.base import Stage
4
+ from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel
5
+
6
+
7
+ class RandomRetrieval(RetrievalMassSpecGymModel):
8
+
9
+ def step(
10
+ self, batch: dict, stage: Stage = Stage.NONE
11
+ ) -> tuple[torch.Tensor, torch.Tensor]:
12
+ # Generate random retrieval scores
13
+ scores = torch.rand(batch["candidates"].shape[0]).to(self.device)
14
+
15
+ # Random baseline, so we return a dummy loss
16
+ loss = torch.tensor(0.0, requires_grad=True)
17
+
18
+ return dict(loss=loss, scores=scores)
19
+
20
+ def configure_optimizers(self):
21
+ # No optimizer needed for a random baseline
22
+ return None
massspecgym/models/simulation/__init__.py ADDED
File without changes
massspecgym/models/simulation/base.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as T
2
+ from abc import ABC
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import pytorch_lightning as pl
7
+ from torchmetrics import RetrievalHitRate, CosineSimilarity
8
+
9
+ from massspecgym.models.base import MassSpecGymModel
10
+
11
+
12
+ class SimulationMassSpecGymModel(MassSpecGymModel, ABC):
13
+
14
+ def on_batch_end(
15
+ self, outputs: T.Any, batch: dict, batch_idx: int, metric_pref: str = ""
16
+ ) -> None:
17
+ """
18
+ Compute evaluation metrics for the retrieval model based on the batch and corresponding predictions.
19
+ This method will be used in the `on_train_batch_end`, `on_validation_batch_end`, since `on_test_batch_end` is
20
+ overriden below.
21
+ """
22
+ self.evaluate_cos_similarity_step(
23
+ outputs["spec_pred"],
24
+ batch["spec"],
25
+ metric_pref=metric_pref,
26
+ )
27
+
28
+ def on_test_batch_end(
29
+ self, outputs: T.Any, batch: dict, batch_idx: int
30
+ ) -> None:
31
+ metric_pref = "_test"
32
+ self.evaluate_cos_similarity_step(
33
+ outputs["spec_pred"],
34
+ batch["spec"],
35
+ metric_pref=metric_pref
36
+ )
37
+ self.evaluate_hit_rate_step(
38
+ outputs["spec_pred"],
39
+ batch["spec"],
40
+ metric_pref=metric_pref
41
+ )
42
+
43
+ def evaluate_cos_similarity_step(
44
+ self,
45
+ specs_pred: torch.Tensor,
46
+ specs: torch.Tensor,
47
+ metric_pref: str = ""
48
+ ) -> None:
49
+ """
50
+ Evaulate cosine similarity.
51
+ """
52
+ raise NotImplementedError
53
+
54
+ def evaluate_hit_rate_step(
55
+ self,
56
+ specs_pred: torch.Tensor,
57
+ specs: torch.Tensor,
58
+ metric_pref: str = ""
59
+ ) -> None:
60
+ """
61
+ Evaulate Hit rate @ {1, 5, 20} (typically reported as Accuracy @ {1, 5, 20}).
62
+ """
63
+ raise NotImplementedError
massspecgym/models/tokenizers.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import typing as T
3
+ import selfies as sf
4
+ from tokenizers import ByteLevelBPETokenizer
5
+ from tokenizers import Tokenizer, processors, models
6
+ from tokenizers.implementations import BaseTokenizer, ByteLevelBPETokenizer
7
+ import massspecgym.utils as utils
8
+ from massspecgym.definitions import PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN
9
+
10
+
11
+ class SpecialTokensBaseTokenizer(BaseTokenizer):
12
+ def __init__(
13
+ self,
14
+ tokenizer: Tokenizer,
15
+ max_len: T.Optional[int] = None,
16
+ ):
17
+ """Initialize the base tokenizer with special tokens performing padding and truncation."""
18
+ super().__init__(tokenizer)
19
+
20
+ # Save essential attributes
21
+ self.pad_token = PAD_TOKEN
22
+ self.sos_token = SOS_TOKEN
23
+ self.eos_token = EOS_TOKEN
24
+ self.unk_token = UNK_TOKEN
25
+ self.max_length = max_len
26
+
27
+ # Add special tokens
28
+ self.add_special_tokens([self.pad_token, self.sos_token, self.eos_token, self.unk_token])
29
+
30
+ # Get token IDs
31
+ self.pad_token_id = self.token_to_id(self.pad_token)
32
+ self.sos_token_id = self.token_to_id(self.sos_token)
33
+ self.eos_token_id = self.token_to_id(self.eos_token)
34
+ self.unk_token_id = self.token_to_id(self.unk_token)
35
+
36
+ # Enable padding
37
+ self.enable_padding(
38
+ direction="right",
39
+ pad_token=self.pad_token,
40
+ pad_id=self.pad_token_id,
41
+ length=max_len,
42
+ )
43
+
44
+ # Enable truncation
45
+ self.enable_truncation(max_len)
46
+
47
+ # Set post-processing to add SOS and EOS tokens
48
+ self._tokenizer.post_processor = processors.TemplateProcessing(
49
+ single=f"{self.sos_token} $A {self.eos_token}",
50
+ pair=f"{self.sos_token} $A {self.eos_token} {self.sos_token} $B {self.eos_token}",
51
+ special_tokens=[
52
+ (self.sos_token, self.sos_token_id),
53
+ (self.eos_token, self.eos_token_id),
54
+ ],
55
+ )
56
+
57
+
58
+ class SelfiesTokenizer(SpecialTokensBaseTokenizer):
59
+ def __init__(
60
+ self,
61
+ selfies_train: T.Optional[T.Union[str, T.List[str]]] = None,
62
+ **kwargs
63
+ ):
64
+ """
65
+ Initialize the SELFIES tokenizer with optional training data to build a vocanulary.
66
+
67
+ Args:
68
+ selfies_train (str or list of str): Either a list of SELFIES strings to build the vocabulary from,
69
+ or a `semantic_robust_alphabet` string indicating the usahe of `selfies.get_semantic_robust_alphabet()`
70
+ alphabet. If None, the MassSpecGym training molecules will be used.
71
+ """
72
+
73
+ if selfies_train == 'semantic_robust_alphabet':
74
+ alphabet = list(sorted(sf.get_semantic_robust_alphabet()))
75
+ else:
76
+ if not selfies_train:
77
+ selfies_train = utils.load_train_mols()
78
+ selfies = [sf.encoder(s, strict=False) for s in selfies_train]
79
+ else:
80
+ selfies = selfies_train
81
+ alphabet = list(sorted(sf.get_alphabet_from_selfies(selfies)))
82
+
83
+ vocab = {symbol: i for i, symbol in enumerate(alphabet)}
84
+ vocab[UNK_TOKEN] = len(vocab)
85
+ tokenizer = Tokenizer(models.WordLevel(vocab=vocab, unk_token=UNK_TOKEN))
86
+
87
+ super().__init__(tokenizer, **kwargs)
88
+
89
+ def encode(self, text: str, add_special_tokens: bool = True) -> Tokenizer:
90
+ """Encodes a SMILES string into a list of SELFIES token IDs."""
91
+ selfies_string = sf.encoder(text, strict=False)
92
+ selfies_tokens = list(sf.split_selfies(selfies_string))
93
+ return super().encode(
94
+ selfies_tokens, is_pretokenized=True, add_special_tokens=add_special_tokens
95
+ )
96
+
97
+ def decode(self, token_ids: T.List[int], skip_special_tokens: bool = True) -> str:
98
+ """Decodes a list of SELFIES token IDs back into a SMILES string."""
99
+ selfies_string = super().decode(
100
+ token_ids, skip_special_tokens=skip_special_tokens
101
+ )
102
+ selfies_string = self._decode_wordlevel_str_to_selfies(selfies_string)
103
+ return sf.decoder(selfies_string)
104
+
105
+ def encode_batch(
106
+ self, texts: T.List[str], add_special_tokens: bool = True
107
+ ) -> T.List[Tokenizer]:
108
+ """Encodes a batch of SMILES strings into a list of SELFIES token IDs."""
109
+ selfies_strings = [
110
+ list(sf.split_selfies(sf.encoder(text, strict=False))) for text in texts
111
+ ]
112
+ return super().encode_batch(
113
+ selfies_strings, is_pretokenized=True, add_special_tokens=add_special_tokens
114
+ )
115
+
116
+ def decode_batch(
117
+ self, token_ids_batch: T.List[T.List[int]], skip_special_tokens: bool = True
118
+ ) -> T.List[str]:
119
+ """Decodes a batch of SELFIES token IDs back into SMILES strings."""
120
+ selfies_strings = super().decode_batch(
121
+ token_ids_batch, skip_special_tokens=skip_special_tokens
122
+ )
123
+ return [
124
+ sf.decoder(
125
+ self._decode_wordlevel_str_to_selfies(
126
+ selfies_string
127
+ )
128
+ )
129
+ for selfies_string in selfies_strings
130
+ ]
131
+
132
+ def _decode_wordlevel_str_to_selfies(self, text: str) -> str:
133
+ """Converts a WordLevel string back to a SELFIES string."""
134
+ return text.replace(" ", "")
135
+
136
+
137
+ class SmilesBPETokenizer(SpecialTokensBaseTokenizer):
138
+ def __init__(self, smiles_pth: T.Optional[str] = None, **kwargs):
139
+ """
140
+ Initialize the BPE tokenizer for SMILES strings, with optional training data.
141
+
142
+ Args:
143
+ smiles_pth (str): Path to a file containing SMILES strings to train the tokenizer on. If None,
144
+ the MassSpecGym training molecules will be used.
145
+ """
146
+ tokenizer = ByteLevelBPETokenizer()
147
+ if smiles_pth:
148
+ tokenizer.train(smiles_pth)
149
+ else:
150
+ smiles = utils.load_unlabeled_mols("smiles").tolist()
151
+ smiles += utils.load_train_mols().tolist()
152
+
153
+ print(f"Training tokenizer on {len(smiles)} SMILES strings.")
154
+ tokenizer.train_from_iterator(smiles)
155
+
156
+ super().__init__(tokenizer, **kwargs)
massspecgym/utils.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ # import seaborn as sns
3
+ import matplotlib as mpl
4
+ import matplotlib.pyplot as plt
5
+ import matplotlib.colors
6
+ import matplotlib.cm as cm
7
+ import matplotlib.colors as mcolors
8
+ import matplotlib.ticker as ticker
9
+ import pandas as pd
10
+ import typing as T
11
+ import pulp
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from itertools import groupby
16
+ from pathlib import Path
17
+ from myopic_mces.myopic_mces import MCES
18
+ from rdkit.Chem import AllChem as Chem
19
+ from rdkit.Chem import DataStructs, Draw
20
+ from rdkit.Chem.Descriptors import ExactMolWt
21
+ # from huggingface_hub import hf_hub_download
22
+ # from standardizeUtils.standardizeUtils import (
23
+ # standardize_structure_with_pubchem,
24
+ # standardize_structure_list_with_pubchem,
25
+ # )
26
+ from torchmetrics.wrappers import BootStrapper
27
+ from torchmetrics.metric import Metric
28
+
29
+
30
+ def load_massspecgym(fold: T.Optional[str] = None) -> pd.DataFrame:
31
+ """
32
+ Load the MassSpecGym dataset.
33
+
34
+ Args:
35
+ fold (str, optional): Fold name to load. If None, the entire dataset is loaded.
36
+ """
37
+ df = pd.read_csv(hugging_face_download("MassSpecGym.tsv"), sep="\t")
38
+ df = df.set_index("identifier")
39
+ df['mzs'] = df['mzs'].apply(parse_spec_array)
40
+ df['intensities'] = df['intensities'].apply(parse_spec_array)
41
+ if fold is not None:
42
+ df = df[df['fold'] == fold]
43
+ return df
44
+
45
+
46
+ def load_unlabeled_mols(col_name: str = "smiles") -> pd.Series:
47
+ """
48
+ Load a list of unlabeled molecules.
49
+
50
+ Args:
51
+ col_name (str, optional): Name of the column to return. Should be one of ["smiles", "selfies"].
52
+ """
53
+ return pd.read_csv(
54
+ hugging_face_download(
55
+ "molecules/MassSpecGym_molecules_MCES2_disjoint_with_test_fold_4M.tsv"
56
+ ),
57
+ sep="\t"
58
+ )[col_name]
59
+
60
+
61
+ def load_train_mols(col_name: str = "smiles") -> pd.Series:
62
+ """
63
+ Load a list of training molecules.
64
+
65
+ Args:
66
+ col_name (str, optional): Name of the column to return. Should be one of ["smiles", "selfies"].
67
+ """
68
+ return load_massspecgym("train")[col_name]
69
+
70
+
71
+ def pad_spectrum(
72
+ spec: np.ndarray, max_n_peaks: int, pad_value: float = 0.0
73
+ ) -> np.ndarray:
74
+ """
75
+ Pad a spectrum to a fixed number of peaks by appending zeros to the end of the spectrum.
76
+
77
+ Args:
78
+ spec (np.ndarray): Spectrum to pad represented as numpy array of shape (n_peaks, 2).
79
+ max_n_peaks (int): Maximum number of peaks in the padded spectrum.
80
+ pad_value (float, optional): Value to use for padding.
81
+ """
82
+ n_peaks = spec.shape[0]
83
+ if n_peaks > max_n_peaks:
84
+ raise ValueError(
85
+ f"Number of peaks in the spectrum ({n_peaks}) is greater than the maximum number of peaks."
86
+ )
87
+ else:
88
+ return np.pad(
89
+ spec,
90
+ ((0, max_n_peaks - n_peaks), (0, 0)),
91
+ mode="constant",
92
+ constant_values=pad_value,
93
+ )
94
+
95
+
96
+ def morgan_fp(mol: Chem.Mol, fp_size=2048, radius=2, to_np=True):
97
+ """
98
+ Compute Morgan fingerprint for a molecule.
99
+
100
+ Args:
101
+ mol (Chem.Mol): _description_
102
+ fp_size (int, optional): Size of the fingerprint.
103
+ radius (int, optional): Radius of the fingerprint.
104
+ to_np (bool, optional): Convert the fingerprint to numpy array.
105
+ """
106
+
107
+ fp = Chem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=fp_size)
108
+ if to_np:
109
+ fp_np = np.zeros((0,), dtype=np.int32)
110
+ DataStructs.ConvertToNumpyArray(fp, fp_np)
111
+ fp = fp_np
112
+ return fp
113
+
114
+
115
+ def tanimoto_morgan_similarity(mol1: T.Union[Chem.Mol, str], mol2: T.Union[Chem.Mol, str]) -> float:
116
+ """
117
+ Compute Tanimoto similarity between two molecules using Morgan fingerprints.
118
+
119
+ Args:
120
+ mol1 (T.Union[Chem.Mol, str]): First molecule as RDKit molecule or SMILES string.
121
+ mol2 (T.Union[Chem.Mol, str]): Second molecule as RDKit molecule or SMILES string.
122
+ """
123
+ if isinstance(mol1, str):
124
+ mol1 = Chem.MolFromSmiles(mol1)
125
+ if isinstance(mol2, str):
126
+ mol2 = Chem.MolFromSmiles(mol2)
127
+ return DataStructs.TanimotoSimilarity(morgan_fp(mol1, to_np=False), morgan_fp(mol2, to_np=False))
128
+
129
+
130
+ def standardize_smiles(smiles: T.Union[str, T.List[str]]) -> T.Union[str, T.List[str]]:
131
+ """
132
+ Standardize SMILES representation of a molecule using PubChem standardization.
133
+ """
134
+ if isinstance(smiles, str):
135
+ return standardize_structure_with_pubchem(smiles, 'smiles')
136
+ elif isinstance(smiles, list):
137
+ return standardize_structure_list_with_pubchem(smiles, 'smiles')
138
+ else:
139
+ raise ValueError("Input should be a SMILES tring or a list of SMILES strings.")
140
+
141
+
142
+ def mol_to_inchi_key(mol: Chem.Mol, twod: bool = True) -> str:
143
+ """
144
+ Convert a molecule to InChI Key representation.
145
+
146
+ Args:
147
+ mol (Chem.Mol): RDKit molecule object.
148
+ twod (bool, optional): Return 2D InChI Key (first 14 characers of InChI Key).
149
+ """
150
+ inchi_key = Chem.MolToInchiKey(mol)
151
+ if twod:
152
+ inchi_key = inchi_key.split("-")[0]
153
+ return inchi_key
154
+
155
+
156
+ def smiles_to_inchi_key(mol: str, twod: bool = True) -> str:
157
+ """
158
+ Convert a SMILES molecule to InChI Key representation.
159
+
160
+ Args:
161
+ mol (str): SMILES string.
162
+ twod (bool, optional): Return 2D InChI Key (first 14 characers of InChI Key).
163
+ """
164
+ mol = Chem.MolFromSmiles(mol)
165
+ return mol_to_inchi_key(mol, twod)
166
+
167
+
168
+ def hugging_face_download(file_name: str) -> str:
169
+ """
170
+ Download a file from the Hugging Face Hub and return its location on disk.
171
+
172
+ Args:
173
+ file_name (str): Name of the file to download.
174
+ """
175
+ return hf_hub_download(
176
+ repo_id="roman-bushuiev/MassSpecGym",
177
+ filename="data/" + file_name,
178
+ repo_type="dataset",
179
+ )
180
+
181
+
182
+ def init_plotting(figsize=(6, 2), font_scale=1.0, style="whitegrid"):
183
+ # Set default figure size
184
+ plt.show() # Does not work without this line for some reason
185
+ sns.set_theme(rc={"figure.figsize": figsize})
186
+ mpl.rcParams['svg.fonttype'] = 'none'
187
+ # Set default style and font scale
188
+ sns.set_style(style)
189
+ sns.set_context("paper", font_scale=font_scale)
190
+ sns.set_palette(["#009473", "#D94F70", "#5A5B9F", "#F0C05A", "#7BC4C4", "#FF6F61"])
191
+
192
+
193
+ def parse_spec_array(arr: str) -> np.ndarray:
194
+ return np.array(list(map(float, arr.split(","))))
195
+
196
+
197
+ def spec_array_to_str(arr: np.ndarray) -> str:
198
+ return ",".join(map(str, arr))
199
+
200
+
201
+ def compute_mass(smiles: str) -> float:
202
+ mol = Chem.MolFromSmiles(smiles)
203
+ if mol is None:
204
+ raise ValueError("Invalid SMILES string.")
205
+ return ExactMolWt(mol)
206
+
207
+
208
+ def plot_spectrum(spec, hue=None, xlim=None, ylim=None, mirror_spec=None, highl_idx=None,
209
+ figsize=(6, 2), colors=None, save_pth=None):
210
+
211
+ if colors is not None:
212
+ assert len(colors) >= 3
213
+ else:
214
+ colors = ['blue', 'green', 'red']
215
+
216
+ # Normalize input spectrum
217
+ def norm_spec(spec):
218
+ assert len(spec.shape) == 2
219
+ if spec.shape[0] != 2:
220
+ spec = spec.T
221
+ mzs, ins = spec[0], spec[1]
222
+ return mzs, ins / max(ins) * 100
223
+ mzs, ins = norm_spec(spec)
224
+
225
+ # Initialize plotting
226
+ init_plotting(figsize=figsize)
227
+ fig, ax = plt.subplots(1, 1)
228
+
229
+ # Setup color palette
230
+ if hue is not None:
231
+ norm = matplotlib.colors.Normalize(vmin=min(hue), vmax=max(hue), clip=True)
232
+ mapper = cm.ScalarMappable(norm=norm, cmap=cm.cool)
233
+ plt.colorbar(mapper, ax=ax)
234
+
235
+ # Plot spectrum
236
+ for i in range(len(mzs)):
237
+ if hue is not None:
238
+ color = mcolors.to_hex(mapper.to_rgba(hue[i]))
239
+ else:
240
+ color = colors[0]
241
+ plt.plot([mzs[i], mzs[i]], [0, ins[i]], color=color, marker='o', markevery=(1, 2), mfc='white', zorder=2)
242
+
243
+ # Plot mirror spectrum
244
+ if mirror_spec is not None:
245
+ mzs_m, ins_m = norm_spec(mirror_spec)
246
+
247
+ @ticker.FuncFormatter
248
+ def major_formatter(x, pos):
249
+ label = str(round(-x)) if x < 0 else str(round(x))
250
+ return label
251
+
252
+ for i in range(len(mzs_m)):
253
+ plt.plot([mzs_m[i], mzs_m[i]], [0, -ins_m[i]], color=colors[2], marker='o', markevery=(1, 2), mfc='white',
254
+ zorder=1)
255
+ ax.yaxis.set_major_formatter(major_formatter)
256
+
257
+ # Setup axes
258
+ if xlim is not None:
259
+ plt.xlim(xlim[0], xlim[1])
260
+ else:
261
+ plt.xlim(0, max(mzs) + 10)
262
+ if ylim is not None:
263
+ plt.ylim(ylim[0], ylim[1])
264
+ plt.xlabel('m/z')
265
+ plt.ylabel('Intensity [%]')
266
+
267
+ if save_pth is not None:
268
+ raise NotImplementedError()
269
+
270
+
271
+ def show_mols(mols, legends='new_indices', smiles_in=False, svg=False, sort_by_legend=False, max_mols=500,
272
+ legend_float_decimals=4, mols_per_row=6, save_pth: T.Optional[Path] = None):
273
+ """
274
+ Returns svg image representing a grid of skeletal structures of the given molecules. Copy-pasted
275
+ from https://github.com/pluskal-lab/DreaMS/blob/main/dreams/utils/mols.py
276
+
277
+ :param mols: list of rdkit molecules
278
+ :param smiles_in: True - SMILES inputs, False - RDKit mols
279
+ :param legends: list of labels for each molecule, length must be equal to the length of mols
280
+ :param svg: True - return svg image, False - return png image
281
+ :param sort_by_legend: True - sort molecules by legend values
282
+ :param max_mols: maximum number of molecules to show
283
+ :param legend_float_decimals: number of decimal places to show for float legends
284
+ :param mols_per_row: number of molecules per row to show
285
+ :param save_pth: path to save the .svg image to
286
+ """
287
+ if smiles_in:
288
+ mols = [Chem.MolFromSmiles(e) for e in mols]
289
+
290
+ if legends == 'new_indices':
291
+ legends = list(range(len(mols)))
292
+ elif legends == 'masses':
293
+ legends = [ExactMolWt(m) for m in mols]
294
+ elif callable(legends):
295
+ legends = [legends(e) for e in mols]
296
+
297
+ if sort_by_legend:
298
+ idx = np.argsort(legends).tolist()
299
+ legends = [legends[i] for i in idx]
300
+ mols = [mols[i] for i in idx]
301
+
302
+ legends = [f'{l:.{legend_float_decimals}f}' if isinstance(l, float) else str(l) for l in legends]
303
+
304
+ img = Draw.MolsToGridImage(mols, maxMols=max_mols, legends=legends, molsPerRow=min(max_mols, mols_per_row),
305
+ useSVG=svg, returnPNG=False)
306
+
307
+ if save_pth:
308
+ with open(save_pth, 'w') as f:
309
+ f.write(img.data)
310
+
311
+ return img
312
+
313
+
314
+ class MyopicMCES():
315
+ def __init__(
316
+ self,
317
+ ind: int = 0, # dummy index
318
+ solver: str = pulp.listSolvers(onlyAvailable=True)[0], # Use the first available solver
319
+ threshold: int = 15, # MCES threshold
320
+ always_stronger_bound: bool = True, # "False" makes computations a lot faster, but leads to overall higher MCES values
321
+ solver_options: dict = None
322
+ ):
323
+ self.ind = ind
324
+ self.solver = solver
325
+ self.threshold = threshold
326
+ self.always_stronger_bound = always_stronger_bound
327
+ if solver_options is None:
328
+ solver_options = dict(msg=0) # make ILP solver silent
329
+ self.solver_options = solver_options
330
+
331
+ # def __call__(self, smiles_1: str, smiles_2: str) -> float:
332
+ # retval = MCES(
333
+ # s1=smiles_1,
334
+ # s2=smiles_2,
335
+ # ind=self.ind,
336
+ # threshold=self.threshold,
337
+ # always_stronger_bound=self.always_stronger_bound,
338
+ # solver=self.solver,
339
+ # solver_options=self.solver_options
340
+ # )
341
+ # dist = retval[1]
342
+ # return dist
343
+ def __call__(self, smiles_1: str, smiles_2: str) -> float:
344
+ retval = MCES(
345
+ smiles_1,
346
+ smiles_2,
347
+ threshold=self.threshold,
348
+ always_stronger_bound=self.always_stronger_bound,
349
+ solver=self.solver,
350
+ solver_options = self.solver_options
351
+ )
352
+ dist = retval[1]
353
+ return dist
354
+
355
+
356
+ class ReturnScalarBootStrapper(BootStrapper):
357
+ def __init__(
358
+ self,
359
+ base_metric: Metric,
360
+ num_bootstraps: int = 10,
361
+ mean: bool = False,
362
+ std: bool = False,
363
+ quantile: T.Optional[T.Union[float, torch.Tensor]] = None,
364
+ raw: bool = False,
365
+ sampling_strategy: str = "poisson",
366
+ **kwargs: T.Any
367
+ ) -> None:
368
+ """Wrapper for BootStrapper that returns a scalar value in compute instead of a dictionary."""
369
+
370
+ if mean + std + bool(quantile) + raw != 1:
371
+ raise ValueError("Exactly one of mean, std, quantile or raw should be True.")
372
+
373
+ if std:
374
+ self.compute_key = "std"
375
+ else:
376
+ raise NotImplementedError("Currently only std is implemented.")
377
+
378
+ super().__init__(
379
+ base_metric=base_metric,
380
+ num_bootstraps=num_bootstraps,
381
+ mean=mean,
382
+ std=std,
383
+ quantile=quantile,
384
+ raw=raw,
385
+ sampling_strategy=sampling_strategy,
386
+ **kwargs
387
+ )
388
+
389
+ def compute(self):
390
+ return super().compute()[self.compute_key]
391
+
392
+
393
+ def batch_ptr_to_batch_idx(batch_ptr: torch.Tensor) -> torch.Tensor:
394
+ """
395
+ Convert a tensor of batch pointers to a tensor of batch indexes.
396
+
397
+ For example [1, 3, 2] -> [0, 1, 1, 1, 2, 2]
398
+
399
+ Args:
400
+ batch_ptr (Tensor): Tensor of batch pointers.
401
+ """
402
+ indexes = torch.arange(batch_ptr.size(0), device=batch_ptr.device)
403
+ indexes = torch.repeat_interleave(indexes, batch_ptr)
404
+ return indexes
405
+
406
+
407
+ def unbatch_list(batch_list: list, batch_idx: torch.Tensor) -> list:
408
+ """
409
+ Unbatch a list of items using the batch indexes (i.e., number of samples per batch).
410
+
411
+ Args:
412
+ batch_list (list): List of items to unbatch.
413
+ batch_idx (Tensor): Tensor of batch indexes.
414
+ """
415
+ return [
416
+ [batch_list[j] for j in range(len(batch_list)) if batch_idx[j] == i]
417
+ for i in range(batch_idx[-1] + 1)
418
+ ]
419
+
420
+
421
+ class CosSimLoss(nn.Module):
422
+ def __init__(self):
423
+ super(CosSimLoss, self).__init__()
424
+
425
+ def forward(self, inputs, targets):
426
+ return 1 - F.cosine_similarity(inputs, targets).mean()
427
+
428
+
429
+ def parse_sirius_ms(spectra_file: str) -> T.Tuple[dict, T.List[T.Tuple[str, np.ndarray]]]:
430
+ """
431
+ Parses spectra from the SIRIUS .ms file.
432
+
433
+ Copied from the code of Goldman et al.:
434
+ https://github.com/samgoldman97/mist/blob/4c23d34fc82425ad5474a53e10b4622dcdbca479/src/mist/utils/parse_utils.py#LL10C77-L10C77.
435
+ :return T.Tuple[dict, T.List[T.Tuple[str, np.ndarray]]]: metadata and list of spectra tuples containing name and array
436
+ """
437
+ lines = [i.strip() for i in open(spectra_file, "r").readlines()]
438
+
439
+ group_num = 0
440
+ metadata = {}
441
+ spectras = []
442
+ my_iterator = groupby(
443
+ lines, lambda line: line.startswith(">") or line.startswith("#")
444
+ )
445
+
446
+ for index, (start_line, lines) in enumerate(my_iterator):
447
+ group_lines = list(lines)
448
+ subject_lines = list(next(my_iterator)[1])
449
+ # Get spectra
450
+ if group_num > 0:
451
+ spectra_header = group_lines[0].split(">")[1]
452
+ peak_data = [
453
+ [float(x) for x in peak.split()[:2]]
454
+ for peak in subject_lines
455
+ if peak.strip()
456
+ ]
457
+ # Check if spectra is empty
458
+ if len(peak_data):
459
+ peak_data = np.vstack(peak_data)
460
+ # Add new tuple
461
+ spectras.append((spectra_header, peak_data))
462
+ # Get meta data
463
+ else:
464
+ entries = {}
465
+ for i in group_lines:
466
+ if " " not in i:
467
+ continue
468
+ elif i.startswith("#INSTRUMENT TYPE"):
469
+ key = "#INSTRUMENT TYPE"
470
+ val = i.split(key)[1].strip()
471
+ entries[key[1:]] = val
472
+ else:
473
+ start, end = i.split(" ", 1)
474
+ start = start[1:]
475
+ while start in entries:
476
+ start = f"{start}'"
477
+ entries[start] = end
478
+
479
+ metadata.update(entries)
480
+ group_num += 1
481
+
482
+ metadata["_FILE_PATH"] = spectra_file
483
+ metadata["_FILE"] = Path(spectra_file).stem
484
+ return metadata, spectras