Text Generation
Safetensors
imp
custom_code
iambestfeed commited on
Commit
cfc5695
·
verified ·
1 Parent(s): dcf1608

Update modeling_imp.py

Browse files
Files changed (1) hide show
  1. modeling_imp.py +11 -0
modeling_imp.py CHANGED
@@ -63,6 +63,17 @@ except:
63
 
64
  logger = logging.get_logger(__name__)
65
 
 
 
 
 
 
 
 
 
 
 
 
66
  # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi
67
  class PhiRotaryEmbedding(nn.Module):
68
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
 
63
 
64
  logger = logging.get_logger(__name__)
65
 
66
+ def _get_unpad_data(attention_mask):
67
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
68
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
69
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
70
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) # pylint: disable=E1102
71
+ return (
72
+ indices,
73
+ cu_seqlens,
74
+ max_seqlen_in_batch,
75
+ )
76
+
77
  # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi
78
  class PhiRotaryEmbedding(nn.Module):
79
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):