Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import wandb | |
import shutil | |
from config import SmolLM2Config | |
from model import SmolLM2Lightning | |
import pytorch_lightning as pl | |
from pytorch_lightning.callbacks import ModelCheckpoint, Callback | |
from pytorch_lightning.loggers import WandbLogger | |
from env_setup import setup_environment, cleanup_environment | |
# Set CUDA environment variables before any other CUDA operations | |
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' | |
os.environ['TORCH_USE_CUDA_DSA'] = '1' | |
def setup_training(): | |
"""Setup training environment""" | |
try: | |
if torch.cuda.is_available(): | |
# Configure CUDA settings | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
torch.backends.cudnn.benchmark = True | |
torch.set_float32_matmul_precision('high') | |
# Set default device | |
device = torch.device('cuda:0') | |
torch.cuda.set_device(device) | |
# Print GPU info | |
print(f"Using GPU: {torch.cuda.get_device_name()}") | |
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") | |
return device | |
except Exception as e: | |
print(f"CUDA setup error: {str(e)}") | |
print("Using CPU") | |
return torch.device('cpu') | |
def cleanup_training(): | |
"""Cleanup training resources""" | |
try: | |
# Move model to CPU before cleanup | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Clean up wandb | |
try: | |
wandb.finish() | |
except: | |
pass | |
except Exception as e: | |
print(f"Cleanup error: {str(e)}") | |
# Setup CUDA at module level | |
device = setup_training() | |
class GenerationMonitorCallback(Callback): | |
def __init__(self, prompt="Explain what machine learning is:", sample_every_n_steps=500): | |
super().__init__() | |
self.prompt = prompt | |
self.sample_every_n_steps = sample_every_n_steps | |
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | |
try: | |
if (trainer.global_step + 1) % self.sample_every_n_steps == 0: | |
# Switch to eval mode | |
pl_module.eval() | |
with torch.no_grad(): | |
# Tokenize prompt | |
inputs = pl_module.tokenizer( | |
self.prompt, | |
return_tensors="pt", | |
truncation=True, | |
max_length=pl_module.config.model.max_position_embeddings, | |
padding=True | |
).to(pl_module.device) | |
try: | |
# Generate text with error handling | |
outputs = pl_module.generate( | |
input_ids=inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_length=100, | |
temperature=0.7, | |
top_p=0.9, | |
top_k=50, | |
do_sample=True, | |
pad_token_id=pl_module.tokenizer.pad_token_id, | |
bos_token_id=pl_module.tokenizer.bos_token_id, | |
eos_token_id=pl_module.tokenizer.eos_token_id | |
) | |
# Decode generated text | |
generated_text = pl_module.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Print results | |
print(f"\n=== Generation at step {trainer.global_step + 1} ===") | |
print(f"Prompt: {self.prompt}") | |
print(f"Generated: {generated_text}\n") | |
except RuntimeError as e: | |
print(f"\nError during generation at step {trainer.global_step + 1}: {str(e)}") | |
print(f"Input shape: {inputs.input_ids.shape}") | |
print(f"Input device: {inputs.input_ids.device}") | |
# Switch back to train mode | |
pl_module.train() | |
except Exception as e: | |
print(f"\nCallback error at step {trainer.global_step + 1}: {str(e)}") | |
def init_wandb(project_name, run_name): | |
"""Initialize WandB with error handling and cleanup""" | |
try: | |
# Try to clean up any existing wandb directory | |
wandb_dir = os.path.join(os.getcwd(), "wandb") | |
if os.path.exists(wandb_dir): | |
try: | |
shutil.rmtree(wandb_dir) | |
print("Cleaned up existing wandb directory") | |
except Exception as e: | |
print(f"Warning: Could not clean up wandb directory: {str(e)}") | |
# Create fresh wandb directory with proper permissions | |
os.makedirs(wandb_dir, exist_ok=True) | |
# Initialize WandB logger | |
logger = WandbLogger( | |
project=project_name, | |
name=run_name, | |
save_dir=os.getcwd(), | |
settings=wandb.Settings(start_method="thread") | |
) | |
return logger | |
except Exception as e: | |
print(f"Error initializing WandB: {str(e)}") | |
print("Continuing without WandB logging...") | |
return None | |
def main(): | |
device = setup_training() | |
try: | |
# Load configuration | |
config = SmolLM2Config("config.yaml") | |
# Initialize model | |
model = SmolLM2Lightning(config) | |
# Phase 1: Initial Training | |
print("\n=== Starting Phase 1 Training ===") | |
# Initialize wandb logger for phase 1 with error handling | |
wandb_logger = init_wandb("smol-lm2", "training_run_phase1") | |
# Setup checkpoint callback for phase 1 | |
checkpoint_callback = ModelCheckpoint( | |
dirpath=config.training.checkpoint_dir, | |
filename="smol-lm2-phase1-{epoch:02d}-{train_loss:.2f}", | |
save_top_k=3, | |
monitor="train_loss", | |
mode="min", | |
every_n_train_steps=config.training.save_steps | |
) | |
# Setup generation monitoring callback for phase 1 | |
generation_callback = GenerationMonitorCallback( | |
prompt=config.training.sample_prompt, | |
sample_every_n_steps=config.training.sample_frequency | |
) | |
# Initialize trainer for phase 1 | |
trainer_phase1 = pl.Trainer( | |
max_steps=config.training.first_phase_steps, | |
accelerator=config.hardware.accelerator, | |
devices=config.hardware.devices, | |
precision=config.hardware.precision, | |
logger=wandb_logger, | |
callbacks=[checkpoint_callback, generation_callback], | |
gradient_clip_val=config.hardware.gradient_clip, | |
accumulate_grad_batches=config.training.gradient_accumulation_steps, | |
log_every_n_steps=config.training.logging_steps, | |
deterministic=False, | |
benchmark=True, | |
strategy='auto', # Let PyTorch Lightning handle device strategy | |
) | |
# Train phase 1 with error handling | |
try: | |
trainer_phase1.fit(model) | |
except Exception as e: | |
print(f"Error during phase 1 training: {str(e)}") | |
raise | |
# Save phase 1 checkpoint | |
phase1_checkpoint_path = os.path.join(config.training.checkpoint_dir, "smol-lm2-phase1-final.ckpt") | |
trainer_phase1.save_checkpoint(phase1_checkpoint_path) | |
print(f"Phase 1 completed. Model saved to {phase1_checkpoint_path}") | |
# Clear GPU memory between phases | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Phase 2: Fine-tuning | |
print("\n=== Starting Phase 2 Training ===") | |
# Load the model from phase 1 checkpoint with error handling | |
try: | |
model = SmolLM2Lightning.load_from_checkpoint(phase1_checkpoint_path, config=config) | |
except Exception as e: | |
print(f"Error loading checkpoint for phase 2: {str(e)}") | |
raise | |
# Initialize wandb logger for phase 2 with error handling | |
wandb_logger = init_wandb("smol-lm2", "training_run_phase2") | |
# Setup generation monitoring callback with higher frequency for phase 2 | |
generation_callback = GenerationMonitorCallback( | |
prompt=config.training.sample_prompt, | |
sample_every_n_steps=config.training.second_phase_sample_frequency | |
) | |
# Initialize trainer for phase 2 | |
trainer_phase2 = pl.Trainer( | |
max_steps=config.training.second_phase_steps, | |
accelerator=config.hardware.accelerator, | |
devices=config.hardware.devices, | |
precision=config.hardware.precision, | |
logger=wandb_logger, | |
callbacks=[generation_callback], | |
gradient_clip_val=config.hardware.gradient_clip, | |
accumulate_grad_batches=config.training.gradient_accumulation_steps, | |
log_every_n_steps=config.training.logging_steps, | |
deterministic=False, | |
benchmark=True, | |
) | |
# Train phase 2 with error handling | |
try: | |
trainer_phase2.fit(model) | |
except Exception as e: | |
print(f"Error during phase 2 training: {str(e)}") | |
raise | |
# Save final model | |
final_checkpoint_path = os.path.join(config.training.checkpoint_dir, "smol-lm2-final.ckpt") | |
trainer_phase2.save_checkpoint(final_checkpoint_path) | |
print(f"Phase 2 completed. Final model saved to {final_checkpoint_path}") | |
except Exception as e: | |
print(f"\nTraining failed with error: {str(e)}") | |
if torch.cuda.is_available(): | |
print(f"CUDA memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB") | |
print(f"CUDA memory cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB") | |
raise | |
finally: | |
cleanup_training() | |
if __name__ == "__main__": | |
main() |