Tousifahamed commited on
Commit
3afe7b3
·
verified ·
1 Parent(s): 59470db

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +73 -0
  2. model.py +304 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer
3
+ from model import TransformerModel # Replace with your model class
4
+ import gradio as gr
5
+
6
+ # Load the tokenizer
7
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
8
+
9
+ # Load the model
10
+ def load_model(checkpoint_path):
11
+ # Initialize the model (replace with your model's configuration)
12
+ model = TransformerModel(
13
+ vocab_size=49152,
14
+ hidden_size=576,
15
+ num_hidden_layers=30,
16
+ num_attention_heads=9,
17
+ intermediate_size=1536,
18
+ num_key_value_heads=3,
19
+ max_position_embeddings=2048,
20
+ rms_norm_eps=1e-5,
21
+ hidden_act="silu",
22
+ tie_word_embeddings=True,
23
+ pad_token_id=tokenizer.pad_token_id,
24
+ )
25
+
26
+ # Load the checkpoint
27
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
28
+ model.load_state_dict(checkpoint["model_state_dict"])
29
+ model.eval()
30
+ return model
31
+
32
+ # Load the model
33
+ model = load_model("checkpoint_5050_quantized.pt")
34
+
35
+ # Function to generate text
36
+ def generate_text(prompt, max_length=50, temperature=1.0, top_k=50):
37
+ # Encode the prompt
38
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
39
+
40
+ # Generate text
41
+ with torch.no_grad():
42
+ output_ids = model.generate(
43
+ input_ids,
44
+ max_length=max_length,
45
+ temperature=temperature,
46
+ top_k=top_k,
47
+ do_sample=True,
48
+ )
49
+
50
+ # Decode the generated text
51
+ generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
52
+ return generated_text
53
+
54
+ # Gradio Interface
55
+ def gradio_generate_text(prompt, max_length, temperature, top_k):
56
+ return generate_text(prompt, max_length, temperature, top_k)
57
+
58
+ # Create the Gradio app
59
+ interface = gr.Interface(
60
+ fn=gradio_generate_text,
61
+ inputs=[
62
+ gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."),
63
+ gr.Slider(minimum=10, maximum=200, value=50, label="Max Length"),
64
+ gr.Slider(minimum=0.1, maximum=2.0, value=1.0, label="Temperature"),
65
+ gr.Slider(minimum=1, maximum=100, value=50, label="Top-k Sampling"),
66
+ ],
67
+ outputs=gr.Textbox(label="Generated Text"),
68
+ title="Text Generation with SMOL-LM2",
69
+ description="Generate text using the SMOL-LM2 model.",
70
+ )
71
+
72
+ # Launch the app
73
+ interface.launch()
model.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional
5
+
6
+ class RMSNorm(nn.Module):
7
+ """
8
+ Root Mean Square Layer Normalization (RMSNorm).
9
+ """
10
+ def __init__(self, hidden_size: int, eps: float = 1e-5):
11
+ super().__init__()
12
+ self.weight = nn.Parameter(torch.ones(hidden_size))
13
+ self.eps = eps
14
+
15
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
16
+ variance = x.pow(2).mean(-1, keepdim=True)
17
+ x = x * torch.rsqrt(variance + self.eps)
18
+ return self.weight * x
19
+
20
+ class RotaryPositionalEmbedding(nn.Module):
21
+ """
22
+ Rotary Positional Embedding (RoPE) for transformers.
23
+ """
24
+ def __init__(self, dim: int, theta: float = 10000.0):
25
+ super().__init__()
26
+ self.dim = dim
27
+ self.theta = theta
28
+
29
+ def forward(self, x: torch.Tensor, seq_len: int) -> torch.Tensor:
30
+ """
31
+ Apply rotary positional embedding to the input tensor.
32
+
33
+ Args:
34
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, num_heads, head_dim).
35
+ seq_len (int): Sequence length.
36
+
37
+ Returns:
38
+ torch.Tensor: Output tensor with rotary positional embeddings applied.
39
+ """
40
+ batch_size, seq_len, num_heads, head_dim = x.shape
41
+
42
+ # Generate position indices
43
+ position = torch.arange(seq_len, dtype=torch.float32, device=x.device).unsqueeze(-1)
44
+
45
+ # Generate frequencies
46
+ freqs = torch.exp(
47
+ torch.arange(0, head_dim, 2, dtype=torch.float32, device=x.device) * -(torch.log(torch.tensor(self.theta)) / head_dim)
48
+ )
49
+
50
+ # Compute sinusoids
51
+ sinusoid = position * freqs
52
+ sin = torch.sin(sinusoid)
53
+ cos = torch.cos(sinusoid)
54
+
55
+ # Reshape sin and cos to match the input tensor's shape
56
+ sin = sin.unsqueeze(0).unsqueeze(2) # Shape: (1, seq_len, 1, head_dim // 2)
57
+ cos = cos.unsqueeze(0).unsqueeze(2) # Shape: (1, seq_len, 1, head_dim // 2)
58
+
59
+ # Apply rotary embeddings
60
+ x_rotated = x.clone()
61
+ x_rotated[..., 0::2] = x[..., 0::2] * cos - x[..., 1::2] * sin
62
+ x_rotated[..., 1::2] = x[..., 1::2] * cos + x[..., 0::2] * sin
63
+
64
+ return x_rotated
65
+
66
+ from torch.utils.checkpoint import checkpoint
67
+
68
+ class TransformerBlock(nn.Module):
69
+ """
70
+ A single transformer block with self-attention and feed-forward layers.
71
+ """
72
+ def __init__(
73
+ self,
74
+ hidden_size: int,
75
+ num_attention_heads: int,
76
+ intermediate_size: int,
77
+ num_key_value_heads: int,
78
+ rms_norm_eps: float,
79
+ hidden_act: str = "silu",
80
+ ):
81
+ super().__init__()
82
+ self.hidden_size = hidden_size
83
+ self.num_attention_heads = num_attention_heads
84
+ self.num_key_value_heads = num_key_value_heads
85
+ self.head_dim = hidden_size // num_attention_heads
86
+
87
+ # Ensure the hidden size is divisible by the number of attention heads
88
+ if hidden_size % num_attention_heads != 0:
89
+ raise ValueError(
90
+ f"hidden_size ({hidden_size}) must be divisible by num_attention_heads ({num_attention_heads})"
91
+ )
92
+
93
+ # Self-attention layers
94
+ self.q_proj = nn.Linear(hidden_size, hidden_size)
95
+ self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim)
96
+ self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim)
97
+ self.o_proj = nn.Linear(hidden_size, hidden_size)
98
+
99
+ # Feed-forward layers
100
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size)
101
+ self.up_proj = nn.Linear(hidden_size, intermediate_size)
102
+ self.down_proj = nn.Linear(intermediate_size, hidden_size)
103
+
104
+ # Normalization layers
105
+ self.input_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
106
+ self.post_attention_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
107
+
108
+ # Activation function
109
+ self.act = nn.SiLU() if hidden_act == "silu" else nn.GELU()
110
+
111
+ # Rotary positional embedding
112
+ self.rope = RotaryPositionalEmbedding(self.head_dim)
113
+
114
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
115
+ def create_custom_forward(module):
116
+ def custom_forward(*inputs):
117
+ return module._forward(inputs[0], inputs[1])
118
+ return custom_forward
119
+
120
+ # Use gradient checkpointing
121
+ return checkpoint(create_custom_forward(self), x, attention_mask)
122
+
123
+ def _forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
124
+ # Self-attention
125
+ residual = x
126
+ x = self.input_norm(x)
127
+
128
+ # Project inputs to query, key, and value
129
+ batch_size, seq_len, _ = x.shape
130
+
131
+ # Reshape queries for multi-head attention
132
+ q = self.q_proj(x).view(batch_size, seq_len, self.num_attention_heads, self.head_dim)
133
+
134
+ # Reshape keys and values for key-value heads
135
+ k = self.k_proj(x).view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
136
+ v = self.v_proj(x).view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
137
+
138
+ # Apply rotary positional embedding
139
+ q = self.rope(q, seq_len)
140
+ k = self.rope(k, seq_len)
141
+
142
+ # Scaled dot-product attention
143
+ attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
144
+ attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_size)
145
+ attn_output = self.o_proj(attn_output)
146
+
147
+ # Add residual connection
148
+ x = residual + attn_output
149
+
150
+ # Feed-forward network
151
+ residual = x
152
+ x = self.post_attention_norm(x)
153
+ gate = self.act(self.gate_proj(x))
154
+ up = self.up_proj(x)
155
+ ff_output = self.down_proj(gate * up)
156
+
157
+ # Add residual connection
158
+ x = residual + ff_output
159
+
160
+ return x
161
+
162
+ class TransformerModel(nn.Module):
163
+ """
164
+ The full transformer model with multiple layers.
165
+ """
166
+ def __init__(
167
+ self,
168
+ vocab_size: int,
169
+ hidden_size: int,
170
+ num_hidden_layers: int,
171
+ num_attention_heads: int,
172
+ intermediate_size: int,
173
+ num_key_value_heads: int,
174
+ max_position_embeddings: int,
175
+ rms_norm_eps: float,
176
+ hidden_act: str = "silu",
177
+ tie_word_embeddings: bool = True,
178
+ ):
179
+ super().__init__()
180
+ self.vocab_size = vocab_size
181
+ self.hidden_size = hidden_size
182
+ self.num_hidden_layers = num_hidden_layers
183
+ self.max_position_embeddings = max_position_embeddings
184
+
185
+ # Embedding layers
186
+ self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
187
+ self.embed_positions = nn.Embedding(max_position_embeddings, hidden_size)
188
+
189
+ # Transformer blocks
190
+ self.layers = nn.ModuleList([
191
+ TransformerBlock(
192
+ hidden_size=hidden_size,
193
+ num_attention_heads=num_attention_heads,
194
+ intermediate_size=intermediate_size,
195
+ num_key_value_heads=num_key_value_heads,
196
+ rms_norm_eps=rms_norm_eps,
197
+ hidden_act=hidden_act,
198
+ )
199
+ for _ in range(num_hidden_layers)
200
+ ])
201
+
202
+ # Final normalization layer
203
+ self.final_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
204
+
205
+ # Output layer (tied to input embeddings if specified)
206
+ self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
207
+ if tie_word_embeddings:
208
+ self.lm_head.weight = self.embed_tokens.weight
209
+
210
+ def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
211
+ # Embed tokens and positions
212
+ seq_len = input_ids.size(1)
213
+ position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
214
+ token_embeddings = self.embed_tokens(input_ids)
215
+ position_embeddings = self.embed_positions(position_ids)
216
+ x = token_embeddings + position_embeddings
217
+
218
+ # Pass through transformer layers
219
+ for layer in self.layers:
220
+ x = layer(x, attention_mask)
221
+
222
+ # Final normalization
223
+ x = self.final_norm(x)
224
+
225
+ # Output logits
226
+ logits = self.lm_head(x)
227
+ return logits
228
+
229
+ def generate(
230
+ self,
231
+ input_ids: torch.Tensor,
232
+ max_length: int = 50,
233
+ temperature: float = 1.0,
234
+ top_k: int = 50,
235
+ do_sample: bool = True,
236
+ ) -> torch.Tensor:
237
+ """
238
+ Generate text autoregressively.
239
+
240
+ Args:
241
+ input_ids (torch.Tensor): Input token IDs of shape (batch_size, seq_len).
242
+ max_length (int): Maximum length of the generated sequence.
243
+ temperature (float): Sampling temperature. Higher values mean more random sampling.
244
+ top_k (int): Top-k sampling. Only the top-k tokens are considered.
245
+ do_sample (bool): Whether to sample from the distribution or take the argmax.
246
+
247
+ Returns:
248
+ torch.Tensor: Generated token IDs of shape (batch_size, max_length).
249
+ """
250
+ self.eval()
251
+ with torch.no_grad():
252
+ for _ in range(max_length - input_ids.size(1)):
253
+ # Get the logits for the last token
254
+ logits = self(input_ids)[:, -1, :]
255
+
256
+ # Apply temperature
257
+ logits = logits / temperature
258
+
259
+ # Top-k sampling
260
+ if top_k > 0:
261
+ top_k_values, top_k_indices = torch.topk(logits, top_k)
262
+ logits[logits < top_k_values[:, -1].unsqueeze(-1)] = -float("Inf")
263
+
264
+ # Convert logits to probabilities
265
+ probs = F.softmax(logits, dim=-1)
266
+
267
+ # Sample or take the argmax
268
+ if do_sample:
269
+ next_token = torch.multinomial(probs, num_samples=1)
270
+ else:
271
+ next_token = torch.argmax(probs, dim=-1, keepdim=True)
272
+
273
+ # Append the next token to the input_ids
274
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
275
+
276
+ return input_ids
277
+
278
+ # Create the model based on the configuration
279
+ def create_model_from_config(config: dict) -> TransformerModel:
280
+ model_config = config["model"]["model_config"]
281
+ return TransformerModel(
282
+ vocab_size=model_config["vocab_size"],
283
+ hidden_size=model_config["hidden_size"],
284
+ num_hidden_layers=model_config["num_hidden_layers"],
285
+ num_attention_heads=model_config["num_attention_heads"],
286
+ intermediate_size=model_config["intermediate_size"],
287
+ num_key_value_heads=model_config["num_key_value_heads"],
288
+ max_position_embeddings=model_config["max_position_embeddings"],
289
+ rms_norm_eps=model_config["rms_norm_eps"],
290
+ hidden_act=model_config["hidden_act"],
291
+ tie_word_embeddings=model_config["tie_word_embeddings"],
292
+ )
293
+
294
+ # Example usage
295
+ if __name__ == "__main__":
296
+ import json
297
+
298
+ # Load the configuration file
299
+ with open("config_smollm2_135M.json", "r") as f:
300
+ config = json.load(f)
301
+
302
+ # Create the model
303
+ model = create_model_from_config(config)
304
+ print(model)