| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import random as rd |
|
|
| from diffusion import Diffusion |
| from scoring.scoring_functions import ScoringFunctions |
| from utils.filter import PeptideAnalyzer |
| import noise_schedule |
|
|
| """" |
| Notes: store rolled out sequence? |
| path of node objects or strings? |
| should we only select valid expandable leaf nodes? |
| calculate similarity between sibling nodes? |
| should we evaluate generated sequences? |
| """ |
| class Node: |
| """ |
| Node class: partially unmasked SMILES string |
| - parentNode: Node object at previous time step |
| - childNodes: set of M Node objects generated from sampling M distinct unmasking schemes |
| - totalReward: vector of cumulative rewards for all K objectives |
| - visits: number of times the node has been visited by an interation |
| - path: array of partially unmasked SMILES strings leading to the node from the completely masked root node |
| - timestep: the time step where the sequence was sampled |
| - sampleProb: probability of sampling the sequence from the diffusion model |
| """ |
| def __init__(self, config, tokens=None, parentNode=None, childNodes=[], scoreVector=None, totalReward=None, timestep=None, sampleProb=None): |
| self.config = config |
| self.parentNode = parentNode |
| self.childNodes = childNodes |
| self.scoreVector = scoreVector |
| |
| |
| if totalReward is not None: |
| self.totalReward = totalReward |
| else: |
| self.totalReward = np.zeros(self.config.mcts.num_objectives) |
| |
| |
| self.visits = 1 |
| |
| |
| |
| self.timestep = timestep |
| |
| self.sampleProb = sampleProb |
| |
| |
| self.tokens = tokens |
| |
| |
| |
| def selectNode(self, num_func): |
| """ |
| Selects a node to move to among the children nodes |
| """ |
| |
| nodeStatus = self.getExpandStatus() |
| |
| |
| if (nodeStatus == 3): |
| |
| paretoFront = {} |
| for childNode in self.childNodes: |
| childStatus = childNode.getExpandStatus() |
| |
| if childStatus == 2 or childStatus == 3: |
| selectScore = childNode.calcSelectScore() |
| paretoFront = updateParetoFront(paretoFront, childNode, selectScore, num_func) |
| |
| |
| |
| selected = rd.choice(list(paretoFront.keys())) |
| |
| return selected, selected.getExpandStatus() |
| |
| |
| return self, nodeStatus |
|
|
| def addChildNode(self, tokens, totalReward, prob=None): |
| """" |
| Adds a child node |
| """ |
| child = Node(config=self.config, |
| tokens=tokens, |
| parentNode=self, |
| childNodes=[], |
| totalReward=totalReward, |
| timestep=self.timestep+1, |
| sampleProb=prob) |
| |
| self.childNodes.append(child) |
| return child |
| |
| def updateNode(self, rewards): |
| """ |
| Updates the cumulative rewards vector with the reward vector at a descendent leaf node. |
| Increments the number of visits to the node. |
| """ |
| self.visits += 1 |
| self.totalReward += rewards |
| |
| def calcSelectScore(self): |
| """ |
| Calculates the select score for the node from the cumulative rewards vector and number of visits. |
| - c: determines the degree of exploration |
| - minSelectScore: determines the |
| """ |
| """" |
| if not self.parentNode: |
| return 0.0 |
| """ |
| |
| normRewards = self.totalReward / self.visits |
| if self.sampleProb is not None: |
| print("Sample Prob") |
| print(self.sampleProb) |
| return normRewards + (self.config.mcts.sample_prob * self.sampleProb * np.sqrt(self.root.visits) / self.visits) |
| return normRewards |
| |
| def getExpandStatus(self): |
| """ |
| Returns an integer indicating whether the node is a: |
| 1. terminal node (sequence is fully unmasked) |
| 2. legal leaf node (partially unmasked sequence that can be expanded) |
| 3. legal non-leaf node (already expanded sequence with M child nodes) |
| """ |
| if self.timestep == self.config.sampling.steps: |
| return 1 |
| elif (self.timestep < self.config.sampling.steps) and (len(self.childNodes) == 0): |
| return 2 |
| return 3 |
| |
| """END OF NODE CLASS""" |
|
|
| def updateParetoFront(paretoFront, node, scoreVector, num_func): |
| """ |
| Removes sequences that are dominated by scoreVector |
| adds the SMILES sequence if it is non-dominated and its scoreVector |
| """ |
| paretoSize = len(paretoFront) |
| if paretoSize == 0: |
| |
| paretoFront[node] = scoreVector |
| else: |
| |
| |
| |
| nondominate = [] |
| |
| delete = [] |
| for k, v in paretoFront.items(): |
| nondominated = scoreVector >= np.asarray(v) |
| dominant = scoreVector > np.asarray(v) |
| |
| if num_func <= len(nondominated): |
| attn_nondominated = nondominated[:num_func] |
| attn_dominant = dominant[:num_func] |
| |
| |
| if attn_nondominated.all() and attn_dominant.any(): |
| |
| delete.append(k) |
| |
| nondominate.append(True) |
| elif attn_nondominated.all(): |
| |
| nondominate.append(True) |
| else: |
| |
| nondominate.append(False) |
| |
| nondominate = np.asarray(nondominate) |
| |
| if nondominate.all(): |
| paretoFront[node] = scoreVector |
| |
| |
| while (paretoSize > 0) and (len(delete) > 0): |
| |
| del paretoFront[delete[0]] |
| del delete[0] |
| paretoSize -= 1 |
| return paretoFront |
| |
| """BEGINNING OF MCTS CLASS""" |
|
|
| class MCTS: |
| def __init__(self, config, max_sequence_length=None, mdlm=None, score_func_names=[], prot_seqs=None, num_func = []): |
| self.config = config |
| self.noise = noise_schedule.get_noise(config) |
| self.time_conditioning = self.config.time_conditioning |
| |
| self.peptideParetoFront = {} |
| self.num_steps = config.sampling.steps |
| self.num_sequences = config.sampling.num_sequences |
| |
| |
| self.mdlm = mdlm |
| self.tokenizer = mdlm.tokenizer |
| self.device = mdlm.device |
| |
| if max_sequence_length is None: |
| self.sequence_length = self.config.sampling.seq_length |
| else: |
| self.sequence_length = max_sequence_length |
| |
| self.num_iter = config.mcts.num_iter |
| |
| self.num_child = config.mcts.num_children |
| |
| |
| self.score_functions = ScoringFunctions(score_func_names, prot_seqs) |
| self.num_func = num_func |
| self.iter_num = 0 |
| self.curr_num_func = 1 |
| self.analyzer = PeptideAnalyzer() |
| |
| |
| self.valid_fraction_log = [] |
| self.affinity1_log = [] |
| self.affinity2_log = [] |
| self.permeability_log = [] |
| self.sol_log = [] |
| self.hemo_log = [] |
| self.nf_log = [] |
| |
| def reset(self): |
| self.iter_num = 0 |
| self.valid_fraction_log = [] |
| self.affinity1_log = [] |
| self.affinity2_log = [] |
| self.permeability_log = [] |
| self.sol_log = [] |
| self.hemo_log = [] |
| self.nf_log = [] |
| self.peptideParetoFront = {} |
| |
| def forward(self, rootNode): |
| self.reset() |
| |
| while (self.iter_num < self.num_iter): |
| self.iter_num += 1 |
| |
| |
| leafNode, _ = self.select(rootNode) |
| |
| |
| |
| self.expand(leafNode) |
| |
| |
| return self.peptideParetoFront |
|
|
| |
| def updateParetoFront(self, sequence, scoreVector, tokens): |
| """ |
| Removes sequences that are dominated by scoreVector |
| adds the SMILES sequence if it is non-dominated and its scoreVector |
| |
| num_func: index of the last objective to consider when updating the pareto front from 0 to K |
| """ |
| paretoSize = len(self.peptideParetoFront) |
| |
| self.curr_num_func = 1 |
| |
| for i in range(len(self.num_func)): |
| if self.iter_num >= self.num_func[i]: |
| self.curr_num_func = i+1 |
| |
| if paretoSize == 0: |
| |
| self.peptideParetoFront[sequence] = {'scores': scoreVector, 'token_ids': tokens} |
| |
| rewardVector = np.ones(len(scoreVector)) |
| else: |
| |
| |
| |
| nondominate = [] |
| |
| delete = [] |
| |
| rewardVector = np.zeros(len(scoreVector)) |
| for k, v in self.peptideParetoFront.items(): |
| |
| |
| |
| nondominated = scoreVector >= np.asarray(v['scores']) |
| dominant = scoreVector > np.asarray(v['scores']) |
| |
| rewardVector += nondominated |
|
|
| if self.curr_num_func <= len(nondominated): |
| attn_nondominated = nondominated[:self.curr_num_func] |
| attn_dominant = dominant[:self.curr_num_func] |
| |
| |
| |
| if attn_nondominated.all() and attn_dominant.any(): |
| |
| delete.append(k) |
| |
| nondominate.append(True) |
| elif attn_nondominated.all(): |
| |
| nondominate.append(True) |
| else: |
| |
| nondominate.append(False) |
| |
| assert len(nondominate) == paretoSize |
| nondominate = np.asarray(nondominate) |
| |
| |
| if nondominate.all() or paretoSize < self.num_sequences: |
| self.peptideParetoFront[sequence] = {'scores': scoreVector, 'token_ids': tokens} |
| |
| rewardVector = rewardVector / paretoSize |
| |
| |
| while (paretoSize > self.num_sequences) and (len(delete) > 0): |
| |
| del self.peptideParetoFront[delete[0]] |
| del delete[0] |
| paretoSize -= 1 |
| |
| return rewardVector |
|
|
| def isPathEnd(self, path, maxDepth): |
| """ |
| Checks if the node is completely unmasked (ie. end of path) |
| or if the path is at the max depth |
| """ |
| if (path[-1] != self.config.mcts.mask_token).all(): |
| return True |
| elif len(path) >= maxDepth: |
| return True |
| return False |
| |
| def select(self, currNode): |
| """ |
| Traverse the tree from the root node until reaching a legal leaf node |
| """ |
| while True: |
| currNode, nodeStatus = currNode.selectNode(self.curr_num_func) |
| if nodeStatus != 3: |
| return currNode, nodeStatus |
| |
| def expand(self, parentNode, eps=1e-5, checkSimilarity = True): |
| """ |
| Sample unmasking steps from the pre-trained MDLM |
| adds num_children partially unmasked sequences to the children of the parentNode |
| """ |
| |
| num_children = self.config.mcts.num_children |
| |
| allChildReward = np.zeros_like(parentNode.totalReward) |
| |
| |
| |
| |
| num_rollout_steps = self.num_steps - parentNode.timestep |
| |
| rollout_t = torch.linspace(1, eps, num_rollout_steps, device=self.device) |
| dt = (1 - eps) / self.num_steps |
| p_x0_cache = None |
| |
| |
| x = parentNode.tokens['input_ids'].to(self.device) |
| attn_mask = parentNode.tokens['attention_mask'].to(self.device) |
| |
| t = rollout_t[0] * torch.ones(num_children, 1, device = self.device) |
| |
| print("token array:") |
| print(x) |
| p_x0_cache, x_children = self.mdlm.batch_cached_reverse_step(token_array=x, |
| t=t, dt=dt, |
| batch_size=num_children, |
| attn_mask=attn_mask) |
| x_rollout = x_children |
| |
| for i in range(1, num_rollout_steps): |
| t = rollout_t[i] * torch.ones(num_children, 1, device = self.device) |
| |
| p_x0_cache, x_next = self.mdlm.cached_reverse_step(x=x_rollout, |
| t=t, dt=dt, p_x0=p_x0_cache, |
| attn_mask=attn_mask) |
| |
| if (not torch.allclose(x_next, x) or self.time_conditioning): |
| |
| p_x0_cache = None |
| |
| x_rollout = x_next |
| |
| if self.config.sampling.noise_removal: |
| t = rollout_t[-1] * torch.ones(x.shape[0], 1, device=self.device) |
| """if self.sampler == 'analytic': |
| x = self.mdlm._denoiser_update(x, t) |
| else:""" |
| time_cond = self.noise(t)[0] |
| x_rollout = self.mdlm.forward(x_rollout, attn_mask, time_cond).argmax(dim=-1) |
| |
| childSequences = self.tokenizer.batch_decode(x_rollout) |
| |
| validSequences = [] |
| maskedTokens = [] |
| unmaskedTokens = [] |
| for i in range(num_children): |
| childSeq = childSequences[i] |
| |
| rewardVector = np.zeros(self.config.mcts.num_objectives) |
| |
| |
| if self.analyzer.is_peptide(childSeq): |
| validSequences.append(childSeq) |
| maskedTokens.append(x_children[i]) |
| unmaskedTokens.append(x_rollout[i]) |
| else: |
| childTokens = {'input_ids': x_children[i], 'attention_mask': attn_mask} |
| parentNode.addChildNode(tokens=childTokens, |
| totalReward=rewardVector) |
| |
| if (len(validSequences) != 0): |
| scoreVectors = self.score_functions(input_seqs=validSequences) |
| average_scores = scoreVectors.T |
| if self.config.mcts.single: |
| self.permeability_log.append(average_scores[0]) |
| else: |
| self.affinity1_log.append(average_scores[0]) |
| self.sol_log.append(average_scores[1]) |
| self.hemo_log.append(average_scores[2]) |
| self.nf_log.append(average_scores[3]) |
| if self.config.mcts.perm: |
| self.permeability_log.append(average_scores[4]) |
| elif self.config.mcts.dual: |
| self.affinity2_log.append(average_scores[4]) |
| else: |
| self.affinity1_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences))) |
| self.sol_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences))) |
| self.hemo_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences))) |
| self.nf_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences))) |
| |
| if self.config.mcts.perm: |
| self.permeability_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences))) |
| elif self.config.mcts.dual: |
| self.affinity2_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences))) |
| |
| for i, validSeq in enumerate(validSequences): |
| |
| scoreVector = scoreVectors[i] |
| |
| |
| rewardVector = self.updateParetoFront(validSeq, scoreVector, unmaskedTokens[i]) |
| print(scoreVector) |
| print(rewardVector) |
| |
| |
| allChildReward += rewardVector |
| |
| |
| childTokens = {'input_ids': maskedTokens[i], 'attention_mask': attn_mask} |
| parentNode.addChildNode(tokens=childTokens, |
| totalReward=rewardVector) |
| |
| |
| invalid = (num_children - len(validSequences)) / num_children |
|
|
| valid_fraction = len(validSequences) / num_children |
| print(f"Valid fraction: {valid_fraction}") |
| self.valid_fraction_log.append(valid_fraction) |
| |
| print(self.config.mcts.invalid_penalty) |
| |
| allChildReward = allChildReward - (self.config.mcts.invalid_penalty * invalid) |
| |
| self.backprop(parentNode, allChildReward) |
|
|
|
|
| def backprop(self, node, reward_vector): |
| |
| while node: |
| node.updateNode(reward_vector) |
| node = node.parentNode |
| |
|
|
| def getSequenceForObjective(self, objective_index, k): |
| """ |
| Returns the top-k sequences in the pareto front that has the best score for |
| a given objective and their score vectors for all objectives |
| """ |
| |
| |
| topk = {} |
| |
| peptides = [] |
| objectiveScores = [] |
| for k, v in self.peptideParetoFront.items(): |
| |
| peptides.append(k) |
| |
| objectiveScores.append(v['token_ids'][objective_index]) |
| |
| objectiveScores = torch.tensor(objectiveScores) |
| topKScores = torch.topk(objectiveScores, k) |
| for (_, index) in topKScores.items(): |
| seq = peptides[index] |
| |
| topk[seq] = self.peptideParetoFront.get(seq) |
| |
| return topk |
| |
|
|