GinnM commited on
Commit
dfce08c
·
verified ·
1 Parent(s): 44c054e

Upload TransformerForMaskedLM

Browse files
Files changed (3) hide show
  1. config.json +2 -1
  2. model.safetensors +3 -0
  3. modeling_transformer.py +600 -0
config.json CHANGED
@@ -7,7 +7,8 @@
7
  "attention_probs_dropout_prob": 0.1,
8
  "attn_impl": "sdpa",
9
  "auto_map": {
10
- "AutoConfig": "configuration_transformer.TransformerConfig"
 
11
  },
12
  "decoder_start_token_id": 1,
13
  "decoder_vocab_size": 36,
 
7
  "attention_probs_dropout_prob": 0.1,
8
  "attn_impl": "sdpa",
9
  "auto_map": {
10
+ "AutoConfig": "configuration_transformer.TransformerConfig",
11
+ "AutoModelForMaskedLM": "modeling_transformer.TransformerForMaskedLM"
12
  },
13
  "decoder_start_token_id": 1,
14
  "decoder_vocab_size": 36,
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96a62f322c3c4c8399ab9432c38bb392a175bd82dfdd246de5de4a4bf43d1616
3
+ size 1208482648
modeling_transformer.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .configuration_transformer import TransformerConfig
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import init
5
+ import torch.nn.functional as F
6
+ from torch.nn.parameter import Parameter
7
+ from torch import nn
8
+ from typing import Tuple, List
9
+ from itertools import chain
10
+ from transformers.modeling_outputs import (
11
+ BaseModelOutput,
12
+ MaskedLMOutput,
13
+ CausalLMOutput,
14
+ )
15
+ from torch.utils.checkpoint import checkpoint
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ import math
18
+ try:
19
+ from flash_attn import flash_attn_varlen_func, flash_attn_func
20
+ except ImportError:
21
+ pass
22
+
23
+
24
+ A_LARGE_NEGATIVE_NUMER = -1e10
25
+
26
+
27
+ def create_4d_mask(attn_mask, return_type="bool", x=None, causal=False):
28
+ B, L = attn_mask.shape
29
+ device = attn_mask.device
30
+ mask_4d = torch.eq(attn_mask[:, None, :, None], attn_mask[:, None, None, :])
31
+ if causal:
32
+ causal_mask = torch.tril(torch.ones(L, L, device=device)).unsqueeze(0).unsqueeze(0)
33
+ mask_4d = mask_4d & causal_mask
34
+ if return_type == "bool":
35
+ return mask_4d.to(torch.bool)
36
+ elif return_type == "float":
37
+ mask_4d = mask_4d.to(x.dtype)
38
+ return mask_4d * 0 + (1 - mask_4d) * A_LARGE_NEGATIVE_NUMER
39
+
40
+
41
+ def rotate_half(x):
42
+ return torch.cat((-x[..., x.shape[-1] // 2 :], x[..., : x.shape[-1] // 2]), dim=-1)
43
+
44
+
45
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
46
+ cos = cos.unsqueeze(unsqueeze_dim)
47
+ sin = sin.unsqueeze(unsqueeze_dim)
48
+ q_embed = (q * cos) + (rotate_half(q) * sin)
49
+ k_embed = (k * cos) + (rotate_half(k) * sin)
50
+ return q_embed, k_embed
51
+
52
+
53
+ def apply_rotary_pos_emb_1(x, cos, sin, unsqueeze_dim=1):
54
+ cos = cos.unsqueeze(unsqueeze_dim)
55
+ sin = sin.unsqueeze(unsqueeze_dim)
56
+ return (x * cos) + (rotate_half(x) * sin)
57
+
58
+
59
+ class RMSNorm(nn.Module):
60
+
61
+ def __init__(self, dim: int, eps: float = 1e-5):
62
+ super().__init__()
63
+ self.eps = eps
64
+ init_device = None
65
+ self.weight = Parameter(
66
+ torch.empty(dim, device=init_device, dtype=torch.float32)
67
+ )
68
+ init.ones_(self.weight)
69
+
70
+ def _norm(self, x):
71
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
72
+
73
+ def forward(self, x : torch.Tensor) -> torch.Tensor:
74
+ output = self._norm(x.float()).type_as(x)
75
+ return output * self.weight
76
+
77
+
78
+ class TokenEmbedding(nn.Module):
79
+
80
+ def __init__(self, config: TransformerConfig):
81
+ super().__init__()
82
+ self.config = config
83
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
84
+ if config.embedding_layer_norm:
85
+ if config.layernorm_type == "layernorm":
86
+ self.rms_norm = nn.LayerNorm(
87
+ config.hidden_size, eps=config.layer_norm_eps, bias=False
88
+ ) # For name compatibility
89
+ elif config.layernorm_type == "rmsnorm":
90
+ self.rms_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
91
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
92
+ self.position_embedding_type = config.position_embedding_type
93
+ self.padding_idx = config.pad_token_id
94
+ self.token_dropout = config.token_dropout
95
+ self.mask_token_id = config.mask_token_id
96
+
97
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
98
+ embeddings = self.word_embeddings(input_ids)
99
+ if self.config.embedding_shrinking and self.training:
100
+ # Embedding shrinking (https://keg.cs.tsinghua.edu.cn/jietang/publications/ICLR23-GLM-130B.pdf)
101
+ embeddings = embeddings * 0.1 + embeddings.detach() * 0.9
102
+ if self.config.embedding_layer_norm:
103
+ embeddings = self.rms_norm(embeddings)
104
+ return embeddings
105
+
106
+
107
+ class RotaryEmbedding(nn.Module):
108
+
109
+ def __init__(self, dim: int, b: int = 10000):
110
+ super().__init__()
111
+ inv_freq = 1.0 / (
112
+ b ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)
113
+ )
114
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
115
+
116
+ @torch.no_grad()
117
+ def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
118
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
119
+ position_ids_expanded = position_ids[:, None, :].float()
120
+ with torch.autocast(device_type=x.device.type, enabled=False):
121
+ freqs = inv_freq_expanded.float() @ position_ids_expanded.float()
122
+ freqs = freqs.transpose(1, 2)
123
+ emb = torch.cat((freqs, freqs), dim=-1)
124
+ cos = emb.cos()
125
+ sin = emb.sin()
126
+ return cos.to(x.dtype), sin.to(x.dtype)
127
+
128
+
129
+ class SelfAttention(nn.Module):
130
+
131
+ def __init__(self, config: TransformerConfig, causal: bool=False):
132
+ super().__init__()
133
+ if config.hidden_size % config.num_attention_heads != 0:
134
+ raise ValueError(
135
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
136
+ f"heads ({config.num_attention_heads})"
137
+ )
138
+ self.config = config
139
+ self.num_attention_heads = config.num_attention_heads
140
+ self.attention_head_size = config.hidden_size // config.num_attention_heads
141
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
142
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
143
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
144
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
145
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
146
+ self.output = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
147
+ self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
148
+ self.config = config
149
+ self.causal = causal
150
+
151
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
152
+ new_x_shape = x.size()[:-1] + (
153
+ self.num_attention_heads,
154
+ self.attention_head_size,
155
+ ) # [B, L, D] -> [B, L, num_heads, head_size]
156
+ x = x.view(new_x_shape)
157
+ return x.permute(
158
+ 0, 2, 1, 3
159
+ ) # [B, L, num_heads, head_size] -> [B, num_heads, L, head_size] for broadcasting in the future
160
+
161
+ def naive_forward(
162
+ self,
163
+ hidden_states: torch.Tensor,
164
+ attention_mask: torch.Tensor = None,
165
+ rotary_embeddings: torch.Tensor = None,
166
+ output_attentions: bool = False,
167
+ ) -> Tuple[torch.Tensor]:
168
+ B, L, D = hidden_states.size()
169
+ query_states = self.query(hidden_states)
170
+ key_states = self.key(hidden_states)
171
+ value_states = self.value(hidden_states)
172
+ key_states = self.transpose_for_scores(key_states).contiguous() # [B, L, D] -> [B, num_heads, L, head_size]
173
+ query_states = self.transpose_for_scores(query_states).contiguous() # [B, L, D] -> [B, num_heads, L, head_size]
174
+ value_states = self.transpose_for_scores(value_states).contiguous() # [B, L, D] -> [B, num_heads, L, head_size]
175
+ if rotary_embeddings is not None:
176
+ cos, sin = rotary_embeddings
177
+ query_states, key_states = apply_rotary_pos_emb(
178
+ query_states, key_states, cos, sin
179
+ )
180
+ scale_factor = self.attention_head_size**-0.5
181
+ attention_scores = torch.matmul(
182
+ query_states, key_states.transpose(-1, -2)
183
+ ) # [B, num_heads, L, L]
184
+ attention_scores = attention_scores * scale_factor
185
+ if attention_mask is not None:
186
+ attention_scores = attention_scores + attention_mask
187
+ attention_probs = nn.functional.softmax(
188
+ attention_scores, dim=-1, dtype=torch.float32
189
+ ).to(query_states.dtype)
190
+ attention_probs = self.dropout(attention_probs)
191
+ context = torch.matmul(
192
+ attention_probs, value_states
193
+ ) # [B, num_heads, L, head_size]
194
+ context = context.permute(
195
+ 0, 2, 1, 3
196
+ ).contiguous() # [B, L, num_heads, head_size]
197
+ context = context.view(B, L, D)
198
+ context = self.output(context)
199
+ context = self.output_dropout(context)
200
+ return_attention_probs = attention_probs.detach() if output_attentions else None
201
+ return context, return_attention_probs
202
+
203
+ def sdpa_forward(
204
+ self,
205
+ hidden_states: torch.Tensor,
206
+ attention_mask: torch.Tensor = None,
207
+ rotary_embeddings: torch.Tensor = None,
208
+ ) -> Tuple[torch.Tensor]:
209
+ B, L, D = hidden_states.size()
210
+ query_states = self.query(hidden_states)
211
+ key_states = self.key(hidden_states)
212
+ value_states = self.value(hidden_states)
213
+ key_states = self.transpose_for_scores(key_states).contiguous()
214
+ query_states = self.transpose_for_scores(query_states).contiguous()
215
+ value_states = self.transpose_for_scores(value_states).contiguous()
216
+ if rotary_embeddings is not None:
217
+ cos, sin = rotary_embeddings
218
+ query_states, key_states = apply_rotary_pos_emb(
219
+ query_states, key_states, cos, sin
220
+ )
221
+ scale_factor = self.attention_head_size**-0.5
222
+ dropout_p = self.config.attention_probs_dropout_prob if self.training else 0
223
+ context = F.scaled_dot_product_attention(
224
+ query=query_states,
225
+ key=key_states,
226
+ value=value_states,
227
+ attn_mask=attention_mask,
228
+ dropout_p=dropout_p,
229
+ scale=scale_factor,
230
+ ) # [B, num_heads, L, head_size]
231
+ context = context.permute(
232
+ 0, 2, 1, 3
233
+ ).contiguous() # [B, L, num_heads, head_size]
234
+ context = context.view(B, L, D)
235
+ context = self.output(context)
236
+ context = self.output_dropout(context)
237
+ return_attention_probs = None
238
+ return context, return_attention_probs
239
+
240
+ def flash_attn_forward(
241
+ self,
242
+ hidden_states: torch.Tensor,
243
+ rotary_embeddings: torch.Tensor = None,
244
+ lengths: List[List[int]] = None,
245
+ ) -> Tuple[torch.Tensor]:
246
+ B, L, D = hidden_states.size()
247
+ NH = self.num_attention_heads
248
+ H = self.attention_head_size
249
+
250
+ scale_factor = self.attention_head_size**-0.5
251
+ query_states = self.query(hidden_states)
252
+ key_states = self.key(hidden_states)
253
+ value_states = self.value(hidden_states)
254
+
255
+ if lengths is not None:
256
+ # flash_attn_varlen_func
257
+ query_states = query_states.view(B * L, NH, H).contiguous()
258
+ key_states = key_states.view(B * L, NH, H).contiguous()
259
+ value_states = value_states.view(B * L, NH, H).contiguous()
260
+ if rotary_embeddings is not None:
261
+ cos, sin = rotary_embeddings
262
+ cos = cos.view(B * L, 1, H)
263
+ sin = sin.view(B * L, 1, H)
264
+ query_states = (query_states * cos) + (rotate_half(query_states) * sin)
265
+ key_states = (key_states * cos) + (rotate_half(key_states) * sin)
266
+ lengths = [0, ] + list(chain(*lengths))
267
+ lengths = torch.tensor(lengths, dtype=torch.int, device=query_states.device)
268
+ max_seqlen = torch.max(lengths)
269
+ cum_seqlen = torch.cumsum(lengths, dim=0, dtype=torch.int)
270
+ context = flash_attn_varlen_func(
271
+ q=query_states,
272
+ k=key_states,
273
+ v=value_states,
274
+ cu_seqlens_q=cum_seqlen,
275
+ cu_seqlens_k=cum_seqlen,
276
+ max_seqlen_q=max_seqlen,
277
+ max_seqlen_k=max_seqlen,
278
+ causal=self.causal,
279
+ return_attn_probs=False,
280
+ softmax_scale=scale_factor,
281
+ )
282
+ else:
283
+ query_states = query_states.view(B, L, NH, H).contiguous()
284
+ key_states = key_states.view(B, L, NH, H).contiguous()
285
+ value_states = value_states.view(B, L, NH, H).contiguous()
286
+ if rotary_embeddings is not None:
287
+ cos, sin = rotary_embeddings
288
+ query_states, key_states = apply_rotary_pos_emb(
289
+ query_states, key_states, cos, sin, unsqueeze_dim=2
290
+ )
291
+ context = flash_attn_func(
292
+ q=query_states,
293
+ k=key_states,
294
+ v=value_states,
295
+ softmax_scale=scale_factor,
296
+ causal=self.causal,
297
+ )
298
+ context = context.view(B, L, D).contiguous()
299
+ context = self.output(context)
300
+ context = self.output_dropout(context)
301
+ return_attention_probs = None
302
+ return context, return_attention_probs
303
+
304
+ def forward(
305
+ self,
306
+ hidden_states: torch.Tensor,
307
+ attention_mask: torch.Tensor = None,
308
+ lengths: List[torch.Tensor] = None,
309
+ rotary_embeddings: torch.Tensor = None,
310
+ output_attentions: bool = False,
311
+ ):
312
+ if self.config.attn_impl == "naive":
313
+ return self.naive_forward(
314
+ hidden_states=hidden_states,
315
+ attention_mask=attention_mask,
316
+ rotary_embeddings=rotary_embeddings,
317
+ output_attentions=output_attentions,
318
+ )
319
+ elif self.config.attn_impl == "sdpa":
320
+ return self.sdpa_forward(
321
+ hidden_states=hidden_states,
322
+ attention_mask=attention_mask,
323
+ rotary_embeddings=rotary_embeddings,
324
+ )
325
+ elif self.config.attn_impl == "flash_attn":
326
+ return self.flash_attn_forward(
327
+ hidden_states=hidden_states,
328
+ rotary_embeddings=rotary_embeddings,
329
+ lengths=lengths,
330
+ )
331
+
332
+
333
+ class FeedForwardNetwork(nn.Module):
334
+
335
+ def __init__(self, config: TransformerConfig):
336
+ super().__init__()
337
+ self.w1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
338
+ self.w2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
339
+ if config.act_fn == "gelu":
340
+ self.act_fn = nn.GELU()
341
+ elif config.act_fn == "silu":
342
+ self.act_fn = nn.SiLU()
343
+
344
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
345
+ return self.w2(self.act_fn(self.w1(hidden_states)))
346
+
347
+
348
+ class TransFormerLayer(nn.Module):
349
+
350
+ def __init__(self, config: TransformerConfig, causal=False):
351
+ super().__init__()
352
+ self.config = config
353
+ self.causal = causal
354
+ self.attention = SelfAttention(config, causal=causal)
355
+ self.ffn = FeedForwardNetwork(config)
356
+ if config.layernorm_type == "layernorm":
357
+ self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False)
358
+ self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False)
359
+ else:
360
+ self.pre_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
361
+ self.post_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
362
+
363
+ def forward(
364
+ self,
365
+ hidden_states: torch.Tensor,
366
+ attention_mask: torch.Tensor = None,
367
+ lengths: List[torch.Tensor] = None,
368
+ rotary_embeddings: torch.Tensor = None,
369
+ output_attentions: bool = False,
370
+ ):
371
+ residual = hidden_states
372
+ hidden_states = self.pre_norm(hidden_states)
373
+ hidden_states, attn_probs = self.attention(
374
+ hidden_states=hidden_states,
375
+ attention_mask=attention_mask,
376
+ lengths=lengths,
377
+ rotary_embeddings=rotary_embeddings,
378
+ output_attentions=output_attentions
379
+ )
380
+ hidden_states = residual + hidden_states
381
+ residual = hidden_states
382
+ hidden_states = self.post_norm(hidden_states)
383
+ hidden_states = self.ffn(hidden_states)
384
+ hidden_states = residual + hidden_states
385
+ return (hidden_states, attn_probs)
386
+
387
+
388
+ class TransformerCore(nn.Module):
389
+
390
+ def __init__(self, config: TransformerConfig, causal=False):
391
+ super().__init__()
392
+ self.config = config
393
+ self.layer = []
394
+ for _ in range(config.num_hidden_layers):
395
+ sub_layer = TransFormerLayer(config, causal=causal)
396
+ self.layer.append(sub_layer)
397
+ self.layer = nn.ModuleList(self.layer)
398
+ if self.config.layernorm_type == "layernorm":
399
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False)
400
+ else:
401
+ self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
402
+ self.gradient_checkpointing = False
403
+
404
+ def forward(
405
+ self,
406
+ hidden_states: torch.Tensor,
407
+ attention_mask: torch.Tensor = None,
408
+ lengths: List[torch.Tensor] = None,
409
+ rotary_embeddings: torch.Tensor = None,
410
+ output_attentions: bool = False,
411
+ output_hidden_states=False,
412
+ ):
413
+ all_hidden_states = []
414
+ all_self_attentions = []
415
+ for i, layer_module in enumerate(self.layer):
416
+ if output_hidden_states:
417
+ all_hidden_states.append(hidden_states.detach().cpu())
418
+ if torch.is_grad_enabled() and hidden_states.requires_grad and self.gradient_checkpointing:
419
+ hidden_states, attn_probs = checkpoint(
420
+ layer_module,
421
+ hidden_states,
422
+ attention_mask,
423
+ rotary_embeddings,
424
+ lengths,
425
+ output_attentions,
426
+ use_reentrant=False,
427
+ )
428
+ else:
429
+ hidden_states, attn_probs = layer_module(
430
+ hidden_states=hidden_states,
431
+ attention_mask=attention_mask,
432
+ lengths=lengths,
433
+ rotary_embeddings=rotary_embeddings,
434
+ output_attentions=output_attentions
435
+ )
436
+ if output_attentions:
437
+ all_self_attentions.append(attn_probs.detach().cpu() if attn_probs is not None else None)
438
+ hidden_states = self.norm(hidden_states)
439
+ if output_hidden_states:
440
+ all_hidden_states.append(hidden_states.detach().cpu(), )
441
+ return BaseModelOutput(
442
+ last_hidden_state=hidden_states,
443
+ hidden_states=all_hidden_states,
444
+ attentions=all_self_attentions,
445
+ )
446
+
447
+
448
+ class BaseTransformerModel(PreTrainedModel):
449
+
450
+ config_class = TransformerConfig
451
+ base_model_prefix = "transformer"
452
+
453
+
454
+ class TransformerModel(BaseTransformerModel):
455
+
456
+ def __init__(self, config: TransformerConfig, causal=False):
457
+ super().__init__(config)
458
+ self.config = config
459
+ self.rotary_embedding = RotaryEmbedding(dim=config.hidden_size // config.num_attention_heads)
460
+ self.token_embedding = TokenEmbedding(config)
461
+ self.transformer = TransformerCore(config, causal=causal)
462
+ self.causal = causal
463
+
464
+ def enable_gradient_checkpointing(self):
465
+ self.transformer.gradient_checkpointing = True
466
+
467
+ def disable_gradient_checkpointing(self):
468
+ self.transformer.gradient_checkpointing = False
469
+
470
+ def forward(
471
+ self,
472
+ input_ids: torch.Tensor,
473
+ attention_mask: torch.Tensor=None,
474
+ lengths: torch.Tensor=None,
475
+ position_ids: torch.Tensor=None,
476
+ output_attentions=False,
477
+ output_hidden_states=False,
478
+ ) -> BaseModelOutput:
479
+ embeddings = self.token_embedding(input_ids)
480
+ if position_ids is None:
481
+ position_ids = torch.arange(input_ids.size(1)).to(input_ids.device)
482
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
483
+ if position_ids.shape != input_ids.shape:
484
+ raise ValueError("Position IDs must have the same shape as input_ids")
485
+ rotary_embeddings = self.rotary_embedding(embeddings, position_ids)
486
+
487
+ if attention_mask is not None:
488
+ if self.config.attn_impl == "flash_attn":
489
+ raise ValueError("Flash attention does not support specifying attention mask")
490
+ attention_mask = create_4d_mask(
491
+ attention_mask,
492
+ return_type="float",
493
+ x=embeddings,
494
+ causal=self.causal,
495
+ )
496
+
497
+ outputs = self.transformer(
498
+ hidden_states=embeddings,
499
+ attention_mask=attention_mask,
500
+ lengths=lengths,
501
+ rotary_embeddings=rotary_embeddings,
502
+ output_attentions=output_attentions,
503
+ output_hidden_states=output_hidden_states
504
+ )
505
+
506
+ return BaseModelOutput(
507
+ last_hidden_state=outputs.last_hidden_state,
508
+ hidden_states=outputs.hidden_states,
509
+ attentions=outputs.attentions,
510
+ )
511
+
512
+
513
+ class TransformerForMaskedLM(BaseTransformerModel):
514
+
515
+ def __init__(self, config: TransformerConfig):
516
+ super().__init__(config)
517
+ self.model = TransformerModel(config, causal=False)
518
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
519
+ self.post_init()
520
+
521
+ def forward(
522
+ self,
523
+ input_ids: torch.Tensor,
524
+ attention_mask: torch.Tensor=None,
525
+ lengths: torch.Tensor=None,
526
+ position_ids: torch.Tensor=None,
527
+ labels: torch.Tensor=None,
528
+ output_attentions=False,
529
+ output_hidden_states=False,
530
+ ) -> MaskedLMOutput:
531
+ outputs = self.model(
532
+ input_ids=input_ids,
533
+ attention_mask=attention_mask,
534
+ lengths=lengths,
535
+ position_ids=position_ids,
536
+ output_attentions=output_attentions,
537
+ output_hidden_states=output_hidden_states
538
+ )
539
+ sequence_output = outputs.last_hidden_state
540
+ prediction_scores = self.lm_head(sequence_output)
541
+ loss = None
542
+ if labels is not None:
543
+ loss_fct = nn.CrossEntropyLoss()
544
+ labels = labels.to(prediction_scores.device)
545
+ loss = loss_fct(
546
+ prediction_scores.view(-1, self.config.vocab_size).float(), labels.view(-1)
547
+ )
548
+ return MaskedLMOutput(
549
+ loss=loss,
550
+ logits=prediction_scores,
551
+ hidden_states=sequence_output,
552
+ attentions=outputs.attentions,
553
+ )
554
+
555
+
556
+ class TransformerForCausalLM(BaseTransformerModel):
557
+
558
+ def __init__(self, config):
559
+ super().__init__(config)
560
+ self.model = TransformerModel(config, causal=True)
561
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
562
+ self.init_weights()
563
+
564
+ def forward(
565
+ self,
566
+ input_ids: torch.Tensor,
567
+ attention_mask: torch.Tensor=None,
568
+ lengths: torch.Tensor=None,
569
+ position_ids: torch.Tensor=None,
570
+ labels: torch.Tensor=None,
571
+ output_attentions=False,
572
+ output_hidden_states=False,
573
+ reduction="mean",
574
+ ) -> CausalLMOutput:
575
+ outputs = self.model(
576
+ input_ids=input_ids,
577
+ attention_mask=attention_mask,
578
+ lengths=lengths,
579
+ position_ids=position_ids,
580
+ output_attentions=output_attentions,
581
+ output_hidden_states=output_hidden_states
582
+ )
583
+ sequence_output = outputs.last_hidden_state
584
+ prediction_scores = self.lm_head(sequence_output)
585
+ loss = None
586
+ if labels is not None:
587
+ loss_fct = nn.CrossEntropyLoss(reduction=reduction)
588
+ labels = labels.to(prediction_scores.device)
589
+ loss = loss_fct(
590
+ prediction_scores.view(-1, self.config.vocab_size).to(torch.float32),
591
+ labels.view(-1),
592
+ )
593
+ return CausalLMOutput(
594
+ loss=loss,
595
+ logits=prediction_scores,
596
+ hidden_states=outputs.hidden_states,
597
+ attentions=outputs.attentions,
598
+ )
599
+
600
+ TransformerForMaskedLM.register_for_auto_class("AutoModelForMaskedLM")