padmanabhbosamia commited on
Commit
416db21
·
verified ·
1 Parent(s): ac8a885

Upload 5 files

Browse files
Files changed (5) hide show
  1. config.py +195 -0
  2. config.yaml +81 -0
  3. model.py +522 -0
  4. smol-lm2-final.ckpt +3 -0
  5. train_script.py +264 -0
config.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ from dataclasses import dataclass, field
4
+ from typing import List, Optional
5
+
6
+ @dataclass
7
+ class ModelConfig:
8
+ type: str = "custom"
9
+ name: str = "smollm2_transformer"
10
+ tokenizer_name: str = "HuggingFaceTB/SmolLM2-135M"
11
+ vocab_size: int = 49152
12
+ hidden_size: int = 576
13
+ num_attention_heads: int = 9
14
+ num_key_value_heads: int = 3
15
+ num_hidden_layers: int = 30
16
+ intermediate_size: int = 1536
17
+ hidden_act: str = "gelu"
18
+ max_position_embeddings: int = 512
19
+ initializer_range: float = 0.02
20
+ rms_norm_eps: float = 1e-5
21
+ use_cache: bool = True
22
+ pad_token_id: Optional[int] = None
23
+ max_length: int = 512
24
+
25
+ def __post_init__(self):
26
+ # Ensure numeric values are proper types
27
+ self.vocab_size = int(self.vocab_size)
28
+ self.hidden_size = int(self.hidden_size)
29
+ self.num_attention_heads = int(self.num_attention_heads)
30
+ self.num_key_value_heads = int(self.num_key_value_heads)
31
+ self.num_hidden_layers = int(self.num_hidden_layers)
32
+ self.intermediate_size = int(self.intermediate_size)
33
+ self.max_position_embeddings = int(self.max_position_embeddings)
34
+ self.initializer_range = float(self.initializer_range)
35
+ self.rms_norm_eps = float(self.rms_norm_eps)
36
+ self.max_length = int(self.max_length)
37
+
38
+ @dataclass
39
+ class OptimizerConfig:
40
+ type: str = "adamW"
41
+ weight_decay: float = 0.01
42
+ adam_beta1: float = 0.9
43
+ adam_beta2: float = 0.95
44
+ adam_eps: float = 1e-8
45
+ torch_adam_is_fused: bool = True
46
+ clip_grad: float = 1.0
47
+ accumulate_grad_in_fp32: bool = True
48
+
49
+ def __post_init__(self):
50
+ # Ensure numeric values are proper floats
51
+ self.weight_decay = float(self.weight_decay)
52
+ self.adam_beta1 = float(self.adam_beta1)
53
+ self.adam_beta2 = float(self.adam_beta2)
54
+ self.adam_eps = float(self.adam_eps)
55
+ self.clip_grad = float(self.clip_grad)
56
+
57
+ @dataclass
58
+ class SchedulerConfig:
59
+ type: str = "one_cycle"
60
+ learning_rate: float = 0.003
61
+ warmup_steps: int = 100
62
+ max_lr: float = 0.003
63
+ pct_start: float = 0.02
64
+ anneal_strategy: str = "cos"
65
+ cycle_momentum: bool = False
66
+ div_factor: float = 25.0
67
+ final_div_factor: float = 1000.0
68
+
69
+ @dataclass
70
+ class TrainingConfig:
71
+ output_dir: str = "./results"
72
+ batch_size: int = 2
73
+ micro_batch_size: int = 1
74
+ gradient_accumulation_steps: int = 4
75
+ sequence_length: int = 512
76
+ learning_rate: float = 0.003
77
+ max_steps: int = 5050
78
+ first_phase_steps: int = 5000
79
+ second_phase_steps: int = 50
80
+ sample_frequency: int = 500
81
+ second_phase_sample_frequency: int = 10
82
+ logging_dir: str = "./logs"
83
+ logging_steps: int = 1
84
+ save_steps: int = 500
85
+ checkpoint_dir: str = "checkpoints"
86
+ sample_prompt: str = "Explain what machine learning is:"
87
+ max_generate_length: int = 100
88
+
89
+ @dataclass
90
+ class HardwareConfig:
91
+ precision: str = "16-mixed"
92
+ accelerator: str = "gpu"
93
+ devices: int = 1
94
+ strategy: str = "auto"
95
+ gradient_clip: float = 1.0
96
+
97
+ @dataclass
98
+ class DatasetConfig:
99
+ name: str
100
+ path: str
101
+ subset: str
102
+ weight: float
103
+ split_ratio: float = 1.0 # Default to using full dataset
104
+
105
+ @dataclass
106
+ class DataLoadingConfig:
107
+ num_workers: int = 2
108
+ batch_size: int = 32
109
+ pin_memory: bool = True
110
+ prefetch_factor: int = 2
111
+ persistent_workers: bool = True
112
+
113
+ @dataclass
114
+ class DataConfig:
115
+ datasets: List[DatasetConfig] = field(default_factory=list)
116
+ loading: DataLoadingConfig = field(default_factory=DataLoadingConfig)
117
+
118
+ class SmolLM2Config:
119
+ def __init__(self, config_path: str = None):
120
+ self.model = ModelConfig()
121
+ self.optimizer = OptimizerConfig()
122
+ self.scheduler = SchedulerConfig()
123
+ self.training = TrainingConfig()
124
+ self.hardware = HardwareConfig()
125
+ self.data = DataConfig()
126
+
127
+ # Default dataset configuration
128
+ self.data.datasets = [
129
+ DatasetConfig(
130
+ name="wikitext",
131
+ path="wikitext",
132
+ subset="wikitext-2-raw-v1",
133
+ weight=1.0
134
+ )
135
+ ]
136
+
137
+ if config_path and os.path.exists(config_path):
138
+ self.load_from_yaml(config_path)
139
+
140
+ def load_from_yaml(self, config_path: str):
141
+ with open(config_path, 'r') as f:
142
+ config_dict = yaml.safe_load(f)
143
+
144
+ # Update configurations from yaml
145
+ if 'model' in config_dict:
146
+ for k, v in config_dict['model'].items():
147
+ setattr(self.model, k, v)
148
+
149
+ if 'optimizer' in config_dict:
150
+ for k, v in config_dict['optimizer'].items():
151
+ setattr(self.optimizer, k, v)
152
+
153
+ if 'scheduler' in config_dict:
154
+ for k, v in config_dict['scheduler'].items():
155
+ setattr(self.scheduler, k, v)
156
+
157
+ if 'training' in config_dict:
158
+ for k, v in config_dict['training'].items():
159
+ setattr(self.training, k, v)
160
+
161
+ if 'hardware' in config_dict:
162
+ for k, v in config_dict['hardware'].items():
163
+ setattr(self.hardware, k, v)
164
+
165
+ if 'data' in config_dict:
166
+ for k, v in config_dict['data'].items():
167
+ if k == 'datasets':
168
+ for dataset in v:
169
+ self.data.datasets.append(DatasetConfig(**dataset))
170
+ elif k == 'loading':
171
+ for k, v in config_dict['data']['loading'].items():
172
+ setattr(self.data.loading, k, v)
173
+
174
+ def save_to_yaml(self, config_path: str):
175
+ config_dict = {
176
+ 'model': self.model.__dict__,
177
+ 'optimizer': self.optimizer.__dict__,
178
+ 'scheduler': self.scheduler.__dict__,
179
+ 'training': self.training.__dict__,
180
+ 'hardware': self.hardware.__dict__,
181
+ 'data': self.data.__dict__
182
+ }
183
+
184
+ with open(config_path, 'w') as f:
185
+ yaml.dump(config_dict, f, default_flow_style=False)
186
+
187
+ def __repr__(self):
188
+ return f"SmolLM2Config(\n" \
189
+ f" model={self.model}\n" \
190
+ f" optimizer={self.optimizer}\n" \
191
+ f" scheduler={self.scheduler}\n" \
192
+ f" training={self.training}\n" \
193
+ f" hardware={self.hardware}\n" \
194
+ f" data={self.data}\n" \
195
+ f")"
config.yaml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ type: "custom"
3
+ name: "smollm2_transformer"
4
+ tokenizer_name: "gpt2"
5
+ vocab_size: 50257
6
+ hidden_size: 256
7
+ num_attention_heads: 4
8
+ num_key_value_heads: 2
9
+ num_hidden_layers: 6
10
+ intermediate_size: 512
11
+ hidden_act: "gelu"
12
+ max_position_embeddings: 256
13
+ initializer_range: 0.02
14
+ rms_norm_eps: 1.0e-5
15
+ use_cache: true
16
+ pad_token_id: null
17
+
18
+ optimizer:
19
+ type: "adamW"
20
+ weight_decay: 0.01
21
+ adam_beta1: 0.9
22
+ adam_beta2: 0.95
23
+ adam_eps: 1.0e-8
24
+ torch_adam_is_fused: true
25
+ clip_grad: 1.0
26
+ accumulate_grad_in_fp32: true
27
+
28
+ scheduler:
29
+ type: "one_cycle"
30
+ learning_rate: 0.001
31
+ warmup_steps: 50
32
+ max_lr: 0.001
33
+ pct_start: 0.02
34
+ anneal_strategy: "cos"
35
+ cycle_momentum: false
36
+ div_factor: 25.0
37
+ final_div_factor: 1000.0
38
+
39
+ training:
40
+ output_dir: "./results"
41
+ batch_size: 4
42
+ micro_batch_size: 2
43
+ gradient_accumulation_steps: 2
44
+ sequence_length: 256
45
+ learning_rate: 0.001
46
+ max_steps: 5050 # Total steps (5000 + 50)
47
+ first_phase_steps: 5000 # Initial training phase
48
+ second_phase_steps: 50 # Fine-tuning phase
49
+ sample_frequency: 100 # Sample every 100 steps in first phase
50
+ second_phase_sample_frequency: 5 # Sample more frequently in second phase
51
+ logging_dir: "./logs"
52
+ logging_steps: 1
53
+ save_steps: 100
54
+ checkpoint_dir: "checkpoints"
55
+ sample_prompt: "Explain what machine learning is:"
56
+ max_generate_length: 50
57
+
58
+ hardware:
59
+ precision: "16-mixed"
60
+ accelerator: "gpu"
61
+ devices: 1
62
+ strategy: "auto"
63
+ gradient_clip: 1.0
64
+ cuda_memory_fraction: 0.9
65
+ allow_tf32: true
66
+ benchmark: true
67
+ deterministic: false
68
+
69
+ data:
70
+ datasets:
71
+ - name: "wikitext"
72
+ path: "wikitext"
73
+ subset: "wikitext-103-raw-v1"
74
+ split_ratio: 0.01 # Use only 1% of the dataset
75
+ weight: 1.0
76
+ loading:
77
+ num_workers: 2
78
+ batch_size: 16
79
+ pin_memory: true
80
+ prefetch_factor: 2
81
+ persistent_workers: true
model.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch.optim import AdamW
5
+ from torch.optim.lr_scheduler import OneCycleLR
6
+ from transformers import AutoTokenizer
7
+ import torch.nn as nn
8
+ import math
9
+ from torch.utils.data import DataLoader, Dataset
10
+ from datasets import load_dataset
11
+ import os
12
+
13
+ def _init_weights(module, std=0.02):
14
+ if isinstance(module, nn.Linear):
15
+ module.weight.data.normal_(mean=0.0, std=std)
16
+ elif isinstance(module, nn.Embedding):
17
+ module.weight.data.normal_(mean=0.0, std=std)
18
+
19
+ class RMSNorm(nn.Module):
20
+ def __init__(self, dim, eps=1e-5):
21
+ super().__init__()
22
+ self.eps = float(eps) # Ensure eps is a float
23
+ self.weight = nn.Parameter(torch.ones(dim))
24
+
25
+ def forward(self, x):
26
+ norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
27
+ return x * norm * self.weight
28
+
29
+ class RotaryEmbedding(nn.Module):
30
+ def __init__(self, dim, max_position_embeddings=2048, base=10000):
31
+ super().__init__()
32
+ self.dim = dim
33
+ self.max_position_embeddings = int(max_position_embeddings) # Convert to int
34
+ self.base = base
35
+
36
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
37
+ self.register_buffer("inv_freq", inv_freq)
38
+
39
+ t = torch.arange(self.max_position_embeddings).type_as(self.inv_freq)
40
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
41
+ emb = torch.cat((freqs, freqs), dim=-1)
42
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :])
43
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :])
44
+
45
+ def forward(self, x, seq_len=None):
46
+ # Convert seq_len to int and ensure it's a valid value
47
+ seq_len = int(seq_len) if seq_len is not None else x.size(1)
48
+ if seq_len > self.max_position_embeddings:
49
+ seq_len = self.max_position_embeddings
50
+
51
+ return (
52
+ self.cos_cached[:,:,:seq_len,:],
53
+ self.sin_cached[:,:,:seq_len,:]
54
+ )
55
+
56
+ def rotate_half(x):
57
+ """Rotates half the hidden dims of the input."""
58
+ x1, x2 = x.chunk(2, dim=-1)
59
+ return torch.cat((-x2, x1), dim=-1)
60
+
61
+ def apply_rotary_pos_emb(q, k, cos, sin):
62
+ # Ensure proper broadcasting
63
+ cos = cos[:, :, :q.size(2), :] # [batch, 1, seq_len, dim]
64
+ sin = sin[:, :, :q.size(2), :] # [batch, 1, seq_len, dim]
65
+
66
+ q_embed = (q * cos) + (rotate_half(q) * sin)
67
+ k_embed = (k * cos) + (rotate_half(k) * sin)
68
+ return q_embed, k_embed
69
+
70
+ class Attention(nn.Module):
71
+ def __init__(self, config):
72
+ super().__init__()
73
+ self.hidden_size = config.hidden_size
74
+ self.num_attention_heads = config.num_attention_heads
75
+ self.num_key_value_heads = config.num_key_value_heads
76
+ self.head_dim = self.hidden_size // self.num_attention_heads
77
+
78
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
79
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
80
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
81
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
82
+
83
+ def forward(self, hidden_states, cos, sin, attention_mask=None):
84
+ batch_size, seq_length, _ = hidden_states.shape
85
+
86
+ q = self.q_proj(hidden_states)
87
+ k = self.k_proj(hidden_states)
88
+ v = self.v_proj(hidden_states)
89
+
90
+ # Reshape for attention computation
91
+ q = q.view(batch_size, seq_length, self.num_attention_heads, self.head_dim)
92
+ k = k.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim)
93
+ v = v.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim)
94
+
95
+ # Transpose for attention computation
96
+ q = q.transpose(1, 2) # [batch, num_heads, seq_len, head_dim]
97
+ k = k.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim]
98
+ v = v.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim]
99
+
100
+ # Apply rotary embeddings
101
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
102
+
103
+ # Repeat k/v heads if num_key_value_heads < num_attention_heads
104
+ if self.num_key_value_heads != self.num_attention_heads:
105
+ k = k.repeat_interleave(self.num_attention_heads // self.num_key_value_heads, dim=1)
106
+ v = v.repeat_interleave(self.num_attention_heads // self.num_key_value_heads, dim=1)
107
+
108
+ # Compute attention
109
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
110
+
111
+ if attention_mask is not None:
112
+ attn_weights = attn_weights + attention_mask
113
+
114
+ attn_weights = F.softmax(attn_weights, dim=-1)
115
+
116
+ # Compute output
117
+ output = torch.matmul(attn_weights, v)
118
+ output = output.transpose(1, 2).contiguous() # [batch, seq_len, num_heads, head_dim]
119
+ output = output.view(batch_size, seq_length, -1)
120
+
121
+ return self.o_proj(output)
122
+
123
+ class MLP(nn.Module):
124
+ def __init__(self, config):
125
+ super().__init__()
126
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
127
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
128
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
129
+ self.act_fn = nn.SiLU()
130
+
131
+ def forward(self, x):
132
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
133
+
134
+ class DecoderLayer(nn.Module):
135
+ def __init__(self, config):
136
+ super().__init__()
137
+ self.self_attn = Attention(config)
138
+ self.mlp = MLP(config)
139
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
140
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
141
+
142
+ def forward(self, hidden_states, cos, sin, attention_mask=None):
143
+ # Self attention
144
+ residual = hidden_states
145
+ hidden_states = self.input_layernorm(hidden_states)
146
+ hidden_states = self.self_attn(hidden_states, cos, sin, attention_mask)
147
+ hidden_states = residual + hidden_states
148
+
149
+ # MLP
150
+ residual = hidden_states
151
+ hidden_states = self.post_attention_layernorm(hidden_states)
152
+ hidden_states = self.mlp(hidden_states)
153
+ hidden_states = residual + hidden_states
154
+
155
+ return hidden_states
156
+
157
+ class SmolLM2(nn.Module):
158
+ def __init__(self, config):
159
+ super().__init__()
160
+ self.config = config
161
+
162
+ # Token embeddings
163
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
164
+
165
+ # Initialize transformer layers
166
+ self.layers = nn.ModuleList([
167
+ DecoderLayer(config) for _ in range(config.num_hidden_layers)
168
+ ])
169
+
170
+ # Final layer norm
171
+ self.norm = RMSNorm(config.hidden_size, eps=float(config.rms_norm_eps))
172
+
173
+ # Initialize rotary embeddings
174
+ self.rotary_emb = RotaryEmbedding(
175
+ config.hidden_size // config.num_attention_heads,
176
+ max_position_embeddings=config.max_position_embeddings
177
+ )
178
+
179
+ # Initialize weights
180
+ self.apply(lambda p: _init_weights(p, std=config.initializer_range))
181
+
182
+ def forward(self, input_ids, attention_mask=None):
183
+ try:
184
+ # Ensure inputs are on the correct device
185
+ device = input_ids.device
186
+ batch_size, seq_length = input_ids.shape
187
+
188
+ # Input validation
189
+ if seq_length > self.config.max_position_embeddings:
190
+ raise ValueError(f"Input sequence length {seq_length} exceeds maximum position embeddings {self.config.max_position_embeddings}")
191
+
192
+ # Get embeddings
193
+ hidden_states = self.embed_tokens(input_ids)
194
+
195
+ # Get position embeddings
196
+ cos, sin = self.rotary_emb(hidden_states, seq_length)
197
+
198
+ # Generate attention mask if none provided
199
+ if attention_mask is None:
200
+ attention_mask = torch.ones(
201
+ (batch_size, seq_length),
202
+ dtype=torch.bool,
203
+ device=device
204
+ )
205
+ else:
206
+ # Convert to boolean if it's not already and ensure contiguous memory
207
+ attention_mask = attention_mask.bool().contiguous()
208
+
209
+ # Create causal mask
210
+ causal_mask = torch.triu(
211
+ torch.ones((seq_length, seq_length), device=device),
212
+ diagonal=1
213
+ ).bool()
214
+
215
+ # Create attention mask [batch_size, 1, seq_length, seq_length]
216
+ attention_mask = attention_mask.view(batch_size, 1, 1, seq_length)
217
+ attention_mask = attention_mask.expand(batch_size, 1, seq_length, seq_length)
218
+
219
+ # Prepare causal mask
220
+ causal_mask = causal_mask.view(1, 1, seq_length, seq_length)
221
+
222
+ # Combine masks
223
+ mask = attention_mask & ~causal_mask
224
+
225
+ # Convert boolean mask to float mask
226
+ mask = mask.to(dtype=hidden_states.dtype)
227
+ mask = (1.0 - mask) * torch.finfo(hidden_states.dtype).min
228
+
229
+ # Apply transformer layers
230
+ for layer in self.layers:
231
+ hidden_states = layer(hidden_states, cos, sin, mask)
232
+
233
+ # Apply final normalization
234
+ hidden_states = self.norm(hidden_states)
235
+
236
+ # Project back to vocabulary
237
+ logits = F.linear(hidden_states, self.embed_tokens.weight)
238
+
239
+ return logits
240
+
241
+ except Exception as e:
242
+ print(f"\nForward pass error:")
243
+ print(f"Input shape: {input_ids.shape}")
244
+ print(f"Device: {input_ids.device}")
245
+ print(f"CUDA memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
246
+ print(f"Error: {str(e)}")
247
+ raise
248
+
249
+ def generate(
250
+ self,
251
+ input_ids,
252
+ attention_mask=None,
253
+ max_length=100,
254
+ temperature=0.7,
255
+ top_p=0.9,
256
+ top_k=50,
257
+ num_return_sequences=1,
258
+ do_sample=True,
259
+ pad_token_id=None,
260
+ bos_token_id=None,
261
+ eos_token_id=None
262
+ ):
263
+ try:
264
+ batch_size = input_ids.shape[0]
265
+ current_length = input_ids.shape[1]
266
+ device = input_ids.device
267
+
268
+ # Input validation
269
+ if current_length >= self.config.max_position_embeddings:
270
+ raise ValueError(f"Input sequence length {current_length} exceeds maximum position embeddings {self.config.max_position_embeddings}")
271
+
272
+ # Ensure we don't exceed maximum position embeddings
273
+ max_length = min(max_length, self.config.max_position_embeddings)
274
+
275
+ # Initialize attention mask if None
276
+ if attention_mask is None:
277
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool, device=device)
278
+
279
+ for _ in range(max_length - current_length):
280
+ # Forward pass
281
+ outputs = self(input_ids, attention_mask)
282
+ next_token_logits = outputs[:, -1, :] / temperature
283
+
284
+ # Apply top-k filtering
285
+ if top_k > 0:
286
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
287
+ next_token_logits[indices_to_remove] = float('-inf')
288
+
289
+ # Apply top-p filtering
290
+ if top_p < 1.0:
291
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
292
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
293
+ sorted_indices_to_remove = cumulative_probs > top_p
294
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
295
+ sorted_indices_to_remove[..., 0] = 0
296
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
297
+ next_token_logits[indices_to_remove] = float('-inf')
298
+
299
+ # Sample from the filtered distribution
300
+ if do_sample:
301
+ probs = F.softmax(next_token_logits, dim=-1)
302
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
303
+ else:
304
+ next_tokens = torch.argmax(next_token_logits, dim=-1)
305
+
306
+ # Append new tokens
307
+ input_ids = torch.cat([input_ids, next_tokens.unsqueeze(-1)], dim=-1)
308
+ attention_mask = torch.cat([attention_mask, torch.ones_like(next_tokens.unsqueeze(-1))], dim=-1)
309
+
310
+ # Stop if we've hit special tokens
311
+ if (pad_token_id is not None and (next_tokens == pad_token_id).all()) or \
312
+ (eos_token_id is not None and (next_tokens == eos_token_id).all()):
313
+ break
314
+
315
+ return input_ids
316
+
317
+ except Exception as e:
318
+ print(f"\nGeneration error:")
319
+ print(f"Input shape: {input_ids.shape}")
320
+ print(f"Device: {input_ids.device}")
321
+ print(f"Error: {str(e)}")
322
+ raise
323
+
324
+ class TextDataset(Dataset):
325
+ def __init__(self, config, split="train"):
326
+ self.config = config
327
+
328
+ # Load dataset from HuggingFace
329
+ full_dataset = load_dataset(
330
+ config.data.datasets[0].path,
331
+ config.data.datasets[0].subset,
332
+ split=split
333
+ )
334
+
335
+ # Apply split ratio if less than 1
336
+ if config.data.datasets[0].split_ratio < 1.0:
337
+ num_samples = int(len(full_dataset) * config.data.datasets[0].split_ratio)
338
+ self.dataset = full_dataset.select(range(num_samples))
339
+ else:
340
+ self.dataset = full_dataset
341
+
342
+ # Initialize tokenizer
343
+ self.tokenizer = AutoTokenizer.from_pretrained(config.model.tokenizer_name)
344
+ if self.tokenizer.pad_token is None:
345
+ self.tokenizer.pad_token = self.tokenizer.eos_token
346
+
347
+ def __len__(self):
348
+ return len(self.dataset)
349
+
350
+ def __getitem__(self, idx):
351
+ # Get text from dataset
352
+ text = self.dataset[idx]["text"]
353
+
354
+ # Tokenize
355
+ encodings = self.tokenizer(
356
+ text,
357
+ truncation=True,
358
+ max_length=self.config.model.max_position_embeddings,
359
+ padding="max_length",
360
+ return_tensors="pt"
361
+ )
362
+
363
+ return {
364
+ "input_ids": encodings.input_ids.squeeze(),
365
+ "attention_mask": encodings.attention_mask.squeeze(),
366
+ "labels": encodings.input_ids.squeeze()
367
+ }
368
+
369
+ class SmolLM2Lightning(pl.LightningModule):
370
+ def __init__(self, config):
371
+ super().__init__()
372
+ self.save_hyperparameters()
373
+ self.config = config
374
+
375
+ # Initialize tokenizer
376
+ self.tokenizer = AutoTokenizer.from_pretrained(config.model.tokenizer_name)
377
+ if self.tokenizer.pad_token is None:
378
+ self.tokenizer.pad_token = self.tokenizer.eos_token
379
+
380
+ # Initialize the base model
381
+ self.model = SmolLM2(config.model)
382
+
383
+ def forward(self, input_ids, attention_mask=None):
384
+ return self.model(input_ids, attention_mask)
385
+
386
+ def training_step(self, batch, batch_idx):
387
+ try:
388
+ input_ids = batch["input_ids"]
389
+ labels = batch["labels"]
390
+ attention_mask = batch.get("attention_mask", None)
391
+
392
+ # Ensure tensors are contiguous and on the correct device
393
+ inputs = input_ids[..., :-1].contiguous()
394
+ labels = input_ids[..., 1:].contiguous()
395
+
396
+ if attention_mask is not None:
397
+ attention_mask = attention_mask[..., :-1].contiguous()
398
+
399
+ # Forward pass
400
+ logits = self(inputs, attention_mask)
401
+
402
+ # Calculate loss
403
+ loss = F.cross_entropy(
404
+ logits.view(-1, self.config.model.vocab_size),
405
+ labels.view(-1),
406
+ ignore_index=self.config.model.pad_token_id if self.config.model.pad_token_id is not None else -100,
407
+ reduction='mean'
408
+ )
409
+
410
+ # Detach loss for logging
411
+ loss_value = loss.detach().float()
412
+
413
+ # Log metrics
414
+ self.log('train_loss', loss_value, prog_bar=True, on_step=True, sync_dist=True)
415
+
416
+ return loss
417
+
418
+ except Exception as e:
419
+ print(f"\nTraining step error:")
420
+ print(f"Input shape: {input_ids.shape if input_ids is not None else 'None'}")
421
+ print(f"Device: {input_ids.device if input_ids is not None else 'None'}")
422
+ print(f"Error: {str(e)}")
423
+ raise
424
+
425
+ def validation_step(self, batch, batch_idx):
426
+ try:
427
+ input_ids = batch["input_ids"]
428
+ labels = batch["labels"]
429
+ attention_mask = batch.get("attention_mask", None)
430
+
431
+ # Ensure tensors are contiguous and on the correct device
432
+ inputs = input_ids[..., :-1].contiguous()
433
+ labels = input_ids[..., 1:].contiguous()
434
+
435
+ if attention_mask is not None:
436
+ attention_mask = attention_mask[..., :-1].contiguous()
437
+
438
+ # Forward pass
439
+ logits = self(inputs, attention_mask)
440
+
441
+ # Calculate loss
442
+ loss = F.cross_entropy(
443
+ logits.view(-1, self.config.model.vocab_size),
444
+ labels.view(-1),
445
+ ignore_index=self.config.model.pad_token_id if self.config.model.pad_token_id is not None else -100,
446
+ reduction='mean'
447
+ )
448
+
449
+ # Detach loss for logging
450
+ loss_value = loss.detach().float()
451
+
452
+ # Log metrics
453
+ self.log('val_loss', loss_value, prog_bar=True, on_epoch=True, sync_dist=True)
454
+
455
+ return loss
456
+
457
+ except Exception as e:
458
+ print(f"\nValidation step error:")
459
+ print(f"Input shape: {input_ids.shape if input_ids is not None else 'None'}")
460
+ print(f"Device: {input_ids.device if input_ids is not None else 'None'}")
461
+ print(f"Error: {str(e)}")
462
+ raise
463
+
464
+ def configure_optimizers(self):
465
+ # Create optimizer with explicit type conversion
466
+ optimizer = AdamW(
467
+ self.parameters(),
468
+ lr=float(self.config.scheduler.learning_rate),
469
+ weight_decay=float(self.config.optimizer.weight_decay),
470
+ betas=(float(self.config.optimizer.adam_beta1),
471
+ float(self.config.optimizer.adam_beta2)),
472
+ eps=float(self.config.optimizer.adam_eps),
473
+ )
474
+
475
+ # Create scheduler
476
+ scheduler = OneCycleLR(
477
+ optimizer,
478
+ max_lr=float(self.config.scheduler.max_lr),
479
+ total_steps=int(self.config.training.max_steps),
480
+ pct_start=float(self.config.scheduler.pct_start),
481
+ anneal_strategy=self.config.scheduler.anneal_strategy,
482
+ cycle_momentum=bool(self.config.scheduler.cycle_momentum),
483
+ div_factor=float(self.config.scheduler.div_factor),
484
+ final_div_factor=float(self.config.scheduler.final_div_factor),
485
+ )
486
+
487
+ return {
488
+ "optimizer": optimizer,
489
+ "lr_scheduler": {
490
+ "scheduler": scheduler,
491
+ "interval": "step",
492
+ "frequency": 1
493
+ }
494
+ }
495
+
496
+ def generate(self, *args, **kwargs):
497
+ return self.model.generate(*args, **kwargs)
498
+
499
+ def train_dataloader(self):
500
+ dataset = TextDataset(self.config, split="train")
501
+ return DataLoader(
502
+ dataset,
503
+ batch_size=self.config.training.batch_size,
504
+ shuffle=True,
505
+ num_workers=self.config.data.loading.num_workers,
506
+ pin_memory=self.config.data.loading.pin_memory,
507
+ persistent_workers=True,
508
+ prefetch_factor=self.config.data.loading.prefetch_factor,
509
+ drop_last=True # Drop incomplete batches
510
+ )
511
+
512
+ def val_dataloader(self):
513
+ dataset = TextDataset(self.config, split="validation")
514
+ return DataLoader(
515
+ dataset,
516
+ batch_size=self.config.training.batch_size,
517
+ shuffle=False,
518
+ num_workers=self.config.data.loading.num_workers,
519
+ pin_memory=self.config.data.loading.pin_memory,
520
+ persistent_workers=True,
521
+ prefetch_factor=self.config.data.loading.prefetch_factor
522
+ )
smol-lm2-final.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7992a6cb4ad6ca593be88b64f9e4359f771afaeabf6da719bd6aab480461fb08
3
+ size 197102570
train_script.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import wandb
4
+ import shutil
5
+ from config import SmolLM2Config
6
+ from model import SmolLM2Lightning
7
+ import pytorch_lightning as pl
8
+ from pytorch_lightning.callbacks import ModelCheckpoint, Callback
9
+ from pytorch_lightning.loggers import WandbLogger
10
+ from env_setup import setup_environment, cleanup_environment
11
+
12
+ # Set CUDA environment variables before any other CUDA operations
13
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
14
+ os.environ['TORCH_USE_CUDA_DSA'] = '1'
15
+
16
+ def setup_training():
17
+ """Setup training environment"""
18
+ try:
19
+ if torch.cuda.is_available():
20
+ # Configure CUDA settings
21
+ torch.backends.cuda.matmul.allow_tf32 = True
22
+ torch.backends.cudnn.allow_tf32 = True
23
+ torch.backends.cudnn.benchmark = True
24
+ torch.set_float32_matmul_precision('high')
25
+
26
+ # Set default device
27
+ device = torch.device('cuda:0')
28
+ torch.cuda.set_device(device)
29
+
30
+ # Print GPU info
31
+ print(f"Using GPU: {torch.cuda.get_device_name()}")
32
+ print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
33
+ return device
34
+ except Exception as e:
35
+ print(f"CUDA setup error: {str(e)}")
36
+
37
+ print("Using CPU")
38
+ return torch.device('cpu')
39
+
40
+ def cleanup_training():
41
+ """Cleanup training resources"""
42
+ try:
43
+ # Move model to CPU before cleanup
44
+ if torch.cuda.is_available():
45
+ torch.cuda.empty_cache()
46
+
47
+ # Clean up wandb
48
+ try:
49
+ wandb.finish()
50
+ except:
51
+ pass
52
+
53
+ except Exception as e:
54
+ print(f"Cleanup error: {str(e)}")
55
+
56
+ # Setup CUDA at module level
57
+ device = setup_training()
58
+
59
+ class GenerationMonitorCallback(Callback):
60
+ def __init__(self, prompt="Explain what machine learning is:", sample_every_n_steps=500):
61
+ super().__init__()
62
+ self.prompt = prompt
63
+ self.sample_every_n_steps = sample_every_n_steps
64
+
65
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
66
+ try:
67
+ if (trainer.global_step + 1) % self.sample_every_n_steps == 0:
68
+ # Switch to eval mode
69
+ pl_module.eval()
70
+
71
+ with torch.no_grad():
72
+ # Tokenize prompt
73
+ inputs = pl_module.tokenizer(
74
+ self.prompt,
75
+ return_tensors="pt",
76
+ truncation=True,
77
+ max_length=pl_module.config.model.max_position_embeddings,
78
+ padding=True
79
+ ).to(pl_module.device)
80
+
81
+ try:
82
+ # Generate text with error handling
83
+ outputs = pl_module.generate(
84
+ input_ids=inputs.input_ids,
85
+ attention_mask=inputs.attention_mask,
86
+ max_length=100,
87
+ temperature=0.7,
88
+ top_p=0.9,
89
+ top_k=50,
90
+ do_sample=True,
91
+ pad_token_id=pl_module.tokenizer.pad_token_id,
92
+ bos_token_id=pl_module.tokenizer.bos_token_id,
93
+ eos_token_id=pl_module.tokenizer.eos_token_id
94
+ )
95
+
96
+ # Decode generated text
97
+ generated_text = pl_module.tokenizer.decode(outputs[0], skip_special_tokens=True)
98
+
99
+ # Print results
100
+ print(f"\n=== Generation at step {trainer.global_step + 1} ===")
101
+ print(f"Prompt: {self.prompt}")
102
+ print(f"Generated: {generated_text}\n")
103
+
104
+ except RuntimeError as e:
105
+ print(f"\nError during generation at step {trainer.global_step + 1}: {str(e)}")
106
+ print(f"Input shape: {inputs.input_ids.shape}")
107
+ print(f"Input device: {inputs.input_ids.device}")
108
+
109
+ # Switch back to train mode
110
+ pl_module.train()
111
+
112
+ except Exception as e:
113
+ print(f"\nCallback error at step {trainer.global_step + 1}: {str(e)}")
114
+
115
+ def init_wandb(project_name, run_name):
116
+ """Initialize WandB with error handling and cleanup"""
117
+ try:
118
+ # Try to clean up any existing wandb directory
119
+ wandb_dir = os.path.join(os.getcwd(), "wandb")
120
+ if os.path.exists(wandb_dir):
121
+ try:
122
+ shutil.rmtree(wandb_dir)
123
+ print("Cleaned up existing wandb directory")
124
+ except Exception as e:
125
+ print(f"Warning: Could not clean up wandb directory: {str(e)}")
126
+
127
+ # Create fresh wandb directory with proper permissions
128
+ os.makedirs(wandb_dir, exist_ok=True)
129
+
130
+ # Initialize WandB logger
131
+ logger = WandbLogger(
132
+ project=project_name,
133
+ name=run_name,
134
+ save_dir=os.getcwd(),
135
+ settings=wandb.Settings(start_method="thread")
136
+ )
137
+ return logger
138
+
139
+ except Exception as e:
140
+ print(f"Error initializing WandB: {str(e)}")
141
+ print("Continuing without WandB logging...")
142
+ return None
143
+
144
+ def main():
145
+ device = setup_training()
146
+
147
+ try:
148
+ # Load configuration
149
+ config = SmolLM2Config("config.yaml")
150
+
151
+ # Initialize model
152
+ model = SmolLM2Lightning(config)
153
+
154
+ # Phase 1: Initial Training
155
+ print("\n=== Starting Phase 1 Training ===")
156
+
157
+ # Initialize wandb logger for phase 1 with error handling
158
+ wandb_logger = init_wandb("smol-lm2", "training_run_phase1")
159
+
160
+ # Setup checkpoint callback for phase 1
161
+ checkpoint_callback = ModelCheckpoint(
162
+ dirpath=config.training.checkpoint_dir,
163
+ filename="smol-lm2-phase1-{epoch:02d}-{train_loss:.2f}",
164
+ save_top_k=3,
165
+ monitor="train_loss",
166
+ mode="min",
167
+ every_n_train_steps=config.training.save_steps
168
+ )
169
+
170
+ # Setup generation monitoring callback for phase 1
171
+ generation_callback = GenerationMonitorCallback(
172
+ prompt=config.training.sample_prompt,
173
+ sample_every_n_steps=config.training.sample_frequency
174
+ )
175
+
176
+ # Initialize trainer for phase 1
177
+ trainer_phase1 = pl.Trainer(
178
+ max_steps=config.training.first_phase_steps,
179
+ accelerator=config.hardware.accelerator,
180
+ devices=config.hardware.devices,
181
+ precision=config.hardware.precision,
182
+ logger=wandb_logger,
183
+ callbacks=[checkpoint_callback, generation_callback],
184
+ gradient_clip_val=config.hardware.gradient_clip,
185
+ accumulate_grad_batches=config.training.gradient_accumulation_steps,
186
+ log_every_n_steps=config.training.logging_steps,
187
+ deterministic=False,
188
+ benchmark=True,
189
+ strategy='auto', # Let PyTorch Lightning handle device strategy
190
+ )
191
+
192
+ # Train phase 1 with error handling
193
+ try:
194
+ trainer_phase1.fit(model)
195
+ except Exception as e:
196
+ print(f"Error during phase 1 training: {str(e)}")
197
+ raise
198
+
199
+ # Save phase 1 checkpoint
200
+ phase1_checkpoint_path = os.path.join(config.training.checkpoint_dir, "smol-lm2-phase1-final.ckpt")
201
+ trainer_phase1.save_checkpoint(phase1_checkpoint_path)
202
+ print(f"Phase 1 completed. Model saved to {phase1_checkpoint_path}")
203
+
204
+ # Clear GPU memory between phases
205
+ if torch.cuda.is_available():
206
+ torch.cuda.empty_cache()
207
+
208
+ # Phase 2: Fine-tuning
209
+ print("\n=== Starting Phase 2 Training ===")
210
+
211
+ # Load the model from phase 1 checkpoint with error handling
212
+ try:
213
+ model = SmolLM2Lightning.load_from_checkpoint(phase1_checkpoint_path, config=config)
214
+ except Exception as e:
215
+ print(f"Error loading checkpoint for phase 2: {str(e)}")
216
+ raise
217
+
218
+ # Initialize wandb logger for phase 2 with error handling
219
+ wandb_logger = init_wandb("smol-lm2", "training_run_phase2")
220
+
221
+ # Setup generation monitoring callback with higher frequency for phase 2
222
+ generation_callback = GenerationMonitorCallback(
223
+ prompt=config.training.sample_prompt,
224
+ sample_every_n_steps=config.training.second_phase_sample_frequency
225
+ )
226
+
227
+ # Initialize trainer for phase 2
228
+ trainer_phase2 = pl.Trainer(
229
+ max_steps=config.training.second_phase_steps,
230
+ accelerator=config.hardware.accelerator,
231
+ devices=config.hardware.devices,
232
+ precision=config.hardware.precision,
233
+ logger=wandb_logger,
234
+ callbacks=[generation_callback],
235
+ gradient_clip_val=config.hardware.gradient_clip,
236
+ accumulate_grad_batches=config.training.gradient_accumulation_steps,
237
+ log_every_n_steps=config.training.logging_steps,
238
+ deterministic=False,
239
+ benchmark=True,
240
+ )
241
+
242
+ # Train phase 2 with error handling
243
+ try:
244
+ trainer_phase2.fit(model)
245
+ except Exception as e:
246
+ print(f"Error during phase 2 training: {str(e)}")
247
+ raise
248
+
249
+ # Save final model
250
+ final_checkpoint_path = os.path.join(config.training.checkpoint_dir, "smol-lm2-final.ckpt")
251
+ trainer_phase2.save_checkpoint(final_checkpoint_path)
252
+ print(f"Phase 2 completed. Final model saved to {final_checkpoint_path}")
253
+
254
+ except Exception as e:
255
+ print(f"\nTraining failed with error: {str(e)}")
256
+ if torch.cuda.is_available():
257
+ print(f"CUDA memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
258
+ print(f"CUDA memory cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
259
+ raise
260
+ finally:
261
+ cleanup_training()
262
+
263
+ if __name__ == "__main__":
264
+ main()