Spaces:
Sleeping
Sleeping
Feat: Completed logic for multiple models training and comparison
Browse files- app.py +62 -64
- scripts/training/train.py +266 -8
- static/js/train.js +6 -6
- static/js/train_compare.js +159 -6
- templates/train_compare.html +110 -40
app.py
CHANGED
@@ -7,10 +7,12 @@ from typing import List, Optional
|
|
7 |
import uvicorn
|
8 |
import torch
|
9 |
from scripts.model import Net
|
10 |
-
from scripts.training.train import train
|
11 |
from pathlib import Path
|
12 |
from fastapi import BackgroundTasks
|
13 |
import warnings
|
|
|
|
|
14 |
|
15 |
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.transforms")
|
16 |
|
@@ -83,7 +85,9 @@ async def train_model(config: TrainingConfig, background_tasks: BackgroundTasks)
|
|
83 |
async def websocket_endpoint(websocket: WebSocket):
|
84 |
await websocket.accept()
|
85 |
try:
|
|
|
86 |
config_data = await websocket.receive_json()
|
|
|
87 |
|
88 |
model = Net(
|
89 |
kernels=[
|
@@ -93,28 +97,20 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
93 |
]
|
94 |
)
|
95 |
|
96 |
-
|
97 |
-
config =
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
)
|
106 |
|
107 |
print(f"Starting training with config: {config_data}")
|
108 |
|
109 |
try:
|
110 |
-
# Pass "single" as model_type for single model training
|
111 |
await train(model, config, websocket, model_type="single")
|
112 |
-
await websocket.send_json({
|
113 |
-
"type": "training_complete",
|
114 |
-
"data": {
|
115 |
-
"message": "Training completed successfully!"
|
116 |
-
}
|
117 |
-
})
|
118 |
except Exception as e:
|
119 |
print(f"Training error: {str(e)}")
|
120 |
await websocket.send_json({
|
@@ -128,68 +124,70 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
128 |
print("WebSocket disconnected")
|
129 |
except Exception as e:
|
130 |
print(f"WebSocket error: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
finally:
|
132 |
print("WebSocket connection closed")
|
133 |
|
134 |
@app.websocket("/ws/compare")
|
135 |
-
async def
|
136 |
-
|
|
|
137 |
try:
|
|
|
|
|
|
|
|
|
138 |
data = await websocket.receive_json()
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
model2_config.update(**data["model2"])
|
149 |
-
|
150 |
-
# Create models with respective configurations
|
151 |
-
model1 = Net(
|
152 |
-
kernels=[
|
153 |
-
model1_config.block1,
|
154 |
-
model1_config.block2,
|
155 |
-
model1_config.block3
|
156 |
-
]
|
157 |
-
)
|
158 |
-
|
159 |
-
model2 = Net(
|
160 |
-
kernels=[
|
161 |
-
model2_config.block1,
|
162 |
-
model2_config.block2,
|
163 |
-
model2_config.block3
|
164 |
-
]
|
165 |
-
)
|
166 |
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
await train(model2, model2_config, websocket, model_type="model_2")
|
171 |
-
|
172 |
await websocket.send_json({
|
173 |
-
|
174 |
-
|
175 |
-
"message": "Training completed successfully!"
|
176 |
-
}
|
177 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
except Exception as e:
|
179 |
-
print(f"
|
180 |
await websocket.send_json({
|
181 |
-
|
182 |
-
|
183 |
-
"message": f"Training failed: {str(e)}"
|
184 |
-
}
|
185 |
})
|
186 |
-
|
|
|
|
|
187 |
except WebSocketDisconnect:
|
188 |
print("WebSocket disconnected")
|
|
|
|
|
189 |
except Exception as e:
|
190 |
-
print(f"
|
191 |
finally:
|
192 |
-
print("WebSocket
|
193 |
|
194 |
# @app.post("/api/train_single")
|
195 |
# async def train_single_model(config: TrainingConfig):
|
|
|
7 |
import uvicorn
|
8 |
import torch
|
9 |
from scripts.model import Net
|
10 |
+
from scripts.training.train import train, start_comparison_training
|
11 |
from pathlib import Path
|
12 |
from fastapi import BackgroundTasks
|
13 |
import warnings
|
14 |
+
import asyncio
|
15 |
+
import json
|
16 |
|
17 |
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.transforms")
|
18 |
|
|
|
85 |
async def websocket_endpoint(websocket: WebSocket):
|
86 |
await websocket.accept()
|
87 |
try:
|
88 |
+
print("WebSocket connection accepted for single model training")
|
89 |
config_data = await websocket.receive_json()
|
90 |
+
print(f"Received config data: {config_data}")
|
91 |
|
92 |
model = Net(
|
93 |
kernels=[
|
|
|
97 |
]
|
98 |
)
|
99 |
|
100 |
+
# Create TrainingConfig object for single model using **kwargs
|
101 |
+
config = TrainingConfig(**{
|
102 |
+
'block1': config_data['block1'],
|
103 |
+
'block2': config_data['block2'],
|
104 |
+
'block3': config_data['block3'],
|
105 |
+
'optimizer': config_data['optimizer'],
|
106 |
+
'batch_size': config_data['batch_size'],
|
107 |
+
'epochs': config_data['epochs']
|
108 |
+
})
|
|
|
109 |
|
110 |
print(f"Starting training with config: {config_data}")
|
111 |
|
112 |
try:
|
|
|
113 |
await train(model, config, websocket, model_type="single")
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
except Exception as e:
|
115 |
print(f"Training error: {str(e)}")
|
116 |
await websocket.send_json({
|
|
|
124 |
print("WebSocket disconnected")
|
125 |
except Exception as e:
|
126 |
print(f"WebSocket error: {str(e)}")
|
127 |
+
await websocket.send_json({
|
128 |
+
"type": "training_error",
|
129 |
+
"data": {
|
130 |
+
"message": f"WebSocket error: {str(e)}"
|
131 |
+
}
|
132 |
+
})
|
133 |
finally:
|
134 |
print("WebSocket connection closed")
|
135 |
|
136 |
@app.websocket("/ws/compare")
|
137 |
+
async def websocket_endpoint(websocket: WebSocket):
|
138 |
+
print("\n=== New WebSocket Connection ===")
|
139 |
+
print("New WebSocket connection attempt")
|
140 |
try:
|
141 |
+
await websocket.accept()
|
142 |
+
print("WebSocket connection accepted")
|
143 |
+
|
144 |
+
print("Waiting for initial message...")
|
145 |
data = await websocket.receive_json()
|
146 |
+
print(f"Received initial message: {data}")
|
147 |
+
|
148 |
+
if 'action' not in data:
|
149 |
+
print("Error: Missing 'action' in message")
|
150 |
+
await websocket.send_json({
|
151 |
+
'status': 'error',
|
152 |
+
'message': 'Missing action in request'
|
153 |
+
})
|
154 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
+
if data['action'] == 'start_training':
|
157 |
+
if 'parameters' not in data:
|
158 |
+
print("Error: Missing 'parameters' in message")
|
|
|
|
|
159 |
await websocket.send_json({
|
160 |
+
'status': 'error',
|
161 |
+
'message': 'Missing parameters in request'
|
|
|
|
|
162 |
})
|
163 |
+
return
|
164 |
+
|
165 |
+
print("Starting training task")
|
166 |
+
try:
|
167 |
+
training_task = asyncio.create_task(start_comparison_training(
|
168 |
+
websocket,
|
169 |
+
data['parameters']
|
170 |
+
))
|
171 |
+
print("Training task created, awaiting completion...")
|
172 |
+
await training_task
|
173 |
+
print("Training task completed")
|
174 |
except Exception as e:
|
175 |
+
print(f"Error during training task: {str(e)}")
|
176 |
await websocket.send_json({
|
177 |
+
'status': 'error',
|
178 |
+
'message': f'Training error: {str(e)}'
|
|
|
|
|
179 |
})
|
180 |
+
else:
|
181 |
+
print(f"Unknown action received: {data['action']}")
|
182 |
+
|
183 |
except WebSocketDisconnect:
|
184 |
print("WebSocket disconnected")
|
185 |
+
except json.JSONDecodeError as e:
|
186 |
+
print(f"JSON decode error: {str(e)}")
|
187 |
except Exception as e:
|
188 |
+
print(f"Unexpected error in websocket handler: {str(e)}")
|
189 |
finally:
|
190 |
+
print("=== WebSocket Connection Closed ===\n")
|
191 |
|
192 |
# @app.post("/api/train_single")
|
193 |
# async def train_single_model(config: TrainingConfig):
|
scripts/training/train.py
CHANGED
@@ -12,6 +12,18 @@ import urllib.request
|
|
12 |
import shutil
|
13 |
from tqdm import tqdm
|
14 |
import asyncio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
def generate_model_filename(config, model_type="single"):
|
17 |
"""Generate a filename based on model configuration
|
@@ -185,7 +197,6 @@ async def train(model, config, websocket=None, model_type="single"):
|
|
185 |
correct = 0
|
186 |
total = 0
|
187 |
|
188 |
-
# Create progress bar for each epoch
|
189 |
progress_bar = tqdm(
|
190 |
train_loader,
|
191 |
desc=f"Epoch {epoch+1}/{config.epochs}",
|
@@ -211,12 +222,6 @@ async def train(model, config, websocket=None, model_type="single"):
|
|
211 |
current_loss = total_loss / (batch_idx + 1)
|
212 |
current_acc = 100. * correct / total
|
213 |
|
214 |
-
# Update progress bar description
|
215 |
-
progress_bar.set_postfix({
|
216 |
-
'loss': f'{current_loss:.4f}',
|
217 |
-
'acc': f'{current_acc:.2f}%'
|
218 |
-
})
|
219 |
-
|
220 |
# Send training update through websocket
|
221 |
if websocket:
|
222 |
try:
|
@@ -226,7 +231,8 @@ async def train(model, config, websocket=None, model_type="single"):
|
|
226 |
'data': {
|
227 |
'step': step,
|
228 |
'train_loss': current_loss,
|
229 |
-
'train_acc': current_acc
|
|
|
230 |
}
|
231 |
})
|
232 |
except Exception as e:
|
@@ -284,8 +290,260 @@ async def train(model, config, websocket=None, model_type="single"):
|
|
284 |
|
285 |
except Exception as e:
|
286 |
print(f"\nError during training: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
raise e
|
288 |
|
289 |
print("\nTraining completed!")
|
290 |
print(f"Best validation accuracy: {best_val_acc:.2f}%")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
import shutil
|
13 |
from tqdm import tqdm
|
14 |
import asyncio
|
15 |
+
from fastapi import WebSocket
|
16 |
+
import json
|
17 |
+
from scripts.model import Net
|
18 |
+
|
19 |
+
class TrainingConfig:
|
20 |
+
def __init__(self, params_dict):
|
21 |
+
self.block1 = params_dict['block1']
|
22 |
+
self.block2 = params_dict['block2']
|
23 |
+
self.block3 = params_dict['block3']
|
24 |
+
self.optimizer = params_dict['optimizer']
|
25 |
+
self.batch_size = params_dict['batch_size']
|
26 |
+
self.epochs = params_dict['epochs']
|
27 |
|
28 |
def generate_model_filename(config, model_type="single"):
|
29 |
"""Generate a filename based on model configuration
|
|
|
197 |
correct = 0
|
198 |
total = 0
|
199 |
|
|
|
200 |
progress_bar = tqdm(
|
201 |
train_loader,
|
202 |
desc=f"Epoch {epoch+1}/{config.epochs}",
|
|
|
222 |
current_loss = total_loss / (batch_idx + 1)
|
223 |
current_acc = 100. * correct / total
|
224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
# Send training update through websocket
|
226 |
if websocket:
|
227 |
try:
|
|
|
231 |
'data': {
|
232 |
'step': step,
|
233 |
'train_loss': current_loss,
|
234 |
+
'train_acc': current_acc,
|
235 |
+
'epoch': epoch
|
236 |
}
|
237 |
})
|
238 |
except Exception as e:
|
|
|
290 |
|
291 |
except Exception as e:
|
292 |
print(f"\nError during training: {e}")
|
293 |
+
if websocket:
|
294 |
+
await websocket.send_json({
|
295 |
+
'type': 'training_error',
|
296 |
+
'data': {
|
297 |
+
'message': str(e)
|
298 |
+
}
|
299 |
+
})
|
300 |
raise e
|
301 |
|
302 |
print("\nTraining completed!")
|
303 |
print(f"Best validation accuracy: {best_val_acc:.2f}%")
|
304 |
+
|
305 |
+
if websocket:
|
306 |
+
await websocket.send_json({
|
307 |
+
'type': 'training_complete',
|
308 |
+
'data': {
|
309 |
+
'message': 'Training completed successfully!',
|
310 |
+
'best_val_acc': best_val_acc
|
311 |
+
}
|
312 |
+
})
|
313 |
return None
|
314 |
+
|
315 |
+
def initialize_datasets(batch_size):
|
316 |
+
"""Initialize and return train and test datasets with dataloaders"""
|
317 |
+
# Ensure data is downloaded and extracted
|
318 |
+
print("Preparing dataset...")
|
319 |
+
download_and_extract_mnist_data()
|
320 |
+
|
321 |
+
# Paths to the extracted files
|
322 |
+
train_images_path = "data/MNIST/raw/train-images-idx3-ubyte"
|
323 |
+
train_labels_path = "data/MNIST/raw/train-labels-idx1-ubyte"
|
324 |
+
test_images_path = "data/MNIST/raw/t10k-images-idx3-ubyte"
|
325 |
+
test_labels_path = "data/MNIST/raw/t10k-labels-idx1-ubyte"
|
326 |
+
|
327 |
+
# Data loading
|
328 |
+
transform = transforms.Compose([
|
329 |
+
transforms.Normalize((0.1307,), (0.3081,))
|
330 |
+
])
|
331 |
+
|
332 |
+
train_dataset = CustomMNISTDataset(train_images_path, train_labels_path, transform=transform)
|
333 |
+
test_dataset = CustomMNISTDataset(test_images_path, test_labels_path, transform=transform)
|
334 |
+
|
335 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
336 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
337 |
+
|
338 |
+
return train_dataset, test_dataset, train_loader, test_loader
|
339 |
+
|
340 |
+
async def start_comparison_training(websocket: WebSocket, parameters: dict):
|
341 |
+
print("\n=== Starting Comparison Training ===")
|
342 |
+
print(f"Received parameters: {json.dumps(parameters, indent=2)}")
|
343 |
+
|
344 |
+
try:
|
345 |
+
# Create models directory if it doesn't exist
|
346 |
+
models_dir = Path("scripts/training/models")
|
347 |
+
models_dir.mkdir(parents=True, exist_ok=True)
|
348 |
+
|
349 |
+
# Validate parameters
|
350 |
+
if not parameters.get('model_params'):
|
351 |
+
print("Error: Missing model parameters")
|
352 |
+
raise ValueError("Missing model parameters")
|
353 |
+
|
354 |
+
if not parameters.get('dataset_params'):
|
355 |
+
print("Error: Missing dataset parameters")
|
356 |
+
raise ValueError("Missing dataset parameters")
|
357 |
+
|
358 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
359 |
+
criterion = nn.CrossEntropyLoss()
|
360 |
+
|
361 |
+
# Calculate total training samples once
|
362 |
+
train_dataset = CustomMNISTDataset(
|
363 |
+
"data/MNIST/raw/train-images-idx3-ubyte",
|
364 |
+
"data/MNIST/raw/train-labels-idx1-ubyte",
|
365 |
+
transform=transforms.Compose([transforms.Normalize((0.1307,), (0.3081,))])
|
366 |
+
)
|
367 |
+
total_samples = len(train_dataset)
|
368 |
+
|
369 |
+
# Dictionary to store best accuracies
|
370 |
+
best_accuracies = {}
|
371 |
+
|
372 |
+
# Start training models
|
373 |
+
for model_key, model_letter in [('model_a', 'A'), ('model_b', 'B')]:
|
374 |
+
print(f"\n{'='*50}")
|
375 |
+
print(f"Training Model {model_letter}")
|
376 |
+
print(f"{'='*50}")
|
377 |
+
|
378 |
+
model_params = parameters['model_params'][model_key]
|
379 |
+
|
380 |
+
# Calculate iterations per epoch for this model
|
381 |
+
batch_size = model_params['batch_size']
|
382 |
+
iterations_per_epoch = total_samples // batch_size
|
383 |
+
total_iterations = iterations_per_epoch * model_params['epochs']
|
384 |
+
|
385 |
+
# Print configuration details
|
386 |
+
print("\nModel Configuration:")
|
387 |
+
print(f"Architecture: {model_params['block1']}-{model_params['block2']}-{model_params['block3']}")
|
388 |
+
print(f"Optimizer: {model_params['optimizer']}")
|
389 |
+
print(f"Batch Size: {model_params['batch_size']}")
|
390 |
+
print(f"Epochs: {model_params['epochs']}")
|
391 |
+
print(f"Iterations per epoch: {iterations_per_epoch:,}")
|
392 |
+
print(f"Total iterations: {total_iterations:,}")
|
393 |
+
|
394 |
+
try:
|
395 |
+
# Initialize datasets with model-specific batch size
|
396 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
397 |
+
test_dataset = CustomMNISTDataset(
|
398 |
+
"data/MNIST/raw/t10k-images-idx3-ubyte",
|
399 |
+
"data/MNIST/raw/t10k-labels-idx1-ubyte",
|
400 |
+
transform=transforms.Compose([transforms.Normalize((0.1307,), (0.3081,))])
|
401 |
+
)
|
402 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
403 |
+
|
404 |
+
print(f"\nDataset Information:")
|
405 |
+
print(f"Training samples: {len(train_dataset):,}")
|
406 |
+
print(f"Test samples: {len(test_dataset):,}")
|
407 |
+
print(f"Steps per epoch: {len(train_loader):,}")
|
408 |
+
|
409 |
+
# Initialize model and move to device
|
410 |
+
model = Net(kernels=[
|
411 |
+
model_params['block1'],
|
412 |
+
model_params['block2'],
|
413 |
+
model_params['block3']
|
414 |
+
]).to(device)
|
415 |
+
|
416 |
+
# Print model parameters
|
417 |
+
total_params = sum(p.numel() for p in model.parameters())
|
418 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
419 |
+
print(f"\nModel Parameters:")
|
420 |
+
print(f"Total parameters: {total_params:,}")
|
421 |
+
print(f"Trainable parameters: {trainable_params:,}")
|
422 |
+
|
423 |
+
# Initialize optimizer
|
424 |
+
if model_params['optimizer'].lower() == 'adam':
|
425 |
+
optimizer = optim.Adam(model.parameters())
|
426 |
+
else:
|
427 |
+
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
|
428 |
+
|
429 |
+
# Train the model
|
430 |
+
current_iteration = 0
|
431 |
+
best_acc = 0 # Track best accuracy for model saving
|
432 |
+
|
433 |
+
for epoch in range(model_params['epochs']):
|
434 |
+
model.train()
|
435 |
+
total_loss = 0
|
436 |
+
correct = 0
|
437 |
+
total = 0
|
438 |
+
|
439 |
+
# Create progress bar for each epoch
|
440 |
+
progress_bar = tqdm(
|
441 |
+
train_loader,
|
442 |
+
desc=f"Epoch {epoch+1}/{model_params['epochs']}",
|
443 |
+
unit='batch',
|
444 |
+
leave=True,
|
445 |
+
ncols=100
|
446 |
+
)
|
447 |
+
|
448 |
+
for batch_idx, (data, target) in enumerate(progress_bar):
|
449 |
+
data, target = data.to(device), target.to(device)
|
450 |
+
optimizer.zero_grad()
|
451 |
+
output = model(data)
|
452 |
+
loss = criterion(output, target)
|
453 |
+
loss.backward()
|
454 |
+
optimizer.step()
|
455 |
+
|
456 |
+
# Calculate batch accuracy
|
457 |
+
pred = output.argmax(dim=1, keepdim=True)
|
458 |
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
459 |
+
total += target.size(0)
|
460 |
+
total_loss += loss.item()
|
461 |
+
|
462 |
+
# Calculate current metrics
|
463 |
+
current_loss = total_loss / (batch_idx + 1)
|
464 |
+
current_acc = 100. * correct / total
|
465 |
+
|
466 |
+
# Update progress bar description
|
467 |
+
progress_bar.set_postfix({
|
468 |
+
'loss': f'{current_loss:.4f}',
|
469 |
+
'acc': f'{current_acc:.2f}%'
|
470 |
+
})
|
471 |
+
|
472 |
+
# Send comparison-specific training update
|
473 |
+
current_iteration += 1
|
474 |
+
await websocket.send_json({
|
475 |
+
'status': 'training',
|
476 |
+
'model': model_letter,
|
477 |
+
'metrics': {
|
478 |
+
'iteration': current_iteration,
|
479 |
+
'total_iterations': total_iterations,
|
480 |
+
'loss': current_loss,
|
481 |
+
'accuracy': current_acc
|
482 |
+
},
|
483 |
+
'epoch': epoch,
|
484 |
+
'batch_size': batch_size,
|
485 |
+
'iterations_per_epoch': iterations_per_epoch
|
486 |
+
})
|
487 |
+
|
488 |
+
# Print epoch summary
|
489 |
+
print(f"\nEpoch {epoch+1} Summary:")
|
490 |
+
print(f"Average Loss: {current_loss:.4f}")
|
491 |
+
print(f"Accuracy: {current_acc:.2f}%")
|
492 |
+
|
493 |
+
# Add validation phase at the end of each epoch
|
494 |
+
model.eval()
|
495 |
+
val_loss = 0
|
496 |
+
val_correct = 0
|
497 |
+
val_total = 0
|
498 |
+
|
499 |
+
print("\nRunning validation...")
|
500 |
+
with torch.no_grad():
|
501 |
+
for data, target in test_loader:
|
502 |
+
data, target = data.to(device), target.to(device)
|
503 |
+
output = model(data)
|
504 |
+
val_loss += criterion(output, target).item()
|
505 |
+
pred = output.argmax(dim=1, keepdim=True)
|
506 |
+
val_correct += pred.eq(target.view_as(pred)).sum().item()
|
507 |
+
val_total += target.size(0)
|
508 |
+
|
509 |
+
val_loss /= len(test_loader)
|
510 |
+
val_acc = 100. * val_correct / val_total
|
511 |
+
|
512 |
+
# Save model if it's the best so far
|
513 |
+
if val_acc > best_acc:
|
514 |
+
best_acc = val_acc
|
515 |
+
# Generate filename with configuration
|
516 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
517 |
+
model_filename = f"{model_key}_arch_{model_params['block1']}_{model_params['block2']}_{model_params['block3']}_opt_{model_params['optimizer'].lower()}_batch_{model_params['batch_size']}_{timestamp}.pth"
|
518 |
+
model_path = models_dir / model_filename
|
519 |
+
|
520 |
+
print(f"\nSaving Model {model_letter} with accuracy {val_acc:.2f}% as: {model_filename}")
|
521 |
+
torch.save(model.state_dict(), model_path)
|
522 |
+
|
523 |
+
print(f"\nModel {model_letter} training completed")
|
524 |
+
print(f"Best validation accuracy: {best_acc:.2f}%")
|
525 |
+
|
526 |
+
# Save best accuracy for this model
|
527 |
+
best_accuracies[model_key] = best_acc
|
528 |
+
|
529 |
+
except Exception as e:
|
530 |
+
print(f"Error training Model {model_letter}: {str(e)}")
|
531 |
+
raise
|
532 |
+
|
533 |
+
print("\nBoth models trained successfully")
|
534 |
+
await websocket.send_json({
|
535 |
+
'status': 'complete',
|
536 |
+
'message': 'Training completed for both models',
|
537 |
+
'model_a_acc': best_accuracies.get('model_a'),
|
538 |
+
'model_b_acc': best_accuracies.get('model_b')
|
539 |
+
})
|
540 |
+
|
541 |
+
except Exception as e:
|
542 |
+
error_msg = f"Error in comparison training: {str(e)}"
|
543 |
+
print(error_msg)
|
544 |
+
await websocket.send_json({
|
545 |
+
'status': 'error',
|
546 |
+
'message': error_msg
|
547 |
+
})
|
548 |
+
finally:
|
549 |
+
print("=== Comparison Training Ended ===\n")
|
static/js/train.js
CHANGED
@@ -169,24 +169,24 @@ async function compareModels() {
|
|
169 |
|
170 |
function initializeComparisonCharts() {
|
171 |
const lossData = [{
|
172 |
-
name: 'Model
|
173 |
x: [],
|
174 |
y: [],
|
175 |
type: 'scatter'
|
176 |
}, {
|
177 |
-
name: 'Model
|
178 |
x: [],
|
179 |
y: [],
|
180 |
type: 'scatter'
|
181 |
}];
|
182 |
|
183 |
const accuracyData = [{
|
184 |
-
name: 'Model
|
185 |
x: [],
|
186 |
y: [],
|
187 |
type: 'scatter'
|
188 |
}, {
|
189 |
-
name: 'Model
|
190 |
x: [],
|
191 |
y: [],
|
192 |
type: 'scatter'
|
@@ -209,13 +209,13 @@ function displayComparisonResults(data) {
|
|
209 |
const logsDiv = document.getElementById('comparison-logs');
|
210 |
logsDiv.innerHTML = `
|
211 |
<div class="comparison-model">
|
212 |
-
<h4>Model
|
213 |
<p>Final Loss: ${data.model1_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
|
214 |
<p>Final Accuracy: ${data.model1_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
|
215 |
<p>Model Name: ${data.model1_results.model_name}</p>
|
216 |
</div>
|
217 |
<div class="comparison-model">
|
218 |
-
<h4>Model
|
219 |
<p>Final Loss: ${data.model2_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
|
220 |
<p>Final Accuracy: ${data.model2_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
|
221 |
<p>Model Name: ${data.model2_results.model_name}</p>
|
|
|
169 |
|
170 |
function initializeComparisonCharts() {
|
171 |
const lossData = [{
|
172 |
+
name: 'Model A Loss',
|
173 |
x: [],
|
174 |
y: [],
|
175 |
type: 'scatter'
|
176 |
}, {
|
177 |
+
name: 'Model B Loss',
|
178 |
x: [],
|
179 |
y: [],
|
180 |
type: 'scatter'
|
181 |
}];
|
182 |
|
183 |
const accuracyData = [{
|
184 |
+
name: 'Model A Accuracy',
|
185 |
x: [],
|
186 |
y: [],
|
187 |
type: 'scatter'
|
188 |
}, {
|
189 |
+
name: 'Model B Accuracy',
|
190 |
x: [],
|
191 |
y: [],
|
192 |
type: 'scatter'
|
|
|
209 |
const logsDiv = document.getElementById('comparison-logs');
|
210 |
logsDiv.innerHTML = `
|
211 |
<div class="comparison-model">
|
212 |
+
<h4>Model A</h4>
|
213 |
<p>Final Loss: ${data.model1_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
|
214 |
<p>Final Accuracy: ${data.model1_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
|
215 |
<p>Model Name: ${data.model1_results.model_name}</p>
|
216 |
</div>
|
217 |
<div class="comparison-model">
|
218 |
+
<h4>Model B</h4>
|
219 |
<p>Final Loss: ${data.model2_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
|
220 |
<p>Final Accuracy: ${data.model2_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
|
221 |
<p>Model Name: ${data.model2_results.model_name}</p>
|
static/js/train_compare.js
CHANGED
@@ -2,24 +2,24 @@ let ws;
|
|
2 |
|
3 |
function initializeComparisonCharts() {
|
4 |
const lossData = [{
|
5 |
-
name: 'Model
|
6 |
x: [],
|
7 |
y: [],
|
8 |
type: 'scatter'
|
9 |
}, {
|
10 |
-
name: 'Model
|
11 |
x: [],
|
12 |
y: [],
|
13 |
type: 'scatter'
|
14 |
}];
|
15 |
|
16 |
const accuracyData = [{
|
17 |
-
name: 'Model
|
18 |
x: [],
|
19 |
y: [],
|
20 |
type: 'scatter'
|
21 |
}, {
|
22 |
-
name: 'Model
|
23 |
x: [],
|
24 |
y: [],
|
25 |
type: 'scatter'
|
@@ -90,16 +90,169 @@ function displayComparisonResults(data) {
|
|
90 |
const logsDiv = document.getElementById('comparison-logs');
|
91 |
logsDiv.innerHTML = `
|
92 |
<div class="comparison-model">
|
93 |
-
<h4>Model
|
94 |
<p>Final Loss: ${data.model1_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
|
95 |
<p>Final Accuracy: ${data.model1_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
|
96 |
<p>Model Name: ${data.model1_results.model_name}</p>
|
97 |
</div>
|
98 |
<div class="comparison-model">
|
99 |
-
<h4>Model
|
100 |
<p>Final Loss: ${data.model2_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
|
101 |
<p>Final Accuracy: ${data.model2_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
|
102 |
<p>Model Name: ${data.model2_results.model_name}</p>
|
103 |
</div>
|
104 |
`;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
}
|
|
|
2 |
|
3 |
function initializeComparisonCharts() {
|
4 |
const lossData = [{
|
5 |
+
name: 'Model A Loss',
|
6 |
x: [],
|
7 |
y: [],
|
8 |
type: 'scatter'
|
9 |
}, {
|
10 |
+
name: 'Model B Loss',
|
11 |
x: [],
|
12 |
y: [],
|
13 |
type: 'scatter'
|
14 |
}];
|
15 |
|
16 |
const accuracyData = [{
|
17 |
+
name: 'Model A Accuracy',
|
18 |
x: [],
|
19 |
y: [],
|
20 |
type: 'scatter'
|
21 |
}, {
|
22 |
+
name: 'Model B Accuracy',
|
23 |
x: [],
|
24 |
y: [],
|
25 |
type: 'scatter'
|
|
|
90 |
const logsDiv = document.getElementById('comparison-logs');
|
91 |
logsDiv.innerHTML = `
|
92 |
<div class="comparison-model">
|
93 |
+
<h4>Model A</h4>
|
94 |
<p>Final Loss: ${data.model1_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
|
95 |
<p>Final Accuracy: ${data.model1_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
|
96 |
<p>Model Name: ${data.model1_results.model_name}</p>
|
97 |
</div>
|
98 |
<div class="comparison-model">
|
99 |
+
<h4>Model B</h4>
|
100 |
<p>Final Loss: ${data.model2_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
|
101 |
<p>Final Accuracy: ${data.model2_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
|
102 |
<p>Model Name: ${data.model2_results.model_name}</p>
|
103 |
</div>
|
104 |
`;
|
105 |
+
}
|
106 |
+
|
107 |
+
// Add these helper functions to get the parameters
|
108 |
+
function getModelParameters() {
|
109 |
+
try {
|
110 |
+
const params = {
|
111 |
+
model_a: {
|
112 |
+
block1: parseInt(document.getElementById('model1_kernel1').value),
|
113 |
+
block2: parseInt(document.getElementById('model1_kernel2').value),
|
114 |
+
block3: parseInt(document.getElementById('model1_kernel3').value),
|
115 |
+
optimizer: document.getElementById('model1_optimizer').value,
|
116 |
+
batch_size: parseInt(document.getElementById('model1_batch_size').value),
|
117 |
+
epochs: parseInt(document.getElementById('model1_epochs').value)
|
118 |
+
},
|
119 |
+
model_b: {
|
120 |
+
block1: parseInt(document.getElementById('model2_kernel1').value),
|
121 |
+
block2: parseInt(document.getElementById('model2_kernel2').value),
|
122 |
+
block3: parseInt(document.getElementById('model2_kernel3').value),
|
123 |
+
optimizer: document.getElementById('model2_optimizer').value,
|
124 |
+
batch_size: parseInt(document.getElementById('model2_batch_size').value),
|
125 |
+
epochs: parseInt(document.getElementById('model2_epochs').value)
|
126 |
+
}
|
127 |
+
};
|
128 |
+
|
129 |
+
// Validate that all values are present and valid
|
130 |
+
for (const model of ['model_a', 'model_b']) {
|
131 |
+
for (const [key, value] of Object.entries(params[model])) {
|
132 |
+
if (value === null || value === undefined || Number.isNaN(value)) {
|
133 |
+
throw new Error(`Invalid value for ${model} ${key}: ${value}`);
|
134 |
+
}
|
135 |
+
}
|
136 |
+
}
|
137 |
+
|
138 |
+
console.log('Collected and validated model parameters:', params);
|
139 |
+
return params;
|
140 |
+
} catch (error) {
|
141 |
+
console.error('Error in getModelParameters:', error);
|
142 |
+
throw error;
|
143 |
+
}
|
144 |
+
}
|
145 |
+
|
146 |
+
function getDatasetParameters() {
|
147 |
+
return {
|
148 |
+
batch_size: parseInt(document.getElementById('model1_batch_size').value), // Using model1's batch size for dataset
|
149 |
+
shuffle: true
|
150 |
+
};
|
151 |
+
}
|
152 |
+
|
153 |
+
// Update the WebSocket event listener
|
154 |
+
document.getElementById('startComparisonBtn').addEventListener('click', function() {
|
155 |
+
console.log('Start Comparison button clicked');
|
156 |
+
|
157 |
+
// Validate form inputs before proceeding
|
158 |
+
const formInputs = document.querySelectorAll('input[type="number"], select'); // Added select for optimizer
|
159 |
+
let isValid = true;
|
160 |
+
let formValues = {};
|
161 |
+
|
162 |
+
formInputs.forEach(input => {
|
163 |
+
console.log(`Checking input ${input.id}: ${input.value}`);
|
164 |
+
formValues[input.id] = input.value;
|
165 |
+
if (!input.value) {
|
166 |
+
console.error(`Missing value for ${input.id}`);
|
167 |
+
isValid = false;
|
168 |
+
}
|
169 |
+
});
|
170 |
+
|
171 |
+
console.log('Form values:', formValues); // Log all form values
|
172 |
+
|
173 |
+
if (!isValid) {
|
174 |
+
alert('Please fill in all required fields');
|
175 |
+
return;
|
176 |
+
}
|
177 |
+
|
178 |
+
// Show comparison progress section
|
179 |
+
document.getElementById('comparison-progress').classList.remove('hidden');
|
180 |
+
console.log('Initialized comparison charts');
|
181 |
+
initializeComparisonCharts();
|
182 |
+
|
183 |
+
console.log('Attempting WebSocket connection...');
|
184 |
+
const ws = new WebSocket(`ws://${window.location.host}/ws/compare`);
|
185 |
+
|
186 |
+
ws.onopen = function() {
|
187 |
+
console.log('WebSocket connection established');
|
188 |
+
const parameters = {
|
189 |
+
model_params: getModelParameters(),
|
190 |
+
dataset_params: getDatasetParameters()
|
191 |
+
};
|
192 |
+
|
193 |
+
const message = {
|
194 |
+
action: 'start_training',
|
195 |
+
parameters: parameters
|
196 |
+
};
|
197 |
+
|
198 |
+
console.log('Preparing to send message:', JSON.stringify(message, null, 2));
|
199 |
+
|
200 |
+
// Add a small delay to ensure WebSocket is ready
|
201 |
+
setTimeout(() => {
|
202 |
+
try {
|
203 |
+
ws.send(JSON.stringify(message));
|
204 |
+
console.log('Message sent successfully');
|
205 |
+
} catch (error) {
|
206 |
+
console.error('Error sending message:', error);
|
207 |
+
alert('Error sending training parameters. Please check console for details.');
|
208 |
+
}
|
209 |
+
}, 100);
|
210 |
+
};
|
211 |
+
|
212 |
+
ws.onmessage = function(event) {
|
213 |
+
console.log('Received WebSocket message:', event.data);
|
214 |
+
try {
|
215 |
+
const data = JSON.parse(event.data);
|
216 |
+
console.log('Parsed message data:', data);
|
217 |
+
updateTrainingProgress(data);
|
218 |
+
} catch (error) {
|
219 |
+
console.error('Error processing message:', error);
|
220 |
+
}
|
221 |
+
};
|
222 |
+
|
223 |
+
ws.onerror = function(error) {
|
224 |
+
console.error('WebSocket error:', error);
|
225 |
+
alert('Connection error occurred. Please check console for details.');
|
226 |
+
};
|
227 |
+
|
228 |
+
ws.onclose = function(event) {
|
229 |
+
console.log('WebSocket connection closed. Code:', event.code, 'Reason:', event.reason);
|
230 |
+
};
|
231 |
+
});
|
232 |
+
|
233 |
+
// Add the updateTrainingProgress function
|
234 |
+
function updateTrainingProgress(data) {
|
235 |
+
if (data.status === 'training') {
|
236 |
+
// Update loss plot
|
237 |
+
Plotly.extendTraces('comparison-loss-plot', {
|
238 |
+
y: [[data.metrics.loss]],
|
239 |
+
}, [data.model === 'A' ? 0 : 1]);
|
240 |
+
|
241 |
+
// Update accuracy plot
|
242 |
+
Plotly.extendTraces('comparison-accuracy-plot', {
|
243 |
+
y: [[data.metrics.accuracy]],
|
244 |
+
}, [data.model === 'A' ? 0 : 1]);
|
245 |
+
|
246 |
+
// Update progress text
|
247 |
+
const progressText = document.getElementById('training-progress-text');
|
248 |
+
progressText.textContent = `Training ${data.model === 'A' ? 'Model A' : 'Model B'} - Epoch ${data.epoch + 1}`;
|
249 |
+
} else if (data.status === 'complete') {
|
250 |
+
// Handle training completion
|
251 |
+
document.getElementById('training-progress-text').textContent = 'Training Complete!';
|
252 |
+
displayComparisonResults(data.metrics);
|
253 |
+
} else if (data.status === 'error') {
|
254 |
+
// Handle error
|
255 |
+
console.error('Training error:', data.message);
|
256 |
+
alert(`Training error: ${data.message}`);
|
257 |
+
}
|
258 |
}
|
templates/train_compare.html
CHANGED
@@ -11,9 +11,9 @@
|
|
11 |
<div class="container">
|
12 |
<h1>Compare Models</h1>
|
13 |
<div class="models-grid">
|
14 |
-
<!-- Model
|
15 |
<div class="model-config">
|
16 |
-
<h3>Model
|
17 |
<div class="network-config">
|
18 |
<h4>Network Architecture</h4>
|
19 |
<div class="block-config">
|
@@ -78,9 +78,9 @@
|
|
78 |
</div>
|
79 |
</div>
|
80 |
|
81 |
-
<!-- Model
|
82 |
<div class="model-config">
|
83 |
-
<h3>Model
|
84 |
<div class="network-config">
|
85 |
<h4>Network Architecture</h4>
|
86 |
<div class="block-config">
|
@@ -157,6 +157,18 @@
|
|
157 |
<div id="lossChart"></div>
|
158 |
<div id="accuracyChart"></div>
|
159 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
</div>
|
161 |
|
162 |
<style>
|
@@ -278,6 +290,28 @@
|
|
278 |
.config-item .section-title {
|
279 |
margin-bottom: 5px;
|
280 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
</style>
|
282 |
|
283 |
<script>
|
@@ -292,13 +326,13 @@
|
|
292 |
{
|
293 |
x: [],
|
294 |
y: [],
|
295 |
-
name: 'Model
|
296 |
type: 'scatter'
|
297 |
},
|
298 |
{
|
299 |
x: [],
|
300 |
y: [],
|
301 |
-
name: 'Model
|
302 |
type: 'scatter'
|
303 |
}
|
304 |
];
|
@@ -320,13 +354,13 @@
|
|
320 |
{
|
321 |
x: [],
|
322 |
y: [],
|
323 |
-
name: 'Model
|
324 |
type: 'scatter'
|
325 |
},
|
326 |
{
|
327 |
x: [],
|
328 |
y: [],
|
329 |
-
name: 'Model
|
330 |
type: 'scatter'
|
331 |
}
|
332 |
];
|
@@ -375,55 +409,91 @@
|
|
375 |
// Setup WebSocket connection
|
376 |
ws = new WebSocket(`ws://${window.location.host}/ws/compare`);
|
377 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
378 |
ws.onmessage = function(event) {
|
|
|
379 |
const data = JSON.parse(event.data);
|
380 |
|
381 |
-
if (data.
|
382 |
-
const modelIndex = data.
|
|
|
|
|
|
|
383 |
|
384 |
-
// Update
|
385 |
Plotly.extendTraces('lossChart', {
|
386 |
-
x: [[
|
387 |
-
y: [[data.
|
388 |
}, [modelIndex]);
|
389 |
|
|
|
390 |
Plotly.extendTraces('accuracyChart', {
|
391 |
-
x: [[
|
392 |
-
y: [[data.
|
393 |
}, [modelIndex]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
394 |
}
|
395 |
-
else if (data.
|
396 |
-
|
|
|
397 |
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
});
|
406 |
-
|
407 |
-
Plotly.addTraces('accuracyChart', {
|
408 |
-
x: [data.data.step],
|
409 |
-
y: [data.data.val_acc],
|
410 |
-
name: `Model ${data.data.model_id} Validation Accuracy`,
|
411 |
-
mode: 'markers',
|
412 |
-
marker: { size: 8 }
|
413 |
-
});
|
414 |
}
|
415 |
-
else if (data.
|
|
|
|
|
416 |
document.getElementById('startComparison').disabled = false;
|
417 |
document.getElementById('stopComparison').disabled = true;
|
418 |
}
|
419 |
};
|
420 |
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
|
|
|
|
|
|
|
|
|
|
427 |
}
|
428 |
|
429 |
function stopComparison() {
|
|
|
11 |
<div class="container">
|
12 |
<h1>Compare Models</h1>
|
13 |
<div class="models-grid">
|
14 |
+
<!-- Model A Configuration -->
|
15 |
<div class="model-config">
|
16 |
+
<h3>Model A</h3>
|
17 |
<div class="network-config">
|
18 |
<h4>Network Architecture</h4>
|
19 |
<div class="block-config">
|
|
|
78 |
</div>
|
79 |
</div>
|
80 |
|
81 |
+
<!-- Model B Configuration -->
|
82 |
<div class="model-config">
|
83 |
+
<h3>Model B</h3>
|
84 |
<div class="network-config">
|
85 |
<h4>Network Architecture</h4>
|
86 |
<div class="block-config">
|
|
|
157 |
<div id="lossChart"></div>
|
158 |
<div id="accuracyChart"></div>
|
159 |
</div>
|
160 |
+
|
161 |
+
<!-- Add this after the charts container -->
|
162 |
+
<div class="training-status">
|
163 |
+
<p id="training-progress"></p>
|
164 |
+
</div>
|
165 |
+
|
166 |
+
<!-- Add this after the training-status div -->
|
167 |
+
<div class="inference-controls" style="display: none;">
|
168 |
+
<button id="goToInference" onclick="window.location.href='/inference'" class="inference-button">
|
169 |
+
Try Model Inference
|
170 |
+
</button>
|
171 |
+
</div>
|
172 |
</div>
|
173 |
|
174 |
<style>
|
|
|
290 |
.config-item .section-title {
|
291 |
margin-bottom: 5px;
|
292 |
}
|
293 |
+
|
294 |
+
.training-status {
|
295 |
+
text-align: center;
|
296 |
+
margin: 20px 0;
|
297 |
+
font-weight: bold;
|
298 |
+
}
|
299 |
+
|
300 |
+
.inference-controls {
|
301 |
+
margin: 20px 0;
|
302 |
+
text-align: center;
|
303 |
+
}
|
304 |
+
|
305 |
+
.inference-button {
|
306 |
+
background-color: #28a745;
|
307 |
+
padding: 12px 24px;
|
308 |
+
font-size: 1.1em;
|
309 |
+
transition: background-color 0.3s;
|
310 |
+
}
|
311 |
+
|
312 |
+
.inference-button:hover {
|
313 |
+
background-color: #218838;
|
314 |
+
}
|
315 |
</style>
|
316 |
|
317 |
<script>
|
|
|
326 |
{
|
327 |
x: [],
|
328 |
y: [],
|
329 |
+
name: 'Model A Training Loss',
|
330 |
type: 'scatter'
|
331 |
},
|
332 |
{
|
333 |
x: [],
|
334 |
y: [],
|
335 |
+
name: 'Model B Training Loss',
|
336 |
type: 'scatter'
|
337 |
}
|
338 |
];
|
|
|
354 |
{
|
355 |
x: [],
|
356 |
y: [],
|
357 |
+
name: 'Model A Training Accuracy',
|
358 |
type: 'scatter'
|
359 |
},
|
360 |
{
|
361 |
x: [],
|
362 |
y: [],
|
363 |
+
name: 'Model B Training Accuracy',
|
364 |
type: 'scatter'
|
365 |
}
|
366 |
];
|
|
|
409 |
// Setup WebSocket connection
|
410 |
ws = new WebSocket(`ws://${window.location.host}/ws/compare`);
|
411 |
|
412 |
+
ws.onopen = function() {
|
413 |
+
console.log('WebSocket connection established');
|
414 |
+
// Only send the message after connection is established
|
415 |
+
const message = {
|
416 |
+
action: 'start_training',
|
417 |
+
parameters: {
|
418 |
+
model_params: {
|
419 |
+
model_a: model1Config,
|
420 |
+
model_b: model2Config
|
421 |
+
},
|
422 |
+
dataset_params: {
|
423 |
+
batch_size: model1Config.batch_size,
|
424 |
+
shuffle: true
|
425 |
+
}
|
426 |
+
}
|
427 |
+
};
|
428 |
+
console.log('Sending message:', message);
|
429 |
+
ws.send(JSON.stringify(message));
|
430 |
+
};
|
431 |
+
|
432 |
ws.onmessage = function(event) {
|
433 |
+
console.log('Received message:', event.data);
|
434 |
const data = JSON.parse(event.data);
|
435 |
|
436 |
+
if (data.status === 'training') {
|
437 |
+
const modelIndex = data.model === 'A' ? 0 : 1;
|
438 |
+
const iteration = data.metrics.iteration;
|
439 |
+
|
440 |
+
console.log(`Updating charts for model ${data.model} at iteration ${iteration}`);
|
441 |
|
442 |
+
// Update loss chart using iteration number
|
443 |
Plotly.extendTraces('lossChart', {
|
444 |
+
x: [[iteration]],
|
445 |
+
y: [[data.metrics.loss]]
|
446 |
}, [modelIndex]);
|
447 |
|
448 |
+
// Update accuracy chart using iteration number
|
449 |
Plotly.extendTraces('accuracyChart', {
|
450 |
+
x: [[iteration]],
|
451 |
+
y: [[data.metrics.accuracy]]
|
452 |
}, [modelIndex]);
|
453 |
+
|
454 |
+
// Update progress text with more detailed information
|
455 |
+
const progressText = document.getElementById('training-progress');
|
456 |
+
if (progressText) {
|
457 |
+
const progress = (data.metrics.iteration / data.metrics.total_iterations * 100).toFixed(1);
|
458 |
+
progressText.textContent =
|
459 |
+
`Training Model ${data.model} - ` +
|
460 |
+
`Epoch ${data.epoch + 1} - ` +
|
461 |
+
`Iteration ${data.metrics.iteration}/${data.metrics.total_iterations} ` +
|
462 |
+
`(${progress}%) - ` +
|
463 |
+
`Batch Size: ${data.batch_size}`;
|
464 |
+
}
|
465 |
}
|
466 |
+
else if (data.status === 'complete') {
|
467 |
+
document.getElementById('startComparison').disabled = false;
|
468 |
+
document.getElementById('stopComparison').disabled = true;
|
469 |
|
470 |
+
const progressText = document.getElementById('training-progress');
|
471 |
+
if (progressText) {
|
472 |
+
progressText.textContent = 'Training Complete!';
|
473 |
+
}
|
474 |
+
|
475 |
+
// Show the inference button
|
476 |
+
document.querySelector('.inference-controls').style.display = 'block';
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
}
|
478 |
+
else if (data.status === 'error') {
|
479 |
+
console.error('Training error:', data.message);
|
480 |
+
alert(`Training error: ${data.message}`);
|
481 |
document.getElementById('startComparison').disabled = false;
|
482 |
document.getElementById('stopComparison').disabled = true;
|
483 |
}
|
484 |
};
|
485 |
|
486 |
+
ws.onerror = function(error) {
|
487 |
+
console.error('WebSocket error:', error);
|
488 |
+
document.getElementById('startComparison').disabled = false;
|
489 |
+
document.getElementById('stopComparison').disabled = true;
|
490 |
+
};
|
491 |
+
|
492 |
+
ws.onclose = function(event) {
|
493 |
+
console.log('WebSocket connection closed:', event);
|
494 |
+
document.getElementById('startComparison').disabled = false;
|
495 |
+
document.getElementById('stopComparison').disabled = true;
|
496 |
+
};
|
497 |
}
|
498 |
|
499 |
function stopComparison() {
|