padmanabhbosamia
commited on
Upload 5 files
Browse files- config.py +195 -0
- config.yaml +81 -0
- model.py +522 -0
- smol-lm2-final.ckpt +3 -0
- train_script.py +264 -0
config.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from typing import List, Optional
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class ModelConfig:
|
8 |
+
type: str = "custom"
|
9 |
+
name: str = "smollm2_transformer"
|
10 |
+
tokenizer_name: str = "HuggingFaceTB/SmolLM2-135M"
|
11 |
+
vocab_size: int = 49152
|
12 |
+
hidden_size: int = 576
|
13 |
+
num_attention_heads: int = 9
|
14 |
+
num_key_value_heads: int = 3
|
15 |
+
num_hidden_layers: int = 30
|
16 |
+
intermediate_size: int = 1536
|
17 |
+
hidden_act: str = "gelu"
|
18 |
+
max_position_embeddings: int = 512
|
19 |
+
initializer_range: float = 0.02
|
20 |
+
rms_norm_eps: float = 1e-5
|
21 |
+
use_cache: bool = True
|
22 |
+
pad_token_id: Optional[int] = None
|
23 |
+
max_length: int = 512
|
24 |
+
|
25 |
+
def __post_init__(self):
|
26 |
+
# Ensure numeric values are proper types
|
27 |
+
self.vocab_size = int(self.vocab_size)
|
28 |
+
self.hidden_size = int(self.hidden_size)
|
29 |
+
self.num_attention_heads = int(self.num_attention_heads)
|
30 |
+
self.num_key_value_heads = int(self.num_key_value_heads)
|
31 |
+
self.num_hidden_layers = int(self.num_hidden_layers)
|
32 |
+
self.intermediate_size = int(self.intermediate_size)
|
33 |
+
self.max_position_embeddings = int(self.max_position_embeddings)
|
34 |
+
self.initializer_range = float(self.initializer_range)
|
35 |
+
self.rms_norm_eps = float(self.rms_norm_eps)
|
36 |
+
self.max_length = int(self.max_length)
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class OptimizerConfig:
|
40 |
+
type: str = "adamW"
|
41 |
+
weight_decay: float = 0.01
|
42 |
+
adam_beta1: float = 0.9
|
43 |
+
adam_beta2: float = 0.95
|
44 |
+
adam_eps: float = 1e-8
|
45 |
+
torch_adam_is_fused: bool = True
|
46 |
+
clip_grad: float = 1.0
|
47 |
+
accumulate_grad_in_fp32: bool = True
|
48 |
+
|
49 |
+
def __post_init__(self):
|
50 |
+
# Ensure numeric values are proper floats
|
51 |
+
self.weight_decay = float(self.weight_decay)
|
52 |
+
self.adam_beta1 = float(self.adam_beta1)
|
53 |
+
self.adam_beta2 = float(self.adam_beta2)
|
54 |
+
self.adam_eps = float(self.adam_eps)
|
55 |
+
self.clip_grad = float(self.clip_grad)
|
56 |
+
|
57 |
+
@dataclass
|
58 |
+
class SchedulerConfig:
|
59 |
+
type: str = "one_cycle"
|
60 |
+
learning_rate: float = 0.003
|
61 |
+
warmup_steps: int = 100
|
62 |
+
max_lr: float = 0.003
|
63 |
+
pct_start: float = 0.02
|
64 |
+
anneal_strategy: str = "cos"
|
65 |
+
cycle_momentum: bool = False
|
66 |
+
div_factor: float = 25.0
|
67 |
+
final_div_factor: float = 1000.0
|
68 |
+
|
69 |
+
@dataclass
|
70 |
+
class TrainingConfig:
|
71 |
+
output_dir: str = "./results"
|
72 |
+
batch_size: int = 2
|
73 |
+
micro_batch_size: int = 1
|
74 |
+
gradient_accumulation_steps: int = 4
|
75 |
+
sequence_length: int = 512
|
76 |
+
learning_rate: float = 0.003
|
77 |
+
max_steps: int = 5050
|
78 |
+
first_phase_steps: int = 5000
|
79 |
+
second_phase_steps: int = 50
|
80 |
+
sample_frequency: int = 500
|
81 |
+
second_phase_sample_frequency: int = 10
|
82 |
+
logging_dir: str = "./logs"
|
83 |
+
logging_steps: int = 1
|
84 |
+
save_steps: int = 500
|
85 |
+
checkpoint_dir: str = "checkpoints"
|
86 |
+
sample_prompt: str = "Explain what machine learning is:"
|
87 |
+
max_generate_length: int = 100
|
88 |
+
|
89 |
+
@dataclass
|
90 |
+
class HardwareConfig:
|
91 |
+
precision: str = "16-mixed"
|
92 |
+
accelerator: str = "gpu"
|
93 |
+
devices: int = 1
|
94 |
+
strategy: str = "auto"
|
95 |
+
gradient_clip: float = 1.0
|
96 |
+
|
97 |
+
@dataclass
|
98 |
+
class DatasetConfig:
|
99 |
+
name: str
|
100 |
+
path: str
|
101 |
+
subset: str
|
102 |
+
weight: float
|
103 |
+
split_ratio: float = 1.0 # Default to using full dataset
|
104 |
+
|
105 |
+
@dataclass
|
106 |
+
class DataLoadingConfig:
|
107 |
+
num_workers: int = 2
|
108 |
+
batch_size: int = 32
|
109 |
+
pin_memory: bool = True
|
110 |
+
prefetch_factor: int = 2
|
111 |
+
persistent_workers: bool = True
|
112 |
+
|
113 |
+
@dataclass
|
114 |
+
class DataConfig:
|
115 |
+
datasets: List[DatasetConfig] = field(default_factory=list)
|
116 |
+
loading: DataLoadingConfig = field(default_factory=DataLoadingConfig)
|
117 |
+
|
118 |
+
class SmolLM2Config:
|
119 |
+
def __init__(self, config_path: str = None):
|
120 |
+
self.model = ModelConfig()
|
121 |
+
self.optimizer = OptimizerConfig()
|
122 |
+
self.scheduler = SchedulerConfig()
|
123 |
+
self.training = TrainingConfig()
|
124 |
+
self.hardware = HardwareConfig()
|
125 |
+
self.data = DataConfig()
|
126 |
+
|
127 |
+
# Default dataset configuration
|
128 |
+
self.data.datasets = [
|
129 |
+
DatasetConfig(
|
130 |
+
name="wikitext",
|
131 |
+
path="wikitext",
|
132 |
+
subset="wikitext-2-raw-v1",
|
133 |
+
weight=1.0
|
134 |
+
)
|
135 |
+
]
|
136 |
+
|
137 |
+
if config_path and os.path.exists(config_path):
|
138 |
+
self.load_from_yaml(config_path)
|
139 |
+
|
140 |
+
def load_from_yaml(self, config_path: str):
|
141 |
+
with open(config_path, 'r') as f:
|
142 |
+
config_dict = yaml.safe_load(f)
|
143 |
+
|
144 |
+
# Update configurations from yaml
|
145 |
+
if 'model' in config_dict:
|
146 |
+
for k, v in config_dict['model'].items():
|
147 |
+
setattr(self.model, k, v)
|
148 |
+
|
149 |
+
if 'optimizer' in config_dict:
|
150 |
+
for k, v in config_dict['optimizer'].items():
|
151 |
+
setattr(self.optimizer, k, v)
|
152 |
+
|
153 |
+
if 'scheduler' in config_dict:
|
154 |
+
for k, v in config_dict['scheduler'].items():
|
155 |
+
setattr(self.scheduler, k, v)
|
156 |
+
|
157 |
+
if 'training' in config_dict:
|
158 |
+
for k, v in config_dict['training'].items():
|
159 |
+
setattr(self.training, k, v)
|
160 |
+
|
161 |
+
if 'hardware' in config_dict:
|
162 |
+
for k, v in config_dict['hardware'].items():
|
163 |
+
setattr(self.hardware, k, v)
|
164 |
+
|
165 |
+
if 'data' in config_dict:
|
166 |
+
for k, v in config_dict['data'].items():
|
167 |
+
if k == 'datasets':
|
168 |
+
for dataset in v:
|
169 |
+
self.data.datasets.append(DatasetConfig(**dataset))
|
170 |
+
elif k == 'loading':
|
171 |
+
for k, v in config_dict['data']['loading'].items():
|
172 |
+
setattr(self.data.loading, k, v)
|
173 |
+
|
174 |
+
def save_to_yaml(self, config_path: str):
|
175 |
+
config_dict = {
|
176 |
+
'model': self.model.__dict__,
|
177 |
+
'optimizer': self.optimizer.__dict__,
|
178 |
+
'scheduler': self.scheduler.__dict__,
|
179 |
+
'training': self.training.__dict__,
|
180 |
+
'hardware': self.hardware.__dict__,
|
181 |
+
'data': self.data.__dict__
|
182 |
+
}
|
183 |
+
|
184 |
+
with open(config_path, 'w') as f:
|
185 |
+
yaml.dump(config_dict, f, default_flow_style=False)
|
186 |
+
|
187 |
+
def __repr__(self):
|
188 |
+
return f"SmolLM2Config(\n" \
|
189 |
+
f" model={self.model}\n" \
|
190 |
+
f" optimizer={self.optimizer}\n" \
|
191 |
+
f" scheduler={self.scheduler}\n" \
|
192 |
+
f" training={self.training}\n" \
|
193 |
+
f" hardware={self.hardware}\n" \
|
194 |
+
f" data={self.data}\n" \
|
195 |
+
f")"
|
config.yaml
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
type: "custom"
|
3 |
+
name: "smollm2_transformer"
|
4 |
+
tokenizer_name: "gpt2"
|
5 |
+
vocab_size: 50257
|
6 |
+
hidden_size: 256
|
7 |
+
num_attention_heads: 4
|
8 |
+
num_key_value_heads: 2
|
9 |
+
num_hidden_layers: 6
|
10 |
+
intermediate_size: 512
|
11 |
+
hidden_act: "gelu"
|
12 |
+
max_position_embeddings: 256
|
13 |
+
initializer_range: 0.02
|
14 |
+
rms_norm_eps: 1.0e-5
|
15 |
+
use_cache: true
|
16 |
+
pad_token_id: null
|
17 |
+
|
18 |
+
optimizer:
|
19 |
+
type: "adamW"
|
20 |
+
weight_decay: 0.01
|
21 |
+
adam_beta1: 0.9
|
22 |
+
adam_beta2: 0.95
|
23 |
+
adam_eps: 1.0e-8
|
24 |
+
torch_adam_is_fused: true
|
25 |
+
clip_grad: 1.0
|
26 |
+
accumulate_grad_in_fp32: true
|
27 |
+
|
28 |
+
scheduler:
|
29 |
+
type: "one_cycle"
|
30 |
+
learning_rate: 0.001
|
31 |
+
warmup_steps: 50
|
32 |
+
max_lr: 0.001
|
33 |
+
pct_start: 0.02
|
34 |
+
anneal_strategy: "cos"
|
35 |
+
cycle_momentum: false
|
36 |
+
div_factor: 25.0
|
37 |
+
final_div_factor: 1000.0
|
38 |
+
|
39 |
+
training:
|
40 |
+
output_dir: "./results"
|
41 |
+
batch_size: 4
|
42 |
+
micro_batch_size: 2
|
43 |
+
gradient_accumulation_steps: 2
|
44 |
+
sequence_length: 256
|
45 |
+
learning_rate: 0.001
|
46 |
+
max_steps: 5050 # Total steps (5000 + 50)
|
47 |
+
first_phase_steps: 5000 # Initial training phase
|
48 |
+
second_phase_steps: 50 # Fine-tuning phase
|
49 |
+
sample_frequency: 100 # Sample every 100 steps in first phase
|
50 |
+
second_phase_sample_frequency: 5 # Sample more frequently in second phase
|
51 |
+
logging_dir: "./logs"
|
52 |
+
logging_steps: 1
|
53 |
+
save_steps: 100
|
54 |
+
checkpoint_dir: "checkpoints"
|
55 |
+
sample_prompt: "Explain what machine learning is:"
|
56 |
+
max_generate_length: 50
|
57 |
+
|
58 |
+
hardware:
|
59 |
+
precision: "16-mixed"
|
60 |
+
accelerator: "gpu"
|
61 |
+
devices: 1
|
62 |
+
strategy: "auto"
|
63 |
+
gradient_clip: 1.0
|
64 |
+
cuda_memory_fraction: 0.9
|
65 |
+
allow_tf32: true
|
66 |
+
benchmark: true
|
67 |
+
deterministic: false
|
68 |
+
|
69 |
+
data:
|
70 |
+
datasets:
|
71 |
+
- name: "wikitext"
|
72 |
+
path: "wikitext"
|
73 |
+
subset: "wikitext-103-raw-v1"
|
74 |
+
split_ratio: 0.01 # Use only 1% of the dataset
|
75 |
+
weight: 1.0
|
76 |
+
loading:
|
77 |
+
num_workers: 2
|
78 |
+
batch_size: 16
|
79 |
+
pin_memory: true
|
80 |
+
prefetch_factor: 2
|
81 |
+
persistent_workers: true
|
model.py
ADDED
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.optim import AdamW
|
5 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
import torch.nn as nn
|
8 |
+
import math
|
9 |
+
from torch.utils.data import DataLoader, Dataset
|
10 |
+
from datasets import load_dataset
|
11 |
+
import os
|
12 |
+
|
13 |
+
def _init_weights(module, std=0.02):
|
14 |
+
if isinstance(module, nn.Linear):
|
15 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
16 |
+
elif isinstance(module, nn.Embedding):
|
17 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
18 |
+
|
19 |
+
class RMSNorm(nn.Module):
|
20 |
+
def __init__(self, dim, eps=1e-5):
|
21 |
+
super().__init__()
|
22 |
+
self.eps = float(eps) # Ensure eps is a float
|
23 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
27 |
+
return x * norm * self.weight
|
28 |
+
|
29 |
+
class RotaryEmbedding(nn.Module):
|
30 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000):
|
31 |
+
super().__init__()
|
32 |
+
self.dim = dim
|
33 |
+
self.max_position_embeddings = int(max_position_embeddings) # Convert to int
|
34 |
+
self.base = base
|
35 |
+
|
36 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
37 |
+
self.register_buffer("inv_freq", inv_freq)
|
38 |
+
|
39 |
+
t = torch.arange(self.max_position_embeddings).type_as(self.inv_freq)
|
40 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
41 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
42 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :])
|
43 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :])
|
44 |
+
|
45 |
+
def forward(self, x, seq_len=None):
|
46 |
+
# Convert seq_len to int and ensure it's a valid value
|
47 |
+
seq_len = int(seq_len) if seq_len is not None else x.size(1)
|
48 |
+
if seq_len > self.max_position_embeddings:
|
49 |
+
seq_len = self.max_position_embeddings
|
50 |
+
|
51 |
+
return (
|
52 |
+
self.cos_cached[:,:,:seq_len,:],
|
53 |
+
self.sin_cached[:,:,:seq_len,:]
|
54 |
+
)
|
55 |
+
|
56 |
+
def rotate_half(x):
|
57 |
+
"""Rotates half the hidden dims of the input."""
|
58 |
+
x1, x2 = x.chunk(2, dim=-1)
|
59 |
+
return torch.cat((-x2, x1), dim=-1)
|
60 |
+
|
61 |
+
def apply_rotary_pos_emb(q, k, cos, sin):
|
62 |
+
# Ensure proper broadcasting
|
63 |
+
cos = cos[:, :, :q.size(2), :] # [batch, 1, seq_len, dim]
|
64 |
+
sin = sin[:, :, :q.size(2), :] # [batch, 1, seq_len, dim]
|
65 |
+
|
66 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
67 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
68 |
+
return q_embed, k_embed
|
69 |
+
|
70 |
+
class Attention(nn.Module):
|
71 |
+
def __init__(self, config):
|
72 |
+
super().__init__()
|
73 |
+
self.hidden_size = config.hidden_size
|
74 |
+
self.num_attention_heads = config.num_attention_heads
|
75 |
+
self.num_key_value_heads = config.num_key_value_heads
|
76 |
+
self.head_dim = self.hidden_size // self.num_attention_heads
|
77 |
+
|
78 |
+
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
79 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
80 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
81 |
+
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
82 |
+
|
83 |
+
def forward(self, hidden_states, cos, sin, attention_mask=None):
|
84 |
+
batch_size, seq_length, _ = hidden_states.shape
|
85 |
+
|
86 |
+
q = self.q_proj(hidden_states)
|
87 |
+
k = self.k_proj(hidden_states)
|
88 |
+
v = self.v_proj(hidden_states)
|
89 |
+
|
90 |
+
# Reshape for attention computation
|
91 |
+
q = q.view(batch_size, seq_length, self.num_attention_heads, self.head_dim)
|
92 |
+
k = k.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim)
|
93 |
+
v = v.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim)
|
94 |
+
|
95 |
+
# Transpose for attention computation
|
96 |
+
q = q.transpose(1, 2) # [batch, num_heads, seq_len, head_dim]
|
97 |
+
k = k.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim]
|
98 |
+
v = v.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim]
|
99 |
+
|
100 |
+
# Apply rotary embeddings
|
101 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
102 |
+
|
103 |
+
# Repeat k/v heads if num_key_value_heads < num_attention_heads
|
104 |
+
if self.num_key_value_heads != self.num_attention_heads:
|
105 |
+
k = k.repeat_interleave(self.num_attention_heads // self.num_key_value_heads, dim=1)
|
106 |
+
v = v.repeat_interleave(self.num_attention_heads // self.num_key_value_heads, dim=1)
|
107 |
+
|
108 |
+
# Compute attention
|
109 |
+
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
110 |
+
|
111 |
+
if attention_mask is not None:
|
112 |
+
attn_weights = attn_weights + attention_mask
|
113 |
+
|
114 |
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
115 |
+
|
116 |
+
# Compute output
|
117 |
+
output = torch.matmul(attn_weights, v)
|
118 |
+
output = output.transpose(1, 2).contiguous() # [batch, seq_len, num_heads, head_dim]
|
119 |
+
output = output.view(batch_size, seq_length, -1)
|
120 |
+
|
121 |
+
return self.o_proj(output)
|
122 |
+
|
123 |
+
class MLP(nn.Module):
|
124 |
+
def __init__(self, config):
|
125 |
+
super().__init__()
|
126 |
+
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
127 |
+
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
128 |
+
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
129 |
+
self.act_fn = nn.SiLU()
|
130 |
+
|
131 |
+
def forward(self, x):
|
132 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
133 |
+
|
134 |
+
class DecoderLayer(nn.Module):
|
135 |
+
def __init__(self, config):
|
136 |
+
super().__init__()
|
137 |
+
self.self_attn = Attention(config)
|
138 |
+
self.mlp = MLP(config)
|
139 |
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
140 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
141 |
+
|
142 |
+
def forward(self, hidden_states, cos, sin, attention_mask=None):
|
143 |
+
# Self attention
|
144 |
+
residual = hidden_states
|
145 |
+
hidden_states = self.input_layernorm(hidden_states)
|
146 |
+
hidden_states = self.self_attn(hidden_states, cos, sin, attention_mask)
|
147 |
+
hidden_states = residual + hidden_states
|
148 |
+
|
149 |
+
# MLP
|
150 |
+
residual = hidden_states
|
151 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
152 |
+
hidden_states = self.mlp(hidden_states)
|
153 |
+
hidden_states = residual + hidden_states
|
154 |
+
|
155 |
+
return hidden_states
|
156 |
+
|
157 |
+
class SmolLM2(nn.Module):
|
158 |
+
def __init__(self, config):
|
159 |
+
super().__init__()
|
160 |
+
self.config = config
|
161 |
+
|
162 |
+
# Token embeddings
|
163 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
164 |
+
|
165 |
+
# Initialize transformer layers
|
166 |
+
self.layers = nn.ModuleList([
|
167 |
+
DecoderLayer(config) for _ in range(config.num_hidden_layers)
|
168 |
+
])
|
169 |
+
|
170 |
+
# Final layer norm
|
171 |
+
self.norm = RMSNorm(config.hidden_size, eps=float(config.rms_norm_eps))
|
172 |
+
|
173 |
+
# Initialize rotary embeddings
|
174 |
+
self.rotary_emb = RotaryEmbedding(
|
175 |
+
config.hidden_size // config.num_attention_heads,
|
176 |
+
max_position_embeddings=config.max_position_embeddings
|
177 |
+
)
|
178 |
+
|
179 |
+
# Initialize weights
|
180 |
+
self.apply(lambda p: _init_weights(p, std=config.initializer_range))
|
181 |
+
|
182 |
+
def forward(self, input_ids, attention_mask=None):
|
183 |
+
try:
|
184 |
+
# Ensure inputs are on the correct device
|
185 |
+
device = input_ids.device
|
186 |
+
batch_size, seq_length = input_ids.shape
|
187 |
+
|
188 |
+
# Input validation
|
189 |
+
if seq_length > self.config.max_position_embeddings:
|
190 |
+
raise ValueError(f"Input sequence length {seq_length} exceeds maximum position embeddings {self.config.max_position_embeddings}")
|
191 |
+
|
192 |
+
# Get embeddings
|
193 |
+
hidden_states = self.embed_tokens(input_ids)
|
194 |
+
|
195 |
+
# Get position embeddings
|
196 |
+
cos, sin = self.rotary_emb(hidden_states, seq_length)
|
197 |
+
|
198 |
+
# Generate attention mask if none provided
|
199 |
+
if attention_mask is None:
|
200 |
+
attention_mask = torch.ones(
|
201 |
+
(batch_size, seq_length),
|
202 |
+
dtype=torch.bool,
|
203 |
+
device=device
|
204 |
+
)
|
205 |
+
else:
|
206 |
+
# Convert to boolean if it's not already and ensure contiguous memory
|
207 |
+
attention_mask = attention_mask.bool().contiguous()
|
208 |
+
|
209 |
+
# Create causal mask
|
210 |
+
causal_mask = torch.triu(
|
211 |
+
torch.ones((seq_length, seq_length), device=device),
|
212 |
+
diagonal=1
|
213 |
+
).bool()
|
214 |
+
|
215 |
+
# Create attention mask [batch_size, 1, seq_length, seq_length]
|
216 |
+
attention_mask = attention_mask.view(batch_size, 1, 1, seq_length)
|
217 |
+
attention_mask = attention_mask.expand(batch_size, 1, seq_length, seq_length)
|
218 |
+
|
219 |
+
# Prepare causal mask
|
220 |
+
causal_mask = causal_mask.view(1, 1, seq_length, seq_length)
|
221 |
+
|
222 |
+
# Combine masks
|
223 |
+
mask = attention_mask & ~causal_mask
|
224 |
+
|
225 |
+
# Convert boolean mask to float mask
|
226 |
+
mask = mask.to(dtype=hidden_states.dtype)
|
227 |
+
mask = (1.0 - mask) * torch.finfo(hidden_states.dtype).min
|
228 |
+
|
229 |
+
# Apply transformer layers
|
230 |
+
for layer in self.layers:
|
231 |
+
hidden_states = layer(hidden_states, cos, sin, mask)
|
232 |
+
|
233 |
+
# Apply final normalization
|
234 |
+
hidden_states = self.norm(hidden_states)
|
235 |
+
|
236 |
+
# Project back to vocabulary
|
237 |
+
logits = F.linear(hidden_states, self.embed_tokens.weight)
|
238 |
+
|
239 |
+
return logits
|
240 |
+
|
241 |
+
except Exception as e:
|
242 |
+
print(f"\nForward pass error:")
|
243 |
+
print(f"Input shape: {input_ids.shape}")
|
244 |
+
print(f"Device: {input_ids.device}")
|
245 |
+
print(f"CUDA memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
|
246 |
+
print(f"Error: {str(e)}")
|
247 |
+
raise
|
248 |
+
|
249 |
+
def generate(
|
250 |
+
self,
|
251 |
+
input_ids,
|
252 |
+
attention_mask=None,
|
253 |
+
max_length=100,
|
254 |
+
temperature=0.7,
|
255 |
+
top_p=0.9,
|
256 |
+
top_k=50,
|
257 |
+
num_return_sequences=1,
|
258 |
+
do_sample=True,
|
259 |
+
pad_token_id=None,
|
260 |
+
bos_token_id=None,
|
261 |
+
eos_token_id=None
|
262 |
+
):
|
263 |
+
try:
|
264 |
+
batch_size = input_ids.shape[0]
|
265 |
+
current_length = input_ids.shape[1]
|
266 |
+
device = input_ids.device
|
267 |
+
|
268 |
+
# Input validation
|
269 |
+
if current_length >= self.config.max_position_embeddings:
|
270 |
+
raise ValueError(f"Input sequence length {current_length} exceeds maximum position embeddings {self.config.max_position_embeddings}")
|
271 |
+
|
272 |
+
# Ensure we don't exceed maximum position embeddings
|
273 |
+
max_length = min(max_length, self.config.max_position_embeddings)
|
274 |
+
|
275 |
+
# Initialize attention mask if None
|
276 |
+
if attention_mask is None:
|
277 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool, device=device)
|
278 |
+
|
279 |
+
for _ in range(max_length - current_length):
|
280 |
+
# Forward pass
|
281 |
+
outputs = self(input_ids, attention_mask)
|
282 |
+
next_token_logits = outputs[:, -1, :] / temperature
|
283 |
+
|
284 |
+
# Apply top-k filtering
|
285 |
+
if top_k > 0:
|
286 |
+
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
|
287 |
+
next_token_logits[indices_to_remove] = float('-inf')
|
288 |
+
|
289 |
+
# Apply top-p filtering
|
290 |
+
if top_p < 1.0:
|
291 |
+
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
292 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
293 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
294 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
295 |
+
sorted_indices_to_remove[..., 0] = 0
|
296 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
297 |
+
next_token_logits[indices_to_remove] = float('-inf')
|
298 |
+
|
299 |
+
# Sample from the filtered distribution
|
300 |
+
if do_sample:
|
301 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
302 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
303 |
+
else:
|
304 |
+
next_tokens = torch.argmax(next_token_logits, dim=-1)
|
305 |
+
|
306 |
+
# Append new tokens
|
307 |
+
input_ids = torch.cat([input_ids, next_tokens.unsqueeze(-1)], dim=-1)
|
308 |
+
attention_mask = torch.cat([attention_mask, torch.ones_like(next_tokens.unsqueeze(-1))], dim=-1)
|
309 |
+
|
310 |
+
# Stop if we've hit special tokens
|
311 |
+
if (pad_token_id is not None and (next_tokens == pad_token_id).all()) or \
|
312 |
+
(eos_token_id is not None and (next_tokens == eos_token_id).all()):
|
313 |
+
break
|
314 |
+
|
315 |
+
return input_ids
|
316 |
+
|
317 |
+
except Exception as e:
|
318 |
+
print(f"\nGeneration error:")
|
319 |
+
print(f"Input shape: {input_ids.shape}")
|
320 |
+
print(f"Device: {input_ids.device}")
|
321 |
+
print(f"Error: {str(e)}")
|
322 |
+
raise
|
323 |
+
|
324 |
+
class TextDataset(Dataset):
|
325 |
+
def __init__(self, config, split="train"):
|
326 |
+
self.config = config
|
327 |
+
|
328 |
+
# Load dataset from HuggingFace
|
329 |
+
full_dataset = load_dataset(
|
330 |
+
config.data.datasets[0].path,
|
331 |
+
config.data.datasets[0].subset,
|
332 |
+
split=split
|
333 |
+
)
|
334 |
+
|
335 |
+
# Apply split ratio if less than 1
|
336 |
+
if config.data.datasets[0].split_ratio < 1.0:
|
337 |
+
num_samples = int(len(full_dataset) * config.data.datasets[0].split_ratio)
|
338 |
+
self.dataset = full_dataset.select(range(num_samples))
|
339 |
+
else:
|
340 |
+
self.dataset = full_dataset
|
341 |
+
|
342 |
+
# Initialize tokenizer
|
343 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config.model.tokenizer_name)
|
344 |
+
if self.tokenizer.pad_token is None:
|
345 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
346 |
+
|
347 |
+
def __len__(self):
|
348 |
+
return len(self.dataset)
|
349 |
+
|
350 |
+
def __getitem__(self, idx):
|
351 |
+
# Get text from dataset
|
352 |
+
text = self.dataset[idx]["text"]
|
353 |
+
|
354 |
+
# Tokenize
|
355 |
+
encodings = self.tokenizer(
|
356 |
+
text,
|
357 |
+
truncation=True,
|
358 |
+
max_length=self.config.model.max_position_embeddings,
|
359 |
+
padding="max_length",
|
360 |
+
return_tensors="pt"
|
361 |
+
)
|
362 |
+
|
363 |
+
return {
|
364 |
+
"input_ids": encodings.input_ids.squeeze(),
|
365 |
+
"attention_mask": encodings.attention_mask.squeeze(),
|
366 |
+
"labels": encodings.input_ids.squeeze()
|
367 |
+
}
|
368 |
+
|
369 |
+
class SmolLM2Lightning(pl.LightningModule):
|
370 |
+
def __init__(self, config):
|
371 |
+
super().__init__()
|
372 |
+
self.save_hyperparameters()
|
373 |
+
self.config = config
|
374 |
+
|
375 |
+
# Initialize tokenizer
|
376 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config.model.tokenizer_name)
|
377 |
+
if self.tokenizer.pad_token is None:
|
378 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
379 |
+
|
380 |
+
# Initialize the base model
|
381 |
+
self.model = SmolLM2(config.model)
|
382 |
+
|
383 |
+
def forward(self, input_ids, attention_mask=None):
|
384 |
+
return self.model(input_ids, attention_mask)
|
385 |
+
|
386 |
+
def training_step(self, batch, batch_idx):
|
387 |
+
try:
|
388 |
+
input_ids = batch["input_ids"]
|
389 |
+
labels = batch["labels"]
|
390 |
+
attention_mask = batch.get("attention_mask", None)
|
391 |
+
|
392 |
+
# Ensure tensors are contiguous and on the correct device
|
393 |
+
inputs = input_ids[..., :-1].contiguous()
|
394 |
+
labels = input_ids[..., 1:].contiguous()
|
395 |
+
|
396 |
+
if attention_mask is not None:
|
397 |
+
attention_mask = attention_mask[..., :-1].contiguous()
|
398 |
+
|
399 |
+
# Forward pass
|
400 |
+
logits = self(inputs, attention_mask)
|
401 |
+
|
402 |
+
# Calculate loss
|
403 |
+
loss = F.cross_entropy(
|
404 |
+
logits.view(-1, self.config.model.vocab_size),
|
405 |
+
labels.view(-1),
|
406 |
+
ignore_index=self.config.model.pad_token_id if self.config.model.pad_token_id is not None else -100,
|
407 |
+
reduction='mean'
|
408 |
+
)
|
409 |
+
|
410 |
+
# Detach loss for logging
|
411 |
+
loss_value = loss.detach().float()
|
412 |
+
|
413 |
+
# Log metrics
|
414 |
+
self.log('train_loss', loss_value, prog_bar=True, on_step=True, sync_dist=True)
|
415 |
+
|
416 |
+
return loss
|
417 |
+
|
418 |
+
except Exception as e:
|
419 |
+
print(f"\nTraining step error:")
|
420 |
+
print(f"Input shape: {input_ids.shape if input_ids is not None else 'None'}")
|
421 |
+
print(f"Device: {input_ids.device if input_ids is not None else 'None'}")
|
422 |
+
print(f"Error: {str(e)}")
|
423 |
+
raise
|
424 |
+
|
425 |
+
def validation_step(self, batch, batch_idx):
|
426 |
+
try:
|
427 |
+
input_ids = batch["input_ids"]
|
428 |
+
labels = batch["labels"]
|
429 |
+
attention_mask = batch.get("attention_mask", None)
|
430 |
+
|
431 |
+
# Ensure tensors are contiguous and on the correct device
|
432 |
+
inputs = input_ids[..., :-1].contiguous()
|
433 |
+
labels = input_ids[..., 1:].contiguous()
|
434 |
+
|
435 |
+
if attention_mask is not None:
|
436 |
+
attention_mask = attention_mask[..., :-1].contiguous()
|
437 |
+
|
438 |
+
# Forward pass
|
439 |
+
logits = self(inputs, attention_mask)
|
440 |
+
|
441 |
+
# Calculate loss
|
442 |
+
loss = F.cross_entropy(
|
443 |
+
logits.view(-1, self.config.model.vocab_size),
|
444 |
+
labels.view(-1),
|
445 |
+
ignore_index=self.config.model.pad_token_id if self.config.model.pad_token_id is not None else -100,
|
446 |
+
reduction='mean'
|
447 |
+
)
|
448 |
+
|
449 |
+
# Detach loss for logging
|
450 |
+
loss_value = loss.detach().float()
|
451 |
+
|
452 |
+
# Log metrics
|
453 |
+
self.log('val_loss', loss_value, prog_bar=True, on_epoch=True, sync_dist=True)
|
454 |
+
|
455 |
+
return loss
|
456 |
+
|
457 |
+
except Exception as e:
|
458 |
+
print(f"\nValidation step error:")
|
459 |
+
print(f"Input shape: {input_ids.shape if input_ids is not None else 'None'}")
|
460 |
+
print(f"Device: {input_ids.device if input_ids is not None else 'None'}")
|
461 |
+
print(f"Error: {str(e)}")
|
462 |
+
raise
|
463 |
+
|
464 |
+
def configure_optimizers(self):
|
465 |
+
# Create optimizer with explicit type conversion
|
466 |
+
optimizer = AdamW(
|
467 |
+
self.parameters(),
|
468 |
+
lr=float(self.config.scheduler.learning_rate),
|
469 |
+
weight_decay=float(self.config.optimizer.weight_decay),
|
470 |
+
betas=(float(self.config.optimizer.adam_beta1),
|
471 |
+
float(self.config.optimizer.adam_beta2)),
|
472 |
+
eps=float(self.config.optimizer.adam_eps),
|
473 |
+
)
|
474 |
+
|
475 |
+
# Create scheduler
|
476 |
+
scheduler = OneCycleLR(
|
477 |
+
optimizer,
|
478 |
+
max_lr=float(self.config.scheduler.max_lr),
|
479 |
+
total_steps=int(self.config.training.max_steps),
|
480 |
+
pct_start=float(self.config.scheduler.pct_start),
|
481 |
+
anneal_strategy=self.config.scheduler.anneal_strategy,
|
482 |
+
cycle_momentum=bool(self.config.scheduler.cycle_momentum),
|
483 |
+
div_factor=float(self.config.scheduler.div_factor),
|
484 |
+
final_div_factor=float(self.config.scheduler.final_div_factor),
|
485 |
+
)
|
486 |
+
|
487 |
+
return {
|
488 |
+
"optimizer": optimizer,
|
489 |
+
"lr_scheduler": {
|
490 |
+
"scheduler": scheduler,
|
491 |
+
"interval": "step",
|
492 |
+
"frequency": 1
|
493 |
+
}
|
494 |
+
}
|
495 |
+
|
496 |
+
def generate(self, *args, **kwargs):
|
497 |
+
return self.model.generate(*args, **kwargs)
|
498 |
+
|
499 |
+
def train_dataloader(self):
|
500 |
+
dataset = TextDataset(self.config, split="train")
|
501 |
+
return DataLoader(
|
502 |
+
dataset,
|
503 |
+
batch_size=self.config.training.batch_size,
|
504 |
+
shuffle=True,
|
505 |
+
num_workers=self.config.data.loading.num_workers,
|
506 |
+
pin_memory=self.config.data.loading.pin_memory,
|
507 |
+
persistent_workers=True,
|
508 |
+
prefetch_factor=self.config.data.loading.prefetch_factor,
|
509 |
+
drop_last=True # Drop incomplete batches
|
510 |
+
)
|
511 |
+
|
512 |
+
def val_dataloader(self):
|
513 |
+
dataset = TextDataset(self.config, split="validation")
|
514 |
+
return DataLoader(
|
515 |
+
dataset,
|
516 |
+
batch_size=self.config.training.batch_size,
|
517 |
+
shuffle=False,
|
518 |
+
num_workers=self.config.data.loading.num_workers,
|
519 |
+
pin_memory=self.config.data.loading.pin_memory,
|
520 |
+
persistent_workers=True,
|
521 |
+
prefetch_factor=self.config.data.loading.prefetch_factor
|
522 |
+
)
|
smol-lm2-final.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7992a6cb4ad6ca593be88b64f9e4359f771afaeabf6da719bd6aab480461fb08
|
3 |
+
size 197102570
|
train_script.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import wandb
|
4 |
+
import shutil
|
5 |
+
from config import SmolLM2Config
|
6 |
+
from model import SmolLM2Lightning
|
7 |
+
import pytorch_lightning as pl
|
8 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
|
9 |
+
from pytorch_lightning.loggers import WandbLogger
|
10 |
+
from env_setup import setup_environment, cleanup_environment
|
11 |
+
|
12 |
+
# Set CUDA environment variables before any other CUDA operations
|
13 |
+
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
14 |
+
os.environ['TORCH_USE_CUDA_DSA'] = '1'
|
15 |
+
|
16 |
+
def setup_training():
|
17 |
+
"""Setup training environment"""
|
18 |
+
try:
|
19 |
+
if torch.cuda.is_available():
|
20 |
+
# Configure CUDA settings
|
21 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
22 |
+
torch.backends.cudnn.allow_tf32 = True
|
23 |
+
torch.backends.cudnn.benchmark = True
|
24 |
+
torch.set_float32_matmul_precision('high')
|
25 |
+
|
26 |
+
# Set default device
|
27 |
+
device = torch.device('cuda:0')
|
28 |
+
torch.cuda.set_device(device)
|
29 |
+
|
30 |
+
# Print GPU info
|
31 |
+
print(f"Using GPU: {torch.cuda.get_device_name()}")
|
32 |
+
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
33 |
+
return device
|
34 |
+
except Exception as e:
|
35 |
+
print(f"CUDA setup error: {str(e)}")
|
36 |
+
|
37 |
+
print("Using CPU")
|
38 |
+
return torch.device('cpu')
|
39 |
+
|
40 |
+
def cleanup_training():
|
41 |
+
"""Cleanup training resources"""
|
42 |
+
try:
|
43 |
+
# Move model to CPU before cleanup
|
44 |
+
if torch.cuda.is_available():
|
45 |
+
torch.cuda.empty_cache()
|
46 |
+
|
47 |
+
# Clean up wandb
|
48 |
+
try:
|
49 |
+
wandb.finish()
|
50 |
+
except:
|
51 |
+
pass
|
52 |
+
|
53 |
+
except Exception as e:
|
54 |
+
print(f"Cleanup error: {str(e)}")
|
55 |
+
|
56 |
+
# Setup CUDA at module level
|
57 |
+
device = setup_training()
|
58 |
+
|
59 |
+
class GenerationMonitorCallback(Callback):
|
60 |
+
def __init__(self, prompt="Explain what machine learning is:", sample_every_n_steps=500):
|
61 |
+
super().__init__()
|
62 |
+
self.prompt = prompt
|
63 |
+
self.sample_every_n_steps = sample_every_n_steps
|
64 |
+
|
65 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
66 |
+
try:
|
67 |
+
if (trainer.global_step + 1) % self.sample_every_n_steps == 0:
|
68 |
+
# Switch to eval mode
|
69 |
+
pl_module.eval()
|
70 |
+
|
71 |
+
with torch.no_grad():
|
72 |
+
# Tokenize prompt
|
73 |
+
inputs = pl_module.tokenizer(
|
74 |
+
self.prompt,
|
75 |
+
return_tensors="pt",
|
76 |
+
truncation=True,
|
77 |
+
max_length=pl_module.config.model.max_position_embeddings,
|
78 |
+
padding=True
|
79 |
+
).to(pl_module.device)
|
80 |
+
|
81 |
+
try:
|
82 |
+
# Generate text with error handling
|
83 |
+
outputs = pl_module.generate(
|
84 |
+
input_ids=inputs.input_ids,
|
85 |
+
attention_mask=inputs.attention_mask,
|
86 |
+
max_length=100,
|
87 |
+
temperature=0.7,
|
88 |
+
top_p=0.9,
|
89 |
+
top_k=50,
|
90 |
+
do_sample=True,
|
91 |
+
pad_token_id=pl_module.tokenizer.pad_token_id,
|
92 |
+
bos_token_id=pl_module.tokenizer.bos_token_id,
|
93 |
+
eos_token_id=pl_module.tokenizer.eos_token_id
|
94 |
+
)
|
95 |
+
|
96 |
+
# Decode generated text
|
97 |
+
generated_text = pl_module.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
98 |
+
|
99 |
+
# Print results
|
100 |
+
print(f"\n=== Generation at step {trainer.global_step + 1} ===")
|
101 |
+
print(f"Prompt: {self.prompt}")
|
102 |
+
print(f"Generated: {generated_text}\n")
|
103 |
+
|
104 |
+
except RuntimeError as e:
|
105 |
+
print(f"\nError during generation at step {trainer.global_step + 1}: {str(e)}")
|
106 |
+
print(f"Input shape: {inputs.input_ids.shape}")
|
107 |
+
print(f"Input device: {inputs.input_ids.device}")
|
108 |
+
|
109 |
+
# Switch back to train mode
|
110 |
+
pl_module.train()
|
111 |
+
|
112 |
+
except Exception as e:
|
113 |
+
print(f"\nCallback error at step {trainer.global_step + 1}: {str(e)}")
|
114 |
+
|
115 |
+
def init_wandb(project_name, run_name):
|
116 |
+
"""Initialize WandB with error handling and cleanup"""
|
117 |
+
try:
|
118 |
+
# Try to clean up any existing wandb directory
|
119 |
+
wandb_dir = os.path.join(os.getcwd(), "wandb")
|
120 |
+
if os.path.exists(wandb_dir):
|
121 |
+
try:
|
122 |
+
shutil.rmtree(wandb_dir)
|
123 |
+
print("Cleaned up existing wandb directory")
|
124 |
+
except Exception as e:
|
125 |
+
print(f"Warning: Could not clean up wandb directory: {str(e)}")
|
126 |
+
|
127 |
+
# Create fresh wandb directory with proper permissions
|
128 |
+
os.makedirs(wandb_dir, exist_ok=True)
|
129 |
+
|
130 |
+
# Initialize WandB logger
|
131 |
+
logger = WandbLogger(
|
132 |
+
project=project_name,
|
133 |
+
name=run_name,
|
134 |
+
save_dir=os.getcwd(),
|
135 |
+
settings=wandb.Settings(start_method="thread")
|
136 |
+
)
|
137 |
+
return logger
|
138 |
+
|
139 |
+
except Exception as e:
|
140 |
+
print(f"Error initializing WandB: {str(e)}")
|
141 |
+
print("Continuing without WandB logging...")
|
142 |
+
return None
|
143 |
+
|
144 |
+
def main():
|
145 |
+
device = setup_training()
|
146 |
+
|
147 |
+
try:
|
148 |
+
# Load configuration
|
149 |
+
config = SmolLM2Config("config.yaml")
|
150 |
+
|
151 |
+
# Initialize model
|
152 |
+
model = SmolLM2Lightning(config)
|
153 |
+
|
154 |
+
# Phase 1: Initial Training
|
155 |
+
print("\n=== Starting Phase 1 Training ===")
|
156 |
+
|
157 |
+
# Initialize wandb logger for phase 1 with error handling
|
158 |
+
wandb_logger = init_wandb("smol-lm2", "training_run_phase1")
|
159 |
+
|
160 |
+
# Setup checkpoint callback for phase 1
|
161 |
+
checkpoint_callback = ModelCheckpoint(
|
162 |
+
dirpath=config.training.checkpoint_dir,
|
163 |
+
filename="smol-lm2-phase1-{epoch:02d}-{train_loss:.2f}",
|
164 |
+
save_top_k=3,
|
165 |
+
monitor="train_loss",
|
166 |
+
mode="min",
|
167 |
+
every_n_train_steps=config.training.save_steps
|
168 |
+
)
|
169 |
+
|
170 |
+
# Setup generation monitoring callback for phase 1
|
171 |
+
generation_callback = GenerationMonitorCallback(
|
172 |
+
prompt=config.training.sample_prompt,
|
173 |
+
sample_every_n_steps=config.training.sample_frequency
|
174 |
+
)
|
175 |
+
|
176 |
+
# Initialize trainer for phase 1
|
177 |
+
trainer_phase1 = pl.Trainer(
|
178 |
+
max_steps=config.training.first_phase_steps,
|
179 |
+
accelerator=config.hardware.accelerator,
|
180 |
+
devices=config.hardware.devices,
|
181 |
+
precision=config.hardware.precision,
|
182 |
+
logger=wandb_logger,
|
183 |
+
callbacks=[checkpoint_callback, generation_callback],
|
184 |
+
gradient_clip_val=config.hardware.gradient_clip,
|
185 |
+
accumulate_grad_batches=config.training.gradient_accumulation_steps,
|
186 |
+
log_every_n_steps=config.training.logging_steps,
|
187 |
+
deterministic=False,
|
188 |
+
benchmark=True,
|
189 |
+
strategy='auto', # Let PyTorch Lightning handle device strategy
|
190 |
+
)
|
191 |
+
|
192 |
+
# Train phase 1 with error handling
|
193 |
+
try:
|
194 |
+
trainer_phase1.fit(model)
|
195 |
+
except Exception as e:
|
196 |
+
print(f"Error during phase 1 training: {str(e)}")
|
197 |
+
raise
|
198 |
+
|
199 |
+
# Save phase 1 checkpoint
|
200 |
+
phase1_checkpoint_path = os.path.join(config.training.checkpoint_dir, "smol-lm2-phase1-final.ckpt")
|
201 |
+
trainer_phase1.save_checkpoint(phase1_checkpoint_path)
|
202 |
+
print(f"Phase 1 completed. Model saved to {phase1_checkpoint_path}")
|
203 |
+
|
204 |
+
# Clear GPU memory between phases
|
205 |
+
if torch.cuda.is_available():
|
206 |
+
torch.cuda.empty_cache()
|
207 |
+
|
208 |
+
# Phase 2: Fine-tuning
|
209 |
+
print("\n=== Starting Phase 2 Training ===")
|
210 |
+
|
211 |
+
# Load the model from phase 1 checkpoint with error handling
|
212 |
+
try:
|
213 |
+
model = SmolLM2Lightning.load_from_checkpoint(phase1_checkpoint_path, config=config)
|
214 |
+
except Exception as e:
|
215 |
+
print(f"Error loading checkpoint for phase 2: {str(e)}")
|
216 |
+
raise
|
217 |
+
|
218 |
+
# Initialize wandb logger for phase 2 with error handling
|
219 |
+
wandb_logger = init_wandb("smol-lm2", "training_run_phase2")
|
220 |
+
|
221 |
+
# Setup generation monitoring callback with higher frequency for phase 2
|
222 |
+
generation_callback = GenerationMonitorCallback(
|
223 |
+
prompt=config.training.sample_prompt,
|
224 |
+
sample_every_n_steps=config.training.second_phase_sample_frequency
|
225 |
+
)
|
226 |
+
|
227 |
+
# Initialize trainer for phase 2
|
228 |
+
trainer_phase2 = pl.Trainer(
|
229 |
+
max_steps=config.training.second_phase_steps,
|
230 |
+
accelerator=config.hardware.accelerator,
|
231 |
+
devices=config.hardware.devices,
|
232 |
+
precision=config.hardware.precision,
|
233 |
+
logger=wandb_logger,
|
234 |
+
callbacks=[generation_callback],
|
235 |
+
gradient_clip_val=config.hardware.gradient_clip,
|
236 |
+
accumulate_grad_batches=config.training.gradient_accumulation_steps,
|
237 |
+
log_every_n_steps=config.training.logging_steps,
|
238 |
+
deterministic=False,
|
239 |
+
benchmark=True,
|
240 |
+
)
|
241 |
+
|
242 |
+
# Train phase 2 with error handling
|
243 |
+
try:
|
244 |
+
trainer_phase2.fit(model)
|
245 |
+
except Exception as e:
|
246 |
+
print(f"Error during phase 2 training: {str(e)}")
|
247 |
+
raise
|
248 |
+
|
249 |
+
# Save final model
|
250 |
+
final_checkpoint_path = os.path.join(config.training.checkpoint_dir, "smol-lm2-final.ckpt")
|
251 |
+
trainer_phase2.save_checkpoint(final_checkpoint_path)
|
252 |
+
print(f"Phase 2 completed. Final model saved to {final_checkpoint_path}")
|
253 |
+
|
254 |
+
except Exception as e:
|
255 |
+
print(f"\nTraining failed with error: {str(e)}")
|
256 |
+
if torch.cuda.is_available():
|
257 |
+
print(f"CUDA memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
|
258 |
+
print(f"CUDA memory cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
|
259 |
+
raise
|
260 |
+
finally:
|
261 |
+
cleanup_training()
|
262 |
+
|
263 |
+
if __name__ == "__main__":
|
264 |
+
main()
|