damerajee commited on
Commit
8dd438b
·
verified ·
1 Parent(s): 475027c

Update modeling_Llamoe.py

Browse files
Files changed (1) hide show
  1. modeling_Llamoe.py +133 -162
modeling_Llamoe.py CHANGED
@@ -33,10 +33,9 @@ from transformers.utils import (
33
  replace_return_docstrings,
34
  )
35
  from transformers.utils.import_utils import is_torch_fx_available
36
- from .configuration_Llamoe import LlamoeConfig
37
 
38
  from math import sqrt as math_sqrt
39
- _CONFIG_FOR_DOC = "LlamoeConfig"
40
 
41
 
42
  if is_flash_attn_2_available():
@@ -53,8 +52,10 @@ if is_torch_fx_available():
53
 
54
  _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
55
 
56
- def approx_gelu(x):
57
- return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))
 
 
58
 
59
  def load_balancing_loss_func(
60
  gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
@@ -130,36 +131,40 @@ def load_balancing_loss_func(
130
 
131
 
132
 
 
 
 
133
  def _get_unpad_data(attention_mask):
134
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
135
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
136
  max_seqlen_in_batch = seqlens_in_batch.max().item()
137
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
138
  return (
139
  indices,
140
  cu_seqlens,
141
  max_seqlen_in_batch,
142
  )
143
 
 
 
144
  class LlamoeRMSNorm(nn.Module):
145
- def __init__(self, hidden_size, eps=1e-6):
146
- """
147
- LlamaRMSNorm is equivalent to T5LayerNorm
148
- """
149
  super().__init__()
150
- self.weight = nn.Parameter(torch.ones(hidden_size))
151
- self.variance_epsilon = eps
152
-
153
- def forward(self, hidden_states):
154
- input_dtype = hidden_states.dtype
155
- hidden_states = hidden_states.to(torch.float32)
156
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
157
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
158
- return self.weight * hidden_states.to(input_dtype)
159
 
 
 
 
 
160
 
161
- ALL_LAYERNORM_LAYERS.append(LlamoeRMSNorm)
 
 
 
 
162
 
 
163
 
164
  class LlamoeRotaryEmbedding(nn.Module):
165
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
@@ -205,63 +210,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
205
  k_embed = (k * cos) + (rotate_half(k) * sin)
206
  return q_embed, k_embed
207
 
208
-
209
-
210
- class LlamoeBlockSparseTop2MLP(nn.Module):
211
- def __init__(self, config: LlamoeConfig):
212
- super().__init__()
213
- self.ffn_dim = config.intermediate_size
214
- self.hidden_dim = config.hidden_size
215
-
216
- self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
217
- self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
218
- self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
219
-
220
- self.act_fn = approx_gelu
221
-
222
- def forward(self, hidden_states):
223
- current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
224
- current_hidden_states = self.w2(current_hidden_states)
225
- return current_hidden_states.to(hidden_states.dtype)
226
-
227
- class LlamoeSparseMoeBlock(nn.Module):
228
- def __init__(self, config):
229
- super().__init__()
230
- self.hidden_dim = config.hidden_size
231
- self.ffn_dim = config.intermediate_size
232
- self.num_experts = config.num_local_experts
233
- self.top_k = 2
234
-
235
- # gating
236
- self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
237
-
238
- self.experts = nn.ModuleList([LlamoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
239
-
240
- def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
241
- batch_size, sequence_length, hidden_dim = hidden_states.shape
242
- hidden_states = hidden_states.view(-1, hidden_dim)
243
-
244
- # router_logits: (batch * sequence_length, n_experts)
245
- router_logits = self.gate(hidden_states)
246
- routing_weights = F.softmax(router_logits, dim=1)
247
- topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
248
- topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
249
-
250
- hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
251
-
252
- y = torch.empty_like(hidden_states)
253
-
254
- flat_topk_idx = topk_idx.view(-1)
255
- for i in range(self.num_experts):
256
- expert = self.experts[i]
257
- expert_output = expert(hidden_states[flat_topk_idx == i])
258
- y[flat_topk_idx == i] = expert_output
259
-
260
- y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
261
-
262
- final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
263
- return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
264
-
265
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
266
  """
267
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -273,10 +223,10 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
273
  hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
274
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
275
 
276
-
277
  class LlamoeAttention(nn.Module):
278
  """Multi-headed attention from 'Attention Is All You Need' paper"""
279
 
 
280
  def __init__(self, config: LlamoeConfig, layer_idx: Optional[int] = None):
281
  super().__init__()
282
  self.config = config
@@ -291,14 +241,14 @@ class LlamoeAttention(nn.Module):
291
  self.attention_dropout = config.attention_dropout
292
  self.hidden_size = config.hidden_size
293
  self.num_heads = config.num_attention_heads
294
- self.head_dim = self.hidden_size // self.num_heads
295
  self.num_key_value_heads = config.num_key_value_heads
296
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
297
  self.max_position_embeddings = config.max_position_embeddings
298
  self.rope_theta = config.rope_theta
299
  self.is_causal = True
300
 
301
- if (self.head_dim * self.num_heads) != self.hidden_size:
302
  raise ValueError(
303
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
304
  f" and `num_heads`: {self.num_heads})."
@@ -307,19 +257,12 @@ class LlamoeAttention(nn.Module):
307
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
308
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
309
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
310
- self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
311
- self._init_rope()
312
-
313
- def _init_rope(self):
314
- if self.config.rope_scaling is None:
315
- self.rotary_emb = LlamoeRotaryEmbedding(
316
- self.head_dim,
317
- max_position_embeddings=self.max_position_embeddings,
318
- base=self.rope_theta,
319
- )
320
-
321
- else:
322
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
323
 
324
  def forward(
325
  self,
@@ -334,35 +277,17 @@ class LlamoeAttention(nn.Module):
334
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
335
  bsz, q_len, _ = hidden_states.size()
336
 
337
- if self.config.pretraining_tp > 1:
338
- key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
339
- query_slices = self.q_proj.weight.split(
340
- (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
341
- )
342
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
343
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
344
-
345
- query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
346
- query_states = torch.cat(query_states, dim=-1)
347
-
348
- key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
349
- key_states = torch.cat(key_states, dim=-1)
350
-
351
- value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
352
- value_states = torch.cat(value_states, dim=-1)
353
-
354
- else:
355
- query_states = self.q_proj(hidden_states)
356
- key_states = self.k_proj(hidden_states)
357
- value_states = self.v_proj(hidden_states)
358
 
359
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
360
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
361
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
362
 
363
  past_key_value = getattr(self, "past_key_value", past_key_value)
364
- cos, sin = self.rotary_emb(value_states, position_ids)
365
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
366
 
367
  if past_key_value is not None:
368
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
@@ -375,9 +300,10 @@ class LlamoeAttention(nn.Module):
375
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
376
 
377
  if attention_mask is not None: # no matter the length, we just slice it
378
- causal_mask = attention_mask
379
  if cache_position is not None:
380
  causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
 
 
381
  attn_weights = attn_weights + causal_mask
382
 
383
  # upcast attention to fp32
@@ -393,14 +319,8 @@ class LlamoeAttention(nn.Module):
393
 
394
  attn_output = attn_output.transpose(1, 2).contiguous()
395
 
396
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
397
-
398
- if self.config.pretraining_tp > 1:
399
- attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
400
- o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
401
- attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
402
- else:
403
- attn_output = self.o_proj(attn_output)
404
 
405
  if not output_attentions:
406
  attn_weights = None
@@ -408,9 +328,10 @@ class LlamoeAttention(nn.Module):
408
  return attn_output, attn_weights, past_key_value
409
 
410
 
411
- class LlamoeFlashAttention2(LlamoeAttention):
 
412
  """
413
- Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
414
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
415
  flash attention and deal with padding tokens in case the input contains any of them.
416
  """
@@ -423,6 +344,7 @@ class LlamoeFlashAttention2(LlamoeAttention):
423
  # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
424
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
425
 
 
426
  def forward(
427
  self,
428
  hidden_states: torch.Tensor,
@@ -449,8 +371,8 @@ class LlamoeFlashAttention2(LlamoeAttention):
449
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
450
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
451
 
452
- cos, sin = self.rotary_emb(value_states, position_ids)
453
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
454
 
455
  past_key_value = getattr(self, "past_key_value", past_key_value)
456
 
@@ -471,7 +393,7 @@ class LlamoeFlashAttention2(LlamoeAttention):
471
  # therefore the input hidden states gets silently casted in float32. Hence, we need
472
  # cast them back in the correct dtype just to be sure everything works as expected.
473
  # This might slowdown training & inference so it is recommended to not cast the LayerNorms
474
- # in fp32. (LlamaRMSNorm handles it correctly)
475
 
476
  input_dtype = query_states.dtype
477
  if input_dtype == torch.float32:
@@ -497,7 +419,7 @@ class LlamoeFlashAttention2(LlamoeAttention):
497
  query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
498
  )
499
 
500
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
501
  attn_output = self.o_proj(attn_output)
502
 
503
  if not output_attentions:
@@ -511,7 +433,6 @@ class LlamoeFlashAttention2(LlamoeAttention):
511
  """
512
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
513
  first unpad the input, then computes the attention scores and pad the final attention scores.
514
-
515
  Args:
516
  query_states (`torch.Tensor`):
517
  Input query states to be passed to Flash Attention API
@@ -530,7 +451,7 @@ class LlamoeFlashAttention2(LlamoeAttention):
530
  if not self._flash_attn_uses_top_left_mask:
531
  causal = self.is_causal
532
  else:
533
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
534
  causal = self.is_causal and query_length != 1
535
 
536
  # Contains at least one padding token in the sequence
@@ -603,6 +524,7 @@ class LlamoeFlashAttention2(LlamoeAttention):
603
  )
604
 
605
 
 
606
  class LlamoeSdpaAttention(LlamoeAttention):
607
  """
608
  Gemmoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -624,7 +546,7 @@ class LlamoeSdpaAttention(LlamoeAttention):
624
  if output_attentions:
625
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
626
  logger.warning_once(
627
- "LlamoeModel is using LlamoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
628
  'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
629
  )
630
  return super().forward(
@@ -660,9 +582,9 @@ class LlamoeSdpaAttention(LlamoeAttention):
660
  key_states = repeat_kv(key_states, self.num_key_value_groups)
661
  value_states = repeat_kv(value_states, self.num_key_value_groups)
662
 
663
-
664
- causal_mask = torch.tril(torch.ones((bsz, q_len, q_len), device=query_states.device))
665
- causal_mask = causal_mask.to(dtype=query_states.dtype)
666
 
667
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
668
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
@@ -671,11 +593,6 @@ class LlamoeSdpaAttention(LlamoeAttention):
671
  key_states = key_states.contiguous()
672
  value_states = value_states.contiguous()
673
 
674
- print("query:",query_states.shape)
675
- print("keys:",key_states.shape)
676
- print("values:",value_states.shape)
677
- print("causal_mask:",causal_mask.shape)
678
-
679
  attn_output = torch.nn.functional.scaled_dot_product_attention(
680
  query_states,
681
  key_states,
@@ -690,23 +607,80 @@ class LlamoeSdpaAttention(LlamoeAttention):
690
  attn_output = self.o_proj(attn_output)
691
 
692
  return attn_output, None, past_key_value
693
-
694
 
695
  LLAMOE_ATTENTION_CLASSES = {
696
  "eager": LlamoeAttention,
697
  "flash_attention_2": LlamoeFlashAttention2,
698
- "sdpa": LlamoeSdpaAttention
699
  }
700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
 
 
 
702
  class LlamoeDecoderLayer(nn.Module):
703
  def __init__(self, config: LlamoeConfig, layer_idx: int):
704
  super().__init__()
705
  self.hidden_size = config.hidden_size
706
 
707
- self.self_attn = LLAMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
708
 
709
- self.block_sparse_moe = LlamoeBlockSparseTop2MLP(config)
710
  self.input_layernorm = LlamoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
711
  self.post_attention_layernorm = LlamoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
712
 
@@ -777,18 +751,15 @@ class LlamoeDecoderLayer(nn.Module):
777
  return outputs
778
 
779
 
780
-
781
- LLAMOE_START_DOCSTRING = r"""
782
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
783
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
784
  etc.)
785
-
786
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
787
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
788
  and behavior.
789
-
790
  Parameters:
791
- config ([`LlamaConfig`]):
792
  Model configuration class with all the parameters of the model. Initializing with a config file does not
793
  load the weights associated with the model, only the configuration. Check out the
794
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
@@ -796,11 +767,11 @@ LLAMOE_START_DOCSTRING = r"""
796
 
797
 
798
  @add_start_docstrings(
799
- "The bare Llamoe Model outputting raw hidden-states without any specific head on top.",
800
- LLAMOE_START_DOCSTRING,
801
  )
802
 
803
- class LlammoePreTrainedModel(PreTrainedModel):
804
  config_class = LlamoeConfig
805
  base_model_prefix = "model"
806
  supports_gradient_checkpointing = True
@@ -844,7 +815,7 @@ class LlammoePreTrainedModel(PreTrainedModel):
844
  layer.self_attn.past_key_value = None
845
 
846
 
847
- LLAMOE_INPUTS_DOCSTRING = r"""
848
  Args:
849
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
850
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
@@ -908,14 +879,14 @@ LLAMOE_INPUTS_DOCSTRING = r"""
908
 
909
  @add_start_docstrings(
910
  "The bare Gemmoe Model outputting raw hidden-states without any specific head on top.",
911
- LLAMOE_START_DOCSTRING,
912
  )
913
 
914
- class LlamoeModel(LlammoePreTrainedModel):
915
  """
916
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamoeDecoderLayer`]
917
  Args:
918
- config: LlamoeConfig
919
  """
920
 
921
  def __init__(self, config: LlamoeConfig):
@@ -945,7 +916,7 @@ class LlamoeModel(LlammoePreTrainedModel):
945
  def set_input_embeddings(self, value):
946
  self.embed_tokens = value
947
 
948
- @add_start_docstrings_to_model_forward(LLAMOE_INPUTS_DOCSTRING)
949
  def forward(
950
  self,
951
  input_ids: torch.LongTensor = None,
@@ -1121,12 +1092,12 @@ class LlamoeModel(LlammoePreTrainedModel):
1121
 
1122
  return causal_mask
1123
 
1124
- class LlamoeForCausalLM(LlammoePreTrainedModel):
1125
  _tied_weights_keys = ["lm_head.weight"]
1126
 
1127
  def __init__(self, config):
1128
  super().__init__(config)
1129
- self.model = LlamoeModel(config)
1130
  self.vocab_size = config.vocab_size
1131
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1132
  self.router_aux_loss_coef = config.router_aux_loss_coef
@@ -1153,7 +1124,7 @@ class LlamoeForCausalLM(LlammoePreTrainedModel):
1153
  def get_decoder(self):
1154
  return self.model
1155
 
1156
- @add_start_docstrings_to_model_forward(LLAMOE_INPUTS_DOCSTRING)
1157
  @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1158
  # Ignore copy
1159
  def forward(
 
33
  replace_return_docstrings,
34
  )
35
  from transformers.utils.import_utils import is_torch_fx_available
36
+ from .configuration_gemmoe import LlamoeConfig
37
 
38
  from math import sqrt as math_sqrt
 
39
 
40
 
41
  if is_flash_attn_2_available():
 
52
 
53
  _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
54
 
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+ _CONFIG_FOR_DOC = "LlamoeConfig"
59
 
60
  def load_balancing_loss_func(
61
  gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
 
131
 
132
 
133
 
134
+ def approx_gelu(x):
135
+ return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))
136
+
137
  def _get_unpad_data(attention_mask):
138
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
139
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
140
  max_seqlen_in_batch = seqlens_in_batch.max().item()
141
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
142
  return (
143
  indices,
144
  cu_seqlens,
145
  max_seqlen_in_batch,
146
  )
147
 
148
+
149
+
150
  class LlamoeRMSNorm(nn.Module):
151
+ def __init__(self, dim: int, eps: float = 1e-6):
 
 
 
152
  super().__init__()
153
+ self.eps = eps
154
+ self.weight = nn.Parameter(torch.zeros(dim))
 
 
 
 
 
 
 
155
 
156
+ def _norm(self, x):
157
+ x_float = x.float()
158
+ normed_x = x_float * torch.rsqrt(x_float.pow(2).mean(-1, keepdim=True) + self.eps)
159
+ return normed_x
160
 
161
+ def forward(self, x):
162
+ normed_x = self._norm(x)
163
+ # Downcast the result to the original dtype at the end
164
+ normed_x = normed_x.type_as(x)
165
+ return normed_x * (self.weight + 1)
166
 
167
+ ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
168
 
169
  class LlamoeRotaryEmbedding(nn.Module):
170
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
 
210
  k_embed = (k * cos) + (rotate_half(k) * sin)
211
  return q_embed, k_embed
212
 
213
+
214
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
216
  """
217
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
 
223
  hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
224
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
225
 
 
226
  class LlamoeAttention(nn.Module):
227
  """Multi-headed attention from 'Attention Is All You Need' paper"""
228
 
229
+ # Ignore copy
230
  def __init__(self, config: LlamoeConfig, layer_idx: Optional[int] = None):
231
  super().__init__()
232
  self.config = config
 
241
  self.attention_dropout = config.attention_dropout
242
  self.hidden_size = config.hidden_size
243
  self.num_heads = config.num_attention_heads
244
+ self.head_dim = config.head_dim
245
  self.num_key_value_heads = config.num_key_value_heads
246
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
247
  self.max_position_embeddings = config.max_position_embeddings
248
  self.rope_theta = config.rope_theta
249
  self.is_causal = True
250
 
251
+ if self.hidden_size % self.num_heads != 0:
252
  raise ValueError(
253
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
254
  f" and `num_heads`: {self.num_heads})."
 
257
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
258
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
259
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
260
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
261
+ self.rotary_emb = LlamoeRotaryEmbedding(
262
+ self.head_dim,
263
+ max_position_embeddings=self.max_position_embeddings,
264
+ base=self.rope_theta,
265
+ )
 
 
 
 
 
 
 
266
 
267
  def forward(
268
  self,
 
277
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
278
  bsz, q_len, _ = hidden_states.size()
279
 
280
+ query_states = self.q_proj(hidden_states)
281
+ key_states = self.k_proj(hidden_states)
282
+ value_states = self.v_proj(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
285
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
286
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
287
 
288
  past_key_value = getattr(self, "past_key_value", past_key_value)
289
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
290
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
291
 
292
  if past_key_value is not None:
293
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
 
300
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
301
 
302
  if attention_mask is not None: # no matter the length, we just slice it
 
303
  if cache_position is not None:
304
  causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
305
+ else:
306
+ causal_mask = attention_mask
307
  attn_weights = attn_weights + causal_mask
308
 
309
  # upcast attention to fp32
 
319
 
320
  attn_output = attn_output.transpose(1, 2).contiguous()
321
 
322
+ attn_output = attn_output.view(bsz, q_len, -1)
323
+ attn_output = self.o_proj(attn_output)
 
 
 
 
 
 
324
 
325
  if not output_attentions:
326
  attn_weights = None
 
328
  return attn_output, attn_weights, past_key_value
329
 
330
 
331
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Gemmoe
332
+ class LlamoeFlashAttention2(GemmoeAttention):
333
  """
334
+ Gemmoe flash attention module. This module inherits from `GemmoeAttention` as the weights of the module stays
335
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
336
  flash attention and deal with padding tokens in case the input contains any of them.
337
  """
 
344
  # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
345
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
346
 
347
+ # Ignore copy
348
  def forward(
349
  self,
350
  hidden_states: torch.Tensor,
 
371
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
372
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
373
 
374
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
375
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
376
 
377
  past_key_value = getattr(self, "past_key_value", past_key_value)
378
 
 
393
  # therefore the input hidden states gets silently casted in float32. Hence, we need
394
  # cast them back in the correct dtype just to be sure everything works as expected.
395
  # This might slowdown training & inference so it is recommended to not cast the LayerNorms
396
+ # in fp32. (GemmoeRMSNorm handles it correctly)
397
 
398
  input_dtype = query_states.dtype
399
  if input_dtype == torch.float32:
 
419
  query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
420
  )
421
 
422
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
423
  attn_output = self.o_proj(attn_output)
424
 
425
  if not output_attentions:
 
433
  """
434
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
435
  first unpad the input, then computes the attention scores and pad the final attention scores.
 
436
  Args:
437
  query_states (`torch.Tensor`):
438
  Input query states to be passed to Flash Attention API
 
451
  if not self._flash_attn_uses_top_left_mask:
452
  causal = self.is_causal
453
  else:
454
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in GemmoeFlashAttention2 __init__.
455
  causal = self.is_causal and query_length != 1
456
 
457
  # Contains at least one padding token in the sequence
 
524
  )
525
 
526
 
527
+ # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Gemmoe
528
  class LlamoeSdpaAttention(LlamoeAttention):
529
  """
530
  Gemmoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
 
546
  if output_attentions:
547
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
548
  logger.warning_once(
549
+ "GemmoeModel is using GemmoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
550
  'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
551
  )
552
  return super().forward(
 
582
  key_states = repeat_kv(key_states, self.num_key_value_groups)
583
  value_states = repeat_kv(value_states, self.num_key_value_groups)
584
 
585
+ causal_mask = attention_mask
586
+ if attention_mask is not None and cache_position is not None:
587
+ causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
588
 
589
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
590
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
 
593
  key_states = key_states.contiguous()
594
  value_states = value_states.contiguous()
595
 
 
 
 
 
 
596
  attn_output = torch.nn.functional.scaled_dot_product_attention(
597
  query_states,
598
  key_states,
 
607
  attn_output = self.o_proj(attn_output)
608
 
609
  return attn_output, None, past_key_value
610
+
611
 
612
  LLAMOE_ATTENTION_CLASSES = {
613
  "eager": LlamoeAttention,
614
  "flash_attention_2": LlamoeFlashAttention2,
615
+ "sdpa": LlamoeSdpaAttention,
616
  }
617
 
618
+ class LlamoeBlockSparseTop2MLP(nn.Module):
619
+ def __init__(self, config: GemmoeConfig):
620
+ super().__init__()
621
+ self.ffn_dim = config.intermediate_size
622
+ self.hidden_dim = config.hidden_size
623
+
624
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
625
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
626
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
627
+
628
+ self.act_fn = approx_gelu
629
+
630
+ def forward(self, hidden_states):
631
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
632
+ current_hidden_states = self.w2(current_hidden_states)
633
+ return current_hidden_states.to(hidden_states.dtype)
634
+
635
+
636
+ class LlamoeSparseMoeBlock(nn.Module):
637
+ def __init__(self, config):
638
+ super().__init__()
639
+ self.hidden_dim = config.hidden_size
640
+ self.ffn_dim = config.intermediate_size
641
+ self.num_experts = config.num_local_experts
642
+ self.top_k = 2
643
+
644
+ # gating
645
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
646
+
647
+ self.experts = nn.ModuleList([LlamoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
648
+
649
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
650
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
651
+ hidden_states = hidden_states.view(-1, hidden_dim)
652
+
653
+ # router_logits: (batch * sequence_length, n_experts)
654
+ router_logits = self.gate(hidden_states)
655
+ routing_weights = F.softmax(router_logits, dim=1)
656
+ topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
657
+ topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
658
+
659
+ hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
660
+
661
+ y = torch.empty_like(hidden_states)
662
+
663
+ flat_topk_idx = topk_idx.view(-1)
664
+ for i in range(self.num_experts):
665
+ expert = self.experts[i]
666
+ expert_output = expert(hidden_states[flat_topk_idx == i])
667
+ y[flat_topk_idx == i] = expert_output
668
+
669
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
670
+
671
+ final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
672
+ return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
673
 
674
+
675
+ # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMOE,Llama->Gemmoe
676
  class LlamoeDecoderLayer(nn.Module):
677
  def __init__(self, config: LlamoeConfig, layer_idx: int):
678
  super().__init__()
679
  self.hidden_size = config.hidden_size
680
 
681
+ self.self_attn = Llamoe_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
682
 
683
+ self.block_sparse_moe = LlamoeSparseMoeBlock(config)
684
  self.input_layernorm = LlamoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
685
  self.post_attention_layernorm = LlamoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
686
 
 
751
  return outputs
752
 
753
 
754
+ Llamoe_START_DOCSTRING = r"""
 
755
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
756
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
757
  etc.)
 
758
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
759
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
760
  and behavior.
 
761
  Parameters:
762
+ config ([`GemmoeConfig`]):
763
  Model configuration class with all the parameters of the model. Initializing with a config file does not
764
  load the weights associated with the model, only the configuration. Check out the
765
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
 
767
 
768
 
769
  @add_start_docstrings(
770
+ "The bare Gemmoe Model outputting raw hidden-states without any specific head on top.",
771
+ Llamoe_START_DOCSTRING,
772
  )
773
 
774
+ class LlamoePreTrainedModel(PreTrainedModel):
775
  config_class = LlamoeConfig
776
  base_model_prefix = "model"
777
  supports_gradient_checkpointing = True
 
815
  layer.self_attn.past_key_value = None
816
 
817
 
818
+ Llamoe_INPUTS_DOCSTRING = r"""
819
  Args:
820
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
821
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
 
879
 
880
  @add_start_docstrings(
881
  "The bare Gemmoe Model outputting raw hidden-states without any specific head on top.",
882
+ Llamoe_START_DOCSTRING,
883
  )
884
 
885
+ class LlamoeModel(LlamoePreTrainedModel):
886
  """
887
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmoeDecoderLayer`]
888
  Args:
889
+ config: GemmoeConfig
890
  """
891
 
892
  def __init__(self, config: LlamoeConfig):
 
916
  def set_input_embeddings(self, value):
917
  self.embed_tokens = value
918
 
919
+ @add_start_docstrings_to_model_forward(Llamoe_INPUTS_DOCSTRING)
920
  def forward(
921
  self,
922
  input_ids: torch.LongTensor = None,
 
1092
 
1093
  return causal_mask
1094
 
1095
+ class LlamoeForCausalLM(LlamoePreTrainedModel):
1096
  _tied_weights_keys = ["lm_head.weight"]
1097
 
1098
  def __init__(self, config):
1099
  super().__init__(config)
1100
+ self.model = GemmoeModel(config)
1101
  self.vocab_size = config.vocab_size
1102
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1103
  self.router_aux_loss_coef = config.router_aux_loss_coef
 
1124
  def get_decoder(self):
1125
  return self.model
1126
 
1127
+ @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
1128
  @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1129
  # Ignore copy
1130
  def forward(