yzhouchen001 commited on
Commit
d9df210
·
1 Parent(s): 6cd92cd

model code

Browse files
mvp/__pycache__/definitions.cpython-311.pyc CHANGED
Binary files a/mvp/__pycache__/definitions.cpython-311.pyc and b/mvp/__pycache__/definitions.cpython-311.pyc differ
 
mvp/data/__pycache__/data_module.cpython-311.pyc CHANGED
Binary files a/mvp/data/__pycache__/data_module.cpython-311.pyc and b/mvp/data/__pycache__/data_module.cpython-311.pyc differ
 
mvp/data/__pycache__/datasets.cpython-311.pyc CHANGED
Binary files a/mvp/data/__pycache__/datasets.cpython-311.pyc and b/mvp/data/__pycache__/datasets.cpython-311.pyc differ
 
mvp/data/__pycache__/transforms.cpython-311.pyc CHANGED
Binary files a/mvp/data/__pycache__/transforms.cpython-311.pyc and b/mvp/data/__pycache__/transforms.cpython-311.pyc differ
 
mvp/data/data_module.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data.dataloader import DataLoader
2
+ from massspecgym.data.data_module import MassSpecDataModule
3
+ from mvp.data.datasets import ContrastiveDataset
4
+ from functools import partial
5
+ from massspecgym.models.base import Stage
6
+
7
+ class TestDataModule(MassSpecDataModule):
8
+ def __init__(
9
+ self,
10
+ collate_fn,
11
+ **kwargs
12
+ ):
13
+ super().__init__(**kwargs)
14
+ self.collate_fn = collate_fn
15
+
16
+ def prepare_data(self):
17
+ pass
18
+
19
+ def setup(self, stage=None):
20
+ if stage == "test":
21
+ self.test_dataset = self.dataset
22
+ else:
23
+ raise Exception("Data module supports test set only")
24
+
25
+ def test_dataloader(self):
26
+ return DataLoader(
27
+ self.test_dataset,
28
+ batch_size=self.batch_size,
29
+ shuffle=False,
30
+ num_workers=self.num_workers,
31
+ persistent_workers=self.persistent_workers,
32
+ drop_last=False,
33
+ collate_fn=self.collate_fn,
34
+ )
35
+
36
+ def train_dataloader(self):
37
+ return None
38
+
39
+ def val_dataset(self):
40
+ return None
41
+
42
+ class ContrastiveDataModule(MassSpecDataModule):
43
+ def __init__(
44
+ self,
45
+ collate_fn,
46
+ **kwargs
47
+ ):
48
+ super().__init__(**kwargs)
49
+ self.collate_fn = collate_fn
50
+ self.regularization_flag = False
51
+
52
+ def train_dataloader(self):
53
+ self.train_contrastive_dataset = ContrastiveDataset(self.train_dataset)
54
+
55
+ return DataLoader(self.train_contrastive_dataset,
56
+ batch_size=self.batch_size,
57
+ shuffle=True,
58
+ num_workers=self.num_workers,
59
+ persistent_workers=self.persistent_workers,
60
+ drop_last=False,
61
+ collate_fn=partial(self.collate_fn, stage=Stage.TRAIN),
62
+ )
63
+
64
+ def val_dataloader(self):
65
+ self.val_contrastive_dataset = ContrastiveDataset(self.val_dataset)
66
+
67
+ return DataLoader(self.val_contrastive_dataset,
68
+ batch_size=self.batch_size,
69
+ shuffle=False,
70
+ num_workers=self.num_workers,
71
+ persistent_workers=self.persistent_workers,
72
+ drop_last=False,
73
+ collate_fn=partial(self.collate_fn, stage=Stage.VAL))
74
+
75
+ def test_dataloader(self):
76
+ return DataLoader(
77
+ self.test_dataset,
78
+ batch_size=self.batch_size,
79
+ shuffle=False,
80
+ num_workers=self.num_workers,
81
+ persistent_workers=self.persistent_workers,
82
+ drop_last=False,
83
+ collate_fn=self.dataset.collate_fn,
84
+ )
mvp/data/datasets.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import json
3
+ import typing as T
4
+ import numpy as np
5
+ import torch
6
+ import massspecgym.utils as utils
7
+ from pathlib import Path
8
+ from torch.utils.data.dataset import Dataset
9
+ from torch.utils.data.dataloader import default_collate
10
+ import dgl
11
+ from collections import defaultdict
12
+ from massspecgym.data.transforms import SpecTransform, MolTransform, MolToInChIKey
13
+ from massspecgym.data.datasets import MassSpecDataset
14
+ import mvp.utils.data as data_utils
15
+ from torch.nn.utils.rnn import pad_sequence
16
+ from massspecgym.models.base import Stage
17
+ import pickle
18
+ import math
19
+ import itertools
20
+ from rdkit.Chem import AllChem
21
+ from rdkit import Chem
22
+ class JESTR1_MassSpecDataset(MassSpecDataset):
23
+ def __init__(
24
+ self,
25
+ spectra_view: str,
26
+ fp_dir_pth: str = None,
27
+ cons_spec_dir_pth: str = None,
28
+ NL_spec_dir_pth: str = None,
29
+ **kwargs
30
+ ):
31
+ super().__init__(**kwargs)
32
+
33
+ self.use_fp = False
34
+ self.use_cons_spec = False
35
+ self.use_NL_spec = False
36
+ self.spectra_view = spectra_view
37
+
38
+ # load fingerprints
39
+ self._load_fp(fp_dir_pth)
40
+
41
+ # load consensus
42
+ self._load_cons_spec(cons_spec_dir_pth)
43
+
44
+ # load NL specs
45
+ self._load_NL_spec(NL_spec_dir_pth)
46
+
47
+ def _load_fp(self, fp_dir_pth):
48
+ if fp_dir_pth is not None:
49
+ self.use_fp = True
50
+ if fp_dir_pth:
51
+ with open(fp_dir_pth, 'rb') as f:
52
+ self.smiles_to_fp = pickle.load(f)
53
+ else:
54
+ self.smiles_to_fp = {}
55
+
56
+ def _load_cons_spec(self, cons_spec_dir_pth):
57
+ if cons_spec_dir_pth is not None:
58
+ self.use_cons_spec = True
59
+ with open(cons_spec_dir_pth, 'rb') as f:
60
+ cons_specs = pickle.load(f)
61
+
62
+ # Convert spectra to matchms spectra
63
+ matchMS_preparer = data_utils.PrepMatchMS(self.spectra_view)
64
+ spectra = cons_specs.apply(matchMS_preparer.prepare,axis=1)
65
+
66
+ self.cons_specs = dict(zip(cons_specs['smiles'].tolist(), spectra))
67
+
68
+ def _load_NL_spec(self, NL_spec_dir_pth):
69
+ if NL_spec_dir_pth is not None:
70
+ self.use_NL_spec = True
71
+ with open(NL_spec_dir_pth, 'rb') as f:
72
+ NL_specs = pickle.load(f)
73
+
74
+ # Convert spectra to matchms spectra
75
+ matchMS_preparer = data_utils.PrepMatchMS(self.spectra_view)
76
+ self.NL_specs = NL_specs.apply(matchMS_preparer.prepare,axis=1)
77
+
78
+
79
+ def __getitem__(self, i, transform_spec: bool = True, transform_mol: bool = True):
80
+
81
+ spec = self.spectra[i]
82
+ metadata = self.metadata.iloc[i]
83
+ mol = metadata["smiles"]
84
+
85
+ # Apply all transformations to the spectrum
86
+ item = {}
87
+ if transform_spec and self.spec_transform:
88
+ if isinstance(self.spec_transform, dict):
89
+ for key, transform in self.spec_transform.items():
90
+ item[key] = transform(spec) if transform is not None else spec
91
+ else:
92
+ item["spec"] = self.spec_transform(spec)
93
+ else:
94
+ item["spec"] = spec
95
+
96
+ if self.return_mol_freq:
97
+ item["mol_freq"] = metadata["mol_freq"]
98
+
99
+ if self.return_identifier:
100
+ item["identifier"] = metadata["identifier"]
101
+
102
+ if self.use_fp and self.smiles_to_fp:
103
+ item['fp'] = torch.Tensor(self.smiles_to_fp[mol].ToList())
104
+
105
+ if self.use_cons_spec:
106
+ item['cons_spec'] = self.spec_transform[self.spectra_view](self.cons_specs[mol])
107
+
108
+ if self.use_NL_spec:
109
+ item['NL_spec'] = self.spec_transform[self.spectra_view](self.NL_specs[i])
110
+
111
+ # Apply all transformations to the molecule
112
+ if transform_mol and self.mol_transform:
113
+ if isinstance(self.mol_transform, dict):
114
+ for key, transform in self.mol_transform.items():
115
+ item[key] = transform(mol) if transform is not None else mol
116
+ else:
117
+ item["mol"] = self.mol_transform(mol)
118
+ else:
119
+ item["mol"] = mol
120
+ return item
121
+
122
+ class MassSpecDataset_PeakFormulas(JESTR1_MassSpecDataset):
123
+ def __init__(
124
+ self,
125
+ spectra_view: str,
126
+ spec_transform: T.Optional[T.Union[SpecTransform, T.Dict[str, SpecTransform]]],
127
+ mol_transform: T.Optional[T.Union[MolTransform, T.Dict[str, MolTransform]]],
128
+ pth: T.Optional[Path],
129
+ subformula_dir_pth: str,
130
+ fp_dir_pth: str = None,
131
+ NL_spec_dir_pth: str = None,
132
+ cons_spec_dir_pth: str = None,
133
+ return_mol_freq: bool = False,
134
+ return_identifier: bool = True,
135
+ dtype: T.Type = torch.float32
136
+ ):
137
+ """
138
+ Args:
139
+ """
140
+ self.pth = pth
141
+ self.spec_transform = spec_transform
142
+ self.mol_transform = mol_transform
143
+ self.return_mol_freq = return_mol_freq
144
+ self.pred_fp = False
145
+ self.use_fp = False
146
+ self.use_cons_spec = False
147
+ self.use_NL_spec = False
148
+ self.spectra_view = spectra_view
149
+
150
+ if isinstance(self.pth, str):
151
+ self.pth = Path(self.pth)
152
+
153
+ self.spectra_view = spectra_view
154
+ print("Data path: ", self.pth)
155
+ self.metadata = pd.read_csv(self.pth, sep="\t")
156
+
157
+ # Used for training on consensus spectra
158
+ # with open(self.pth, 'rb') as f:
159
+ # self.metadata = pickle.load(f)
160
+ # self.metadata['identifier'] = self.metadata['smiles'].tolist()
161
+
162
+ # load subformulas
163
+ all_spec_ids = self.metadata['identifier'].tolist()
164
+ subformulaLoader = data_utils.Subformula_Loader(spectra_view=spectra_view, dir_path=subformula_dir_pth)
165
+ id_to_spec = subformulaLoader(all_spec_ids)
166
+
167
+ # create subformula spectra if no subformula is available
168
+ tmp_ids = [spec_id for spec_id in all_spec_ids if spec_id not in id_to_spec]
169
+ tmp_df = self.metadata[self.metadata['identifier'].isin(tmp_ids)]
170
+ tmp_df['spec'] = tmp_df.apply(lambda row: data_utils.make_tmp_subformula_spectra(row), axis=1)
171
+ id_to_spec.update(dict(zip(tmp_df['identifier'].tolist(), tmp_df['spec'].tolist())))
172
+
173
+
174
+ # load fingerprints
175
+ self._load_fp(fp_dir_pth)
176
+
177
+ # load consensus spectra
178
+ self._load_cons_spec(cons_spec_dir_pth)
179
+
180
+ # load NL specs
181
+ self._load_NL_spec(NL_spec_dir_pth)
182
+
183
+ self.metadata = self.metadata[self.metadata['identifier'].isin(id_to_spec)]
184
+ formula_df = pd.DataFrame.from_dict(id_to_spec, orient='index').reset_index().rename(columns={'index': 'identifier'})
185
+ self.metadata = self.metadata.merge(formula_df, on='identifier')
186
+
187
+ # create matchms spectra
188
+ matchMS_preparer = data_utils.PrepMatchMS(spectra_view=spectra_view)
189
+ self.spectra = self.metadata.apply(matchMS_preparer.prepare,axis=1)
190
+
191
+ if self.return_mol_freq:
192
+ if "inchikey" not in self.metadata.columns:
193
+ self.metadata["inchikey"] = self.metadata["smiles"].apply(utils.smiles_to_inchi_key)
194
+ self.metadata["mol_freq"] = self.metadata.groupby("inchikey")["inchikey"].transform("count")
195
+
196
+ self.return_identifier = return_identifier
197
+ self.dtype = dtype
198
+
199
+ def __getitem__(self, i, transform_spec: bool = True, transform_mol: bool = True):
200
+ item = super().__getitem__(i, transform_spec, transform_mol = False)
201
+ mol = item['mol'] #smiles
202
+
203
+ # transform mol
204
+ if transform_mol:
205
+ if isinstance(self.mol_transform, dict):
206
+ for key, transform in self.mol_transform.items():
207
+ item[key] = transform(mol) if transform is not None else mol
208
+ else:
209
+ item["mol"] = self.mol_transform(mol)
210
+
211
+ return item
212
+
213
+ class ContrastiveDataset(Dataset):
214
+ def __init__(
215
+ self,
216
+ spec_mol_data,
217
+ ):
218
+ super().__init__()
219
+
220
+ indices = spec_mol_data.indices
221
+ self.spec_mol_data = spec_mol_data
222
+ self.smiles_to_specmol_ids = spec_mol_data.dataset.metadata.loc[indices].groupby('smiles').indices
223
+ self.smiles_to_spec_couter = defaultdict(int)
224
+ self.smiles_list = list(self.smiles_to_specmol_ids.keys())
225
+
226
+ def __len__(self) -> int:
227
+ return len(self.smiles_list)
228
+
229
+ def __getitem__(self, i:int) -> dict:
230
+ mol = self.smiles_list[i]
231
+
232
+ # select spectrum (iterate through list of spectra)
233
+ specmol_ids = self.smiles_to_specmol_ids[mol]
234
+ counter = self.smiles_to_spec_couter[mol]
235
+ specmol_id = specmol_ids[counter % len(specmol_ids)]
236
+
237
+ item = self.spec_mol_data.__getitem__(specmol_id)
238
+ self.smiles_to_spec_couter[mol] = counter+1
239
+ # item['smiles'] = mol
240
+ # item['spec_id'] = specmol_id
241
+ return item
242
+
243
+ @staticmethod
244
+ def collate_fn(batch: T.Iterable[dict], spec_enc: str, spectra_view: str, stage=None, mask_peak_ratio: float = 0.0, aug_cands: bool = False) -> dict:
245
+ mol_key = 'cand' if stage == Stage.TEST else 'mol'
246
+ non_standard_collate = ['mol', 'cand', 'aug_cands', 'cons_spec', 'aug_cands_fp', 'NL_spec']
247
+ require_pad = False
248
+ if 'Formula' in spectra_view or 'Tokens' in spectra_view:
249
+ require_pad = True
250
+ padding_value=-5 if spec_enc in ('Transformer_Formula', 'Formula_BinnedSpec', 'Transformer_MzInt') else 0
251
+ non_standard_collate.append(spectra_view)
252
+ else:
253
+ non_standard_collate.remove('cons_spec')
254
+ non_standard_collate.remove('NL_spec')
255
+
256
+ collated_batch = {}
257
+ # standard collate
258
+ for k in batch[0].keys():
259
+ if k not in non_standard_collate:
260
+ collated_batch[k] = default_collate([item[k] for item in batch])
261
+
262
+ # batch graphs
263
+ batch_mol = []
264
+ batch_mol_nodes= []
265
+
266
+ for item in batch:
267
+ batch_mol.append(item[mol_key])
268
+ batch_mol_nodes.append(item[mol_key].num_nodes())
269
+
270
+ collated_batch[mol_key] = dgl.batch(batch_mol)
271
+ collated_batch['mol_n_nodes'] = batch_mol_nodes
272
+
273
+ # pad peaks/formulas
274
+ if require_pad:
275
+ peaks = []
276
+ n_peaks = []
277
+ for item in batch:
278
+ peaks.append(item[spectra_view])
279
+ n_peaks.append(len(item[spectra_view]))
280
+ collated_batch[spectra_view] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
281
+ collated_batch['n_peaks'] = n_peaks
282
+
283
+ if 'cons_spec' in batch[0]:
284
+ peaks = []
285
+ n_peaks = []
286
+ for item in batch:
287
+ peaks.append(item['cons_spec'])
288
+ n_peaks.append(len(item['cons_spec']))
289
+ collated_batch['cons_spec'] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
290
+ collated_batch['cons_n_peaks'] = n_peaks
291
+
292
+ if 'NL_spec' in batch[0]:
293
+ peaks = []
294
+ n_peaks = []
295
+ for item in batch:
296
+ peaks.append(item['NL_spec'])
297
+ n_peaks.append(len(item['NL_spec']))
298
+ collated_batch['NL_spec'] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
299
+ collated_batch['NL_n_peaks'] = n_peaks
300
+
301
+
302
+ # mask peaks
303
+ if mask_peak_ratio > 0.0 and stage == Stage.TRAIN:
304
+ n_mask_peaks = [math.floor(n_peak* mask_peak_ratio) for n_peak in n_peaks]
305
+ mask_peak_idx = [np.random.choice(n_peak, n_mask, replace=False) for n_peak, n_mask in zip(n_peaks, n_mask_peaks)]
306
+ for i, peaks in enumerate(collated_batch[spectra_view]):
307
+ peaks[mask_peak_idx[i]] = -5.0
308
+
309
+ # batch candidates
310
+ if aug_cands:
311
+ candidates = \
312
+ sum([item["aug_cands"] for item in batch], start=[])
313
+ collated_batch['aug_cands'] = dgl.batch(candidates)
314
+
315
+ if 'aug_cands_fp' in batch[0]:
316
+ cand_fp = [item['aug_cands_fp'] for item in batch]
317
+ collated_batch['aug_cands_fp'] = torch.flatten(torch.Tensor(cand_fp), end_dim=1)
318
+
319
+ return collated_batch
320
+
321
+
322
+
323
+ class ExpandedRetrievalDataset:
324
+ '''Used for testing only
325
+ Assumes 'fold' column defines the split'''
326
+ def __init__(self,
327
+ use_formulas: bool = True,
328
+ mol_label_transform: MolTransform = MolToInChIKey(),
329
+ candidates_pth: T.Optional[T.Union[Path, str]] = None,
330
+ fp_size: int = None,
331
+ fp_radius: int = None,
332
+ external_test: bool = False,
333
+ **kwargs):
334
+
335
+ self.external_test = external_test
336
+
337
+ self.instance = MassSpecDataset_PeakFormulas(**kwargs, return_mol_freq=False) if use_formulas else JESTR1_MassSpecDataset(**kwargs, return_mol_freq=False)
338
+ # super().__init__(**kwargs)
339
+
340
+ if self.use_fp:
341
+ self.fpgen = AllChem.GetMorganGenerator(radius=fp_radius,fpSize=fp_size)
342
+
343
+ self.candidates_pth = candidates_pth
344
+ self.mol_label_transform = mol_label_transform
345
+
346
+ # Read candidates_pth from json to dict: SMILES -> respective candidate SMILES
347
+ with open(self.candidates_pth, "r") as file:
348
+ candidates = json.load(file)
349
+
350
+ self.candidates = {}
351
+ for s, cand in candidates.items():
352
+ self.candidates[s] = [c for c in cand if '.' not in c]
353
+
354
+ self.spec_cand = [] #(spec index, cand_smiles, true_label)
355
+
356
+ # use for external dataset where target smiles is not known
357
+ # self.candidates should be a dict of identifier to candidates
358
+ if self.external_test or 'smiles' not in self.metadata.columns:
359
+ if not isinstance(self.metadata.iloc[0]['identifier'], str):
360
+ self.metadata['smiles'] = self.metadata['identifier'].apply(str)
361
+ else:
362
+ self.metadata['smiles'] = self.metadata['identifier']
363
+ test_smiles = self.metadata[self.metadata['fold'] == "test"]['smiles'].tolist()
364
+ test_ms_id = self.metadata[self.metadata['fold'] == "test"]['identifier'].tolist()
365
+
366
+ spec_id_to_index = dict(zip(self.metadata['identifier'], self.metadata.index))
367
+ for spec_id, s in zip(test_ms_id, test_smiles):
368
+ candidates = self.candidates[s]
369
+ # mol_label = self.mol_label_transform(s)
370
+ # labels = [self.mol_label_transform(c) == mol_label for c in candidates]
371
+ if not self.external_test:
372
+ labels = [c == s for c in candidates]
373
+
374
+ if len(candidates) == 0:
375
+ print(f"Skipping {spec_id}; empty candidate set")
376
+ continue
377
+ if not any(labels):
378
+ print(f"Target smiles not in candidate set")
379
+ else:
380
+ labels = [False] * len(candidates)
381
+
382
+ self.spec_cand.extend([(spec_id_to_index[spec_id], candidates[j], k) for j, k in enumerate(labels)])
383
+
384
+ def __getattr__(self, name):
385
+ return self.instance.__getattribute__(name)
386
+
387
+ def __len__(self):
388
+ return len(self.spec_cand)
389
+
390
+ def __getitem__(self, i):
391
+ spec_i = self.spec_cand[i][0]
392
+ cand_smiles = self.spec_cand[i][1]
393
+ label = self.spec_cand[i][2]
394
+
395
+ item = self.instance.__getitem__(spec_i, transform_mol=False)
396
+ item['cand'] = self.mol_transform(cand_smiles)
397
+ item['cand_smiles'] = cand_smiles
398
+ item['label'] = label
399
+
400
+ if self.use_fp:
401
+ item['fp'] = torch.Tensor(self.fpgen.GetFingerprint(Chem.MolFromSmiles(cand_smiles)).ToList())
402
+
403
+ return item
mvp/data/transforms.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import matchms
4
+ from typing import Optional
5
+ from rdkit.Chem import AllChem as Chem
6
+ from mvp.definitions import CHEM_ELEMS_SMALL
7
+ from massspecgym.data.transforms import MolTransform, SpecTransform, default_matchms_transforms
8
+ from massspecgym.data.transforms import SpecBinner
9
+
10
+ import dgllife.utils as chemutils
11
+ import re
12
+
13
+ class SpecBinnerLog(SpecTransform):
14
+ def __init__(
15
+ self,
16
+ max_mz: float = 1005,
17
+ bin_width: float = 1,
18
+ ) -> None:
19
+ self.max_mz = max_mz
20
+ self.bin_width = bin_width
21
+ if not (max_mz / bin_width).is_integer():
22
+ raise ValueError("`max_mz` must be divisible by `bin_width`.")
23
+
24
+ def matchms_transforms(self, spec: matchms.Spectrum) -> matchms.Spectrum:
25
+ return default_matchms_transforms(spec, mz_to=self.max_mz, n_max_peaks=None)
26
+
27
+ def matchms_to_torch(self, spec: matchms.Spectrum) -> torch.Tensor:
28
+ """
29
+ Bin the spectrum into a fixed number of bins.
30
+ """
31
+ binned_spec = self._bin_mass_spectrum(
32
+ mzs=spec.peaks.mz,
33
+ intensities=spec.peaks.intensities,
34
+ max_mz=self.max_mz,
35
+ bin_width=self.bin_width,
36
+ )
37
+ return torch.from_numpy(binned_spec).to(dtype=torch.float32)
38
+
39
+ def _bin_mass_spectrum(
40
+ self, mzs, intensities, max_mz, bin_width
41
+ ):
42
+
43
+ # Calculate the number of bins
44
+ num_bins = int(np.ceil(max_mz / bin_width))
45
+
46
+ # Calculate the bin indices for each mass
47
+ bin_indices = np.floor(mzs -1 / bin_width).astype(int)
48
+
49
+ # Filter out mzs that exceed max_mz
50
+ valid_indices = bin_indices[mzs <= max_mz]
51
+ valid_intensities = intensities[mzs <= max_mz]
52
+
53
+ # Clip bin indices to ensure they are within the valid range
54
+ valid_indices = np.clip(valid_indices, 0, num_bins - 1)
55
+
56
+ # Initialize an array to store the binned intensities
57
+ binned_intensities = np.zeros(num_bins)
58
+
59
+ # Use np.add.at to sum intensities in the appropriate bins
60
+ np.add.at(binned_intensities, valid_indices, valid_intensities)
61
+
62
+ binned_intensities = binned_intensities/np.max(binned_intensities) * 999
63
+
64
+ binned_intensities = np.log10(binned_intensities + 1) / 3
65
+
66
+ return binned_intensities
67
+
68
+ class SpecFormulaFeaturizer(SpecTransform):
69
+ ''' Uses processed mz and intensities, excludes mz values, keep peaks with formulas only'''
70
+ def __init__(
71
+ self,
72
+ add_intensities: bool,
73
+ max_mz: float = 1005,
74
+ element_list: list = CHEM_ELEMS_SMALL,
75
+ formula_normalize_vector: Optional[np.array] = None
76
+ ) -> None:
77
+ self.max_mz = max_mz
78
+ self.elem_to_pos = {e: i for i, e in enumerate(element_list)}
79
+ self.add_intensities = add_intensities
80
+ if formula_normalize_vector is None:
81
+ formula_normalize_vector = np.ones(len(element_list))
82
+ self.formula_normalize_vector = formula_normalize_vector
83
+ self.CHEM_FORMULA_SIZE = "([A-Z][a-z]*)([0-9]*)"
84
+
85
+ def matchms_transforms(self, spec: matchms.Spectrum):
86
+ return spec
87
+
88
+ def matchms_to_torch(self, spec: matchms.Spectrum) -> torch.Tensor:
89
+ mzs = spec.peaks.mz
90
+ intensities = spec.peaks.intensities
91
+ formulas = spec.metadata['formulas'] # list of formulas
92
+
93
+ peak_idx = np.where(mzs <= self.max_mz)[0]
94
+ intensities = intensities[peak_idx]
95
+ formulas = formulas[peak_idx]
96
+
97
+ spec = self._featurize_formula(formulas)
98
+ spec = spec/self.formula_normalize_vector
99
+
100
+ if self.add_intensities:
101
+ spec = np.concatenate((spec, intensities.reshape(-1,1)), axis=1)
102
+ spec = spec.astype(np.float32)
103
+
104
+ return torch.from_numpy(spec)
105
+
106
+ def _featurize_formula(self, formulas):
107
+ formula_vector = np.zeros((len(formulas), len(self.elem_to_pos)))
108
+ for i, f in enumerate(formulas):
109
+ try:
110
+ for (e, ct) in re.findall(self.CHEM_FORMULA_SIZE, f):
111
+ ct = 1 if ct == "" else int(ct)
112
+ try:
113
+ formula_vector[i][self.elem_to_pos[e]]+=ct
114
+ except:
115
+ print(f"Couldn't vectorize {f}, element {e} not supported")
116
+ continue
117
+ except:
118
+ print(f"Couldn't vectorize {f}, formula not supported")
119
+ continue
120
+ return formula_vector
121
+
122
+ class MolToGraph(MolTransform):
123
+ def __init__ (self, atom_feature: str = "full", bond_feature: str = "full", element_list: list = CHEM_ELEMS_SMALL):
124
+ self.atom_feature = atom_feature
125
+ self.bond_feature = bond_feature
126
+ self.node_featurizer = self._get_atom_featurizer(element_list=element_list)
127
+ self.edge_featurizer = self._get_bond_featurizer()
128
+
129
+ def from_smiles(self, mol:str):
130
+ mol = Chem.MolFromSmiles(mol)
131
+ g = chemutils.mol_to_bigraph(mol, node_featurizer=self.node_featurizer, edge_featurizer=self.edge_featurizer, add_self_loop = True,
132
+ num_virtual_nodes = 0, canonical_atom_order=False)
133
+
134
+ # atom_ids = [atom.GetIdx() for atom in mol.GetAtoms()] # added for visualization
135
+ # g.ndata['atom_id'] = torch.tensor(atom_ids, dtype=torch.long)
136
+
137
+ return g
138
+
139
+ def _get_atom_featurizer(self, element_list) -> dict:
140
+ feature_mode = self.atom_feature
141
+ atom_mass_fun = chemutils.ConcatFeaturizer(
142
+ [chemutils.atom_mass]
143
+ )
144
+ def atom_bond_type_one_hot(atom):
145
+ bs = atom.GetBonds()
146
+ bt = np.array([chemutils.bond_type_one_hot(b) for b in bs])
147
+ return [any(bt[:, i]) for i in range(bt.shape[1])]
148
+
149
+ def atom_type_one_hot(atom):
150
+ return chemutils.atom_type_one_hot(
151
+ atom, allowable_set = element_list, encode_unknown = True
152
+ )
153
+
154
+ if feature_mode == 'light':
155
+ atom_featurizer_funs = chemutils.ConcatFeaturizer([
156
+ chemutils.atom_mass,
157
+ atom_type_one_hot
158
+ ])
159
+ elif feature_mode == 'full':
160
+ atom_featurizer_funs = chemutils.ConcatFeaturizer([
161
+ chemutils.atom_mass,
162
+ atom_type_one_hot,
163
+ atom_bond_type_one_hot,
164
+ chemutils.atom_degree_one_hot,
165
+ chemutils.atom_total_degree_one_hot,
166
+ chemutils.atom_explicit_valence_one_hot,
167
+ chemutils.atom_implicit_valence_one_hot,
168
+ chemutils.atom_hybridization_one_hot,
169
+ chemutils.atom_total_num_H_one_hot,
170
+ chemutils.atom_formal_charge_one_hot,
171
+ chemutils.atom_num_radical_electrons_one_hot,
172
+ chemutils.atom_is_aromatic_one_hot,
173
+ chemutils.atom_is_in_ring_one_hot,
174
+ chemutils.atom_chiral_tag_one_hot
175
+ ])
176
+ elif feature_mode == 'medium':
177
+ atom_featurizer_funs = chemutils.ConcatFeaturizer([
178
+ chemutils.atom_mass,
179
+ atom_type_one_hot,
180
+ atom_bond_type_one_hot,
181
+ chemutils.atom_total_degree_one_hot,
182
+ chemutils.atom_total_num_H_one_hot,
183
+ chemutils.atom_is_aromatic_one_hot,
184
+ chemutils.atom_is_in_ring_one_hot,
185
+ ])
186
+ return chemutils.BaseAtomFeaturizer(
187
+ {"h": atom_featurizer_funs,
188
+ "m": atom_mass_fun}
189
+ )
190
+
191
+ def _get_bond_featurizer(self, self_loop=True) -> dict:
192
+ feature_mode = self.bond_feature
193
+ if feature_mode == 'light':
194
+ return chemutils.BaseBondFeaturizer(
195
+ featurizer_funcs = {'e': chemutils.ConcatFeaturizer([
196
+ chemutils.bond_type_one_hot
197
+ ])}, self_loop = self_loop
198
+ )
199
+ elif feature_mode == 'full':
200
+ return chemutils.CanonicalBondFeaturizer(
201
+ bond_data_field='e', self_loop = self_loop
202
+ )
mvp/data_preprocess.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from mvp.utils.preprocessing import generate_cons_spec_formulas, generate_cons_spec
3
+ import os
4
+ import pickle
5
+ import pandas as pd
6
+ from rdkit.Chem import AllChem
7
+ from rdkit import Chem
8
+ from tqdm import tqdm
9
+
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument("--spec_type", choices=('formSpec', 'binnedSpec'), required=True)
12
+ parser.add_argument("--dataset_pth", required=True, help="path to spectra data")
13
+ parser.add_argument("--candidates_pth", required=True, help="path to candidates data")
14
+ parser.add_argument("--output_dir", required=True, help="path to output directory")
15
+ parser.add_argument("--subformula_dir_pth", default='', help="path to subformula directory if using formSpec")
16
+
17
+
18
+ def check_args():
19
+
20
+ # create output directory
21
+ os.makedirs(args.output_dir, exist_ok=True)
22
+
23
+ # check files
24
+ if args.spec_type == 'formSpec':
25
+ assert(os.path.isdir(args.subformula_dir_pth))
26
+
27
+ assert(os.path.exists(args.dataset_pth))
28
+ assert(os.path.exists(args.candidates_pth))
29
+
30
+ def construct_smiles_to_fp(smiles_list, r=5, fp_size=1024):
31
+ fpgen = AllChem.GetMorganGenerator(radius=r,fpSize=fp_size)
32
+ smiles_to_fp = {}
33
+ failed_ct = 0
34
+
35
+ for s in tqdm(smiles_list, total=len(smiles_list)):
36
+ try:
37
+ mol = Chem.MolFromSmiles(s)
38
+ fp = fpgen.GetFingerprint(mol)
39
+ smiles_to_fp[s] = fp
40
+ except:
41
+ failed_ct+=1
42
+ print(f'Failed to generate fingerprints for {failed_ct} smiles')
43
+
44
+ # save smiles_to_fp
45
+ with open(os.path.join(args.output_dir, f'morganfp_r{r}_{fp_size}.pickle'), 'wb') as f:
46
+ pickle.dump(smiles_to_fp, f)
47
+
48
+ def construct_consensus_spectra():
49
+ if args.spec_type == 'formSpec':
50
+ df = generate_cons_spec_formulas(args.dataset_pth, args.subformula_dir_pth, args.output_dir)
51
+ elif args.spec_type == 'binnedSpec':
52
+ df = generate_cons_spec(args.dataset_pth, args.output_dir)
53
+
54
+ # save consensus spectra df
55
+ with open(os.path.join(args.output_dir, f'consensus_{args.spec_type}.pkl'), 'wb') as f:
56
+ pickle.dump(df, f)
57
+
58
+ def main(data):
59
+
60
+ # generate fingerpints
61
+ print("Processing fingerprints...")
62
+ unique_smiles = data['smiles'].unique().tolist()
63
+ construct_smiles_to_fp(unique_smiles)
64
+
65
+ # generate consensus spectra
66
+ print("Processring consensus spectra...")
67
+ construct_consensus_spectra()
68
+
69
+
70
+ if __name__ == '__main__':
71
+ args = parser.parse_args([] if "__file__" not in globals() else None)
72
+
73
+ check_args()
74
+
75
+ # load data
76
+ data = pd.read_csv(args.dataset_pth, sep='\t')
77
+
78
+ main(data)
mvp/definitions.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Global variables used across the package."""
2
+ import pathlib
3
+
4
+ # Dirs
5
+ ROOT_DIR = pathlib.Path(__file__).parent.absolute()
6
+ REPO_DIR = ROOT_DIR.parent
7
+ DATA_DIR = REPO_DIR / 'data'
8
+ TEST_RESULTS_DIR = REPO_DIR / 'experiments'
9
+ ASSETS_DIR = REPO_DIR / 'assets'
10
+
11
+ # C
12
+ CHEM_ELEMS_SMALL = ['H', 'C', 'O', 'N', 'P', 'S', 'Cl', 'F', 'Br', 'I', 'B', 'As', 'Si', 'Se']
13
+
14
+ MSGYM_FORMULA_VECTOR_NORM = [102.0, 59.0, 25.0, 13.0, 3.0, 6.0, 6.0, 17.0, 4.0, 4.0, 1.0, 1.0, 5.0, 2.0]
15
+
16
+ #MSGYM standardization
17
+ MSGYM_STANDARD_MH = {
18
+ 'mz_mean': 195.155185,
19
+ 'mz_std':127.591549
20
+ }
21
+ MSGYM_STANDARD_all = { # got these from Yinkai
22
+ "mz_mean": 80.88304948022557,
23
+ "mz_std" : 197.4588028571758}
mvp/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import sys
2
+ sys.path.insert(0, "/data/yzhouc01//MassSpecGym")
3
+ from massspecgym.models import *
mvp/models/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/mvp/models/__pycache__/__init__.cpython-311.pyc and b/mvp/models/__pycache__/__init__.cpython-311.pyc differ
 
mvp/models/__pycache__/contrastive.cpython-311.pyc CHANGED
Binary files a/mvp/models/__pycache__/contrastive.cpython-311.pyc and b/mvp/models/__pycache__/contrastive.cpython-311.pyc differ
 
mvp/models/__pycache__/encoders.cpython-311.pyc CHANGED
Binary files a/mvp/models/__pycache__/encoders.cpython-311.pyc and b/mvp/models/__pycache__/encoders.cpython-311.pyc differ
 
mvp/models/__pycache__/mol_encoder.cpython-311.pyc CHANGED
Binary files a/mvp/models/__pycache__/mol_encoder.cpython-311.pyc and b/mvp/models/__pycache__/mol_encoder.cpython-311.pyc differ
 
mvp/models/__pycache__/spec_encoder.cpython-311.pyc CHANGED
Binary files a/mvp/models/__pycache__/spec_encoder.cpython-311.pyc and b/mvp/models/__pycache__/spec_encoder.cpython-311.pyc differ
 
mvp/models/contrastive.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as T
2
+ import torch
3
+ import torch.nn as nn
4
+ import pandas as pd
5
+ from collections import defaultdict
6
+ import numpy as np
7
+ import os
8
+ from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel
9
+ from massspecgym.models.base import Stage
10
+ from massspecgym import utils
11
+
12
+ from mvp.utils.loss import contrastive_loss, cand_spec_sim_loss, fp_loss, cons_spec_loss
13
+ import mvp.utils.models as model_utils
14
+
15
+ import torch.nn.functional as F
16
+
17
+
18
+ class ContrastiveModel(RetrievalMassSpecGymModel):
19
+ def __init__(
20
+ self,
21
+ external_test = False,
22
+ **kwargs,
23
+ ):
24
+ super().__init__(**kwargs)
25
+ self.save_hyperparameters()
26
+ self.external_test = external_test
27
+
28
+ if 'use_fp' not in self.hparams:
29
+ self.hparams.use_fp = False
30
+
31
+ if 'loss_strategy' not in self.hparams:
32
+ self.hparams.loss_strategy = 'static'
33
+ self.hparams.contr_wt = 1.0
34
+ self.hparams.use_contr = True
35
+
36
+ self.spec_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
37
+ self.mol_enc_model = model_utils.get_mol_encoder(self.hparams.mol_enc, self.hparams)
38
+
39
+ if self.hparams.pred_fp:
40
+ self.fp_loss = fp_loss(self.hparams.fp_loss_type)
41
+ self.fp_pred_model = model_utils.get_fp_pred_model(self.hparams)
42
+ if self.hparams.use_cons_spec:
43
+ self.cons_spec_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
44
+ self.cons_loss = cons_spec_loss(self.hparams.cons_loss_type)
45
+
46
+ self.spec_view = self.hparams.spectra_view
47
+
48
+ # result storage for testing results
49
+ self.result_dct = defaultdict(lambda: defaultdict(list))
50
+
51
+ def forward(self, batch, stage):
52
+ g = batch['cand'] if stage == Stage.TEST else batch['mol']
53
+
54
+ if self.hparams.use_cons_spec and stage != Stage.TEST:
55
+ spec = batch['cons_spec']
56
+ n_peaks = batch['cons_n_peaks'] if 'cons_n_peaks' in batch else None
57
+ spec_enc = self.cons_spec_enc_model(spec, n_peaks)
58
+ else:
59
+ spec = batch[self.spec_view]
60
+ n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
61
+ spec_enc = self.spec_enc_model(spec, n_peaks)
62
+
63
+ fp = batch['fp'] if self.hparams.use_fp else None
64
+ mol_enc = self.mol_enc_model(g, fp=fp)
65
+
66
+ return spec_enc, mol_enc
67
+
68
+ def compute_loss(self, batch: dict, spec_enc, mol_enc, output):
69
+ loss = 0
70
+ losses = {}
71
+ contr_loss, cong_loss, noncong_loss = contrastive_loss(spec_enc, mol_enc, self.hparams.contr_temp)
72
+ contr_loss = self.loss_wts['contr_wt'] *contr_loss
73
+ losses['contr_loss'] = contr_loss.detach().item()
74
+ losses['cong_loss'] = cong_loss.detach().item()
75
+ losses['noncong_loss'] = noncong_loss.detach().item()
76
+
77
+ loss+=contr_loss
78
+ if self.hparams.pred_fp:
79
+ fp_loss_val = self.loss_wts['fp_wt'] *self.fp_loss(output['fp'], batch['fp'])
80
+ loss+= fp_loss_val
81
+ losses['fp_loss'] = fp_loss_val.detach().item()
82
+
83
+ if 'aug_cand_enc' in output:
84
+ aug_cand_loss = self.loss_wts['aug_cand_wt'] * cand_spec_sim_loss(spec_enc, output['aug_cand_enc'])
85
+ loss+= aug_cand_loss
86
+ losses['aug_cand_loss'] = aug_cand_loss.detach().item()
87
+
88
+ if 'ind_spec' in output:
89
+ spec_loss = self.loss_wts['cons_spec_wt'] * self.cons_loss(spec_enc, output['ind_spec'])
90
+ loss+=spec_loss
91
+ losses['cons_spec_loss'] = spec_loss.detach().item()
92
+
93
+ losses['loss'] = loss
94
+
95
+ return losses
96
+
97
+ def step(
98
+ self, batch: dict, stage= Stage.NONE):
99
+
100
+ # Compute spectra and mol encoding
101
+ spec_enc, mol_enc = self.forward(batch, stage)
102
+
103
+ if stage == Stage.TEST:
104
+ return dict(spec_enc=spec_enc, mol_enc=mol_enc)
105
+
106
+ # Aux tasks
107
+ output = {}
108
+ if self.hparams.pred_fp:
109
+ output['fp'] = self.fp_pred_model(mol_enc)
110
+
111
+ if self.hparams.use_cons_spec:
112
+ spec = batch[self.spec_view]
113
+ n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
114
+ output['ind_spec'] = self.spec_enc_model(spec, n_peaks)
115
+
116
+ # Calculate loss
117
+ losses = self.compute_loss(batch, spec_enc, mol_enc, output)
118
+
119
+ return losses
120
+
121
+ def on_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage) -> None:
122
+ # total loss
123
+ self.log(
124
+ f'{stage.to_pref()}loss',
125
+ outputs['loss'],
126
+ batch_size=len(batch['identifier']),
127
+ sync_dist=True,
128
+ prog_bar=True,
129
+ on_epoch=True,
130
+ # on_step=True
131
+ )
132
+
133
+ # contr loss
134
+ if self.hparams.use_contr:
135
+ self.log(
136
+ f'{stage.to_pref()}contr_loss',
137
+ outputs['contr_loss'],
138
+ batch_size=len(batch['identifier']),
139
+ sync_dist=True,
140
+ prog_bar=False,
141
+ on_epoch=True,
142
+ # on_step=True
143
+ )
144
+
145
+ # noncongruent pairs
146
+ self.log(
147
+ f'{stage.to_pref()}noncong_loss',
148
+ outputs['noncong_loss'],
149
+ batch_size=len(batch['identifier']),
150
+ sync_dist=True,
151
+ prog_bar=False,
152
+ on_epoch=True,
153
+ # on_step=True
154
+ )
155
+
156
+ # congruent pairs
157
+ self.log(
158
+ f'{stage.to_pref()}cong_loss',
159
+ outputs['cong_loss'],
160
+ batch_size=len(batch['identifier']),
161
+ sync_dist=True,
162
+ prog_bar=False,
163
+ on_epoch=True,
164
+ # on_step=True
165
+ )
166
+
167
+
168
+ if self.hparams.pred_fp:
169
+
170
+ self.log(
171
+ f'{stage.to_pref()}_fp_loss',
172
+ outputs['fp_loss'],
173
+ batch_size=len(batch['identifier']),
174
+ sync_dist=True,
175
+ prog_bar=False,
176
+ on_epoch=True,
177
+ )
178
+
179
+ if self.hparams.use_cons_spec:
180
+ self.log(
181
+ f'{stage.to_pref()}cons_loss',
182
+ outputs['cons_spec_loss'],
183
+ batch_size=len(batch['identifier']),
184
+ sync_dist=True,
185
+ prog_bar=False,
186
+ on_epoch=True,
187
+ )
188
+
189
+ def test_step(self, batch, batch_idx):
190
+ # Unpack inputs
191
+ identifiers = batch['identifier']
192
+ cand_smiles = batch['cand_smiles']
193
+ id_to_ct = defaultdict(int)
194
+ for i in identifiers: id_to_ct[i]+=1
195
+ batch_ptr = torch.tensor(list(id_to_ct.values()))
196
+
197
+ outputs = self.step(batch, stage=Stage.TEST)
198
+ spec_enc = outputs['spec_enc']
199
+ mol_enc = outputs['mol_enc']
200
+
201
+ # Calculate scores
202
+ indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
203
+
204
+ scores = nn.functional.cosine_similarity(spec_enc, mol_enc)
205
+ scores = torch.split(scores, list(id_to_ct.values()))
206
+
207
+ cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
208
+ labels = utils.unbatch_list(batch['label'], indexes)
209
+
210
+ return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
211
+
212
+ def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None:
213
+
214
+ # save scores
215
+ for i, cands, scores, l in zip(outputs['identifiers'], outputs['cand_smiles'], outputs['scores'], outputs['labels']):
216
+ self.result_dct[i]['candidates'].extend(cands)
217
+ self.result_dct[i]['scores'].extend(scores.cpu().tolist())
218
+ self.result_dct[i]['labels'].extend([x.cpu().item() for x in l])
219
+
220
+ def _compute_rank(self, scores, labels):
221
+ if not any(labels):
222
+ return -1
223
+ scores = np.array(scores)
224
+ target_score = scores[labels][0]
225
+ rank = np.count_nonzero(scores >=target_score)
226
+ return rank
227
+
228
+ def on_test_epoch_end(self) -> None:
229
+
230
+ self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'})
231
+
232
+ # Compute rank
233
+ self.df_test['rank'] = self.df_test.apply(lambda row: self._compute_rank(row['scores'], row['labels']), axis=1)
234
+ if not self.df_test_path:
235
+ self.df_test_path = os.path.join(self.hparams['experiment_dir'], 'result.pkl')
236
+ # self.df_test_path.parent.mkdir(parents=True, exist_ok=True)
237
+ self.df_test.to_pickle(self.df_test_path)
238
+
239
+ def get_checkpoint_monitors(self) -> T.List[dict]:
240
+ monitors = [
241
+ {"monitor": f"{Stage.TRAIN.to_pref()}loss", "mode": "min", "early_stopping": False}, # monitor train loss
242
+ ]
243
+ return monitors
244
+
245
+ def _update_loss_weights(self)-> None:
246
+ if self.hparams.loss_strategy == 'linear':
247
+ for loss in self.loss_wts:
248
+ self.loss_wts[loss] += self.loss_updates[loss]
249
+ elif self.hparams.loss_strategy == 'manual':
250
+ for loss in self.loss_wts:
251
+ if self.current_epoch in self.loss_updates[loss]:
252
+ self.loss_wts[loss] = self.loss_updates[loss][self.current_epoch]
253
+
254
+ def on_train_epoch_end(self) -> None:
255
+ self._update_loss_weights()
256
+
257
+ class MultiViewContrastive(ContrastiveModel):
258
+
259
+ def __init__(self,
260
+ **kwargs):
261
+
262
+ super().__init__(**kwargs)
263
+
264
+ # build fingerprint encoder model
265
+ if self.hparams.use_fp:
266
+ self.fp_enc_model = model_utils.get_fp_enc_model(self.hparams)
267
+
268
+ # build NL encoder model
269
+ # if self.hparams.use_NL_spec:
270
+ # self.NL_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
271
+
272
+ def forward(self, batch, stage):
273
+ g = batch['cand'] if stage == Stage.TEST else batch['mol']
274
+
275
+ spec = batch[self.spec_view]
276
+ n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
277
+
278
+ spec_enc = self.spec_enc_model(spec, n_peaks)
279
+ mol_enc = self.mol_enc_model(g)
280
+ views = {'spec_enc': spec_enc, 'mol_enc': mol_enc}
281
+
282
+ if self.hparams.use_fp:
283
+ fp_enc = self.fp_enc_model(batch['fp'])
284
+ views['fp_enc'] = fp_enc
285
+
286
+ if self.hparams.use_cons_spec:
287
+ spec = batch['cons_spec']
288
+ n_peaks = batch['cons_n_peaks'] if 'cons_n_peaks' in batch else None
289
+ spec_enc = self.cons_spec_enc_model(spec, n_peaks)
290
+ views['cons_spec_enc'] = spec_enc
291
+
292
+ if self.hparams.use_NL_spec:
293
+ spec = batch['NL_spec']
294
+ n_peaks = batch['NL_n_peaks'] if 'NL_n_peaks' in batch else None
295
+ spec_enc = self.NL_enc_model(spec, n_peaks)
296
+ views['NL_spec_enc'] = spec_enc
297
+ return views
298
+
299
+ def step(
300
+ self, batch: dict, stage= Stage.NONE):
301
+
302
+ # Compute spectra and mol encoding
303
+ views = self.forward(batch, stage)
304
+
305
+ if stage == Stage.TEST:
306
+ return views
307
+
308
+ # Calculate loss
309
+ losses = self.compute_loss(batch, views)
310
+
311
+ return losses
312
+
313
+ def compute_loss(self, batch: dict, views: dict):
314
+ loss = 0
315
+ losses = {}
316
+ for v1, v2 in self.hparams.contr_views:
317
+ contr_loss, cong_loss, noncong_loss = contrastive_loss(views[v1], views[v2], self.hparams.contr_temp)
318
+ loss+=contr_loss
319
+
320
+ losses[f'{v1[:-4]}-{v2[:-4]}_contr_loss'] = contr_loss.detach().item()
321
+ losses[f'{v1[:-4]}-{v2[:-4]}_cong_loss'] = cong_loss.detach().item()
322
+ losses[f'{v1[:-4]}-{v2[:-4]}_noncong_loss'] = noncong_loss.detach().item()
323
+
324
+ losses['loss'] = loss
325
+
326
+ return losses
327
+
328
+ def on_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage) -> None:
329
+ # total loss
330
+ self.log(
331
+ f'{stage.to_pref()}loss',
332
+ outputs['loss'],
333
+ batch_size=len(batch['identifier']),
334
+ sync_dist=True,
335
+ prog_bar=True,
336
+ on_epoch=True,
337
+ # on_step=True
338
+ )
339
+
340
+ for v1, v2 in self.hparams.contr_views:
341
+ self.log(
342
+ f'{stage.to_pref()}{v1[:-4]}-{v2[:-4]}_contr_loss',
343
+ outputs[f'{v1[:-4]}-{v2[:-4]}_contr_loss'],
344
+ batch_size=len(batch['identifier']),
345
+ sync_dist=True,
346
+ on_epoch=True,
347
+ )
348
+ self.log(
349
+ f'{stage.to_pref()}{v1[:-4]}-{v2[:-4]}_cong_loss',
350
+ outputs[f'{v1[:-4]}-{v2[:-4]}_cong_loss'],
351
+ batch_size=len(batch['identifier']),
352
+ sync_dist=True,
353
+ on_epoch=True,
354
+ )
355
+ self.log(
356
+ f'{stage.to_pref()}{v1[:-4]}-{v2[:-4]}_noncong_loss',
357
+ outputs[f'{v1[:-4]}-{v2[:-4]}_noncong_loss'],
358
+ batch_size=len(batch['identifier']),
359
+ sync_dist=True,
360
+ on_epoch=True,
361
+ )
362
+
363
+ def test_step(self, batch, batch_idx):
364
+ # Unpack inputs
365
+ identifiers = batch['identifier']
366
+ cand_smiles = batch['cand_smiles']
367
+ id_to_ct = defaultdict(int)
368
+ for i in identifiers: id_to_ct[i]+=1
369
+ batch_ptr = torch.tensor(list(id_to_ct.values()))
370
+
371
+ outputs = self.step(batch, stage=Stage.TEST)
372
+ scores = {}
373
+ for v1, v2 in self.hparams.contr_views:
374
+ # if 'cons_spec_enc' in (v1, v2):
375
+ # continue
376
+ v1_enc = outputs[v1]
377
+ v2_enc = outputs[v2]
378
+
379
+ s = nn.functional.cosine_similarity(v1_enc, v2_enc)
380
+ scores[f'{v1[:-4]}-{v2[:-4]}_scores'] = torch.split(s, list(id_to_ct.values()))
381
+
382
+ indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
383
+
384
+ cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
385
+ labels = utils.unbatch_list(batch['label'], indexes)
386
+
387
+ return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
388
+
389
+ def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None:
390
+
391
+ # save scores
392
+ for i, cands, l in zip(outputs['identifiers'], outputs['cand_smiles'], outputs['labels']):
393
+ self.result_dct[i]['candidates'].extend(cands)
394
+ self.result_dct[i]['labels'].extend([x.cpu().item() for x in l])
395
+
396
+ for v1, v2 in self.hparams.contr_views:
397
+ for i, scores in zip(outputs['identifiers'], outputs['scores'][f'{v1[:-4]}-{v2[:-4]}_scores']):
398
+ self.result_dct[i][f'{v1[:-4]}-{v2[:-4]}_scores'].extend(scores.cpu().tolist())
399
+
400
+ def _get_top_cand(self, scores, candidates):
401
+ return candidates[np.argmax(np.array(scores))]
402
+
403
+ def on_test_epoch_end(self) -> None:
404
+
405
+ self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'})
406
+
407
+ # Compute rank
408
+ if not self.external_test:
409
+ for v1, v2 in self.hparams.contr_views:
410
+ self.df_test[f'{v1[:-4]}-{v2[:-4]}_rank'] = self.df_test.apply(lambda row: self._compute_rank(row[f'{v1[:-4]}-{v2[:-4]}_scores'], row['labels']), axis=1)
411
+
412
+ if self.external_test:
413
+ self.df_test.drop('labels', axis=1, inplace=True)
414
+ for v1, v2 in self.hparams.contr_views:
415
+ self.df_test[f'top_{v1[:-4]}-{v2[:-4]}_cand'] = self.df_test.apply(lambda row: self._get_top_cand(row[f'{v1[:-4]}-{v2[:-4]}_scores'], row['candidates']), axis=1)
416
+ self.df_test.to_pickle(self.df_test_path)
mvp/models/encoders.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+ class MLP(nn.Module):
5
+ def __init__(self, in_dim, hidden_dims, dropout=0.1, final_activation=None):
6
+ super(MLP, self).__init__()
7
+
8
+ self.dropout = nn.Dropout(dropout)
9
+ self.has_final_activation = False
10
+ layers = [nn.Linear(in_dim, hidden_dims[0])]
11
+ for d1, d2 in zip(hidden_dims[:-1], hidden_dims[1:]):
12
+ layers.append(nn.Linear(d1, d2))
13
+ self.layers = nn.ModuleList(layers)
14
+ if final_activation is not None:
15
+ self.has_final_activation = True
16
+
17
+ self.final_activation = {'relu': F.relu,
18
+ 'sigmoid': F.sigmoid,
19
+ 'softmax': F.softmax,}[final_activation]
20
+
21
+ def forward(self, x):
22
+ for i, layer in enumerate(self.layers):
23
+ x = layer(x)
24
+ if i < len(self.layers) -1:
25
+ x = F.relu(x)
26
+ x = self.dropout(x)
27
+ elif self.has_final_activation:
28
+ x = self.final_activation(x)
29
+ return x
mvp/models/mol_encoder.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import dgl
4
+ from dgllife.model import GCN, GAT
5
+
6
+ class MolEnc(nn.Module):
7
+
8
+ def __init__(self,
9
+ args,
10
+ in_dim,):
11
+ super().__init__()
12
+
13
+ self.return_emb = False
14
+
15
+ if args.model in ('crossAttenContrastive', 'filipContrastive'):
16
+ self.return_emb = True
17
+
18
+ dropout = [args.gnn_dropout for _ in range(len(args.gnn_channels))]
19
+ batchnorm = [True for _ in range(len(args.gnn_channels))]
20
+ gnn_map = {
21
+ "gcn": GCN(in_dim, args.gnn_channels, batchnorm = batchnorm, dropout = dropout),
22
+ "gat": GAT(in_dim, args.gnn_channels, args.attn_heads)
23
+ }
24
+ self.GNN = gnn_map[args.gnn_type]
25
+ self.pool = dgl.nn.pytorch.glob.MaxPooling()
26
+
27
+ if not self.return_emb:
28
+ self.fc1_graph = nn.Linear(args.gnn_channels[len(args.gnn_channels) - 1], args.gnn_hidden_dim * 2)
29
+ self.fc2_graph = nn.Linear(args.gnn_hidden_dim * 2, args.final_embedding_dim)
30
+
31
+ self.dropout = nn.Dropout(args.fc_dropout)
32
+ self.relu = nn.ReLU()
33
+
34
+ def forward(self, g, fp=None) -> torch.Tensor:
35
+ g1 = g
36
+ f1 = g.ndata['h']
37
+
38
+ f = self.GNN(g1, f1)
39
+ if self.return_emb:
40
+ return f
41
+ h = self.pool(g1, f)
42
+ if fp is not None:
43
+ h = torch.concat((h, fp), dim=-1)
44
+ h1 = self.relu(self.fc1_graph(h))
45
+ h1 = self.dropout(h1)
46
+ h1 = self.fc2_graph(h1)
47
+ h1 = self.dropout(h1)
48
+
49
+ return h1
50
+
mvp/models/spec_encoder.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from mvp.models.encoders import MLP
4
+ from torch_geometric.nn import global_mean_pool
5
+
6
+
7
+ class SpecEncMLP_BIN(nn.Module):
8
+ def __init__(self, args, out_dim=None):
9
+ super(SpecEncMLP_BIN, self).__init__()
10
+
11
+ if not out_dim:
12
+ out_dim = args.final_embedding_dim
13
+
14
+ bin_size = int(args.max_mz / args.bin_width)
15
+ self.dropout = nn.Dropout(args.fc_dropout)
16
+ self.mz_fc1 = nn.Linear(bin_size, out_dim * 2)
17
+ self.mz_fc2 = nn.Linear(out_dim* 2, out_dim * 2)
18
+ self.mz_fc3 = nn.Linear(out_dim * 2, out_dim)
19
+ self.relu = nn.ReLU()
20
+
21
+ def forward(self, mzi_b, n_peaks=None):
22
+
23
+ h1 = self.mz_fc1(mzi_b)
24
+ h1 = self.relu(h1)
25
+ h1 = self.dropout(h1)
26
+ h1 = self.mz_fc2(h1)
27
+ h1 = self.relu(h1)
28
+ h1 = self.dropout(h1)
29
+ mz_vec = self.mz_fc3(h1)
30
+ mz_vec = self.dropout(mz_vec)
31
+
32
+ return mz_vec
33
+
34
+
35
+ class SpecFormulaTransformer(nn.Module):
36
+ def __init__(self, args, out_dim=None):
37
+ super(SpecFormulaTransformer, self).__init__()
38
+ in_dim = len(args.element_list)
39
+ if args.add_intensities: # intensity
40
+ in_dim+=1
41
+ if args.spectra_view == "SpecFormulaMz": #mz
42
+ in_dim+=1
43
+
44
+ self.returnEmb = False
45
+
46
+ self.formulaEnc = MLP(in_dim=in_dim, hidden_dims=args.formula_dims, dropout=args.formula_dropout)
47
+
48
+ self.use_cls = args.use_cls
49
+ if args.use_cls:
50
+ self.cls_embed = torch.nn.Embedding(1,args.formula_dims[-1])
51
+ encoder_layer = nn.TransformerEncoderLayer(d_model=args.formula_dims[-1], nhead=2, batch_first=True)
52
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
53
+
54
+ if not out_dim:
55
+ out_dim = args.final_embedding_dim
56
+ self.fc = nn.Linear(args.formula_dims[-1], out_dim)
57
+
58
+ def forward(self, spec, n_peaks):
59
+ h = self.formulaEnc(spec)
60
+ pad = (spec == -5)
61
+ pad = torch.all(pad, -1)
62
+
63
+ if self.use_cls:
64
+ cls_embed = self.cls_embed(torch.tensor(0).to(spec.device))
65
+ h = torch.concat((cls_embed.repeat(spec.shape[0], 1).unsqueeze(1), h), dim=1)
66
+ pad = torch.concat((torch.tensor(False).repeat(pad.shape[0],1).to(spec.device), pad), dim=1)
67
+ h = self.transformer(h, src_key_padding_mask=pad)
68
+ h = h[:,0,:]
69
+ else:
70
+ h = self.transformer(h, src_key_padding_mask=pad)
71
+
72
+ if self.returnEmb:
73
+ # repad h
74
+ h[pad] = -5
75
+ return h
76
+
77
+ h = h[~pad].reshape(-1, h.shape[-1])
78
+ indecies = torch.tensor([i for i, count in enumerate(n_peaks) for _ in range(count)]).to(h.device)
79
+ h = global_mean_pool(h, indecies)
80
+
81
+ h = self.fc(h)
82
+
83
+ return h
84
+
85
+
mvp/params_binnedSpec.yaml ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Experiment setup
3
+ job_key: ''
4
+ run_name: 'binnedSpec_experiment'
5
+ run_details: ""
6
+ project_name: ''
7
+ wandb_entity_name: 'mass-spec-ml'
8
+ no_wandb: True
9
+ seed: 0
10
+ debug: False
11
+ checkpoint_pth: ""
12
+
13
+ # Training setup
14
+ max_epochs: 1000
15
+ accelerator: 'gpu'
16
+ devices: [1]
17
+ log_every_n_steps: 250
18
+ val_check_interval: 1.0
19
+
20
+ # Data paths
21
+ candidates_pth: ../data/sample/candidates_mass.json
22
+ dataset_pth: "../data/sample/data.tsv"
23
+ subformula_dir_pth: ""
24
+ split_pth:
25
+ fp_dir_pth: '../data/sample/morganfp_r5_1024.pickle'
26
+ cons_spec_dir_pth: "../data/sample/consensus_binnedSpec.pkl"
27
+ NL_spec_dir_pth: ""
28
+ partial_checkpoint: ""
29
+
30
+ # General hyperparameters
31
+ batch_size: 64
32
+ lr: 5.0e-4
33
+ weight_decay: 0
34
+ contr_temp: 0.05
35
+ early_stopping_patience: 300
36
+ loss_strategy: 'static' # static, linear, manual
37
+ num_workers: 50
38
+
39
+
40
+ ############################## Data transforms ##############################
41
+ # - Spectra
42
+ spectra_view: SpecBinnerLog
43
+ max_mz: 1000
44
+ bin_width: 1
45
+ mask_peak_ratio: 0.00
46
+
47
+ # 2. SpecFormula
48
+ element_list: ['H', 'C', 'O', 'N', 'P', 'S', 'Cl', 'F', 'Br', 'I', 'B', 'As', 'Si', 'Se']
49
+ add_intensities: True
50
+ mask_precursor: False
51
+
52
+ # - Molecule
53
+ molecule_view: "MolGraph"
54
+ atom_feature: 'full'
55
+ bond_feature: 'full'
56
+
57
+
58
+ ############################## Views ##############################
59
+ # contrastive
60
+ use_contr: True
61
+ contr_wt: 1
62
+ contr_wt_update: {}
63
+
64
+ # consensus spectra
65
+ use_cons_spec: False
66
+ cons_spec_wt: 3
67
+ cons_spec_wt_update: {}
68
+ cons_loss_type: 'l2' # cosine, l2
69
+
70
+ # fp prediction/usage
71
+ pred_fp: False
72
+ use_fp: False
73
+ fp_loss_type: 'cosine' #cosine, bce
74
+ fp_wt: 3
75
+ fp_wt_update: {}
76
+ fp_size: 1024
77
+ fp_radius: 5
78
+ fp_dropout: 0.4
79
+
80
+ # candidates
81
+ aug_cands: False
82
+ aug_cands_wt: 0.1
83
+ aug_cands_update: {}
84
+ aug_cands_size: 3
85
+
86
+ # neutral loss
87
+ use_NL: False
88
+
89
+
90
+
91
+ ############################## Task and model ##############################
92
+ task: 'retrieval'
93
+ spec_enc: MLP_BIN
94
+ mol_enc: "GNN"
95
+ model: "MultiviewContrastive"
96
+ contr_views: [['spec_enc', 'mol_enc']]
97
+ log_only_loss_at_stages: []
98
+ df_test_path: ""
99
+
100
+ # - Spectra encoder
101
+ final_embedding_dim: 512
102
+ fc_dropout: 0.4
103
+
104
+ # - Spectra Token encoder
105
+ hidden_dims: [64, 128]
106
+ peak_dropout: 0.2
107
+
108
+ # - Formula-based spec encoders
109
+ formula_dropout: 0.2
110
+ formula_dims: [64, 128, 256]
111
+ cross_attn_heads: 2
112
+ use_cls: True
113
+
114
+ # -- GAT params
115
+ attn_heads: [12,12,12]
116
+
117
+ # - Molecule encoder (GNN)
118
+ gnn_channels: [64,128,256]
119
+ gnn_type: "gcn"
120
+ num_gnn_layers: 3
121
+ gnn_hidden_dim: 512
122
+ gnn_dropout: 0.3
mvp/params_formSpec.yaml ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Experiment setup
2
+ job_key: ''
3
+ run_name: 'filip_large'
4
+ run_details: ""
5
+ project_name: ''
6
+ wandb_entity_name: 'mass-spec-ml'
7
+ no_wandb: True
8
+ seed: 0
9
+ debug: False
10
+ checkpoint_pth: #'../pretrained_models/msgym_formSpec.ckpt'
11
+
12
+ # Training setup
13
+ max_epochs: 2000
14
+ accelerator: 'gpu'
15
+ devices: [1]
16
+ log_every_n_steps: 250
17
+ val_check_interval: 1.0
18
+
19
+ # Data paths
20
+ candidates_pth: ../data/sample/candidates_mass.json
21
+ dataset_pth: ../data/MassSpecGym/data/sample_data.tsv
22
+ subformula_dir_pth: ../data/MassSpecGym/data/subformulae_default
23
+ split_pth:
24
+ fp_dir_pth: '../data/MassSpecGym/data/morganfp_r5_1024.pickle'
25
+ cons_spec_dir_pth: "../data/MassSpecGym/data/sample_consensus_formSpec.pkl"
26
+ NL_spec_dir_pth: ""
27
+ partial_checkpoint: ""
28
+
29
+ # General hyperparameters
30
+ batch_size: 64
31
+ lr: 5.0e-05
32
+ weight_decay: 0
33
+ contr_temp: 0.05
34
+ early_stopping_patience: 300
35
+ loss_strategy: 'static'
36
+ num_workers: 50
37
+
38
+
39
+ ############################## Data transforms ##############################
40
+ # - Spectra
41
+ spectra_view: SpecFormula
42
+ # 1. Binner
43
+ max_mz: 1000
44
+ bin_width: 1
45
+ mask_peak_ratio: 0.00
46
+
47
+ # 2. SpecFormula
48
+ element_list: ['H', 'C', 'O', 'N', 'P', 'S', 'Cl', 'F', 'Br', 'I', 'B', 'As', 'Si', 'Se']
49
+ add_intensities: True
50
+ mask_precursor: False
51
+
52
+ # - Molecule
53
+ molecule_view: "MolGraph"
54
+ atom_feature: 'full'
55
+ bond_feature: 'full'
56
+
57
+
58
+ ############################## Views ##############################
59
+ # contrastive
60
+ use_contr: False
61
+ contr_wt: 1
62
+ contr_wt_update: {}
63
+
64
+ # consensus spectra
65
+ use_cons_spec: False
66
+ cons_spec_wt: 3
67
+ cons_spec_wt_update: {}
68
+ cons_loss_type: 'l2' # cosine, l2
69
+
70
+ # fp prediction/usage
71
+ pred_fp: False
72
+ use_fp: False
73
+ fp_loss_type: 'cosine' #cosine, bce
74
+ fp_wt: 3
75
+ fp_wt_update: {}
76
+ fp_size: 1024
77
+ fp_radius: 5
78
+ fp_dropout: 0.4
79
+
80
+ # candidates
81
+ aug_cands: False
82
+ aug_cands_wt: 0.1
83
+ aug_cands_update: {}
84
+ aug_cands_size: 3
85
+
86
+ # neutral loss
87
+ use_NL: False
88
+
89
+
90
+ ############################## Task and model ##############################
91
+ task: 'retrieval'
92
+ spec_enc: Transformer_Formula
93
+ mol_enc: "GNN"
94
+ model: MultiviewContrastive
95
+ contr_views: [['spec_enc', 'mol_enc'], ['spec_enc', 'NL_spec_enc'], ['mol_enc', 'NL_spec_enc']] #[['spec_enc', 'mol_enc'], ['mol_enc', 'cons_spec_enc'], ['cons_spec_enc', 'spec_enc'], ['fp_enc', 'mol_enc'], ['fp_enc', 'spec_enc'], ['fp_enc', 'cons_spec_enc']]
96
+ log_only_loss_at_stages: []
97
+ df_test_path: ""
98
+
99
+ # - Spectra encoder
100
+ final_embedding_dim: 512
101
+ fc_dropout: 0.4
102
+
103
+ # - Spectra Token encoder
104
+ hidden_dims: [64, 128]
105
+ peak_dropout: 0.2
106
+
107
+ # - Formula-based spec encoders
108
+ formula_dropout: 0.2
109
+ formula_dims: [64, 128, 256]
110
+ cross_attn_heads: 2
111
+ use_cls: False
112
+
113
+ # -- GAT params
114
+ attn_heads: [12,12,12]
115
+
116
+ # - Molecule encoder (GNN)
117
+ gnn_channels: [64,128,256]
118
+ gnn_type: "gcn"
119
+ num_gnn_layers: 3
120
+ gnn_hidden_dim: 512
121
+ gnn_dropout: 0.3
mvp/run.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # 1. preprocess data (subformula labels should be obtained through MIST)
2
+ python subformula_assign/assign_subformulae.py --spec-files ../data/sample/data.tsv --output-dir ../data/sample/subformulae_default --max-formulae 60 --labels-file ../data/sample/data.tsv
3
+ python data_preprocess.py --spec_type formSpec --dataset_pth ../data/sample/data.tsv --candidates_pth ../data/sample/candidates_mass.json --subformula_dir_pth ../data/sample/subformulae_default/ --output_dir ../data/sample/
4
+
5
+ # 2. train model on msgym
6
+ python train.py --param_pth params_formSpec.yaml
7
+
8
+ # 3. test model on msgym
9
+ python train.py --param_pth params_binnedSpec.yaml
mvp/test.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import sys
4
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
5
+
6
+ from rdkit import RDLogger
7
+ import pytorch_lightning as pl
8
+ from pytorch_lightning import Trainer
9
+ from massspecgym.models.base import Stage
10
+ import os
11
+
12
+ from mvp.data.data_module import TestDataModule
13
+ from mvp.data.datasets import ContrastiveDataset
14
+ from mvp.utils.data import get_spec_featurizer, get_mol_featurizer, get_test_ms_dataset
15
+ from mvp.utils.models import get_model
16
+
17
+ from mvp.definitions import TEST_RESULTS_DIR
18
+ import yaml
19
+ from functools import partial
20
+ # Suppress RDKit warnings and errors
21
+ lg = RDLogger.logger()
22
+ lg.setLevel(RDLogger.CRITICAL)
23
+
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument("--param_pth", type=str, default="params_formSpec.yaml")
26
+ parser.add_argument('--checkpoint_pth', type=str, default='')
27
+ parser.add_argument('--checkpoint_choice', type=str, default='train', choices=['train', 'val'])
28
+ parser.add_argument('--df_test_pth', type=str, help='result file name')
29
+ parser.add_argument('--exp_dir', type=str)
30
+ parser.add_argument('--candidates_pth', type=str)
31
+ def main(params):
32
+ # Seed everything
33
+ pl.seed_everything(params['seed'])
34
+
35
+ # Init paths to data files
36
+ if params['debug']:
37
+ params['dataset_pth'] = "../data/sample/data.tsv"
38
+ params['split_pth']=None
39
+ params['df_test_path'] = os.path.join(params['experiment_dir'], 'debug_result.pkl')
40
+
41
+ # Load dataset
42
+ spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
43
+ mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
44
+ dataset = get_test_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)
45
+
46
+ # Init data module
47
+ collate_fn = partial(ContrastiveDataset.collate_fn, spec_enc=params['spec_enc'], spectra_view=params['spectra_view'], stage=Stage.TEST)
48
+ data_module = TestDataModule(
49
+ dataset=dataset,
50
+ collate_fn=collate_fn,
51
+ split_pth=params['split_pth'],
52
+ batch_size=params['batch_size'],
53
+ num_workers=params['num_workers']
54
+ )
55
+
56
+ model = get_model(params['model'], params)
57
+ model.df_test_path = params['df_test_path']
58
+
59
+ # Init trainer
60
+ trainer = Trainer(
61
+ accelerator=params['accelerator'],
62
+ devices=params['devices'],
63
+ default_root_dir=params['experiment_dir']
64
+ )
65
+
66
+ # Prepare data module to test
67
+ data_module.prepare_data()
68
+ data_module.setup(stage="test")
69
+
70
+ # Test
71
+ trainer.test(model, datamodule=data_module)
72
+
73
+
74
+ if __name__ == "__main__":
75
+ args = parser.parse_args([] if "__file__" not in globals() else None)
76
+
77
+ # Load
78
+ with open(args.param_pth) as f:
79
+ params = yaml.load(f, Loader=yaml.FullLoader)
80
+
81
+ # Experiment directory
82
+ if args.exp_dir:
83
+ exp_dir = args.exp_dir
84
+ else:
85
+ run_name = params['run_name']
86
+ for exp in os.listdir(TEST_RESULTS_DIR): # find exp dir with matching run_name
87
+ if exp.endswith("_"+run_name):
88
+ exp_dir = str(TEST_RESULTS_DIR / exp)
89
+ break
90
+ if not exp_dir:
91
+ now = datetime.datetime.now().strftime("%Y%m%d")
92
+ exp_dir = str(TEST_RESULTS_DIR / f"{now}_{params['run_name']}")
93
+ os.makedirs(exp_dir, exist_ok=True)
94
+ print("EXPERIMENT directory: ",exp_dir)
95
+ params['experiment_dir'] = exp_dir
96
+
97
+ # Checkpoint path
98
+ if args.checkpoint_pth:
99
+ params['checkpoint_pth'] = args.checkpoint_pth
100
+
101
+ if not params['checkpoint_pth']:
102
+ print("No checkpoint provided. Using the checkpoint in the experiment directory")
103
+ for f in os.listdir(exp_dir):
104
+ if f.endswith("ckpt") and f.startswith("epoch") and args.checkpoint_choice in f:
105
+ checkpoint_path = os.path.join(exp_dir, f)
106
+ params['checkpoint_pth'] = checkpoint_path
107
+ break
108
+ assert(params['checkpoint_pth'] != '')
109
+
110
+ if args.candidates_pth:
111
+ params['candidates_pth'] = args.candidates_pth
112
+ if args.df_test_pth:
113
+ params['df_test_path'] = os.path.join(exp_dir, args.df_test_pth)
114
+ if not params['df_test_path']:
115
+ params['df_test_path'] = os.path.join(exp_dir, f"result_{params['candidates_pth'].split('/')[-1].split('.')[0]}.pkl")
116
+
117
+ main(params)
mvp/train.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+
4
+ import os
5
+ import sys
6
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
7
+
8
+ from rdkit import RDLogger
9
+ import pytorch_lightning as pl
10
+ from pytorch_lightning import Trainer
11
+ from pytorch_lightning.callbacks.early_stopping import EarlyStopping
12
+
13
+
14
+ from mvp.data.data_module import ContrastiveDataModule
15
+
16
+ from mvp.definitions import TEST_RESULTS_DIR
17
+ import yaml
18
+ from mvp.data.datasets import ContrastiveDataset
19
+ from functools import partial
20
+
21
+ from mvp.utils.data import get_ms_dataset, get_spec_featurizer, get_mol_featurizer
22
+ from mvp.utils.models import get_model
23
+ # Suppress RDKit warnings and errors
24
+ lg = RDLogger.logger()
25
+ lg.setLevel(RDLogger.CRITICAL)
26
+
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument("--param_pth", type=str, default="params_formSpec.yaml")
29
+
30
+ def main(params):
31
+ # Seed everything
32
+ pl.seed_everything(params['seed'])
33
+
34
+ # Init paths to data files
35
+ if params['debug']:
36
+ params['dataset_pth'] = "../data/sample/data.tsv"
37
+ params['candidates_pth'] =None
38
+ params['split_pth']=None
39
+
40
+ # Load dataset
41
+ spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
42
+ mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
43
+ dataset = get_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)
44
+
45
+ # Init data module
46
+ collate_fn = partial(ContrastiveDataset.collate_fn, spec_enc=params['spec_enc'], spectra_view=params['spectra_view'], mask_peak_ratio=params['mask_peak_ratio'], aug_cands=params['aug_cands'])
47
+ data_module = ContrastiveDataModule(
48
+ dataset=dataset,
49
+ collate_fn=collate_fn,
50
+ split_pth=params['split_pth'],
51
+ batch_size=params['batch_size'],
52
+ num_workers=params['num_workers'],
53
+ )
54
+
55
+ model = get_model(params['model'], params)
56
+
57
+ # Init logger
58
+ if params['no_wandb']:
59
+ logger = None
60
+ else:
61
+ logger = pl.loggers.WandbLogger(
62
+ save_dir=params['experiment_dir'],
63
+ dir=params['experiment_dir'],
64
+ log_dir=params['experiment_dir'],
65
+ name=params['run_name'],
66
+ project=params['project_name'],
67
+ log_model=False,
68
+ config=model.hparams
69
+ )
70
+
71
+ # Init callbacks for checkpointing and early stopping
72
+ callbacks = [pl.callbacks.ModelCheckpoint(save_last=False) ]
73
+ for i, monitor in enumerate(model.get_checkpoint_monitors()):
74
+ monitor_name = monitor['monitor']
75
+ checkpoint = pl.callbacks.ModelCheckpoint(
76
+ monitor=monitor_name,
77
+ save_top_k=1,
78
+ mode=monitor['mode'],
79
+ dirpath=params['experiment_dir'],
80
+ filename=f'{{epoch}}-{{{monitor_name}:.2f}}',
81
+ # filename='{epoch}-{val_loss:.2f}-{train_loss:.2f}',
82
+ auto_insert_metric_name=True,
83
+ save_last=(i == 0)
84
+ )
85
+ callbacks.append(checkpoint)
86
+ if monitor.get('early_stopping', False):
87
+ early_stopping = EarlyStopping(
88
+ monitor=monitor_name,
89
+ mode=monitor['mode'],
90
+ verbose=True,
91
+ patience=params['early_stopping_patience'],
92
+ )
93
+ callbacks.append(early_stopping)
94
+
95
+ # Init trainer
96
+ trainer = Trainer(
97
+ accelerator=params['accelerator'],
98
+ devices=params['devices'],
99
+ max_epochs=params['max_epochs'],
100
+ logger=logger,
101
+ log_every_n_steps=params['log_every_n_steps'],
102
+ val_check_interval=params['val_check_interval'],
103
+ callbacks=callbacks,
104
+ default_root_dir=params['experiment_dir'],
105
+ )
106
+
107
+ # Prepare data module to validate or test before training
108
+ data_module.prepare_data()
109
+ data_module.setup()
110
+
111
+
112
+ # Validate before training
113
+ trainer.validate(model, datamodule=data_module)
114
+
115
+ # Train
116
+ trainer.fit(model, datamodule=data_module)
117
+
118
+
119
+
120
+ if __name__ == "__main__":
121
+ args = parser.parse_args([] if "__file__" not in globals() else None)
122
+
123
+ # Get current time
124
+ now = datetime.datetime.now()
125
+ now_formatted = now.strftime("%Y%m%d")
126
+
127
+ # Load
128
+ with open(args.param_pth) as f:
129
+ params = yaml.load(f, Loader=yaml.FullLoader)
130
+
131
+ experiment_dir = str(TEST_RESULTS_DIR / f"{now_formatted}_{params['run_name']}")
132
+ params['experiment_dir'] = experiment_dir
133
+
134
+ if not params['df_test_path']:
135
+ params['df_test_path'] = os.path.join(experiment_dir, "result.pkl")
136
+
137
+ main(params)
mvp/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import sys
2
+ sys.path.insert(0, "/data/yzhouc01/MassSpecGym")
3
+ from massspecgym.utils import *
mvp/utils/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/mvp/utils/__pycache__/__init__.cpython-311.pyc and b/mvp/utils/__pycache__/__init__.cpython-311.pyc differ
 
mvp/utils/__pycache__/data.cpython-311.pyc CHANGED
Binary files a/mvp/utils/__pycache__/data.cpython-311.pyc and b/mvp/utils/__pycache__/data.cpython-311.pyc differ
 
mvp/utils/__pycache__/loss.cpython-311.pyc CHANGED
Binary files a/mvp/utils/__pycache__/loss.cpython-311.pyc and b/mvp/utils/__pycache__/loss.cpython-311.pyc differ
 
mvp/utils/__pycache__/models.cpython-311.pyc CHANGED
Binary files a/mvp/utils/__pycache__/models.cpython-311.pyc and b/mvp/utils/__pycache__/models.cpython-311.pyc differ
 
mvp/utils/data.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+
5
+ from mvp.data.transforms import SpecBinner, SpecBinnerLog, SpecFormulaFeaturizer
6
+ from massspecgym.data.transforms import SpecTransform, MolTransform
7
+ from mvp.data.transforms import MolToGraph
8
+ import mvp.data.datasets as jestr_datasets
9
+ import typing as T
10
+ from mvp.definitions import MSGYM_FORMULA_VECTOR_NORM
11
+ import matchms
12
+
13
+ class Subformula_Loader:
14
+ def __init__(self, spectra_view, dir_path) -> None:
15
+
16
+ self.dir_path = dir_path
17
+ if spectra_view == 'SpecFormula':
18
+ self.load = self.load_subformula_data
19
+ elif spectra_view == "SpecFormulaMz":
20
+ self.load = self.load_subformula_dict
21
+ else:
22
+ raise Exception("Spectra view is not supported.")
23
+
24
+ def __call__(self, ids):
25
+ id_to_form_spec = {}
26
+ for id in ids:
27
+ data = self.load(id)
28
+ if data:
29
+ id_to_form_spec[id] = data
30
+
31
+ return id_to_form_spec
32
+
33
+ def load_subformula_data(self, spec_id: str):
34
+ '''MIST subformula format:https://github.com/samgoldman97/mist/blob/main_v2/src/mist/utils/spectra_utils.py
35
+ '''
36
+ try:
37
+ file = os.path.join(self.dir_path, spec_id+".json")
38
+ with open(file) as f:
39
+ data = json.load(f)
40
+ mzs = np.array(data['output_tbl']['mz'])
41
+ formulas = np.array(data['output_tbl']['formula'])
42
+ intensities = np.array(data['output_tbl']['ms2_inten'])
43
+
44
+ # sort by mzs
45
+ ind = mzs.argsort()
46
+ mzs = mzs[ind]
47
+ formulas = formulas[ind]
48
+ intensities = intensities[ind]
49
+ return {'formulas': formulas, 'formula_mzs': mzs, 'formula_intensities': intensities}
50
+ except:
51
+ return None
52
+
53
+ def load_subformula_dict(self, spec_id: str):
54
+ '''MIST subformula format:https://github.com/samgoldman97/mist/blob/main_v2/src/mist/utils/spectra_utils.py
55
+ '''
56
+ try:
57
+ file = os.path.join(self.dir_path, spec_id+".json")
58
+ with open(file) as f:
59
+ data = json.load(f)
60
+ mzs = np.array(data['output_tbl']['mz'])
61
+ formulas = np.array(data['output_tbl']['formula'])
62
+ intensities = np.array(data['output_tbl']['ms2_inten'])
63
+
64
+ mz_to_formulas = {mz:f for mz, f in zip(mzs, formulas)}
65
+ for mz, f in zip(mzs, formulas):
66
+ mz_to_formulas[mz] = f
67
+
68
+ ind = mzs.argsort()
69
+ mzs = mzs[ind]
70
+ formulas = formulas[ind]
71
+ intensities = intensities[ind]
72
+ return {'formulas': mz_to_formulas, 'formula_mzs': mzs, 'formula_intensities': intensities}
73
+ except:
74
+ return None
75
+
76
+ def make_tmp_subformula_spectra(row):
77
+ return {'formulas':[row['formula']], 'formula_mzs':[float(row['precursor_mz'])], 'formula_intensities':[1.0]}
78
+
79
+ def get_spec_featurizer(spectra_view: T.Union[str, list[str]],
80
+ params) -> T.Union[SpecTransform, T.Dict[str, SpecTransform]]:
81
+
82
+ featurizers = {"BinnedSpectra": SpecBinner,
83
+ "SpecBinnerLog": SpecBinnerLog,
84
+ "SpecFormula": SpecFormulaFeaturizer}
85
+
86
+ spectra_featurizer = {}
87
+
88
+ if isinstance(spectra_view, str):
89
+ spectra_view = [spectra_view]
90
+
91
+ for view in spectra_view:
92
+ featurizer_params = {'max_mz': params['max_mz']}
93
+ if view in ["BinnedSpectra", "SpecBinnerLog"]:
94
+ featurizer_params.update({'bin_width': params['bin_width']})
95
+ elif view in ["SpecFormula"]:
96
+ featurizer_params.update({'element_list': params['element_list'], 'add_intensities': params['add_intensities'], 'formula_normalize_vector': MSGYM_FORMULA_VECTOR_NORM})
97
+
98
+ spectra_featurizer[view] = featurizers[view](**featurizer_params)
99
+
100
+ return spectra_featurizer
101
+
102
+ def get_mol_featurizer(molecule_view: T.Union[str, T.List[str]], params) -> MolTransform:
103
+ featurizes = {'MolGraph':MolToGraph}
104
+ mol_featurizer = {}
105
+
106
+ if isinstance(molecule_view, str):
107
+ molecule_view = [molecule_view]
108
+ for view in molecule_view:
109
+ featurizer_params = {}
110
+ if view in ('MolGraph'):
111
+ featurizer_params.update({'atom_feature': params['atom_feature'], 'bond_feature': params['bond_feature'], 'element_list': params['element_list']})
112
+
113
+ if len(molecule_view) == 1:
114
+ return featurizes[view](**featurizer_params)
115
+
116
+ mol_featurizer[view] = featurizes[view](**featurizer_params)
117
+
118
+ return mol_featurizer
119
+
120
+ def get_test_ms_dataset(spectra_view: T.Union[str, T.List[str]],
121
+ mol_view: T.Union[str, T.List[str]],
122
+ spectra_featurizer: SpecTransform,
123
+ mol_featurizer: MolTransform,
124
+ params,
125
+ external_test: bool = False,):
126
+
127
+ use_formulas = False
128
+
129
+ views = []
130
+ for v in [spectra_view, mol_view]:
131
+ if isinstance(v, str):
132
+ views.append(v)
133
+ else: views.extend(v)
134
+ views = frozenset(views)
135
+
136
+ dataset_params = {'spectra_view': spectra_view, 'pth': params['dataset_pth'], 'spec_transform': spectra_featurizer, 'mol_transform': mol_featurizer, "candidates_pth": params['candidates_pth']}
137
+ if "SpecFormula" in views or "SpecFormulaMz" in views:
138
+ dataset_params.update({'subformula_dir_pth': params['subformula_dir_pth']})
139
+ use_formulas = True
140
+
141
+ if params['use_cons_spec']:
142
+ dataset_params.update({'cons_spec_dir_pth': params['cons_spec_dir_pth']})
143
+
144
+ if params['pred_fp'] or params['use_fp']:
145
+ dataset_params.update({'fp_dir_pth': '', 'fp_size': params['fp_size'], 'fp_radius': params['fp_radius']})
146
+
147
+ return jestr_datasets.ExpandedRetrievalDataset(use_formulas=use_formulas, external_test=external_test, **dataset_params)
148
+
149
+ def get_ms_dataset(spectra_view: str,
150
+ mol_view: str,
151
+ spectra_featurizer: SpecTransform,
152
+ mol_featurizer: MolTransform,
153
+ params):
154
+
155
+
156
+ # set up dataset_parameters
157
+ dataset_params = {'pth': params['dataset_pth'], 'spec_transform': spectra_featurizer, 'mol_transform': mol_featurizer, 'spectra_view': spectra_view}
158
+ use_formulas = False
159
+ if "SpecFormula" in spectra_view:
160
+ dataset_params.update({'subformula_dir_pth': params['subformula_dir_pth']})
161
+ use_formulas = True
162
+
163
+ if params['pred_fp'] or params['use_fp']:
164
+ dataset_params.update({'fp_dir_pth': params['fp_dir_pth']})
165
+
166
+ if params['use_cons_spec']:
167
+ dataset_params.update({'cons_spec_dir_pth': params['cons_spec_dir_pth']})
168
+
169
+ # select dataset
170
+ if params['aug_cands']:
171
+ return jestr_datasets.MassSpecDataset_Candidates(**dataset_params)
172
+ elif use_formulas:
173
+ return jestr_datasets.MassSpecDataset_PeakFormulas(**dataset_params)
174
+
175
+ return jestr_datasets.JESTR1_MassSpecDataset(**dataset_params)
176
+
177
+ class PrepMatchMS:
178
+ def __init__(self, spectra_view) -> None:
179
+
180
+ if spectra_view == 'SpecFormula':
181
+ self.prepare = self.specFormula
182
+ elif spectra_view == "SpecFormulaMz":
183
+ self.prepare = self.specFormulaMz
184
+ elif spectra_view in ('SpecBinnerLog', 'BinnedSpectra', 'SpecMzIntTokenizer'):
185
+ self.prepare = self.specMzInt
186
+ else:
187
+ raise Exception("Spectra view is not supported.")
188
+
189
+ def specFormulaMz(self, row):
190
+
191
+ return matchms.Spectrum(
192
+ mz = np.array([float(m) for m in row["mzs"].split(",")]),
193
+ intensities = np.array(
194
+ [float(i) for i in row["intensities"].split(",")]
195
+ ),
196
+ metadata = {'precursor_mz': row['precursor_mz'], 'formulas': row['formulas']}
197
+ )
198
+
199
+ def specFormula(self, row):
200
+
201
+ return matchms.Spectrum(
202
+ mz = np.array(row['formula_mzs']),
203
+ intensities = np.array(row['formula_intensities']),
204
+ metadata = {'precursor_mz': row['precursor_mz'], 'formulas': np.array(row['formulas']), 'precursor_formula': row['precursor_formula']}
205
+ )
206
+
207
+ def specMzInt(self, row):
208
+ return matchms.Spectrum(
209
+ mz = row['mzs'],
210
+ intensities = row['intensities'],
211
+ metadata = {'precursor_mz': row['precursor_mz']}
212
+ )
mvp/utils/eval.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from MassSpecGym.massspecgym.utils import MyopicMCES
2
+ import numpy as np
3
+ import tqdm
4
+ from multiprocessing import Pool
5
+
6
+ import os
7
+ import pandas as pd
8
+
9
+ class Compute_Myopic_MCES:
10
+ mces_compute = MyopicMCES()
11
+
12
+
13
+ def compute_mces(tar_cand):
14
+ target, cand = tar_cand
15
+
16
+ dist = Compute_Myopic_MCES.mces_compute(target, cand)
17
+ return (tar_cand, dist)
18
+
19
+ def compute_mces_parallel(target_cand_list, n_processes=25):
20
+
21
+
22
+ with Pool(processes=n_processes) as pool:
23
+ results = list(tqdm.tqdm(pool.imap(Compute_Myopic_MCES.compute_mces, target_cand_list), total=len(target_cand_list)))
24
+ return results
25
+
26
+ class Compute_Myopic_MCES_timeout:
27
+ mces_compute = MyopicMCES()
28
+
29
+ @staticmethod
30
+ def compute_mces(tar_cand):
31
+ target, cand = tar_cand
32
+ dist = Compute_Myopic_MCES.mces_compute(target, cand)
33
+ return (tar_cand, dist)
34
+
35
+ @staticmethod
36
+ def compute_mces_parallel(target_cand_list, n_processes=35, timeout=60): # timeout in seconds
37
+ results = []
38
+
39
+ with Pool(processes=n_processes) as pool:
40
+ async_results = [
41
+ pool.apply_async(Compute_Myopic_MCES.compute_mces, args=(tar_cand,))
42
+ for tar_cand in target_cand_list
43
+ ]
44
+ for async_res in tqdm.tqdm(async_results, total=len(target_cand_list)):
45
+ try:
46
+ result = async_res.get(timeout=timeout)
47
+ except Exception as e:
48
+ # You can log the error or return a default value
49
+ result = (None, f"Timeout or error")
50
+ results.append(result)
51
+
52
+ return results
53
+
54
+
55
+ def get_result_files(exp_dir, spec_type, views_type):
56
+ files = os.listdir(exp_dir)
57
+ mass_result = ''
58
+ form_result = ''
59
+
60
+ for f in files:
61
+ try:
62
+ _, s, views = f.split('_')
63
+ except:
64
+ continue
65
+
66
+ if s == spec_type and views == views_type:
67
+ print(exp_dir / f)
68
+
69
+ files = os.listdir(exp_dir / f)
70
+ for fr in files:
71
+ if 'mass_result' in fr:
72
+ mass_result = exp_dir / f / fr
73
+ elif 'result' in fr:
74
+ form_result = exp_dir / f/ fr
75
+
76
+ return mass_result, form_result
77
+
78
+ # get target
79
+ def get_target(candidates, labels):
80
+ return np.array(candidates)[labels][0]
81
+
82
+ # get mol rank at 1
83
+ def get_top_cand(candidates, scores):
84
+ return candidates[np.argmax(scores)]
85
+
86
+ # split into hit rates
87
+ def convert_rank_to_hit_rates(row, rank_col ,top_k=[1,5,20]):
88
+ top_k_hits ={}
89
+ rank = row[rank_col]
90
+ for k in top_k:
91
+ if rank <= k:
92
+ top_k_hits[f'{rank_col}-hit_rate@{k}'] = 1
93
+ else:
94
+ top_k_hits[f'{rank_col}-hit_rate@{k}'] = 0
95
+ return pd.Series(top_k_hits)
96
+
97
+ #################### Rank aggregation #######################
98
+ from collections import defaultdict
99
+ import numpy as np
100
+ from scipy.stats import rankdata
101
+
102
+ def borda_count(candidates, score_lists, target):
103
+ scores = defaultdict(int)
104
+ N = len(candidates)
105
+ for score_list in score_lists:
106
+ ranked_list = sorted(zip(candidates, score_list), key=lambda x: x[1], reverse=True)
107
+ for rank, (mol, _) in enumerate(ranked_list, start=1):
108
+ scores[mol] += N - rank + 1
109
+ ranked_candidates = [mol for mol, _ in sorted(scores.items(), key=lambda x: x[1], reverse=True)]
110
+ return ranked_candidates.index(target) + 1 if target in ranked_candidates else None
111
+
112
+ def average_rank(candidates, score_lists, target):
113
+ rank_sums = defaultdict(list)
114
+ for score_list in score_lists:
115
+ ranked_list = sorted(zip(candidates, score_list), key=lambda x: x[1], reverse=True)
116
+ for rank, (mol, _) in enumerate(ranked_list, start=1):
117
+ rank_sums[mol].append(rank)
118
+ avg_ranks = {mol: np.mean(ranks) for mol, ranks in rank_sums.items()}
119
+ ranked_candidates = [mol for mol, _ in sorted(avg_ranks.items(), key=lambda x: x[1])]
120
+ return ranked_candidates.index(target) + 1 if target in ranked_candidates else None
121
+
122
+ def reciprocal_rank_aggregation(candidates, score_lists, target):
123
+ scores = defaultdict(float)
124
+ for score_list in score_lists:
125
+ ranked_list = sorted(zip(candidates, score_list), key=lambda x: x[1], reverse=True)
126
+ for rank, (mol, _) in enumerate(ranked_list, start=1):
127
+ scores[mol] += 1 / rank
128
+ ranked_candidates = [mol for mol, _ in sorted(scores.items(), key=lambda x: x[1], reverse=True)]
129
+ return ranked_candidates.index(target) + 1 if target in ranked_candidates else None
130
+
131
+ def weighted_voting(candidates, score_lists, weights, target):
132
+ scores = defaultdict(float)
133
+ for weight, score_list in zip(weights, score_lists):
134
+ ranked_list = sorted(zip(candidates, score_list), key=lambda x: x[1], reverse=True)
135
+ for rank, (mol, _) in enumerate(ranked_list, start=1):
136
+ scores[mol] += weight / rank
137
+ ranked_candidates = [mol for mol, _ in sorted(scores.items(), key=lambda x: x[1], reverse=True)]
138
+ return ranked_candidates.index(target) + 1 if target in ranked_candidates else None
139
+
140
+ def median_rank(candidates, score_lists, target):
141
+ rank_sums = defaultdict(list)
142
+ for score_list in score_lists:
143
+ ranked_list = sorted(zip(candidates, score_list), key=lambda x: x[1], reverse=True)
144
+ for rank, (mol, _) in enumerate(ranked_list, start=1):
145
+ rank_sums[mol].append(rank)
146
+ median_ranks = {mol: np.median(ranks) for mol, ranks in rank_sums.items()}
147
+ ranked_candidates = [mol for mol, _ in sorted(median_ranks.items(), key=lambda x: x[1])]
148
+ return ranked_candidates.index(target) + 1 if target in ranked_candidates else None
149
+
150
+ def score_based_aggregation(candidates, score_lists, target):
151
+ scores = defaultdict(list)
152
+ for score_list in score_lists:
153
+ for mol, score in zip(candidates, score_list):
154
+ scores[mol].append(score)
155
+ avg_scores = {mol: np.mean(vals) for mol, vals in scores.items()}
156
+ ranked_candidates = [mol for mol, _ in sorted(avg_scores.items(), key=lambda x: x[1], reverse=True)]
157
+ return ranked_candidates.index(target) + 1 if target in ranked_candidates else None
mvp/utils/general.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ def pad_graph_nodes(mol_enc, g_n_nodes):
6
+ """
7
+ Args:
8
+ mol_enc: 2D tensor of shape (sum_nodes, D)
9
+ Node embeddings for each molecule.
10
+ g_n_nodes: list[int] Number of nodes per graph (len = B)
11
+
12
+ Returns:
13
+ padded: (B, max_nodes, D) tensor
14
+ mask: (B, max_nodes) bool tensor, True for valid nodes
15
+ """
16
+
17
+ # Already concatenated: shape (sum_nodes, D)
18
+ B = len(g_n_nodes)
19
+ D = mol_enc.shape[1]
20
+ max_nodes = max(g_n_nodes)
21
+ padded = mol_enc.new_zeros((B, max_nodes, D))
22
+ mask = torch.zeros((B, max_nodes), dtype=torch.bool, device=mol_enc.device)
23
+
24
+ idx = 0
25
+ for i, n in enumerate(g_n_nodes):
26
+ padded[i, :n] = mol_enc[idx:idx+n]
27
+ mask[i, :n] = True
28
+ idx += n
29
+ return padded, mask
mvp/utils/loss.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ def contrastive_loss(v1, v2, tau=1.0) -> torch.Tensor:
6
+ v1_norm = torch.norm(v1, dim=1, keepdim=True)
7
+ v2_norm = torch.norm(v2, dim=1, keepdim=True)
8
+
9
+ v2T = torch.transpose(v2, 0, 1)
10
+
11
+ inner_prod = torch.matmul(v1, v2T)
12
+
13
+ v2_normT = torch.transpose(v2_norm, 0, 1)
14
+
15
+ norm_mat = torch.matmul(v1_norm, v2_normT)
16
+
17
+ loss_mat = torch.div(inner_prod, norm_mat)
18
+
19
+ loss_mat = loss_mat * (1/tau)
20
+
21
+ loss_mat = torch.exp(loss_mat)
22
+
23
+ numerator = torch.diagonal(loss_mat)
24
+ numerator = torch.unsqueeze(numerator, 0)
25
+
26
+ Lv1_v2_denom = torch.sum(loss_mat, dim=1, keepdim=True)
27
+ Lv1_v2_denom = torch.transpose(Lv1_v2_denom, 0, 1)
28
+ #Lv1_v2_denom = Lv1_v2_denom - numerator
29
+
30
+ Lv2_v1_denom = torch.sum(loss_mat, dim=0, keepdim=True)
31
+ #Lv2_v1_denom = Lv2_v1_denom - numerator
32
+
33
+ Lv1_v2 = torch.div(numerator, Lv1_v2_denom)
34
+
35
+ Lv1_v2 = -1 * torch.log(Lv1_v2)
36
+ Lv1_v2 = torch.mean(Lv1_v2)
37
+
38
+ Lv2_v1 = torch.div(numerator, Lv2_v1_denom)
39
+
40
+ Lv2_v1 = -1 * torch.log(Lv2_v1)
41
+ Lv2_v1 = torch.mean(Lv2_v1)
42
+
43
+ return Lv1_v2 + Lv2_v1 , torch.mean(numerator), torch.mean(Lv1_v2_denom+Lv2_v1_denom)
44
+
45
+ def cand_spec_sim_loss(spec_enc, cand_enc):
46
+ cand_enc = torch.transpose(cand_enc, 0, 1) # C x B x d
47
+ spec_enc = spec_enc.unsqueeze(0) # 1 x B x d
48
+
49
+ sim = nn.functional.cosine_similarity(spec_enc, cand_enc, dim=2)
50
+ loss = torch.mean(sim)
51
+
52
+ return loss
53
+
54
+ class cons_spec_loss:
55
+ def __init__(self, loss_type) -> None:
56
+ self.loss_compute = {'cosine': self.cos_loss,
57
+ 'l2':torch.nn.MSELoss()}[loss_type]
58
+ def __call__(self,cons_spec, ind_spec):
59
+ return self.loss_compute(cons_spec, ind_spec)
60
+
61
+ def cos_loss(self, cons_spec, ind_spec):
62
+ sim = nn.functional.cosine_similarity(cons_spec, ind_spec)
63
+ loss = 1-torch.mean(sim)
64
+ return loss
65
+
66
+ class fp_loss:
67
+ def __init__(self, loss_type) -> None:
68
+ self.loss_compute = {'cosine': self.fp_loss_cos,
69
+ 'bce': nn.BCELoss()}[loss_type]
70
+
71
+ def __call__(self, predicted_fp, target_fp):
72
+ return self.loss_compute(predicted_fp, target_fp)
73
+
74
+ def fp_loss_cos(self, predicted_fp, target_fp):
75
+ sim = nn.functional.cosine_similarity(predicted_fp, target_fp)
76
+ return 1 - torch.mean(sim)
77
+
78
+
mvp/utils/models.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mvp.models.spec_encoder import SpecEncMLP_BIN, SpecFormulaTransformer
2
+ from mvp.models.mol_encoder import MolEnc
3
+ from mvp.models.encoders import MLP
4
+ from mvp.models.contrastive import ContrastiveModel, MultiViewContrastive
5
+
6
+ def get_spec_encoder(spec_enc:str, args):
7
+ return {"MLP_BIN": SpecEncMLP_BIN,
8
+ "Transformer_Formula": SpecFormulaTransformer}[spec_enc](args)
9
+
10
+ def get_mol_encoder(mol_enc: str, args):
11
+ return {'GNN': MolEnc}[mol_enc](args, in_dim=78)
12
+
13
+ def get_fp_pred_model(args):
14
+ return MLP(in_dim=args.final_embedding_dim, hidden_dims=[args.fp_size], final_activation='sigmoid', dropout=args.fp_dropout)
15
+
16
+ def get_fp_enc_model(args):
17
+ return MLP(in_dim=args.fp_size, hidden_dims=[args.final_embedding_dim,args.final_embedding_dim*2,args.final_embedding_dim,], final_activation=None, dropout=0.0)
18
+
19
+ def get_model(model:str,
20
+ params):
21
+
22
+ if model == 'contrastive':
23
+ model= ContrastiveModel(**params)
24
+ elif model == "MultiviewContrastive":
25
+ model = MultiViewContrastive(**params)
26
+ else:
27
+ raise Exception(f"Model {model} not implemented.")
28
+
29
+ # If checkpoint path is provided, load the model from the checkpoint instead
30
+ if params['checkpoint_pth'] is not None and params['checkpoint_pth'] != "":
31
+ model = type(model).load_from_checkpoint(
32
+ params['checkpoint_pth'],
33
+ log_only_loss_at_stages=params['log_only_loss_at_stages'],
34
+ df_test_path=params['df_test_path']
35
+ )
36
+ print("Loaded Model from checkpoint")
37
+
38
+ return model
mvp/utils/preprocessing.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import pickle
3
+ import numpy as np
4
+ import mvp.utils.data as data_utils
5
+ import collections
6
+ import os
7
+ import requests
8
+ import tqdm
9
+ from multiprocessing import Pool
10
+ from urllib.parse import quote
11
+ from tqdm import tqdm
12
+
13
+ class NPClassProcess:
14
+ def process_smiles(smiles):
15
+ try:
16
+ encoded_smiles = quote(smiles)
17
+ url = f"https://npclassifier.gnps2.org/classify?smiles={encoded_smiles}"
18
+ r = requests.get(url)
19
+ return (smiles, r.json())
20
+ except:
21
+ return (smiles, None)
22
+
23
+ def NPclass_from_smiles(pth, output_dir, n_processes=20):
24
+
25
+ data = pd.read_csv(pth, sep='\t')
26
+ unique_smiles = data['smiles'].unique().tolist()
27
+
28
+ items = unique_smiles
29
+
30
+ with Pool(processes=n_processes) as pool:
31
+ results = list(tqdm(pool.imap(NPClassProcess.process_smiles, items), total=len(items)))
32
+
33
+ failed_ct = 0
34
+ smiles_to_class = {}
35
+ for s, out in results:
36
+ if out is None:
37
+ smiles_to_class[s] = 'NA'
38
+ failed_ct+=1
39
+ else:
40
+ smiles_to_class[s] = out
41
+ file_pth = os.path.join(output_dir, 'SMILES_TO_CLASS.pkl')
42
+ with open(file_pth, 'wb') as f:
43
+ pickle.dump(smiles_to_class, f)
44
+ print(f'Failed to process {failed_ct} SMILES')
45
+ print(f'result file saved to {file_pth}')
46
+ return file_pth
47
+
48
+
49
+
50
+ def construct_NL_spec(pth, output_dir):
51
+ def _get_spec(row):
52
+ mzs = np.array([float(m) for m in row["mzs"].split(",")], dtype=np.float32)
53
+ intensities = np.array([float(i) for i in row["intensities"].split(",")],dtype=np.float32)
54
+ mzs = float(row['precursor_mz']) - mzs
55
+ valid_idx = np.where(mzs>1.0)
56
+ mzs = mzs[valid_idx]
57
+ intensities = intensities[valid_idx]
58
+
59
+ sorted_idx = np.argsort(mzs)
60
+ mzs = np.concatenate((mzs[sorted_idx], [float(row['precursor_mz'])]))
61
+ intensities = np.concatenate((intensities[sorted_idx], [1.0]))
62
+
63
+ return mzs, intensities
64
+
65
+ spec_data = pd.read_csv(pth, sep='\t')
66
+ spec_data[['mzs', 'intensities']] = spec_data.apply(lambda row: _get_spec(row), axis=1, result_type='expand')
67
+
68
+ file_pth = os.path.join(output_dir, 'NL_spec.pkl')
69
+ with open(file_pth, 'wb') as f:
70
+ pickle.dump(spec_data, f)
71
+ return file_pth
72
+
73
+ def generate_cons_spec(pth, output_dir):
74
+ spec_data = pd.read_csv(pth, sep='\t')
75
+ data_by_smiles = spec_data[['identifier', 'smiles', 'mzs', 'intensities', 'fold']].groupby('smiles').agg({'identifier':list, 'mzs':lambda x: ','.join(x), 'intensities': lambda x: ','.join(x), 'fold':list})
76
+ smiles_to_fold = dict(zip(data_by_smiles.index.tolist(), data_by_smiles['fold'].tolist()))
77
+
78
+ consensus_spectra = {}
79
+ for idx, row in tqdm(data_by_smiles.iterrows(), total=len(data_by_smiles)):
80
+ mzs = np.array([float(m) for m in row["mzs"].split(",")], dtype=np.float32)
81
+ intensities = np.array([float(i) for i in row["intensities"].split(",")],dtype=np.float32)
82
+
83
+ sorted_idx = np.argsort(mzs)
84
+ mzs = mzs[sorted_idx]
85
+ intensities = intensities[sorted_idx]
86
+ smiles = row.name
87
+
88
+ consensus_spectra[smiles] = {'mzs':mzs, 'intensities':intensities,'precursor_mz': 10000.0,
89
+ 'fold': smiles_to_fold[smiles][0]}
90
+
91
+ df = pd.DataFrame.from_dict(consensus_spectra, orient='index')
92
+ df = df.rename_axis('smiles').reset_index()
93
+
94
+ return df
95
+
96
+
97
+ def generate_cons_spec_formulas(pth, subformula_dir, output_dir=''):
98
+ # load tsv file
99
+ spec_data = pd.read_csv(pth, sep='\t')
100
+
101
+ # goup spectra by SMILES
102
+ data_by_smiles = spec_data[['identifier', 'smiles', 'fold', 'precursor_mz', 'formula', 'adduct']].groupby('smiles').agg({'identifier':list, 'fold': list, 'formula': list, 'precursor_mz': "max", 'adduct': list})
103
+ smiles_to_id = dict(zip(data_by_smiles.index.tolist(), data_by_smiles['identifier'].tolist()))
104
+ smiles_to_fold = dict(zip(data_by_smiles.index.tolist(), data_by_smiles['fold'].tolist()))
105
+ smiles_to_precursorMz = dict(zip(data_by_smiles.index.tolist(), data_by_smiles['precursor_mz'].tolist()))
106
+ smiles_to_precursorFormula = dict(zip(data_by_smiles.index.tolist(), data_by_smiles['formula'].tolist()))
107
+ # load subformulas
108
+ subformulaLoader = data_utils.Subformula_Loader(spectra_view='SpecFormula', dir_path=subformula_dir)
109
+ id_to_spec = subformulaLoader(spec_data['identifier'].tolist())
110
+
111
+ # combine spectra
112
+ consensus_spectra = {}
113
+ for smiles, ids in tqdm(smiles_to_id.items(), total=len(data_by_smiles)):
114
+ cons_spec = collections.defaultdict(list)
115
+ for id in ids:
116
+ if id in id_to_spec:
117
+ for k, v in id_to_spec[id].items():
118
+ cons_spec[k].extend(v)
119
+ cons_spec = pd.DataFrame(cons_spec)
120
+
121
+ assert(len(set(smiles_to_fold[smiles]))==1)
122
+
123
+ # keep maxed mz and maxed intensity
124
+ try:
125
+ cons_spec = cons_spec.groupby('formulas').agg({'formula_mzs': "max", 'formula_intensities': "max"})
126
+ cons_spec.reset_index(inplace=True)
127
+ except:
128
+ d = {
129
+ 'formulas': [smiles_to_precursorFormula[smiles][0]],
130
+ 'formula_mzs': [smiles_to_precursorMz[smiles]],
131
+ 'formula_intensities': [1.0]
132
+ }
133
+ cons_spec = pd.DataFrame(d)
134
+
135
+ cons_spec = cons_spec.sort_values(by='formula_mzs').reset_index(drop=True)
136
+ cons_spec = {'formulas': cons_spec['formulas'].tolist(),
137
+ 'formula_mzs': cons_spec['formula_mzs'].tolist(),
138
+ 'formula_intensities': cons_spec['formula_intensities'].tolist(),
139
+ 'precursor_mz': smiles_to_precursorMz[smiles],
140
+ 'fold': smiles_to_fold[smiles][0],
141
+ 'precursor_formula': smiles_to_precursorFormula[smiles][0]}# formula without adduct...
142
+
143
+ consensus_spectra[smiles] = cons_spec
144
+
145
+ # save consensus spectra
146
+ df = pd.DataFrame.from_dict(consensus_spectra, orient='index')
147
+ df = df.rename_axis('smiles').reset_index()
148
+
149
+ return df