Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request, HTTPException, WebSocket, WebSocketDisconnect | |
from fastapi.templating import Jinja2Templates | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import HTMLResponse | |
from pydantic import BaseModel | |
from typing import List, Optional | |
import uvicorn | |
import torch | |
from scripts.model import Net | |
from scripts.training.train import train, start_comparison_training | |
from pathlib import Path | |
from fastapi import BackgroundTasks | |
import warnings | |
import asyncio | |
import json | |
import numpy as np | |
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.transforms") | |
app = FastAPI() | |
# Mount static files with a name parameter | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
templates = Jinja2Templates(directory="templates") | |
# Model configurations | |
class TrainingConfig(BaseModel): | |
block1: int | |
block2: int | |
block3: int | |
optimizer: str | |
batch_size: int | |
epochs: int = 1 | |
class ComparisonConfig(BaseModel): | |
model1: TrainingConfig | |
model2: TrainingConfig | |
def get_available_models(): | |
models_dir = Path("scripts/training/models") | |
if not models_dir.exists(): | |
models_dir.mkdir(exist_ok=True, parents=True) | |
return [f.stem for f in models_dir.glob("*.pth")] | |
# Add a global variable to store training task | |
training_task = None | |
async def home(request: Request): | |
return templates.TemplateResponse("index.html", {"request": request}) | |
async def train_page(request: Request): | |
return templates.TemplateResponse("train.html", {"request": request}) | |
async def inference_page(request: Request): | |
available_models = get_available_models() | |
return templates.TemplateResponse( | |
"inference.html", | |
{ | |
"request": request, | |
"available_models": available_models | |
} | |
) | |
async def train_model(config: TrainingConfig, background_tasks: BackgroundTasks): | |
try: | |
# Create model instance with the configuration | |
model = Net( | |
kernels=[config.block1, config.block2, config.block3] | |
) | |
# Store training configuration | |
training_config = { | |
"optimizer": config.optimizer, | |
"batch_size": config.batch_size | |
} | |
return {"status": "success", "message": "Training configuration received"} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def websocket_endpoint(websocket: WebSocket): | |
await websocket.accept() | |
try: | |
print("WebSocket connection accepted for single model training") | |
config_data = await websocket.receive_json() | |
print(f"Received config data: {config_data}") | |
model = Net( | |
kernels=[ | |
config_data['block1'], | |
config_data['block2'], | |
config_data['block3'] | |
] | |
) | |
# Create TrainingConfig object for single model using **kwargs | |
config = TrainingConfig(**{ | |
'block1': config_data['block1'], | |
'block2': config_data['block2'], | |
'block3': config_data['block3'], | |
'optimizer': config_data['optimizer'], | |
'batch_size': config_data['batch_size'], | |
'epochs': config_data['epochs'] | |
}) | |
print(f"Starting training with config: {config_data}") | |
try: | |
await train(model, config, websocket, model_type="single") | |
except Exception as e: | |
print(f"Training error: {str(e)}") | |
await websocket.send_json({ | |
"type": "training_error", | |
"data": { | |
"message": f"Training failed: {str(e)}" | |
} | |
}) | |
except WebSocketDisconnect: | |
print("WebSocket disconnected") | |
except Exception as e: | |
print(f"WebSocket error: {str(e)}") | |
await websocket.send_json({ | |
"type": "training_error", | |
"data": { | |
"message": f"WebSocket error: {str(e)}" | |
} | |
}) | |
finally: | |
print("WebSocket connection closed") | |
async def websocket_endpoint(websocket: WebSocket): | |
print("\n=== New WebSocket Connection ===") | |
print("New WebSocket connection attempt") | |
try: | |
await websocket.accept() | |
print("WebSocket connection accepted") | |
print("Waiting for initial message...") | |
data = await websocket.receive_json() | |
print(f"Received initial message: {data}") | |
if 'action' not in data: | |
print("Error: Missing 'action' in message") | |
await websocket.send_json({ | |
'status': 'error', | |
'message': 'Missing action in request' | |
}) | |
return | |
if data['action'] == 'start_training': | |
if 'parameters' not in data: | |
print("Error: Missing 'parameters' in message") | |
await websocket.send_json({ | |
'status': 'error', | |
'message': 'Missing parameters in request' | |
}) | |
return | |
print("Starting training task") | |
try: | |
training_task = asyncio.create_task(start_comparison_training( | |
websocket, | |
data['parameters'] | |
)) | |
print("Training task created, awaiting completion...") | |
await training_task | |
print("Training task completed") | |
except Exception as e: | |
print(f"Error during training task: {str(e)}") | |
await websocket.send_json({ | |
'status': 'error', | |
'message': f'Training error: {str(e)}' | |
}) | |
else: | |
print(f"Unknown action received: {data['action']}") | |
except WebSocketDisconnect: | |
print("WebSocket disconnected") | |
except json.JSONDecodeError as e: | |
print(f"JSON decode error: {str(e)}") | |
except Exception as e: | |
print(f"Unexpected error in websocket handler: {str(e)}") | |
finally: | |
print("=== WebSocket Connection Closed ===\n") | |
# @app.post("/api/train_single") | |
# async def train_single_model(config: TrainingConfig): | |
# try: | |
# model = Net(kernels=config.kernels) | |
# # Start training without passing the websocket | |
# await train(model, config) | |
# return {"status": "success"} | |
# except Exception as e: | |
# # Log the error for debugging | |
# print(f"Error during training: {str(e)}") | |
# # Return a JSON response with the error message | |
# raise HTTPException(status_code=500, detail=f"Error during training: {str(e)}") | |
async def train_compare_models(config: ComparisonConfig): | |
try: | |
# Train both models | |
model1 = Net(kernels=config.model1.kernels) | |
model2 = Net(kernels=config.model2.kernels) | |
results1 = train(model1, config.model1) | |
results2 = train(model2, config.model2) | |
return { | |
"status": "success", | |
"model1_results": results1, | |
"model2_results": results2 | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def parse_model_filename(filename): | |
"""Extract configuration from model filename""" | |
# Example filename: single_arch_32_64_128_opt_adam_batch_64_20240322_123456.pth | |
try: | |
parts = filename.split('_') | |
# Find architecture values | |
arch_index = parts.index('arch') | |
block1 = int(parts[arch_index + 1]) | |
block2 = int(parts[arch_index + 2]) | |
block3 = int(parts[arch_index + 3]) | |
# Find optimizer | |
opt_index = parts.index('opt') | |
optimizer = parts[opt_index + 1] | |
# Find batch size | |
batch_index = parts.index('batch') | |
batch_size = int(parts[batch_index + 1]) | |
return { | |
'block1': block1, | |
'block2': block2, | |
'block3': block3, | |
'optimizer': optimizer, | |
'batch_size': batch_size | |
} | |
except Exception as e: | |
print(f"Error parsing model filename: {e}") | |
return None | |
async def perform_inference(data: dict): | |
try: | |
model_name = data.get("model_name") | |
if not model_name: | |
raise HTTPException(status_code=400, detail="No model selected") | |
model_path = Path("scripts/training/models") / f"{model_name}.pth" | |
if not model_path.exists(): | |
raise HTTPException(status_code=404, detail=f"Model not found: {model_path}") | |
# Parse model configuration from filename | |
config = parse_model_filename(model_name) | |
if not config: | |
raise HTTPException(status_code=500, detail="Could not parse model configuration") | |
# Create model with the correct configuration | |
model = Net( | |
kernels=[ | |
config['block1'], | |
config['block2'], | |
config['block3'] | |
] | |
) | |
# Load model weights | |
model.load_state_dict(torch.load(str(model_path), map_location=torch.device('cpu'), weights_only=True)) | |
model.eval() | |
# Process image data and get prediction | |
image_data = data.get("image") | |
if not image_data: | |
raise HTTPException(status_code=400, detail="No image data provided") | |
# Convert base64 image to tensor and process | |
try: | |
# Remove the data URL prefix | |
image_data = image_data.split(',')[1] | |
import base64 | |
import io | |
from PIL import Image | |
import torchvision.transforms as transforms | |
# Decode base64 to image | |
image_bytes = base64.b64decode(image_data) | |
image = Image.open(io.BytesIO(image_bytes)).convert('L') # Convert to grayscale | |
# Resize using PIL directly with LANCZOS | |
image = image.resize((28, 28), Image.LANCZOS) | |
# Invert the image (subtract from 255 to invert grayscale) | |
image = Image.fromarray(255 - np.array(image)) | |
# Preprocess image | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
]) | |
# Convert to tensor and add batch dimension | |
image_tensor = transform(image).unsqueeze(0) | |
# Get prediction | |
with torch.no_grad(): | |
output = model(image_tensor) | |
prediction = output.argmax(dim=1).item() | |
# Add configuration info to response | |
return { | |
"prediction": prediction, | |
"model_config": { | |
"architecture": f"{config['block1']}-{config['block2']}-{config['block3']}", | |
"optimizer": config['optimizer'], | |
"batch_size": config['batch_size'] | |
} | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def train_single_page(request: Request): | |
return templates.TemplateResponse("train_single.html", {"request": request}) | |
async def train_compare_page(request: Request): | |
return templates.TemplateResponse("train_compare.html", {"request": request}) | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |