Feat: Upload app files
Browse files- app.py +145 -0
- config.py +149 -0
- inference.py +102 -0
- last.ckpt +3 -0
- requirements.txt +15 -0
- smollmv2.py +243 -0
- smollv2_lightning.py +498 -0
app.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! /usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
This script is a simple text generator using the SmollmV2 model.
|
| 4 |
+
It uses Gradio to create a web interface for generating text.
|
| 5 |
+
"""
|
| 6 |
+
# Third-Party Imports
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import gradio as gr
|
| 10 |
+
from transformers import GPT2Tokenizer
|
| 11 |
+
import spaces
|
| 12 |
+
import os
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
# Local imports
|
| 16 |
+
from smollmv2 import SmollmV2
|
| 17 |
+
from config import SmollmConfig, DataConfig
|
| 18 |
+
from smollv2_lightning import LitSmollmv2
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def combine_model_parts(model_dir="split_models", output_file="checkpoints/last.ckpt"):
|
| 22 |
+
"""
|
| 23 |
+
Combine split model parts into a single checkpoint file
|
| 24 |
+
"""
|
| 25 |
+
# Create checkpoints directory if it doesn't exist
|
| 26 |
+
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
| 27 |
+
|
| 28 |
+
# Check if combined model already exists
|
| 29 |
+
if os.path.exists(output_file):
|
| 30 |
+
print(f"Model already combined at: {output_file}")
|
| 31 |
+
return output_file
|
| 32 |
+
|
| 33 |
+
# Ensure the model parts exist
|
| 34 |
+
if not os.path.exists(model_dir):
|
| 35 |
+
raise FileNotFoundError(f"Model directory {model_dir} not found")
|
| 36 |
+
|
| 37 |
+
# Combine the parts
|
| 38 |
+
parts = sorted(Path(model_dir).glob("last.ckpt.part_*"))
|
| 39 |
+
if not parts:
|
| 40 |
+
raise FileNotFoundError("No model parts found")
|
| 41 |
+
|
| 42 |
+
print("Combining model parts...")
|
| 43 |
+
with open(output_file, 'wb') as outfile:
|
| 44 |
+
for part in parts:
|
| 45 |
+
print(f"Processing part: {part}")
|
| 46 |
+
with open(part, 'rb') as infile:
|
| 47 |
+
outfile.write(infile.read())
|
| 48 |
+
|
| 49 |
+
print(f"Model combined successfully: {output_file}")
|
| 50 |
+
return output_file
|
| 51 |
+
|
| 52 |
+
def load_model():
|
| 53 |
+
"""
|
| 54 |
+
Load the SmollmV2 model and tokenizer.
|
| 55 |
+
"""
|
| 56 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 57 |
+
|
| 58 |
+
# Combine model parts and get the checkpoint path
|
| 59 |
+
checkpoint_path = combine_model_parts()
|
| 60 |
+
|
| 61 |
+
# Load the model from combined checkpoint using Lightning module
|
| 62 |
+
model = LitSmollmv2.load_from_checkpoint(
|
| 63 |
+
checkpoint_path,
|
| 64 |
+
model_config=SmollmConfig,
|
| 65 |
+
strict=False
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
model.to(device)
|
| 69 |
+
model.eval()
|
| 70 |
+
|
| 71 |
+
# Initialize tokenizer
|
| 72 |
+
tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path)
|
| 73 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 74 |
+
|
| 75 |
+
return model, tokenizer, device
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@spaces.GPU(enable_queue=True)
|
| 79 |
+
def generate_text(prompt, num_tokens, temperature=0.8, top_p=0.9):
|
| 80 |
+
"""
|
| 81 |
+
Generate text using the SmollmV2 model.
|
| 82 |
+
"""
|
| 83 |
+
# Ensure num_tokens doesn't exceed model's block size
|
| 84 |
+
num_tokens = min(num_tokens, SmollmConfig.block_size)
|
| 85 |
+
|
| 86 |
+
# Tokenize input prompt
|
| 87 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
| 88 |
+
|
| 89 |
+
# Generate tokens one at a time
|
| 90 |
+
for _ in range(num_tokens):
|
| 91 |
+
# Get the model's predictions
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
with torch.autocast(device_type=device, dtype=torch.bfloat16):
|
| 94 |
+
logits, _ = model.model(input_ids)
|
| 95 |
+
|
| 96 |
+
# Get the next token probabilities
|
| 97 |
+
logits = logits[:, -1, :] / temperature
|
| 98 |
+
probs = F.softmax(logits, dim=-1)
|
| 99 |
+
|
| 100 |
+
# Apply top-p sampling
|
| 101 |
+
if top_p > 0:
|
| 102 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
| 103 |
+
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 104 |
+
sorted_indices_to_keep = cumsum_probs <= top_p
|
| 105 |
+
sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()
|
| 106 |
+
sorted_indices_to_keep[..., 0] = 1
|
| 107 |
+
indices_to_keep = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_keep)
|
| 108 |
+
probs = torch.where(indices_to_keep, probs, torch.zeros_like(probs))
|
| 109 |
+
probs = probs / probs.sum(dim=-1, keepdim=True)
|
| 110 |
+
|
| 111 |
+
# Sample next token
|
| 112 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 113 |
+
|
| 114 |
+
# Append to input_ids
|
| 115 |
+
input_ids = torch.cat([input_ids, next_token], dim=-1)
|
| 116 |
+
|
| 117 |
+
# Stop if we generate an EOS token
|
| 118 |
+
if next_token.item() == tokenizer.eos_token_id:
|
| 119 |
+
break
|
| 120 |
+
|
| 121 |
+
# Decode and return the generated text
|
| 122 |
+
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
| 123 |
+
return generated_text
|
| 124 |
+
|
| 125 |
+
# Load the model globally
|
| 126 |
+
model, tokenizer, device = load_model()
|
| 127 |
+
|
| 128 |
+
# Create the Gradio interface
|
| 129 |
+
demo = gr.Interface(
|
| 130 |
+
fn=generate_text,
|
| 131 |
+
inputs=[
|
| 132 |
+
gr.Textbox(label="Enter your prompt", value="Once upon a time"),
|
| 133 |
+
gr.Slider(minimum=1, maximum=SmollmConfig.block_size, value=100, step=1, label="Number of tokens to generate"),
|
| 134 |
+
gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature (higher = more random)"),
|
| 135 |
+
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p (nucleus sampling)")
|
| 136 |
+
],
|
| 137 |
+
outputs=gr.Textbox(label="Generated Text"),
|
| 138 |
+
title="SmollmV2 Text Generator",
|
| 139 |
+
description="Generate text using the SmollmV2 model",
|
| 140 |
+
allow_flagging="never",
|
| 141 |
+
cache_examples=True
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
if __name__ == "__main__":
|
| 145 |
+
demo.launch()
|
config.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Configuration class for GPT model
|
| 4 |
+
Author: Shilpaj Bhalerao
|
| 5 |
+
Date: 2025-01-19
|
| 6 |
+
"""
|
| 7 |
+
# Standard Library Imports
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class RoPEConfig:
|
| 13 |
+
"""
|
| 14 |
+
Configuration for Rotary Position Embeddings
|
| 15 |
+
"""
|
| 16 |
+
base: int = 10000 # Base for the angle calculations
|
| 17 |
+
scaling_factor: float = 1.0 # Scaling factor for rotary embeddings
|
| 18 |
+
head_dim_fraction: float = 0.3125 # Set to get exactly kv_dim=24 (216 total)
|
| 19 |
+
round_multiple: int = 8 # Round kv_dim to nearest multiple of this number
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class SmollmConfig:
|
| 24 |
+
"""
|
| 25 |
+
Configuration for Smollm training setup
|
| 26 |
+
"""
|
| 27 |
+
# Model configuration
|
| 28 |
+
block_size: int = 2048 # max sequence length
|
| 29 |
+
vocab_size: int = 49152 # vocabulary size
|
| 30 |
+
n_layer: int = 30 # number of transformer layers
|
| 31 |
+
n_head: int = 9 # number of attention heads
|
| 32 |
+
n_embd: int = 576 # embedding dimension
|
| 33 |
+
mlp_ratio: int = 2.67 # Based on MLP implementation (1536/576)
|
| 34 |
+
dropout: float = 0.0 # No dropout used in implementation
|
| 35 |
+
|
| 36 |
+
# Training configuration
|
| 37 |
+
batch_size: int = 1 # Minimum batch size (from smollv2_lightning.py)
|
| 38 |
+
num_workers: int = 0 # No additional workers to save memory
|
| 39 |
+
shuffle_buffer_size: int = 1000 # Shuffle buffer size for dataset
|
| 40 |
+
max_length: int = 2048 # Sequence length for training
|
| 41 |
+
learning_rate: float = 3e-5 # From LitGPT initialization
|
| 42 |
+
weight_decay: float = 1e-4 # From LitGPT initialization
|
| 43 |
+
|
| 44 |
+
# Generation configuration
|
| 45 |
+
max_new_tokens: int = 100 # From generation code in training_step
|
| 46 |
+
|
| 47 |
+
# Training control
|
| 48 |
+
seed: int = 1337
|
| 49 |
+
max_steps: int = 5000
|
| 50 |
+
clear_cache_every: int = 1000 # Clear GPU cache every N steps, 0 to disable
|
| 51 |
+
|
| 52 |
+
# Generation parameters
|
| 53 |
+
context_length: int = 10 # Number of tokens to use as context
|
| 54 |
+
temperature: float = 1.0 # Sampling temperature
|
| 55 |
+
top_k: int = 50 # Top-k sampling parameter
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@dataclass
|
| 59 |
+
class CheckpointConfig:
|
| 60 |
+
"""
|
| 61 |
+
Configuration for checkpointing
|
| 62 |
+
"""
|
| 63 |
+
checkpoint_dir: str = "checkpoints"
|
| 64 |
+
checkpoint_every: int = 500 # Save checkpoint every 500 steps
|
| 65 |
+
save_last: bool = True
|
| 66 |
+
save_top_k: int = 1 # Changed from checkpoint_save_top_k
|
| 67 |
+
save_weights_only: bool = True # Changed from checkpoint_save_weights_only
|
| 68 |
+
monitor: str = "train_loss" # Monitor training loss for checkpointing
|
| 69 |
+
mode: str = "min" # Mode for the monitor metric
|
| 70 |
+
save_on_train_epoch_end: bool = False # Whether to save on training epoch end
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@dataclass
|
| 74 |
+
class LoggingConfig:
|
| 75 |
+
"""
|
| 76 |
+
Configuration for logging
|
| 77 |
+
"""
|
| 78 |
+
log_every: int = 50 # Log metrics every 50 steps
|
| 79 |
+
generate_every: int = 500 # Generate sample text every 500 steps
|
| 80 |
+
log_metrics: bool = True
|
| 81 |
+
log_progress_bar: bool = True
|
| 82 |
+
log_model_summary: bool = True
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@dataclass
|
| 86 |
+
class OptimizerConfig:
|
| 87 |
+
"""
|
| 88 |
+
Configuration for optimizer
|
| 89 |
+
"""
|
| 90 |
+
optimizer: str = "AdamW" # Using AdamW optimizer
|
| 91 |
+
learning_rate: float = 3e-5
|
| 92 |
+
weight_decay: float = 1e-4
|
| 93 |
+
max_lr: float = 3e-4 # max_lr = learning_rate * 10
|
| 94 |
+
div_factor: float = 25.0 # From OneCycleLR config
|
| 95 |
+
final_div_factor: float = 100.0 # From OneCycleLR config
|
| 96 |
+
pct_start: float = 0.2 # From OneCycleLR config
|
| 97 |
+
|
| 98 |
+
# Additional optimizer settings
|
| 99 |
+
optimizer_kwargs: dict = field(default_factory=lambda: {
|
| 100 |
+
'betas': (0.9, 0.95), # Default betas for AdamW
|
| 101 |
+
'eps': 1e-8, # Default epsilon value
|
| 102 |
+
})
|
| 103 |
+
three_phase: bool = False # Use three-phase learning rate schedule
|
| 104 |
+
anneal_strategy: str = 'linear' # Learning rate annealing strategy
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@dataclass
|
| 108 |
+
class DataConfig:
|
| 109 |
+
"""
|
| 110 |
+
Configuration for dataset and tokenizer
|
| 111 |
+
"""
|
| 112 |
+
# Dataset configuration
|
| 113 |
+
dataset_path: str = "HuggingFaceTB/smollm-corpus"
|
| 114 |
+
dataset_name: str = "cosmopedia-v2"
|
| 115 |
+
|
| 116 |
+
# Tokenizer configuration
|
| 117 |
+
tokenizer_path: str = "HuggingFaceTB/cosmo2-tokenizer"
|
| 118 |
+
|
| 119 |
+
# DataLoader configuration
|
| 120 |
+
batch_size: int = 32
|
| 121 |
+
num_workers: int = 4
|
| 122 |
+
shuffle_buffer_size: int = 1000
|
| 123 |
+
max_length: int = 512
|
| 124 |
+
|
| 125 |
+
# Dataset splits
|
| 126 |
+
validation_split: float = 0.1 # 10% for validation
|
| 127 |
+
pin_memory: bool = True
|
| 128 |
+
streaming: bool = True # Use streaming mode for dataset
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@dataclass
|
| 132 |
+
class TrainerConfig:
|
| 133 |
+
"""
|
| 134 |
+
Configuration for PyTorch Lightning Trainer
|
| 135 |
+
"""
|
| 136 |
+
accelerator: str = 'auto'
|
| 137 |
+
devices: int = 1
|
| 138 |
+
precision: str = '16-mixed'
|
| 139 |
+
log_every_n_steps: int = 10
|
| 140 |
+
strategy: str = 'auto'
|
| 141 |
+
deterministic: bool = False
|
| 142 |
+
benchmark: bool = True
|
| 143 |
+
enable_progress_bar: bool = True
|
| 144 |
+
enable_model_summary: bool = True
|
| 145 |
+
profiler: str = 'simple'
|
| 146 |
+
gradient_clip_val: float = 1.0
|
| 147 |
+
accumulate_grad_batches: int = 2
|
| 148 |
+
val_check_interval: int = 1000 # Run validation every N training steps
|
| 149 |
+
check_val_every_n_epoch: None = None # Disable epoch-based validation
|
inference.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! /usr/bin/env python
|
| 2 |
+
"""
|
| 3 |
+
Inference script for SmollmV2 model
|
| 4 |
+
Author: Shilpaj Bhalerao
|
| 5 |
+
Date: 2025-01-25
|
| 6 |
+
"""
|
| 7 |
+
# Third-Party Imports
|
| 8 |
+
import torch
|
| 9 |
+
from transformers import GPT2Tokenizer
|
| 10 |
+
|
| 11 |
+
# Local Imports
|
| 12 |
+
from smollv2_lightning import LitSmollmv2
|
| 13 |
+
from config import SmollmConfig, DataConfig
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_model(checkpoint_path):
|
| 17 |
+
"""
|
| 18 |
+
Load the trained model from checkpoint.
|
| 19 |
+
"""
|
| 20 |
+
model = LitSmollmv2.load_from_checkpoint(
|
| 21 |
+
checkpoint_path,
|
| 22 |
+
model_config=SmollmConfig,
|
| 23 |
+
strict=False
|
| 24 |
+
)
|
| 25 |
+
model.eval()
|
| 26 |
+
return model
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def generate_text(model, prompt, max_new_tokens=100, temperature=0.8, top_p=0.9):
|
| 30 |
+
"""
|
| 31 |
+
Generate text using the loaded model.
|
| 32 |
+
"""
|
| 33 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 34 |
+
model = model.to(device)
|
| 35 |
+
|
| 36 |
+
# Initialize tokenizer the same way as in CosmopediaDataModule
|
| 37 |
+
tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path)
|
| 38 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 39 |
+
|
| 40 |
+
# Tokenize input prompt
|
| 41 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
| 42 |
+
|
| 43 |
+
# Generate tokens one at a time
|
| 44 |
+
for _ in range(max_new_tokens):
|
| 45 |
+
# Get the model's predictions
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
logits, _ = model.model(input_ids)
|
| 48 |
+
|
| 49 |
+
# Get the next token probabilities
|
| 50 |
+
logits = logits[:, -1, :] / temperature
|
| 51 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 52 |
+
|
| 53 |
+
# Sample from the distribution
|
| 54 |
+
if top_p > 0:
|
| 55 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
| 56 |
+
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 57 |
+
sorted_indices_to_keep = cumsum_probs <= top_p
|
| 58 |
+
sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()
|
| 59 |
+
sorted_indices_to_keep[..., 0] = 1
|
| 60 |
+
indices_to_keep = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_keep)
|
| 61 |
+
probs = torch.where(indices_to_keep, probs, torch.zeros_like(probs))
|
| 62 |
+
probs = probs / probs.sum(dim=-1, keepdim=True)
|
| 63 |
+
|
| 64 |
+
# Sample next token
|
| 65 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 66 |
+
|
| 67 |
+
# Append to input_ids
|
| 68 |
+
input_ids = torch.cat([input_ids, next_token], dim=-1)
|
| 69 |
+
|
| 70 |
+
# Stop if we generate an EOS token
|
| 71 |
+
if next_token.item() == tokenizer.eos_token_id:
|
| 72 |
+
break
|
| 73 |
+
|
| 74 |
+
# Decode and return the generated text
|
| 75 |
+
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
| 76 |
+
return generated_text
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def main():
|
| 80 |
+
# Path to your checkpoint
|
| 81 |
+
checkpoint_path = "./checkpoints/last.ckpt"
|
| 82 |
+
|
| 83 |
+
# Load the model
|
| 84 |
+
model = load_model(checkpoint_path)
|
| 85 |
+
print("Model loaded successfully!")
|
| 86 |
+
|
| 87 |
+
# Example prompts for generation
|
| 88 |
+
prompts = [
|
| 89 |
+
"Once upon a time",
|
| 90 |
+
"The future of artificial intelligence",
|
| 91 |
+
"In the distant galaxy"
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
# Generate text for each prompt
|
| 95 |
+
for prompt in prompts:
|
| 96 |
+
print("\nPrompt:", prompt)
|
| 97 |
+
generated = generate_text(prompt=prompt, model=model)
|
| 98 |
+
print("Generated:", generated)
|
| 99 |
+
print("-" * 50)
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
main()
|
last.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4c7f0b043f2a6492e6f20568c0842d06c64fe20c95ddb03ca3a7fcab5f57e2d4
|
| 3 |
+
size 811285105
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core ML libraries
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
transformers>=4.30.0
|
| 4 |
+
lightning>=2.0.0
|
| 5 |
+
|
| 6 |
+
# Web UI
|
| 7 |
+
gradio>=5.13.1
|
| 8 |
+
|
| 9 |
+
# HuggingFace Space utilities
|
| 10 |
+
huggingface-hub>=0.19.0
|
| 11 |
+
spaces>=0.19.0
|
| 12 |
+
|
| 13 |
+
# Optional dependencies for better performance
|
| 14 |
+
accelerate>=0.20.0
|
| 15 |
+
bitsandbytes>=0.41.0
|
smollmv2.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! /usr/bin/env python
|
| 2 |
+
"""
|
| 3 |
+
SmollmV2 model implementation
|
| 4 |
+
Author: Shilpaj Bhalerao
|
| 5 |
+
Date: 2025-01-19
|
| 6 |
+
"""
|
| 7 |
+
# Third-Party Imports
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import math
|
| 12 |
+
|
| 13 |
+
# Local Imports
|
| 14 |
+
from config import SmollmConfig, RoPEConfig
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class RoPEAttention:
|
| 18 |
+
"""
|
| 19 |
+
Rotary Position Embedding attention with support for different Q/K dimensions
|
| 20 |
+
"""
|
| 21 |
+
def __init__(self, head_dim, kv_dim, base=RoPEConfig.base):
|
| 22 |
+
"""
|
| 23 |
+
Initialize rotary embeddings
|
| 24 |
+
Args:
|
| 25 |
+
head_dim: Dimension of query head
|
| 26 |
+
kv_dim: Dimension of key/value head
|
| 27 |
+
base: Base for the angle calculations (default: 10000)
|
| 28 |
+
"""
|
| 29 |
+
super().__init__()
|
| 30 |
+
|
| 31 |
+
# Generate theta parameter for rotary embeddings for both Q and K dimensions
|
| 32 |
+
inv_freq_k = 1.0 / (base ** (torch.arange(0, kv_dim, 2).float() / kv_dim))
|
| 33 |
+
self.register_buffer('inv_freq_k', inv_freq_k)
|
| 34 |
+
|
| 35 |
+
self.head_dim = head_dim
|
| 36 |
+
self.kv_dim = kv_dim
|
| 37 |
+
self.seq_len_cached = None
|
| 38 |
+
self.cos_cached = None
|
| 39 |
+
self.sin_cached = None
|
| 40 |
+
|
| 41 |
+
def _update_cos_sin_cache(self, x, seq_len):
|
| 42 |
+
"""Update cached cos and sin values for given sequence length"""
|
| 43 |
+
if seq_len != self.seq_len_cached:
|
| 44 |
+
self.seq_len_cached = seq_len
|
| 45 |
+
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq_k)
|
| 46 |
+
freqs = torch.einsum('i,j->ij', t, self.inv_freq_k)
|
| 47 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
| 48 |
+
|
| 49 |
+
self.cos_cached = emb.cos()[None, None, :, :]
|
| 50 |
+
self.sin_cached = emb.sin()[None, None, :, :]
|
| 51 |
+
|
| 52 |
+
def _rotate_half(self, x):
|
| 53 |
+
"""Rotate half the hidden dims of the input."""
|
| 54 |
+
x1 = x[..., :x.shape[-1] // 2]
|
| 55 |
+
x2 = x[..., x.shape[-1] // 2:]
|
| 56 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 57 |
+
|
| 58 |
+
def __call__(self, q, k):
|
| 59 |
+
"""
|
| 60 |
+
Apply rotary embeddings to input queries and keys
|
| 61 |
+
Args:
|
| 62 |
+
q: Query tensor of shape (batch, n_head, seq_len, head_dim)
|
| 63 |
+
k: Key tensor of shape (batch, n_head, seq_len, kv_dim)
|
| 64 |
+
Returns:
|
| 65 |
+
q_rot: Rotated query tensor
|
| 66 |
+
k_rot: Rotated key tensor
|
| 67 |
+
"""
|
| 68 |
+
seq_len = q.shape[2]
|
| 69 |
+
self._update_cos_sin_cache(k, seq_len)
|
| 70 |
+
|
| 71 |
+
# Apply rotary embeddings to keys
|
| 72 |
+
k_cos = self.cos_cached[..., :self.kv_dim]
|
| 73 |
+
k_sin = self.sin_cached[..., :self.kv_dim]
|
| 74 |
+
k_rot = (k * k_cos) + (self._rotate_half(k) * k_sin)
|
| 75 |
+
|
| 76 |
+
# For queries, we only apply rotation to the part that interacts with keys
|
| 77 |
+
q_part = q[..., :self.kv_dim]
|
| 78 |
+
q_cos = self.cos_cached[..., :self.kv_dim]
|
| 79 |
+
q_sin = self.sin_cached[..., :self.kv_dim]
|
| 80 |
+
q_rot_part = (q_part * q_cos) + (self._rotate_half(q_part) * q_sin)
|
| 81 |
+
|
| 82 |
+
# Combine rotated part with unrotated parts for query
|
| 83 |
+
q_rot = torch.cat([q_rot_part, q[..., self.kv_dim:]], dim=-1)
|
| 84 |
+
|
| 85 |
+
return q_rot, k_rot
|
| 86 |
+
|
| 87 |
+
def register_buffer(self, name, tensor):
|
| 88 |
+
"""Helper function to register a buffer"""
|
| 89 |
+
setattr(self, name, tensor)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class CausalSelfAttention(nn.Module):
|
| 93 |
+
"""
|
| 94 |
+
Causal self-attention mechanism with reduced KV dimensions and RoPE
|
| 95 |
+
"""
|
| 96 |
+
def __init__(self, config):
|
| 97 |
+
super().__init__()
|
| 98 |
+
assert config.n_embd % config.n_head == 0
|
| 99 |
+
|
| 100 |
+
# Calculate dimensions
|
| 101 |
+
self.head_dim = config.n_embd // config.n_head # 576/9 = 64
|
| 102 |
+
self.n_head = config.n_head
|
| 103 |
+
self.n_embd = config.n_embd
|
| 104 |
+
|
| 105 |
+
# Make kv_dim divisible by n_head (189 is closest to 192 that's divisible by 9)
|
| 106 |
+
self.kv_dim = 189 # 189 = 9 * 21, closest to 192 that's divisible by 9
|
| 107 |
+
self.kv_dim_per_head = self.kv_dim // self.n_head # 21
|
| 108 |
+
|
| 109 |
+
# Separate projections with reduced dimensions for k,v
|
| 110 |
+
self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
| 111 |
+
self.k_proj = nn.Linear(config.n_embd, self.kv_dim, bias=False) # 189 dimensions
|
| 112 |
+
self.v_proj = nn.Linear(config.n_embd, self.kv_dim, bias=False) # 189 dimensions
|
| 113 |
+
|
| 114 |
+
# output projection
|
| 115 |
+
self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
| 116 |
+
|
| 117 |
+
# rotary embeddings
|
| 118 |
+
self.rope = RoPEAttention(self.head_dim, self.kv_dim_per_head)
|
| 119 |
+
|
| 120 |
+
def forward(self, x):
|
| 121 |
+
B, T, C = x.size()
|
| 122 |
+
|
| 123 |
+
# calculate query, key, values
|
| 124 |
+
q = self.q_proj(x)
|
| 125 |
+
k = self.k_proj(x)
|
| 126 |
+
v = self.v_proj(x)
|
| 127 |
+
|
| 128 |
+
# reshape with exact dimensions
|
| 129 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 130 |
+
k = k.view(B, T, self.n_head, self.kv_dim_per_head).transpose(1, 2)
|
| 131 |
+
v = v.view(B, T, self.n_head, self.kv_dim_per_head).transpose(1, 2)
|
| 132 |
+
|
| 133 |
+
# apply rotary embeddings
|
| 134 |
+
q, k = self.rope(q, k)
|
| 135 |
+
|
| 136 |
+
# pad k and v to match q dimension for attention
|
| 137 |
+
k_pad = torch.zeros_like(q)
|
| 138 |
+
v_pad = torch.zeros_like(q)
|
| 139 |
+
k_pad[..., :self.kv_dim_per_head] = k
|
| 140 |
+
v_pad[..., :self.kv_dim_per_head] = v
|
| 141 |
+
|
| 142 |
+
# flash attention
|
| 143 |
+
y = F.scaled_dot_product_attention(q, k_pad, v_pad, is_causal=True)
|
| 144 |
+
|
| 145 |
+
# reshape back
|
| 146 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 147 |
+
|
| 148 |
+
# output projection
|
| 149 |
+
y = self.o_proj(y)
|
| 150 |
+
return y
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class MLP(nn.Module):
|
| 154 |
+
"""
|
| 155 |
+
MLP (Multi-Layer Perceptron) layer with gate/up/down projection structure
|
| 156 |
+
"""
|
| 157 |
+
def __init__(self, config):
|
| 158 |
+
super().__init__()
|
| 159 |
+
hidden_dim = int(config.n_embd * config.mlp_ratio) - 1
|
| 160 |
+
self.gate_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
|
| 161 |
+
self.up_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
|
| 162 |
+
self.down_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
|
| 163 |
+
self.down_proj.NANOGPT_SCALE_INIT = 1
|
| 164 |
+
|
| 165 |
+
def forward(self, x):
|
| 166 |
+
# SwiGLU activation as used in PaLM, Llama, etc.
|
| 167 |
+
gate = self.gate_proj(x)
|
| 168 |
+
up = self.up_proj(x)
|
| 169 |
+
x = F.silu(gate) * up
|
| 170 |
+
x = self.down_proj(x)
|
| 171 |
+
return x
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class Block(nn.Module):
|
| 175 |
+
"""
|
| 176 |
+
Transformer block
|
| 177 |
+
"""
|
| 178 |
+
def __init__(self, config):
|
| 179 |
+
super().__init__()
|
| 180 |
+
self.ln_1 = nn.LayerNorm(config.n_embd, bias=False)
|
| 181 |
+
self.attn = CausalSelfAttention(config)
|
| 182 |
+
self.ln_2 = nn.LayerNorm(config.n_embd, bias=False)
|
| 183 |
+
self.mlp = MLP(config)
|
| 184 |
+
|
| 185 |
+
def forward(self, x):
|
| 186 |
+
x = x + self.attn(self.ln_1(x))
|
| 187 |
+
x = x + self.mlp(self.ln_2(x))
|
| 188 |
+
return x
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class SmollmV2(nn.Module):
|
| 192 |
+
"""
|
| 193 |
+
SmollmV2 model
|
| 194 |
+
"""
|
| 195 |
+
def __init__(self, config=SmollmConfig()):
|
| 196 |
+
super().__init__()
|
| 197 |
+
self.config = config
|
| 198 |
+
|
| 199 |
+
self.transformer = nn.ModuleDict(dict(
|
| 200 |
+
wte = nn.Embedding(config.vocab_size, config.n_embd),
|
| 201 |
+
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
| 202 |
+
ln_f = nn.LayerNorm(config.n_embd, bias=False),
|
| 203 |
+
))
|
| 204 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 205 |
+
|
| 206 |
+
# weight sharing
|
| 207 |
+
self.transformer.wte.weight = self.lm_head.weight
|
| 208 |
+
|
| 209 |
+
# weight initialization
|
| 210 |
+
self.apply(self._init_weights)
|
| 211 |
+
|
| 212 |
+
# Compile the model if torch version supports it
|
| 213 |
+
if hasattr(torch, 'compile'):
|
| 214 |
+
self.forward = torch.compile(self.forward)
|
| 215 |
+
|
| 216 |
+
def _init_weights(self, module):
|
| 217 |
+
if isinstance(module, nn.Linear):
|
| 218 |
+
std = 0.02
|
| 219 |
+
if hasattr(module, 'NANGPT_SCALE_INIT'):
|
| 220 |
+
std *= (2 * self.config.n_layer) ** -0.5
|
| 221 |
+
torch.nn.init.normal_(module.weight, mean = 0.0, std = std)
|
| 222 |
+
if module.bias is not None:
|
| 223 |
+
torch.nn.init.zeros_(module.bias)
|
| 224 |
+
elif isinstance(module, nn.Embedding):
|
| 225 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std = 0.04)
|
| 226 |
+
|
| 227 |
+
def forward(self, idx, targets=None):
|
| 228 |
+
# idx is of shape (B, T)
|
| 229 |
+
B, T = idx.size()
|
| 230 |
+
assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
|
| 231 |
+
# forward the token and posisition embeddings
|
| 232 |
+
tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
|
| 233 |
+
x = tok_emb
|
| 234 |
+
# forward the blocks of the transformer
|
| 235 |
+
for block in self.transformer.h:
|
| 236 |
+
x = block(x)
|
| 237 |
+
# forward the final layernorm and the classifier
|
| 238 |
+
x = self.transformer.ln_f(x)
|
| 239 |
+
logits = self.lm_head(x) # (B, T, vocab_size)
|
| 240 |
+
loss = None
|
| 241 |
+
if targets is not None:
|
| 242 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
| 243 |
+
return logits, loss
|
smollv2_lightning.py
ADDED
|
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""
|
| 3 |
+
Lightning module for SmollmV2 model training
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# Standard Library Imports
|
| 7 |
+
import os
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
|
| 10 |
+
# Third-Party Imports
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.optim as optim
|
| 14 |
+
import pytorch_lightning as pl
|
| 15 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
| 16 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
from tensorboard.backend.event_processing import event_accumulator
|
| 19 |
+
import time
|
| 20 |
+
import numpy as np
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
|
| 24 |
+
# Local Imports
|
| 25 |
+
from config import (SmollmConfig, OptimizerConfig, CheckpointConfig,
|
| 26 |
+
LoggingConfig, TrainerConfig)
|
| 27 |
+
from smollmv2 import SmollmV2
|
| 28 |
+
from cosmopedia_datamodule import CosmopediaDataModule
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class LitSmollmv2(pl.LightningModule):
|
| 32 |
+
"""
|
| 33 |
+
Lightning module for SmollmV2 model training
|
| 34 |
+
"""
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
learning_rate=OptimizerConfig.learning_rate,
|
| 38 |
+
weight_decay=OptimizerConfig.weight_decay,
|
| 39 |
+
total_epochs=None,
|
| 40 |
+
total_steps=None,
|
| 41 |
+
interupt_steps=SmollmConfig.max_steps,
|
| 42 |
+
compile_model=True
|
| 43 |
+
):
|
| 44 |
+
"""
|
| 45 |
+
Constructor
|
| 46 |
+
:param learning_rate: Learning rate for the optimizer
|
| 47 |
+
:param weight_decay: Weight decay for the optimizer
|
| 48 |
+
:param total_epochs: Total number of epochs (optional)
|
| 49 |
+
:param total_steps: Total number of steps (optional)
|
| 50 |
+
:param compile_model: Whether to compile the model for faster training
|
| 51 |
+
Note: Provide either total_epochs or total_steps, not both
|
| 52 |
+
"""
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.save_hyperparameters()
|
| 55 |
+
|
| 56 |
+
if total_epochs is None and total_steps is None:
|
| 57 |
+
raise ValueError("Must provide either total_epochs or total_steps")
|
| 58 |
+
if total_epochs is not None and total_steps is not None:
|
| 59 |
+
raise ValueError("Provide either total_epochs or total_steps, not both")
|
| 60 |
+
|
| 61 |
+
# Set seeds from config
|
| 62 |
+
torch.manual_seed(SmollmConfig.seed)
|
| 63 |
+
if torch.cuda.is_available():
|
| 64 |
+
torch.cuda.manual_seed(SmollmConfig.seed)
|
| 65 |
+
|
| 66 |
+
# Initialize the model
|
| 67 |
+
self.model = SmollmV2(SmollmConfig())
|
| 68 |
+
|
| 69 |
+
# Compile the model if requested and supported
|
| 70 |
+
if compile_model and hasattr(torch, 'compile'):
|
| 71 |
+
print("Compiling model for faster training...")
|
| 72 |
+
self.model = torch.compile(self.model)
|
| 73 |
+
|
| 74 |
+
# Print total model parameters
|
| 75 |
+
total_params = sum(p.numel() for p in self.model.parameters())
|
| 76 |
+
print(f"Total model parameters: {total_params:,}\n")
|
| 77 |
+
|
| 78 |
+
# OneCycleLR parameters from OptimizerConfig
|
| 79 |
+
self.max_lr = OptimizerConfig.max_lr
|
| 80 |
+
self.div_factor = OptimizerConfig.div_factor
|
| 81 |
+
self.final_div_factor = OptimizerConfig.final_div_factor
|
| 82 |
+
self.pct_start = OptimizerConfig.pct_start
|
| 83 |
+
self.total_epochs = total_epochs
|
| 84 |
+
self.total_steps = total_steps
|
| 85 |
+
|
| 86 |
+
# Add performance monitoring attributes
|
| 87 |
+
self.iter_num = 0
|
| 88 |
+
self.iter_time = 0.0
|
| 89 |
+
self.tokens_processed = 0
|
| 90 |
+
self.interupt_steps = interupt_steps
|
| 91 |
+
|
| 92 |
+
def on_load_checkpoint(self, checkpoint):
|
| 93 |
+
"""Restore iter_num when loading from checkpoint"""
|
| 94 |
+
if 'iter_num' in checkpoint:
|
| 95 |
+
self.iter_num = checkpoint['iter_num']
|
| 96 |
+
|
| 97 |
+
def on_save_checkpoint(self, checkpoint):
|
| 98 |
+
"""Save iter_num in checkpoint"""
|
| 99 |
+
checkpoint['iter_num'] = self.iter_num
|
| 100 |
+
|
| 101 |
+
def forward(self, x, targets=None):
|
| 102 |
+
"""
|
| 103 |
+
Method to forward the input through the model
|
| 104 |
+
"""
|
| 105 |
+
return self.model(x, targets)
|
| 106 |
+
|
| 107 |
+
def training_step(self, batch, batch_idx):
|
| 108 |
+
"""
|
| 109 |
+
Method to perform a training step with performance monitoring
|
| 110 |
+
"""
|
| 111 |
+
try:
|
| 112 |
+
# Stop training at max steps from config
|
| 113 |
+
if self.iter_num >= self.interupt_steps:
|
| 114 |
+
self.trainer.should_stop = True
|
| 115 |
+
return None
|
| 116 |
+
|
| 117 |
+
# Start timing
|
| 118 |
+
t0 = time.time()
|
| 119 |
+
|
| 120 |
+
# Process batch
|
| 121 |
+
input_ids = batch['input_ids']
|
| 122 |
+
labels = batch['labels']
|
| 123 |
+
attention_mask = batch['attention_mask']
|
| 124 |
+
|
| 125 |
+
# Clear cache before forward pass
|
| 126 |
+
if torch.cuda.is_available():
|
| 127 |
+
torch.cuda.empty_cache()
|
| 128 |
+
|
| 129 |
+
# Forward pass
|
| 130 |
+
logits, loss = self(input_ids, targets=labels)
|
| 131 |
+
|
| 132 |
+
# Calculate tokens processed
|
| 133 |
+
tokens_per_iter = np.prod(input_ids.shape)
|
| 134 |
+
self.tokens_processed += tokens_per_iter
|
| 135 |
+
|
| 136 |
+
# Ensure CUDA synchronization after forward pass
|
| 137 |
+
if torch.cuda.is_available():
|
| 138 |
+
torch.cuda.synchronize()
|
| 139 |
+
|
| 140 |
+
# Calculate iteration time
|
| 141 |
+
dt = time.time() - t0
|
| 142 |
+
self.iter_time += dt
|
| 143 |
+
|
| 144 |
+
# Log metrics
|
| 145 |
+
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
|
| 146 |
+
self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'], on_step=True)
|
| 147 |
+
|
| 148 |
+
# Generate sample prediction
|
| 149 |
+
if self.iter_num % LoggingConfig.generate_every == 0:
|
| 150 |
+
# Get a sample input from the batch
|
| 151 |
+
context_length = SmollmConfig.context_length # Number of tokens to use as context
|
| 152 |
+
sample_input = input_ids[0:1, :context_length]
|
| 153 |
+
|
| 154 |
+
# Generate prediction
|
| 155 |
+
self.model.eval()
|
| 156 |
+
with torch.no_grad():
|
| 157 |
+
max_new_tokens = SmollmConfig.max_new_tokens
|
| 158 |
+
temperature = SmollmConfig.temperature
|
| 159 |
+
top_k = SmollmConfig.top_k
|
| 160 |
+
|
| 161 |
+
for _ in range(max_new_tokens):
|
| 162 |
+
# Get model predictions
|
| 163 |
+
logits, _ = self(sample_input)
|
| 164 |
+
logits = logits[:, -1, :] / temperature
|
| 165 |
+
|
| 166 |
+
# Apply top-k sampling
|
| 167 |
+
if top_k is not None:
|
| 168 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 169 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
| 170 |
+
|
| 171 |
+
probs = F.softmax(logits, dim=-1)
|
| 172 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 173 |
+
sample_input = torch.cat([sample_input, next_token], dim=1)
|
| 174 |
+
|
| 175 |
+
# Convert tokens to text using the tokenizer from datamodule
|
| 176 |
+
try:
|
| 177 |
+
input_text = self.trainer.datamodule.tokenizer.decode(sample_input[0, :10].tolist())
|
| 178 |
+
generated_text = self.trainer.datamodule.tokenizer.decode(sample_input[0, 10:].tolist())
|
| 179 |
+
print(f"\nStep {self.iter_num} - Sample Generation:")
|
| 180 |
+
print(f"Input: {input_text}")
|
| 181 |
+
print(f"Generated: {generated_text}\n")
|
| 182 |
+
except Exception as e:
|
| 183 |
+
print(f"Error decoding text: {str(e)}")
|
| 184 |
+
|
| 185 |
+
self.model.train() # Set back to training mode
|
| 186 |
+
|
| 187 |
+
# Log performance metrics
|
| 188 |
+
if self.iter_num % LoggingConfig.log_every == 0:
|
| 189 |
+
tokens_per_sec = self.tokens_processed / self.iter_time if self.iter_time > 0 else 0
|
| 190 |
+
|
| 191 |
+
self.log('tokens_per_sec', tokens_per_sec, on_step=True)
|
| 192 |
+
self.log('iter_time_ms', dt * 1000, on_step=True)
|
| 193 |
+
|
| 194 |
+
print(f"\nstep {self.iter_num} | loss: {loss.item():.4f} | dt: {dt*1000:.2f}ms | tok/sec: {tokens_per_sec:.2f}")
|
| 195 |
+
|
| 196 |
+
if torch.cuda.is_available():
|
| 197 |
+
self.log('gpu_memory', torch.cuda.memory_allocated() / 1e9, on_step=True)
|
| 198 |
+
self.log('gpu_memory_reserved', torch.cuda.memory_reserved() / 1e9, on_step=True)
|
| 199 |
+
print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB / {torch.cuda.memory_reserved() / 1e9:.2f}GB")
|
| 200 |
+
|
| 201 |
+
# Clear GPU cache periodically if enabled
|
| 202 |
+
if SmollmConfig.clear_cache_every > 0 and self.iter_num % SmollmConfig.clear_cache_every == 0:
|
| 203 |
+
if torch.cuda.is_available():
|
| 204 |
+
torch.cuda.empty_cache()
|
| 205 |
+
|
| 206 |
+
self.tokens_processed = 0
|
| 207 |
+
self.iter_time = 0.0
|
| 208 |
+
|
| 209 |
+
self.iter_num += 1
|
| 210 |
+
return loss
|
| 211 |
+
|
| 212 |
+
except RuntimeError as e:
|
| 213 |
+
if "out of memory" in str(e):
|
| 214 |
+
if torch.cuda.is_available():
|
| 215 |
+
torch.cuda.empty_cache()
|
| 216 |
+
print(f"WARNING: out of memory - {str(e)}")
|
| 217 |
+
return None
|
| 218 |
+
raise e
|
| 219 |
+
|
| 220 |
+
def validation_step(self, batch, batch_idx):
|
| 221 |
+
"""
|
| 222 |
+
Method to perform a validation step
|
| 223 |
+
"""
|
| 224 |
+
# Start timing for validation
|
| 225 |
+
t0 = time.time()
|
| 226 |
+
|
| 227 |
+
# Ensure CUDA synchronization for accurate timing
|
| 228 |
+
if torch.cuda.is_available():
|
| 229 |
+
torch.cuda.synchronize()
|
| 230 |
+
|
| 231 |
+
# Process batch - updated for Cosmopedia format
|
| 232 |
+
input_ids = batch['input_ids']
|
| 233 |
+
labels = batch['labels']
|
| 234 |
+
attention_mask = batch['attention_mask']
|
| 235 |
+
|
| 236 |
+
# Forward pass
|
| 237 |
+
logits, loss = self(input_ids, targets=labels)
|
| 238 |
+
|
| 239 |
+
# Ensure CUDA synchronization after forward pass
|
| 240 |
+
if torch.cuda.is_available():
|
| 241 |
+
torch.cuda.synchronize()
|
| 242 |
+
|
| 243 |
+
# Calculate validation time
|
| 244 |
+
dt = time.time() - t0
|
| 245 |
+
|
| 246 |
+
# Log metrics
|
| 247 |
+
self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 248 |
+
|
| 249 |
+
if batch_idx == 0: # Only print for first batch
|
| 250 |
+
print(f"\nValidation - loss: {loss.item():.4f} | dt: {dt*1000:.2f}ms")
|
| 251 |
+
if torch.cuda.is_available():
|
| 252 |
+
print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB / {torch.cuda.memory_reserved() / 1e9:.2f}GB")
|
| 253 |
+
|
| 254 |
+
return loss
|
| 255 |
+
|
| 256 |
+
def configure_optimizers(self):
|
| 257 |
+
"""
|
| 258 |
+
Method to configure the optimizer and scheduler
|
| 259 |
+
"""
|
| 260 |
+
# Create an instance of OptimizerConfig
|
| 261 |
+
optim_config = OptimizerConfig()
|
| 262 |
+
|
| 263 |
+
optimizer = getattr(optim, optim_config.optimizer)(
|
| 264 |
+
self.parameters(),
|
| 265 |
+
lr=self.hparams.learning_rate,
|
| 266 |
+
weight_decay=self.hparams.weight_decay,
|
| 267 |
+
**optim_config.optimizer_kwargs
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# Calculate total steps
|
| 271 |
+
if self.total_steps is None:
|
| 272 |
+
total_steps = len(self.trainer.datamodule.train_dataloader()) * self.total_epochs
|
| 273 |
+
else:
|
| 274 |
+
total_steps = self.total_steps
|
| 275 |
+
|
| 276 |
+
scheduler = {
|
| 277 |
+
'scheduler': optim.lr_scheduler.OneCycleLR(
|
| 278 |
+
optimizer,
|
| 279 |
+
max_lr=self.max_lr,
|
| 280 |
+
total_steps=total_steps,
|
| 281 |
+
pct_start=self.pct_start,
|
| 282 |
+
div_factor=self.div_factor,
|
| 283 |
+
final_div_factor=self.final_div_factor,
|
| 284 |
+
three_phase=optim_config.three_phase,
|
| 285 |
+
anneal_strategy=optim_config.anneal_strategy
|
| 286 |
+
),
|
| 287 |
+
'interval': 'step'
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
return [optimizer], [scheduler]
|
| 291 |
+
|
| 292 |
+
def on_train_epoch_end(self):
|
| 293 |
+
"""
|
| 294 |
+
Called at the end of training epoch
|
| 295 |
+
"""
|
| 296 |
+
# Reset performance counters at epoch end
|
| 297 |
+
self.tokens_processed = 0
|
| 298 |
+
self.iter_time = 0.0
|
| 299 |
+
|
| 300 |
+
def plot_learning_rate(log_dir):
|
| 301 |
+
"""
|
| 302 |
+
Plot learning rate from TensorBoard logs
|
| 303 |
+
"""
|
| 304 |
+
event_files = []
|
| 305 |
+
for root, dirs, files in os.walk(log_dir):
|
| 306 |
+
for file in files:
|
| 307 |
+
if "events.out.tfevents" in file:
|
| 308 |
+
event_files.append(os.path.join(root, file))
|
| 309 |
+
|
| 310 |
+
lr_data = []
|
| 311 |
+
steps = []
|
| 312 |
+
|
| 313 |
+
for event_file in event_files:
|
| 314 |
+
ea = event_accumulator.EventAccumulator(
|
| 315 |
+
event_file,
|
| 316 |
+
size_guidance={'scalars': 0}
|
| 317 |
+
)
|
| 318 |
+
ea.Reload()
|
| 319 |
+
|
| 320 |
+
if 'lr' in ea.Tags()['scalars']:
|
| 321 |
+
events = ea.Scalars('lr')
|
| 322 |
+
for event in events:
|
| 323 |
+
lr_data.append(event.value)
|
| 324 |
+
steps.append(event.step)
|
| 325 |
+
|
| 326 |
+
if lr_data:
|
| 327 |
+
plt.figure(figsize=(10, 6))
|
| 328 |
+
plt.plot(steps, lr_data, '-', linewidth=2)
|
| 329 |
+
plt.title('Learning Rate Schedule')
|
| 330 |
+
plt.xlabel('Training Steps')
|
| 331 |
+
plt.ylabel('Learning Rate')
|
| 332 |
+
plt.grid(True)
|
| 333 |
+
plt.margins(x=0.02)
|
| 334 |
+
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
|
| 335 |
+
plt.savefig('learning_rate_schedule.png', dpi=300, bbox_inches='tight')
|
| 336 |
+
plt.close()
|
| 337 |
+
|
| 338 |
+
def train_model(epochs=None, steps=None, ckpt_path=None, interupt_steps=SmollmConfig.max_steps):
|
| 339 |
+
"""
|
| 340 |
+
Train the model for specified number of epochs or steps
|
| 341 |
+
:param epochs: Number of epochs to train (optional)
|
| 342 |
+
:param steps: Number of steps to train (optional)
|
| 343 |
+
:param ckpt_path: Path to checkpoint for resuming training
|
| 344 |
+
:param interupt_steps: Number of steps after which to interrupt training
|
| 345 |
+
Note: Provide either epochs or steps, not both
|
| 346 |
+
"""
|
| 347 |
+
# Set compilation mode for PyTorch 2.0+
|
| 348 |
+
if hasattr(torch, 'compile'):
|
| 349 |
+
torch._dynamo.config.suppress_errors = True
|
| 350 |
+
torch._dynamo.config.verbose = False
|
| 351 |
+
|
| 352 |
+
torch.set_float32_matmul_precision('high')
|
| 353 |
+
|
| 354 |
+
# Initialize data module with reduced workers and batch size
|
| 355 |
+
data_module = CosmopediaDataModule(
|
| 356 |
+
batch_size=SmollmConfig.batch_size, # Reduced from 32
|
| 357 |
+
num_workers=SmollmConfig.num_workers, # Reduced from 4
|
| 358 |
+
shuffle_buffer_size=SmollmConfig.shuffle_buffer_size,
|
| 359 |
+
max_length=SmollmConfig.block_size
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
# Initialize model
|
| 363 |
+
model = LitSmollmv2(total_epochs=epochs, total_steps=steps, interupt_steps=interupt_steps)
|
| 364 |
+
|
| 365 |
+
# Setup callbacks with reduced frequency
|
| 366 |
+
checkpoint_callback = ModelCheckpoint(
|
| 367 |
+
dirpath='checkpoints',
|
| 368 |
+
filename='smollmv2-{step:05d}-{val_loss:.2f}',
|
| 369 |
+
save_top_k=CheckpointConfig.save_top_k, # Save only the best model
|
| 370 |
+
monitor=CheckpointConfig.monitor, # Monitor training loss instead of validation loss
|
| 371 |
+
mode=CheckpointConfig.mode,
|
| 372 |
+
save_last=CheckpointConfig.save_last,
|
| 373 |
+
every_n_train_steps=CheckpointConfig.checkpoint_every, # Reduced checkpoint frequency
|
| 374 |
+
save_on_train_epoch_end=CheckpointConfig.save_on_train_epoch_end
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
lr_monitor = LearningRateMonitor(logging_interval='step')
|
| 378 |
+
|
| 379 |
+
# Setup logger
|
| 380 |
+
logger = TensorBoardLogger("lightning_logs", name="smollmv2", log_graph=True)
|
| 381 |
+
|
| 382 |
+
# Add gradient scaler for mixed precision training
|
| 383 |
+
scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
|
| 384 |
+
|
| 385 |
+
# Initialize trainer with performance monitoring
|
| 386 |
+
trainer_kwargs = {
|
| 387 |
+
'accelerator': TrainerConfig.accelerator,
|
| 388 |
+
'devices': TrainerConfig.devices,
|
| 389 |
+
'callbacks': [checkpoint_callback, lr_monitor],
|
| 390 |
+
'logger': logger,
|
| 391 |
+
'precision': TrainerConfig.precision,
|
| 392 |
+
'log_every_n_steps': TrainerConfig.log_every_n_steps,
|
| 393 |
+
'strategy': TrainerConfig.strategy,
|
| 394 |
+
'deterministic': TrainerConfig.deterministic,
|
| 395 |
+
'benchmark': TrainerConfig.benchmark,
|
| 396 |
+
'enable_progress_bar': TrainerConfig.enable_progress_bar,
|
| 397 |
+
'enable_model_summary': TrainerConfig.enable_model_summary,
|
| 398 |
+
'profiler': TrainerConfig.profiler,
|
| 399 |
+
'gradient_clip_val': TrainerConfig.gradient_clip_val,
|
| 400 |
+
'accumulate_grad_batches': TrainerConfig.accumulate_grad_batches,
|
| 401 |
+
'val_check_interval': TrainerConfig.val_check_interval,
|
| 402 |
+
'check_val_every_n_epoch': TrainerConfig.check_val_every_n_epoch
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
# Add either max_epochs or max_steps
|
| 406 |
+
if epochs is not None:
|
| 407 |
+
trainer_kwargs['max_epochs'] = epochs
|
| 408 |
+
else:
|
| 409 |
+
trainer_kwargs['max_steps'] = steps
|
| 410 |
+
|
| 411 |
+
trainer = pl.Trainer(**trainer_kwargs)
|
| 412 |
+
|
| 413 |
+
# Train with performance monitoring
|
| 414 |
+
print("\nStarting training with performance monitoring...")
|
| 415 |
+
print("Format: step | loss | iteration time | tokens per second | GPU memory\n")
|
| 416 |
+
|
| 417 |
+
# Enable garbage collection
|
| 418 |
+
import gc
|
| 419 |
+
gc.collect()
|
| 420 |
+
if torch.cuda.is_available():
|
| 421 |
+
torch.cuda.empty_cache()
|
| 422 |
+
|
| 423 |
+
try:
|
| 424 |
+
trainer.fit(model, data_module, ckpt_path=ckpt_path)
|
| 425 |
+
except KeyboardInterrupt:
|
| 426 |
+
print("\nTraining interrupted by user. Saving checkpoint...")
|
| 427 |
+
if not os.path.exists('checkpoints'):
|
| 428 |
+
os.makedirs('checkpoints')
|
| 429 |
+
trainer.save_checkpoint("checkpoints/interrupted_training.ckpt")
|
| 430 |
+
print("Checkpoint saved. Exiting...")
|
| 431 |
+
except Exception as e:
|
| 432 |
+
print(f"An error occurred during training: {str(e)}")
|
| 433 |
+
if torch.cuda.is_available():
|
| 434 |
+
torch.cuda.empty_cache()
|
| 435 |
+
raise e
|
| 436 |
+
|
| 437 |
+
return checkpoint_callback.best_model_path
|
| 438 |
+
|
| 439 |
+
def get_latest_checkpoint():
|
| 440 |
+
"""
|
| 441 |
+
Find the latest checkpoint in the checkpoints directory
|
| 442 |
+
"""
|
| 443 |
+
checkpoint_dir = 'checkpoints'
|
| 444 |
+
if not os.path.exists(checkpoint_dir):
|
| 445 |
+
return None
|
| 446 |
+
|
| 447 |
+
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt')]
|
| 448 |
+
if not checkpoints:
|
| 449 |
+
return None
|
| 450 |
+
|
| 451 |
+
latest_checkpoint = max(
|
| 452 |
+
[os.path.join(checkpoint_dir, f) for f in checkpoints],
|
| 453 |
+
key=os.path.getmtime
|
| 454 |
+
)
|
| 455 |
+
return latest_checkpoint
|
| 456 |
+
|
| 457 |
+
def main(interupt_steps=SmollmConfig.max_steps):
|
| 458 |
+
"""
|
| 459 |
+
Main function to handle training workflow
|
| 460 |
+
"""
|
| 461 |
+
# Ask user for training mode
|
| 462 |
+
mode = input("Train by epochs or steps? (e/s): ").lower()
|
| 463 |
+
|
| 464 |
+
if mode == 'e':
|
| 465 |
+
total_epochs = int(input("Enter number of epochs: "))
|
| 466 |
+
steps = None
|
| 467 |
+
else:
|
| 468 |
+
steps = int(input("Enter number of steps: "))
|
| 469 |
+
total_epochs = None
|
| 470 |
+
|
| 471 |
+
try:
|
| 472 |
+
latest_checkpoint = get_latest_checkpoint()
|
| 473 |
+
|
| 474 |
+
if latest_checkpoint and os.path.exists(latest_checkpoint):
|
| 475 |
+
print(f"\nFound existing checkpoint: {latest_checkpoint}")
|
| 476 |
+
user_input = input("Resume training from checkpoint? (y/n): ").lower()
|
| 477 |
+
|
| 478 |
+
if user_input == 'y':
|
| 479 |
+
print(f"\nResuming training from checkpoint: {latest_checkpoint}")
|
| 480 |
+
train_model(epochs=total_epochs, steps=steps, ckpt_path=latest_checkpoint, interupt_steps=interupt_steps)
|
| 481 |
+
else:
|
| 482 |
+
print("\nStarting fresh training...")
|
| 483 |
+
best_model_path = train_model(epochs=total_epochs, steps=steps, interupt_steps=interupt_steps)
|
| 484 |
+
else:
|
| 485 |
+
print("\nNo checkpoints found. Starting fresh training...")
|
| 486 |
+
best_model_path = train_model(epochs=total_epochs, steps=steps, interupt_steps=interupt_steps)
|
| 487 |
+
|
| 488 |
+
print("\nGenerating learning rate plot...")
|
| 489 |
+
plot_learning_rate("lightning_logs")
|
| 490 |
+
print("Learning rate plot saved as 'learning_rate_schedule.png'")
|
| 491 |
+
|
| 492 |
+
except Exception as e:
|
| 493 |
+
print(f"An error occurred during training: {str(e)}")
|
| 494 |
+
import traceback
|
| 495 |
+
traceback.print_exc()
|
| 496 |
+
|
| 497 |
+
if __name__ == "__main__":
|
| 498 |
+
main()
|