import os import torch import wandb import shutil from config import SmolLM2Config from model import SmolLM2Lightning import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint, Callback from pytorch_lightning.loggers import WandbLogger from env_setup import setup_environment, cleanup_environment # Set CUDA environment variables before any other CUDA operations os.environ['CUDA_LAUNCH_BLOCKING'] = '1' os.environ['TORCH_USE_CUDA_DSA'] = '1' def setup_training(): """Setup training environment""" try: if torch.cuda.is_available(): # Configure CUDA settings torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True torch.set_float32_matmul_precision('high') # Set default device device = torch.device('cuda:0') torch.cuda.set_device(device) # Print GPU info print(f"Using GPU: {torch.cuda.get_device_name()}") print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") return device except Exception as e: print(f"CUDA setup error: {str(e)}") print("Using CPU") return torch.device('cpu') def cleanup_training(): """Cleanup training resources""" try: # Move model to CPU before cleanup if torch.cuda.is_available(): torch.cuda.empty_cache() # Clean up wandb try: wandb.finish() except: pass except Exception as e: print(f"Cleanup error: {str(e)}") # Setup CUDA at module level device = setup_training() class GenerationMonitorCallback(Callback): def __init__(self, prompt="Explain what machine learning is:", sample_every_n_steps=500): super().__init__() self.prompt = prompt self.sample_every_n_steps = sample_every_n_steps def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): try: if (trainer.global_step + 1) % self.sample_every_n_steps == 0: # Switch to eval mode pl_module.eval() with torch.no_grad(): # Tokenize prompt inputs = pl_module.tokenizer( self.prompt, return_tensors="pt", truncation=True, max_length=pl_module.config.model.max_position_embeddings, padding=True ).to(pl_module.device) try: # Generate text with error handling outputs = pl_module.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_length=100, temperature=0.7, top_p=0.9, top_k=50, do_sample=True, pad_token_id=pl_module.tokenizer.pad_token_id, bos_token_id=pl_module.tokenizer.bos_token_id, eos_token_id=pl_module.tokenizer.eos_token_id ) # Decode generated text generated_text = pl_module.tokenizer.decode(outputs[0], skip_special_tokens=True) # Print results print(f"\n=== Generation at step {trainer.global_step + 1} ===") print(f"Prompt: {self.prompt}") print(f"Generated: {generated_text}\n") except RuntimeError as e: print(f"\nError during generation at step {trainer.global_step + 1}: {str(e)}") print(f"Input shape: {inputs.input_ids.shape}") print(f"Input device: {inputs.input_ids.device}") # Switch back to train mode pl_module.train() except Exception as e: print(f"\nCallback error at step {trainer.global_step + 1}: {str(e)}") def init_wandb(project_name, run_name): """Initialize WandB with error handling and cleanup""" try: # Try to clean up any existing wandb directory wandb_dir = os.path.join(os.getcwd(), "wandb") if os.path.exists(wandb_dir): try: shutil.rmtree(wandb_dir) print("Cleaned up existing wandb directory") except Exception as e: print(f"Warning: Could not clean up wandb directory: {str(e)}") # Create fresh wandb directory with proper permissions os.makedirs(wandb_dir, exist_ok=True) # Initialize WandB logger logger = WandbLogger( project=project_name, name=run_name, save_dir=os.getcwd(), settings=wandb.Settings(start_method="thread") ) return logger except Exception as e: print(f"Error initializing WandB: {str(e)}") print("Continuing without WandB logging...") return None def main(): device = setup_training() try: # Load configuration config = SmolLM2Config("config.yaml") # Initialize model model = SmolLM2Lightning(config) # Phase 1: Initial Training print("\n=== Starting Phase 1 Training ===") # Initialize wandb logger for phase 1 with error handling wandb_logger = init_wandb("smol-lm2", "training_run_phase1") # Setup checkpoint callback for phase 1 checkpoint_callback = ModelCheckpoint( dirpath=config.training.checkpoint_dir, filename="smol-lm2-phase1-{epoch:02d}-{train_loss:.2f}", save_top_k=3, monitor="train_loss", mode="min", every_n_train_steps=config.training.save_steps ) # Setup generation monitoring callback for phase 1 generation_callback = GenerationMonitorCallback( prompt=config.training.sample_prompt, sample_every_n_steps=config.training.sample_frequency ) # Initialize trainer for phase 1 trainer_phase1 = pl.Trainer( max_steps=config.training.first_phase_steps, accelerator=config.hardware.accelerator, devices=config.hardware.devices, precision=config.hardware.precision, logger=wandb_logger, callbacks=[checkpoint_callback, generation_callback], gradient_clip_val=config.hardware.gradient_clip, accumulate_grad_batches=config.training.gradient_accumulation_steps, log_every_n_steps=config.training.logging_steps, deterministic=False, benchmark=True, strategy='auto', # Let PyTorch Lightning handle device strategy ) # Train phase 1 with error handling try: trainer_phase1.fit(model) except Exception as e: print(f"Error during phase 1 training: {str(e)}") raise # Save phase 1 checkpoint phase1_checkpoint_path = os.path.join(config.training.checkpoint_dir, "smol-lm2-phase1-final.ckpt") trainer_phase1.save_checkpoint(phase1_checkpoint_path) print(f"Phase 1 completed. Model saved to {phase1_checkpoint_path}") # Clear GPU memory between phases if torch.cuda.is_available(): torch.cuda.empty_cache() # Phase 2: Fine-tuning print("\n=== Starting Phase 2 Training ===") # Load the model from phase 1 checkpoint with error handling try: model = SmolLM2Lightning.load_from_checkpoint(phase1_checkpoint_path, config=config) except Exception as e: print(f"Error loading checkpoint for phase 2: {str(e)}") raise # Initialize wandb logger for phase 2 with error handling wandb_logger = init_wandb("smol-lm2", "training_run_phase2") # Setup generation monitoring callback with higher frequency for phase 2 generation_callback = GenerationMonitorCallback( prompt=config.training.sample_prompt, sample_every_n_steps=config.training.second_phase_sample_frequency ) # Initialize trainer for phase 2 trainer_phase2 = pl.Trainer( max_steps=config.training.second_phase_steps, accelerator=config.hardware.accelerator, devices=config.hardware.devices, precision=config.hardware.precision, logger=wandb_logger, callbacks=[generation_callback], gradient_clip_val=config.hardware.gradient_clip, accumulate_grad_batches=config.training.gradient_accumulation_steps, log_every_n_steps=config.training.logging_steps, deterministic=False, benchmark=True, ) # Train phase 2 with error handling try: trainer_phase2.fit(model) except Exception as e: print(f"Error during phase 2 training: {str(e)}") raise # Save final model final_checkpoint_path = os.path.join(config.training.checkpoint_dir, "smol-lm2-final.ckpt") trainer_phase2.save_checkpoint(final_checkpoint_path) print(f"Phase 2 completed. Final model saved to {final_checkpoint_path}") except Exception as e: print(f"\nTraining failed with error: {str(e)}") if torch.cuda.is_available(): print(f"CUDA memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB") print(f"CUDA memory cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB") raise finally: cleanup_training() if __name__ == "__main__": main()