Keeby-smilyai commited on
Commit
72d9cc3
·
verified ·
1 Parent(s): 88cf0cd

Create modeling_sam2.py

Browse files
Files changed (1) hide show
  1. modeling_sam2.py +155 -0
modeling_sam2.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # modeling_sam2.py
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Tuple, Union
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.nn import CrossEntropyLoss
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+ from transformers.modeling_utils import PreTrainedModel
13
+ from transformers.configuration_utils import PretrainedConfig
14
+
15
+ # -----------------------------
16
+ # Config
17
+ # -----------------------------
18
+ @dataclass
19
+ class Sam2Config(PretrainedConfig):
20
+ model_type = "sam2"
21
+ vocab_size: int = 50257
22
+ d_model: int = 384
23
+ n_layers: int = 6
24
+ n_heads: int = 6
25
+ ff_mult: float = 4.0
26
+ dropout: float = 0.1
27
+ pad_token_id: int = 50256 # default GPT-2 eos
28
+ bos_token_id: int = 50256
29
+ eos_token_id: int = 50256
30
+
31
+ # -----------------------------
32
+ # Building blocks
33
+ # -----------------------------
34
+ class RMSNorm(nn.Module):
35
+ def __init__(self, d, eps=1e-6):
36
+ super().__init__()
37
+ self.eps = eps
38
+ self.weight = nn.Parameter(torch.ones(d))
39
+ def forward(self, x):
40
+ norm = x.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
41
+ return self.weight * x * norm
42
+
43
+ class MHA(nn.Module):
44
+ def __init__(self, d_model, n_heads, dropout=0.0):
45
+ super().__init__()
46
+ self.n_heads = n_heads
47
+ self.head_dim = d_model // n_heads
48
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
49
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
50
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
51
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
52
+ self.dropout = nn.Dropout(dropout)
53
+ def forward(self, x, attn_mask=None):
54
+ B, T, C = x.shape
55
+ q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
56
+ k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
57
+ v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
58
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
59
+ causal = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
60
+ scores = scores.masked_fill(causal, float("-inf"))
61
+ if attn_mask is not None:
62
+ key_mask = attn_mask.unsqueeze(1).unsqueeze(2)
63
+ scores = scores.masked_fill(~key_mask.bool(), float("-inf"))
64
+ attn = F.softmax(scores, dim=-1)
65
+ out = torch.matmul(self.dropout(attn), v).transpose(1, 2).contiguous().view(B, T, C)
66
+ return self.out_proj(out)
67
+
68
+ class SwiGLU(nn.Module):
69
+ def __init__(self, d_model, d_ff, dropout=0.0):
70
+ super().__init__()
71
+ self.w1 = nn.Linear(d_model, d_ff, bias=False)
72
+ self.w2 = nn.Linear(d_model, d_ff, bias=False)
73
+ self.w3 = nn.Linear(d_ff, d_model, bias=False)
74
+ self.dropout = nn.Dropout(dropout)
75
+ def forward(self, x):
76
+ return self.w3(self.dropout(F.silu(self.w1(x)) * self.w2(x)))
77
+
78
+ class Block(nn.Module):
79
+ def __init__(self, d_model, n_heads, ff_mult, dropout=0.0):
80
+ super().__init__()
81
+ self.norm1 = RMSNorm(d_model)
82
+ self.attn = MHA(d_model, n_heads, dropout=dropout)
83
+ self.norm2 = RMSNorm(d_model)
84
+ self.ff = SwiGLU(d_model, int(ff_mult * d_model), dropout=dropout)
85
+ self.drop = nn.Dropout(dropout)
86
+ def forward(self, x, attn_mask=None):
87
+ x = x + self.drop(self.attn(self.norm1(x), attn_mask=attn_mask))
88
+ x = x + self.drop(self.ff(self.norm2(x)))
89
+ return x
90
+
91
+ # -----------------------------
92
+ # Main model
93
+ # -----------------------------
94
+ class Sam2PreTrainedModel(PreTrainedModel):
95
+ config_class = Sam2Config
96
+ base_model_prefix = "sam2"
97
+ supports_gradient_checkpointing = False
98
+
99
+ def _init_weights(self, module):
100
+ if isinstance(module, nn.Linear):
101
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
102
+ if module.bias is not None:
103
+ nn.init.zeros_(module.bias)
104
+ elif isinstance(module, nn.Embedding):
105
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
106
+
107
+ class Sam2Model(Sam2PreTrainedModel):
108
+ def __init__(self, config: Sam2Config):
109
+ super().__init__(config)
110
+ self.embed = nn.Embedding(config.vocab_size, config.d_model)
111
+ self.blocks = nn.ModuleList([
112
+ Block(config.d_model, config.n_heads, config.ff_mult, dropout=config.dropout)
113
+ for _ in range(config.n_layers)
114
+ ])
115
+ self.norm = RMSNorm(config.d_model)
116
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
117
+ self.lm_head.weight = self.embed.weight
118
+ self.dropout = nn.Dropout(config.dropout)
119
+ self.post_init()
120
+
121
+ def forward(
122
+ self,
123
+ input_ids: torch.LongTensor,
124
+ attention_mask: Optional[torch.Tensor] = None,
125
+ labels: Optional[torch.LongTensor] = None,
126
+ **kwargs
127
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
128
+ x = self.embed(input_ids)
129
+ for blk in self.blocks:
130
+ x = blk(x, attn_mask=attention_mask)
131
+ x = self.norm(x)
132
+ logits = self.lm_head(x)
133
+
134
+ loss = None
135
+ if labels is not None:
136
+ shift_logits = logits[:, :-1, :].contiguous()
137
+ shift_labels = labels[:, 1:].contiguous()
138
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
139
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
140
+ shift_labels.view(-1))
141
+
142
+ return CausalLMOutputWithPast(
143
+ loss=loss,
144
+ logits=logits,
145
+ past_key_values=None,
146
+ hidden_states=None,
147
+ attentions=None,
148
+ )
149
+
150
+ # -----------------------------
151
+ # AutoModel registration
152
+ # -----------------------------
153
+ from transformers import AutoConfig, AutoModelForCausalLM
154
+ AutoConfig.register("sam2", Sam2Config)
155
+ AutoModelForCausalLM.register(Sam2Config, Sam2Model)