eunhwanpark-motiftech commited on
Commit
a3eb76a
·
verified ·
1 Parent(s): 45c496d

fix-decoding (#5)

Browse files

- Update modeling_motif.py (1d1f01bba1e93bbed14c810b5b7487d271cae810)

Files changed (1) hide show
  1. modeling_motif.py +2 -2
modeling_motif.py CHANGED
@@ -399,7 +399,7 @@ class MotifAttention(nn.Module):
399
  "removed and `position_embeddings` will be mandatory.")
400
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
401
  else:
402
- cos, sin = (self.rotary_emb(value_states, q_len + past_key_value.get_usable_length(q_len, self.layer_idx))
403
  if use_cache else position_embeddings)
404
 
405
  query_states, key_states = apply_rotary_pos_emb(query_states,
@@ -534,7 +534,7 @@ class MotifFlashAttention2(MotifAttention):
534
  "removed and `position_embeddings` will be mandatory.")
535
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
536
  else:
537
- cos, sin = (self.rotary_emb(value_states, q_len + past_key_value.get_usable_length(q_len, self.layer_idx))
538
  if use_cache else position_embeddings)
539
 
540
  query_states, key_states = apply_rotary_pos_emb(query_states,
 
399
  "removed and `position_embeddings` will be mandatory.")
400
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
401
  else:
402
+ cos, sin = (self.rotary_emb(value_states, q_len + past_key_value.get_seq_length(q_len, self.layer_idx))
403
  if use_cache else position_embeddings)
404
 
405
  query_states, key_states = apply_rotary_pos_emb(query_states,
 
534
  "removed and `position_embeddings` will be mandatory.")
535
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
536
  else:
537
+ cos, sin = (self.rotary_emb(value_states, q_len + past_key_value.get_seq_length(q_len, self.layer_idx))
538
  if use_cache else position_embeddings)
539
 
540
  query_states, key_states = apply_rotary_pos_emb(query_states,