Spaces:
Sleeping
Sleeping
Commit
·
94aa6f9
1
Parent(s):
c65d76d
partial push
Browse files- .gitignore +176 -0
- massspecgym/__init__.py +0 -0
- massspecgym/data/__init__.py +8 -0
- massspecgym/data/data_module.py +102 -0
- massspecgym/data/datasets.py +225 -0
- massspecgym/data/transforms.py +208 -0
- massspecgym/definitions.py +27 -0
- massspecgym/models/__init__.py +0 -0
- massspecgym/models/base.py +180 -0
- massspecgym/models/de_novo/__init__.py +6 -0
- massspecgym/models/de_novo/base.py +241 -0
- massspecgym/models/de_novo/dummy.py +46 -0
- massspecgym/models/de_novo/random.py +1750 -0
- massspecgym/models/de_novo/smiles_tranformer.py +200 -0
- massspecgym/models/layers.py +101 -0
- massspecgym/models/retrieval/__init__.py +13 -0
- massspecgym/models/retrieval/base.py +206 -0
- massspecgym/models/retrieval/deepsets.py +101 -0
- massspecgym/models/retrieval/fingerprint_ffn.py +65 -0
- massspecgym/models/retrieval/from_dict.py +67 -0
- massspecgym/models/retrieval/random.py +22 -0
- massspecgym/models/simulation/__init__.py +0 -0
- massspecgym/models/simulation/base.py +63 -0
- massspecgym/models/tokenizers.py +156 -0
- massspecgym/utils.py +484 -0
.gitignore
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 110 |
+
.pdm.toml
|
| 111 |
+
.pdm-python
|
| 112 |
+
.pdm-build/
|
| 113 |
+
|
| 114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 115 |
+
__pypackages__/
|
| 116 |
+
|
| 117 |
+
# Celery stuff
|
| 118 |
+
celerybeat-schedule
|
| 119 |
+
celerybeat.pid
|
| 120 |
+
|
| 121 |
+
# SageMath parsed files
|
| 122 |
+
*.sage.py
|
| 123 |
+
|
| 124 |
+
# Environments
|
| 125 |
+
.env
|
| 126 |
+
.venv
|
| 127 |
+
env/
|
| 128 |
+
venv/
|
| 129 |
+
ENV/
|
| 130 |
+
env.bak/
|
| 131 |
+
venv.bak/
|
| 132 |
+
|
| 133 |
+
# Spyder project settings
|
| 134 |
+
.spyderproject
|
| 135 |
+
.spyproject
|
| 136 |
+
|
| 137 |
+
# Rope project settings
|
| 138 |
+
.ropeproject
|
| 139 |
+
|
| 140 |
+
# mkdocs documentation
|
| 141 |
+
/site
|
| 142 |
+
|
| 143 |
+
# mypy
|
| 144 |
+
.mypy_cache/
|
| 145 |
+
.dmypy.json
|
| 146 |
+
dmypy.json
|
| 147 |
+
|
| 148 |
+
# Pyre type checker
|
| 149 |
+
.pyre/
|
| 150 |
+
|
| 151 |
+
# pytype static type analyzer
|
| 152 |
+
.pytype/
|
| 153 |
+
|
| 154 |
+
# Cython debug symbols
|
| 155 |
+
cython_debug/
|
| 156 |
+
|
| 157 |
+
# PyCharm
|
| 158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 162 |
+
#.idea/
|
| 163 |
+
|
| 164 |
+
# data
|
| 165 |
+
data/*
|
| 166 |
+
experiments/main_result/*
|
| 167 |
+
experiments/old/*
|
| 168 |
+
experiments/test_dir/*
|
| 169 |
+
my_notebooks/*
|
| 170 |
+
other/*
|
| 171 |
+
.cache/
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
!data/.gitkeep
|
| 175 |
+
!experiments/.gitkeep
|
| 176 |
+
!data/sample/
|
massspecgym/__init__.py
ADDED
|
File without changes
|
massspecgym/data/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .datasets import MassSpecDataset, RetrievalDataset
|
| 2 |
+
from .data_module import MassSpecDataModule
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"MassSpecDataset",
|
| 6 |
+
"RetrievalDataset",
|
| 7 |
+
"MassSpecDataModule"
|
| 8 |
+
]
|
massspecgym/data/data_module.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import typing as T
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
import massspecgym.utils as utils
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Optional
|
| 8 |
+
from torch.utils.data.dataset import Subset
|
| 9 |
+
from torch.utils.data.dataloader import DataLoader
|
| 10 |
+
from massspecgym.data.datasets import MassSpecDataset
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MassSpecDataModule(pl.LightningDataModule):
|
| 14 |
+
"""
|
| 15 |
+
Data module containing a mass spectrometry dataset. This class is responsible for loading, splitting, and wrapping
|
| 16 |
+
the dataset into data loaders according to pre-defined train, validation, test folds.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
dataset: MassSpecDataset,
|
| 22 |
+
batch_size: int,
|
| 23 |
+
num_workers: int = 0,
|
| 24 |
+
persistent_workers: bool = True,
|
| 25 |
+
split_pth: Optional[Path] = None,
|
| 26 |
+
**kwargs
|
| 27 |
+
):
|
| 28 |
+
"""
|
| 29 |
+
Args:
|
| 30 |
+
split_pth (Optional[Path], optional): Path to a .tsv file with columns "identifier" and "fold",
|
| 31 |
+
corresponding to dataset item IDs, and "fold", containg "train", "val", "test"
|
| 32 |
+
values. Default is None, in which case the split from the `dataset` is used.
|
| 33 |
+
"""
|
| 34 |
+
super().__init__(**kwargs)
|
| 35 |
+
self.dataset = dataset
|
| 36 |
+
self.split_pth = split_pth
|
| 37 |
+
self.batch_size = batch_size
|
| 38 |
+
self.num_workers = num_workers
|
| 39 |
+
self.persistent_workers = persistent_workers if num_workers > 0 else False
|
| 40 |
+
|
| 41 |
+
def prepare_data(self):
|
| 42 |
+
if self.split_pth is None:
|
| 43 |
+
self.split = self.dataset.metadata[["identifier", "fold"]]
|
| 44 |
+
else:
|
| 45 |
+
# NOTE: custom split is not tested
|
| 46 |
+
self.split = pd.read_csv(self.split_pth, sep="\t")
|
| 47 |
+
if set(self.split.columns) != {"identifier", "fold"}:
|
| 48 |
+
raise ValueError('Split file must contain "id" and "fold" columns.')
|
| 49 |
+
self.split["identifier"] = self.split["identifier"].astype(str)
|
| 50 |
+
if set(self.dataset.metadata["identifier"]) != set(self.split["identifier"]):
|
| 51 |
+
raise ValueError(
|
| 52 |
+
"Dataset item IDs must match the IDs in the split file."
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self.split = self.split.set_index("identifier")["fold"]
|
| 56 |
+
if not set(self.split) <= {"train", "val", "test"}:
|
| 57 |
+
raise ValueError(
|
| 58 |
+
'"Folds" column must contain only "train", "val", or "test" values.'
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def setup(self, stage=None):
|
| 62 |
+
split_mask = self.split.loc[self.dataset.metadata["identifier"]].values
|
| 63 |
+
if stage == "fit" or stage is None:
|
| 64 |
+
self.train_dataset = Subset(
|
| 65 |
+
self.dataset, np.where(split_mask == "train")[0]
|
| 66 |
+
)
|
| 67 |
+
self.val_dataset = Subset(self.dataset, np.where(split_mask == "val")[0])
|
| 68 |
+
if stage == "test":
|
| 69 |
+
self.test_dataset = Subset(self.dataset, np.where(split_mask == "test")[0])
|
| 70 |
+
|
| 71 |
+
def train_dataloader(self):
|
| 72 |
+
return DataLoader(
|
| 73 |
+
self.train_dataset,
|
| 74 |
+
batch_size=self.batch_size,
|
| 75 |
+
shuffle=True,
|
| 76 |
+
num_workers=self.num_workers,
|
| 77 |
+
persistent_workers=self.persistent_workers,
|
| 78 |
+
drop_last=False,
|
| 79 |
+
collate_fn=self.dataset.collate_fn,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def val_dataloader(self):
|
| 83 |
+
return DataLoader(
|
| 84 |
+
self.val_dataset,
|
| 85 |
+
batch_size=self.batch_size,
|
| 86 |
+
shuffle=False,
|
| 87 |
+
num_workers=self.num_workers,
|
| 88 |
+
persistent_workers=self.persistent_workers,
|
| 89 |
+
drop_last=False,
|
| 90 |
+
collate_fn=self.dataset.collate_fn,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def test_dataloader(self):
|
| 94 |
+
return DataLoader(
|
| 95 |
+
self.test_dataset,
|
| 96 |
+
batch_size=self.batch_size,
|
| 97 |
+
shuffle=False,
|
| 98 |
+
num_workers=self.num_workers,
|
| 99 |
+
persistent_workers=self.persistent_workers,
|
| 100 |
+
drop_last=False,
|
| 101 |
+
collate_fn=self.dataset.collate_fn,
|
| 102 |
+
)
|
massspecgym/data/datasets.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import json
|
| 3 |
+
import typing as T
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import matchms
|
| 7 |
+
import massspecgym.utils as utils
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from rdkit import Chem
|
| 10 |
+
from torch.utils.data.dataset import Dataset
|
| 11 |
+
from torch.utils.data.dataloader import default_collate
|
| 12 |
+
from matchms.importing import load_from_mgf
|
| 13 |
+
from massspecgym.data.transforms import SpecTransform, MolTransform, MolToInChIKey
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class MassSpecDataset(Dataset):
|
| 17 |
+
"""
|
| 18 |
+
Dataset containing mass spectra and their corresponding molecular structures. This class is
|
| 19 |
+
responsible for loading the data from disk and applying transformation steps to the spectra and
|
| 20 |
+
molecules.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
spec_transform: T.Optional[T.Union[SpecTransform, T.Dict[str, SpecTransform]]] = None,
|
| 26 |
+
mol_transform: T.Optional[T.Union[MolTransform, T.Dict[str, MolTransform]]] = None,
|
| 27 |
+
pth: T.Optional[Path] = None,
|
| 28 |
+
return_mol_freq: bool = True,
|
| 29 |
+
return_identifier: bool = True,
|
| 30 |
+
dtype: T.Type = torch.float32
|
| 31 |
+
):
|
| 32 |
+
"""
|
| 33 |
+
Args:
|
| 34 |
+
pth (Optional[Path], optional): Path to the .tsv or .mgf file containing the mass spectra.
|
| 35 |
+
Default is None, in which case the MassSpecGym dataset is downloaded from HuggingFace Hub.
|
| 36 |
+
"""
|
| 37 |
+
self.pth = pth
|
| 38 |
+
self.spec_transform = spec_transform
|
| 39 |
+
self.mol_transform = mol_transform
|
| 40 |
+
self.return_mol_freq = return_mol_freq
|
| 41 |
+
|
| 42 |
+
if self.pth is None:
|
| 43 |
+
self.pth = utils.hugging_face_download("MassSpecGym.tsv")
|
| 44 |
+
|
| 45 |
+
if isinstance(self.pth, str):
|
| 46 |
+
self.pth = Path(self.pth)
|
| 47 |
+
|
| 48 |
+
if self.pth.suffix == ".tsv":
|
| 49 |
+
self.metadata = pd.read_csv(self.pth, sep="\t")
|
| 50 |
+
self.spectra = self.metadata.apply(
|
| 51 |
+
lambda row: matchms.Spectrum(
|
| 52 |
+
mz=np.array([float(m) for m in row["mzs"].split(",")]),
|
| 53 |
+
intensities=np.array(
|
| 54 |
+
[float(i) for i in row["intensities"].split(",")]
|
| 55 |
+
),
|
| 56 |
+
metadata={"precursor_mz": row["precursor_mz"]},
|
| 57 |
+
),
|
| 58 |
+
axis=1,
|
| 59 |
+
)
|
| 60 |
+
self.metadata = self.metadata.drop(columns=["mzs", "intensities"])
|
| 61 |
+
elif self.pth.suffix == ".mgf":
|
| 62 |
+
self.spectra = list(load_from_mgf(str(self.pth)))
|
| 63 |
+
self.metadata = pd.DataFrame([s.metadata for s in self.spectra])
|
| 64 |
+
else:
|
| 65 |
+
raise ValueError(f"{self.pth.suffix} file format not supported.")
|
| 66 |
+
|
| 67 |
+
if self.return_mol_freq:
|
| 68 |
+
if "inchikey" not in self.metadata.columns:
|
| 69 |
+
self.metadata["inchikey"] = self.metadata["smiles"].apply(utils.smiles_to_inchi_key)
|
| 70 |
+
self.metadata["mol_freq"] = self.metadata.groupby("inchikey")["inchikey"].transform("count")
|
| 71 |
+
|
| 72 |
+
self.return_identifier = return_identifier
|
| 73 |
+
self.dtype = dtype
|
| 74 |
+
|
| 75 |
+
def __len__(self) -> int:
|
| 76 |
+
return len(self.spectra)
|
| 77 |
+
|
| 78 |
+
def __getitem__(
|
| 79 |
+
self, i: int, transform_spec: bool = True, transform_mol: bool = True
|
| 80 |
+
) -> dict:
|
| 81 |
+
spec = self.spectra[i]
|
| 82 |
+
metadata = self.metadata.iloc[i]
|
| 83 |
+
mol = metadata["smiles"]
|
| 84 |
+
|
| 85 |
+
# Apply all transformations to the spectrum
|
| 86 |
+
item = {}
|
| 87 |
+
if transform_spec and self.spec_transform:
|
| 88 |
+
if isinstance(self.spec_transform, dict):
|
| 89 |
+
for key, transform in self.spec_transform.items():
|
| 90 |
+
item[key] = transform(spec) if transform is not None else spec
|
| 91 |
+
else:
|
| 92 |
+
item["spec"] = self.spec_transform(spec)
|
| 93 |
+
else:
|
| 94 |
+
item["spec"] = spec
|
| 95 |
+
|
| 96 |
+
# Apply all transformations to the molecule
|
| 97 |
+
if transform_mol and self.mol_transform:
|
| 98 |
+
if isinstance(self.mol_transform, dict):
|
| 99 |
+
for key, transform in self.mol_transform.items():
|
| 100 |
+
item[key] = transform(mol) if transform is not None else mol
|
| 101 |
+
else:
|
| 102 |
+
item["mol"] = self.mol_transform(mol)
|
| 103 |
+
else:
|
| 104 |
+
item["mol"] = mol
|
| 105 |
+
|
| 106 |
+
# Add other metadata to the item
|
| 107 |
+
# item.update({
|
| 108 |
+
# k: metadata[k] for k in ["precursor_mz", "adduct"]
|
| 109 |
+
# })
|
| 110 |
+
|
| 111 |
+
if self.return_mol_freq:
|
| 112 |
+
item["mol_freq"] = metadata["mol_freq"]
|
| 113 |
+
|
| 114 |
+
if self.return_identifier:
|
| 115 |
+
item["identifier"] = metadata["identifier"]
|
| 116 |
+
|
| 117 |
+
# TODO: this should be refactored
|
| 118 |
+
for k, v in item.items():
|
| 119 |
+
if not isinstance(v, str):
|
| 120 |
+
try:
|
| 121 |
+
item[k] = torch.as_tensor(v, dtype=self.dtype)
|
| 122 |
+
except:
|
| 123 |
+
continue
|
| 124 |
+
|
| 125 |
+
return item
|
| 126 |
+
|
| 127 |
+
@staticmethod
|
| 128 |
+
def collate_fn(batch: T.Iterable[dict]) -> dict:
|
| 129 |
+
"""
|
| 130 |
+
Custom collate function to handle the outputs of __getitem__.
|
| 131 |
+
"""
|
| 132 |
+
return default_collate(batch)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class RetrievalDataset(MassSpecDataset):
|
| 136 |
+
"""
|
| 137 |
+
Dataset containing mass spectra and their corresponding molecular structures, with additional
|
| 138 |
+
candidates of molecules for retrieval based on spectral similarity.
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
def __init__(
|
| 142 |
+
self,
|
| 143 |
+
mol_label_transform: MolTransform = MolToInChIKey(),
|
| 144 |
+
candidates_pth: T.Optional[T.Union[Path, str]] = None,
|
| 145 |
+
**kwargs,
|
| 146 |
+
):
|
| 147 |
+
super().__init__(**kwargs)
|
| 148 |
+
|
| 149 |
+
self.candidates_pth = candidates_pth
|
| 150 |
+
self.mol_label_transform = mol_label_transform
|
| 151 |
+
|
| 152 |
+
# Download candidates from HuggigFace Hub if not a path to exisiting file is passed
|
| 153 |
+
if self.candidates_pth is None:
|
| 154 |
+
self.candidates_pth = utils.hugging_face_download(
|
| 155 |
+
"molecules/MassSpecGym_retrieval_candidates_mass.json"
|
| 156 |
+
)
|
| 157 |
+
elif isinstance(self.candidates_pth, str):
|
| 158 |
+
if Path(self.candidates_pth).is_file():
|
| 159 |
+
self.candidates_pth = Path(self.candidates_pth)
|
| 160 |
+
else:
|
| 161 |
+
self.candidates_pth = utils.hugging_face_download(candidates_pth)
|
| 162 |
+
|
| 163 |
+
# Read candidates_pth from json to dict: SMILES -> respective candidate SMILES
|
| 164 |
+
with open(self.candidates_pth, "r") as file:
|
| 165 |
+
self.candidates = json.load(file)
|
| 166 |
+
|
| 167 |
+
def __getitem__(self, i) -> dict:
|
| 168 |
+
item = super().__getitem__(i, transform_mol=False)
|
| 169 |
+
|
| 170 |
+
# Save the original SMILES representation of the query molecule (for evaluation)
|
| 171 |
+
item["smiles"] = item["mol"]
|
| 172 |
+
|
| 173 |
+
# Get candidates
|
| 174 |
+
if item["mol"] not in self.candidates:
|
| 175 |
+
raise ValueError(f'No candidates for the query molecule {item["mol"]}.')
|
| 176 |
+
item["candidates"] = self.candidates[item["mol"]]
|
| 177 |
+
|
| 178 |
+
# Save the original SMILES representations of the canidates (for evaluation)
|
| 179 |
+
item["candidates_smiles"] = item["candidates"]
|
| 180 |
+
|
| 181 |
+
# Create neg/pos label mask by matching the query molecule with the candidates
|
| 182 |
+
item_label = self.mol_label_transform(item["mol"])
|
| 183 |
+
item["labels"] = [
|
| 184 |
+
self.mol_label_transform(c) == item_label for c in item["candidates"]
|
| 185 |
+
]
|
| 186 |
+
|
| 187 |
+
if not any(item["labels"]):
|
| 188 |
+
raise ValueError(
|
| 189 |
+
f'Query molecule {item["mol"]} not found in the candidates list.'
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# Transform the query and candidate molecules
|
| 193 |
+
item["mol"] = self.mol_transform(item["mol"])
|
| 194 |
+
item["candidates"] = [self.mol_transform(c) for c in item["candidates"]]
|
| 195 |
+
if isinstance(item["mol"], np.ndarray):
|
| 196 |
+
item["mol"] = torch.as_tensor(item["mol"], dtype=self.dtype)
|
| 197 |
+
# item["candidates"] = [torch.as_tensor(c, dtype=self.dtype) for c in item["candidates"]]
|
| 198 |
+
|
| 199 |
+
return item
|
| 200 |
+
|
| 201 |
+
@staticmethod
|
| 202 |
+
def collate_fn(batch: T.Iterable[dict]) -> dict:
|
| 203 |
+
# Standard collate for everything except candidates and their labels (which may have different length per sample)
|
| 204 |
+
collated_batch = {}
|
| 205 |
+
for k in batch[0].keys():
|
| 206 |
+
if k not in ["candidates", "labels", "candidates_smiles"]:
|
| 207 |
+
collated_batch[k] = default_collate([item[k] for item in batch])
|
| 208 |
+
|
| 209 |
+
# Collate candidates and labels by concatenating and storing sizes of each list
|
| 210 |
+
collated_batch["candidates"] = torch.as_tensor(
|
| 211 |
+
np.concatenate([item["candidates"] for item in batch])
|
| 212 |
+
)
|
| 213 |
+
collated_batch["labels"] = torch.as_tensor(
|
| 214 |
+
sum([item["labels"] for item in batch], start=[])
|
| 215 |
+
)
|
| 216 |
+
collated_batch["batch_ptr"] = torch.as_tensor(
|
| 217 |
+
[len(item["candidates"]) for item in batch]
|
| 218 |
+
)
|
| 219 |
+
collated_batch["candidates_smiles"] = \
|
| 220 |
+
sum([item["candidates_smiles"] for item in batch], start=[])
|
| 221 |
+
|
| 222 |
+
return collated_batch
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
# TODO: Datasets for unlabeled data.
|
massspecgym/data/transforms.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import matchms
|
| 4 |
+
import matchms.filtering as ms_filters
|
| 5 |
+
from rdkit.Chem import AllChem as Chem
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
import massspecgym.utils as utils
|
| 9 |
+
from massspecgym.definitions import CHEM_ELEMS
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SpecTransform(ABC):
|
| 13 |
+
"""
|
| 14 |
+
Base class for spectrum transformations. Custom transformatios should inherit from this class.
|
| 15 |
+
The transformation consists of two consecutive steps:
|
| 16 |
+
1. Apply a series of matchms filters to the input spectrum (method `matchms_transforms`).
|
| 17 |
+
2. Convert the matchms spectrum to a torch tensor (method `matchms_to_torch`).
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
@abstractmethod
|
| 21 |
+
def matchms_transforms(self, spec: matchms.Spectrum) -> matchms.Spectrum:
|
| 22 |
+
"""
|
| 23 |
+
Apply a series of matchms filters to the input spectrum. Abstract method.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
@abstractmethod
|
| 27 |
+
def matchms_to_torch(self, spec: matchms.Spectrum) -> torch.Tensor:
|
| 28 |
+
"""
|
| 29 |
+
Convert a matchms spectrum to a torch tensor. Abstract method.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __call__(self, spec: matchms.Spectrum) -> torch.Tensor:
|
| 33 |
+
"""
|
| 34 |
+
Compose the matchms filters and the torch conversion.
|
| 35 |
+
"""
|
| 36 |
+
return self.matchms_to_torch(self.matchms_transforms(spec))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def default_matchms_transforms(
|
| 40 |
+
spec: matchms.Spectrum,
|
| 41 |
+
n_max_peaks: int = 60,
|
| 42 |
+
mz_from: float = 10,
|
| 43 |
+
mz_to: float = 1000,
|
| 44 |
+
) -> matchms.Spectrum:
|
| 45 |
+
spec = ms_filters.select_by_mz(spec, mz_from=mz_from, mz_to=mz_to)
|
| 46 |
+
if n_max_peaks is not None:
|
| 47 |
+
spec = ms_filters.reduce_to_number_of_peaks(spec, n_max=n_max_peaks)
|
| 48 |
+
spec = ms_filters.normalize_intensities(spec)
|
| 49 |
+
return spec
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class SpecTokenizer(SpecTransform):
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
n_peaks: Optional[int] = 60,
|
| 56 |
+
prec_mz_intensity: Optional[float] = 1.1,
|
| 57 |
+
matchms_kwargs: Optional[dict] = None
|
| 58 |
+
) -> None:
|
| 59 |
+
self.n_peaks = n_peaks
|
| 60 |
+
self.prec_mz_intensity = prec_mz_intensity
|
| 61 |
+
self.matchms_kwargs = matchms_kwargs if matchms_kwargs is not None else {}
|
| 62 |
+
|
| 63 |
+
def matchms_transforms(self, spec: matchms.Spectrum) -> matchms.Spectrum:
|
| 64 |
+
return default_matchms_transforms(spec, n_max_peaks=self.n_peaks, **self.matchms_kwargs)
|
| 65 |
+
|
| 66 |
+
def matchms_to_torch(self, spec: matchms.Spectrum) -> torch.Tensor:
|
| 67 |
+
"""
|
| 68 |
+
Stack arrays of mz and intensities into a matrix of shape (num_peaks, 2).
|
| 69 |
+
If the number of peaks is less than `n_peaks`, pad the matrix with zeros.
|
| 70 |
+
"""
|
| 71 |
+
spec_t = np.vstack([spec.peaks.mz, spec.peaks.intensities]).T
|
| 72 |
+
if self.prec_mz_intensity is not None:
|
| 73 |
+
spec_t = np.vstack([[spec.metadata["precursor_mz"], self.prec_mz_intensity], spec_t])
|
| 74 |
+
if self.n_peaks is not None:
|
| 75 |
+
spec_t = utils.pad_spectrum(
|
| 76 |
+
spec_t,
|
| 77 |
+
self.n_peaks + 1 if self.prec_mz_intensity is not None else self.n_peaks
|
| 78 |
+
)
|
| 79 |
+
return torch.from_numpy(spec_t)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class SpecBinner(SpecTransform):
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
max_mz: float = 1005,
|
| 86 |
+
bin_width: float = 1,
|
| 87 |
+
to_rel_intensities: bool = True,
|
| 88 |
+
) -> None:
|
| 89 |
+
self.max_mz = max_mz
|
| 90 |
+
self.bin_width = bin_width
|
| 91 |
+
self.to_rel_intensities = to_rel_intensities
|
| 92 |
+
if not (max_mz / bin_width).is_integer():
|
| 93 |
+
raise ValueError("`max_mz` must be divisible by `bin_width`.")
|
| 94 |
+
|
| 95 |
+
def matchms_transforms(self, spec: matchms.Spectrum) -> matchms.Spectrum:
|
| 96 |
+
return default_matchms_transforms(spec, mz_to=self.max_mz, n_max_peaks=None)
|
| 97 |
+
|
| 98 |
+
def matchms_to_torch(self, spec: matchms.Spectrum) -> torch.Tensor:
|
| 99 |
+
"""
|
| 100 |
+
Bin the spectrum into a fixed number of bins.
|
| 101 |
+
"""
|
| 102 |
+
binned_spec = self._bin_mass_spectrum(
|
| 103 |
+
mzs=spec.peaks.mz,
|
| 104 |
+
intensities=spec.peaks.intensities,
|
| 105 |
+
max_mz=self.max_mz,
|
| 106 |
+
bin_width=self.bin_width,
|
| 107 |
+
to_rel_intensities=self.to_rel_intensities,
|
| 108 |
+
)
|
| 109 |
+
return torch.from_numpy(binned_spec)
|
| 110 |
+
|
| 111 |
+
def _bin_mass_spectrum(
|
| 112 |
+
self, mzs, intensities, max_mz, bin_width, to_rel_intensities=True
|
| 113 |
+
):
|
| 114 |
+
# Calculate the number of bins
|
| 115 |
+
num_bins = int(np.ceil(max_mz / bin_width))
|
| 116 |
+
|
| 117 |
+
# Calculate the bin indices for each mass
|
| 118 |
+
bin_indices = np.floor(mzs / bin_width).astype(int)
|
| 119 |
+
|
| 120 |
+
# Filter out mzs that exceed max_mz
|
| 121 |
+
valid_indices = bin_indices[mzs <= max_mz]
|
| 122 |
+
valid_intensities = intensities[mzs <= max_mz]
|
| 123 |
+
|
| 124 |
+
# Clip bin indices to ensure they are within the valid range
|
| 125 |
+
valid_indices = np.clip(valid_indices, 0, num_bins - 1)
|
| 126 |
+
|
| 127 |
+
# Initialize an array to store the binned intensities
|
| 128 |
+
binned_intensities = np.zeros(num_bins)
|
| 129 |
+
|
| 130 |
+
# Use np.add.at to sum intensities in the appropriate bins
|
| 131 |
+
np.add.at(binned_intensities, valid_indices, valid_intensities)
|
| 132 |
+
|
| 133 |
+
# Generate the bin edges for reference
|
| 134 |
+
# bin_edges = np.arange(0, max_mz + bin_width, bin_width)
|
| 135 |
+
|
| 136 |
+
# Normalize the intensities to relative intensities
|
| 137 |
+
if to_rel_intensities:
|
| 138 |
+
binned_intensities /= np.max(binned_intensities)
|
| 139 |
+
|
| 140 |
+
return binned_intensities # , bin_edges
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class MolTransform(ABC):
|
| 144 |
+
@abstractmethod
|
| 145 |
+
def from_smiles(self, mol: str):
|
| 146 |
+
"""
|
| 147 |
+
Convert a SMILES string to a tensor-like representation. Abstract method.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def __call__(self, mol: str):
|
| 151 |
+
return self.from_smiles(mol)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class MolFingerprinter(MolTransform):
|
| 155 |
+
def __init__(self, type: str = "morgan", fp_size: int = 2048, radius: int = 2):
|
| 156 |
+
if type != "morgan":
|
| 157 |
+
raise NotImplementedError(
|
| 158 |
+
"Only Morgan fingerprints are implemented at the moment."
|
| 159 |
+
)
|
| 160 |
+
self.type = type
|
| 161 |
+
self.fp_size = fp_size
|
| 162 |
+
self.radius = radius
|
| 163 |
+
|
| 164 |
+
def from_smiles(self, mol: str):
|
| 165 |
+
mol = Chem.MolFromSmiles(mol)
|
| 166 |
+
return utils.morgan_fp(
|
| 167 |
+
mol, fp_size=self.fp_size, radius=self.radius, to_np=True
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class MolToInChIKey(MolTransform):
|
| 172 |
+
def __init__(self, twod: bool = True) -> None:
|
| 173 |
+
self.twod = twod
|
| 174 |
+
|
| 175 |
+
def from_smiles(self, mol: str) -> str:
|
| 176 |
+
mol = Chem.MolFromSmiles(mol)
|
| 177 |
+
return utils.mol_to_inchi_key(mol, twod=self.twod)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class MolToFormulaVector(MolTransform):
|
| 181 |
+
def __init__(self):
|
| 182 |
+
self.element_index = {element: i for i, element in enumerate(CHEM_ELEMS)}
|
| 183 |
+
|
| 184 |
+
def from_smiles(self, smiles: str):
|
| 185 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 186 |
+
if mol is None:
|
| 187 |
+
raise ValueError(f"Invalid SMILES string: {smiles}")
|
| 188 |
+
|
| 189 |
+
# Add explicit hydrogens to the molecule
|
| 190 |
+
mol = Chem.AddHs(mol)
|
| 191 |
+
|
| 192 |
+
# Initialize a vector of zeros for the 118 elements
|
| 193 |
+
formula_vector = np.zeros(118, dtype=np.int32)
|
| 194 |
+
|
| 195 |
+
# Iterate over atoms in the molecule and count occurrences of each element
|
| 196 |
+
for atom in mol.GetAtoms():
|
| 197 |
+
symbol = atom.GetSymbol()
|
| 198 |
+
if symbol in self.element_index:
|
| 199 |
+
index = self.element_index[symbol]
|
| 200 |
+
formula_vector[index] += 1
|
| 201 |
+
else:
|
| 202 |
+
raise ValueError(f"Element '{symbol}' not found in the list of 118 elements.")
|
| 203 |
+
|
| 204 |
+
return formula_vector
|
| 205 |
+
|
| 206 |
+
@staticmethod
|
| 207 |
+
def num_elements():
|
| 208 |
+
return len(CHEM_ELEMS)
|
massspecgym/definitions.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Global variables used across the package."""
|
| 2 |
+
import pathlib
|
| 3 |
+
|
| 4 |
+
# Dirs
|
| 5 |
+
MASSSPECGYM_ROOT_DIR = pathlib.Path(__file__).parent.absolute()
|
| 6 |
+
MASSSPECGYM_REPO_DIR = MASSSPECGYM_ROOT_DIR.parent
|
| 7 |
+
MASSSPECGYM_DATA_DIR = MASSSPECGYM_REPO_DIR / 'data'
|
| 8 |
+
MASSSPECGYM_TEST_RESULTS_DIR = MASSSPECGYM_DATA_DIR / 'test_results'
|
| 9 |
+
MASSSPECGYM_ASSETS_DIR = MASSSPECGYM_REPO_DIR / 'assets'
|
| 10 |
+
|
| 11 |
+
# Special tokens
|
| 12 |
+
PAD_TOKEN = "<pad>"
|
| 13 |
+
SOS_TOKEN = "<s>"
|
| 14 |
+
EOS_TOKEN = "</s>"
|
| 15 |
+
UNK_TOKEN = "<unk>"
|
| 16 |
+
|
| 17 |
+
# Chemistry
|
| 18 |
+
# List of all 118 elements (indexed by atomic number)
|
| 19 |
+
CHEM_ELEMS = [
|
| 20 |
+
"H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar",
|
| 21 |
+
"K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr",
|
| 22 |
+
"Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe",
|
| 23 |
+
"Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu",
|
| 24 |
+
"Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac",
|
| 25 |
+
"Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh",
|
| 26 |
+
"Hs", "Mt", "Ds", "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og"
|
| 27 |
+
]
|
massspecgym/models/__init__.py
ADDED
|
File without changes
|
massspecgym/models/base.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import typing as T
|
| 2 |
+
import collections
|
| 3 |
+
from enum import Enum
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import pytorch_lightning as pl
|
| 9 |
+
from torchmetrics import Metric, SumMetric
|
| 10 |
+
from massspecgym.utils import ReturnScalarBootStrapper
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Stage(Enum):
|
| 14 |
+
|
| 15 |
+
TRAIN = 'train'
|
| 16 |
+
VAL = 'val'
|
| 17 |
+
TEST = 'test'
|
| 18 |
+
NONE = 'none'
|
| 19 |
+
|
| 20 |
+
def to_pref(self) -> str:
|
| 21 |
+
return f"{self.value}_" if self != Stage.NONE else ""
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MassSpecGymModel(pl.LightningModule, ABC):
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
lr: float = 1e-4,
|
| 29 |
+
weight_decay: float = 0.0,
|
| 30 |
+
log_only_loss_at_stages: T.Sequence[Stage | str] = (),
|
| 31 |
+
bootstrap_metrics: bool = True,
|
| 32 |
+
df_test_path: T.Optional[str | Path] = None,
|
| 33 |
+
*args,
|
| 34 |
+
**kwargs
|
| 35 |
+
):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.save_hyperparameters()
|
| 38 |
+
|
| 39 |
+
# Setup metring logging
|
| 40 |
+
self.log_only_loss_at_stages = [
|
| 41 |
+
Stage(s) if isinstance(s, str) else s for s in log_only_loss_at_stages
|
| 42 |
+
]
|
| 43 |
+
self.bootstrap_metrics = bootstrap_metrics
|
| 44 |
+
|
| 45 |
+
# Init dictionary to store dataframe columns where rows correspond to samples
|
| 46 |
+
# (for constructing test dataframe with predictions and metrics for each sample)
|
| 47 |
+
self.df_test_path = Path(df_test_path) if df_test_path is not None else None
|
| 48 |
+
self.df_test = collections.defaultdict(list)
|
| 49 |
+
|
| 50 |
+
@abstractmethod
|
| 51 |
+
def step(
|
| 52 |
+
self, batch: dict, stage: Stage = Stage.NONE
|
| 53 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 54 |
+
raise NotImplementedError(
|
| 55 |
+
"Method `step` must be implemented in the model-specific child class."
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def training_step(
|
| 59 |
+
self, batch: dict, batch_idx: torch.Tensor
|
| 60 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 61 |
+
return self.step(batch, stage=Stage.TRAIN)
|
| 62 |
+
|
| 63 |
+
def validation_step(
|
| 64 |
+
self, batch: dict, batch_idx: torch.Tensor
|
| 65 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 66 |
+
return self.step(batch, stage=Stage.VAL)
|
| 67 |
+
|
| 68 |
+
def test_step(
|
| 69 |
+
self, batch: dict, batch_idx: torch.Tensor
|
| 70 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 71 |
+
return self.step(batch, stage=Stage.TEST)
|
| 72 |
+
|
| 73 |
+
@abstractmethod
|
| 74 |
+
def on_batch_end(
|
| 75 |
+
self, outputs: T.Any, batch: dict, batch_idx: int, stage: Stage
|
| 76 |
+
) -> None:
|
| 77 |
+
"""
|
| 78 |
+
Method to be called at the end of each batch. This method should be implemented by a child,
|
| 79 |
+
task-dedicated class and contain the evaluation necessary for the task.
|
| 80 |
+
"""
|
| 81 |
+
raise NotImplementedError(
|
| 82 |
+
"Method `on_batch_end` must be implemented in the task-specific child class."
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def on_train_batch_end(self, *args, **kwargs):
|
| 86 |
+
return self.on_batch_end(*args, **kwargs, stage=Stage.TRAIN)
|
| 87 |
+
|
| 88 |
+
def on_validation_batch_end(self, *args, **kwargs):
|
| 89 |
+
return self.on_batch_end(*args, **kwargs, stage=Stage.VAL)
|
| 90 |
+
|
| 91 |
+
def on_test_batch_end(self, *args, **kwargs):
|
| 92 |
+
return self.on_batch_end(*args, **kwargs, stage=Stage.TEST)
|
| 93 |
+
|
| 94 |
+
def configure_optimizers(self):
|
| 95 |
+
return torch.optim.Adam(
|
| 96 |
+
self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def get_checkpoint_monitors(self) -> list[dict]:
|
| 100 |
+
monitors = [
|
| 101 |
+
{"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": True}
|
| 102 |
+
]
|
| 103 |
+
return monitors
|
| 104 |
+
|
| 105 |
+
def _update_metric(
|
| 106 |
+
self,
|
| 107 |
+
name: str,
|
| 108 |
+
metric_class: type[Metric],
|
| 109 |
+
update_args: T.Any,
|
| 110 |
+
batch_size: T.Optional[int] = None,
|
| 111 |
+
prog_bar: bool = False,
|
| 112 |
+
metric_kwargs: T.Optional[dict] = None,
|
| 113 |
+
log: bool = True,
|
| 114 |
+
log_n_samples: bool = False,
|
| 115 |
+
bootstrap: bool = False,
|
| 116 |
+
num_bootstraps: int = 100
|
| 117 |
+
) -> None:
|
| 118 |
+
"""
|
| 119 |
+
This method enables updating and logging metrics without instantiating them in advance in
|
| 120 |
+
the __init__ method. The metrics are aggreated over batches and logged at the end of the
|
| 121 |
+
epoch. If the metric does not exist yet, it is instantiated and added as an attribute to the
|
| 122 |
+
model.
|
| 123 |
+
"""
|
| 124 |
+
# Process arguments
|
| 125 |
+
bootstrap = bootstrap and self.bootstrap_metrics
|
| 126 |
+
|
| 127 |
+
# Log total number of samples (useful for debugging)
|
| 128 |
+
if log_n_samples:
|
| 129 |
+
self._update_metric(
|
| 130 |
+
name=name + "_n_samples",
|
| 131 |
+
metric_class=SumMetric,
|
| 132 |
+
update_args=(len(update_args[0]),),
|
| 133 |
+
batch_size=1,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Init metric if does not exits yet
|
| 137 |
+
if hasattr(self, name):
|
| 138 |
+
metric = getattr(self, name)
|
| 139 |
+
else:
|
| 140 |
+
if metric_kwargs is None:
|
| 141 |
+
metric_kwargs = dict()
|
| 142 |
+
metric = metric_class(**metric_kwargs)
|
| 143 |
+
metric = metric.to(self.device)
|
| 144 |
+
setattr(self, name, metric)
|
| 145 |
+
|
| 146 |
+
# Update
|
| 147 |
+
metric(*update_args)
|
| 148 |
+
|
| 149 |
+
# Log
|
| 150 |
+
if log:
|
| 151 |
+
self.log(
|
| 152 |
+
name,
|
| 153 |
+
metric,
|
| 154 |
+
prog_bar=prog_bar,
|
| 155 |
+
batch_size=batch_size,
|
| 156 |
+
on_step=False,
|
| 157 |
+
on_epoch=True,
|
| 158 |
+
add_dataloader_idx=False,
|
| 159 |
+
metric_attribute=name # Suggested by a torchmetrics error
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Bootstrap
|
| 163 |
+
if bootstrap:
|
| 164 |
+
def _bootsrapped_metric_class(**metric_kwargs):
|
| 165 |
+
metric = metric_class(**metric_kwargs)
|
| 166 |
+
return ReturnScalarBootStrapper(metric, std=True, num_bootstraps=num_bootstraps)
|
| 167 |
+
|
| 168 |
+
self._update_metric(
|
| 169 |
+
name=name + "_std",
|
| 170 |
+
metric_class=_bootsrapped_metric_class,
|
| 171 |
+
update_args=update_args,
|
| 172 |
+
batch_size=batch_size,
|
| 173 |
+
metric_kwargs=metric_kwargs,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def _update_df_test(self, dct: dict) -> None:
|
| 177 |
+
for col, vals in dct.items():
|
| 178 |
+
if isinstance(vals, torch.Tensor):
|
| 179 |
+
vals = vals.tolist()
|
| 180 |
+
self.df_test[col].extend(vals)
|
massspecgym/models/de_novo/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import DeNovoMassSpecGymModel
|
| 2 |
+
from .random import RandomDeNovo
|
| 3 |
+
from .dummy import DummyDeNovo
|
| 4 |
+
from .smiles_tranformer import SmilesTransformer
|
| 5 |
+
|
| 6 |
+
__all__ = ["DeNovoMassSpecGymModel", "RandomDeNovo", "DummyDeNovo", "SmilesTransformer"]
|
massspecgym/models/de_novo/base.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import typing as T
|
| 2 |
+
from abc import ABC
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from rdkit import Chem
|
| 7 |
+
from rdkit.DataStructs import TanimotoSimilarity
|
| 8 |
+
from torchmetrics.aggregation import MeanMetric
|
| 9 |
+
|
| 10 |
+
from massspecgym.models.base import MassSpecGymModel, Stage
|
| 11 |
+
from massspecgym.utils import morgan_fp, mol_to_inchi_key, MyopicMCES
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DeNovoMassSpecGymModel(MassSpecGymModel, ABC):
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
top_ks: T.Iterable[int] = (1, 10),
|
| 19 |
+
myopic_mces_kwargs: T.Optional[T.Mapping] = None,
|
| 20 |
+
*args,
|
| 21 |
+
**kwargs
|
| 22 |
+
):
|
| 23 |
+
super().__init__(*args, **kwargs)
|
| 24 |
+
|
| 25 |
+
self.top_ks = top_ks
|
| 26 |
+
self.myopic_mces = MyopicMCES(**(myopic_mces_kwargs or {}))
|
| 27 |
+
self.mol_pred_kind: T.Literal["smiles", "rdkit"] = "smiles"
|
| 28 |
+
# caches of already computed results to avoid expensive re-computations
|
| 29 |
+
self.mces_cache = dict()
|
| 30 |
+
self.mol_2_morgan_fp = dict()
|
| 31 |
+
|
| 32 |
+
def on_batch_end(
|
| 33 |
+
self,
|
| 34 |
+
outputs: T.Any,
|
| 35 |
+
batch: dict,
|
| 36 |
+
batch_idx: int,
|
| 37 |
+
stage: Stage
|
| 38 |
+
) -> None:
|
| 39 |
+
self.log(
|
| 40 |
+
f"{stage.to_pref()}loss",
|
| 41 |
+
outputs['loss'],
|
| 42 |
+
batch_size=batch['spec'].size(0),
|
| 43 |
+
sync_dist=True,
|
| 44 |
+
prog_bar=True,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
if stage in self.log_only_loss_at_stages:
|
| 48 |
+
return
|
| 49 |
+
|
| 50 |
+
metric_vals = self.evaluate_de_novo_step(
|
| 51 |
+
outputs["mols_pred"], # (bs, k) list of generated rdkit molecules or SMILES strings
|
| 52 |
+
batch["mol"], # (bs) list of ground truth SMILES strings
|
| 53 |
+
stage=stage
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
if stage == Stage.TEST and self.df_test_path is not None:
|
| 57 |
+
self._update_df_test(metric_vals)
|
| 58 |
+
|
| 59 |
+
def evaluate_de_novo_step(
|
| 60 |
+
self,
|
| 61 |
+
mols_pred: list[list[T.Optional[Chem.Mol | str]]],
|
| 62 |
+
mol_true: list[str],
|
| 63 |
+
stage: Stage,
|
| 64 |
+
) -> dict[str, torch.Tensor]:
|
| 65 |
+
"""
|
| 66 |
+
# TODO: refactor to compute only for max(k) and then use the result to obtain the rest by
|
| 67 |
+
subsetting.
|
| 68 |
+
|
| 69 |
+
Main evaluation method for the models for de novo molecule generation from mass spectra.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
mols_pred (list[list[Mol | str]]): (bs, k) list of generated rdkit molecules or SMILES
|
| 73 |
+
strings with possible Nones if no molecule was generated
|
| 74 |
+
mol_true (list[str]): (bs) list of ground-truth SMILES strings
|
| 75 |
+
"""
|
| 76 |
+
# Initialize return dictionary to store metric values per sample
|
| 77 |
+
metric_vals = {}
|
| 78 |
+
|
| 79 |
+
# Get SMILES and RDKit molecule objects for all predictions
|
| 80 |
+
if self.mol_pred_kind == "smiles":
|
| 81 |
+
smiles_pred_valid, mols_pred_valid = [], []
|
| 82 |
+
for mols_pred_sample in mols_pred:
|
| 83 |
+
smiles_pred_valid_sample, mols_pred_valid_sample = [], []
|
| 84 |
+
for s in mols_pred_sample:
|
| 85 |
+
m = Chem.MolFromSmiles(s) if s is not None else None
|
| 86 |
+
# If SMILES cannot be converted to RDKit molecule, the molecule is set to None
|
| 87 |
+
smiles_pred_valid_sample.append(s if m is not None else None)
|
| 88 |
+
mols_pred_valid_sample.append(m)
|
| 89 |
+
smiles_pred_valid.append(smiles_pred_valid_sample)
|
| 90 |
+
mols_pred_valid.append(mols_pred_valid_sample)
|
| 91 |
+
smiles_pred, mols_pred = smiles_pred_valid, mols_pred_valid
|
| 92 |
+
elif self.mol_pred_kind == "rdkit":
|
| 93 |
+
smiles_pred = [
|
| 94 |
+
[Chem.MolToSmiles(m) if m is not None else None for m in ms]
|
| 95 |
+
for ms in mols_pred
|
| 96 |
+
]
|
| 97 |
+
else:
|
| 98 |
+
raise ValueError(f"Invalid mol_pred_kind: {self.mol_pred_kind}")
|
| 99 |
+
|
| 100 |
+
# Auxiliary metric: number of valid molecules
|
| 101 |
+
self._update_metric(
|
| 102 |
+
stage.to_pref() + f"num_valid_mols",
|
| 103 |
+
MeanMetric,
|
| 104 |
+
([sum([m is not None for m in ms]) for ms in mols_pred],),
|
| 105 |
+
batch_size=len(mols_pred),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Get RDKit molecule objects for ground truth
|
| 109 |
+
smile_true = mol_true
|
| 110 |
+
mol_true = [Chem.MolFromSmiles(sm) for sm in mol_true]
|
| 111 |
+
|
| 112 |
+
def _get_morgan_fp_with_cache(mol):
|
| 113 |
+
"""
|
| 114 |
+
A helper function to retrieve either cached Morgan Fingerprint value, or to compute and cache it
|
| 115 |
+
@param mol: RDKit molecule object
|
| 116 |
+
@return:
|
| 117 |
+
"""
|
| 118 |
+
if mol not in self.mol_2_morgan_fp:
|
| 119 |
+
self.mol_2_morgan_fp[mol] = morgan_fp(mol, to_np=False)
|
| 120 |
+
return self.mol_2_morgan_fp[mol]
|
| 121 |
+
|
| 122 |
+
# Evaluate top-k metrics
|
| 123 |
+
for top_k in self.top_ks:
|
| 124 |
+
# Get top-k predicted molecules for each ground-truth sample
|
| 125 |
+
smiles_pred_top_k = [smiles_pred_sample[:top_k] for smiles_pred_sample in smiles_pred]
|
| 126 |
+
mols_pred_top_k = [mols_pred_sample[:top_k] for mols_pred_sample in mols_pred]
|
| 127 |
+
|
| 128 |
+
# 1. Evaluate minimum common edge subgraph:
|
| 129 |
+
# Calculate MCES distance between top-k predicted molecules and ground truth and
|
| 130 |
+
# report the minimum distance. The minimum distances for each sample in the batch are
|
| 131 |
+
# averaged across the epoch.
|
| 132 |
+
min_mces_dists = []
|
| 133 |
+
mces_thld = 100
|
| 134 |
+
# Iterate over batch
|
| 135 |
+
for preds, true in zip(smiles_pred_top_k, smile_true):
|
| 136 |
+
# Iterate over top-k predicted molecule samples
|
| 137 |
+
dists = []
|
| 138 |
+
for pred in preds:
|
| 139 |
+
if pred is None:
|
| 140 |
+
dists.append(mces_thld)
|
| 141 |
+
else:
|
| 142 |
+
if (true, pred) not in self.mces_cache:
|
| 143 |
+
mce_val = self.myopic_mces(true, pred)
|
| 144 |
+
self.mces_cache[(true, pred)] = mce_val
|
| 145 |
+
dists.append(self.mces_cache[(true, pred)])
|
| 146 |
+
min_mces_dists.append(min(min(dists), mces_thld))
|
| 147 |
+
min_mces_dists = torch.tensor(min_mces_dists, device=self.device)
|
| 148 |
+
|
| 149 |
+
# Log
|
| 150 |
+
metric_name = stage.to_pref() + f"top_{top_k}_mces_dist"
|
| 151 |
+
self._update_metric(
|
| 152 |
+
metric_name,
|
| 153 |
+
MeanMetric,
|
| 154 |
+
(min_mces_dists,),
|
| 155 |
+
batch_size=len(min_mces_dists),
|
| 156 |
+
bootstrap=stage == Stage.TEST
|
| 157 |
+
)
|
| 158 |
+
metric_vals[metric_name] = min_mces_dists
|
| 159 |
+
|
| 160 |
+
# 2. Evaluate Tanimoto similarity:
|
| 161 |
+
# Calculate Tanimoto similarity between top-k predicted molecules and ground truth and
|
| 162 |
+
# report the maximum similarity. The maximum similarities for each sample in the batch
|
| 163 |
+
# are averaged across the epoch.
|
| 164 |
+
fps_pred_top_k = [
|
| 165 |
+
[_get_morgan_fp_with_cache(m) if m is not None else None for m in ms]
|
| 166 |
+
for ms in mols_pred_top_k
|
| 167 |
+
]
|
| 168 |
+
fp_true = [_get_morgan_fp_with_cache(m) for m in mol_true]
|
| 169 |
+
|
| 170 |
+
max_tanimoto_sims = []
|
| 171 |
+
# Iterate over batch
|
| 172 |
+
for preds, true in zip(fps_pred_top_k, fp_true):
|
| 173 |
+
# Iterate over top-k predicted molecule samples
|
| 174 |
+
sims = [
|
| 175 |
+
TanimotoSimilarity(true, pred)
|
| 176 |
+
if pred is not None else 0
|
| 177 |
+
for pred in preds
|
| 178 |
+
]
|
| 179 |
+
max_tanimoto_sims.append(max(sims))
|
| 180 |
+
max_tanimoto_sims = torch.tensor(max_tanimoto_sims, device=self.device)
|
| 181 |
+
|
| 182 |
+
# Log
|
| 183 |
+
metric_name = stage.to_pref() + f"top_{top_k}_max_tanimoto_sim"
|
| 184 |
+
self._update_metric(
|
| 185 |
+
metric_name,
|
| 186 |
+
MeanMetric,
|
| 187 |
+
(max_tanimoto_sims,),
|
| 188 |
+
batch_size=len(max_tanimoto_sims),
|
| 189 |
+
bootstrap=stage == Stage.TEST
|
| 190 |
+
)
|
| 191 |
+
metric_vals[metric_name] = max_tanimoto_sims
|
| 192 |
+
|
| 193 |
+
# 3. Evaluate exact match (accuracy):
|
| 194 |
+
# Calculate if the ground truth molecule is in the top-k predicted molecules and report
|
| 195 |
+
# the average across the epoch.
|
| 196 |
+
in_top_k = [
|
| 197 |
+
mol_to_inchi_key(true) in [
|
| 198 |
+
mol_to_inchi_key(pred)
|
| 199 |
+
if pred is not None else None
|
| 200 |
+
for pred in preds
|
| 201 |
+
]
|
| 202 |
+
for true, preds in zip(mol_true, mols_pred_top_k)
|
| 203 |
+
]
|
| 204 |
+
in_top_k = torch.tensor(in_top_k, device=self.device)
|
| 205 |
+
|
| 206 |
+
# Log
|
| 207 |
+
metric_name = stage.to_pref() + f"top_{top_k}_accuracy"
|
| 208 |
+
self._update_metric(
|
| 209 |
+
metric_name,
|
| 210 |
+
MeanMetric,
|
| 211 |
+
(in_top_k,),
|
| 212 |
+
batch_size=len(in_top_k),
|
| 213 |
+
bootstrap=stage == Stage.TEST
|
| 214 |
+
)
|
| 215 |
+
metric_vals[metric_name] = in_top_k
|
| 216 |
+
|
| 217 |
+
return metric_vals
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def test_step(
|
| 221 |
+
self,
|
| 222 |
+
batch: dict,
|
| 223 |
+
batch_idx: torch.Tensor
|
| 224 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 225 |
+
outputs = super().test_step(batch, batch_idx)
|
| 226 |
+
|
| 227 |
+
# Get generated (i.e., predicted) SMILES
|
| 228 |
+
if self.df_test_path is not None:
|
| 229 |
+
self._update_df_test({
|
| 230 |
+
'identifier': batch['identifier'],
|
| 231 |
+
'mols_pred': outputs['mols_pred']
|
| 232 |
+
})
|
| 233 |
+
|
| 234 |
+
return outputs
|
| 235 |
+
|
| 236 |
+
def on_test_epoch_end(self):
|
| 237 |
+
# Save test data frame to disk
|
| 238 |
+
if self.df_test_path is not None:
|
| 239 |
+
df_test = pd.DataFrame(self.df_test)
|
| 240 |
+
self.df_test_path.parent.mkdir(parents=True, exist_ok=True)
|
| 241 |
+
df_test.to_pickle(self.df_test_path)
|
massspecgym/models/de_novo/dummy.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from massspecgym.models.base import Stage
|
| 6 |
+
from massspecgym.models.de_novo.base import DeNovoMassSpecGymModel
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DummyDeNovo(DeNovoMassSpecGymModel):
|
| 10 |
+
|
| 11 |
+
def __init__(self, n_samples: int = 10, *args, **kwargs):
|
| 12 |
+
super().__init__(*args, **kwargs)
|
| 13 |
+
self.n_samples = n_samples
|
| 14 |
+
|
| 15 |
+
self.dummy_smiles = [
|
| 16 |
+
"O", # Water (H₂O)
|
| 17 |
+
"C", # Methane (CH₄)
|
| 18 |
+
"CCO", # Ethanol (C₂H₆O)
|
| 19 |
+
"C(C1C(C(C(C(O1)O)O)O)O)O", # Glucose (C₆H₁₂O₆)
|
| 20 |
+
"CC(=O)C", # Acetone (C₃H₆O)
|
| 21 |
+
"CC(=O)Oc1ccccc1C(=O)O", # Aspirin (C₉H₈O₄)
|
| 22 |
+
"CN1C=NC2=C1C(=O)N(C(=O)N2C)C", # Caffeine (C₈H₁₀N₄O₂)
|
| 23 |
+
"c1ccccc1", # Benzene (C₆H₆)
|
| 24 |
+
"CC(=O)O", # Acetic Acid (C₂H₄O₂)
|
| 25 |
+
"CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", # Ibuprofen (C₁₃H₁₈O₂)
|
| 26 |
+
None
|
| 27 |
+
]
|
| 28 |
+
self.mol_pred_kind = "smiles"
|
| 29 |
+
|
| 30 |
+
def step(
|
| 31 |
+
self, batch: dict, stage: Stage = Stage.NONE
|
| 32 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 33 |
+
bs = batch['spec'].shape[0]
|
| 34 |
+
|
| 35 |
+
# Sample dummy molecules from the pre-defined list
|
| 36 |
+
mols_pred = [[random.choice(self.dummy_smiles) for _ in range(self.n_samples)] for _ in range(bs)]
|
| 37 |
+
|
| 38 |
+
# Random baseline, so we return a dummy loss
|
| 39 |
+
loss = torch.tensor(0.0, requires_grad=True)
|
| 40 |
+
|
| 41 |
+
# Return molecules in the dict
|
| 42 |
+
return dict(loss=loss, mols_pred=mols_pred)
|
| 43 |
+
|
| 44 |
+
def configure_optimizers(self):
|
| 45 |
+
# No optimizer needed for a random baseline
|
| 46 |
+
return None
|
massspecgym/models/de_novo/random.py
ADDED
|
@@ -0,0 +1,1750 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import deque, defaultdict
|
| 2 |
+
from collections.abc import Generator
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from random import choice, shuffle
|
| 5 |
+
|
| 6 |
+
import chemparse
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from massspecgym.models.base import Stage
|
| 10 |
+
from massspecgym.models.de_novo.base import DeNovoMassSpecGymModel
|
| 11 |
+
from rdkit import Chem
|
| 12 |
+
from rdkit.Chem.MolStandardize import rdMolStandardize
|
| 13 |
+
from rdkit.Chem.rdMolDescriptors import CalcMolFormula
|
| 14 |
+
from rdkit.Chem.Descriptors import ExactMolWt
|
| 15 |
+
from rdkit.Chem.rdchem import Mol, BondType
|
| 16 |
+
from copy import deepcopy
|
| 17 |
+
from collections import Counter
|
| 18 |
+
import bisect
|
| 19 |
+
from itertools import combinations
|
| 20 |
+
|
| 21 |
+
# type aliases for code readability
|
| 22 |
+
chem_element = str
|
| 23 |
+
number_of_atoms = int
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass(frozen=True, order=True)
|
| 27 |
+
class ValenceAndCharge:
|
| 28 |
+
"""
|
| 29 |
+
A data class to store valence value with the corresponding charge
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
valence: int
|
| 33 |
+
charge: int
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass(frozen=True, order=True)
|
| 37 |
+
class AtomWithValence:
|
| 38 |
+
"""
|
| 39 |
+
A data class to store atom info including the computed valence
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
atom_type: chem_element
|
| 43 |
+
atom_valence_and_charge: ValenceAndCharge
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass(frozen=True, order=True)
|
| 47 |
+
class BondToNeighbouringAtom:
|
| 48 |
+
"""
|
| 49 |
+
A data class to store info about the adjacent atom
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
adjacent_atom: AtomWithValence
|
| 53 |
+
bond_type: int
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclass
|
| 57 |
+
class AtomNodeForRandomTraversal:
|
| 58 |
+
"""
|
| 59 |
+
A data class to store atom info including the computed valence
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
atom_with_valence: AtomWithValence
|
| 63 |
+
_remaining_node_degree: int = None
|
| 64 |
+
_remaining_node_charge: int = None
|
| 65 |
+
|
| 66 |
+
def __post_init__(self):
|
| 67 |
+
"""Setting up remaining node degree and charge for random traversal"""
|
| 68 |
+
self._remaining_node_degree = (
|
| 69 |
+
self.atom_with_valence.atom_valence_and_charge.valence
|
| 70 |
+
)
|
| 71 |
+
self._remaining_node_charge = (
|
| 72 |
+
self.atom_with_valence.atom_valence_and_charge.charge
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
def remaining_node_degree(self):
|
| 77 |
+
"""remaining_node_degree variable getter"""
|
| 78 |
+
return self._remaining_node_degree
|
| 79 |
+
|
| 80 |
+
@remaining_node_degree.setter
|
| 81 |
+
def remaining_node_degree(self, value: int):
|
| 82 |
+
"""remaining_node_degree variable setter"""
|
| 83 |
+
self._remaining_node_degree = value
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def remaining_node_charge(self):
|
| 87 |
+
"""remaining_node_charge variable getter"""
|
| 88 |
+
return self._remaining_node_charge
|
| 89 |
+
|
| 90 |
+
@remaining_node_charge.setter
|
| 91 |
+
def remaining_node_charge(self, value: int):
|
| 92 |
+
"""remaining_node_charge variable setter"""
|
| 93 |
+
self._remaining_node_charge = value
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def create_rdkit_molecule_from_edge_list(
|
| 97 |
+
edge_list: list[tuple[int, int]], all_graph_nodes: list[AtomNodeForRandomTraversal]
|
| 98 |
+
) -> Mol:
|
| 99 |
+
"""
|
| 100 |
+
A helper function converting a randomly generated edge list into rdkit.Chem.rdchem.Mol object
|
| 101 |
+
@param edge_list: a list of edges, where each edge is specified by the index of its nodes
|
| 102 |
+
@param all_graph_nodes: a list of all atomic nodes in the molecular graph
|
| 103 |
+
"""
|
| 104 |
+
# first we traverse all randomly generated edges and compute bond types between each pair of atoms
|
| 105 |
+
edge_2_bondtype = defaultdict(int)
|
| 106 |
+
for edge_node_i, edge_node_j in edge_list:
|
| 107 |
+
edge_2_bondtype[
|
| 108 |
+
(min(edge_node_i, edge_node_j), max(edge_node_i, edge_node_j))
|
| 109 |
+
] += 1
|
| 110 |
+
|
| 111 |
+
# helper routine to get the rdking enum bondtype
|
| 112 |
+
def _get_rdkit_bondtype(bondtype: int) -> BondType:
|
| 113 |
+
int_bondtype_2_enum = {
|
| 114 |
+
1: BondType.SINGLE,
|
| 115 |
+
2: BondType.DOUBLE,
|
| 116 |
+
3: BondType.TRIPLE,
|
| 117 |
+
4: BondType.QUADRUPLE,
|
| 118 |
+
5: BondType.QUINTUPLE,
|
| 119 |
+
6: BondType.HEXTUPLE,
|
| 120 |
+
}
|
| 121 |
+
try:
|
| 122 |
+
return int_bondtype_2_enum[bondtype]
|
| 123 |
+
except KeyError:
|
| 124 |
+
raise NotImplementedError(f"Bond type {bondtype} is not supported")
|
| 125 |
+
|
| 126 |
+
edge_list_rdkit = [
|
| 127 |
+
(node_i, node_j, _get_rdkit_bondtype(bondtype))
|
| 128 |
+
for (node_i, node_j), bondtype in edge_2_bondtype.items()
|
| 129 |
+
]
|
| 130 |
+
# creating an empty editable molecule
|
| 131 |
+
mol = Chem.RWMol()
|
| 132 |
+
# adding the atoms to the molecule object
|
| 133 |
+
|
| 134 |
+
# as some all_graph nodes can represent charges, we have to remember mapping of molecular atom index to
|
| 135 |
+
# the corresponding atom index in all_graph_nodes
|
| 136 |
+
all_graph_atom_idx_2_mol_atom_idx = {}
|
| 137 |
+
for all_graph_atom_idx, atom in enumerate(all_graph_nodes):
|
| 138 |
+
# ignoring charge-related graph nodes
|
| 139 |
+
if atom.atom_with_valence.atom_type not in {"+", "-"}:
|
| 140 |
+
all_graph_atom_idx_2_mol_atom_idx[all_graph_atom_idx] = mol.GetNumAtoms()
|
| 141 |
+
next_atom = Chem.Atom(atom.atom_with_valence.atom_type)
|
| 142 |
+
next_atom.SetFormalCharge(
|
| 143 |
+
atom.atom_with_valence.atom_valence_and_charge.charge
|
| 144 |
+
)
|
| 145 |
+
mol.AddAtom(next_atom)
|
| 146 |
+
|
| 147 |
+
# adding bonds
|
| 148 |
+
for (edge_node_i, edge_node_j, bond_type) in edge_list_rdkit:
|
| 149 |
+
# checking if the edge represents a charge of connected atom
|
| 150 |
+
the_edge_represents_charge = len(
|
| 151 |
+
{
|
| 152 |
+
all_graph_nodes[node_i].atom_with_valence.atom_type
|
| 153 |
+
for node_i in [edge_node_i, edge_node_j]
|
| 154 |
+
}.intersection({"+", "-"})
|
| 155 |
+
)
|
| 156 |
+
if the_edge_represents_charge:
|
| 157 |
+
# setting a charge to the corresponding atom
|
| 158 |
+
for node_i in [edge_node_i, edge_node_j]:
|
| 159 |
+
if all_graph_nodes[node_i].atom_with_valence.atom_type in {"+", "-"}:
|
| 160 |
+
charge_value = (
|
| 161 |
+
1
|
| 162 |
+
if all_graph_nodes[node_i].atom_with_valence.atom_type == "+"
|
| 163 |
+
else -1
|
| 164 |
+
)
|
| 165 |
+
else:
|
| 166 |
+
atom_node_i = node_i
|
| 167 |
+
mol.GetAtomWithIdx(
|
| 168 |
+
all_graph_atom_idx_2_mol_atom_idx[atom_node_i]
|
| 169 |
+
).SetFormalCharge(charge_value)
|
| 170 |
+
else:
|
| 171 |
+
mol.AddBond(
|
| 172 |
+
all_graph_atom_idx_2_mol_atom_idx[edge_node_i],
|
| 173 |
+
all_graph_atom_idx_2_mol_atom_idx[edge_node_j],
|
| 174 |
+
bond_type,
|
| 175 |
+
)
|
| 176 |
+
# returning the rdkit.Chem.rdchem.Mol object
|
| 177 |
+
return mol.GetMol()
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class RandomDeNovo(DeNovoMassSpecGymModel):
|
| 181 |
+
def __init__(
|
| 182 |
+
self,
|
| 183 |
+
formula_known: bool = True,
|
| 184 |
+
count_of_valid_valence_assignments: int = 10,
|
| 185 |
+
estimate_chem_element_stats: bool = False,
|
| 186 |
+
max_top_k: int = 10,
|
| 187 |
+
enforce_connectivity: bool = True,
|
| 188 |
+
cache_results: bool = True,
|
| 189 |
+
**kwargs
|
| 190 |
+
):
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
@param formula_known: a boolean flag about the information available prior to generation
|
| 194 |
+
If formula_known is True, we should generate molecules with the specified formula
|
| 195 |
+
If formula_known is False, we should generate any molecule with the specified mass
|
| 196 |
+
@param count_of_valid_valence_assignments: an integer controlling process of selecting valence assignment
|
| 197 |
+
to each atom in the generated molecule.
|
| 198 |
+
`count_of_valid_valence_assignments` of assignment corresponding to
|
| 199 |
+
the formula are generated, then one assignment is is picked at random.
|
| 200 |
+
The default is set to 3 for the computational speed purposes.
|
| 201 |
+
When setting to 1, the first feasible valence assignment will be used.
|
| 202 |
+
@param estimate_chem_element_stats: a boolean flag controlling if prior information about elements' valences
|
| 203 |
+
and bond type distributions is estimated from training data
|
| 204 |
+
@param max_top_k: a maximum number of candidates to generate. If the count of valid valence assignments do
|
| 205 |
+
not allow generation of max_top_k, then less candidates are returned
|
| 206 |
+
@param enforce_connectivity: a boolean flag controlling connectivity of randomly generated molecules.
|
| 207 |
+
When it is set to True, first a random spanning tree is sampled
|
| 208 |
+
@param cache_results: a boolean flag controlling caching of already generated structures.
|
| 209 |
+
When set to True, for each unique formula the set of random molecules is cached to avoid
|
| 210 |
+
recomputation.
|
| 211 |
+
"""
|
| 212 |
+
super(RandomDeNovo, self).__init__(**kwargs)
|
| 213 |
+
self.formula_known = formula_known
|
| 214 |
+
self.count_of_valid_valence_assignments = count_of_valid_valence_assignments
|
| 215 |
+
self.estimate_chem_element_stats = estimate_chem_element_stats
|
| 216 |
+
self.max_top_k = min(max(self.top_ks), max_top_k)
|
| 217 |
+
self.enforce_connectivity = enforce_connectivity
|
| 218 |
+
# prior chemical knownledge about element valences
|
| 219 |
+
self.element_2_valences = ELEMENT_VALENCES
|
| 220 |
+
# a dictionary structure to record molecular weights with corresponding formulas from training data
|
| 221 |
+
# during training steps, for each molecular weight we record all encountered formulas
|
| 222 |
+
# then on training end we compute proportions of the formulas and record it as a mapping
|
| 223 |
+
# mol_weight -> [[formula_1, formula_2], [proportion_of_formula_1, proportion_of_formula_2]]
|
| 224 |
+
self.mol_weight_2_formulas = defaultdict(list)
|
| 225 |
+
# a helper array to store sorted list of train molecular weights.
|
| 226 |
+
# It will be used for the O(logn) lookup of the closest mol weight
|
| 227 |
+
self.mol_weight_trn_values: list[float] = None
|
| 228 |
+
# a dictionary structure for statistics about bond type distributions
|
| 229 |
+
# the dictionary has the following mapping:
|
| 230 |
+
# chem_element ->
|
| 231 |
+
# ValenceAndCharge ->
|
| 232 |
+
# number of already bonded atoms ->
|
| 233 |
+
# [already created BondToNeighbouringAtom] ->
|
| 234 |
+
# AtomWithValence ->
|
| 235 |
+
# list of (bond_type, count) + total_count
|
| 236 |
+
self.element_2_bond_stats = None
|
| 237 |
+
# a cache with already precomputed sets of randomly generated molecules for the given formula
|
| 238 |
+
self.formula_2_random_smiles = {}
|
| 239 |
+
self.cache_results = cache_results
|
| 240 |
+
|
| 241 |
+
def generator_for_splits_of_chem_element_atoms_by_possible_valences(
|
| 242 |
+
self,
|
| 243 |
+
atom_type: chem_element,
|
| 244 |
+
possible_valences: list[ValenceAndCharge],
|
| 245 |
+
atom_count: int,
|
| 246 |
+
already_assigned_groups_of_atoms: dict[AtomWithValence, number_of_atoms],
|
| 247 |
+
) -> Generator[dict[AtomWithValence, number_of_atoms]]:
|
| 248 |
+
"""
|
| 249 |
+
A recursive generator function to iterate over all possible partitions of element atoms
|
| 250 |
+
into groups with different valid valences.
|
| 251 |
+
Each allowed valence value can have any number from atoms, from zero up to total `atom_count`
|
| 252 |
+
@param atom_type: chemical element
|
| 253 |
+
@param possible_valences: a list of allowed valences
|
| 254 |
+
@param atom_count: a total number of element atoms to split into valence groups
|
| 255 |
+
@param already_assigned_groups_of_atoms: partial results to pass into the subsequent recursive calls
|
| 256 |
+
|
| 257 |
+
@return A generator for lazy enumeration over all possible splits of `atom_count` atoms into subgroups
|
| 258 |
+
of valid valences specified in `possible valences` parameters.
|
| 259 |
+
Each return value is a dictionary, mapping atom with fixed valence to a total count of such instances
|
| 260 |
+
in the molecule.
|
| 261 |
+
|
| 262 |
+
@note In the future the method can be made into a function in a separate utils module,
|
| 263 |
+
for the simplicity of codebase organization and testing purposes it's kept as the method for now
|
| 264 |
+
"""
|
| 265 |
+
# the check for a base case of the recursion
|
| 266 |
+
if atom_count == 0:
|
| 267 |
+
yield already_assigned_groups_of_atoms
|
| 268 |
+
elif len(possible_valences):
|
| 269 |
+
# taking the first valence value from the possible ones
|
| 270 |
+
next_valence = possible_valences[0]
|
| 271 |
+
# iterating over possible sizes for a group of atoms with `next_valence` value of the valence
|
| 272 |
+
for size_of_group in range(atom_count, -1, -1):
|
| 273 |
+
# recording the assigned size of the group
|
| 274 |
+
already_assigned_groups_of_atoms_next = (
|
| 275 |
+
already_assigned_groups_of_atoms.copy()
|
| 276 |
+
)
|
| 277 |
+
atom_with_valence = AtomWithValence(
|
| 278 |
+
atom_type=atom_type, atom_valence_and_charge=next_valence
|
| 279 |
+
)
|
| 280 |
+
already_assigned_groups_of_atoms_next[atom_with_valence] = size_of_group
|
| 281 |
+
yield from self.generator_for_splits_of_chem_element_atoms_by_possible_valences(
|
| 282 |
+
atom_type=atom_type,
|
| 283 |
+
possible_valences=possible_valences[1:],
|
| 284 |
+
atom_count=atom_count - size_of_group,
|
| 285 |
+
already_assigned_groups_of_atoms=already_assigned_groups_of_atoms_next,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
def assigner_of_valences_to_all_atoms(
|
| 289 |
+
self,
|
| 290 |
+
unassigned_molecule_elements_with_counts: dict[chem_element, number_of_atoms],
|
| 291 |
+
already_assigned_atoms_with_valences: dict[AtomWithValence, number_of_atoms],
|
| 292 |
+
common_valences_only: bool = True,
|
| 293 |
+
) -> Generator[dict[AtomWithValence, number_of_atoms]]:
|
| 294 |
+
"""
|
| 295 |
+
A recursive function to iterate over all possible valid assignments of valences for each atom in the molecule
|
| 296 |
+
@param unassigned_molecule_elements_with_counts: a dictionary representation of a molecule,
|
| 297 |
+
mapping each present element to a corresponding number of atoms.
|
| 298 |
+
The function is recursive, in the subsequence calls
|
| 299 |
+
the dictionary represents an yet-unprocessed submolecule
|
| 300 |
+
@param already_assigned_atoms_with_valences: partial results to pass into the subsequent recursive calls,
|
| 301 |
+
stored as a dictionary, mapping atom with fixed valence
|
| 302 |
+
to a total count of such atoms in the molecule
|
| 303 |
+
@param common_valences_only: a flag for using the common valence values for each element
|
| 304 |
+
|
| 305 |
+
@return A generator for lazy enumeration over all possible assignments of all molecule atoms into subgroups
|
| 306 |
+
defined by valences. Valence values are the valid ones for the corresponding chemical element.
|
| 307 |
+
Each return value is a dictionary, mapping atom of specified chemical element with a fixed valence
|
| 308 |
+
to a total count of such atoms in the molecule.
|
| 309 |
+
|
| 310 |
+
@note In the future the method can be made into a function in a separate utils module,
|
| 311 |
+
for the simplicity of codebase organization and testing purposes it's kept as the method for now
|
| 312 |
+
"""
|
| 313 |
+
# the check for a base case of the recursion
|
| 314 |
+
if len(unassigned_molecule_elements_with_counts) == 0:
|
| 315 |
+
yield already_assigned_atoms_with_valences
|
| 316 |
+
else:
|
| 317 |
+
# processing the next chemical element in the molecule
|
| 318 |
+
chem_element_type, atom_count = list(
|
| 319 |
+
unassigned_molecule_elements_with_counts.items()
|
| 320 |
+
)[0]
|
| 321 |
+
# for the subsequence recursive calls the picked atom will be removed from the yet-to-be-processed
|
| 322 |
+
remaining_unassigned_atoms_with_counts = (
|
| 323 |
+
unassigned_molecule_elements_with_counts.copy()
|
| 324 |
+
)
|
| 325 |
+
del remaining_unassigned_atoms_with_counts[chem_element_type]
|
| 326 |
+
# generating splits of the element count into groups with possible valences
|
| 327 |
+
valences_common, valences_others = self.element_2_valences[
|
| 328 |
+
chem_element_type.capitalize()
|
| 329 |
+
]
|
| 330 |
+
possible_element_valences = (
|
| 331 |
+
valences_common
|
| 332 |
+
if common_valences_only
|
| 333 |
+
else valences_common + valences_others
|
| 334 |
+
)
|
| 335 |
+
# we ignore "the direction" of ionic bonds, therefore we work with absolute values of valences
|
| 336 |
+
possible_element_valences = map(
|
| 337 |
+
lambda x: ValenceAndCharge(valence=np.abs(x.valence), charge=x.charge),
|
| 338 |
+
possible_element_valences,
|
| 339 |
+
)
|
| 340 |
+
# we require a connected molecule graph, so we ignore possible 0 values of valences
|
| 341 |
+
possible_element_valences = list(
|
| 342 |
+
set(filter(lambda x: x.valence > 0, possible_element_valences))
|
| 343 |
+
)
|
| 344 |
+
# creating a generator for lazy enumeration over all possible splits of element atoms
|
| 345 |
+
# into subgroups of possible valid valences
|
| 346 |
+
valence_split_generator = (
|
| 347 |
+
self.generator_for_splits_of_chem_element_atoms_by_possible_valences(
|
| 348 |
+
atom_type=chem_element_type,
|
| 349 |
+
possible_valences=possible_element_valences,
|
| 350 |
+
atom_count=atom_count,
|
| 351 |
+
already_assigned_groups_of_atoms=dict(),
|
| 352 |
+
)
|
| 353 |
+
)
|
| 354 |
+
# iterating over splits of the element count into groups with possible valences
|
| 355 |
+
for element_atoms_with_valence_2_count in valence_split_generator:
|
| 356 |
+
already_assigned_atoms_with_valences_new = (
|
| 357 |
+
already_assigned_atoms_with_valences.copy()
|
| 358 |
+
)
|
| 359 |
+
already_assigned_atoms_with_valences_new.update(
|
| 360 |
+
element_atoms_with_valence_2_count
|
| 361 |
+
)
|
| 362 |
+
yield from self.assigner_of_valences_to_all_atoms(
|
| 363 |
+
unassigned_molecule_elements_with_counts=remaining_unassigned_atoms_with_counts,
|
| 364 |
+
already_assigned_atoms_with_valences=already_assigned_atoms_with_valences_new,
|
| 365 |
+
common_valences_only=common_valences_only,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
def is_valence_assignment_feasible(
|
| 369 |
+
self, valence_assignment: dict[AtomWithValence, number_of_atoms]
|
| 370 |
+
) -> bool:
|
| 371 |
+
"""
|
| 372 |
+
A function for checking if the valence assignment to all molecule atoms can be feasible
|
| 373 |
+
|
| 374 |
+
@param valence_assignment: an assignment of all molecule atoms into subgroups of plausible valences
|
| 375 |
+
|
| 376 |
+
@note In the future the method can be made into a function in a separate utils module,
|
| 377 |
+
for the simplicity of codebase organization and testing purposes it's kept as the method for now
|
| 378 |
+
"""
|
| 379 |
+
# considering a molecule as a graph with atom being nodes and chemical bonds being edges
|
| 380 |
+
# computing sum of all node degrees
|
| 381 |
+
sum_of_all_node_degrees = sum(
|
| 382 |
+
[
|
| 383 |
+
atom.atom_valence_and_charge.valence * count_of_atoms
|
| 384 |
+
for atom, count_of_atoms in valence_assignment.items()
|
| 385 |
+
]
|
| 386 |
+
)
|
| 387 |
+
if sum_of_all_node_degrees % 2 == 1:
|
| 388 |
+
# the valence assignment is infeasible as in the graph the number of edges is half of the total degrees sum
|
| 389 |
+
# therefore the sum_of_all_node_degrees must be an even number
|
| 390 |
+
return False
|
| 391 |
+
total_number_of_bonds = sum_of_all_node_degrees / 2
|
| 392 |
+
# the total number of all atoms in the whole molecule
|
| 393 |
+
total_number_of_atoms_in_molecule = sum(valence_assignment.values())
|
| 394 |
+
if total_number_of_bonds < total_number_of_atoms_in_molecule - 1:
|
| 395 |
+
# the valence assignment is infeasible as the molecule graph cannot be connected
|
| 396 |
+
return False
|
| 397 |
+
# check that charges add up to zero
|
| 398 |
+
total_charge = 0
|
| 399 |
+
for atom, count_of_atoms in valence_assignment.items():
|
| 400 |
+
# we do not take virtual nodes for the charged molecules, we force the remaining submolecule to be neutral
|
| 401 |
+
if atom.atom_type not in {"+", "-"}:
|
| 402 |
+
total_charge += atom.atom_valence_and_charge.charge * count_of_atoms
|
| 403 |
+
if total_charge != 0:
|
| 404 |
+
return False
|
| 405 |
+
return True
|
| 406 |
+
|
| 407 |
+
def get_feasible_atom_valence_assignments(
|
| 408 |
+
self, chemical_formula: str
|
| 409 |
+
) -> list[dict[AtomWithValence, number_of_atoms]]:
|
| 410 |
+
"""
|
| 411 |
+
A function generating candidate assignments of valences to individual atoms in the molecule.
|
| 412 |
+
Candidates are returned in a random order.
|
| 413 |
+
@param chemical_formula: a string containing the chemical formula of the molecule
|
| 414 |
+
|
| 415 |
+
@note In the future the method can be made into a function in a separate utils module,
|
| 416 |
+
for the simplicity of codebase organization and testing purposes it's kept as the method for now
|
| 417 |
+
"""
|
| 418 |
+
# parsing chemical formula into a dictionary of elements with corresponding counts
|
| 419 |
+
element_2_count = {
|
| 420 |
+
element: int(count)
|
| 421 |
+
for element, count in chemparse.parse_formula(chemical_formula).items()
|
| 422 |
+
}
|
| 423 |
+
# checking that all input elements are valid
|
| 424 |
+
for element in element_2_count.keys():
|
| 425 |
+
if element.capitalize() not in self.element_2_valences:
|
| 426 |
+
raise ValueError(
|
| 427 |
+
f"Found an unknown element {element.capitalize()} in the formula {chemical_formula}"
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# estimate the total number of all atoms in the whole molecule
|
| 431 |
+
# it will be used to check validity of the valence assignments
|
| 432 |
+
total_number_of_atoms_in_molecule = sum(element_2_count.values())
|
| 433 |
+
generated_candidate_valence_assignments = []
|
| 434 |
+
valence_assignment_generator = self.assigner_of_valences_to_all_atoms(
|
| 435 |
+
unassigned_molecule_elements_with_counts=element_2_count,
|
| 436 |
+
already_assigned_atoms_with_valences=dict(),
|
| 437 |
+
common_valences_only=True,
|
| 438 |
+
)
|
| 439 |
+
termination_assignment_value = {AtomWithValence("No more assignments", -1): -1}
|
| 440 |
+
next_valence_assignment = next(
|
| 441 |
+
valence_assignment_generator, termination_assignment_value
|
| 442 |
+
)
|
| 443 |
+
while (
|
| 444 |
+
len(generated_candidate_valence_assignments)
|
| 445 |
+
< self.count_of_valid_valence_assignments
|
| 446 |
+
and next_valence_assignment != termination_assignment_value
|
| 447 |
+
):
|
| 448 |
+
if self.is_valence_assignment_feasible(next_valence_assignment):
|
| 449 |
+
generated_candidate_valence_assignments.append(next_valence_assignment)
|
| 450 |
+
next_valence_assignment = next(
|
| 451 |
+
valence_assignment_generator, termination_assignment_value
|
| 452 |
+
)
|
| 453 |
+
# if no valence assignment was found with common valences,
|
| 454 |
+
# then try generating assignments including not-common valences
|
| 455 |
+
if len(generated_candidate_valence_assignments) == 0:
|
| 456 |
+
valence_assignment_generator = self.assigner_of_valences_to_all_atoms(
|
| 457 |
+
unassigned_molecule_elements_with_counts=element_2_count,
|
| 458 |
+
already_assigned_atoms_with_valences=dict(),
|
| 459 |
+
common_valences_only=False,
|
| 460 |
+
)
|
| 461 |
+
next_valence_assignment = next(
|
| 462 |
+
valence_assignment_generator, termination_assignment_value
|
| 463 |
+
)
|
| 464 |
+
while (
|
| 465 |
+
len(generated_candidate_valence_assignments)
|
| 466 |
+
< self.count_of_valid_valence_assignments
|
| 467 |
+
and next_valence_assignment != termination_assignment_value
|
| 468 |
+
):
|
| 469 |
+
if self.is_valence_assignment_feasible(next_valence_assignment):
|
| 470 |
+
generated_candidate_valence_assignments.append(
|
| 471 |
+
next_valence_assignment
|
| 472 |
+
)
|
| 473 |
+
next_valence_assignment = next(
|
| 474 |
+
valence_assignment_generator, termination_assignment_value
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
if len(generated_candidate_valence_assignments) == 0:
|
| 478 |
+
raise ValueError(
|
| 479 |
+
f"No valence assignments can be generated for the formula {chemical_formula}"
|
| 480 |
+
)
|
| 481 |
+
shuffle(generated_candidate_valence_assignments)
|
| 482 |
+
return generated_candidate_valence_assignments
|
| 483 |
+
|
| 484 |
+
def sample_second_edgenode_at_random(
|
| 485 |
+
self,
|
| 486 |
+
edge_start_node_i: int,
|
| 487 |
+
all_graph_nodes: list[AtomNodeForRandomTraversal],
|
| 488 |
+
open_nodes_for_sampling: dict[str, set[int]],
|
| 489 |
+
possible_candidates_type: str,
|
| 490 |
+
closed_set: set[int],
|
| 491 |
+
use_chem_element_stats: bool = False,
|
| 492 |
+
already_connected_neighbours: list[BondToNeighbouringAtom] = None,
|
| 493 |
+
):
|
| 494 |
+
"""
|
| 495 |
+
A function randomly sampling the second node for an edge
|
| 496 |
+
@param edge_start_node_i: index of the first edge node
|
| 497 |
+
@param all_graph_nodes: a list of all nodes in the molecule graph
|
| 498 |
+
@param open_nodes_for_sampling: dictionary with sets of node indices which
|
| 499 |
+
can be considered for closing the edge.
|
| 500 |
+
Each set is specified by the dictionary key:
|
| 501 |
+
"coordinate_bond_negatively_charged_targets",
|
| 502 |
+
"coordinate_bond_positively_charged_targets",
|
| 503 |
+
"covalent_bond_targets"
|
| 504 |
+
@param possible_candidates_type: the `open_nodes_for_sampling` dictionary key
|
| 505 |
+
@param closed_set: closed set for traversal
|
| 506 |
+
@param use_chem_element_stats: a boolean flag setting up usage of per chem. elements statistics about its bonds
|
| 507 |
+
@param already_connected_neighbours: an adjacency list of already sampled neighbours
|
| 508 |
+
"""
|
| 509 |
+
if not use_chem_element_stats:
|
| 510 |
+
edge_end_node_j = choice(
|
| 511 |
+
[
|
| 512 |
+
candidate_node_j
|
| 513 |
+
for candidate_node_j in open_nodes_for_sampling[
|
| 514 |
+
possible_candidates_type
|
| 515 |
+
]
|
| 516 |
+
if candidate_node_j not in closed_set
|
| 517 |
+
]
|
| 518 |
+
)
|
| 519 |
+
bond_degree = 1
|
| 520 |
+
else:
|
| 521 |
+
# checking the current state of the atom and gathering the corresponding stats
|
| 522 |
+
number_of_already_sampled_neighbours = len(already_connected_neighbours)
|
| 523 |
+
# note that the graph is undirected, start-end node refers to the random traversal only
|
| 524 |
+
start_atom = all_graph_nodes[edge_start_node_i]
|
| 525 |
+
if self.element_2_bond_stats is None:
|
| 526 |
+
raise RuntimeError(
|
| 527 |
+
"To use chem. element stats, the model has to be trained first,"
|
| 528 |
+
"to record training molecular weights with corresponding formulas."
|
| 529 |
+
)
|
| 530 |
+
# the structure of `self.element_2_bond_stats` is
|
| 531 |
+
# chem_element ->
|
| 532 |
+
# ValenceAndCharge ->
|
| 533 |
+
# number of already bonded atoms ->
|
| 534 |
+
# [already created BondToNeighbouringAtom] ->
|
| 535 |
+
# AtomWithValence ->
|
| 536 |
+
# list of (bond_type, count)+ total_count
|
| 537 |
+
|
| 538 |
+
# if we don't have stats -> fall back to sampling from all candidates
|
| 539 |
+
try:
|
| 540 |
+
element_stats = self.element_2_bond_stats[
|
| 541 |
+
start_atom.atom_with_valence.atom_type
|
| 542 |
+
][
|
| 543 |
+
ValenceAndCharge(
|
| 544 |
+
start_atom.atom_with_valence.atom_valence_and_charge.valence,
|
| 545 |
+
start_atom.atom_with_valence.atom_valence_and_charge.charge,
|
| 546 |
+
)
|
| 547 |
+
][
|
| 548 |
+
number_of_already_sampled_neighbours
|
| 549 |
+
][
|
| 550 |
+
tuple(sorted(already_connected_neighbours))
|
| 551 |
+
]
|
| 552 |
+
|
| 553 |
+
full_candidates_list = []
|
| 554 |
+
neighb_with_stats_candidates_list = []
|
| 555 |
+
neighb_with_stats_bondcounts = []
|
| 556 |
+
neighb_with_stats_bondlists = []
|
| 557 |
+
# iterating over open nodes of the corresponding bond type
|
| 558 |
+
for candidate_node_j in open_nodes_for_sampling[
|
| 559 |
+
possible_candidates_type
|
| 560 |
+
]:
|
| 561 |
+
if candidate_node_j not in closed_set:
|
| 562 |
+
# remembering all candidates in case no statistic-based option is there
|
| 563 |
+
full_candidates_list.append(candidate_node_j)
|
| 564 |
+
# checking if the candidate is present in element-specific bond stats
|
| 565 |
+
candidate_neighb_atom = all_graph_nodes[
|
| 566 |
+
candidate_node_j
|
| 567 |
+
].atom_with_valence
|
| 568 |
+
if candidate_neighb_atom in element_stats:
|
| 569 |
+
neighb_with_stats_candidates_list.append(candidate_node_j)
|
| 570 |
+
bondslist, total_bond_count = element_stats[
|
| 571 |
+
candidate_neighb_atom
|
| 572 |
+
]
|
| 573 |
+
neighb_with_stats_bondcounts.append(total_bond_count)
|
| 574 |
+
neighb_with_stats_bondlists.append(bondslist)
|
| 575 |
+
# when no stats-based neighbour remain (e.g. hydrogens are not recorded in the stats)
|
| 576 |
+
if len(neighb_with_stats_candidates_list) == 0:
|
| 577 |
+
edge_end_node_j = choice(full_candidates_list)
|
| 578 |
+
bond_degree = 1
|
| 579 |
+
else:
|
| 580 |
+
# sampling based on frequences in bond stats
|
| 581 |
+
total_bondcount_sum = sum(neighb_with_stats_bondcounts)
|
| 582 |
+
proportions = [
|
| 583 |
+
val / total_bondcount_sum
|
| 584 |
+
for val in neighb_with_stats_bondcounts
|
| 585 |
+
]
|
| 586 |
+
edge_end_node_j = np.random.choice(
|
| 587 |
+
neighb_with_stats_candidates_list, p=proportions
|
| 588 |
+
)
|
| 589 |
+
# getting i of the sampled neighbour to access its bond-stats
|
| 590 |
+
neighb_i = neighb_with_stats_candidates_list.index(edge_end_node_j)
|
| 591 |
+
# for the sampled end node, we sample the type of the bond based on the stats
|
| 592 |
+
bondtypes_possible = []
|
| 593 |
+
counts_of_possible_bondtypes = []
|
| 594 |
+
total_possible_bondtype_count = 0
|
| 595 |
+
# we leave only the bonds which current state of random generation allows
|
| 596 |
+
# i.e., we cannot sample a bond violating the current remaining degree of `edge_start_node_i`
|
| 597 |
+
start_node_remaining_degree = all_graph_nodes[
|
| 598 |
+
edge_start_node_i
|
| 599 |
+
].remaining_node_degree
|
| 600 |
+
for bondtype, count in neighb_with_stats_bondlists[neighb_i]:
|
| 601 |
+
if bondtype <= start_node_remaining_degree:
|
| 602 |
+
bondtypes_possible.append(bondtype)
|
| 603 |
+
counts_of_possible_bondtypes.append(count)
|
| 604 |
+
total_possible_bondtype_count += count
|
| 605 |
+
# if no bonds can be closed for the sampled element, fall back to sampling from full candidates list
|
| 606 |
+
if len(bondtypes_possible) == 0:
|
| 607 |
+
edge_end_node_j = choice(full_candidates_list)
|
| 608 |
+
bond_degree = 1
|
| 609 |
+
else:
|
| 610 |
+
bond_degree_proportions = [
|
| 611 |
+
num / total_possible_bondtype_count
|
| 612 |
+
for num in counts_of_possible_bondtypes
|
| 613 |
+
]
|
| 614 |
+
bond_degree = np.random.choice(
|
| 615 |
+
bondtypes_possible, p=bond_degree_proportions
|
| 616 |
+
)
|
| 617 |
+
already_connected_neighbours.append(
|
| 618 |
+
BondToNeighbouringAtom(
|
| 619 |
+
adjacent_atom=all_graph_nodes[
|
| 620 |
+
edge_end_node_j
|
| 621 |
+
].atom_with_valence,
|
| 622 |
+
bond_type=bond_degree,
|
| 623 |
+
)
|
| 624 |
+
)
|
| 625 |
+
except:
|
| 626 |
+
edge_end_node_j = choice(
|
| 627 |
+
[
|
| 628 |
+
candidate_node_j
|
| 629 |
+
for candidate_node_j in open_nodes_for_sampling[
|
| 630 |
+
possible_candidates_type
|
| 631 |
+
]
|
| 632 |
+
if candidate_node_j not in closed_set
|
| 633 |
+
]
|
| 634 |
+
)
|
| 635 |
+
bond_degree = 1
|
| 636 |
+
return edge_end_node_j, bond_degree, already_connected_neighbours
|
| 637 |
+
|
| 638 |
+
def sample_edge_at_random(
|
| 639 |
+
self,
|
| 640 |
+
all_graph_nodes: list[AtomNodeForRandomTraversal],
|
| 641 |
+
open_nodes_for_sampling: dict[str, set[int]],
|
| 642 |
+
edge_start_node_i: int = None,
|
| 643 |
+
closed_set: set[int] = None,
|
| 644 |
+
use_chem_element_stats: bool = False,
|
| 645 |
+
atom_2_already_connected_neighbours: list[list[BondToNeighbouringAtom]] = None,
|
| 646 |
+
) -> tuple[tuple[int, int], list[AtomNodeForRandomTraversal], set[int]]:
|
| 647 |
+
"""
|
| 648 |
+
Helper function to filter atoms suitable for generation of a random bond with `edge_start_node_i`
|
| 649 |
+
and sampling a random edge
|
| 650 |
+
@param all_graph_nodes: a list of all nodes in the molecule graph
|
| 651 |
+
@param edge_start_node_i: index of the first edge node
|
| 652 |
+
@param open_nodes_for_sampling: dictionary with sets of node indices which
|
| 653 |
+
can be considered for closing the edge.
|
| 654 |
+
Each set is specified by the dictionary key:
|
| 655 |
+
"coordinate_bond_negatively_charged_targets",
|
| 656 |
+
"coordinate_bond_positively_charged_targets",
|
| 657 |
+
"covalent_bond_targets"
|
| 658 |
+
@param use_chem_element_stats: a boolean flag setting up usage of per chem. elements statistics about its bonds
|
| 659 |
+
@param closed_set: closed set for traversal
|
| 660 |
+
@param atom_2_already_connected_neighbours: a mapping from atom to its adjacency list of already sampled neighbours
|
| 661 |
+
@return: a sampled edge and updated structures `all_graph_nodes`, `open_nodes_for_sampling`
|
| 662 |
+
"""
|
| 663 |
+
# sample the start node for the edge if it's not specified
|
| 664 |
+
if edge_start_node_i is None:
|
| 665 |
+
edge_start_node_i = choice(
|
| 666 |
+
sum(map(list, open_nodes_for_sampling.values()), [])
|
| 667 |
+
)
|
| 668 |
+
if closed_set is None:
|
| 669 |
+
closed_set = {edge_start_node_i}
|
| 670 |
+
# check if the start edge atom has the charge and therefore can form coordinate bond
|
| 671 |
+
can_form_coordinate_bond = (
|
| 672 |
+
all_graph_nodes[edge_start_node_i].remaining_node_charge != 0
|
| 673 |
+
)
|
| 674 |
+
# if possible, create coordinate bond at random
|
| 675 |
+
is_bond_coordinate = can_form_coordinate_bond and np.random.rand() < 0.5
|
| 676 |
+
if is_bond_coordinate:
|
| 677 |
+
start_node_charge_sign = np.sign(
|
| 678 |
+
all_graph_nodes[edge_start_node_i].remaining_node_charge
|
| 679 |
+
)
|
| 680 |
+
# if for the coordinate bond one atom is positively charged, then another must be charged negatively
|
| 681 |
+
if start_node_charge_sign > 0:
|
| 682 |
+
possible_candidates_type = "coordinate_bond_neg_charged_targets"
|
| 683 |
+
else:
|
| 684 |
+
possible_candidates_type = "coordinate_bond_pos_charged_targets"
|
| 685 |
+
else:
|
| 686 |
+
possible_candidates_type = "covalent_bond_targets"
|
| 687 |
+
|
| 688 |
+
(
|
| 689 |
+
edge_end_node_j,
|
| 690 |
+
node_degree_reduction,
|
| 691 |
+
atom_2_already_connected_neighbours[edge_start_node_i],
|
| 692 |
+
) = self.sample_second_edgenode_at_random(
|
| 693 |
+
edge_start_node_i,
|
| 694 |
+
all_graph_nodes,
|
| 695 |
+
open_nodes_for_sampling,
|
| 696 |
+
possible_candidates_type,
|
| 697 |
+
closed_set,
|
| 698 |
+
use_chem_element_stats,
|
| 699 |
+
atom_2_already_connected_neighbours[edge_start_node_i],
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
# decrease the node degrees correspondingly
|
| 703 |
+
for node_of_a_new_edge_i in [edge_start_node_i, edge_end_node_j]:
|
| 704 |
+
all_graph_nodes[
|
| 705 |
+
node_of_a_new_edge_i
|
| 706 |
+
].remaining_node_degree -= node_degree_reduction
|
| 707 |
+
# if all bonds are created for the particular atom, it is no more open for traversal
|
| 708 |
+
if all_graph_nodes[node_of_a_new_edge_i].remaining_node_degree <= 0:
|
| 709 |
+
for candidates_type in open_nodes_for_sampling.keys():
|
| 710 |
+
if node_of_a_new_edge_i in open_nodes_for_sampling[candidates_type]:
|
| 711 |
+
open_nodes_for_sampling[candidates_type].remove(
|
| 712 |
+
node_of_a_new_edge_i
|
| 713 |
+
)
|
| 714 |
+
# if the added bond was coordinate, modify the remaining charges correspondingly
|
| 715 |
+
elif is_bond_coordinate:
|
| 716 |
+
new_charge_abs_value = (
|
| 717 |
+
np.abs(all_graph_nodes[node_of_a_new_edge_i].remaining_node_charge)
|
| 718 |
+
- 1
|
| 719 |
+
)
|
| 720 |
+
# check if the node still can form coordinate bonds
|
| 721 |
+
if new_charge_abs_value == 0:
|
| 722 |
+
for candidates_type in [
|
| 723 |
+
"coordinate_bond_neg_charged_targets",
|
| 724 |
+
"coordinate_bond_pos_charged_targets",
|
| 725 |
+
]:
|
| 726 |
+
if (
|
| 727 |
+
node_of_a_new_edge_i
|
| 728 |
+
in open_nodes_for_sampling[candidates_type]
|
| 729 |
+
):
|
| 730 |
+
open_nodes_for_sampling[candidates_type].remove(
|
| 731 |
+
node_of_a_new_edge_i
|
| 732 |
+
)
|
| 733 |
+
else:
|
| 734 |
+
charge_sign = np.sign(
|
| 735 |
+
all_graph_nodes[node_of_a_new_edge_i].remaining_node_charge
|
| 736 |
+
)
|
| 737 |
+
all_graph_nodes[node_of_a_new_edge_i].remaining_node_charge = (
|
| 738 |
+
charge_sign * new_charge_abs_value
|
| 739 |
+
)
|
| 740 |
+
return (
|
| 741 |
+
(edge_start_node_i, edge_end_node_j),
|
| 742 |
+
all_graph_nodes,
|
| 743 |
+
open_nodes_for_sampling,
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
def generate_random_molecule_graphs_via_traversal(
|
| 747 |
+
self,
|
| 748 |
+
chemical_formula: str,
|
| 749 |
+
max_number_of_retries_per_valence_assignment: int = 100,
|
| 750 |
+
) -> list[Mol]:
|
| 751 |
+
"""
|
| 752 |
+
A function generating random molecule graph(s).
|
| 753 |
+
The generation process ensures that each graph is connected.
|
| 754 |
+
If any of the `self.count_of_valid_valence_assignments` enables it,
|
| 755 |
+
the function returns graph(s) without self-loops.
|
| 756 |
+
|
| 757 |
+
@param chemical_formula: a string containing the chemical formula of the molecule
|
| 758 |
+
@param max_number_of_retries_per_valence_assignment: a max count of attempts to generate a random spanning tree
|
| 759 |
+
for a given potentially feasible valence assignment
|
| 760 |
+
|
| 761 |
+
@note In the future the method can be made into a function in a separate utils module,
|
| 762 |
+
for the simplicity of codebase organization and testing purposes it's kept as the method for now
|
| 763 |
+
"""
|
| 764 |
+
# check if for the input formula the random structures have been already generated
|
| 765 |
+
if self.cache_results and chemical_formula in self.formula_2_random_smiles:
|
| 766 |
+
return self.formula_2_random_smiles[chemical_formula]
|
| 767 |
+
|
| 768 |
+
# get candidate partitions of all molecule atoms into valences
|
| 769 |
+
candidate_valence_assignments = self.get_feasible_atom_valence_assignments(
|
| 770 |
+
chemical_formula
|
| 771 |
+
)
|
| 772 |
+
# iterate over each valence assignment to all atoms, the order is random
|
| 773 |
+
assert (
|
| 774 |
+
len(candidate_valence_assignments) > 0
|
| 775 |
+
), f"No potentially feasible atom valence assignment for {chemical_formula}"
|
| 776 |
+
# number of iteration over feasible valence assignments
|
| 777 |
+
num_of_iterations_over_splits_into_valences = int(
|
| 778 |
+
np.ceil(self.max_top_k / len(candidate_valence_assignments))
|
| 779 |
+
)
|
| 780 |
+
generated_molecules = []
|
| 781 |
+
|
| 782 |
+
# we request to generate self.max_top_k molecule(s)
|
| 783 |
+
while len(generated_molecules) < self.max_top_k:
|
| 784 |
+
for _ in range(num_of_iterations_over_splits_into_valences):
|
| 785 |
+
for valence_assignment in candidate_valence_assignments:
|
| 786 |
+
# first randomly create a spanning tree of the molecule graph, to ensure the connectivity of molecule.
|
| 787 |
+
# The feasibility check `self.is_valence_assignment_feasible` inside the
|
| 788 |
+
# `self.get_feasible_atom_valence_assignments` function should ensure the possibility to create the tree.
|
| 789 |
+
spanning_tree_was_generated = False
|
| 790 |
+
spanning_tree_generation_attempts = 0
|
| 791 |
+
while (
|
| 792 |
+
not spanning_tree_was_generated
|
| 793 |
+
and spanning_tree_generation_attempts
|
| 794 |
+
< max_number_of_retries_per_valence_assignment
|
| 795 |
+
):
|
| 796 |
+
spanning_tree_generation_attempts += 1
|
| 797 |
+
# we optimistically set the value of `spanning_tree_was_generated` to True,
|
| 798 |
+
# If the current traversal do not lead to a spanning tree,
|
| 799 |
+
# then `spanning_tree_was_generated` is set to False in the code below
|
| 800 |
+
spanning_tree_was_generated = True
|
| 801 |
+
|
| 802 |
+
# prepare node list for a random edges generation
|
| 803 |
+
all_graph_nodes = []
|
| 804 |
+
for (
|
| 805 |
+
atom_with_valence,
|
| 806 |
+
num_of_atoms_in_molecule,
|
| 807 |
+
) in valence_assignment.items():
|
| 808 |
+
for _ in range(num_of_atoms_in_molecule):
|
| 809 |
+
all_graph_nodes.append(
|
| 810 |
+
AtomNodeForRandomTraversal(
|
| 811 |
+
atom_with_valence=atom_with_valence
|
| 812 |
+
)
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
# a helper structure to record already sampled bonds
|
| 816 |
+
# it is used only if we use estimated chem elements stats in the generation process
|
| 817 |
+
atom_2_already_connected_neighbours = [
|
| 818 |
+
[] for _ in range(len(all_graph_nodes))
|
| 819 |
+
]
|
| 820 |
+
|
| 821 |
+
# recording sets of nodes available for random sampling of covalent and coordinate bonds
|
| 822 |
+
coordinate_bond_neg_charged_targets = {
|
| 823 |
+
node_i
|
| 824 |
+
for node_i, node in enumerate(all_graph_nodes)
|
| 825 |
+
if np.sign(node.remaining_node_charge) == -1
|
| 826 |
+
}
|
| 827 |
+
coordinate_bond_pos_charged_targets = {
|
| 828 |
+
node_i
|
| 829 |
+
for node_i, node in enumerate(all_graph_nodes)
|
| 830 |
+
if np.sign(node.remaining_node_charge) == 1
|
| 831 |
+
}
|
| 832 |
+
covalent_bond_targets = {
|
| 833 |
+
node_i
|
| 834 |
+
for node_i, node in enumerate(all_graph_nodes)
|
| 835 |
+
if node.remaining_node_charge == 0
|
| 836 |
+
or node.remaining_node_degree
|
| 837 |
+
> np.abs(node.remaining_node_charge)
|
| 838 |
+
}
|
| 839 |
+
|
| 840 |
+
open_nodes_for_sampling = {
|
| 841 |
+
"coordinate_bond_neg_charged_targets": coordinate_bond_neg_charged_targets,
|
| 842 |
+
"coordinate_bond_pos_charged_targets": coordinate_bond_pos_charged_targets,
|
| 843 |
+
"covalent_bond_targets": covalent_bond_targets,
|
| 844 |
+
}
|
| 845 |
+
|
| 846 |
+
# the final edge list will be stored into the variable below.
|
| 847 |
+
# An edge is defined by a pair of position indices in the `all_graph_nodes` list
|
| 848 |
+
edge_list = []
|
| 849 |
+
|
| 850 |
+
# the nodes already included into the spanning tree
|
| 851 |
+
# the set is used for quick blacklisting, while the list is used for possible backtracking when
|
| 852 |
+
(
|
| 853 |
+
spanning_tree_visited_nodes_set,
|
| 854 |
+
spanning_tree_traversal_list,
|
| 855 |
+
) = (
|
| 856 |
+
set(),
|
| 857 |
+
deque(),
|
| 858 |
+
)
|
| 859 |
+
# sample a random start of spanning tree generation
|
| 860 |
+
edge_start_node_i = choice(list(range(len(all_graph_nodes))))
|
| 861 |
+
spanning_tree_visited_nodes_set.add(edge_start_node_i)
|
| 862 |
+
spanning_tree_traversal_list.append(edge_start_node_i)
|
| 863 |
+
while self.enforce_connectivity and len(
|
| 864 |
+
spanning_tree_visited_nodes_set
|
| 865 |
+
) < len(all_graph_nodes):
|
| 866 |
+
# check if the start edge atom has the charge and therefore can form coordinate bond
|
| 867 |
+
try:
|
| 868 |
+
(
|
| 869 |
+
(edge_start_node_i, edge_end_node_i),
|
| 870 |
+
all_graph_nodes,
|
| 871 |
+
open_nodes_for_sampling,
|
| 872 |
+
) = self.sample_edge_at_random(
|
| 873 |
+
all_graph_nodes,
|
| 874 |
+
open_nodes_for_sampling,
|
| 875 |
+
edge_start_node_i=edge_start_node_i,
|
| 876 |
+
closed_set=spanning_tree_visited_nodes_set,
|
| 877 |
+
use_chem_element_stats=self.estimate_chem_element_stats,
|
| 878 |
+
atom_2_already_connected_neighbours=atom_2_already_connected_neighbours,
|
| 879 |
+
)
|
| 880 |
+
except IndexError:
|
| 881 |
+
spanning_tree_was_generated = False
|
| 882 |
+
break
|
| 883 |
+
# note that the graph is undirected, start-end node refers to the random traversal only
|
| 884 |
+
edge_list.append((edge_start_node_i, edge_end_node_i))
|
| 885 |
+
# recording the node added to the random spanning tree
|
| 886 |
+
spanning_tree_visited_nodes_set.add(edge_end_node_i)
|
| 887 |
+
spanning_tree_traversal_list.append(edge_end_node_i)
|
| 888 |
+
|
| 889 |
+
# finding a start node for the next sampled edge.
|
| 890 |
+
# We have to ensure that such a node still has some degree not covered by sampling nodes.
|
| 891 |
+
# For that, we might need to backtrack.
|
| 892 |
+
candidate_for_start_node_i = edge_end_node_i
|
| 893 |
+
try:
|
| 894 |
+
while (
|
| 895 |
+
all_graph_nodes[
|
| 896 |
+
candidate_for_start_node_i
|
| 897 |
+
].remaining_node_degree
|
| 898 |
+
== 0
|
| 899 |
+
):
|
| 900 |
+
spanning_tree_traversal_list.pop()
|
| 901 |
+
candidate_for_start_node_i = (
|
| 902 |
+
spanning_tree_traversal_list[-1]
|
| 903 |
+
)
|
| 904 |
+
except IndexError:
|
| 905 |
+
spanning_tree_was_generated = False
|
| 906 |
+
break
|
| 907 |
+
edge_start_node_i = candidate_for_start_node_i
|
| 908 |
+
|
| 909 |
+
# after the spanning tree edges were sampled,
|
| 910 |
+
# now we randomly connect nodes with remaining degrees yet uncovered by sampled bonds
|
| 911 |
+
while sum(map(len, open_nodes_for_sampling.values())) >= 2:
|
| 912 |
+
try:
|
| 913 |
+
(
|
| 914 |
+
(edge_start_node_i, edge_end_node_i),
|
| 915 |
+
all_graph_nodes,
|
| 916 |
+
open_nodes_for_sampling,
|
| 917 |
+
) = self.sample_edge_at_random(
|
| 918 |
+
all_graph_nodes,
|
| 919 |
+
open_nodes_for_sampling,
|
| 920 |
+
use_chem_element_stats=self.estimate_chem_element_stats,
|
| 921 |
+
atom_2_already_connected_neighbours=atom_2_already_connected_neighbours,
|
| 922 |
+
)
|
| 923 |
+
except IndexError:
|
| 924 |
+
break
|
| 925 |
+
edge_list.append((edge_start_node_i, edge_end_node_i))
|
| 926 |
+
|
| 927 |
+
# if all nodes were covered by edges without self-loops, then we remember the generated molecule
|
| 928 |
+
if sum(map(len, open_nodes_for_sampling.values())) == 0:
|
| 929 |
+
generated_molecules.append(
|
| 930 |
+
create_rdkit_molecule_from_edge_list(
|
| 931 |
+
edge_list, all_graph_nodes
|
| 932 |
+
)
|
| 933 |
+
)
|
| 934 |
+
if len(generated_molecules) == self.max_top_k:
|
| 935 |
+
if self.cache_results:
|
| 936 |
+
self.formula_2_random_smiles[
|
| 937 |
+
chemical_formula
|
| 938 |
+
] = generated_molecules
|
| 939 |
+
return generated_molecules
|
| 940 |
+
if self.cache_results:
|
| 941 |
+
self.formula_2_random_smiles[chemical_formula] = generated_molecules
|
| 942 |
+
return generated_molecules
|
| 943 |
+
|
| 944 |
+
def training_step(
|
| 945 |
+
self, batch: dict, batch_idx: torch.Tensor
|
| 946 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 947 |
+
# recording statistics about chemical elements
|
| 948 |
+
if self.estimate_chem_element_stats:
|
| 949 |
+
if self.element_2_bond_stats is None:
|
| 950 |
+
self.element_2_bond_stats = defaultdict(dict)
|
| 951 |
+
for mol_smiles in batch["mol"]:
|
| 952 |
+
molecule = Chem.MolFromSmiles(mol_smiles)
|
| 953 |
+
# in order to work with double and single bonds instead of aromatic
|
| 954 |
+
Chem.Kekulize(molecule, clearAromaticFlags=True)
|
| 955 |
+
# we add hydrogen atoms (which are ommited by default)
|
| 956 |
+
molecule = Chem.AddHs(molecule)
|
| 957 |
+
formula = CalcMolFormula(molecule)
|
| 958 |
+
for atom in molecule.GetAtoms():
|
| 959 |
+
valence = atom.GetTotalValence()
|
| 960 |
+
charge = atom.GetFormalCharge()
|
| 961 |
+
valence_charge = ValenceAndCharge(valence, charge)
|
| 962 |
+
chem_element_type = atom.GetSymbol()
|
| 963 |
+
atom_bonds = atom.GetBonds()
|
| 964 |
+
if (
|
| 965 |
+
valence_charge
|
| 966 |
+
not in self.element_2_bond_stats[chem_element_type]
|
| 967 |
+
):
|
| 968 |
+
# for each value of atom's valence, number of neighbours we will count types of neighbouring atoms
|
| 969 |
+
self.element_2_bond_stats[chem_element_type][
|
| 970 |
+
valence_charge
|
| 971 |
+
] = dict()
|
| 972 |
+
|
| 973 |
+
all_atom_neighbours = set()
|
| 974 |
+
for bond in atom_bonds:
|
| 975 |
+
start_atom_idx = atom.GetIdx()
|
| 976 |
+
end_atom = [
|
| 977 |
+
_atom
|
| 978 |
+
for _atom in [bond.GetBeginAtom(), bond.GetEndAtom()]
|
| 979 |
+
if _atom.GetIdx() != start_atom_idx
|
| 980 |
+
][0]
|
| 981 |
+
all_atom_neighbours.add(
|
| 982 |
+
(
|
| 983 |
+
end_atom.GetSymbol(),
|
| 984 |
+
end_atom.GetTotalValence(),
|
| 985 |
+
end_atom.GetFormalCharge(),
|
| 986 |
+
bond.GetBondTypeAsDouble(),
|
| 987 |
+
)
|
| 988 |
+
)
|
| 989 |
+
|
| 990 |
+
all_neighbour_subsets = [
|
| 991 |
+
[set(subset) for subset in combinations(all_atom_neighbours, r)]
|
| 992 |
+
for r in range(len(all_atom_neighbours))
|
| 993 |
+
]
|
| 994 |
+
for neighbour_subsets_of_fixed_size in all_neighbour_subsets:
|
| 995 |
+
subset_size = len(neighbour_subsets_of_fixed_size[0])
|
| 996 |
+
# for each number of already connected atoms we record the neighbours and then possible bonds yet to be closed
|
| 997 |
+
if (
|
| 998 |
+
subset_size
|
| 999 |
+
not in self.element_2_bond_stats[chem_element_type][
|
| 1000 |
+
valence_charge
|
| 1001 |
+
]
|
| 1002 |
+
):
|
| 1003 |
+
self.element_2_bond_stats[chem_element_type][
|
| 1004 |
+
valence_charge
|
| 1005 |
+
][subset_size] = dict()
|
| 1006 |
+
for neighbours in neighbour_subsets_of_fixed_size:
|
| 1007 |
+
neighbours_tuple = tuple(sorted(neighbours))
|
| 1008 |
+
if (
|
| 1009 |
+
neighbours_tuple
|
| 1010 |
+
not in self.element_2_bond_stats[chem_element_type][
|
| 1011 |
+
valence_charge
|
| 1012 |
+
][subset_size]
|
| 1013 |
+
):
|
| 1014 |
+
self.element_2_bond_stats[chem_element_type][
|
| 1015 |
+
valence_charge
|
| 1016 |
+
][subset_size][neighbours_tuple] = defaultdict(int)
|
| 1017 |
+
remaining_bonds = all_atom_neighbours.difference(neighbours)
|
| 1018 |
+
for remaining_bonded_neighbour in remaining_bonds:
|
| 1019 |
+
self.element_2_bond_stats[chem_element_type][
|
| 1020 |
+
valence_charge
|
| 1021 |
+
][subset_size][neighbours_tuple][
|
| 1022 |
+
remaining_bonded_neighbour
|
| 1023 |
+
] += 1
|
| 1024 |
+
|
| 1025 |
+
# recording molecular weight
|
| 1026 |
+
for mol_smiles in batch["mol"]:
|
| 1027 |
+
molecule = Chem.MolFromSmiles(mol_smiles)
|
| 1028 |
+
formula = CalcMolFormula(molecule)
|
| 1029 |
+
weight = ExactMolWt(molecule)
|
| 1030 |
+
self.mol_weight_2_formulas[weight].append(formula)
|
| 1031 |
+
# Random baseline, so we return a dummy loss
|
| 1032 |
+
loss = torch.tensor(0.0, requires_grad=True)
|
| 1033 |
+
return dict(loss=loss, mols_pred=["C"])
|
| 1034 |
+
|
| 1035 |
+
def on_train_end(self) -> None:
|
| 1036 |
+
# for each molecular weight we compute proportions of recorded molecular formulas
|
| 1037 |
+
molecular_weight_2_formula_counts = {
|
| 1038 |
+
weight: Counter(formulas)
|
| 1039 |
+
for weight, formulas in self.mol_weight_2_formulas.items()
|
| 1040 |
+
}
|
| 1041 |
+
weight_2_formula_proportions = {}
|
| 1042 |
+
for weight, formula_2_count in molecular_weight_2_formula_counts.items():
|
| 1043 |
+
total_count = sum(formula_2_count.values())
|
| 1044 |
+
weight_2_formula_proportions[weight] = {
|
| 1045 |
+
formula: count / total_count
|
| 1046 |
+
for formula, count in formula_2_count.items()
|
| 1047 |
+
}
|
| 1048 |
+
# for consequent sampling using numpy.random.choice function, we store the results in the format
|
| 1049 |
+
# weight -> [[formula_1, formula_2], [proportion_of_formula_1, proportion_of_formula_2]]
|
| 1050 |
+
self.mol_weight_2_formulas = {
|
| 1051 |
+
weight: [
|
| 1052 |
+
list(formula_2_proportions.keys()),
|
| 1053 |
+
list(formula_2_proportions.values()),
|
| 1054 |
+
]
|
| 1055 |
+
for weight, formula_2_proportions in weight_2_formula_proportions.items()
|
| 1056 |
+
}
|
| 1057 |
+
# storing weights in the sorted list for the logarithmic time look-up of the closest weight value
|
| 1058 |
+
self.mol_weight_trn_values = sorted(self.mol_weight_2_formulas.keys())
|
| 1059 |
+
|
| 1060 |
+
# if chem element stats are used, then the corresponding data structure is reformated in accordance with the
|
| 1061 |
+
# description from docstring to the class __init__:
|
| 1062 |
+
# chem_element ->
|
| 1063 |
+
# ValenceAndCharge ->
|
| 1064 |
+
# number of already bonded atoms ->
|
| 1065 |
+
# [already created BondToNeighbouringAtom] ->
|
| 1066 |
+
# AtomWithValence ->
|
| 1067 |
+
# list of (bond_type, count) + total_count
|
| 1068 |
+
if self.estimate_chem_element_stats:
|
| 1069 |
+
element_2_bond_stats = defaultdict(dict)
|
| 1070 |
+
for (
|
| 1071 |
+
chem_element,
|
| 1072 |
+
valence_charge_2_stats,
|
| 1073 |
+
) in self.element_2_bond_stats.items():
|
| 1074 |
+
for valence_charge, num_bonds_2_stats in valence_charge_2_stats.items():
|
| 1075 |
+
element_2_bond_stats[chem_element][valence_charge] = dict()
|
| 1076 |
+
for num_bonds, bonds_2_stats in num_bonds_2_stats.items():
|
| 1077 |
+
element_2_bond_stats[chem_element][valence_charge][
|
| 1078 |
+
num_bonds
|
| 1079 |
+
] = dict()
|
| 1080 |
+
for (
|
| 1081 |
+
bonds,
|
| 1082 |
+
neighb_atom_with_valence_2_stats,
|
| 1083 |
+
) in bonds_2_stats.items():
|
| 1084 |
+
present_bonds_sorted = tuple(
|
| 1085 |
+
sorted(
|
| 1086 |
+
[
|
| 1087 |
+
BondToNeighbouringAtom(
|
| 1088 |
+
adjacent_atom=AtomWithValence(
|
| 1089 |
+
atom_type=bond[0],
|
| 1090 |
+
atom_valence_and_charge=ValenceAndCharge(
|
| 1091 |
+
valence=bond[1], charge=bond[2]
|
| 1092 |
+
),
|
| 1093 |
+
),
|
| 1094 |
+
bond_type=bond[3],
|
| 1095 |
+
)
|
| 1096 |
+
for bond in bonds
|
| 1097 |
+
]
|
| 1098 |
+
)
|
| 1099 |
+
)
|
| 1100 |
+
element_2_bond_stats[chem_element][valence_charge][
|
| 1101 |
+
num_bonds
|
| 1102 |
+
][present_bonds_sorted] = defaultdict(list)
|
| 1103 |
+
for (
|
| 1104 |
+
neighb_atom_type,
|
| 1105 |
+
neighb_atom_valence,
|
| 1106 |
+
neighb_atom_charge,
|
| 1107 |
+
bondtype,
|
| 1108 |
+
), count in neighb_atom_with_valence_2_stats.items():
|
| 1109 |
+
neighbouring_atom = AtomWithValence(
|
| 1110 |
+
atom_type=neighb_atom_type,
|
| 1111 |
+
atom_valence_and_charge=ValenceAndCharge(
|
| 1112 |
+
valence=neighb_atom_valence,
|
| 1113 |
+
charge=neighb_atom_charge,
|
| 1114 |
+
),
|
| 1115 |
+
)
|
| 1116 |
+
element_2_bond_stats[chem_element][valence_charge][
|
| 1117 |
+
num_bonds
|
| 1118 |
+
][present_bonds_sorted][neighbouring_atom].append(
|
| 1119 |
+
(bondtype, count)
|
| 1120 |
+
)
|
| 1121 |
+
# computing total count of all bound per neighbouring atom
|
| 1122 |
+
for (
|
| 1123 |
+
neighbouring_atom,
|
| 1124 |
+
list_of_bondtype_counts,
|
| 1125 |
+
) in element_2_bond_stats[chem_element][valence_charge][
|
| 1126 |
+
num_bonds
|
| 1127 |
+
][
|
| 1128 |
+
present_bonds_sorted
|
| 1129 |
+
].items():
|
| 1130 |
+
total_count_of_bonds = sum(
|
| 1131 |
+
map(lambda x: x[1], list_of_bondtype_counts)
|
| 1132 |
+
)
|
| 1133 |
+
element_2_bond_stats[chem_element][valence_charge][
|
| 1134 |
+
num_bonds
|
| 1135 |
+
][present_bonds_sorted][neighbouring_atom] = (
|
| 1136 |
+
list_of_bondtype_counts,
|
| 1137 |
+
total_count_of_bonds,
|
| 1138 |
+
)
|
| 1139 |
+
self.element_2_bond_stats = element_2_bond_stats
|
| 1140 |
+
|
| 1141 |
+
def sample_formula_with_the_closest_molecular_weight(
|
| 1142 |
+
self, molecular_weight: float
|
| 1143 |
+
) -> str:
|
| 1144 |
+
"""
|
| 1145 |
+
A method sampling chemical formula observed in training data with the closest weight to `molecular_weight`
|
| 1146 |
+
@param molecular_weight: Molecular weight of a structure to be generated
|
| 1147 |
+
"""
|
| 1148 |
+
if self.mol_weight_trn_values is None:
|
| 1149 |
+
raise RuntimeError(
|
| 1150 |
+
"For random denovo generation without known formula, the model has to be trained first,"
|
| 1151 |
+
"to record training molecular weights with corresponding formulas."
|
| 1152 |
+
)
|
| 1153 |
+
# finding a place in the sorted array for insertion of the `molecular_weight`, while preserving sorted order
|
| 1154 |
+
idx_of_closest_larger = bisect.bisect_left(
|
| 1155 |
+
self.mol_weight_trn_values, molecular_weight
|
| 1156 |
+
)
|
| 1157 |
+
# check if the exact same molecular weight was observed in training data, otherwise select the closest weight
|
| 1158 |
+
if molecular_weight == self.mol_weight_trn_values[idx_of_closest_larger]:
|
| 1159 |
+
idx_of_closest = idx_of_closest_larger
|
| 1160 |
+
elif idx_of_closest_larger > 0:
|
| 1161 |
+
# determining the closest molecular weight out of both neighbours
|
| 1162 |
+
idx_of_closest_smaller = idx_of_closest_larger - 1
|
| 1163 |
+
weight_difference_with_smaller_neighbour = (
|
| 1164 |
+
molecular_weight - self.mol_weight_trn_values[idx_of_closest_smaller]
|
| 1165 |
+
)
|
| 1166 |
+
weight_difference_with_larger_neighbour = (
|
| 1167 |
+
self.mol_weight_trn_values[idx_of_closest_larger] - molecular_weight
|
| 1168 |
+
)
|
| 1169 |
+
if (
|
| 1170 |
+
weight_difference_with_larger_neighbour
|
| 1171 |
+
< weight_difference_with_smaller_neighbour
|
| 1172 |
+
):
|
| 1173 |
+
idx_of_closest = idx_of_closest_larger
|
| 1174 |
+
else:
|
| 1175 |
+
idx_of_closest = idx_of_closest_smaller
|
| 1176 |
+
else:
|
| 1177 |
+
idx_of_closest = 0
|
| 1178 |
+
# the value of the molecular weight observed in training labels, which is the closest to `molecular_weight`
|
| 1179 |
+
closest_observed_molecular_weight = self.mol_weight_trn_values[idx_of_closest]
|
| 1180 |
+
# getting chemical formulas observed for this molecular weight
|
| 1181 |
+
# self.mol_weight_2_formulas is a dictionary containing the following mapping
|
| 1182 |
+
# weight -> [[formula_1, formula_2], [proportion_of_formula_1, proportion_of_formula_2]]
|
| 1183 |
+
feasible_formulas, formula_proportions = self.mol_weight_2_formulas[
|
| 1184 |
+
closest_observed_molecular_weight
|
| 1185 |
+
]
|
| 1186 |
+
# if just one formula is known, it is returned directly
|
| 1187 |
+
if len(feasible_formulas) == 1:
|
| 1188 |
+
return feasible_formulas[0]
|
| 1189 |
+
# otherwise we randomly sample in accordance with proportions
|
| 1190 |
+
return np.random.choice(feasible_formulas, p=formula_proportions)
|
| 1191 |
+
|
| 1192 |
+
def step(
|
| 1193 |
+
self, batch: dict, stage: Stage = Stage.NONE
|
| 1194 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 1195 |
+
mols = batch["mol"] # List of SMILES of length batch_size
|
| 1196 |
+
|
| 1197 |
+
# If formula_known is True, we should generate molecules with the same formula as label (`mols` above)
|
| 1198 |
+
# If formula_known is False, we should generate any molecule with the same mass as label
|
| 1199 |
+
|
| 1200 |
+
# obtaining molecule objects from SMILES
|
| 1201 |
+
molecules = [Chem.MolFromSmiles(smiles) for smiles in mols]
|
| 1202 |
+
# getting the formulas
|
| 1203 |
+
if self.formula_known:
|
| 1204 |
+
formulas = [CalcMolFormula(molecule) for molecule in molecules]
|
| 1205 |
+
else:
|
| 1206 |
+
molecular_weights = [ExactMolWt(molecule) for molecule in molecules]
|
| 1207 |
+
formulas = [
|
| 1208 |
+
self.sample_formula_with_the_closest_molecular_weight(mol_weight)
|
| 1209 |
+
for mol_weight in molecular_weights
|
| 1210 |
+
]
|
| 1211 |
+
# (bs, k) list of rdkit molecules
|
| 1212 |
+
mols_pred = [
|
| 1213 |
+
self.generate_random_molecule_graphs_via_traversal(formula)
|
| 1214 |
+
for formula in formulas
|
| 1215 |
+
]
|
| 1216 |
+
|
| 1217 |
+
for predicted_mol_group in mols_pred:
|
| 1218 |
+
for mol in predicted_mol_group:
|
| 1219 |
+
Chem.RemoveHs(mol)
|
| 1220 |
+
|
| 1221 |
+
# list of predicted smiles
|
| 1222 |
+
smiles_pred = [
|
| 1223 |
+
[
|
| 1224 |
+
Chem.MolToSmiles(mol_candidate)
|
| 1225 |
+
for mol_candidate in candidates_per_input_mol
|
| 1226 |
+
]
|
| 1227 |
+
for candidates_per_input_mol in mols_pred
|
| 1228 |
+
]
|
| 1229 |
+
|
| 1230 |
+
# Random baseline, so we return a dummy loss
|
| 1231 |
+
loss = torch.tensor(0.0, requires_grad=True)
|
| 1232 |
+
return dict(loss=loss, mols_pred=smiles_pred)
|
| 1233 |
+
|
| 1234 |
+
def configure_optimizers(self):
|
| 1235 |
+
# No optimizer needed for a random baseline
|
| 1236 |
+
return None
|
| 1237 |
+
|
| 1238 |
+
|
| 1239 |
+
# element valences taken from sources like https://sciencenotes.org/element-valency-pdf
|
| 1240 |
+
# the first list contains the typical valences, each tuple is a valence value with the corresponding charge
|
| 1241 |
+
ELEMENT_VALENCES = {
|
| 1242 |
+
"H": (
|
| 1243 |
+
[ValenceAndCharge(valence=1, charge=0)],
|
| 1244 |
+
[ValenceAndCharge(valence=0, charge=0), ValenceAndCharge(valence=1, charge=-1)],
|
| 1245 |
+
),
|
| 1246 |
+
"He": ([ValenceAndCharge(valence=0, charge=0)], []),
|
| 1247 |
+
"Li": (
|
| 1248 |
+
[ValenceAndCharge(valence=1, charge=0)],
|
| 1249 |
+
[ValenceAndCharge(valence=1, charge=-1)],
|
| 1250 |
+
),
|
| 1251 |
+
"Be": ([ValenceAndCharge(valence=2, charge=0)], []),
|
| 1252 |
+
"B": (
|
| 1253 |
+
[ValenceAndCharge(valence=3, charge=0), ValenceAndCharge(valence=4, charge=-1)],
|
| 1254 |
+
[ValenceAndCharge(valence=2, charge=0), ValenceAndCharge(valence=1, charge=0)],
|
| 1255 |
+
),
|
| 1256 |
+
"C": (
|
| 1257 |
+
[ValenceAndCharge(valence=4, charge=0)],
|
| 1258 |
+
[
|
| 1259 |
+
ValenceAndCharge(valence=3, charge=-1),
|
| 1260 |
+
ValenceAndCharge(valence=2, charge=0),
|
| 1261 |
+
ValenceAndCharge(valence=2, charge=-1),
|
| 1262 |
+
ValenceAndCharge(valence=1, charge=0),
|
| 1263 |
+
ValenceAndCharge(valence=1, charge=-1),
|
| 1264 |
+
],
|
| 1265 |
+
),
|
| 1266 |
+
"N": (
|
| 1267 |
+
[ValenceAndCharge(valence=3, charge=0), ValenceAndCharge(valence=4, charge=1)],
|
| 1268 |
+
[
|
| 1269 |
+
ValenceAndCharge(valence=2, charge=-1),
|
| 1270 |
+
ValenceAndCharge(valence=5, charge=0),
|
| 1271 |
+
ValenceAndCharge(valence=1, charge=0),
|
| 1272 |
+
ValenceAndCharge(valence=0, charge=0),
|
| 1273 |
+
ValenceAndCharge(valence=1, charge=-1),
|
| 1274 |
+
],
|
| 1275 |
+
),
|
| 1276 |
+
"O": (
|
| 1277 |
+
[ValenceAndCharge(valence=2, charge=0), ValenceAndCharge(valence=1, charge=-1)],
|
| 1278 |
+
[ValenceAndCharge(valence=3, charge=1)],
|
| 1279 |
+
),
|
| 1280 |
+
"F": ([ValenceAndCharge(valence=1, charge=0)], []),
|
| 1281 |
+
"Ne": ([ValenceAndCharge(valence=0, charge=0)], []),
|
| 1282 |
+
"Na": (
|
| 1283 |
+
[ValenceAndCharge(valence=1, charge=0)],
|
| 1284 |
+
[],
|
| 1285 |
+
),
|
| 1286 |
+
"Mg": ([ValenceAndCharge(valence=2, charge=0)], []),
|
| 1287 |
+
"Al": (
|
| 1288 |
+
[ValenceAndCharge(valence=3, charge=0)],
|
| 1289 |
+
[ValenceAndCharge(valence=1, charge=0)],
|
| 1290 |
+
),
|
| 1291 |
+
"Si": (
|
| 1292 |
+
[ValenceAndCharge(valence=4, charge=0)],
|
| 1293 |
+
[],
|
| 1294 |
+
),
|
| 1295 |
+
"P": (
|
| 1296 |
+
[ValenceAndCharge(valence=5, charge=0)],
|
| 1297 |
+
[
|
| 1298 |
+
ValenceAndCharge(valence=4, charge=1),
|
| 1299 |
+
ValenceAndCharge(valence=3, charge=0),
|
| 1300 |
+
ValenceAndCharge(valence=2, charge=0),
|
| 1301 |
+
ValenceAndCharge(valence=1, charge=0),
|
| 1302 |
+
],
|
| 1303 |
+
),
|
| 1304 |
+
"S": (
|
| 1305 |
+
[ValenceAndCharge(valence=2, charge=0), ValenceAndCharge(valence=6, charge=0)],
|
| 1306 |
+
[
|
| 1307 |
+
ValenceAndCharge(valence=4, charge=0),
|
| 1308 |
+
ValenceAndCharge(valence=1, charge=-1),
|
| 1309 |
+
ValenceAndCharge(valence=3, charge=1),
|
| 1310 |
+
],
|
| 1311 |
+
),
|
| 1312 |
+
"Cl": (
|
| 1313 |
+
[ValenceAndCharge(valence=1, charge=0)],
|
| 1314 |
+
[],
|
| 1315 |
+
),
|
| 1316 |
+
"Ar": ([ValenceAndCharge(valence=0, charge=0)], []),
|
| 1317 |
+
"K": (
|
| 1318 |
+
[ValenceAndCharge(valence=1, charge=0)],
|
| 1319 |
+
[],
|
| 1320 |
+
),
|
| 1321 |
+
"Ca": ([ValenceAndCharge(valence=2, charge=0)], []),
|
| 1322 |
+
"Sc": (
|
| 1323 |
+
[ValenceAndCharge(valence=3, charge=0)],
|
| 1324 |
+
[ValenceAndCharge(valence=2, charge=0), ValenceAndCharge(valence=1, charge=0)],
|
| 1325 |
+
),
|
| 1326 |
+
"Ti": (
|
| 1327 |
+
[ValenceAndCharge(valence=4, charge=0)],
|
| 1328 |
+
[
|
| 1329 |
+
ValenceAndCharge(valence=3, charge=0),
|
| 1330 |
+
ValenceAndCharge(valence=2, charge=0),
|
| 1331 |
+
ValenceAndCharge(valence=0, charge=0),
|
| 1332 |
+
],
|
| 1333 |
+
),
|
| 1334 |
+
"V": (
|
| 1335 |
+
[
|
| 1336 |
+
ValenceAndCharge(valence=5, charge=0),
|
| 1337 |
+
ValenceAndCharge(valence=4, charge=0),
|
| 1338 |
+
ValenceAndCharge(valence=3, charge=0),
|
| 1339 |
+
],
|
| 1340 |
+
[
|
| 1341 |
+
ValenceAndCharge(valence=2, charge=0),
|
| 1342 |
+
ValenceAndCharge(valence=1, charge=0),
|
| 1343 |
+
ValenceAndCharge(valence=0, charge=0),
|
| 1344 |
+
],
|
| 1345 |
+
),
|
| 1346 |
+
"Cr": (
|
| 1347 |
+
[
|
| 1348 |
+
ValenceAndCharge(valence=6, charge=0),
|
| 1349 |
+
ValenceAndCharge(valence=3, charge=0),
|
| 1350 |
+
ValenceAndCharge(valence=2, charge=0),
|
| 1351 |
+
],
|
| 1352 |
+
[
|
| 1353 |
+
ValenceAndCharge(valence=5, charge=0),
|
| 1354 |
+
ValenceAndCharge(valence=4, charge=0),
|
| 1355 |
+
ValenceAndCharge(valence=1, charge=0),
|
| 1356 |
+
ValenceAndCharge(valence=0, charge=0),
|
| 1357 |
+
],
|
| 1358 |
+
),
|
| 1359 |
+
"Mn": (
|
| 1360 |
+
[
|
| 1361 |
+
ValenceAndCharge(valence=7, charge=0),
|
| 1362 |
+
ValenceAndCharge(valence=4, charge=0),
|
| 1363 |
+
ValenceAndCharge(valence=2, charge=0),
|
| 1364 |
+
],
|
| 1365 |
+
[
|
| 1366 |
+
ValenceAndCharge(valence=6, charge=0),
|
| 1367 |
+
ValenceAndCharge(valence=5, charge=0),
|
| 1368 |
+
ValenceAndCharge(valence=3, charge=0),
|
| 1369 |
+
ValenceAndCharge(valence=1, charge=0),
|
| 1370 |
+
ValenceAndCharge(valence=0, charge=0),
|
| 1371 |
+
],
|
| 1372 |
+
),
|
| 1373 |
+
"Fe": (
|
| 1374 |
+
[ValenceAndCharge(valence=2, charge=0), ValenceAndCharge(valence=3, charge=0)],
|
| 1375 |
+
[
|
| 1376 |
+
ValenceAndCharge(valence=6, charge=0),
|
| 1377 |
+
ValenceAndCharge(valence=5, charge=0),
|
| 1378 |
+
ValenceAndCharge(valence=4, charge=0),
|
| 1379 |
+
ValenceAndCharge(valence=1, charge=0),
|
| 1380 |
+
ValenceAndCharge(valence=0, charge=0),
|
| 1381 |
+
],
|
| 1382 |
+
),
|
| 1383 |
+
"Co": (
|
| 1384 |
+
[ValenceAndCharge(valence=2, charge=0), ValenceAndCharge(valence=3, charge=0)],
|
| 1385 |
+
[
|
| 1386 |
+
ValenceAndCharge(valence=5, charge=0),
|
| 1387 |
+
ValenceAndCharge(valence=4, charge=0),
|
| 1388 |
+
ValenceAndCharge(valence=1, charge=0),
|
| 1389 |
+
ValenceAndCharge(valence=0, charge=0),
|
| 1390 |
+
],
|
| 1391 |
+
),
|
| 1392 |
+
"Ni": (
|
| 1393 |
+
[ValenceAndCharge(valence=2, charge=0)],
|
| 1394 |
+
[
|
| 1395 |
+
ValenceAndCharge(valence=6, charge=0),
|
| 1396 |
+
ValenceAndCharge(valence=4, charge=0),
|
| 1397 |
+
ValenceAndCharge(valence=3, charge=0),
|
| 1398 |
+
ValenceAndCharge(valence=1, charge=0),
|
| 1399 |
+
ValenceAndCharge(valence=0, charge=0),
|
| 1400 |
+
],
|
| 1401 |
+
),
|
| 1402 |
+
"Cu": (
|
| 1403 |
+
[ValenceAndCharge(valence=2, charge=0), ValenceAndCharge(valence=1, charge=0)],
|
| 1404 |
+
[
|
| 1405 |
+
ValenceAndCharge(valence=4, charge=0),
|
| 1406 |
+
ValenceAndCharge(valence=3, charge=0),
|
| 1407 |
+
ValenceAndCharge(valence=0, charge=0),
|
| 1408 |
+
],
|
| 1409 |
+
),
|
| 1410 |
+
"Zn": (
|
| 1411 |
+
[ValenceAndCharge(valence=2, charge=0)],
|
| 1412 |
+
[ValenceAndCharge(valence=1, charge=0), ValenceAndCharge(valence=0, charge=0)],
|
| 1413 |
+
),
|
| 1414 |
+
"Ga": (
|
| 1415 |
+
[ValenceAndCharge(valence=3, charge=0)],
|
| 1416 |
+
[ValenceAndCharge(valence=2, charge=0), ValenceAndCharge(valence=1, charge=0)],
|
| 1417 |
+
),
|
| 1418 |
+
"Ge": (
|
| 1419 |
+
[ValenceAndCharge(valence=4, charge=0)],
|
| 1420 |
+
[
|
| 1421 |
+
ValenceAndCharge(valence=3, charge=0),
|
| 1422 |
+
ValenceAndCharge(valence=2, charge=0),
|
| 1423 |
+
ValenceAndCharge(valence=1, charge=0),
|
| 1424 |
+
],
|
| 1425 |
+
),
|
| 1426 |
+
"As": (
|
| 1427 |
+
[ValenceAndCharge(valence=5, charge=0), ValenceAndCharge(valence=4, charge=1)],
|
| 1428 |
+
[],
|
| 1429 |
+
),
|
| 1430 |
+
"Se": (
|
| 1431 |
+
[ValenceAndCharge(valence=2, charge=0)],
|
| 1432 |
+
[],
|
| 1433 |
+
),
|
| 1434 |
+
"Br": (
|
| 1435 |
+
[ValenceAndCharge(valence=1, charge=0)],
|
| 1436 |
+
[],
|
| 1437 |
+
),
|
| 1438 |
+
"Kr": (
|
| 1439 |
+
[ValenceAndCharge(valence=0, charge=0)],
|
| 1440 |
+
[ValenceAndCharge(valence=2, charge=0)],
|
| 1441 |
+
),
|
| 1442 |
+
"Rb": (
|
| 1443 |
+
[ValenceAndCharge(valence=1, charge=0)],
|
| 1444 |
+
[],
|
| 1445 |
+
),
|
| 1446 |
+
"Sr": ([ValenceAndCharge(valence=2, charge=0)], []),
|
| 1447 |
+
"Y": (
|
| 1448 |
+
[ValenceAndCharge(valence=3, charge=0)],
|
| 1449 |
+
[ValenceAndCharge(valence=2, charge=0)],
|
| 1450 |
+
),
|
| 1451 |
+
"Zr": (
|
| 1452 |
+
[ValenceAndCharge(valence=4, charge=0)],
|
| 1453 |
+
[
|
| 1454 |
+
ValenceAndCharge(valence=3, charge=0),
|
| 1455 |
+
ValenceAndCharge(valence=2, charge=0),
|
| 1456 |
+
ValenceAndCharge(valence=1, charge=0),
|
| 1457 |
+
ValenceAndCharge(valence=0, charge=0),
|
| 1458 |
+
],
|
| 1459 |
+
),
|
| 1460 |
+
"Nb": (
|
| 1461 |
+
[ValenceAndCharge(valence=5, charge=0)],
|
| 1462 |
+
[
|
| 1463 |
+
ValenceAndCharge(valence=4, charge=0),
|
| 1464 |
+
ValenceAndCharge(valence=3, charge=0),
|
| 1465 |
+
ValenceAndCharge(valence=2, charge=0),
|
| 1466 |
+
ValenceAndCharge(valence=1, charge=0),
|
| 1467 |
+
ValenceAndCharge(valence=0, charge=0),
|
| 1468 |
+
],
|
| 1469 |
+
),
|
| 1470 |
+
"Mo": (
|
| 1471 |
+
[ValenceAndCharge(valence=6, charge=0), ValenceAndCharge(valence=4, charge=0)],
|
| 1472 |
+
[
|
| 1473 |
+
ValenceAndCharge(valence=5, charge=0),
|
| 1474 |
+
ValenceAndCharge(valence=3, charge=0),
|
| 1475 |
+
ValenceAndCharge(valence=2, charge=0),
|
| 1476 |
+
ValenceAndCharge(valence=1, charge=0),
|
| 1477 |
+
ValenceAndCharge(valence=0, charge=0),
|
| 1478 |
+
],
|
| 1479 |
+
),
|
| 1480 |
+
"Tc": (
|
| 1481 |
+
[ValenceAndCharge(valence=7, charge=0), ValenceAndCharge(valence=4, charge=0)],
|
| 1482 |
+
[
|
| 1483 |
+
ValenceAndCharge(valence=6, charge=0),
|
| 1484 |
+
ValenceAndCharge(valence=5, charge=0),
|
| 1485 |
+
ValenceAndCharge(valence=3, charge=0),
|
| 1486 |
+
ValenceAndCharge(valence=2, charge=0),
|
| 1487 |
+
ValenceAndCharge(valence=1, charge=0),
|
| 1488 |
+
ValenceAndCharge(valence=0, charge=0),
|
| 1489 |
+
],
|
| 1490 |
+
),
|
| 1491 |
+
"Ru": (
|
| 1492 |
+
[ValenceAndCharge(valence=4, charge=0), ValenceAndCharge(valence=3, charge=0)],
|
| 1493 |
+
[
|
| 1494 |
+
ValenceAndCharge(valence=8, charge=0),
|
| 1495 |
+
ValenceAndCharge(valence=7, charge=0),
|
| 1496 |
+
ValenceAndCharge(valence=6, charge=0),
|
| 1497 |
+
ValenceAndCharge(valence=5, charge=0),
|
| 1498 |
+
ValenceAndCharge(valence=2, charge=0),
|
| 1499 |
+
ValenceAndCharge(valence=1, charge=0),
|
| 1500 |
+
ValenceAndCharge(valence=0, charge=0),
|
| 1501 |
+
],
|
| 1502 |
+
),
|
| 1503 |
+
"Rh": (
|
| 1504 |
+
[ValenceAndCharge(valence=3, charge=0)],
|
| 1505 |
+
[
|
| 1506 |
+
ValenceAndCharge(valence=6, charge=0),
|
| 1507 |
+
ValenceAndCharge(valence=5, charge=0),
|
| 1508 |
+
ValenceAndCharge(valence=4, charge=0),
|
| 1509 |
+
ValenceAndCharge(valence=2, charge=0),
|
| 1510 |
+
ValenceAndCharge(valence=1, charge=0),
|
| 1511 |
+
ValenceAndCharge(valence=0, charge=0),
|
| 1512 |
+
],
|
| 1513 |
+
),
|
| 1514 |
+
"Pd": (
|
| 1515 |
+
[ValenceAndCharge(valence=4, charge=0), ValenceAndCharge(valence=2, charge=0)],
|
| 1516 |
+
[ValenceAndCharge(valence=0, charge=0)],
|
| 1517 |
+
),
|
| 1518 |
+
"Ag": (
|
| 1519 |
+
[ValenceAndCharge(valence=1, charge=0)],
|
| 1520 |
+
[
|
| 1521 |
+
ValenceAndCharge(valence=3, charge=0),
|
| 1522 |
+
ValenceAndCharge(valence=2, charge=0),
|
| 1523 |
+
ValenceAndCharge(valence=0, charge=0),
|
| 1524 |
+
],
|
| 1525 |
+
),
|
| 1526 |
+
"Cd": (
|
| 1527 |
+
[ValenceAndCharge(valence=2, charge=0)],
|
| 1528 |
+
[ValenceAndCharge(valence=1, charge=0)],
|
| 1529 |
+
),
|
| 1530 |
+
"In": (
|
| 1531 |
+
[ValenceAndCharge(valence=3, charge=0)],
|
| 1532 |
+
[ValenceAndCharge(valence=2, charge=0), ValenceAndCharge(valence=1, charge=0)],
|
| 1533 |
+
),
|
| 1534 |
+
"Sn": (
|
| 1535 |
+
[ValenceAndCharge(valence=2, charge=0)],
|
| 1536 |
+
[ValenceAndCharge(valence=4, charge=0)],
|
| 1537 |
+
),
|
| 1538 |
+
"Sb": (
|
| 1539 |
+
[ValenceAndCharge(valence=3, charge=0)],
|
| 1540 |
+
[ValenceAndCharge(valence=5, charge=0), ValenceAndCharge(valence=3, charge=-1)],
|
| 1541 |
+
),
|
| 1542 |
+
"Te": (
|
| 1543 |
+
[ValenceAndCharge(valence=4, charge=0)],
|
| 1544 |
+
[ValenceAndCharge(valence=2, charge=0), ValenceAndCharge(valence=6, charge=0)],
|
| 1545 |
+
),
|
| 1546 |
+
"I": (
|
| 1547 |
+
[ValenceAndCharge(valence=1, charge=0)],
|
| 1548 |
+
[
|
| 1549 |
+
ValenceAndCharge(valence=1, charge=0),
|
| 1550 |
+
ValenceAndCharge(valence=3, charge=0),
|
| 1551 |
+
ValenceAndCharge(valence=7, charge=0),
|
| 1552 |
+
ValenceAndCharge(valence=0, charge=0),
|
| 1553 |
+
],
|
| 1554 |
+
),
|
| 1555 |
+
"Xe": (
|
| 1556 |
+
[ValenceAndCharge(valence=0, charge=0)],
|
| 1557 |
+
[
|
| 1558 |
+
ValenceAndCharge(valence=2, charge=0),
|
| 1559 |
+
ValenceAndCharge(valence=4, charge=0),
|
| 1560 |
+
ValenceAndCharge(valence=6, charge=0),
|
| 1561 |
+
ValenceAndCharge(valence=8, charge=0),
|
| 1562 |
+
],
|
| 1563 |
+
),
|
| 1564 |
+
"Cs": (
|
| 1565 |
+
[ValenceAndCharge(valence=1, charge=0)],
|
| 1566 |
+
[],
|
| 1567 |
+
),
|
| 1568 |
+
"Ba": ([ValenceAndCharge(valence=2, charge=0)], []),
|
| 1569 |
+
"La": ([ValenceAndCharge(valence=3, charge=0)], []),
|
| 1570 |
+
"Ce": (
|
| 1571 |
+
[ValenceAndCharge(valence=3, charge=0)],
|
| 1572 |
+
[ValenceAndCharge(valence=4, charge=0)],
|
| 1573 |
+
),
|
| 1574 |
+
"Pr": (
|
| 1575 |
+
[ValenceAndCharge(valence=3, charge=0)],
|
| 1576 |
+
[ValenceAndCharge(valence=4, charge=0)],
|
| 1577 |
+
),
|
| 1578 |
+
"Nd": ([ValenceAndCharge(valence=3, charge=0)], []),
|
| 1579 |
+
"Pm": ([ValenceAndCharge(valence=3, charge=0)], []),
|
| 1580 |
+
"Sm": ([ValenceAndCharge(valence=3, charge=0)], []),
|
| 1581 |
+
"Eu": (
|
| 1582 |
+
[ValenceAndCharge(valence=3, charge=0)],
|
| 1583 |
+
[ValenceAndCharge(valence=2, charge=0)],
|
| 1584 |
+
),
|
| 1585 |
+
"Gd": ([ValenceAndCharge(valence=3, charge=0)], []),
|
| 1586 |
+
"Tb": (
|
| 1587 |
+
[ValenceAndCharge(valence=3, charge=0)],
|
| 1588 |
+
[ValenceAndCharge(valence=4, charge=0)],
|
| 1589 |
+
),
|
| 1590 |
+
"Dy": ([ValenceAndCharge(valence=3, charge=0)], []),
|
| 1591 |
+
"Ho": ([ValenceAndCharge(valence=3, charge=0)], []),
|
| 1592 |
+
"Er": ([ValenceAndCharge(valence=3, charge=0)], []),
|
| 1593 |
+
"Tm": ([ValenceAndCharge(valence=3, charge=0)], []),
|
| 1594 |
+
"Yb": (
|
| 1595 |
+
[ValenceAndCharge(valence=3, charge=0)],
|
| 1596 |
+
[ValenceAndCharge(valence=2, charge=0)],
|
| 1597 |
+
),
|
| 1598 |
+
"Lu": ([ValenceAndCharge(valence=3, charge=0)], []),
|
| 1599 |
+
"Hf": ([ValenceAndCharge(valence=4, charge=0)], []),
|
| 1600 |
+
"Ta": ([ValenceAndCharge(valence=5, charge=0)], []),
|
| 1601 |
+
"W": (
|
| 1602 |
+
[ValenceAndCharge(valence=6, charge=0), ValenceAndCharge(valence=4, charge=0)],
|
| 1603 |
+
[
|
| 1604 |
+
ValenceAndCharge(valence=5, charge=0),
|
| 1605 |
+
ValenceAndCharge(valence=3, charge=0),
|
| 1606 |
+
ValenceAndCharge(valence=2, charge=0),
|
| 1607 |
+
],
|
| 1608 |
+
),
|
| 1609 |
+
"Re": (
|
| 1610 |
+
[
|
| 1611 |
+
ValenceAndCharge(valence=5, charge=0),
|
| 1612 |
+
ValenceAndCharge(valence=4, charge=0),
|
| 1613 |
+
ValenceAndCharge(valence=3, charge=0),
|
| 1614 |
+
],
|
| 1615 |
+
[
|
| 1616 |
+
ValenceAndCharge(valence=7, charge=0),
|
| 1617 |
+
ValenceAndCharge(valence=6, charge=0),
|
| 1618 |
+
ValenceAndCharge(valence=2, charge=0),
|
| 1619 |
+
ValenceAndCharge(valence=1, charge=0),
|
| 1620 |
+
ValenceAndCharge(valence=0, charge=0),
|
| 1621 |
+
],
|
| 1622 |
+
),
|
| 1623 |
+
"Os": (
|
| 1624 |
+
[ValenceAndCharge(valence=4, charge=0)],
|
| 1625 |
+
[
|
| 1626 |
+
ValenceAndCharge(valence=8, charge=0),
|
| 1627 |
+
ValenceAndCharge(valence=6, charge=0),
|
| 1628 |
+
ValenceAndCharge(valence=2, charge=0),
|
| 1629 |
+
],
|
| 1630 |
+
),
|
| 1631 |
+
"Ir": (
|
| 1632 |
+
[ValenceAndCharge(valence=4, charge=0), ValenceAndCharge(valence=3, charge=0)],
|
| 1633 |
+
[ValenceAndCharge(valence=6, charge=0), ValenceAndCharge(valence=4, charge=0)],
|
| 1634 |
+
),
|
| 1635 |
+
"Pt": (
|
| 1636 |
+
[ValenceAndCharge(valence=2, charge=0)],
|
| 1637 |
+
[ValenceAndCharge(valence=4, charge=0)],
|
| 1638 |
+
),
|
| 1639 |
+
"Au": (
|
| 1640 |
+
[ValenceAndCharge(valence=3, charge=0)],
|
| 1641 |
+
[ValenceAndCharge(valence=1, charge=0)],
|
| 1642 |
+
),
|
| 1643 |
+
"Hg": (
|
| 1644 |
+
[ValenceAndCharge(valence=2, charge=0)],
|
| 1645 |
+
[ValenceAndCharge(valence=1, charge=0)],
|
| 1646 |
+
),
|
| 1647 |
+
"Tl": (
|
| 1648 |
+
[ValenceAndCharge(valence=3, charge=0)],
|
| 1649 |
+
[ValenceAndCharge(valence=1, charge=0)],
|
| 1650 |
+
),
|
| 1651 |
+
"Pb": (
|
| 1652 |
+
[ValenceAndCharge(valence=4, charge=0)],
|
| 1653 |
+
[ValenceAndCharge(valence=2, charge=0)],
|
| 1654 |
+
),
|
| 1655 |
+
"Bi": (
|
| 1656 |
+
[ValenceAndCharge(valence=3, charge=0), ValenceAndCharge(valence=1, charge=0)],
|
| 1657 |
+
[ValenceAndCharge(valence=5, charge=0)],
|
| 1658 |
+
),
|
| 1659 |
+
"Po": (
|
| 1660 |
+
[ValenceAndCharge(valence=4, charge=0)],
|
| 1661 |
+
[ValenceAndCharge(valence=2, charge=0)],
|
| 1662 |
+
),
|
| 1663 |
+
"At": (
|
| 1664 |
+
[ValenceAndCharge(valence=1, charge=0)],
|
| 1665 |
+
[
|
| 1666 |
+
ValenceAndCharge(valence=5, charge=0),
|
| 1667 |
+
ValenceAndCharge(valence=3, charge=0),
|
| 1668 |
+
ValenceAndCharge(valence=7, charge=0),
|
| 1669 |
+
],
|
| 1670 |
+
),
|
| 1671 |
+
"Rn": (
|
| 1672 |
+
[ValenceAndCharge(valence=0, charge=0)],
|
| 1673 |
+
[ValenceAndCharge(valence=2, charge=0)],
|
| 1674 |
+
),
|
| 1675 |
+
"Fr": ([ValenceAndCharge(valence=1, charge=0)], []),
|
| 1676 |
+
"Ra": ([ValenceAndCharge(valence=2, charge=0)], []),
|
| 1677 |
+
"Ac": ([ValenceAndCharge(valence=3, charge=0)], []),
|
| 1678 |
+
"Th": ([ValenceAndCharge(valence=4, charge=0)], []),
|
| 1679 |
+
"Pa": (
|
| 1680 |
+
[ValenceAndCharge(valence=5, charge=0)],
|
| 1681 |
+
[ValenceAndCharge(valence=4, charge=0)],
|
| 1682 |
+
),
|
| 1683 |
+
"U": (
|
| 1684 |
+
[ValenceAndCharge(valence=6, charge=0)],
|
| 1685 |
+
[
|
| 1686 |
+
ValenceAndCharge(valence=5, charge=0),
|
| 1687 |
+
ValenceAndCharge(valence=4, charge=0),
|
| 1688 |
+
ValenceAndCharge(valence=3, charge=0),
|
| 1689 |
+
],
|
| 1690 |
+
),
|
| 1691 |
+
"Np": (
|
| 1692 |
+
[ValenceAndCharge(valence=7, charge=0)],
|
| 1693 |
+
[
|
| 1694 |
+
ValenceAndCharge(valence=6, charge=0),
|
| 1695 |
+
ValenceAndCharge(valence=5, charge=0),
|
| 1696 |
+
ValenceAndCharge(valence=4, charge=0),
|
| 1697 |
+
ValenceAndCharge(valence=3, charge=0),
|
| 1698 |
+
],
|
| 1699 |
+
),
|
| 1700 |
+
"Pu": (
|
| 1701 |
+
[ValenceAndCharge(valence=7, charge=0), ValenceAndCharge(valence=4, charge=0)],
|
| 1702 |
+
[
|
| 1703 |
+
ValenceAndCharge(valence=6, charge=0),
|
| 1704 |
+
ValenceAndCharge(valence=5, charge=0),
|
| 1705 |
+
ValenceAndCharge(valence=3, charge=0),
|
| 1706 |
+
],
|
| 1707 |
+
),
|
| 1708 |
+
"Am": (
|
| 1709 |
+
[ValenceAndCharge(valence=3, charge=0)],
|
| 1710 |
+
[ValenceAndCharge(valence=5, charge=0), ValenceAndCharge(valence=4, charge=0)],
|
| 1711 |
+
),
|
| 1712 |
+
"Cm": (
|
| 1713 |
+
[
|
| 1714 |
+
ValenceAndCharge(valence=6, charge=0),
|
| 1715 |
+
ValenceAndCharge(valence=5, charge=0),
|
| 1716 |
+
ValenceAndCharge(valence=3, charge=0),
|
| 1717 |
+
],
|
| 1718 |
+
[],
|
| 1719 |
+
),
|
| 1720 |
+
"Bk": (
|
| 1721 |
+
[ValenceAndCharge(valence=3, charge=0)],
|
| 1722 |
+
[ValenceAndCharge(valence=4, charge=0)],
|
| 1723 |
+
),
|
| 1724 |
+
"Cf": ([ValenceAndCharge(valence=3, charge=0)], []),
|
| 1725 |
+
"Es": ([ValenceAndCharge(valence=3, charge=0)], []),
|
| 1726 |
+
"Fm": ([ValenceAndCharge(valence=3, charge=0)], []),
|
| 1727 |
+
"Md": ([ValenceAndCharge(valence=3, charge=0)], []),
|
| 1728 |
+
"No": (
|
| 1729 |
+
[ValenceAndCharge(valence=3, charge=0)],
|
| 1730 |
+
[ValenceAndCharge(valence=2, charge=0)],
|
| 1731 |
+
),
|
| 1732 |
+
"Lr": ([ValenceAndCharge(valence=3, charge=0)], []),
|
| 1733 |
+
"Rf": ([ValenceAndCharge(valence=4, charge=0)], []),
|
| 1734 |
+
"Db": ([ValenceAndCharge(valence=5, charge=0)], []),
|
| 1735 |
+
"Sg": ([ValenceAndCharge(valence=6, charge=0)], []),
|
| 1736 |
+
"Bh": ([ValenceAndCharge(valence=7, charge=0)], []),
|
| 1737 |
+
"Hs": ([ValenceAndCharge(valence=8, charge=0)], []),
|
| 1738 |
+
"Mt": ([ValenceAndCharge(valence=8, charge=0)], []),
|
| 1739 |
+
"Ds": ([ValenceAndCharge(valence=8, charge=0)], []),
|
| 1740 |
+
"Rg": ([ValenceAndCharge(valence=8, charge=0)], []),
|
| 1741 |
+
"Cn": ([ValenceAndCharge(valence=2, charge=0)], []),
|
| 1742 |
+
"Nh": ([ValenceAndCharge(valence=3, charge=0)], []),
|
| 1743 |
+
"Fl": ([ValenceAndCharge(valence=4, charge=0)], []),
|
| 1744 |
+
"Mc": ([ValenceAndCharge(valence=3, charge=0)], []),
|
| 1745 |
+
"Lv": ([ValenceAndCharge(valence=4, charge=0)], []),
|
| 1746 |
+
"Ts": ([ValenceAndCharge(valence=7, charge=0)], []),
|
| 1747 |
+
"Og": ([ValenceAndCharge(valence=0, charge=0)], []),
|
| 1748 |
+
"+": ([ValenceAndCharge(valence=1, charge=1)], []),
|
| 1749 |
+
"-": ([ValenceAndCharge(valence=1, charge=-1)], []),
|
| 1750 |
+
}
|
massspecgym/models/de_novo/smiles_tranformer.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import typing as T
|
| 5 |
+
from torch_geometric.nn import MLP
|
| 6 |
+
from massspecgym.models.tokenizers import SpecialTokensBaseTokenizer
|
| 7 |
+
from massspecgym.data.transforms import MolToFormulaVector
|
| 8 |
+
from massspecgym.models.base import Stage
|
| 9 |
+
from massspecgym.models.de_novo.base import DeNovoMassSpecGymModel
|
| 10 |
+
from massspecgym.definitions import PAD_TOKEN, SOS_TOKEN, EOS_TOKEN
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SmilesTransformer(DeNovoMassSpecGymModel):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
input_dim: int,
|
| 17 |
+
d_model: int,
|
| 18 |
+
nhead: int,
|
| 19 |
+
num_encoder_layers: int,
|
| 20 |
+
num_decoder_layers: int,
|
| 21 |
+
smiles_tokenizer: SpecialTokensBaseTokenizer,
|
| 22 |
+
start_token: str = SOS_TOKEN,
|
| 23 |
+
end_token: str = EOS_TOKEN,
|
| 24 |
+
pad_token: str = PAD_TOKEN,
|
| 25 |
+
dropout: float = 0.1,
|
| 26 |
+
max_smiles_len: int = 200,
|
| 27 |
+
k_predictions: int = 1,
|
| 28 |
+
temperature: T.Optional[float] = 1.0,
|
| 29 |
+
pre_norm: bool = False,
|
| 30 |
+
chemical_formula: bool = False,
|
| 31 |
+
*args,
|
| 32 |
+
**kwargs
|
| 33 |
+
):
|
| 34 |
+
super().__init__(*args, **kwargs)
|
| 35 |
+
self.smiles_tokenizer = smiles_tokenizer
|
| 36 |
+
self.vocab_size = smiles_tokenizer.get_vocab_size()
|
| 37 |
+
for token in [start_token, end_token, pad_token]:
|
| 38 |
+
assert token in smiles_tokenizer.get_vocab(), f"Token {token} not found in tokenizer vocabulary."
|
| 39 |
+
self.start_token_id = smiles_tokenizer.token_to_id(start_token)
|
| 40 |
+
self.end_token_id = smiles_tokenizer.token_to_id(end_token)
|
| 41 |
+
self.pad_token_id = smiles_tokenizer.token_to_id(pad_token)
|
| 42 |
+
|
| 43 |
+
self.d_model = d_model
|
| 44 |
+
self.max_smiles_len = max_smiles_len
|
| 45 |
+
self.k_predictions = k_predictions
|
| 46 |
+
self.temperature = temperature
|
| 47 |
+
if self.k_predictions == 1: # TODO: this logic should be changed because sampling with k = 1 also makes sense
|
| 48 |
+
self.temperature = None
|
| 49 |
+
|
| 50 |
+
self.src_encoder = nn.Linear(input_dim, d_model)
|
| 51 |
+
self.tgt_embedding = nn.Embedding(self.vocab_size, d_model)
|
| 52 |
+
self.transformer = nn.Transformer(
|
| 53 |
+
d_model=d_model,
|
| 54 |
+
nhead=nhead,
|
| 55 |
+
num_encoder_layers=num_encoder_layers,
|
| 56 |
+
num_decoder_layers=num_decoder_layers,
|
| 57 |
+
dim_feedforward=4 * d_model,
|
| 58 |
+
dropout=dropout,
|
| 59 |
+
norm_first=pre_norm
|
| 60 |
+
)
|
| 61 |
+
self.tgt_decoder = nn.Linear(d_model, self.vocab_size)
|
| 62 |
+
|
| 63 |
+
self.chemical_formula = chemical_formula
|
| 64 |
+
if self.chemical_formula:
|
| 65 |
+
self.formula_mlp = MLP(
|
| 66 |
+
in_channels=MolToFormulaVector.num_elements(),
|
| 67 |
+
hidden_channels=MolToFormulaVector.num_elements(),
|
| 68 |
+
out_channels=d_model,
|
| 69 |
+
num_layers=1,
|
| 70 |
+
dropout=dropout,
|
| 71 |
+
norm=None
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
self.criterion = nn.CrossEntropyLoss()
|
| 75 |
+
|
| 76 |
+
def forward(self, batch):
|
| 77 |
+
|
| 78 |
+
spec = batch["spec"] # (batch_size, seq_len, in_dim)
|
| 79 |
+
smiles = batch["mol"] # List of SMILES of length batch_size
|
| 80 |
+
|
| 81 |
+
smiles = self.smiles_tokenizer.encode_batch(smiles)
|
| 82 |
+
smiles = [s.ids for s in smiles]
|
| 83 |
+
smiles = torch.tensor(smiles, device=spec.device) # (batch_size, seq_len)
|
| 84 |
+
|
| 85 |
+
# Generating padding masks for variable-length sequences
|
| 86 |
+
src_key_padding_mask = self.generate_src_padding_mask(spec)
|
| 87 |
+
tgt_key_padding_mask = self.generate_tgt_padding_mask(smiles)
|
| 88 |
+
|
| 89 |
+
# Create target mask (causal mask)
|
| 90 |
+
tgt_seq_len = smiles.size(1)
|
| 91 |
+
tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_seq_len).to(smiles.device)
|
| 92 |
+
|
| 93 |
+
# Preapre inputs for transformer teacher forcing
|
| 94 |
+
src = spec.permute(1, 0, 2) # (seq_len, batch_size, in_dim)
|
| 95 |
+
smiles = smiles.permute(1, 0) # (seq_len, batch_size)
|
| 96 |
+
tgt = smiles[:-1, :]
|
| 97 |
+
tgt_mask = tgt_mask[:-1, :-1]
|
| 98 |
+
src_key_padding_mask = src_key_padding_mask
|
| 99 |
+
tgt_key_padding_mask = tgt_key_padding_mask[:, :-1]
|
| 100 |
+
|
| 101 |
+
# Input and output embeddings
|
| 102 |
+
src = self.src_encoder(src) # (seq_len, batch_size, d_model)
|
| 103 |
+
if self.chemical_formula:
|
| 104 |
+
formula_emb = self.formula_mlp(batch["formula"]) # (batch_size, d_model)
|
| 105 |
+
src = src + formula_emb.unsqueeze(0) # (seq_len, batch_size, d_model) + (1, batch_size, d_model)
|
| 106 |
+
src = src * (self.d_model**0.5)
|
| 107 |
+
tgt = self.tgt_embedding(tgt) * (self.d_model**0.5) # (seq_len, batch_size, d_model)
|
| 108 |
+
|
| 109 |
+
# Transformer forward pass
|
| 110 |
+
memory = self.transformer.encoder(src, src_key_padding_mask=src_key_padding_mask)
|
| 111 |
+
output = self.transformer.decoder(tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask)
|
| 112 |
+
|
| 113 |
+
# Logits to vocabulary
|
| 114 |
+
output = self.tgt_decoder(output) # (seq_len, batch_size, vocab_size)
|
| 115 |
+
|
| 116 |
+
# Reshape before returning
|
| 117 |
+
smiles_pred = output.view(-1, self.vocab_size)
|
| 118 |
+
smiles = smiles[1:, :].contiguous().view(-1)
|
| 119 |
+
return smiles_pred, smiles
|
| 120 |
+
|
| 121 |
+
def step(self, batch: dict, stage: Stage = Stage.NONE) -> dict:
|
| 122 |
+
|
| 123 |
+
# Forward pass
|
| 124 |
+
smiles_pred, smiles = self.forward(batch)
|
| 125 |
+
|
| 126 |
+
# Compute loss
|
| 127 |
+
loss = self.criterion(smiles_pred, smiles)
|
| 128 |
+
|
| 129 |
+
# Generate SMILES strings
|
| 130 |
+
if stage in self.log_only_loss_at_stages:
|
| 131 |
+
mols_pred = None
|
| 132 |
+
else:
|
| 133 |
+
mols_pred = self.decode_smiles(batch)
|
| 134 |
+
|
| 135 |
+
return dict(loss=loss, mols_pred=mols_pred)
|
| 136 |
+
|
| 137 |
+
def generate_src_padding_mask(self, spec):
|
| 138 |
+
return spec.sum(-1) == 0
|
| 139 |
+
|
| 140 |
+
def generate_tgt_padding_mask(self, smiles):
|
| 141 |
+
return smiles == self.pad_token_id
|
| 142 |
+
|
| 143 |
+
def decode_smiles(self, batch):
|
| 144 |
+
|
| 145 |
+
decoded_smiles_str = []
|
| 146 |
+
for _ in range(self.k_predictions):
|
| 147 |
+
decoded_smiles = self.greedy_decode(
|
| 148 |
+
batch,
|
| 149 |
+
max_len=self.max_smiles_len,
|
| 150 |
+
temperature=self.temperature,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
decoded_smiles = [seq.tolist() for seq in decoded_smiles]
|
| 154 |
+
decoded_smiles_str.append(self.smiles_tokenizer.decode_batch(decoded_smiles))
|
| 155 |
+
|
| 156 |
+
# Transpose from (k, batch_size) to (batch_size, k)
|
| 157 |
+
decoded_smiles_str = list(map(list, zip(*decoded_smiles_str)))
|
| 158 |
+
|
| 159 |
+
return decoded_smiles_str
|
| 160 |
+
|
| 161 |
+
def greedy_decode(self, batch, max_len, temperature):
|
| 162 |
+
|
| 163 |
+
with torch.inference_mode():
|
| 164 |
+
|
| 165 |
+
spec = batch["spec"] # (batch_size, seq_len, in_dim)
|
| 166 |
+
src_key_padding_mask = self.generate_src_padding_mask(spec)
|
| 167 |
+
|
| 168 |
+
spec = spec.permute(1, 0, 2) # (seq_len, batch_size, in_dim)
|
| 169 |
+
src = self.src_encoder(spec) # (seq_len, batch_size, d_model)
|
| 170 |
+
if self.chemical_formula:
|
| 171 |
+
formula_emb = self.formula_mlp(batch["formula"]) # (batch_size, d_model)
|
| 172 |
+
src = src + formula_emb.unsqueeze(0) # (seq_len, batch_size, d_model) + (1, batch_size, d_model)
|
| 173 |
+
src = src * (self.d_model**0.5)
|
| 174 |
+
memory = self.transformer.encoder(src, src_key_padding_mask=src_key_padding_mask)
|
| 175 |
+
|
| 176 |
+
batch_size = src.size(1)
|
| 177 |
+
out_tokens = torch.ones(1, batch_size).fill_(self.start_token_id).type(torch.long).to(spec.device)
|
| 178 |
+
|
| 179 |
+
for _ in range(max_len - 1):
|
| 180 |
+
tgt = self.tgt_embedding(out_tokens) * (self.d_model**0.5)
|
| 181 |
+
tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.size(0)).to(src.device)
|
| 182 |
+
out = self.transformer.decoder(tgt, memory, tgt_mask=tgt_mask)
|
| 183 |
+
out = self.tgt_decoder(out[-1, :]) # (batch_size, vocab_size)
|
| 184 |
+
|
| 185 |
+
# Select next token
|
| 186 |
+
if self.temperature is None:
|
| 187 |
+
probs = F.softmax(out, dim=-1)
|
| 188 |
+
next_token = torch.argmax(probs, dim=-1) # (batch_size,)
|
| 189 |
+
else:
|
| 190 |
+
probs = F.softmax(out / temperature, dim=-1)
|
| 191 |
+
next_token = torch.multinomial(probs, num_samples=1).squeeze(1) # (batch_size,)
|
| 192 |
+
|
| 193 |
+
next_token = next_token.unsqueeze(0) # (1, batch_size)
|
| 194 |
+
|
| 195 |
+
out_tokens = torch.cat([out_tokens, next_token], dim=0)
|
| 196 |
+
if torch.all(next_token == self.end_token_id):
|
| 197 |
+
break
|
| 198 |
+
|
| 199 |
+
out_tokens = out_tokens.permute(1, 0) # (batch_size, seq_len)
|
| 200 |
+
return out_tokens
|
massspecgym/models/layers.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Reproduced from https://github.com/pluskal-lab/DreaMS/blob/main/dreams/models/layers/fourier_features.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from math import ceil
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class FourierFeatures(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
A module for generating Fourier features for input data. This module maps input data
|
| 11 |
+
to a higher-dimensional space using sinusoidal functions, enhancing the representation
|
| 12 |
+
capabilities for various tasks.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
strategy (str): Strategy for generating frequency components. Available options are
|
| 16 |
+
'random', 'voronov_et_al', and 'dreams'. Each option corresponds to a certain paper:
|
| 17 |
+
- 'random': https://doi.org/10.48550/arXiv.2006.10739.
|
| 18 |
+
- 'voronov_et_al': https://doi.org/10.48550/arXiv.2207.02980.
|
| 19 |
+
- 'dreams': https://doi.org/10.26434/chemrxiv-2023-kss3r-v2.
|
| 20 |
+
x_min (float, optional): The minimum value for generating frequencies. Defaults to 1e-4.
|
| 21 |
+
x_max (float, optional): The maximum value for generating frequencies. Defaults to 1000.
|
| 22 |
+
trainable (bool, optional): If True, the frequencies are treated as trainable parameters.
|
| 23 |
+
Defaults to False.
|
| 24 |
+
funcs (str, optional): Specifies the trigonometric functions to use. Options are 'both',
|
| 25 |
+
'sin', and 'cos'. Defaults to 'both'.
|
| 26 |
+
sigma (float, optional): Standard deviation used for random frequency initialization
|
| 27 |
+
when strategy is 'random'. Defaults to 10.
|
| 28 |
+
num_freqs (int, optional): Number of frequency components to generate. Defaults to 512.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
strategy='dreams',
|
| 34 |
+
x_min=1e-4,
|
| 35 |
+
x_max=1000,
|
| 36 |
+
trainable=False,
|
| 37 |
+
funcs="both",
|
| 38 |
+
sigma=10,
|
| 39 |
+
num_freqs=512,
|
| 40 |
+
):
|
| 41 |
+
assert funcs in {"both", "sin", "cos"}, "funcs must be 'both', 'sin', or 'cos'"
|
| 42 |
+
assert 0 < x_min < 1, "x_min must be a positive fraction"
|
| 43 |
+
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.funcs = funcs
|
| 46 |
+
self.strategy = strategy
|
| 47 |
+
self.trainable = trainable
|
| 48 |
+
self.num_freqs = num_freqs
|
| 49 |
+
|
| 50 |
+
if strategy == "random":
|
| 51 |
+
self.b = torch.randn(num_freqs) * sigma
|
| 52 |
+
elif self.strategy == "voronov_et_al":
|
| 53 |
+
self.b = torch.tensor(
|
| 54 |
+
[
|
| 55 |
+
1 / (x_min * (x_max / x_min) ** (2 * i / (num_freqs - 2)))
|
| 56 |
+
for i in range(1, num_freqs)
|
| 57 |
+
],
|
| 58 |
+
)
|
| 59 |
+
elif self.strategy == "dreams":
|
| 60 |
+
self.b = torch.tensor(
|
| 61 |
+
[1 / (x_min * i) for i in range(2, ceil(1 / x_min), 2)]
|
| 62 |
+
+ [1 / (1 * i) for i in range(2, ceil(x_max), 1)],
|
| 63 |
+
)
|
| 64 |
+
else:
|
| 65 |
+
raise ValueError(f"Unknown strategy: {strategy}")
|
| 66 |
+
|
| 67 |
+
self.b = self.b.unsqueeze(0)
|
| 68 |
+
self.b = nn.Parameter(self.b, requires_grad=self.trainable)
|
| 69 |
+
self.register_parameter("Fourier frequencies", self.b)
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
def num_features(self):
|
| 73 |
+
"""
|
| 74 |
+
Returns the number of features generated by the FourierFeatures module.
|
| 75 |
+
If both sine and cosine functions are used, the number of features is doubled.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
int: The number of features.
|
| 79 |
+
"""
|
| 80 |
+
return self.b.shape[1] if self.funcs != "both" else 2 * self.b.shape[1]
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
"""
|
| 84 |
+
Applies the Fourier transformation to the input data.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
x (torch.Tensor): Input tensor of shape (batch_size, input_dim) to transform.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
torch.Tensor: Fourier features.
|
| 91 |
+
"""
|
| 92 |
+
x = 2 * torch.pi * x @ self.b
|
| 93 |
+
|
| 94 |
+
if self.funcs == "both":
|
| 95 |
+
x = torch.cat((torch.cos(x), torch.sin(x)), dim=-1)
|
| 96 |
+
elif self.funcs == "cos":
|
| 97 |
+
x = torch.cos(x)
|
| 98 |
+
elif self.funcs == "sin":
|
| 99 |
+
x = torch.sin(x)
|
| 100 |
+
|
| 101 |
+
return x
|
massspecgym/models/retrieval/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import RetrievalMassSpecGymModel
|
| 2 |
+
from .random import RandomRetrieval
|
| 3 |
+
from .deepsets import DeepSetsRetrieval
|
| 4 |
+
from .fingerprint_ffn import FingerprintFFNRetrieval
|
| 5 |
+
from .from_dict import FromDictRetrieval
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"RetrievalMassSpecGymModel",
|
| 9 |
+
"RandomRetrieval",
|
| 10 |
+
"DeepSetsRetrieval",
|
| 11 |
+
"FingerprintFFNRetrieval",
|
| 12 |
+
"FromDictRetrieval"
|
| 13 |
+
]
|
massspecgym/models/retrieval/base.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import typing as T
|
| 2 |
+
from abc import ABC
|
| 3 |
+
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import torch
|
| 6 |
+
from torchmetrics import CosineSimilarity, MeanMetric
|
| 7 |
+
from torchmetrics.functional.retrieval import retrieval_hit_rate
|
| 8 |
+
from torch_geometric.utils import unbatch
|
| 9 |
+
|
| 10 |
+
from massspecgym.models.base import MassSpecGymModel, Stage
|
| 11 |
+
import massspecgym.utils as utils
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class RetrievalMassSpecGymModel(MassSpecGymModel, ABC):
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
at_ks: T.Iterable[int] = (1, 5, 20),
|
| 19 |
+
myopic_mces_kwargs: T.Optional[T.Mapping] = None,
|
| 20 |
+
*args,
|
| 21 |
+
**kwargs
|
| 22 |
+
):
|
| 23 |
+
super().__init__(*args, **kwargs)
|
| 24 |
+
self.at_ks = at_ks
|
| 25 |
+
self.myopic_mces = utils.MyopicMCES(**(myopic_mces_kwargs or {}))
|
| 26 |
+
|
| 27 |
+
def on_batch_end(
|
| 28 |
+
self, outputs: T.Any, batch: dict, batch_idx: int, stage: Stage
|
| 29 |
+
) -> None:
|
| 30 |
+
"""
|
| 31 |
+
Compute evaluation metrics for the retrieval model based on the batch and corresponding
|
| 32 |
+
predictions.
|
| 33 |
+
"""
|
| 34 |
+
self.log(
|
| 35 |
+
f"{stage.to_pref()}loss",
|
| 36 |
+
outputs['loss'],
|
| 37 |
+
batch_size=batch['spec'].size(0),
|
| 38 |
+
sync_dist=True,
|
| 39 |
+
prog_bar=True,
|
| 40 |
+
)
|
| 41 |
+
if stage in self.log_only_loss_at_stages:
|
| 42 |
+
return
|
| 43 |
+
|
| 44 |
+
metric_vals = {}
|
| 45 |
+
metric_vals |= self.evaluate_retrieval_step(
|
| 46 |
+
outputs["scores"],
|
| 47 |
+
batch["labels"],
|
| 48 |
+
batch["batch_ptr"],
|
| 49 |
+
stage=stage,
|
| 50 |
+
)
|
| 51 |
+
metric_vals |= self.evaluate_mces_at_1(
|
| 52 |
+
outputs["scores"],
|
| 53 |
+
batch["labels"],
|
| 54 |
+
batch["smiles"],
|
| 55 |
+
batch["candidates_smiles"],
|
| 56 |
+
batch["batch_ptr"],
|
| 57 |
+
stage=stage,
|
| 58 |
+
)
|
| 59 |
+
if stage == Stage.TEST and self.df_test_path is not None:
|
| 60 |
+
self._update_df_test(metric_vals)
|
| 61 |
+
|
| 62 |
+
def evaluate_retrieval_step(
|
| 63 |
+
self,
|
| 64 |
+
scores: torch.Tensor,
|
| 65 |
+
labels: torch.Tensor,
|
| 66 |
+
batch_ptr: torch.Tensor,
|
| 67 |
+
stage: Stage,
|
| 68 |
+
) -> dict[str, torch.Tensor]:
|
| 69 |
+
"""
|
| 70 |
+
Main evaluation method for the retrieval models. The retrieval step is evaluated by
|
| 71 |
+
computing the hit rate at different top-k values.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
scores (torch.Tensor): Concatenated scores for all candidates for all samples in the
|
| 75 |
+
batch
|
| 76 |
+
labels (torch.Tensor): Concatenated True/False labels for all candidates for all samples
|
| 77 |
+
in the batch
|
| 78 |
+
batch_ptr (torch.Tensor): Number of each sample's candidates in the concatenated tensors
|
| 79 |
+
"""
|
| 80 |
+
# Initialize return dictionary to store metric values per sample
|
| 81 |
+
metric_vals = {}
|
| 82 |
+
|
| 83 |
+
# Evaluate hitrate at different top-k values
|
| 84 |
+
indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
|
| 85 |
+
scores = unbatch(scores, indexes)
|
| 86 |
+
labels = unbatch(labels, indexes)
|
| 87 |
+
|
| 88 |
+
for at_k in self.at_ks:
|
| 89 |
+
hit_rates = []
|
| 90 |
+
for scores_sample, labels_sample in zip(scores, labels):
|
| 91 |
+
hit_rates.append(retrieval_hit_rate(scores_sample, labels_sample, top_k=at_k))
|
| 92 |
+
hit_rates = torch.tensor(hit_rates, device=batch_ptr.device)
|
| 93 |
+
|
| 94 |
+
metric_name = f"{stage.to_pref()}hit_rate@{at_k}"
|
| 95 |
+
self._update_metric(
|
| 96 |
+
metric_name,
|
| 97 |
+
MeanMetric,
|
| 98 |
+
(hit_rates,),
|
| 99 |
+
batch_size=batch_ptr.size(0),
|
| 100 |
+
bootstrap=stage == Stage.TEST
|
| 101 |
+
)
|
| 102 |
+
metric_vals[metric_name] = hit_rates
|
| 103 |
+
|
| 104 |
+
return metric_vals
|
| 105 |
+
|
| 106 |
+
def evaluate_mces_at_1(
|
| 107 |
+
self,
|
| 108 |
+
scores: torch.Tensor,
|
| 109 |
+
labels: torch.Tensor,
|
| 110 |
+
smiles: list[str],
|
| 111 |
+
candidates_smiles: list[str],
|
| 112 |
+
batch_ptr: torch.Tensor,
|
| 113 |
+
stage: Stage,
|
| 114 |
+
) -> dict[str, torch.Tensor]:
|
| 115 |
+
"""
|
| 116 |
+
TODO
|
| 117 |
+
"""
|
| 118 |
+
if labels.sum() != len(smiles):
|
| 119 |
+
raise ValueError("MCES@1 evaluation currently supports exactly 1 positive candidate per sample.")
|
| 120 |
+
|
| 121 |
+
# Initialize return dictionary to store metric values per sample
|
| 122 |
+
metric_vals = {}
|
| 123 |
+
|
| 124 |
+
# Get top-1 predicted molecules for each ground-truth sample
|
| 125 |
+
smiles_pred_top_1 = []
|
| 126 |
+
batch_ptr = torch.cumsum(batch_ptr, dim=0)
|
| 127 |
+
for i, j in zip(torch.cat([torch.tensor([0], device=batch_ptr.device), batch_ptr]), batch_ptr):
|
| 128 |
+
scores_sample = scores[i:j]
|
| 129 |
+
top_1_idx = i + torch.argmax(scores_sample)
|
| 130 |
+
smiles_pred_top_1.append(candidates_smiles[top_1_idx])
|
| 131 |
+
|
| 132 |
+
# Calculate MCES distance between top-1 predicted molecules and ground truth
|
| 133 |
+
mces_dists = [
|
| 134 |
+
self.myopic_mces(sm, sm_pred)
|
| 135 |
+
for sm, sm_pred in zip(smiles, smiles_pred_top_1)
|
| 136 |
+
]
|
| 137 |
+
mces_dists = torch.tensor(mces_dists, device=scores.device)
|
| 138 |
+
|
| 139 |
+
# Log
|
| 140 |
+
metric_name = f"{stage.to_pref()}mces@1"
|
| 141 |
+
self._update_metric(
|
| 142 |
+
metric_name,
|
| 143 |
+
MeanMetric,
|
| 144 |
+
(mces_dists,),
|
| 145 |
+
batch_size=len(mces_dists),
|
| 146 |
+
bootstrap=stage == Stage.TEST
|
| 147 |
+
)
|
| 148 |
+
metric_vals[metric_name] = mces_dists
|
| 149 |
+
|
| 150 |
+
return metric_vals
|
| 151 |
+
|
| 152 |
+
def evaluate_fingerprint_step(
|
| 153 |
+
self,
|
| 154 |
+
y_true: torch.Tensor,
|
| 155 |
+
y_pred: torch.Tensor,
|
| 156 |
+
stage: Stage,
|
| 157 |
+
) -> None:
|
| 158 |
+
"""
|
| 159 |
+
Utility evaluation method to assess the quality of predicted fingerprints. This method is
|
| 160 |
+
not a part of the necessary evaluation logic (not called in the `on_batch_end` method)
|
| 161 |
+
since retrieval models are not bound to predict fingerprints.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
y_true (torch.Tensor): [batch_size, fingerprint_size] tensor of true fingerprints
|
| 165 |
+
y_pred (torch.Tensor): [batch_size, fingerprint_size] tensor of predicted fingerprints
|
| 166 |
+
"""
|
| 167 |
+
# Cosine similarity between predicted and true fingerprints
|
| 168 |
+
self._update_metric(
|
| 169 |
+
f"{stage.to_pref()}fingerprint_cos_sim",
|
| 170 |
+
CosineSimilarity,
|
| 171 |
+
(y_pred, y_true),
|
| 172 |
+
batch_size=y_true.size(0),
|
| 173 |
+
metric_kwargs=dict(reduction="mean")
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def test_step(
|
| 177 |
+
self,
|
| 178 |
+
batch: dict,
|
| 179 |
+
batch_idx: torch.Tensor
|
| 180 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 181 |
+
outputs = super().test_step(batch, batch_idx)
|
| 182 |
+
|
| 183 |
+
# Get sorted candidate SMILES based on the predicted scores for each sample
|
| 184 |
+
if self.df_test_path is not None:
|
| 185 |
+
indexes = utils.batch_ptr_to_batch_idx(batch['batch_ptr'])
|
| 186 |
+
scores = unbatch(outputs['scores'], indexes)
|
| 187 |
+
candidates_smiles = utils.unbatch_list(batch['candidates_smiles'], indexes)
|
| 188 |
+
sorted_candidate_smiles = []
|
| 189 |
+
for scores_sample, candidates_smiles_sample in zip(scores, candidates_smiles):
|
| 190 |
+
candidates_smiles_sample = [
|
| 191 |
+
x for _, x in sorted(zip(scores_sample, candidates_smiles_sample), reverse=True)
|
| 192 |
+
]
|
| 193 |
+
sorted_candidate_smiles.append(candidates_smiles_sample)
|
| 194 |
+
self._update_df_test({
|
| 195 |
+
'identifier': batch['identifier'],
|
| 196 |
+
'sorted_candidate_smiles': sorted_candidate_smiles
|
| 197 |
+
})
|
| 198 |
+
|
| 199 |
+
return outputs
|
| 200 |
+
|
| 201 |
+
def on_test_epoch_end(self):
|
| 202 |
+
# Save test data frame to disk
|
| 203 |
+
if self.df_test_path is not None:
|
| 204 |
+
df_test = pd.DataFrame(self.df_test)
|
| 205 |
+
self.df_test_path.parent.mkdir(parents=True, exist_ok=True)
|
| 206 |
+
df_test.to_pickle(self.df_test_path)
|
massspecgym/models/retrieval/deepsets.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import typing as T
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch_geometric.nn import MLP
|
| 7 |
+
|
| 8 |
+
from massspecgym.models.base import Stage
|
| 9 |
+
from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel
|
| 10 |
+
from massspecgym.models.layers import FourierFeatures
|
| 11 |
+
from massspecgym.utils import CosSimLoss
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DeepSetsRetrieval(RetrievalMassSpecGymModel):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
in_channels: int = 2, # m/z and intensity of a peak
|
| 18 |
+
hidden_channels: int = 512, # hidden layer size
|
| 19 |
+
out_channels: int = 4096, # fingerprint size
|
| 20 |
+
num_layers_per_mlp: int = 2,
|
| 21 |
+
dropout: float = 0.0,
|
| 22 |
+
norm: T.Optional[str] = None,
|
| 23 |
+
fourier_features: bool = True,
|
| 24 |
+
fourier_features_mz_channels: T.Optional[int] = None,
|
| 25 |
+
fourier_features_kwargs: T.Optional[dict] = None,
|
| 26 |
+
**kwargs
|
| 27 |
+
):
|
| 28 |
+
super().__init__(**kwargs)
|
| 29 |
+
|
| 30 |
+
self.fourier_features = fourier_features
|
| 31 |
+
if fourier_features:
|
| 32 |
+
if fourier_features_kwargs is None:
|
| 33 |
+
fourier_features_kwargs = {}
|
| 34 |
+
self.ff = FourierFeatures(**fourier_features_kwargs)
|
| 35 |
+
|
| 36 |
+
if fourier_features_mz_channels is None:
|
| 37 |
+
fourier_features_mz_channels = int(0.8 * hidden_channels)
|
| 38 |
+
else:
|
| 39 |
+
assert fourier_features_mz_channels < hidden_channels
|
| 40 |
+
self.ff_proj_mz = nn.Linear(self.ff.num_features, fourier_features_mz_channels)
|
| 41 |
+
self.ff_proj_i = nn.Linear(1, hidden_channels - fourier_features_mz_channels)
|
| 42 |
+
in_channels = hidden_channels
|
| 43 |
+
|
| 44 |
+
self.phi = MLP(
|
| 45 |
+
in_channels=in_channels,
|
| 46 |
+
hidden_channels=hidden_channels,
|
| 47 |
+
out_channels=hidden_channels,
|
| 48 |
+
num_layers=num_layers_per_mlp,
|
| 49 |
+
dropout=dropout,
|
| 50 |
+
norm=norm
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
self.rho = MLP(
|
| 54 |
+
in_channels=hidden_channels,
|
| 55 |
+
hidden_channels=hidden_channels,
|
| 56 |
+
out_channels=out_channels,
|
| 57 |
+
num_layers=num_layers_per_mlp,
|
| 58 |
+
dropout=dropout,
|
| 59 |
+
norm=norm
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
self.loss_fn = CosSimLoss()
|
| 63 |
+
|
| 64 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 65 |
+
if self.fourier_features:
|
| 66 |
+
x_mz = x[:, :, 0].unsqueeze(-1)
|
| 67 |
+
x_mz = self.ff(x_mz)
|
| 68 |
+
x_mz = self.ff_proj_mz(x_mz)
|
| 69 |
+
x_i = x[:, :, 1].unsqueeze(-1)
|
| 70 |
+
x_i = self.ff_proj_i(x_i)
|
| 71 |
+
x = torch.cat((x_mz, x_i), dim=-1)
|
| 72 |
+
x = self.phi(x)
|
| 73 |
+
x = x.sum(dim=-2) # sum over peaks
|
| 74 |
+
x = self.rho(x)
|
| 75 |
+
x = F.sigmoid(x) # predict proper fingerprint
|
| 76 |
+
return x
|
| 77 |
+
|
| 78 |
+
def step(
|
| 79 |
+
self, batch: dict, stage: Stage = Stage.NONE
|
| 80 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 81 |
+
# Unpack inputs
|
| 82 |
+
x = batch["spec"]
|
| 83 |
+
fp_true = batch["mol"]
|
| 84 |
+
cands = batch["candidates"]
|
| 85 |
+
batch_ptr = batch["batch_ptr"]
|
| 86 |
+
|
| 87 |
+
# Predict fingerprint
|
| 88 |
+
fp_pred = self.forward(x)
|
| 89 |
+
|
| 90 |
+
# Calculate loss
|
| 91 |
+
loss = self.loss_fn(fp_true, fp_pred)
|
| 92 |
+
|
| 93 |
+
# Evaluation performance on fingerprint prediction (optional)
|
| 94 |
+
self.evaluate_fingerprint_step(fp_true, fp_pred, stage=stage)
|
| 95 |
+
|
| 96 |
+
# Calculate final similarity scores between predicted fingerprints and corresponding
|
| 97 |
+
# candidate fingerprints for retrieval
|
| 98 |
+
fp_pred_repeated = fp_pred.repeat_interleave(batch_ptr, dim=0)
|
| 99 |
+
scores = nn.functional.cosine_similarity(fp_pred_repeated, cands)
|
| 100 |
+
|
| 101 |
+
return dict(loss=loss, scores=scores)
|
massspecgym/models/retrieval/fingerprint_ffn.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import typing as T
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch_geometric.nn import MLP
|
| 7 |
+
|
| 8 |
+
from massspecgym.models.base import Stage
|
| 9 |
+
from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel
|
| 10 |
+
from massspecgym.utils import CosSimLoss
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FingerprintFFNRetrieval(RetrievalMassSpecGymModel):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
in_channels: int = 1000, # number of bins
|
| 17 |
+
hidden_channels: int = 512, # hidden layer size
|
| 18 |
+
out_channels: int = 4096, # fingerprint size
|
| 19 |
+
num_layers: int = 2,
|
| 20 |
+
dropout: float = 0.0,
|
| 21 |
+
norm: T.Optional[str] = None,
|
| 22 |
+
**kwargs
|
| 23 |
+
):
|
| 24 |
+
super().__init__(**kwargs)
|
| 25 |
+
|
| 26 |
+
self.ffn = MLP(
|
| 27 |
+
in_channels=in_channels,
|
| 28 |
+
hidden_channels=hidden_channels,
|
| 29 |
+
out_channels=out_channels,
|
| 30 |
+
num_layers=num_layers,
|
| 31 |
+
dropout=dropout,
|
| 32 |
+
norm=norm
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
self.loss_fn = CosSimLoss()
|
| 36 |
+
|
| 37 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 38 |
+
x = self.ffn(x)
|
| 39 |
+
x = F.sigmoid(x) # predict proper fingerprint
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
def step(
|
| 43 |
+
self, batch: dict, stage: Stage = Stage.NONE
|
| 44 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 45 |
+
# Unpack inputs
|
| 46 |
+
x = batch["spec"]
|
| 47 |
+
fp_true = batch["mol"]
|
| 48 |
+
cands = batch["candidates"]
|
| 49 |
+
batch_ptr = batch["batch_ptr"]
|
| 50 |
+
|
| 51 |
+
# Predict fingerprint
|
| 52 |
+
fp_pred = self.forward(x)
|
| 53 |
+
|
| 54 |
+
# Calculate loss
|
| 55 |
+
loss = self.loss_fn(fp_true, fp_pred)
|
| 56 |
+
|
| 57 |
+
# Evaluation performance on fingerprint prediction (optional)
|
| 58 |
+
self.evaluate_fingerprint_step(fp_true, fp_pred, stage=stage)
|
| 59 |
+
|
| 60 |
+
# Calculate final similarity scores between predicted fingerprints and corresponding
|
| 61 |
+
# candidate fingerprints for retrieval
|
| 62 |
+
fp_pred_repeated = fp_pred.repeat_interleave(batch_ptr, dim=0)
|
| 63 |
+
scores = nn.functional.cosine_similarity(fp_pred_repeated, cands)
|
| 64 |
+
|
| 65 |
+
return dict(loss=loss, scores=scores)
|
massspecgym/models/retrieval/from_dict.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import typing as T
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from massspecgym.models.base import Stage
|
| 9 |
+
from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class FromDictRetrieval(RetrievalMassSpecGymModel):
|
| 13 |
+
"""
|
| 14 |
+
Read predictions from dictionary with MassSpecGym ids as keys. Currently, the class
|
| 15 |
+
only implements reading fingerprints from the dictionary.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
dct: T.Optional[dict[str, T.Any]] = None,
|
| 21 |
+
dct_path: T.Optional[T.Union[str, Path]] = None, # pickled dict path
|
| 22 |
+
*args,
|
| 23 |
+
**kwargs
|
| 24 |
+
):
|
| 25 |
+
super().__init__(*args, **kwargs)
|
| 26 |
+
|
| 27 |
+
if dct is None and dct_path is None:
|
| 28 |
+
raise ValueError("Either dct or dct_path must be provided.")
|
| 29 |
+
|
| 30 |
+
if dct is not None and dct_path is not None:
|
| 31 |
+
raise ValueError("Only one of dct or dct_path must be provided.")
|
| 32 |
+
|
| 33 |
+
if dct_path is not None:
|
| 34 |
+
with open(dct_path, "rb") as file:
|
| 35 |
+
dct = pickle.load(file)
|
| 36 |
+
|
| 37 |
+
dct = {k: torch.tensor(v) for k, v in dct.items()}
|
| 38 |
+
self.dct = dct
|
| 39 |
+
|
| 40 |
+
def step(
|
| 41 |
+
self, batch: dict, stage: Stage = Stage.NONE
|
| 42 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 43 |
+
# Unpack inputs
|
| 44 |
+
ids = batch["identifier"]
|
| 45 |
+
fp_true = batch["mol"]
|
| 46 |
+
cands = batch["candidates"]
|
| 47 |
+
batch_ptr = batch["batch_ptr"]
|
| 48 |
+
|
| 49 |
+
# Read predicted fingerprints from dictionary
|
| 50 |
+
fp_pred = torch.stack([self.dct[id] for id in ids]).to(fp_true.device)
|
| 51 |
+
|
| 52 |
+
# Evaluation performance on fingerprint prediction (optional)
|
| 53 |
+
self.evaluate_fingerprint_step(fp_true, fp_pred, stage=stage)
|
| 54 |
+
|
| 55 |
+
# Calculate final similarity scores between predicted fingerprints and corresponding
|
| 56 |
+
# candidate fingerprints for retrieval
|
| 57 |
+
fp_pred_repeated = fp_pred.repeat_interleave(batch_ptr, dim=0)
|
| 58 |
+
scores = nn.functional.cosine_similarity(fp_pred_repeated, cands).to(fp_true.device)
|
| 59 |
+
|
| 60 |
+
# Random baseline, so we return a dummy loss
|
| 61 |
+
loss = torch.tensor(0.0, requires_grad=True, device=fp_true.device)
|
| 62 |
+
|
| 63 |
+
return dict(loss=loss, scores=scores)
|
| 64 |
+
|
| 65 |
+
def configure_optimizers(self):
|
| 66 |
+
# No training, so no optimizers
|
| 67 |
+
return None
|
massspecgym/models/retrieval/random.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from massspecgym.models.base import Stage
|
| 4 |
+
from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class RandomRetrieval(RetrievalMassSpecGymModel):
|
| 8 |
+
|
| 9 |
+
def step(
|
| 10 |
+
self, batch: dict, stage: Stage = Stage.NONE
|
| 11 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 12 |
+
# Generate random retrieval scores
|
| 13 |
+
scores = torch.rand(batch["candidates"].shape[0]).to(self.device)
|
| 14 |
+
|
| 15 |
+
# Random baseline, so we return a dummy loss
|
| 16 |
+
loss = torch.tensor(0.0, requires_grad=True)
|
| 17 |
+
|
| 18 |
+
return dict(loss=loss, scores=scores)
|
| 19 |
+
|
| 20 |
+
def configure_optimizers(self):
|
| 21 |
+
# No optimizer needed for a random baseline
|
| 22 |
+
return None
|
massspecgym/models/simulation/__init__.py
ADDED
|
File without changes
|
massspecgym/models/simulation/base.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import typing as T
|
| 2 |
+
from abc import ABC
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
from torchmetrics import RetrievalHitRate, CosineSimilarity
|
| 8 |
+
|
| 9 |
+
from massspecgym.models.base import MassSpecGymModel
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SimulationMassSpecGymModel(MassSpecGymModel, ABC):
|
| 13 |
+
|
| 14 |
+
def on_batch_end(
|
| 15 |
+
self, outputs: T.Any, batch: dict, batch_idx: int, metric_pref: str = ""
|
| 16 |
+
) -> None:
|
| 17 |
+
"""
|
| 18 |
+
Compute evaluation metrics for the retrieval model based on the batch and corresponding predictions.
|
| 19 |
+
This method will be used in the `on_train_batch_end`, `on_validation_batch_end`, since `on_test_batch_end` is
|
| 20 |
+
overriden below.
|
| 21 |
+
"""
|
| 22 |
+
self.evaluate_cos_similarity_step(
|
| 23 |
+
outputs["spec_pred"],
|
| 24 |
+
batch["spec"],
|
| 25 |
+
metric_pref=metric_pref,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def on_test_batch_end(
|
| 29 |
+
self, outputs: T.Any, batch: dict, batch_idx: int
|
| 30 |
+
) -> None:
|
| 31 |
+
metric_pref = "_test"
|
| 32 |
+
self.evaluate_cos_similarity_step(
|
| 33 |
+
outputs["spec_pred"],
|
| 34 |
+
batch["spec"],
|
| 35 |
+
metric_pref=metric_pref
|
| 36 |
+
)
|
| 37 |
+
self.evaluate_hit_rate_step(
|
| 38 |
+
outputs["spec_pred"],
|
| 39 |
+
batch["spec"],
|
| 40 |
+
metric_pref=metric_pref
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def evaluate_cos_similarity_step(
|
| 44 |
+
self,
|
| 45 |
+
specs_pred: torch.Tensor,
|
| 46 |
+
specs: torch.Tensor,
|
| 47 |
+
metric_pref: str = ""
|
| 48 |
+
) -> None:
|
| 49 |
+
"""
|
| 50 |
+
Evaulate cosine similarity.
|
| 51 |
+
"""
|
| 52 |
+
raise NotImplementedError
|
| 53 |
+
|
| 54 |
+
def evaluate_hit_rate_step(
|
| 55 |
+
self,
|
| 56 |
+
specs_pred: torch.Tensor,
|
| 57 |
+
specs: torch.Tensor,
|
| 58 |
+
metric_pref: str = ""
|
| 59 |
+
) -> None:
|
| 60 |
+
"""
|
| 61 |
+
Evaulate Hit rate @ {1, 5, 20} (typically reported as Accuracy @ {1, 5, 20}).
|
| 62 |
+
"""
|
| 63 |
+
raise NotImplementedError
|
massspecgym/models/tokenizers.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import typing as T
|
| 3 |
+
import selfies as sf
|
| 4 |
+
from tokenizers import ByteLevelBPETokenizer
|
| 5 |
+
from tokenizers import Tokenizer, processors, models
|
| 6 |
+
from tokenizers.implementations import BaseTokenizer, ByteLevelBPETokenizer
|
| 7 |
+
import massspecgym.utils as utils
|
| 8 |
+
from massspecgym.definitions import PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SpecialTokensBaseTokenizer(BaseTokenizer):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
tokenizer: Tokenizer,
|
| 15 |
+
max_len: T.Optional[int] = None,
|
| 16 |
+
):
|
| 17 |
+
"""Initialize the base tokenizer with special tokens performing padding and truncation."""
|
| 18 |
+
super().__init__(tokenizer)
|
| 19 |
+
|
| 20 |
+
# Save essential attributes
|
| 21 |
+
self.pad_token = PAD_TOKEN
|
| 22 |
+
self.sos_token = SOS_TOKEN
|
| 23 |
+
self.eos_token = EOS_TOKEN
|
| 24 |
+
self.unk_token = UNK_TOKEN
|
| 25 |
+
self.max_length = max_len
|
| 26 |
+
|
| 27 |
+
# Add special tokens
|
| 28 |
+
self.add_special_tokens([self.pad_token, self.sos_token, self.eos_token, self.unk_token])
|
| 29 |
+
|
| 30 |
+
# Get token IDs
|
| 31 |
+
self.pad_token_id = self.token_to_id(self.pad_token)
|
| 32 |
+
self.sos_token_id = self.token_to_id(self.sos_token)
|
| 33 |
+
self.eos_token_id = self.token_to_id(self.eos_token)
|
| 34 |
+
self.unk_token_id = self.token_to_id(self.unk_token)
|
| 35 |
+
|
| 36 |
+
# Enable padding
|
| 37 |
+
self.enable_padding(
|
| 38 |
+
direction="right",
|
| 39 |
+
pad_token=self.pad_token,
|
| 40 |
+
pad_id=self.pad_token_id,
|
| 41 |
+
length=max_len,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Enable truncation
|
| 45 |
+
self.enable_truncation(max_len)
|
| 46 |
+
|
| 47 |
+
# Set post-processing to add SOS and EOS tokens
|
| 48 |
+
self._tokenizer.post_processor = processors.TemplateProcessing(
|
| 49 |
+
single=f"{self.sos_token} $A {self.eos_token}",
|
| 50 |
+
pair=f"{self.sos_token} $A {self.eos_token} {self.sos_token} $B {self.eos_token}",
|
| 51 |
+
special_tokens=[
|
| 52 |
+
(self.sos_token, self.sos_token_id),
|
| 53 |
+
(self.eos_token, self.eos_token_id),
|
| 54 |
+
],
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class SelfiesTokenizer(SpecialTokensBaseTokenizer):
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
selfies_train: T.Optional[T.Union[str, T.List[str]]] = None,
|
| 62 |
+
**kwargs
|
| 63 |
+
):
|
| 64 |
+
"""
|
| 65 |
+
Initialize the SELFIES tokenizer with optional training data to build a vocanulary.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
selfies_train (str or list of str): Either a list of SELFIES strings to build the vocabulary from,
|
| 69 |
+
or a `semantic_robust_alphabet` string indicating the usahe of `selfies.get_semantic_robust_alphabet()`
|
| 70 |
+
alphabet. If None, the MassSpecGym training molecules will be used.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
if selfies_train == 'semantic_robust_alphabet':
|
| 74 |
+
alphabet = list(sorted(sf.get_semantic_robust_alphabet()))
|
| 75 |
+
else:
|
| 76 |
+
if not selfies_train:
|
| 77 |
+
selfies_train = utils.load_train_mols()
|
| 78 |
+
selfies = [sf.encoder(s, strict=False) for s in selfies_train]
|
| 79 |
+
else:
|
| 80 |
+
selfies = selfies_train
|
| 81 |
+
alphabet = list(sorted(sf.get_alphabet_from_selfies(selfies)))
|
| 82 |
+
|
| 83 |
+
vocab = {symbol: i for i, symbol in enumerate(alphabet)}
|
| 84 |
+
vocab[UNK_TOKEN] = len(vocab)
|
| 85 |
+
tokenizer = Tokenizer(models.WordLevel(vocab=vocab, unk_token=UNK_TOKEN))
|
| 86 |
+
|
| 87 |
+
super().__init__(tokenizer, **kwargs)
|
| 88 |
+
|
| 89 |
+
def encode(self, text: str, add_special_tokens: bool = True) -> Tokenizer:
|
| 90 |
+
"""Encodes a SMILES string into a list of SELFIES token IDs."""
|
| 91 |
+
selfies_string = sf.encoder(text, strict=False)
|
| 92 |
+
selfies_tokens = list(sf.split_selfies(selfies_string))
|
| 93 |
+
return super().encode(
|
| 94 |
+
selfies_tokens, is_pretokenized=True, add_special_tokens=add_special_tokens
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
def decode(self, token_ids: T.List[int], skip_special_tokens: bool = True) -> str:
|
| 98 |
+
"""Decodes a list of SELFIES token IDs back into a SMILES string."""
|
| 99 |
+
selfies_string = super().decode(
|
| 100 |
+
token_ids, skip_special_tokens=skip_special_tokens
|
| 101 |
+
)
|
| 102 |
+
selfies_string = self._decode_wordlevel_str_to_selfies(selfies_string)
|
| 103 |
+
return sf.decoder(selfies_string)
|
| 104 |
+
|
| 105 |
+
def encode_batch(
|
| 106 |
+
self, texts: T.List[str], add_special_tokens: bool = True
|
| 107 |
+
) -> T.List[Tokenizer]:
|
| 108 |
+
"""Encodes a batch of SMILES strings into a list of SELFIES token IDs."""
|
| 109 |
+
selfies_strings = [
|
| 110 |
+
list(sf.split_selfies(sf.encoder(text, strict=False))) for text in texts
|
| 111 |
+
]
|
| 112 |
+
return super().encode_batch(
|
| 113 |
+
selfies_strings, is_pretokenized=True, add_special_tokens=add_special_tokens
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def decode_batch(
|
| 117 |
+
self, token_ids_batch: T.List[T.List[int]], skip_special_tokens: bool = True
|
| 118 |
+
) -> T.List[str]:
|
| 119 |
+
"""Decodes a batch of SELFIES token IDs back into SMILES strings."""
|
| 120 |
+
selfies_strings = super().decode_batch(
|
| 121 |
+
token_ids_batch, skip_special_tokens=skip_special_tokens
|
| 122 |
+
)
|
| 123 |
+
return [
|
| 124 |
+
sf.decoder(
|
| 125 |
+
self._decode_wordlevel_str_to_selfies(
|
| 126 |
+
selfies_string
|
| 127 |
+
)
|
| 128 |
+
)
|
| 129 |
+
for selfies_string in selfies_strings
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
def _decode_wordlevel_str_to_selfies(self, text: str) -> str:
|
| 133 |
+
"""Converts a WordLevel string back to a SELFIES string."""
|
| 134 |
+
return text.replace(" ", "")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class SmilesBPETokenizer(SpecialTokensBaseTokenizer):
|
| 138 |
+
def __init__(self, smiles_pth: T.Optional[str] = None, **kwargs):
|
| 139 |
+
"""
|
| 140 |
+
Initialize the BPE tokenizer for SMILES strings, with optional training data.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
smiles_pth (str): Path to a file containing SMILES strings to train the tokenizer on. If None,
|
| 144 |
+
the MassSpecGym training molecules will be used.
|
| 145 |
+
"""
|
| 146 |
+
tokenizer = ByteLevelBPETokenizer()
|
| 147 |
+
if smiles_pth:
|
| 148 |
+
tokenizer.train(smiles_pth)
|
| 149 |
+
else:
|
| 150 |
+
smiles = utils.load_unlabeled_mols("smiles").tolist()
|
| 151 |
+
smiles += utils.load_train_mols().tolist()
|
| 152 |
+
|
| 153 |
+
print(f"Training tokenizer on {len(smiles)} SMILES strings.")
|
| 154 |
+
tokenizer.train_from_iterator(smiles)
|
| 155 |
+
|
| 156 |
+
super().__init__(tokenizer, **kwargs)
|
massspecgym/utils.py
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
# import seaborn as sns
|
| 3 |
+
import matplotlib as mpl
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import matplotlib.colors
|
| 6 |
+
import matplotlib.cm as cm
|
| 7 |
+
import matplotlib.colors as mcolors
|
| 8 |
+
import matplotlib.ticker as ticker
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import typing as T
|
| 11 |
+
import pulp
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from itertools import groupby
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from myopic_mces.myopic_mces import MCES
|
| 18 |
+
from rdkit.Chem import AllChem as Chem
|
| 19 |
+
from rdkit.Chem import DataStructs, Draw
|
| 20 |
+
from rdkit.Chem.Descriptors import ExactMolWt
|
| 21 |
+
# from huggingface_hub import hf_hub_download
|
| 22 |
+
# from standardizeUtils.standardizeUtils import (
|
| 23 |
+
# standardize_structure_with_pubchem,
|
| 24 |
+
# standardize_structure_list_with_pubchem,
|
| 25 |
+
# )
|
| 26 |
+
from torchmetrics.wrappers import BootStrapper
|
| 27 |
+
from torchmetrics.metric import Metric
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_massspecgym(fold: T.Optional[str] = None) -> pd.DataFrame:
|
| 31 |
+
"""
|
| 32 |
+
Load the MassSpecGym dataset.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
fold (str, optional): Fold name to load. If None, the entire dataset is loaded.
|
| 36 |
+
"""
|
| 37 |
+
df = pd.read_csv(hugging_face_download("MassSpecGym.tsv"), sep="\t")
|
| 38 |
+
df = df.set_index("identifier")
|
| 39 |
+
df['mzs'] = df['mzs'].apply(parse_spec_array)
|
| 40 |
+
df['intensities'] = df['intensities'].apply(parse_spec_array)
|
| 41 |
+
if fold is not None:
|
| 42 |
+
df = df[df['fold'] == fold]
|
| 43 |
+
return df
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def load_unlabeled_mols(col_name: str = "smiles") -> pd.Series:
|
| 47 |
+
"""
|
| 48 |
+
Load a list of unlabeled molecules.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
col_name (str, optional): Name of the column to return. Should be one of ["smiles", "selfies"].
|
| 52 |
+
"""
|
| 53 |
+
return pd.read_csv(
|
| 54 |
+
hugging_face_download(
|
| 55 |
+
"molecules/MassSpecGym_molecules_MCES2_disjoint_with_test_fold_4M.tsv"
|
| 56 |
+
),
|
| 57 |
+
sep="\t"
|
| 58 |
+
)[col_name]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def load_train_mols(col_name: str = "smiles") -> pd.Series:
|
| 62 |
+
"""
|
| 63 |
+
Load a list of training molecules.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
col_name (str, optional): Name of the column to return. Should be one of ["smiles", "selfies"].
|
| 67 |
+
"""
|
| 68 |
+
return load_massspecgym("train")[col_name]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def pad_spectrum(
|
| 72 |
+
spec: np.ndarray, max_n_peaks: int, pad_value: float = 0.0
|
| 73 |
+
) -> np.ndarray:
|
| 74 |
+
"""
|
| 75 |
+
Pad a spectrum to a fixed number of peaks by appending zeros to the end of the spectrum.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
spec (np.ndarray): Spectrum to pad represented as numpy array of shape (n_peaks, 2).
|
| 79 |
+
max_n_peaks (int): Maximum number of peaks in the padded spectrum.
|
| 80 |
+
pad_value (float, optional): Value to use for padding.
|
| 81 |
+
"""
|
| 82 |
+
n_peaks = spec.shape[0]
|
| 83 |
+
if n_peaks > max_n_peaks:
|
| 84 |
+
raise ValueError(
|
| 85 |
+
f"Number of peaks in the spectrum ({n_peaks}) is greater than the maximum number of peaks."
|
| 86 |
+
)
|
| 87 |
+
else:
|
| 88 |
+
return np.pad(
|
| 89 |
+
spec,
|
| 90 |
+
((0, max_n_peaks - n_peaks), (0, 0)),
|
| 91 |
+
mode="constant",
|
| 92 |
+
constant_values=pad_value,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def morgan_fp(mol: Chem.Mol, fp_size=2048, radius=2, to_np=True):
|
| 97 |
+
"""
|
| 98 |
+
Compute Morgan fingerprint for a molecule.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
mol (Chem.Mol): _description_
|
| 102 |
+
fp_size (int, optional): Size of the fingerprint.
|
| 103 |
+
radius (int, optional): Radius of the fingerprint.
|
| 104 |
+
to_np (bool, optional): Convert the fingerprint to numpy array.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
fp = Chem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=fp_size)
|
| 108 |
+
if to_np:
|
| 109 |
+
fp_np = np.zeros((0,), dtype=np.int32)
|
| 110 |
+
DataStructs.ConvertToNumpyArray(fp, fp_np)
|
| 111 |
+
fp = fp_np
|
| 112 |
+
return fp
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def tanimoto_morgan_similarity(mol1: T.Union[Chem.Mol, str], mol2: T.Union[Chem.Mol, str]) -> float:
|
| 116 |
+
"""
|
| 117 |
+
Compute Tanimoto similarity between two molecules using Morgan fingerprints.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
mol1 (T.Union[Chem.Mol, str]): First molecule as RDKit molecule or SMILES string.
|
| 121 |
+
mol2 (T.Union[Chem.Mol, str]): Second molecule as RDKit molecule or SMILES string.
|
| 122 |
+
"""
|
| 123 |
+
if isinstance(mol1, str):
|
| 124 |
+
mol1 = Chem.MolFromSmiles(mol1)
|
| 125 |
+
if isinstance(mol2, str):
|
| 126 |
+
mol2 = Chem.MolFromSmiles(mol2)
|
| 127 |
+
return DataStructs.TanimotoSimilarity(morgan_fp(mol1, to_np=False), morgan_fp(mol2, to_np=False))
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def standardize_smiles(smiles: T.Union[str, T.List[str]]) -> T.Union[str, T.List[str]]:
|
| 131 |
+
"""
|
| 132 |
+
Standardize SMILES representation of a molecule using PubChem standardization.
|
| 133 |
+
"""
|
| 134 |
+
if isinstance(smiles, str):
|
| 135 |
+
return standardize_structure_with_pubchem(smiles, 'smiles')
|
| 136 |
+
elif isinstance(smiles, list):
|
| 137 |
+
return standardize_structure_list_with_pubchem(smiles, 'smiles')
|
| 138 |
+
else:
|
| 139 |
+
raise ValueError("Input should be a SMILES tring or a list of SMILES strings.")
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def mol_to_inchi_key(mol: Chem.Mol, twod: bool = True) -> str:
|
| 143 |
+
"""
|
| 144 |
+
Convert a molecule to InChI Key representation.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
mol (Chem.Mol): RDKit molecule object.
|
| 148 |
+
twod (bool, optional): Return 2D InChI Key (first 14 characers of InChI Key).
|
| 149 |
+
"""
|
| 150 |
+
inchi_key = Chem.MolToInchiKey(mol)
|
| 151 |
+
if twod:
|
| 152 |
+
inchi_key = inchi_key.split("-")[0]
|
| 153 |
+
return inchi_key
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def smiles_to_inchi_key(mol: str, twod: bool = True) -> str:
|
| 157 |
+
"""
|
| 158 |
+
Convert a SMILES molecule to InChI Key representation.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
mol (str): SMILES string.
|
| 162 |
+
twod (bool, optional): Return 2D InChI Key (first 14 characers of InChI Key).
|
| 163 |
+
"""
|
| 164 |
+
mol = Chem.MolFromSmiles(mol)
|
| 165 |
+
return mol_to_inchi_key(mol, twod)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def hugging_face_download(file_name: str) -> str:
|
| 169 |
+
"""
|
| 170 |
+
Download a file from the Hugging Face Hub and return its location on disk.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
file_name (str): Name of the file to download.
|
| 174 |
+
"""
|
| 175 |
+
return hf_hub_download(
|
| 176 |
+
repo_id="roman-bushuiev/MassSpecGym",
|
| 177 |
+
filename="data/" + file_name,
|
| 178 |
+
repo_type="dataset",
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def init_plotting(figsize=(6, 2), font_scale=1.0, style="whitegrid"):
|
| 183 |
+
# Set default figure size
|
| 184 |
+
plt.show() # Does not work without this line for some reason
|
| 185 |
+
sns.set_theme(rc={"figure.figsize": figsize})
|
| 186 |
+
mpl.rcParams['svg.fonttype'] = 'none'
|
| 187 |
+
# Set default style and font scale
|
| 188 |
+
sns.set_style(style)
|
| 189 |
+
sns.set_context("paper", font_scale=font_scale)
|
| 190 |
+
sns.set_palette(["#009473", "#D94F70", "#5A5B9F", "#F0C05A", "#7BC4C4", "#FF6F61"])
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def parse_spec_array(arr: str) -> np.ndarray:
|
| 194 |
+
return np.array(list(map(float, arr.split(","))))
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def spec_array_to_str(arr: np.ndarray) -> str:
|
| 198 |
+
return ",".join(map(str, arr))
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def compute_mass(smiles: str) -> float:
|
| 202 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 203 |
+
if mol is None:
|
| 204 |
+
raise ValueError("Invalid SMILES string.")
|
| 205 |
+
return ExactMolWt(mol)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def plot_spectrum(spec, hue=None, xlim=None, ylim=None, mirror_spec=None, highl_idx=None,
|
| 209 |
+
figsize=(6, 2), colors=None, save_pth=None):
|
| 210 |
+
|
| 211 |
+
if colors is not None:
|
| 212 |
+
assert len(colors) >= 3
|
| 213 |
+
else:
|
| 214 |
+
colors = ['blue', 'green', 'red']
|
| 215 |
+
|
| 216 |
+
# Normalize input spectrum
|
| 217 |
+
def norm_spec(spec):
|
| 218 |
+
assert len(spec.shape) == 2
|
| 219 |
+
if spec.shape[0] != 2:
|
| 220 |
+
spec = spec.T
|
| 221 |
+
mzs, ins = spec[0], spec[1]
|
| 222 |
+
return mzs, ins / max(ins) * 100
|
| 223 |
+
mzs, ins = norm_spec(spec)
|
| 224 |
+
|
| 225 |
+
# Initialize plotting
|
| 226 |
+
init_plotting(figsize=figsize)
|
| 227 |
+
fig, ax = plt.subplots(1, 1)
|
| 228 |
+
|
| 229 |
+
# Setup color palette
|
| 230 |
+
if hue is not None:
|
| 231 |
+
norm = matplotlib.colors.Normalize(vmin=min(hue), vmax=max(hue), clip=True)
|
| 232 |
+
mapper = cm.ScalarMappable(norm=norm, cmap=cm.cool)
|
| 233 |
+
plt.colorbar(mapper, ax=ax)
|
| 234 |
+
|
| 235 |
+
# Plot spectrum
|
| 236 |
+
for i in range(len(mzs)):
|
| 237 |
+
if hue is not None:
|
| 238 |
+
color = mcolors.to_hex(mapper.to_rgba(hue[i]))
|
| 239 |
+
else:
|
| 240 |
+
color = colors[0]
|
| 241 |
+
plt.plot([mzs[i], mzs[i]], [0, ins[i]], color=color, marker='o', markevery=(1, 2), mfc='white', zorder=2)
|
| 242 |
+
|
| 243 |
+
# Plot mirror spectrum
|
| 244 |
+
if mirror_spec is not None:
|
| 245 |
+
mzs_m, ins_m = norm_spec(mirror_spec)
|
| 246 |
+
|
| 247 |
+
@ticker.FuncFormatter
|
| 248 |
+
def major_formatter(x, pos):
|
| 249 |
+
label = str(round(-x)) if x < 0 else str(round(x))
|
| 250 |
+
return label
|
| 251 |
+
|
| 252 |
+
for i in range(len(mzs_m)):
|
| 253 |
+
plt.plot([mzs_m[i], mzs_m[i]], [0, -ins_m[i]], color=colors[2], marker='o', markevery=(1, 2), mfc='white',
|
| 254 |
+
zorder=1)
|
| 255 |
+
ax.yaxis.set_major_formatter(major_formatter)
|
| 256 |
+
|
| 257 |
+
# Setup axes
|
| 258 |
+
if xlim is not None:
|
| 259 |
+
plt.xlim(xlim[0], xlim[1])
|
| 260 |
+
else:
|
| 261 |
+
plt.xlim(0, max(mzs) + 10)
|
| 262 |
+
if ylim is not None:
|
| 263 |
+
plt.ylim(ylim[0], ylim[1])
|
| 264 |
+
plt.xlabel('m/z')
|
| 265 |
+
plt.ylabel('Intensity [%]')
|
| 266 |
+
|
| 267 |
+
if save_pth is not None:
|
| 268 |
+
raise NotImplementedError()
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def show_mols(mols, legends='new_indices', smiles_in=False, svg=False, sort_by_legend=False, max_mols=500,
|
| 272 |
+
legend_float_decimals=4, mols_per_row=6, save_pth: T.Optional[Path] = None):
|
| 273 |
+
"""
|
| 274 |
+
Returns svg image representing a grid of skeletal structures of the given molecules. Copy-pasted
|
| 275 |
+
from https://github.com/pluskal-lab/DreaMS/blob/main/dreams/utils/mols.py
|
| 276 |
+
|
| 277 |
+
:param mols: list of rdkit molecules
|
| 278 |
+
:param smiles_in: True - SMILES inputs, False - RDKit mols
|
| 279 |
+
:param legends: list of labels for each molecule, length must be equal to the length of mols
|
| 280 |
+
:param svg: True - return svg image, False - return png image
|
| 281 |
+
:param sort_by_legend: True - sort molecules by legend values
|
| 282 |
+
:param max_mols: maximum number of molecules to show
|
| 283 |
+
:param legend_float_decimals: number of decimal places to show for float legends
|
| 284 |
+
:param mols_per_row: number of molecules per row to show
|
| 285 |
+
:param save_pth: path to save the .svg image to
|
| 286 |
+
"""
|
| 287 |
+
if smiles_in:
|
| 288 |
+
mols = [Chem.MolFromSmiles(e) for e in mols]
|
| 289 |
+
|
| 290 |
+
if legends == 'new_indices':
|
| 291 |
+
legends = list(range(len(mols)))
|
| 292 |
+
elif legends == 'masses':
|
| 293 |
+
legends = [ExactMolWt(m) for m in mols]
|
| 294 |
+
elif callable(legends):
|
| 295 |
+
legends = [legends(e) for e in mols]
|
| 296 |
+
|
| 297 |
+
if sort_by_legend:
|
| 298 |
+
idx = np.argsort(legends).tolist()
|
| 299 |
+
legends = [legends[i] for i in idx]
|
| 300 |
+
mols = [mols[i] for i in idx]
|
| 301 |
+
|
| 302 |
+
legends = [f'{l:.{legend_float_decimals}f}' if isinstance(l, float) else str(l) for l in legends]
|
| 303 |
+
|
| 304 |
+
img = Draw.MolsToGridImage(mols, maxMols=max_mols, legends=legends, molsPerRow=min(max_mols, mols_per_row),
|
| 305 |
+
useSVG=svg, returnPNG=False)
|
| 306 |
+
|
| 307 |
+
if save_pth:
|
| 308 |
+
with open(save_pth, 'w') as f:
|
| 309 |
+
f.write(img.data)
|
| 310 |
+
|
| 311 |
+
return img
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class MyopicMCES():
|
| 315 |
+
def __init__(
|
| 316 |
+
self,
|
| 317 |
+
ind: int = 0, # dummy index
|
| 318 |
+
solver: str = pulp.listSolvers(onlyAvailable=True)[0], # Use the first available solver
|
| 319 |
+
threshold: int = 15, # MCES threshold
|
| 320 |
+
always_stronger_bound: bool = True, # "False" makes computations a lot faster, but leads to overall higher MCES values
|
| 321 |
+
solver_options: dict = None
|
| 322 |
+
):
|
| 323 |
+
self.ind = ind
|
| 324 |
+
self.solver = solver
|
| 325 |
+
self.threshold = threshold
|
| 326 |
+
self.always_stronger_bound = always_stronger_bound
|
| 327 |
+
if solver_options is None:
|
| 328 |
+
solver_options = dict(msg=0) # make ILP solver silent
|
| 329 |
+
self.solver_options = solver_options
|
| 330 |
+
|
| 331 |
+
# def __call__(self, smiles_1: str, smiles_2: str) -> float:
|
| 332 |
+
# retval = MCES(
|
| 333 |
+
# s1=smiles_1,
|
| 334 |
+
# s2=smiles_2,
|
| 335 |
+
# ind=self.ind,
|
| 336 |
+
# threshold=self.threshold,
|
| 337 |
+
# always_stronger_bound=self.always_stronger_bound,
|
| 338 |
+
# solver=self.solver,
|
| 339 |
+
# solver_options=self.solver_options
|
| 340 |
+
# )
|
| 341 |
+
# dist = retval[1]
|
| 342 |
+
# return dist
|
| 343 |
+
def __call__(self, smiles_1: str, smiles_2: str) -> float:
|
| 344 |
+
retval = MCES(
|
| 345 |
+
smiles_1,
|
| 346 |
+
smiles_2,
|
| 347 |
+
threshold=self.threshold,
|
| 348 |
+
always_stronger_bound=self.always_stronger_bound,
|
| 349 |
+
solver=self.solver,
|
| 350 |
+
solver_options = self.solver_options
|
| 351 |
+
)
|
| 352 |
+
dist = retval[1]
|
| 353 |
+
return dist
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class ReturnScalarBootStrapper(BootStrapper):
|
| 357 |
+
def __init__(
|
| 358 |
+
self,
|
| 359 |
+
base_metric: Metric,
|
| 360 |
+
num_bootstraps: int = 10,
|
| 361 |
+
mean: bool = False,
|
| 362 |
+
std: bool = False,
|
| 363 |
+
quantile: T.Optional[T.Union[float, torch.Tensor]] = None,
|
| 364 |
+
raw: bool = False,
|
| 365 |
+
sampling_strategy: str = "poisson",
|
| 366 |
+
**kwargs: T.Any
|
| 367 |
+
) -> None:
|
| 368 |
+
"""Wrapper for BootStrapper that returns a scalar value in compute instead of a dictionary."""
|
| 369 |
+
|
| 370 |
+
if mean + std + bool(quantile) + raw != 1:
|
| 371 |
+
raise ValueError("Exactly one of mean, std, quantile or raw should be True.")
|
| 372 |
+
|
| 373 |
+
if std:
|
| 374 |
+
self.compute_key = "std"
|
| 375 |
+
else:
|
| 376 |
+
raise NotImplementedError("Currently only std is implemented.")
|
| 377 |
+
|
| 378 |
+
super().__init__(
|
| 379 |
+
base_metric=base_metric,
|
| 380 |
+
num_bootstraps=num_bootstraps,
|
| 381 |
+
mean=mean,
|
| 382 |
+
std=std,
|
| 383 |
+
quantile=quantile,
|
| 384 |
+
raw=raw,
|
| 385 |
+
sampling_strategy=sampling_strategy,
|
| 386 |
+
**kwargs
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
def compute(self):
|
| 390 |
+
return super().compute()[self.compute_key]
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def batch_ptr_to_batch_idx(batch_ptr: torch.Tensor) -> torch.Tensor:
|
| 394 |
+
"""
|
| 395 |
+
Convert a tensor of batch pointers to a tensor of batch indexes.
|
| 396 |
+
|
| 397 |
+
For example [1, 3, 2] -> [0, 1, 1, 1, 2, 2]
|
| 398 |
+
|
| 399 |
+
Args:
|
| 400 |
+
batch_ptr (Tensor): Tensor of batch pointers.
|
| 401 |
+
"""
|
| 402 |
+
indexes = torch.arange(batch_ptr.size(0), device=batch_ptr.device)
|
| 403 |
+
indexes = torch.repeat_interleave(indexes, batch_ptr)
|
| 404 |
+
return indexes
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def unbatch_list(batch_list: list, batch_idx: torch.Tensor) -> list:
|
| 408 |
+
"""
|
| 409 |
+
Unbatch a list of items using the batch indexes (i.e., number of samples per batch).
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
batch_list (list): List of items to unbatch.
|
| 413 |
+
batch_idx (Tensor): Tensor of batch indexes.
|
| 414 |
+
"""
|
| 415 |
+
return [
|
| 416 |
+
[batch_list[j] for j in range(len(batch_list)) if batch_idx[j] == i]
|
| 417 |
+
for i in range(batch_idx[-1] + 1)
|
| 418 |
+
]
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
class CosSimLoss(nn.Module):
|
| 422 |
+
def __init__(self):
|
| 423 |
+
super(CosSimLoss, self).__init__()
|
| 424 |
+
|
| 425 |
+
def forward(self, inputs, targets):
|
| 426 |
+
return 1 - F.cosine_similarity(inputs, targets).mean()
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def parse_sirius_ms(spectra_file: str) -> T.Tuple[dict, T.List[T.Tuple[str, np.ndarray]]]:
|
| 430 |
+
"""
|
| 431 |
+
Parses spectra from the SIRIUS .ms file.
|
| 432 |
+
|
| 433 |
+
Copied from the code of Goldman et al.:
|
| 434 |
+
https://github.com/samgoldman97/mist/blob/4c23d34fc82425ad5474a53e10b4622dcdbca479/src/mist/utils/parse_utils.py#LL10C77-L10C77.
|
| 435 |
+
:return T.Tuple[dict, T.List[T.Tuple[str, np.ndarray]]]: metadata and list of spectra tuples containing name and array
|
| 436 |
+
"""
|
| 437 |
+
lines = [i.strip() for i in open(spectra_file, "r").readlines()]
|
| 438 |
+
|
| 439 |
+
group_num = 0
|
| 440 |
+
metadata = {}
|
| 441 |
+
spectras = []
|
| 442 |
+
my_iterator = groupby(
|
| 443 |
+
lines, lambda line: line.startswith(">") or line.startswith("#")
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
for index, (start_line, lines) in enumerate(my_iterator):
|
| 447 |
+
group_lines = list(lines)
|
| 448 |
+
subject_lines = list(next(my_iterator)[1])
|
| 449 |
+
# Get spectra
|
| 450 |
+
if group_num > 0:
|
| 451 |
+
spectra_header = group_lines[0].split(">")[1]
|
| 452 |
+
peak_data = [
|
| 453 |
+
[float(x) for x in peak.split()[:2]]
|
| 454 |
+
for peak in subject_lines
|
| 455 |
+
if peak.strip()
|
| 456 |
+
]
|
| 457 |
+
# Check if spectra is empty
|
| 458 |
+
if len(peak_data):
|
| 459 |
+
peak_data = np.vstack(peak_data)
|
| 460 |
+
# Add new tuple
|
| 461 |
+
spectras.append((spectra_header, peak_data))
|
| 462 |
+
# Get meta data
|
| 463 |
+
else:
|
| 464 |
+
entries = {}
|
| 465 |
+
for i in group_lines:
|
| 466 |
+
if " " not in i:
|
| 467 |
+
continue
|
| 468 |
+
elif i.startswith("#INSTRUMENT TYPE"):
|
| 469 |
+
key = "#INSTRUMENT TYPE"
|
| 470 |
+
val = i.split(key)[1].strip()
|
| 471 |
+
entries[key[1:]] = val
|
| 472 |
+
else:
|
| 473 |
+
start, end = i.split(" ", 1)
|
| 474 |
+
start = start[1:]
|
| 475 |
+
while start in entries:
|
| 476 |
+
start = f"{start}'"
|
| 477 |
+
entries[start] = end
|
| 478 |
+
|
| 479 |
+
metadata.update(entries)
|
| 480 |
+
group_num += 1
|
| 481 |
+
|
| 482 |
+
metadata["_FILE_PATH"] = spectra_file
|
| 483 |
+
metadata["_FILE"] = Path(spectra_file).stem
|
| 484 |
+
return metadata, spectras
|