Tousifahamed commited on
Commit
5013b2b
·
verified ·
1 Parent(s): 2f16a35

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +40 -0
  2. model.py +301 -0
  3. model_weights_fp16.pt +3 -0
  4. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from model import TransformerModel # or however you define your model classes
4
+ from transformers import AutoTokenizer
5
+ import gradio as gr
6
+
7
+ # Load half-precision state_dict
8
+ checkpoint = torch.load("model_weights_fp16.pt", map_location="cpu")
9
+ state_dict_fp16 = checkpoint["model_state_dict"]
10
+
11
+ # Create model in FP16
12
+ model = TransformerModel(
13
+ vocab_size=49152,
14
+ hidden_size=576,
15
+ num_hidden_layers=30,
16
+ num_attention_heads=9,
17
+ intermediate_size=1536,
18
+ num_key_value_heads=3,
19
+ max_position_embeddings=2048,
20
+ rms_norm_eps=1e-5,
21
+ hidden_act="silu",
22
+ tie_word_embeddings=True,
23
+ )
24
+
25
+ # Convert model to half precision
26
+ model.half()
27
+
28
+ # Load the half-precision weights
29
+ model.load_state_dict(state_dict_fp16, strict=False)
30
+ model.eval()
31
+
32
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
33
+
34
+ def generate_text(prompt, max_length=50):
35
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
36
+ with torch.no_grad():
37
+ output_ids = model.generate(input_ids, max_length=max_length, do_sample=True)
38
+ return tokenizer.decode(output_ids[0], skip_special_tokens=True)
39
+
40
+ gr.Interface(fn=generate_text, inputs="text", outputs="text").launch()
model.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional
5
+
6
+ class RMSNorm(nn.Module):
7
+ """
8
+ Root Mean Square Layer Normalization (RMSNorm).
9
+ """
10
+ def __init__(self, hidden_size: int, eps: float = 1e-5):
11
+ super().__init__()
12
+ self.weight = nn.Parameter(torch.ones(hidden_size))
13
+ self.eps = eps
14
+
15
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
16
+ variance = x.pow(2).mean(-1, keepdim=True)
17
+ x = x * torch.rsqrt(variance + self.eps)
18
+ return self.weight * x
19
+
20
+ class RotaryPositionalEmbedding(nn.Module):
21
+ """
22
+ Rotary Positional Embedding (RoPE) for transformers.
23
+ """
24
+ def __init__(self, dim: int, theta: float = 10000.0):
25
+ super().__init__()
26
+ self.dim = dim
27
+ self.theta = theta
28
+
29
+ def forward(self, x: torch.Tensor, seq_len: int) -> torch.Tensor:
30
+ """
31
+ Apply rotary positional embedding to the input tensor.
32
+
33
+ Args:
34
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, num_heads, head_dim).
35
+ seq_len (int): Sequence length.
36
+
37
+ Returns:
38
+ torch.Tensor: Output tensor with rotary positional embeddings applied.
39
+ """
40
+ batch_size, seq_len, num_heads, head_dim = x.shape
41
+
42
+ # Generate position indices
43
+ position = torch.arange(seq_len, dtype=torch.float32, device=x.device).unsqueeze(-1)
44
+
45
+ # Generate frequencies
46
+ freqs = torch.exp(
47
+ torch.arange(0, head_dim, 2, dtype=torch.float32, device=x.device) * -(torch.log(torch.tensor(self.theta)) / head_dim)
48
+ )
49
+
50
+ # Compute sinusoids
51
+ sinusoid = position * freqs
52
+ sin = torch.sin(sinusoid)
53
+ cos = torch.cos(sinusoid)
54
+
55
+ # Reshape sin and cos to match the input tensor's shape
56
+ sin = sin.unsqueeze(0).unsqueeze(2) # Shape: (1, seq_len, 1, head_dim // 2)
57
+ cos = cos.unsqueeze(0).unsqueeze(2) # Shape: (1, seq_len, 1, head_dim // 2)
58
+
59
+ # Apply rotary embeddings
60
+ x_rotated = x.clone()
61
+ x_rotated[..., 0::2] = x[..., 0::2] * cos - x[..., 1::2] * sin
62
+ x_rotated[..., 1::2] = x[..., 1::2] * cos + x[..., 0::2] * sin
63
+
64
+ return x_rotated
65
+
66
+ from torch.utils.checkpoint import checkpoint
67
+
68
+ class TransformerBlock(nn.Module):
69
+ """
70
+ A single transformer block with self-attention and feed-forward layers.
71
+ """
72
+ def __init__(
73
+ self,
74
+ hidden_size: int,
75
+ num_attention_heads: int,
76
+ intermediate_size: int,
77
+ num_key_value_heads: int,
78
+ rms_norm_eps: float,
79
+ hidden_act: str = "silu",
80
+ ):
81
+ super().__init__()
82
+ self.hidden_size = hidden_size
83
+ self.num_attention_heads = num_attention_heads
84
+ self.num_key_value_heads = num_key_value_heads
85
+ self.head_dim = hidden_size // num_attention_heads
86
+
87
+ # Ensure the hidden size is divisible by the number of attention heads
88
+ if hidden_size % num_attention_heads != 0:
89
+ raise ValueError(
90
+ f"hidden_size ({hidden_size}) must be divisible by num_attention_heads ({num_attention_heads})"
91
+ )
92
+
93
+ # Self-attention layers
94
+ self.q_proj = nn.Linear(hidden_size, hidden_size)
95
+ self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim)
96
+ self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim)
97
+ self.o_proj = nn.Linear(hidden_size, hidden_size)
98
+
99
+ # Feed-forward layers
100
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size)
101
+ self.up_proj = nn.Linear(hidden_size, intermediate_size)
102
+ self.down_proj = nn.Linear(intermediate_size, hidden_size)
103
+
104
+ # Normalization layers
105
+ self.input_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
106
+ self.post_attention_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
107
+
108
+ # Activation function
109
+ self.act = nn.SiLU() if hidden_act == "silu" else nn.GELU()
110
+
111
+ # Rotary positional embedding
112
+ self.rope = RotaryPositionalEmbedding(self.head_dim)
113
+
114
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
115
+ def create_custom_forward(module):
116
+ def custom_forward(*inputs):
117
+ return module._forward(inputs[0], inputs[1])
118
+ return custom_forward
119
+
120
+ # Use gradient checkpointing
121
+ return checkpoint(create_custom_forward(self), x, attention_mask)
122
+
123
+ def _forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
124
+ # Self-attention
125
+ residual = x
126
+ x = self.input_norm(x)
127
+
128
+ # Project inputs to query, key, and value
129
+ batch_size, seq_len, _ = x.shape
130
+
131
+ # Reshape queries for multi-head attention
132
+ q = self.q_proj(x).view(batch_size, seq_len, self.num_attention_heads, self.head_dim)
133
+
134
+ # Reshape keys and values for key-value heads
135
+ k = self.k_proj(x).view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
136
+ v = self.v_proj(x).view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
137
+
138
+ # Apply rotary positional embedding
139
+ q = self.rope(q, seq_len)
140
+ k = self.rope(k, seq_len)
141
+
142
+ # Scaled dot-product attention
143
+ attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
144
+ attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_size)
145
+ attn_output = self.o_proj(attn_output)
146
+
147
+ # Add residual connection
148
+ x = residual + attn_output
149
+
150
+ # Feed-forward network
151
+ residual = x
152
+ x = self.post_attention_norm(x)
153
+ gate = self.act(self.gate_proj(x))
154
+ up = self.up_proj(x)
155
+ ff_output = self.down_proj(gate * up)
156
+
157
+ # Add residual connection
158
+ x = residual + ff_output
159
+
160
+ return x
161
+
162
+ class TransformerModel(nn.Module):
163
+ def __init__(
164
+ self,
165
+ vocab_size: int,
166
+ hidden_size: int,
167
+ num_hidden_layers: int,
168
+ num_attention_heads: int,
169
+ intermediate_size: int,
170
+ num_key_value_heads: int,
171
+ max_position_embeddings: int,
172
+ rms_norm_eps: float,
173
+ hidden_act: str = "silu",
174
+ tie_word_embeddings: bool = True,
175
+ ):
176
+ super().__init__()
177
+ self.vocab_size = vocab_size
178
+ self.hidden_size = hidden_size
179
+ self.num_hidden_layers = num_hidden_layers
180
+ self.max_position_embeddings = max_position_embeddings
181
+
182
+ # Embedding layers (skip quantization for these)
183
+ self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
184
+ self.embed_positions = nn.Embedding(max_position_embeddings, hidden_size)
185
+
186
+ # Transformer blocks
187
+ self.layers = nn.ModuleList([
188
+ TransformerBlock(
189
+ hidden_size=hidden_size,
190
+ num_attention_heads=num_attention_heads,
191
+ intermediate_size=intermediate_size,
192
+ num_key_value_heads=num_key_value_heads,
193
+ rms_norm_eps=rms_norm_eps,
194
+ hidden_act=hidden_act,
195
+ )
196
+ for _ in range(num_hidden_layers)
197
+ ])
198
+
199
+ # Final normalization layer
200
+ self.final_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
201
+
202
+ # Output layer (tied to input embeddings if specified)
203
+ self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
204
+ if tie_word_embeddings:
205
+ self.lm_head.weight = self.embed_tokens.weight
206
+
207
+ def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
208
+ # Embed tokens and positions
209
+ seq_len = input_ids.size(1)
210
+ position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
211
+ token_embeddings = self.embed_tokens(input_ids)
212
+ position_embeddings = self.embed_positions(position_ids)
213
+ x = token_embeddings + position_embeddings
214
+
215
+ # Pass through transformer layers
216
+ for layer in self.layers:
217
+ x = layer(x, attention_mask)
218
+
219
+ # Final normalization
220
+ x = self.final_norm(x)
221
+
222
+ # Output logits
223
+ logits = self.lm_head(x)
224
+ return logits
225
+
226
+ def generate(
227
+ self,
228
+ input_ids: torch.Tensor,
229
+ max_length: int = 50,
230
+ temperature: float = 1.0,
231
+ top_k: int = 50,
232
+ do_sample: bool = True,
233
+ ) -> torch.Tensor:
234
+ """
235
+ Generate text autoregressively.
236
+
237
+ Args:
238
+ input_ids (torch.Tensor): Input token IDs of shape (batch_size, seq_len).
239
+ max_length (int): Maximum length of the generated sequence.
240
+ temperature (float): Sampling temperature. Higher values mean more random sampling.
241
+ top_k (int): Top-k sampling. Only the top-k tokens are considered.
242
+ do_sample (bool): Whether to sample from the distribution or take the argmax.
243
+
244
+ Returns:
245
+ torch.Tensor: Generated token IDs of shape (batch_size, max_length).
246
+ """
247
+ self.eval()
248
+ with torch.no_grad():
249
+ for _ in range(max_length - input_ids.size(1)):
250
+ # Get the logits for the last token
251
+ logits = self(input_ids)[:, -1, :]
252
+
253
+ # Apply temperature
254
+ logits = logits / temperature
255
+
256
+ # Top-k sampling
257
+ if top_k > 0:
258
+ top_k_values, top_k_indices = torch.topk(logits, top_k)
259
+ logits[logits < top_k_values[:, -1].unsqueeze(-1)] = -float("Inf")
260
+
261
+ # Convert logits to probabilities
262
+ probs = F.softmax(logits, dim=-1)
263
+
264
+ # Sample or take the argmax
265
+ if do_sample:
266
+ next_token = torch.multinomial(probs, num_samples=1)
267
+ else:
268
+ next_token = torch.argmax(probs, dim=-1, keepdim=True)
269
+
270
+ # Append the next token to the input_ids
271
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
272
+
273
+ return input_ids
274
+
275
+ # Create the model based on the configuration
276
+ def create_model_from_config(config: dict) -> TransformerModel:
277
+ model_config = config["model"]["model_config"]
278
+ return TransformerModel(
279
+ vocab_size=model_config["vocab_size"],
280
+ hidden_size=model_config["hidden_size"],
281
+ num_hidden_layers=model_config["num_hidden_layers"],
282
+ num_attention_heads=model_config["num_attention_heads"],
283
+ intermediate_size=model_config["intermediate_size"],
284
+ num_key_value_heads=model_config["num_key_value_heads"],
285
+ max_position_embeddings=model_config["max_position_embeddings"],
286
+ rms_norm_eps=model_config["rms_norm_eps"],
287
+ hidden_act=model_config["hidden_act"],
288
+ tie_word_embeddings=model_config["tie_word_embeddings"],
289
+ )
290
+
291
+ # Example usage
292
+ if __name__ == "__main__":
293
+ import json
294
+
295
+ # Load the configuration file
296
+ with open("config_smollm2_135M.json", "r") as f:
297
+ config = json.load(f)
298
+
299
+ # Create the model
300
+ model = create_model_from_config(config)
301
+ print(model)
model_weights_fp16.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11150c6cc5e3e2602aa0b04724f581c96b936b71ef5c95e5a52da4001b28fa49
3
+ size 328474466
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio