Upload modeling_e1.py with huggingface_hub
Browse files- modeling_e1.py +3 -274
modeling_e1.py
CHANGED
|
@@ -1,8 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
| 3 |
|
| 4 |
-
import numpy as np
|
| 5 |
-
import networkx as nx
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
import torch.nn.functional as F
|
|
@@ -10,15 +8,15 @@ from torch.nn.utils.rnn import pad_sequence
|
|
| 10 |
|
| 11 |
from einops import rearrange, repeat
|
| 12 |
from enum import Enum
|
| 13 |
-
from typing import Any, TypedDict, Callable,
|
| 14 |
from dataclasses import dataclass
|
| 15 |
from tokenizers import Tokenizer
|
| 16 |
from transformers import PretrainedConfig, PreTrainedModel
|
| 17 |
from transformers.activations import ACT2FN
|
| 18 |
from transformers.modeling_outputs import ModelOutput
|
| 19 |
from transformers.utils import logging
|
| 20 |
-
|
| 21 |
-
from embedding_mixin import EmbeddingMixin, Pooler
|
| 22 |
|
| 23 |
|
| 24 |
logger = logging.get_logger(__name__)
|
|
@@ -1356,275 +1354,6 @@ class DecoderLayer(nn.Module):
|
|
| 1356 |
return hidden_states, self_attn_weights, present_key_value
|
| 1357 |
|
| 1358 |
|
| 1359 |
-
### Support for embedding datasets with low code
|
| 1360 |
-
class _LegacyPooler:
|
| 1361 |
-
def __init__(self, pooling_types: List[str]):
|
| 1362 |
-
self.pooling_types = pooling_types
|
| 1363 |
-
self.pooling_options = {
|
| 1364 |
-
'mean': self.mean_pooling,
|
| 1365 |
-
'max': self.max_pooling,
|
| 1366 |
-
'norm': self.norm_pooling,
|
| 1367 |
-
'median': self.median_pooling,
|
| 1368 |
-
'std': self.std_pooling,
|
| 1369 |
-
'var': self.var_pooling,
|
| 1370 |
-
'cls': self.cls_pooling,
|
| 1371 |
-
'parti': self._pool_parti,
|
| 1372 |
-
}
|
| 1373 |
-
|
| 1374 |
-
def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor:
|
| 1375 |
-
maxed_attentions = torch.max(attentions, dim=1)[0]
|
| 1376 |
-
return maxed_attentions
|
| 1377 |
-
|
| 1378 |
-
def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"):
|
| 1379 |
-
# Run PageRank on the attention matrix converted to a graph.
|
| 1380 |
-
# Raises exceptions if the graph doesn't match the token sequence or has no edges.
|
| 1381 |
-
# Returns the PageRank scores for each token node.
|
| 1382 |
-
G = self._convert_to_graph(attention_matrix)
|
| 1383 |
-
if G.number_of_nodes() != attention_matrix.shape[0]:
|
| 1384 |
-
raise Exception(
|
| 1385 |
-
f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.")
|
| 1386 |
-
if G.number_of_edges() == 0:
|
| 1387 |
-
raise Exception(f"You don't seem to have any attention edges left in the graph.")
|
| 1388 |
-
|
| 1389 |
-
return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
|
| 1390 |
-
|
| 1391 |
-
def _convert_to_graph(self, matrix):
|
| 1392 |
-
# Convert a matrix (e.g., attention scores) to a directed graph using networkx.
|
| 1393 |
-
# Each element in the matrix represents a directed edge with a weight.
|
| 1394 |
-
G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
|
| 1395 |
-
return G
|
| 1396 |
-
|
| 1397 |
-
def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None):
|
| 1398 |
-
# Remove keys where attention_mask is 0
|
| 1399 |
-
if attention_mask is not None:
|
| 1400 |
-
for k in list(dict_importance.keys()):
|
| 1401 |
-
if attention_mask[k] == 0:
|
| 1402 |
-
del dict_importance[k]
|
| 1403 |
-
|
| 1404 |
-
#dict_importance[0] # remove cls
|
| 1405 |
-
#dict_importance[-1] # remove eos
|
| 1406 |
-
total = sum(dict_importance.values())
|
| 1407 |
-
return np.array([v / total for _, v in dict_importance.items()])
|
| 1408 |
-
|
| 1409 |
-
def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
|
| 1410 |
-
maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
|
| 1411 |
-
# emb is (b, L, d), maxed_attentions is (b, L, L)
|
| 1412 |
-
emb_pooled = []
|
| 1413 |
-
for e, a, mask in zip(emb, maxed_attentions, attention_mask):
|
| 1414 |
-
dict_importance = self._page_rank(a)
|
| 1415 |
-
importance_weights = self._calculate_importance_weights(dict_importance, mask)
|
| 1416 |
-
num_tokens = int(mask.sum().item())
|
| 1417 |
-
emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0))
|
| 1418 |
-
pooled = torch.tensor(np.array(emb_pooled))
|
| 1419 |
-
return pooled
|
| 1420 |
-
|
| 1421 |
-
def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
|
| 1422 |
-
if attention_mask is None:
|
| 1423 |
-
return emb.mean(dim=1)
|
| 1424 |
-
else:
|
| 1425 |
-
attention_mask = attention_mask.unsqueeze(-1)
|
| 1426 |
-
return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
|
| 1427 |
-
|
| 1428 |
-
def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
|
| 1429 |
-
if attention_mask is None:
|
| 1430 |
-
return emb.max(dim=1).values
|
| 1431 |
-
else:
|
| 1432 |
-
attention_mask = attention_mask.unsqueeze(-1)
|
| 1433 |
-
return (emb * attention_mask).max(dim=1).values
|
| 1434 |
-
|
| 1435 |
-
def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
|
| 1436 |
-
if attention_mask is None:
|
| 1437 |
-
return emb.norm(dim=1, p=2)
|
| 1438 |
-
else:
|
| 1439 |
-
attention_mask = attention_mask.unsqueeze(-1)
|
| 1440 |
-
return (emb * attention_mask).norm(dim=1, p=2)
|
| 1441 |
-
|
| 1442 |
-
def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
|
| 1443 |
-
if attention_mask is None:
|
| 1444 |
-
return emb.median(dim=1).values
|
| 1445 |
-
else:
|
| 1446 |
-
attention_mask = attention_mask.unsqueeze(-1)
|
| 1447 |
-
return (emb * attention_mask).median(dim=1).values
|
| 1448 |
-
|
| 1449 |
-
def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
|
| 1450 |
-
if attention_mask is None:
|
| 1451 |
-
return emb.std(dim=1)
|
| 1452 |
-
else:
|
| 1453 |
-
# Compute variance correctly over non-masked positions, then take sqrt
|
| 1454 |
-
var = self.var_pooling(emb, attention_mask, **kwargs)
|
| 1455 |
-
return torch.sqrt(var)
|
| 1456 |
-
|
| 1457 |
-
def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
|
| 1458 |
-
if attention_mask is None:
|
| 1459 |
-
return emb.var(dim=1)
|
| 1460 |
-
else:
|
| 1461 |
-
# Correctly compute variance over only non-masked positions
|
| 1462 |
-
attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1)
|
| 1463 |
-
# Compute mean over non-masked positions
|
| 1464 |
-
mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
|
| 1465 |
-
mean = mean.unsqueeze(1) # (b, 1, d)
|
| 1466 |
-
# Compute squared differences from mean, only over non-masked positions
|
| 1467 |
-
squared_diff = (emb - mean) ** 2 # (b, L, d)
|
| 1468 |
-
# Sum squared differences over non-masked positions and divide by count
|
| 1469 |
-
var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
|
| 1470 |
-
return var
|
| 1471 |
-
|
| 1472 |
-
def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
|
| 1473 |
-
return emb[:, 0, :]
|
| 1474 |
-
|
| 1475 |
-
def __call__(
|
| 1476 |
-
self,
|
| 1477 |
-
emb: torch.Tensor,
|
| 1478 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 1479 |
-
attentions: Optional[torch.Tensor] = None
|
| 1480 |
-
): # [mean, max]
|
| 1481 |
-
final_emb = []
|
| 1482 |
-
for pooling_type in self.pooling_types:
|
| 1483 |
-
final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d)
|
| 1484 |
-
return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d)
|
| 1485 |
-
|
| 1486 |
-
|
| 1487 |
-
class _LegacyEmbeddingMixin:
|
| 1488 |
-
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 1489 |
-
raise NotImplementedError
|
| 1490 |
-
|
| 1491 |
-
@property
|
| 1492 |
-
def device(self) -> torch.device:
|
| 1493 |
-
"""Get the device of the model."""
|
| 1494 |
-
return next(self.parameters()).device
|
| 1495 |
-
|
| 1496 |
-
def _read_sequences_from_db(self, db_path: str) -> set[str]:
|
| 1497 |
-
"""Read sequences from SQLite database."""
|
| 1498 |
-
import sqlite3
|
| 1499 |
-
sequences = []
|
| 1500 |
-
with sqlite3.connect(db_path) as conn:
|
| 1501 |
-
c = conn.cursor()
|
| 1502 |
-
c.execute("SELECT sequence FROM embeddings")
|
| 1503 |
-
while True:
|
| 1504 |
-
row = c.fetchone()
|
| 1505 |
-
if row is None:
|
| 1506 |
-
break
|
| 1507 |
-
sequences.append(row[0])
|
| 1508 |
-
return set(sequences)
|
| 1509 |
-
|
| 1510 |
-
def embed_dataset(
|
| 1511 |
-
self,
|
| 1512 |
-
sequences: List[str],
|
| 1513 |
-
#tokenizer: PreTrainedTokenizerBase, # For E1, the tokenizing is handled by _embed
|
| 1514 |
-
batch_size: int = 2,
|
| 1515 |
-
max_len: int = 512,
|
| 1516 |
-
truncate: bool = True,
|
| 1517 |
-
full_embeddings: bool = False,
|
| 1518 |
-
embed_dtype: torch.dtype = torch.float32,
|
| 1519 |
-
pooling_types: List[str] = ['mean'],
|
| 1520 |
-
sql: bool = False,
|
| 1521 |
-
save: bool = True,
|
| 1522 |
-
sql_db_path: str = 'embeddings.db',
|
| 1523 |
-
save_path: str = 'embeddings.pth',
|
| 1524 |
-
**kwargs,
|
| 1525 |
-
) -> Optional[dict[str, torch.Tensor]]:
|
| 1526 |
-
"""Embed a dataset of protein sequences.
|
| 1527 |
-
|
| 1528 |
-
Args:
|
| 1529 |
-
sequences: List of protein sequences
|
| 1530 |
-
batch_size: Batch size for processing
|
| 1531 |
-
max_len: Maximum sequence length
|
| 1532 |
-
full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
|
| 1533 |
-
pooling_type: Type of pooling ('mean' or 'cls')
|
| 1534 |
-
sql: Whether to store embeddings in SQLite database - will be stored in float32
|
| 1535 |
-
sql_db_path: Path to SQLite database
|
| 1536 |
-
|
| 1537 |
-
Returns:
|
| 1538 |
-
Dictionary mapping sequences to embeddings, or None if sql=True
|
| 1539 |
-
|
| 1540 |
-
Note:
|
| 1541 |
-
- If sql=True, embeddings can only be stored in float32
|
| 1542 |
-
- sql is ideal if you need to stream a very large dataset for training in real-time
|
| 1543 |
-
- save=True is ideal if you can store the entire embedding dictionary in RAM
|
| 1544 |
-
- sql will be used if it is True and save is True or False
|
| 1545 |
-
- If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
|
| 1546 |
-
- Sequences will be truncated to max_len and sorted by length in descending order for faster processing
|
| 1547 |
-
|
| 1548 |
-
Example:
|
| 1549 |
-
>>> embedder = EmbeddingMixin()
|
| 1550 |
-
>>> embedding_dict = embedder.embed_dataset(
|
| 1551 |
-
sequences=[
|
| 1552 |
-
'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
|
| 1553 |
-
],
|
| 1554 |
-
batch_size=2, # adjust for your GPU memory
|
| 1555 |
-
max_len=512, # adjust for your needs
|
| 1556 |
-
full_embeddings=False, # if True, no pooling is performed
|
| 1557 |
-
embed_dtype=torch.float32, # cast to what dtype you want
|
| 1558 |
-
pooling_type=['mean', 'cls'], # more than one pooling type will be concatenated together
|
| 1559 |
-
sql=False, # if True, embeddings will be stored in SQLite database
|
| 1560 |
-
sql_db_path='embeddings.db',
|
| 1561 |
-
save=True, # if True, embeddings will be saved as a .pth file
|
| 1562 |
-
save_path='embeddings.pth',
|
| 1563 |
-
)
|
| 1564 |
-
>>> # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
|
| 1565 |
-
"""
|
| 1566 |
-
sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
|
| 1567 |
-
sequences = sorted(sequences, key=len, reverse=True)
|
| 1568 |
-
hidden_size = self.config.hidden_size
|
| 1569 |
-
pooler = Pooler(pooling_types) if not full_embeddings else None
|
| 1570 |
-
|
| 1571 |
-
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 1572 |
-
if full_embeddings or residue_embeddings.ndim == 2: # if already pooled or want residue-wise embeddings
|
| 1573 |
-
return residue_embeddings
|
| 1574 |
-
else:
|
| 1575 |
-
return pooler(residue_embeddings, attention_mask)
|
| 1576 |
-
|
| 1577 |
-
if sql:
|
| 1578 |
-
import sqlite3
|
| 1579 |
-
conn = sqlite3.connect(sql_db_path)
|
| 1580 |
-
c = conn.cursor()
|
| 1581 |
-
c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)')
|
| 1582 |
-
already_embedded = self._read_sequences_from_db(sql_db_path)
|
| 1583 |
-
to_embed = [seq for seq in sequences if seq not in already_embedded]
|
| 1584 |
-
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
| 1585 |
-
print(f"Embedding {len(to_embed)} new sequences")
|
| 1586 |
-
if len(to_embed) > 0:
|
| 1587 |
-
with torch.no_grad():
|
| 1588 |
-
for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
|
| 1589 |
-
seqs = to_embed[batch_start:batch_start + batch_size]
|
| 1590 |
-
input_ids, attention_mask = self._embed(seqs, return_attention_mask=True)
|
| 1591 |
-
embeddings = get_embeddings(input_ids, attention_mask).float() # sql requires float32
|
| 1592 |
-
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
| 1593 |
-
if full_embeddings:
|
| 1594 |
-
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
| 1595 |
-
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", (seq, emb.cpu().numpy().tobytes()))
|
| 1596 |
-
conn.commit()
|
| 1597 |
-
conn.commit()
|
| 1598 |
-
conn.close()
|
| 1599 |
-
return None
|
| 1600 |
-
|
| 1601 |
-
embeddings_dict = {}
|
| 1602 |
-
if os.path.exists(save_path):
|
| 1603 |
-
embeddings_dict = torch.load(save_path, map_location='cpu', weights_only=True)
|
| 1604 |
-
to_embed = [seq for seq in sequences if seq not in embeddings_dict]
|
| 1605 |
-
print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}")
|
| 1606 |
-
print(f"Embedding {len(to_embed)} new sequences")
|
| 1607 |
-
else:
|
| 1608 |
-
to_embed = sequences
|
| 1609 |
-
print(f"Embedding {len(to_embed)} new sequences")
|
| 1610 |
-
|
| 1611 |
-
if len(to_embed) > 0:
|
| 1612 |
-
with torch.no_grad():
|
| 1613 |
-
for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
|
| 1614 |
-
seqs = to_embed[batch_start:batch_start + batch_size]
|
| 1615 |
-
last_hidden_state, attention_mask = self._embed(seqs, return_attention_mask=True)
|
| 1616 |
-
embeddings = get_embeddings(last_hidden_state, attention_mask).to(embed_dtype)
|
| 1617 |
-
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
| 1618 |
-
if full_embeddings:
|
| 1619 |
-
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
| 1620 |
-
embeddings_dict[seq] = emb.cpu()
|
| 1621 |
-
|
| 1622 |
-
if save:
|
| 1623 |
-
torch.save(embeddings_dict, save_path)
|
| 1624 |
-
|
| 1625 |
-
return embeddings_dict
|
| 1626 |
-
|
| 1627 |
-
|
| 1628 |
class E1PreTrainedModel(PreTrainedModel):
|
| 1629 |
config_class = E1Config
|
| 1630 |
config: E1Config
|
|
|
|
| 1 |
import os
|
| 2 |
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
| 3 |
|
|
|
|
|
|
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
|
|
|
| 8 |
|
| 9 |
from einops import rearrange, repeat
|
| 10 |
from enum import Enum
|
| 11 |
+
from typing import Any, TypedDict, Callable, List
|
| 12 |
from dataclasses import dataclass
|
| 13 |
from tokenizers import Tokenizer
|
| 14 |
from transformers import PretrainedConfig, PreTrainedModel
|
| 15 |
from transformers.activations import ACT2FN
|
| 16 |
from transformers.modeling_outputs import ModelOutput
|
| 17 |
from transformers.utils import logging
|
| 18 |
+
|
| 19 |
+
from .embedding_mixin import EmbeddingMixin, Pooler
|
| 20 |
|
| 21 |
|
| 22 |
logger = logging.get_logger(__name__)
|
|
|
|
| 1354 |
return hidden_states, self_attn_weights, present_key_value
|
| 1355 |
|
| 1356 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1357 |
class E1PreTrainedModel(PreTrainedModel):
|
| 1358 |
config_class = E1Config
|
| 1359 |
config: E1Config
|