crpatel commited on
Commit
fb26382
·
1 Parent(s): 3fb38b7

gradio app

Browse files
Files changed (7) hide show
  1. SmolLm3.py +236 -0
  2. app.py +148 -0
  3. config_smollm2_135M.yaml +103 -0
  4. model_testing.py +79 -0
  5. model_weights_35000_step.pt +3 -0
  6. requirements.txt +12 -0
  7. train.py +351 -0
SmolLm3.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import SiLU
5
+ import yaml
6
+ # from gptdataloader import create_dataloader_v1
7
+ # from chapter5 import calc_loss_loader, calculate_loss_batch
8
+
9
+
10
+ def _init_weights(module, std=0.041666666666666664):
11
+ if isinstance(module, nn.Linear):
12
+ module.weight.data.normal_(mean=0.0, std=std)
13
+ elif isinstance(module, nn.Embedding):
14
+ module.weight.data.normal_(mean=0.0, std=std)
15
+
16
+ class RotaryPositionalEmbedding(nn.Module):
17
+ """
18
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L240
19
+ Rotary Positional Embedding (RoPE) for transformers Implemntation derived from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
20
+ """
21
+ def __init__(self, dim: int, theta: float = 10000.0):
22
+ super().__init__()
23
+ self.dim = dim
24
+ self.theta = theta
25
+
26
+ def forward(self, x: torch.Tensor, seq_len: int) -> torch.Tensor:
27
+ """
28
+ Apply rotary positional embedding to the input tensor.
29
+
30
+ Args:
31
+ x (torch.Tensor): Input tensor of shape # B, T, H, D
32
+ seq_len (int): Sequence length. #T
33
+
34
+ Returns:
35
+ torch.Tensor: Output tensor with rotary positional embeddings applied.
36
+ """
37
+ B, T, H, H_D = x.shape
38
+
39
+ # Generate position indices
40
+ position = torch.arange(T, dtype=torch.float32, device=x.device).unsqueeze(-1)
41
+
42
+ # Generate frequencies
43
+ freqs = torch.exp(
44
+ torch.arange(0, H_D, 2, dtype=torch.float32, device=x.device) *
45
+ -(torch.log(torch.tensor(self.theta)) / H_D)
46
+
47
+ )
48
+
49
+ # Compute sinusoids
50
+ sinusoid = position * freqs
51
+ sin = torch.sin(sinusoid)
52
+ cos = torch.cos(sinusoid)
53
+
54
+ # Reshape sin and cos to match the input tensor's shape
55
+ sin = sin.unsqueeze(0).unsqueeze(2) # Shape: (1, T, 1, D // 2)
56
+ cos = cos.unsqueeze(0).unsqueeze(2) # Shape: (1, T, 1, D // 2)
57
+
58
+ # Apply rotary embeddings
59
+ x_rotated = x.clone()
60
+ x_rotated[..., 0::2] = x[..., 0::2] * cos - x[..., 1::2] * sin
61
+ x_rotated[..., 1::2] = x[..., 1::2] * cos + x[..., 0::2] * sin
62
+
63
+ return x_rotated
64
+
65
+ class LlamaAttention(nn.Module):
66
+ """
67
+ (self_attn): LlamaAttention(
68
+ (q_proj): Linear(in_features=576, out_features=576, bias=False)
69
+ (k_proj): Linear(in_features=576, out_features=192, bias=False)
70
+ (v_proj): Linear(in_features=576, out_features=192, bias=False)
71
+ (o_proj): Linear(in_features=576, out_features=576, bias=False)
72
+ )
73
+ """
74
+ def __init__(self, config, rotary_emb):
75
+ super().__init__()
76
+ self.config = config
77
+ self.num_attention_heads = self.config['num_attention_heads']
78
+ self.hidden_size = self.config['hidden_size']
79
+ # Ensure the hidden size is divisible by the number of attention heads
80
+ if self.hidden_size % self.num_attention_heads != 0:
81
+ raise ValueError(
82
+ f"hidden_size ({self.hidden_size}) must be divisible by num_attention_heads ({self.num_attention_heads})"
83
+ )
84
+ self.num_key_value_heads = self.config['num_key_value_heads']
85
+ self.head_dim = self.hidden_size // self.num_attention_heads
86
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) # D,D
87
+ self.k_proj = nn.Linear(self.hidden_size, self.head_dim*self.num_key_value_heads, bias=False) # D,D/H
88
+ self.v_proj = nn.Linear(self.hidden_size, self.head_dim*self.num_key_value_heads, bias=False) # D,D/H
89
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) # D,D
90
+
91
+ # Convert the mask to boolean type when creating it
92
+ # self.register_buffer("mask",
93
+ # torch.triu(torch.ones(self.config['max_position_embeddings'],
94
+ # self.config['max_position_embeddings']),
95
+ # diagonal=1)) # Convert to boolean
96
+
97
+ self.rotary_pos_emb = rotary_emb
98
+
99
+ def forward(self, x):
100
+ B, T, C = x.size()
101
+
102
+ q = self.q_proj(x) # B,T,D
103
+ k = self.k_proj(x) # B,T,D/H
104
+ v = self.v_proj(x) # B,T,D/H
105
+
106
+ q = q.view(B, T, self.num_attention_heads, self.head_dim) # B,T,H,D
107
+ k = k.view(B, T, self.num_key_value_heads, self.head_dim) # B,T,H,D
108
+ v = v.view(B, T, self.num_key_value_heads, self.head_dim) # B,T,H,D
109
+
110
+ q = q.transpose(1,2) # B,H,T,D
111
+ k = k.transpose(1,2) # B,num_key_value_heads,T,D
112
+ v = v.transpose(1,2) # B,num_key_value_heads,T,D
113
+
114
+ # apply rotary positional embedding
115
+ q = self.rotary_pos_emb(q, T)
116
+ k = self.rotary_pos_emb(k, T)
117
+
118
+ # Repeat k/v heads if num_key_value_heads < num_attention_heads
119
+ if self.num_key_value_heads != self.num_attention_heads:
120
+ k = k.repeat_interleave(self.num_attention_heads // self.num_key_value_heads, dim=1) # B,kv_head,T,D -> B,H,T,D
121
+ v = v.repeat_interleave(self.num_attention_heads // self.num_key_value_heads, dim=1) # B,kv_head,T,D -> B,H,T,D
122
+
123
+ # Manual attention Stats
124
+ # Q(B,H,T,D) @K.T(B,H,D,T) = Q.K_T (B,H,T,T)
125
+ # attn_scores = q @ k.transpose(-2,-1) # B,H,T,T
126
+ # mask_bool = self.mask[:T,:T].bool() # T,T
127
+ # attn_scores.masked_fill_(mask_bool, -torch.inf) # B,H,T,T
128
+ # attn_weights = F.softmax(attn_scores/k.size(-1)**0.5, dim=-1) # B,H,T,T
129
+ # context_vector = attn_weights @ v # B,H,T,T * B,H,T,D = B,H,T,D
130
+ # context_vector = context_vector.transpose(1,2) # B,T,H,D
131
+ # context_vector = context_vector.contiguous().view(B,T,C) # B,T,H,D -> B,T,D
132
+ # Manual attention Stats ENDS
133
+
134
+ # Scaled dot-product attention STARTS
135
+ attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
136
+ context_vector = attn_out.transpose(1,2).reshape(B,T,C)
137
+ # Scaled dot-product attention ENDS
138
+
139
+ context_vector = self.o_proj(context_vector)
140
+
141
+ return context_vector
142
+
143
+
144
+ class LlamaMLP(nn.Module):
145
+ """
146
+ (mlp): LlamaMLP(
147
+ (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
148
+ (up_proj): Linear(in_features=576, out_features=1536, bias=False)
149
+ (down_proj): Linear(in_features=1536, out_features=576, bias=False)
150
+ (act_fn): SiLU()
151
+ )
152
+ """
153
+ def __init__(self, config):
154
+ super().__init__()
155
+ self.config = config
156
+ self.gate_proj = nn.Linear(self.config['hidden_size'], self.config['intermediate_size'], bias=False)
157
+ self.up_proj = nn.Linear(self.config['hidden_size'], self.config['intermediate_size'], bias=False)
158
+ self.down_proj = nn.Linear(self.config['intermediate_size'], self.config['hidden_size'], bias=False)
159
+ self.act_fn = SiLU()
160
+ def forward(self, x):
161
+ gate = self.gate_proj(x)
162
+ up = self.up_proj(x)
163
+ down = self.down_proj(self.act_fn(gate)*up)
164
+ return down
165
+
166
+ class LlamaRMSNorm(nn.Module):
167
+ """
168
+ (norm): LlamaRMSNorm((576,), eps=1e-05)
169
+ # RMSNorm Formula:
170
+ # RMS(x) = sqrt((1 / d) * sum(x_i^2 for i in range(d)))
171
+ # x_normalized = x / RMS(x)
172
+ # output = gamma * x_normalized
173
+
174
+ """
175
+ def __init__(self, config):
176
+ super().__init__()
177
+ self.config = config
178
+ self.eps = self.config['rms_norm_eps']
179
+ self.weight = nn.Parameter(torch.ones(self.config['hidden_size']))
180
+ def forward(self, x):
181
+ rms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
182
+ return self.weight *rms * x
183
+
184
+ class LlamaDecoderLayer(nn.Module):
185
+ def __init__(self, config, rotary_emb):
186
+ super().__init__()
187
+ self.config = config
188
+ self.self_attn = LlamaAttention(self.config, rotary_emb)
189
+ self.mlp = LlamaMLP(self.config)
190
+ self.input_layernorm = LlamaRMSNorm(self.config)
191
+ self.post_attention_layernorm = LlamaRMSNorm(self.config)
192
+
193
+ def forward(self, x):
194
+ residual = x
195
+ x = self.input_layernorm(x)
196
+ x = self.self_attn(x)
197
+ x = x + residual
198
+
199
+ residual = x
200
+ x = self.post_attention_layernorm(x)
201
+ x = self.mlp(x)
202
+ x = x + residual
203
+ return x
204
+ # # x = x + self.self_attn(self.input_layernorm(x))
205
+ # # x = x + self.mlp(self.post_attention_layernorm(x))
206
+ # return x
207
+ class LlamaModel(nn.Module):
208
+ def __init__(self, config):
209
+ super().__init__()
210
+ self.init_method = config['init_method']
211
+ self.config = config['model_config']
212
+ self.embed_tokens = nn.Embedding(self.config['vocab_size'], self.config['hidden_size'])
213
+ self.rotary_emb = RotaryPositionalEmbedding(self.config['hidden_size'], self.config['rope_theta'])
214
+ self.layers = nn.ModuleList([LlamaDecoderLayer(self.config, self.rotary_emb) for _ in range(self.config['num_hidden_layers'])])
215
+ self.norm = LlamaRMSNorm(self.config)
216
+ self.lm_head = nn.Linear(self.config['hidden_size'], self.config['vocab_size'], bias=False)
217
+
218
+ if self.config['tie_word_embeddings']:
219
+ self.lm_head.weight = self.embed_tokens.weight
220
+
221
+ self.apply(lambda m: _init_weights(m, self.init_method['std']))
222
+
223
+ def forward(self, x, y=None):
224
+ x = self.embed_tokens(x)
225
+ for layer in self.layers:
226
+ x = layer(x)
227
+ x = self.norm(x)
228
+ logits = self.lm_head(x) # B,T,V
229
+ logits = logits.view(-1, logits.size(-1)) # Shape: [B*T, V]
230
+ if y is not None:
231
+ y = y.view(-1) # Shape: [B*T]
232
+ loss = torch.nn.functional.cross_entropy(logits, y)
233
+ return logits, loss
234
+ else:
235
+ return logits, None
236
+
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer
4
+ import yaml
5
+ from SmolLm3 import LlamaModel
6
+
7
+
8
+ def generate_helper(model, idx, max_new_tokens, context_length, temperature=1.0, top_k=None, eos_token=None, device=None):
9
+
10
+ model = model.to(device)
11
+ idx = idx.to(device)
12
+ model.eval()
13
+ for _ in range(max_new_tokens):
14
+ idx_cond = idx[:, -context_length:]
15
+ with torch.no_grad():
16
+ logits, _ = model(idx_cond) # Unpack both logits and loss (ignore loss)
17
+ logits = logits.view(idx_cond.shape[0], -1, model.config['vocab_size']) # Reshape to [batch, seq, vocab]
18
+
19
+ # Get the logits for the last token only
20
+ logits = logits[:, -1, :] # Shape: [batch_size, vocab_size]
21
+
22
+ if top_k is not None:
23
+ # top k sampling
24
+ top_logits, top_pos = torch.topk(logits, top_k)
25
+ min_logit = top_logits[:, -1].unsqueeze(-1)
26
+ logits = torch.where(logits < min_logit,
27
+ torch.tensor(float('-inf')).to(logits.device),
28
+ logits)
29
+
30
+ # temperature scaling
31
+ if temperature > 0.0:
32
+ logits /= temperature
33
+ probs = torch.softmax(logits, dim=-1)
34
+ idx_next = torch.multinomial(probs, num_samples=1)
35
+ else:
36
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True)
37
+
38
+ if idx_next.item() == eos_token:
39
+ break
40
+
41
+ idx = torch.cat((idx, idx_next), dim=1)
42
+ model.train()
43
+ return idx
44
+
45
+ def get_config(config_path):
46
+ config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)
47
+ return config
48
+
49
+ def load_model_from_checkpoint(config_path, checkpoint_path, device):
50
+ config = get_config(config_path)
51
+ model = LlamaModel(config['model'])
52
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
53
+ state_dict = checkpoint['model_state_dict']
54
+ state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
55
+ model.load_state_dict(state_dict)
56
+ return model
57
+
58
+ def load_weights(config, weights_path, device):
59
+ model = LlamaModel(config['model'])
60
+ model.load_state_dict(torch.load(weights_path, map_location=torch.device(device)))
61
+ return model
62
+
63
+ def get_tokenizer(config):
64
+ tokenizer_path = config['tokenizer']['tokenizer_name_or_path']
65
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
66
+ tokenizer.pad_token = tokenizer.eos_token
67
+ vocab_size = tokenizer.vocab_size
68
+ return tokenizer, vocab_size
69
+
70
+ def generate_text(model, tokenizer, input_text, max_new_tokens, context_length, temperature, top_k, eos_token, device):
71
+ encoded_text = tokenizer.encode(input_text, return_tensors="pt").to(device)
72
+ generated_text = generate_helper(model,
73
+ idx=encoded_text,
74
+ max_new_tokens=max_new_tokens,
75
+ context_length=context_length,
76
+ temperature=temperature,
77
+ top_k=top_k,
78
+ eos_token=eos_token,
79
+ device=device)
80
+ return tokenizer.decode(generated_text.squeeze(0))
81
+
82
+
83
+
84
+ # Initialize model and tokenizer
85
+ def initialize_model():
86
+ config_path = "config_smollm2_135M.yaml"
87
+ checkpoint_path = "/Users/chiragtagadiya/Documents/Final_training_before_stop_smolllm3/checkpoints/model_37000_steps_avg_loss_2.85920_optimizer_lr_0.00000003.pth" # Update this path
88
+ weights_path = "model_weights_35000_step.pt"
89
+ device = "cuda" if torch.cuda.is_available() else "cpu"
90
+
91
+ # Load configuration
92
+ config = get_config(config_path)
93
+
94
+ # Load model
95
+ # model = load_model_from_checkpoint(config_path, checkpoint_path, device)
96
+ model = load_weights(config, weights_path, device)
97
+ model.to(device)
98
+ model.eval()
99
+
100
+ # Load tokenizer
101
+ tokenizer, vocab_size = get_tokenizer(config)
102
+
103
+ return model, tokenizer, device
104
+
105
+ def generate_response(prompt, max_new_tokens):
106
+ generated_text = generate_text(
107
+ model=model,
108
+ tokenizer=tokenizer,
109
+ input_text=prompt,
110
+ max_new_tokens=max_new_tokens,
111
+ context_length=256,
112
+ temperature=0.9,
113
+ top_k=2,
114
+ eos_token=tokenizer.eos_token_id,
115
+ device=device
116
+ )
117
+ return generated_text
118
+
119
+ # Initialize global variables
120
+ model, tokenizer, device = initialize_model()
121
+
122
+ # Create Gradio interface
123
+ iface = gr.Interface(
124
+ fn=generate_response,
125
+ inputs=[
126
+ gr.Textbox(
127
+ lines=3,
128
+ placeholder="Enter your prompt here...",
129
+ label="Input Prompt"
130
+ ),
131
+ gr.Slider(
132
+ minimum=50,
133
+ maximum=256,
134
+ value=100,
135
+ step=10,
136
+ label="Max New Tokens"
137
+ )
138
+ ],
139
+ outputs=gr.Textbox(
140
+ lines=5,
141
+ label="Generated Text"
142
+ ),
143
+ title="SmolLM Text Generator",
144
+ description="Enter a prompt and adjust the maximum number of tokens to generate text with SmolLM model."
145
+ )
146
+
147
+ if __name__ == "__main__":
148
+ iface.launch()
config_smollm2_135M.yaml ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoints:
2
+ checkpoint_interval: 2000
3
+ checkpoints_path: checkpoints
4
+ checkpoints_path_is_shared_file_system: false
5
+ resume_checkpoint_path: null
6
+ save_final_state: false
7
+ save_initial_state: false
8
+ data_stages:
9
+ - data:
10
+ dataset:
11
+ dataset_folder:
12
+ - datasets/smollm2-corpus
13
+ dataset_weights:
14
+ - 1.0
15
+ num_loading_workers: 0
16
+ seed: 8
17
+ name: stable phase
18
+ start_training_step: 1
19
+ general:
20
+ benchmark_csv_path: null
21
+ consumed_train_samples: null
22
+ ignore_sanity_checks: true
23
+ project: smollm2
24
+ run: smollm2-135M
25
+ seed: 8
26
+ step: null
27
+ logging:
28
+ iteration_step_info_interval: 1
29
+ log_level: info
30
+ log_level_replica: info
31
+ model:
32
+ ddp_bucket_cap_mb: 25
33
+ dtype: bfloat16
34
+ init_method:
35
+ std: 0.041666666666666664
36
+ make_vocab_size_divisible_by: 1
37
+ model_config:
38
+ bos_token_id: 0
39
+ eos_token_id: 0
40
+ hidden_act: silu
41
+ hidden_size: 576
42
+ initializer_range: 0.041666666666666664
43
+ intermediate_size: 1536
44
+ is_llama_config: true
45
+ max_position_embeddings: 2048
46
+ num_attention_heads: 9
47
+ num_hidden_layers: 30
48
+ num_key_value_heads: 3
49
+ pad_token_id: null
50
+ pretraining_tp: 1
51
+ rms_norm_eps: 1.0e-05
52
+ rope_interleaved: false
53
+ rope_scaling: null
54
+ rope_theta: 10000.0
55
+ tie_word_embeddings: true
56
+ use_cache: true
57
+ vocab_size: 49152
58
+ s3_bucket: smollm2-train-jan-25-era3
59
+ s3_checkpoint_folder: checkpoints
60
+ s3_log_folder: logs
61
+ s3_log_file_name: training.log
62
+ optimizer:
63
+ accumulate_grad_in_fp32: true
64
+ clip_grad: 1.0
65
+ learning_rate_scheduler:
66
+ learning_rate: 0.003
67
+ lr_decay_starting_step: 1600000
68
+ lr_decay_steps: 400000
69
+ lr_decay_style: linear
70
+ lr_warmup_steps: 2000
71
+ lr_warmup_style: linear
72
+ min_decay_lr: 0
73
+ optimizer_factory:
74
+ adam_beta1: 0.9
75
+ adam_beta2: 0.95
76
+ adam_eps: 1.0e-08
77
+ name: adamW
78
+ torch_adam_is_fused: true
79
+ weight_decay: 0.01
80
+ zero_stage: 0
81
+ parallelism:
82
+ dp: 64
83
+ expert_parallel_size: 1
84
+ pp: 1
85
+ pp_engine: 1f1b
86
+ recompute_layer: false
87
+ tp: 1
88
+ tp_linear_async_communication: true
89
+ tp_mode: REDUCE_SCATTER
90
+ tp_recompute_allgather: true
91
+ profiler: null
92
+ tokenizer:
93
+ tokenizer_max_length: null
94
+ tokenizer_name_or_path: HuggingFaceTB/cosmo2-tokenizer
95
+ tokenizer_revision: null
96
+ tokens:
97
+ batch_accumulation_per_replica: 1
98
+ limit_test_batches: 0
99
+ limit_val_batches: 0
100
+ micro_batch_size: 16 #16
101
+ sequence_length: 1024 #2048
102
+ train_steps: 2000000
103
+ val_check_interval: 1000
model_testing.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from SmolLm3 import LlamaModel
3
+ import yaml
4
+ import torch
5
+ from transformers import AutoTokenizer
6
+ from train import generate
7
+
8
+ def get_config(config_path):
9
+ config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)
10
+ return config
11
+
12
+ def load_model_from_checkpoint(config_path, checkpoint_path, device):
13
+ config = get_config(config_path)
14
+ model = LlamaModel(config['model'])
15
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
16
+ state_dict = checkpoint['model_state_dict']
17
+ state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
18
+ model.load_state_dict(state_dict)
19
+ return model
20
+
21
+ def get_tokenizer(config):
22
+ tokenizer_path = config['tokenizer']['tokenizer_name_or_path']
23
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
24
+ tokenizer.pad_token = tokenizer.eos_token
25
+ vocab_size = tokenizer.vocab_size
26
+ return tokenizer, vocab_size
27
+
28
+ def generate_text(model, tokenizer, input_text, max_new_tokens, context_length, temperature, top_k, eos_token, device):
29
+ encoded_text = tokenizer.encode(input_text, return_tensors="pt").to(device)
30
+ generated_text = generate(model,
31
+ idx=encoded_text,
32
+ max_new_tokens=max_new_tokens,
33
+ context_length=context_length,
34
+ temperature=temperature,
35
+ top_k=top_k,
36
+ eos_token=eos_token,
37
+ device=device)
38
+ return tokenizer.decode(generated_text.squeeze(0))
39
+
40
+ if __name__ == "__main__":
41
+ parser = argparse.ArgumentParser(description='Generate text using the SmolLM model')
42
+ parser.add_argument('--config_path', type=str, default="config_smollm2_135M.yaml",
43
+ help='Path to the config file')
44
+ parser.add_argument('--checkpoint_path', type=str, required=True,
45
+ help='Path to the model checkpoint')
46
+ parser.add_argument('--input_text', type=str, default="Bernuli principle",
47
+ help='Input text prompt for generation')
48
+ parser.add_argument('--max_new_tokens', type=int, default=256,
49
+ help='Maximum number of new tokens to generate')
50
+ parser.add_argument('--context_length', type=int, default=256,
51
+ help='Context length for generation')
52
+ parser.add_argument('--temperature', type=float, default=0.7,
53
+ help='Temperature for sampling')
54
+ parser.add_argument('--top_k', type=int, default=5,
55
+ help='Top-k value for sampling')
56
+ parser.add_argument('--device', type=str, default="cuda" if torch.cuda.is_available() else "cpu",
57
+ help='Device to run the model on (cuda/cpu)')
58
+
59
+ args = parser.parse_args()
60
+
61
+ config = get_config(args.config_path)
62
+ model = load_model_from_checkpoint(args.config_path, args.checkpoint_path, args.device)
63
+ print(model)
64
+ tokenizer, vocab_size = get_tokenizer(config)
65
+ print(tokenizer)
66
+ print(vocab_size)
67
+
68
+ generated_text = generate_text(
69
+ model,
70
+ tokenizer,
71
+ args.input_text,
72
+ args.max_new_tokens,
73
+ args.context_length,
74
+ args.temperature,
75
+ args.top_k,
76
+ tokenizer.eos_token_id,
77
+ args.device
78
+ )
79
+ print(generated_text)
model_weights_35000_step.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a965c902af30b6148a95d2d404b6848829a94bc4815fd53d2a84be51707e7df
3
+ size 538169702
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchtext
3
+ pandas
4
+ numpy==1.26.1
5
+ matplotlib
6
+ tqdm
7
+ # urllib
8
+ requests
9
+ boto3
10
+ datasets
11
+ transformers
12
+ gradio
train.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from SmolLm3 import LlamaModel
2
+ import torch
3
+ import yaml
4
+ from transformers import AutoTokenizer
5
+ from torch.utils.data import DataLoader
6
+ import numpy as np
7
+ from datasets import load_dataset
8
+ import logging
9
+ import math
10
+
11
+ from utils import upload_file_to_s3
12
+ # At the start of training loop
13
+ # print(f"GPU Memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
14
+ # print(f"GPU Memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
15
+
16
+
17
+ logger = logging.getLogger(__name__)
18
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
19
+ file_handler = logging.FileHandler('training.log')
20
+ file_handler.setFormatter(formatter) # Set formatter on the handler, not the logger
21
+ logger.addHandler(file_handler)
22
+ logger.setLevel(logging.INFO)
23
+
24
+ def encode_text(examples, tokenizer, seq_length):
25
+ """Tokenize and prepare text examples for training."""
26
+ tokens = tokenizer(
27
+ examples["text"],
28
+ truncation=True,
29
+ padding="max_length",
30
+ max_length=seq_length + 1,
31
+ return_tensors="pt",
32
+ )
33
+ # Use clone().detach() as recommended
34
+ input_ids = tokens["input_ids"].squeeze(0).clone().detach()
35
+ input_ids = torch.clamp(input_ids, min=0, max=tokenizer.vocab_size - 1)
36
+ labels = input_ids.clone().detach()
37
+ labels = labels[1:].to(torch.int64)
38
+ input_ids = input_ids[:-1].to(torch.int64)
39
+
40
+ return {"input_ids": input_ids, "labels": labels}
41
+
42
+ def load_cosmopedia_dataset(batch_size=8, seq_length=1024, tokenizer=None):
43
+ """
44
+ Returns a torch dataloader for the cosmopedia dataset
45
+ """
46
+ # Set tokenizer parallelism explicitly
47
+ import os
48
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
49
+ logger.info("tokenizer parallelism set to false")
50
+ try:
51
+ # Increase timeout and retries for dataset loading
52
+ from datasets import config
53
+ config.HF_DATASETS_TIMEOUT = 300 # 5 minutes timeout
54
+ config.MAX_RETRIES = 10 # Increase retry attempts
55
+ logger.info("dataset loading config set")
56
+ train_dataset = load_dataset(
57
+ "HuggingFaceTB/smollm-corpus",
58
+ name="cosmopedia-v2",
59
+ split="train",
60
+ streaming=True,
61
+ )
62
+ logger.info("dataset loaded")
63
+
64
+ # Use partial to bind tokenizer and seq_length to the encode function
65
+ from functools import partial
66
+ encode_fn = partial(encode_text, tokenizer=tokenizer, seq_length=seq_length)
67
+
68
+ train_dataset = train_dataset.map(
69
+ encode_fn,
70
+ remove_columns=["text"],
71
+ batched=False
72
+ )
73
+ train_dataset = train_dataset.with_format("torch")
74
+
75
+ train_dataloader = DataLoader(
76
+ train_dataset,
77
+ batch_size=batch_size,
78
+ num_workers=2,
79
+ pin_memory=True,
80
+ prefetch_factor=4,
81
+ persistent_workers=True
82
+ )
83
+ return train_dataloader
84
+ except Exception as e:
85
+ logger.error(f"Error loading dataset: {str(e)}")
86
+ return None
87
+
88
+
89
+ def generate(model, idx, max_new_tokens, context_length, temperature=1.0, top_k=None, eos_token=None, device=None):
90
+ logger.info(f"Generating on device {device}")
91
+ model = model.to(device)
92
+ idx = idx.to(device)
93
+ model.eval()
94
+ for _ in range(max_new_tokens):
95
+ idx_cond = idx[:, -context_length:]
96
+ with torch.no_grad():
97
+ logits, _ = model(idx_cond) # Unpack both logits and loss (ignore loss)
98
+ logits = logits.view(idx_cond.shape[0], -1, model.config['vocab_size']) # Reshape to [batch, seq, vocab]
99
+
100
+ # Get the logits for the last token only
101
+ logits = logits[:, -1, :] # Shape: [batch_size, vocab_size]
102
+
103
+ if top_k is not None:
104
+ # top k sampling
105
+ top_logits, top_pos = torch.topk(logits, top_k)
106
+ min_logit = top_logits[:, -1].unsqueeze(-1)
107
+ logits = torch.where(logits < min_logit,
108
+ torch.tensor(float('-inf')).to(logits.device),
109
+ logits)
110
+
111
+ # temperature scaling
112
+ if temperature > 0.0:
113
+ logits /= temperature
114
+ probs = torch.softmax(logits, dim=-1)
115
+ idx_next = torch.multinomial(probs, num_samples=1)
116
+ else:
117
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True)
118
+
119
+ if idx_next.item() == eos_token:
120
+ break
121
+
122
+ idx = torch.cat((idx, idx_next), dim=1)
123
+ model.train()
124
+ return idx
125
+
126
+ def sync_device(device):
127
+ if device.startswith('cuda'):
128
+ torch.cuda.synchronize()
129
+ elif device == 'cpu':
130
+ torch.cpu.synchronize() if hasattr(torch.cpu, 'synchronize') else None
131
+ elif device.startswith('mps'): # For Apple Silicon
132
+ torch.mps.synchronize()
133
+
134
+ def print_gpu_memory(step_name=""):
135
+ """
136
+ Print GPU memory statistics with a specified step name
137
+ """
138
+ if torch.cuda.is_available():
139
+ logger.info(f"\nGPU Memory Stats {step_name}:")
140
+ logger.info(f"GPU Memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
141
+ logger.info(f"GPU Memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
142
+ logger.info(f"Max GPU Memory allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
143
+
144
+ # Learning rate scheduler
145
+ def get_lr_lambda(current_step, warmup_steps, max_steps, max_lr):
146
+ """
147
+ Modified learning rate scheduler with:
148
+ 1. Linear warmup for first 3000 steps
149
+ 2. Cosine decay from 3000 to 60000 steps
150
+ 3. Minimum learning rate of 1.5e-5 (5% of max_lr)
151
+ """
152
+ min_lr = max_lr * 0.05 # Minimum learning rate (5% of max_lr)
153
+
154
+ if current_step < warmup_steps:
155
+ # Linear warmup from 0 to max_lr
156
+ return float(current_step) / float(max(1, warmup_steps))
157
+ else:
158
+ # Cosine decay from max_lr to min_lr
159
+ progress = float(current_step - warmup_steps) / float(max(1, max_steps - warmup_steps))
160
+ return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * progress))
161
+
162
+
163
+ def train_model(config, model, train_loader, test_loader, optimizer, device, num_epochs, eval_freq, eval_iter, start_context="Jack Gisburn rather a cheap genius- ", tokenizer=None):
164
+ total_loss = 0
165
+ tokens_seen, global_step = 0, -1
166
+
167
+ # Adjusted gradient accumulation setup
168
+ actual_batch_size = config['tokens']['micro_batch_size'] # Now 16
169
+ effective_batch_size_multiplier = 2 # Reduced from 4 to maintain reasonable memory usage
170
+ target_batch_size = effective_batch_size_multiplier * config['tokens']['micro_batch_size']
171
+ gradient_accumulation_steps = target_batch_size // actual_batch_size
172
+
173
+ # Adjusted learning rate parameters for new batch size
174
+ max_lr = 3e-4 # Keep the same max learning rate
175
+ warmup_steps = 3000 # Increase warmup steps for longer training
176
+ max_steps = 60000 # Set to match 10 hours of training
177
+ min_lr = max_lr * 0.05 # Reduce minimum LR to 5% of max (was 10%)
178
+
179
+ # Create LambdaLR scheduler with the improved lambda function
180
+ lr_lambda = lambda step: get_lr_lambda(step, warmup_steps, max_steps, max_lr)
181
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
182
+
183
+ logger.info(f"Training with learning rate schedule:")
184
+ logger.info(f"Max LR: {max_lr}")
185
+ logger.info(f"Warmup Steps: {warmup_steps}")
186
+ logger.info(f"Max Steps: {max_steps}")
187
+ logger.info(f"Min LR: {max_lr * 0.05}")
188
+ logger.info(f"Gradient Accumulation Steps: {gradient_accumulation_steps}")
189
+ logger.info(f"Effective Batch Size: {actual_batch_size * gradient_accumulation_steps}")
190
+
191
+ print_gpu_memory("at start of training")
192
+
193
+ # Add these near the start of training loop
194
+ torch.cuda.empty_cache()
195
+ torch.backends.cudnn.benchmark = True
196
+
197
+ for epoch in range(num_epochs):
198
+ model.train()
199
+ optimizer.zero_grad() # Zero gradients at start of epoch
200
+
201
+ for batch_idx, batch in enumerate(train_loader):
202
+ input_batch = batch['input_ids'].to(device)
203
+ target_batch = batch['labels'].to(device)
204
+
205
+ # Forward pass
206
+ with torch.autocast(device_type=device, dtype=torch.bfloat16):
207
+ logits, original_loss = model(input_batch, target_batch)
208
+
209
+ # Scale loss for gradient accumulation
210
+ scaled_loss = original_loss / gradient_accumulation_steps
211
+ scaled_loss.backward()
212
+
213
+ # Add the original loss to total_loss for logging
214
+ total_loss += original_loss.item() # Don't multiply back up
215
+ tokens_seen += input_batch.numel()
216
+
217
+ # Calculate running average loss
218
+ total_batches = batch_idx + 1
219
+ avg_loss = total_loss / total_batches
220
+ if batch_idx % 25 == 0:
221
+ logger.info(f"Batch {batch_idx + 1}, Running Avg Loss: {avg_loss:.5f}")
222
+ # Only update weights after accumulating gradients
223
+ if (batch_idx + 1) % gradient_accumulation_steps == 0:
224
+ # Gradient clipping
225
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
226
+
227
+ optimizer.step()
228
+ scheduler.step() # Update learning rate
229
+ optimizer.zero_grad()
230
+ global_step += 1
231
+
232
+ # Evaluation block
233
+ if global_step % eval_freq == 0 and global_step > 0:
234
+ # Use total batches processed instead of global_step
235
+ current_lr = scheduler.get_last_lr()[0]
236
+ optimizer_lr = optimizer.param_groups[0]['lr']
237
+
238
+ print_gpu_memory(f"at step {global_step}")
239
+ logger.info(f"learning rate: {current_lr:.8f}")
240
+ logger.info(f"Ep {epoch+1} (Step {global_step:06d}): "
241
+ f"Avg loss {avg_loss:.3f} | {tokens_seen} tokens seen")
242
+ logger.info(f"optimizer lr: {optimizer_lr:.8f}")
243
+ logger.info(f"scheduler lr: {current_lr:.8f}")
244
+
245
+ # Generate sample text
246
+ encoded_text = tokenizer.encode(start_context, return_tensors="pt")
247
+ random_topk = np.random.randint(1, 10)
248
+ logger.info(f"random_topk: {random_topk}")
249
+ random_temperature = np.random.uniform(0.7, 0.9)
250
+ logger.info(f"random_temperature: {random_temperature}")
251
+ logger.info(f"global step {global_step} , batch_idx {batch_idx} => generating text")
252
+ generated_text = generate(model,
253
+ idx=encoded_text,
254
+ max_new_tokens=256,
255
+ context_length=256,
256
+ temperature=random_temperature,
257
+ top_k=random_topk,
258
+ eos_token=tokenizer.eos_token_id,
259
+ device=device)
260
+ logger.info(f"+++"*30)
261
+ logger.info(tokenizer.decode(generated_text.squeeze(0)))
262
+ logger.info(f"+++"*30)
263
+
264
+ # Save checkpoint
265
+ model_file_name = f"model_{global_step}_steps_avg_loss_{avg_loss:.5f}_optimizer_lr_{optimizer_lr:.8f}.pth"
266
+ torch.save({
267
+ 'step': global_step,
268
+ 'model_state_dict': model.state_dict(),
269
+ 'optimizer_state_dict': optimizer.state_dict(),
270
+ 'scheduler_state_dict': scheduler.state_dict(),
271
+ 'loss': avg_loss,
272
+ }, model_file_name)
273
+
274
+ s3_path = upload_file_to_s3(model_file_name, config['model']['model_config']['s3_bucket'],
275
+ config['model']['model_config']['s3_checkpoint_folder'])
276
+ logger.info(f"Model saved to S3: {s3_path}")
277
+
278
+ log_path = upload_file_to_s3(config['model']['model_config']['s3_log_file_name'], config['model']['model_config']['s3_bucket'],
279
+ config['model']['model_config']['s3_log_folder'])
280
+ logger.info(f"Log saved to S3: {log_path}")
281
+
282
+ if batch_idx % 100 == 0:
283
+ logger.info(f"Batch {batch_idx} finished")
284
+ logger.info(f"+++"*30)
285
+
286
+ logger.info("Training complete")
287
+
288
+ if __name__ == "__main__":
289
+ config = yaml.load(open("config_smollm2_135M.yaml", "r"), Loader=yaml.FullLoader)
290
+ logger.info(config)
291
+
292
+ # Set memory efficient settings
293
+ torch.set_float32_matmul_precision('high')
294
+ torch.backends.cudnn.benchmark = True
295
+ torch.backends.cuda.matmul.allow_tf32 = True
296
+
297
+ # Empty cache before model creation
298
+ torch.cuda.empty_cache()
299
+
300
+ model = LlamaModel(config['model'])
301
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
302
+
303
+ # Enable gradient checkpointing for memory efficiency
304
+ # model.gradient_checkpointing_enable()
305
+
306
+ model.to(device)
307
+ model = torch.compile(model)
308
+ logger.info(model)
309
+ logger.info("++"*30)
310
+
311
+ optimizer = torch.optim.AdamW(
312
+ model.parameters(),
313
+ lr=3e-4,
314
+ weight_decay=0.15,
315
+ betas=(0.9, 0.95)
316
+ )
317
+
318
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
319
+ tokenizer.pad_token = tokenizer.eos_token
320
+ vocab_size = tokenizer.vocab_size
321
+
322
+ # Adjusted batch size and sequence length
323
+ train_loader = load_cosmopedia_dataset(
324
+ batch_size=16, # Set to 16
325
+ seq_length=1024, # Kept at 1024
326
+ tokenizer=tokenizer
327
+ )
328
+
329
+ import time
330
+ t1 = time.time()
331
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
332
+
333
+ # Set environment variable for memory allocation
334
+ import os
335
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
336
+
337
+ train_model(
338
+ config,
339
+ model,
340
+ train_loader,
341
+ train_loader,
342
+ optimizer=optimizer,
343
+ device=device,
344
+ num_epochs=1,
345
+ eval_freq=1000, # Increase eval frequency to every 500 steps
346
+ eval_iter=1000,
347
+ start_context="Once Upon a Time far far away in a galaxy",
348
+ tokenizer=tokenizer
349
+ )
350
+ t2 = time.time()
351
+ logger.info(f"Time taken for training: {t2 - t1:.2f} seconds")