import os from datetime import datetime from typing import Optional, Tuple import glob import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision import transforms, models, datasets from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar from loguru import logger class CustomProgressBar(TQDMProgressBar): def __init__(self): super().__init__() self.enable = True def on_train_epoch_start(self, trainer, pl_module): super().on_train_epoch_start(trainer, pl_module) logger.info(f"\n{'='*20} Epoch {trainer.current_epoch} {'='*20}") class ImageNetModule(LightningModule): def __init__( self, learning_rate: float = 0.1, momentum: float = 0.9, weight_decay: float = 1e-4, batch_size: int = 256, num_workers: int = 16, max_epochs: int = 90, train_path: str = "path/to/imagenet", val_path: str = "path/to/imagenet", checkpoint_dir: str = "checkpoints" ): super().__init__() # self.save_hyperparameters() # Model self.model = models.resnet50(weights=None) # Training parameters self.learning_rate = learning_rate self.momentum = momentum self.weight_decay = weight_decay self.batch_size = batch_size self.num_workers = num_workers self.max_epochs = max_epochs self.train_path = train_path self.val_path = val_path self.checkpoint_dir = checkpoint_dir # Metrics tracking self.training_step_outputs = [] self.validation_step_outputs = [] self.best_val_acc = 0.0 # Set up transforms self.train_transforms = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) self.val_transforms = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): images, labels = batch outputs = self(images) loss = F.cross_entropy(outputs, labels) # Calculate accuracy _, predicted = torch.max(outputs.data, 1) correct = (predicted == labels).sum().item() accuracy = (correct / labels.size(0))*100 # Log metrics for this step self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True) self.log('train_acc', accuracy, on_step=False, on_epoch=True, prog_bar=True) self.training_step_outputs.append({ 'loss': loss.detach(), 'acc': torch.tensor(accuracy) }) return loss def on_train_epoch_end(self): if not self.training_step_outputs: print("Warning: No training outputs available for this epoch") return avg_loss = torch.stack([x['loss'] for x in self.training_step_outputs]).mean() avg_acc = torch.stack([x['acc'] for x in self.training_step_outputs]).mean() # Get current learning rate current_lr = self.optimizers().param_groups[0]['lr'] logger.info(f"Training metrics - Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.4f}, LR: {current_lr:.6f}") self.training_step_outputs.clear() def validation_step(self, batch, batch_idx): images, labels = batch outputs = self(images) loss = F.cross_entropy(outputs, labels) # Calculate accuracy _, predicted = torch.max(outputs.data, 1) correct = (predicted == labels).sum().item() accuracy = (correct / labels.size(0))*100 # Log metrics for this step self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True) self.log('val_acc', accuracy, on_step=False, on_epoch=True, prog_bar=True) self.validation_step_outputs.append({ 'val_loss': loss.detach(), 'val_acc': torch.tensor(accuracy) }) return {'val_loss': loss, 'val_acc': accuracy} def on_validation_epoch_end(self): avg_loss = torch.stack([x['val_loss'] for x in self.validation_step_outputs]).mean() avg_acc = torch.stack([x['val_acc'] for x in self.validation_step_outputs]).mean() # Log final validation metrics self.log('val_loss_epoch', avg_loss) self.log('val_acc_epoch', avg_acc) # Save checkpoint if validation accuracy improves if avg_acc > self.best_val_acc: self.best_val_acc = avg_acc checkpoint_path = os.path.join( self.checkpoint_dir, f"resnet50-epoch{self.current_epoch:02d}-acc{avg_acc:.4f}.ckpt" ) self.trainer.save_checkpoint(checkpoint_path) logger.info(f"New best validation accuracy: {avg_acc:.4f}. Saved checkpoint to {checkpoint_path}") logger.info(f"Validation metrics - Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.4f}") self.validation_step_outputs.clear() def train_dataloader(self): train_dataset = datasets.ImageFolder( self.train_path, transform=self.train_transforms ) return DataLoader( train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True ) def val_dataloader(self): val_dataset = datasets.ImageFolder( self.val_path, transform=self.val_transforms ) return DataLoader( val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True ) def configure_optimizers(self): optimizer = torch.optim.SGD( self.parameters(), lr=self.learning_rate, momentum=self.momentum, weight_decay=self.weight_decay ) # OneCycleLR scheduler scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=self.learning_rate, epochs=self.max_epochs, steps_per_epoch=len(self.train_dataloader()), pct_start=0.3, anneal_strategy='cos', div_factor=25.0, cycle_momentum=True, base_momentum=0.85, max_momentum=0.95, ) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, "interval": "step" } } def setup_logging(log_dir="logs"): os.makedirs(log_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") log_file = os.path.join(log_dir, f"training_{timestamp}.log") logger.remove() logger.add( lambda msg: print(msg), format="{time:YYYY-MM-DD HH:mm:ss} | {message}", colorize=True, level="INFO" ) logger.add( log_file, format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}", level="INFO", rotation="100 MB", retention="30 days" ) logger.info(f"Logging setup complete. Logs will be saved to: {log_file}") return log_file def find_latest_checkpoint(checkpoint_dir: str) -> Optional[str]: """Find the latest checkpoint file using various possible naming patterns.""" # Look for checkpoint files with different possible patterns patterns = [ "*.ckpt", # Generic checkpoint files "resnet50-epoch*.ckpt", # Our custom format "*epoch=*.ckpt", # PyTorch Lightning default format "checkpoint_epoch*.ckpt" # Another common format ] all_checkpoints = [] for pattern in patterns: checkpoint_pattern = os.path.join(checkpoint_dir, pattern) all_checkpoints.extend(glob.glob(checkpoint_pattern)) if not all_checkpoints: logger.info("No existing checkpoints found.") return None def extract_info(checkpoint_path: str) -> Tuple[int, float]: """Extract epoch and optional accuracy from checkpoint filename.""" filename = os.path.basename(checkpoint_path) # Try different patterns to extract epoch number epoch_patterns = [ r'epoch=(\d+)', # matches epoch=X r'epoch(\d+)', # matches epochX r'epoch[_-](\d+)', # matches epoch_X or epoch-X ] epoch = None for pattern in epoch_patterns: match = re.search(pattern, filename) if match: epoch = int(match.group(1)) break # If no epoch found, try to get from file modification time if epoch is None: epoch = int(os.path.getmtime(checkpoint_path)) # Try to extract accuracy if present acc_match = re.search(r'acc[_-]?([\d.]+)', filename) acc = float(acc_match.group(1)) if acc_match else 0.0 return epoch, acc try: latest_checkpoint = max(all_checkpoints, key=lambda x: extract_info(x)[0]) epoch, acc = extract_info(latest_checkpoint) logger.info(f"Found latest checkpoint: {latest_checkpoint}") logger.info(f"Epoch: {epoch}" + (f", Accuracy: {acc:.4f}" if acc > 0 else "")) return latest_checkpoint except Exception as e: logger.error(f"Error processing checkpoints: {str(e)}") # If there's any error in parsing, return the most recently modified file latest_checkpoint = max(all_checkpoints, key=os.path.getmtime) logger.info(f"Falling back to most recently modified checkpoint: {latest_checkpoint}") return latest_checkpoint def main(): checkpoint_dir = "/home/ec2-user/ebs/volumes/era_session9" log_file = setup_logging(log_dir=checkpoint_dir) logger.info("Starting training with configuration:") logger.info(f"PyTorch version: {torch.__version__}") logger.info(f"CUDA available: {torch.cuda.is_available()}") if torch.cuda.is_available(): logger.info(f"CUDA device count: {torch.cuda.device_count()}") logger.info(f"CUDA devices: {[torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]}") # Find latest checkpoint # latest_checkpoint = find_latest_checkpoint(checkpoint_dir) latest_checkpoint = "/home/ec2-user/ebs/volumes/era_session9/resnet50-epoch18-acc53.7369.ckpt" model = ImageNetModule( learning_rate=0.156, batch_size=256, num_workers=16, max_epochs=60, train_path="/home/ec2-user/ebs/volumes/imagenet/ILSVRC/Data/CLS-LOC/train", val_path="/home/ec2-user/ebs/volumes/imagenet/imagenet_validation", checkpoint_dir=checkpoint_dir ) logger.info(f"Model configuration:") logger.info(f"Learning rate: {model.learning_rate}") logger.info(f"Batch size: {model.batch_size}") logger.info(f"Number of workers: {model.num_workers}") logger.info(f"Max epochs: {model.max_epochs}") progress_bar = CustomProgressBar() trainer = Trainer( max_epochs=60, accelerator="gpu", devices=4, strategy="ddp", precision=16, callbacks=[progress_bar], enable_progress_bar=True, ) logger.info("Starting training") try: if latest_checkpoint: logger.info(f"Resuming training from checkpoint: {latest_checkpoint}") trainer.fit(model, ckpt_path=latest_checkpoint) else: logger.info("Starting training from scratch") trainer.fit(model) logger.info("Training completed successfully") except Exception as e: logger.error(f"Training failed with error: {str(e)}") raise finally: logger.info(f"Training session ended. Log file: {log_file}") if __name__ == "__main__": main() # pass