Update modeling_imp.py
Browse files- modeling_imp.py +11 -0
modeling_imp.py
CHANGED
@@ -63,6 +63,17 @@ except:
|
|
63 |
|
64 |
logger = logging.get_logger(__name__)
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi
|
67 |
class PhiRotaryEmbedding(nn.Module):
|
68 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
|
|
63 |
|
64 |
logger = logging.get_logger(__name__)
|
65 |
|
66 |
+
def _get_unpad_data(attention_mask):
|
67 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
68 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
69 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
70 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) # pylint: disable=E1102
|
71 |
+
return (
|
72 |
+
indices,
|
73 |
+
cu_seqlens,
|
74 |
+
max_seqlen_in_batch,
|
75 |
+
)
|
76 |
+
|
77 |
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi
|
78 |
class PhiRotaryEmbedding(nn.Module):
|
79 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|