lhallee commited on
Commit
3d7bb18
·
verified ·
1 Parent(s): c1bc3cb

Upload modeling_e1.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_e1.py +3 -274
modeling_e1.py CHANGED
@@ -1,8 +1,6 @@
1
  import os
2
  os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
3
 
4
- import numpy as np
5
- import networkx as nx
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
@@ -10,15 +8,15 @@ from torch.nn.utils.rnn import pad_sequence
10
 
11
  from einops import rearrange, repeat
12
  from enum import Enum
13
- from typing import Any, TypedDict, Callable, Optional, List
14
  from dataclasses import dataclass
15
  from tokenizers import Tokenizer
16
  from transformers import PretrainedConfig, PreTrainedModel
17
  from transformers.activations import ACT2FN
18
  from transformers.modeling_outputs import ModelOutput
19
  from transformers.utils import logging
20
- from tqdm.auto import tqdm
21
- from embedding_mixin import EmbeddingMixin, Pooler
22
 
23
 
24
  logger = logging.get_logger(__name__)
@@ -1356,275 +1354,6 @@ class DecoderLayer(nn.Module):
1356
  return hidden_states, self_attn_weights, present_key_value
1357
 
1358
 
1359
- ### Support for embedding datasets with low code
1360
- class _LegacyPooler:
1361
- def __init__(self, pooling_types: List[str]):
1362
- self.pooling_types = pooling_types
1363
- self.pooling_options = {
1364
- 'mean': self.mean_pooling,
1365
- 'max': self.max_pooling,
1366
- 'norm': self.norm_pooling,
1367
- 'median': self.median_pooling,
1368
- 'std': self.std_pooling,
1369
- 'var': self.var_pooling,
1370
- 'cls': self.cls_pooling,
1371
- 'parti': self._pool_parti,
1372
- }
1373
-
1374
- def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor:
1375
- maxed_attentions = torch.max(attentions, dim=1)[0]
1376
- return maxed_attentions
1377
-
1378
- def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"):
1379
- # Run PageRank on the attention matrix converted to a graph.
1380
- # Raises exceptions if the graph doesn't match the token sequence or has no edges.
1381
- # Returns the PageRank scores for each token node.
1382
- G = self._convert_to_graph(attention_matrix)
1383
- if G.number_of_nodes() != attention_matrix.shape[0]:
1384
- raise Exception(
1385
- f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.")
1386
- if G.number_of_edges() == 0:
1387
- raise Exception(f"You don't seem to have any attention edges left in the graph.")
1388
-
1389
- return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
1390
-
1391
- def _convert_to_graph(self, matrix):
1392
- # Convert a matrix (e.g., attention scores) to a directed graph using networkx.
1393
- # Each element in the matrix represents a directed edge with a weight.
1394
- G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
1395
- return G
1396
-
1397
- def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None):
1398
- # Remove keys where attention_mask is 0
1399
- if attention_mask is not None:
1400
- for k in list(dict_importance.keys()):
1401
- if attention_mask[k] == 0:
1402
- del dict_importance[k]
1403
-
1404
- #dict_importance[0] # remove cls
1405
- #dict_importance[-1] # remove eos
1406
- total = sum(dict_importance.values())
1407
- return np.array([v / total for _, v in dict_importance.items()])
1408
-
1409
- def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
1410
- maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
1411
- # emb is (b, L, d), maxed_attentions is (b, L, L)
1412
- emb_pooled = []
1413
- for e, a, mask in zip(emb, maxed_attentions, attention_mask):
1414
- dict_importance = self._page_rank(a)
1415
- importance_weights = self._calculate_importance_weights(dict_importance, mask)
1416
- num_tokens = int(mask.sum().item())
1417
- emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0))
1418
- pooled = torch.tensor(np.array(emb_pooled))
1419
- return pooled
1420
-
1421
- def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
1422
- if attention_mask is None:
1423
- return emb.mean(dim=1)
1424
- else:
1425
- attention_mask = attention_mask.unsqueeze(-1)
1426
- return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
1427
-
1428
- def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
1429
- if attention_mask is None:
1430
- return emb.max(dim=1).values
1431
- else:
1432
- attention_mask = attention_mask.unsqueeze(-1)
1433
- return (emb * attention_mask).max(dim=1).values
1434
-
1435
- def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
1436
- if attention_mask is None:
1437
- return emb.norm(dim=1, p=2)
1438
- else:
1439
- attention_mask = attention_mask.unsqueeze(-1)
1440
- return (emb * attention_mask).norm(dim=1, p=2)
1441
-
1442
- def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
1443
- if attention_mask is None:
1444
- return emb.median(dim=1).values
1445
- else:
1446
- attention_mask = attention_mask.unsqueeze(-1)
1447
- return (emb * attention_mask).median(dim=1).values
1448
-
1449
- def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
1450
- if attention_mask is None:
1451
- return emb.std(dim=1)
1452
- else:
1453
- # Compute variance correctly over non-masked positions, then take sqrt
1454
- var = self.var_pooling(emb, attention_mask, **kwargs)
1455
- return torch.sqrt(var)
1456
-
1457
- def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
1458
- if attention_mask is None:
1459
- return emb.var(dim=1)
1460
- else:
1461
- # Correctly compute variance over only non-masked positions
1462
- attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1)
1463
- # Compute mean over non-masked positions
1464
- mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
1465
- mean = mean.unsqueeze(1) # (b, 1, d)
1466
- # Compute squared differences from mean, only over non-masked positions
1467
- squared_diff = (emb - mean) ** 2 # (b, L, d)
1468
- # Sum squared differences over non-masked positions and divide by count
1469
- var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
1470
- return var
1471
-
1472
- def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
1473
- return emb[:, 0, :]
1474
-
1475
- def __call__(
1476
- self,
1477
- emb: torch.Tensor,
1478
- attention_mask: Optional[torch.Tensor] = None,
1479
- attentions: Optional[torch.Tensor] = None
1480
- ): # [mean, max]
1481
- final_emb = []
1482
- for pooling_type in self.pooling_types:
1483
- final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d)
1484
- return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d)
1485
-
1486
-
1487
- class _LegacyEmbeddingMixin:
1488
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1489
- raise NotImplementedError
1490
-
1491
- @property
1492
- def device(self) -> torch.device:
1493
- """Get the device of the model."""
1494
- return next(self.parameters()).device
1495
-
1496
- def _read_sequences_from_db(self, db_path: str) -> set[str]:
1497
- """Read sequences from SQLite database."""
1498
- import sqlite3
1499
- sequences = []
1500
- with sqlite3.connect(db_path) as conn:
1501
- c = conn.cursor()
1502
- c.execute("SELECT sequence FROM embeddings")
1503
- while True:
1504
- row = c.fetchone()
1505
- if row is None:
1506
- break
1507
- sequences.append(row[0])
1508
- return set(sequences)
1509
-
1510
- def embed_dataset(
1511
- self,
1512
- sequences: List[str],
1513
- #tokenizer: PreTrainedTokenizerBase, # For E1, the tokenizing is handled by _embed
1514
- batch_size: int = 2,
1515
- max_len: int = 512,
1516
- truncate: bool = True,
1517
- full_embeddings: bool = False,
1518
- embed_dtype: torch.dtype = torch.float32,
1519
- pooling_types: List[str] = ['mean'],
1520
- sql: bool = False,
1521
- save: bool = True,
1522
- sql_db_path: str = 'embeddings.db',
1523
- save_path: str = 'embeddings.pth',
1524
- **kwargs,
1525
- ) -> Optional[dict[str, torch.Tensor]]:
1526
- """Embed a dataset of protein sequences.
1527
-
1528
- Args:
1529
- sequences: List of protein sequences
1530
- batch_size: Batch size for processing
1531
- max_len: Maximum sequence length
1532
- full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
1533
- pooling_type: Type of pooling ('mean' or 'cls')
1534
- sql: Whether to store embeddings in SQLite database - will be stored in float32
1535
- sql_db_path: Path to SQLite database
1536
-
1537
- Returns:
1538
- Dictionary mapping sequences to embeddings, or None if sql=True
1539
-
1540
- Note:
1541
- - If sql=True, embeddings can only be stored in float32
1542
- - sql is ideal if you need to stream a very large dataset for training in real-time
1543
- - save=True is ideal if you can store the entire embedding dictionary in RAM
1544
- - sql will be used if it is True and save is True or False
1545
- - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
1546
- - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
1547
-
1548
- Example:
1549
- >>> embedder = EmbeddingMixin()
1550
- >>> embedding_dict = embedder.embed_dataset(
1551
- sequences=[
1552
- 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
1553
- ],
1554
- batch_size=2, # adjust for your GPU memory
1555
- max_len=512, # adjust for your needs
1556
- full_embeddings=False, # if True, no pooling is performed
1557
- embed_dtype=torch.float32, # cast to what dtype you want
1558
- pooling_type=['mean', 'cls'], # more than one pooling type will be concatenated together
1559
- sql=False, # if True, embeddings will be stored in SQLite database
1560
- sql_db_path='embeddings.db',
1561
- save=True, # if True, embeddings will be saved as a .pth file
1562
- save_path='embeddings.pth',
1563
- )
1564
- >>> # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
1565
- """
1566
- sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
1567
- sequences = sorted(sequences, key=len, reverse=True)
1568
- hidden_size = self.config.hidden_size
1569
- pooler = Pooler(pooling_types) if not full_embeddings else None
1570
-
1571
- def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1572
- if full_embeddings or residue_embeddings.ndim == 2: # if already pooled or want residue-wise embeddings
1573
- return residue_embeddings
1574
- else:
1575
- return pooler(residue_embeddings, attention_mask)
1576
-
1577
- if sql:
1578
- import sqlite3
1579
- conn = sqlite3.connect(sql_db_path)
1580
- c = conn.cursor()
1581
- c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)')
1582
- already_embedded = self._read_sequences_from_db(sql_db_path)
1583
- to_embed = [seq for seq in sequences if seq not in already_embedded]
1584
- print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
1585
- print(f"Embedding {len(to_embed)} new sequences")
1586
- if len(to_embed) > 0:
1587
- with torch.no_grad():
1588
- for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
1589
- seqs = to_embed[batch_start:batch_start + batch_size]
1590
- input_ids, attention_mask = self._embed(seqs, return_attention_mask=True)
1591
- embeddings = get_embeddings(input_ids, attention_mask).float() # sql requires float32
1592
- for seq, emb, mask in zip(seqs, embeddings, attention_mask):
1593
- if full_embeddings:
1594
- emb = emb[mask.bool()].reshape(-1, hidden_size)
1595
- c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", (seq, emb.cpu().numpy().tobytes()))
1596
- conn.commit()
1597
- conn.commit()
1598
- conn.close()
1599
- return None
1600
-
1601
- embeddings_dict = {}
1602
- if os.path.exists(save_path):
1603
- embeddings_dict = torch.load(save_path, map_location='cpu', weights_only=True)
1604
- to_embed = [seq for seq in sequences if seq not in embeddings_dict]
1605
- print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}")
1606
- print(f"Embedding {len(to_embed)} new sequences")
1607
- else:
1608
- to_embed = sequences
1609
- print(f"Embedding {len(to_embed)} new sequences")
1610
-
1611
- if len(to_embed) > 0:
1612
- with torch.no_grad():
1613
- for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
1614
- seqs = to_embed[batch_start:batch_start + batch_size]
1615
- last_hidden_state, attention_mask = self._embed(seqs, return_attention_mask=True)
1616
- embeddings = get_embeddings(last_hidden_state, attention_mask).to(embed_dtype)
1617
- for seq, emb, mask in zip(seqs, embeddings, attention_mask):
1618
- if full_embeddings:
1619
- emb = emb[mask.bool()].reshape(-1, hidden_size)
1620
- embeddings_dict[seq] = emb.cpu()
1621
-
1622
- if save:
1623
- torch.save(embeddings_dict, save_path)
1624
-
1625
- return embeddings_dict
1626
-
1627
-
1628
  class E1PreTrainedModel(PreTrainedModel):
1629
  config_class = E1Config
1630
  config: E1Config
 
1
  import os
2
  os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
3
 
 
 
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
 
8
 
9
  from einops import rearrange, repeat
10
  from enum import Enum
11
+ from typing import Any, TypedDict, Callable, List
12
  from dataclasses import dataclass
13
  from tokenizers import Tokenizer
14
  from transformers import PretrainedConfig, PreTrainedModel
15
  from transformers.activations import ACT2FN
16
  from transformers.modeling_outputs import ModelOutput
17
  from transformers.utils import logging
18
+
19
+ from .embedding_mixin import EmbeddingMixin, Pooler
20
 
21
 
22
  logger = logging.get_logger(__name__)
 
1354
  return hidden_states, self_attn_weights, present_key_value
1355
 
1356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1357
  class E1PreTrainedModel(PreTrainedModel):
1358
  config_class = E1Config
1359
  config: E1Config