Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import DataLoader, Dataset | |
from torchvision import transforms | |
import numpy as np | |
import gzip | |
import os | |
from pathlib import Path | |
from datetime import datetime | |
import urllib.request | |
import shutil | |
from tqdm import tqdm | |
import asyncio | |
from fastapi import WebSocket | |
import json | |
from scripts.model import Net | |
class TrainingConfig: | |
def __init__(self, params_dict): | |
self.block1 = params_dict['block1'] | |
self.block2 = params_dict['block2'] | |
self.block3 = params_dict['block3'] | |
self.optimizer = params_dict['optimizer'] | |
self.batch_size = params_dict['batch_size'] | |
self.epochs = params_dict['epochs'] | |
def generate_model_filename(config, model_type="single"): | |
"""Generate a filename based on model configuration | |
model_type can be "single", "model_1", or "model_2" | |
""" | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
arch = f"{config.block1}_{config.block2}_{config.block3}" | |
opt = config.optimizer.lower() | |
batch = str(config.batch_size) | |
return f"{model_type}_arch_{arch}_opt_{opt}_batch_{batch}_{timestamp}.pth" | |
def download_and_extract_mnist_data(): | |
"""Download and extract MNIST dataset from a reliable mirror""" | |
base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/" | |
files = { | |
"train_images": "train-images-idx3-ubyte.gz", | |
"train_labels": "train-labels-idx1-ubyte.gz", | |
"test_images": "t10k-images-idx3-ubyte.gz", | |
"test_labels": "t10k-labels-idx1-ubyte.gz" | |
} | |
data_dir = Path("data/MNIST/raw") | |
data_dir.mkdir(parents=True, exist_ok=True) | |
for file_name in files.values(): | |
gz_file_path = data_dir / file_name | |
extracted_file_path = data_dir / file_name.replace('.gz', '') | |
# If the extracted file exists, skip downloading | |
if extracted_file_path.exists(): | |
print(f"{extracted_file_path} already exists, skipping download.") | |
continue | |
# Download the file | |
print(f"Downloading {file_name}...") | |
url = base_url + file_name | |
try: | |
urllib.request.urlretrieve(url, gz_file_path) | |
print(f"Successfully downloaded {file_name}") | |
except Exception as e: | |
print(f"Failed to download {file_name}: {e}") | |
raise Exception(f"Could not download {file_name}") | |
# Extract the files | |
try: | |
print(f"Extracting {file_name}...") | |
with gzip.open(gz_file_path, 'rb') as f_in: | |
with open(extracted_file_path, 'wb') as f_out: | |
shutil.copyfileobj(f_in, f_out) | |
print(f"Successfully extracted {file_name}") | |
except Exception as e: | |
print(f"Failed to extract {file_name}: {e}") | |
raise Exception(f"Could not extract {file_name}") | |
def load_mnist_images(filename): | |
with open(filename, 'rb') as f: | |
data = np.frombuffer(f.read(), np.uint8, offset=16) | |
return data.reshape(-1, 1, 28, 28).astype(np.float32) / 255.0 | |
def load_mnist_labels(filename): | |
with open(filename, 'rb') as f: | |
return np.frombuffer(f.read(), np.uint8, offset=8) | |
class CustomMNISTDataset(Dataset): | |
def __init__(self, images_path, labels_path, transform=None): | |
self.images = load_mnist_images(images_path) | |
self.labels = load_mnist_labels(labels_path) | |
self.transform = transform | |
def __len__(self): | |
return len(self.labels) | |
def __getitem__(self, idx): | |
image = torch.FloatTensor(self.images[idx]) | |
label = int(self.labels[idx]) | |
if self.transform: | |
image = self.transform(image) | |
return image, label | |
def validate(model, test_loader, criterion, device): | |
"""Modified validate function to handle validation properly""" | |
model.eval() | |
val_loss = 0 | |
correct = 0 | |
total = 0 | |
num_batches = 0 | |
with torch.no_grad(): # Important: no gradient computation in validation | |
for data, target in test_loader: | |
data, target = data.to(device), target.to(device) | |
output = model(data) | |
val_loss += criterion(output, target).item() # Don't scale by batch size | |
_, predicted = output.max(1) | |
total += target.size(0) | |
correct += predicted.eq(target).sum().item() | |
num_batches += 1 | |
# Average the loss by number of batches and accuracy by total samples | |
val_loss = val_loss / num_batches # Average loss across batches | |
val_acc = 100. * correct / total | |
return val_loss, val_acc | |
async def train(model, config, websocket=None, model_type="single"): | |
print("\nStarting training...") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
model = model.to(device) | |
# Create data directory if it doesn't exist | |
data_dir = Path("data") | |
data_dir.mkdir(exist_ok=True) | |
# Ensure data is downloaded and extracted | |
print("Preparing dataset...") | |
download_and_extract_mnist_data() | |
# Paths to the extracted files | |
train_images_path = "data/MNIST/raw/train-images-idx3-ubyte" | |
train_labels_path = "data/MNIST/raw/train-labels-idx1-ubyte" | |
test_images_path = "data/MNIST/raw/t10k-images-idx3-ubyte" | |
test_labels_path = "data/MNIST/raw/t10k-labels-idx1-ubyte" | |
# Data loading | |
transform = transforms.Compose([ | |
transforms.Normalize((0.1307,), (0.3081,)) | |
]) | |
train_dataset = CustomMNISTDataset(train_images_path, train_labels_path, transform=transform) | |
test_dataset = CustomMNISTDataset(test_images_path, test_labels_path, transform=transform) | |
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True) | |
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False) | |
print(f"Dataset loaded. Training samples: {len(train_dataset)}, Test samples: {len(test_dataset)}") | |
print("\nTraining Configuration:") | |
print(f"Epochs: {config.epochs}") | |
print(f"Optimizer: {config.optimizer}") | |
print(f"Batch Size: {config.batch_size}") | |
print(f"Network Architecture: {config.block1}-{config.block2}-{config.block3}") | |
# Print model parameters | |
total_params = sum(p.numel() for p in model.parameters()) | |
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
print(f"\nModel Parameters:") | |
print(f"Total parameters: {total_params:,}") | |
print(f"Trainable parameters: {trainable_params:,}") | |
print("\nStarting training loop...") | |
best_val_acc = 0 | |
criterion = nn.CrossEntropyLoss() | |
# Initialize optimizer based on config | |
if config.optimizer.lower() == 'adam': | |
optimizer = optim.Adam(model.parameters()) | |
else: | |
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) | |
# Create models directory if it doesn't exist | |
models_dir = Path("scripts/training/models") | |
models_dir.mkdir(parents=True, exist_ok=True) | |
try: | |
for epoch in range(config.epochs): | |
model.train() | |
total_loss = 0 | |
correct = 0 | |
total = 0 | |
progress_bar = tqdm( | |
train_loader, | |
desc=f"Epoch {epoch+1}/{config.epochs}", | |
unit='batch', | |
leave=True | |
) | |
for batch_idx, (data, target) in enumerate(progress_bar): | |
data, target = data.to(device), target.to(device) | |
optimizer.zero_grad() | |
output = model(data) | |
loss = criterion(output, target) | |
loss.backward() | |
optimizer.step() | |
# Calculate batch accuracy | |
pred = output.argmax(dim=1, keepdim=True) | |
correct += pred.eq(target.view_as(pred)).sum().item() | |
total += target.size(0) | |
total_loss += loss.item() | |
# Calculate current metrics | |
current_loss = total_loss / (batch_idx + 1) | |
current_acc = 100. * correct / total | |
# Send training update through websocket | |
if websocket: | |
try: | |
step = batch_idx + epoch * len(train_loader) | |
await websocket.send_json({ | |
'type': 'training_update', | |
'data': { | |
'step': step, | |
'train_loss': current_loss, | |
'train_acc': current_acc, | |
'epoch': epoch | |
} | |
}) | |
except Exception as e: | |
print(f"Error sending websocket update: {e}") | |
# Validation phase | |
model.eval() | |
val_loss = 0 | |
val_correct = 0 | |
val_total = 0 | |
print("\nRunning validation...") | |
with torch.no_grad(): | |
for data, target in test_loader: | |
data, target = data.to(device), target.to(device) | |
output = model(data) | |
val_loss += criterion(output, target).item() | |
pred = output.argmax(dim=1, keepdim=True) | |
val_correct += pred.eq(target.view_as(pred)).sum().item() | |
val_total += target.size(0) | |
val_loss /= len(test_loader) | |
val_acc = 100. * val_correct / val_total | |
# Print epoch results | |
print(f"\nEpoch {epoch+1}/{config.epochs} Results:") | |
print(f"Training Loss: {current_loss:.4f} | Training Accuracy: {current_acc:.2f}%") | |
print(f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_acc:.2f}%") | |
# Send validation update through websocket | |
if websocket: | |
try: | |
await websocket.send_json({ | |
'type': 'validation_update', | |
'data': { | |
'step': (epoch + 1) * len(train_loader), | |
'val_loss': val_loss, | |
'val_acc': val_acc | |
} | |
}) | |
except Exception as e: | |
print(f"Error sending websocket update: {e}") | |
# Save best model with configuration in filename | |
if val_acc > best_val_acc: | |
best_val_acc = val_acc | |
print(f"\nNew best validation accuracy: {val_acc:.2f}%") | |
# Generate filename with configuration | |
model_filename = generate_model_filename(config, model_type) | |
model_path = models_dir / model_filename | |
print(f"Saving model as: {model_filename}") | |
torch.save(model.state_dict(), model_path) | |
except Exception as e: | |
print(f"\nError during training: {e}") | |
if websocket: | |
await websocket.send_json({ | |
'type': 'training_error', | |
'data': { | |
'message': str(e) | |
} | |
}) | |
raise e | |
print("\nTraining completed!") | |
print(f"Best validation accuracy: {best_val_acc:.2f}%") | |
if websocket: | |
await websocket.send_json({ | |
'type': 'training_complete', | |
'data': { | |
'message': 'Training completed successfully!', | |
'best_val_acc': best_val_acc | |
} | |
}) | |
return None | |
def initialize_datasets(batch_size): | |
"""Initialize and return train and test datasets with dataloaders""" | |
# Ensure data is downloaded and extracted | |
print("Preparing dataset...") | |
download_and_extract_mnist_data() | |
# Paths to the extracted files | |
train_images_path = "data/MNIST/raw/train-images-idx3-ubyte" | |
train_labels_path = "data/MNIST/raw/train-labels-idx1-ubyte" | |
test_images_path = "data/MNIST/raw/t10k-images-idx3-ubyte" | |
test_labels_path = "data/MNIST/raw/t10k-labels-idx1-ubyte" | |
# Data loading | |
transform = transforms.Compose([ | |
transforms.Normalize((0.1307,), (0.3081,)) | |
]) | |
train_dataset = CustomMNISTDataset(train_images_path, train_labels_path, transform=transform) | |
test_dataset = CustomMNISTDataset(test_images_path, test_labels_path, transform=transform) | |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | |
return train_dataset, test_dataset, train_loader, test_loader | |
async def start_comparison_training(websocket: WebSocket, parameters: dict): | |
print("\n=== Starting Comparison Training ===") | |
print(f"Received parameters: {json.dumps(parameters, indent=2)}") | |
try: | |
# Create models directory if it doesn't exist | |
models_dir = Path("scripts/training/models") | |
models_dir.mkdir(parents=True, exist_ok=True) | |
# Validate parameters | |
if not parameters.get('model_params'): | |
print("Error: Missing model parameters") | |
raise ValueError("Missing model parameters") | |
if not parameters.get('dataset_params'): | |
print("Error: Missing dataset parameters") | |
raise ValueError("Missing dataset parameters") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
criterion = nn.CrossEntropyLoss() | |
# Calculate total training samples once | |
train_dataset = CustomMNISTDataset( | |
"data/MNIST/raw/train-images-idx3-ubyte", | |
"data/MNIST/raw/train-labels-idx1-ubyte", | |
transform=transforms.Compose([transforms.Normalize((0.1307,), (0.3081,))]) | |
) | |
total_samples = len(train_dataset) | |
# Dictionary to store best accuracies | |
best_accuracies = {} | |
# Start training models | |
for model_key, model_letter in [('model_a', 'A'), ('model_b', 'B')]: | |
print(f"\n{'='*50}") | |
print(f"Training Model {model_letter}") | |
print(f"{'='*50}") | |
model_params = parameters['model_params'][model_key] | |
# Calculate iterations per epoch for this model | |
batch_size = model_params['batch_size'] | |
iterations_per_epoch = total_samples // batch_size | |
total_iterations = iterations_per_epoch * model_params['epochs'] | |
# Print configuration details | |
print("\nModel Configuration:") | |
print(f"Architecture: {model_params['block1']}-{model_params['block2']}-{model_params['block3']}") | |
print(f"Optimizer: {model_params['optimizer']}") | |
print(f"Batch Size: {model_params['batch_size']}") | |
print(f"Epochs: {model_params['epochs']}") | |
print(f"Iterations per epoch: {iterations_per_epoch:,}") | |
print(f"Total iterations: {total_iterations:,}") | |
try: | |
# Initialize datasets with model-specific batch size | |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
test_dataset = CustomMNISTDataset( | |
"data/MNIST/raw/t10k-images-idx3-ubyte", | |
"data/MNIST/raw/t10k-labels-idx1-ubyte", | |
transform=transforms.Compose([transforms.Normalize((0.1307,), (0.3081,))]) | |
) | |
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | |
print(f"\nDataset Information:") | |
print(f"Training samples: {len(train_dataset):,}") | |
print(f"Test samples: {len(test_dataset):,}") | |
print(f"Steps per epoch: {len(train_loader):,}") | |
# Initialize model and move to device | |
model = Net(kernels=[ | |
model_params['block1'], | |
model_params['block2'], | |
model_params['block3'] | |
]).to(device) | |
# Print model parameters | |
total_params = sum(p.numel() for p in model.parameters()) | |
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
print(f"\nModel Parameters:") | |
print(f"Total parameters: {total_params:,}") | |
print(f"Trainable parameters: {trainable_params:,}") | |
# Initialize optimizer | |
if model_params['optimizer'].lower() == 'adam': | |
optimizer = optim.Adam(model.parameters()) | |
else: | |
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) | |
# Train the model | |
current_iteration = 0 | |
best_acc = 0 # Track best accuracy for model saving | |
for epoch in range(model_params['epochs']): | |
model.train() | |
total_loss = 0 | |
correct = 0 | |
total = 0 | |
# Create progress bar for each epoch | |
progress_bar = tqdm( | |
train_loader, | |
desc=f"Epoch {epoch+1}/{model_params['epochs']}", | |
unit='batch', | |
leave=True, | |
ncols=100 | |
) | |
for batch_idx, (data, target) in enumerate(progress_bar): | |
data, target = data.to(device), target.to(device) | |
optimizer.zero_grad() | |
output = model(data) | |
loss = criterion(output, target) | |
loss.backward() | |
optimizer.step() | |
# Calculate batch accuracy | |
pred = output.argmax(dim=1, keepdim=True) | |
correct += pred.eq(target.view_as(pred)).sum().item() | |
total += target.size(0) | |
total_loss += loss.item() | |
# Calculate current metrics | |
current_loss = total_loss / (batch_idx + 1) | |
current_acc = 100. * correct / total | |
# Update progress bar description | |
progress_bar.set_postfix({ | |
'loss': f'{current_loss:.4f}', | |
'acc': f'{current_acc:.2f}%' | |
}) | |
# Send comparison-specific training update | |
current_iteration += 1 | |
await websocket.send_json({ | |
'status': 'training', | |
'model': model_letter, | |
'metrics': { | |
'iteration': current_iteration, | |
'total_iterations': total_iterations, | |
'loss': current_loss, | |
'accuracy': current_acc | |
}, | |
'epoch': epoch, | |
'batch_size': batch_size, | |
'iterations_per_epoch': iterations_per_epoch | |
}) | |
# Print epoch summary | |
print(f"\nEpoch {epoch+1} Summary:") | |
print(f"Average Loss: {current_loss:.4f}") | |
print(f"Accuracy: {current_acc:.2f}%") | |
# Add validation phase at the end of each epoch | |
model.eval() | |
val_loss = 0 | |
val_correct = 0 | |
val_total = 0 | |
print("\nRunning validation...") | |
with torch.no_grad(): | |
for data, target in test_loader: | |
data, target = data.to(device), target.to(device) | |
output = model(data) | |
val_loss += criterion(output, target).item() | |
pred = output.argmax(dim=1, keepdim=True) | |
val_correct += pred.eq(target.view_as(pred)).sum().item() | |
val_total += target.size(0) | |
val_loss /= len(test_loader) | |
val_acc = 100. * val_correct / val_total | |
# Save model if it's the best so far | |
if val_acc > best_acc: | |
best_acc = val_acc | |
# Generate filename with configuration | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
model_filename = f"{model_key}_arch_{model_params['block1']}_{model_params['block2']}_{model_params['block3']}_opt_{model_params['optimizer'].lower()}_batch_{model_params['batch_size']}_{timestamp}.pth" | |
model_path = models_dir / model_filename | |
print(f"\nSaving Model {model_letter} with accuracy {val_acc:.2f}% as: {model_filename}") | |
torch.save(model.state_dict(), model_path) | |
print(f"\nModel {model_letter} training completed") | |
print(f"Best validation accuracy: {best_acc:.2f}%") | |
# Save best accuracy for this model | |
best_accuracies[model_key] = best_acc | |
except Exception as e: | |
print(f"Error training Model {model_letter}: {str(e)}") | |
raise | |
print("\nBoth models trained successfully") | |
await websocket.send_json({ | |
'status': 'complete', | |
'message': 'Training completed for both models', | |
'model_a_acc': best_accuracies.get('model_a'), | |
'model_b_acc': best_accuracies.get('model_b') | |
}) | |
except Exception as e: | |
error_msg = f"Error in comparison training: {str(e)}" | |
print(error_msg) | |
await websocket.send_json({ | |
'status': 'error', | |
'message': error_msg | |
}) | |
finally: | |
print("=== Comparison Training Ended ===\n") | |