Shilpaj commited on
Commit
30d27e9
·
1 Parent(s): f9b762f

Feat: Completed logic for multiple models training and comparison

Browse files
app.py CHANGED
@@ -7,10 +7,12 @@ from typing import List, Optional
7
  import uvicorn
8
  import torch
9
  from scripts.model import Net
10
- from scripts.training.train import train
11
  from pathlib import Path
12
  from fastapi import BackgroundTasks
13
  import warnings
 
 
14
 
15
  warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.transforms")
16
 
@@ -83,7 +85,9 @@ async def train_model(config: TrainingConfig, background_tasks: BackgroundTasks)
83
  async def websocket_endpoint(websocket: WebSocket):
84
  await websocket.accept()
85
  try:
 
86
  config_data = await websocket.receive_json()
 
87
 
88
  model = Net(
89
  kernels=[
@@ -93,28 +97,20 @@ async def websocket_endpoint(websocket: WebSocket):
93
  ]
94
  )
95
 
96
- from scripts.training.config import NetworkConfig
97
- config = NetworkConfig()
98
- config.update(
99
- block1=config_data['block1'],
100
- block2=config_data['block2'],
101
- block3=config_data['block3'],
102
- optimizer=config_data['optimizer'],
103
- batch_size=config_data['batch_size'],
104
- epochs=config_data['epochs']
105
- )
106
 
107
  print(f"Starting training with config: {config_data}")
108
 
109
  try:
110
- # Pass "single" as model_type for single model training
111
  await train(model, config, websocket, model_type="single")
112
- await websocket.send_json({
113
- "type": "training_complete",
114
- "data": {
115
- "message": "Training completed successfully!"
116
- }
117
- })
118
  except Exception as e:
119
  print(f"Training error: {str(e)}")
120
  await websocket.send_json({
@@ -128,68 +124,70 @@ async def websocket_endpoint(websocket: WebSocket):
128
  print("WebSocket disconnected")
129
  except Exception as e:
130
  print(f"WebSocket error: {str(e)}")
 
 
 
 
 
 
131
  finally:
132
  print("WebSocket connection closed")
133
 
134
  @app.websocket("/ws/compare")
135
- async def websocket_compare_endpoint(websocket: WebSocket):
136
- await websocket.accept()
 
137
  try:
 
 
 
 
138
  data = await websocket.receive_json()
139
- if data.get("type") == "start_comparison":
140
- from scripts.training.config import NetworkConfig
141
-
142
- # Create and train both models
143
- model1_config = NetworkConfig()
144
- model2_config = NetworkConfig()
145
-
146
- # Update configs with received data
147
- model1_config.update(**data["model1"])
148
- model2_config.update(**data["model2"])
149
-
150
- # Create models with respective configurations
151
- model1 = Net(
152
- kernels=[
153
- model1_config.block1,
154
- model1_config.block2,
155
- model1_config.block3
156
- ]
157
- )
158
-
159
- model2 = Net(
160
- kernels=[
161
- model2_config.block1,
162
- model2_config.block2,
163
- model2_config.block3
164
- ]
165
- )
166
 
167
- # Train both models with appropriate model_type
168
- try:
169
- await train(model1, model1_config, websocket, model_type="model_1")
170
- await train(model2, model2_config, websocket, model_type="model_2")
171
-
172
  await websocket.send_json({
173
- "type": "comparison_complete",
174
- "data": {
175
- "message": "Training completed successfully!"
176
- }
177
  })
 
 
 
 
 
 
 
 
 
 
 
178
  except Exception as e:
179
- print(f"Training error: {str(e)}")
180
  await websocket.send_json({
181
- "type": "training_error",
182
- "data": {
183
- "message": f"Training failed: {str(e)}"
184
- }
185
  })
186
-
 
 
187
  except WebSocketDisconnect:
188
  print("WebSocket disconnected")
 
 
189
  except Exception as e:
190
- print(f"WebSocket error: {str(e)}")
191
  finally:
192
- print("WebSocket connection closed")
193
 
194
  # @app.post("/api/train_single")
195
  # async def train_single_model(config: TrainingConfig):
 
7
  import uvicorn
8
  import torch
9
  from scripts.model import Net
10
+ from scripts.training.train import train, start_comparison_training
11
  from pathlib import Path
12
  from fastapi import BackgroundTasks
13
  import warnings
14
+ import asyncio
15
+ import json
16
 
17
  warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.transforms")
18
 
 
85
  async def websocket_endpoint(websocket: WebSocket):
86
  await websocket.accept()
87
  try:
88
+ print("WebSocket connection accepted for single model training")
89
  config_data = await websocket.receive_json()
90
+ print(f"Received config data: {config_data}")
91
 
92
  model = Net(
93
  kernels=[
 
97
  ]
98
  )
99
 
100
+ # Create TrainingConfig object for single model using **kwargs
101
+ config = TrainingConfig(**{
102
+ 'block1': config_data['block1'],
103
+ 'block2': config_data['block2'],
104
+ 'block3': config_data['block3'],
105
+ 'optimizer': config_data['optimizer'],
106
+ 'batch_size': config_data['batch_size'],
107
+ 'epochs': config_data['epochs']
108
+ })
 
109
 
110
  print(f"Starting training with config: {config_data}")
111
 
112
  try:
 
113
  await train(model, config, websocket, model_type="single")
 
 
 
 
 
 
114
  except Exception as e:
115
  print(f"Training error: {str(e)}")
116
  await websocket.send_json({
 
124
  print("WebSocket disconnected")
125
  except Exception as e:
126
  print(f"WebSocket error: {str(e)}")
127
+ await websocket.send_json({
128
+ "type": "training_error",
129
+ "data": {
130
+ "message": f"WebSocket error: {str(e)}"
131
+ }
132
+ })
133
  finally:
134
  print("WebSocket connection closed")
135
 
136
  @app.websocket("/ws/compare")
137
+ async def websocket_endpoint(websocket: WebSocket):
138
+ print("\n=== New WebSocket Connection ===")
139
+ print("New WebSocket connection attempt")
140
  try:
141
+ await websocket.accept()
142
+ print("WebSocket connection accepted")
143
+
144
+ print("Waiting for initial message...")
145
  data = await websocket.receive_json()
146
+ print(f"Received initial message: {data}")
147
+
148
+ if 'action' not in data:
149
+ print("Error: Missing 'action' in message")
150
+ await websocket.send_json({
151
+ 'status': 'error',
152
+ 'message': 'Missing action in request'
153
+ })
154
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ if data['action'] == 'start_training':
157
+ if 'parameters' not in data:
158
+ print("Error: Missing 'parameters' in message")
 
 
159
  await websocket.send_json({
160
+ 'status': 'error',
161
+ 'message': 'Missing parameters in request'
 
 
162
  })
163
+ return
164
+
165
+ print("Starting training task")
166
+ try:
167
+ training_task = asyncio.create_task(start_comparison_training(
168
+ websocket,
169
+ data['parameters']
170
+ ))
171
+ print("Training task created, awaiting completion...")
172
+ await training_task
173
+ print("Training task completed")
174
  except Exception as e:
175
+ print(f"Error during training task: {str(e)}")
176
  await websocket.send_json({
177
+ 'status': 'error',
178
+ 'message': f'Training error: {str(e)}'
 
 
179
  })
180
+ else:
181
+ print(f"Unknown action received: {data['action']}")
182
+
183
  except WebSocketDisconnect:
184
  print("WebSocket disconnected")
185
+ except json.JSONDecodeError as e:
186
+ print(f"JSON decode error: {str(e)}")
187
  except Exception as e:
188
+ print(f"Unexpected error in websocket handler: {str(e)}")
189
  finally:
190
+ print("=== WebSocket Connection Closed ===\n")
191
 
192
  # @app.post("/api/train_single")
193
  # async def train_single_model(config: TrainingConfig):
scripts/training/train.py CHANGED
@@ -12,6 +12,18 @@ import urllib.request
12
  import shutil
13
  from tqdm import tqdm
14
  import asyncio
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def generate_model_filename(config, model_type="single"):
17
  """Generate a filename based on model configuration
@@ -185,7 +197,6 @@ async def train(model, config, websocket=None, model_type="single"):
185
  correct = 0
186
  total = 0
187
 
188
- # Create progress bar for each epoch
189
  progress_bar = tqdm(
190
  train_loader,
191
  desc=f"Epoch {epoch+1}/{config.epochs}",
@@ -211,12 +222,6 @@ async def train(model, config, websocket=None, model_type="single"):
211
  current_loss = total_loss / (batch_idx + 1)
212
  current_acc = 100. * correct / total
213
 
214
- # Update progress bar description
215
- progress_bar.set_postfix({
216
- 'loss': f'{current_loss:.4f}',
217
- 'acc': f'{current_acc:.2f}%'
218
- })
219
-
220
  # Send training update through websocket
221
  if websocket:
222
  try:
@@ -226,7 +231,8 @@ async def train(model, config, websocket=None, model_type="single"):
226
  'data': {
227
  'step': step,
228
  'train_loss': current_loss,
229
- 'train_acc': current_acc
 
230
  }
231
  })
232
  except Exception as e:
@@ -284,8 +290,260 @@ async def train(model, config, websocket=None, model_type="single"):
284
 
285
  except Exception as e:
286
  print(f"\nError during training: {e}")
 
 
 
 
 
 
 
287
  raise e
288
 
289
  print("\nTraining completed!")
290
  print(f"Best validation accuracy: {best_val_acc:.2f}%")
 
 
 
 
 
 
 
 
 
291
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  import shutil
13
  from tqdm import tqdm
14
  import asyncio
15
+ from fastapi import WebSocket
16
+ import json
17
+ from scripts.model import Net
18
+
19
+ class TrainingConfig:
20
+ def __init__(self, params_dict):
21
+ self.block1 = params_dict['block1']
22
+ self.block2 = params_dict['block2']
23
+ self.block3 = params_dict['block3']
24
+ self.optimizer = params_dict['optimizer']
25
+ self.batch_size = params_dict['batch_size']
26
+ self.epochs = params_dict['epochs']
27
 
28
  def generate_model_filename(config, model_type="single"):
29
  """Generate a filename based on model configuration
 
197
  correct = 0
198
  total = 0
199
 
 
200
  progress_bar = tqdm(
201
  train_loader,
202
  desc=f"Epoch {epoch+1}/{config.epochs}",
 
222
  current_loss = total_loss / (batch_idx + 1)
223
  current_acc = 100. * correct / total
224
 
 
 
 
 
 
 
225
  # Send training update through websocket
226
  if websocket:
227
  try:
 
231
  'data': {
232
  'step': step,
233
  'train_loss': current_loss,
234
+ 'train_acc': current_acc,
235
+ 'epoch': epoch
236
  }
237
  })
238
  except Exception as e:
 
290
 
291
  except Exception as e:
292
  print(f"\nError during training: {e}")
293
+ if websocket:
294
+ await websocket.send_json({
295
+ 'type': 'training_error',
296
+ 'data': {
297
+ 'message': str(e)
298
+ }
299
+ })
300
  raise e
301
 
302
  print("\nTraining completed!")
303
  print(f"Best validation accuracy: {best_val_acc:.2f}%")
304
+
305
+ if websocket:
306
+ await websocket.send_json({
307
+ 'type': 'training_complete',
308
+ 'data': {
309
+ 'message': 'Training completed successfully!',
310
+ 'best_val_acc': best_val_acc
311
+ }
312
+ })
313
  return None
314
+
315
+ def initialize_datasets(batch_size):
316
+ """Initialize and return train and test datasets with dataloaders"""
317
+ # Ensure data is downloaded and extracted
318
+ print("Preparing dataset...")
319
+ download_and_extract_mnist_data()
320
+
321
+ # Paths to the extracted files
322
+ train_images_path = "data/MNIST/raw/train-images-idx3-ubyte"
323
+ train_labels_path = "data/MNIST/raw/train-labels-idx1-ubyte"
324
+ test_images_path = "data/MNIST/raw/t10k-images-idx3-ubyte"
325
+ test_labels_path = "data/MNIST/raw/t10k-labels-idx1-ubyte"
326
+
327
+ # Data loading
328
+ transform = transforms.Compose([
329
+ transforms.Normalize((0.1307,), (0.3081,))
330
+ ])
331
+
332
+ train_dataset = CustomMNISTDataset(train_images_path, train_labels_path, transform=transform)
333
+ test_dataset = CustomMNISTDataset(test_images_path, test_labels_path, transform=transform)
334
+
335
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
336
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
337
+
338
+ return train_dataset, test_dataset, train_loader, test_loader
339
+
340
+ async def start_comparison_training(websocket: WebSocket, parameters: dict):
341
+ print("\n=== Starting Comparison Training ===")
342
+ print(f"Received parameters: {json.dumps(parameters, indent=2)}")
343
+
344
+ try:
345
+ # Create models directory if it doesn't exist
346
+ models_dir = Path("scripts/training/models")
347
+ models_dir.mkdir(parents=True, exist_ok=True)
348
+
349
+ # Validate parameters
350
+ if not parameters.get('model_params'):
351
+ print("Error: Missing model parameters")
352
+ raise ValueError("Missing model parameters")
353
+
354
+ if not parameters.get('dataset_params'):
355
+ print("Error: Missing dataset parameters")
356
+ raise ValueError("Missing dataset parameters")
357
+
358
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
359
+ criterion = nn.CrossEntropyLoss()
360
+
361
+ # Calculate total training samples once
362
+ train_dataset = CustomMNISTDataset(
363
+ "data/MNIST/raw/train-images-idx3-ubyte",
364
+ "data/MNIST/raw/train-labels-idx1-ubyte",
365
+ transform=transforms.Compose([transforms.Normalize((0.1307,), (0.3081,))])
366
+ )
367
+ total_samples = len(train_dataset)
368
+
369
+ # Dictionary to store best accuracies
370
+ best_accuracies = {}
371
+
372
+ # Start training models
373
+ for model_key, model_letter in [('model_a', 'A'), ('model_b', 'B')]:
374
+ print(f"\n{'='*50}")
375
+ print(f"Training Model {model_letter}")
376
+ print(f"{'='*50}")
377
+
378
+ model_params = parameters['model_params'][model_key]
379
+
380
+ # Calculate iterations per epoch for this model
381
+ batch_size = model_params['batch_size']
382
+ iterations_per_epoch = total_samples // batch_size
383
+ total_iterations = iterations_per_epoch * model_params['epochs']
384
+
385
+ # Print configuration details
386
+ print("\nModel Configuration:")
387
+ print(f"Architecture: {model_params['block1']}-{model_params['block2']}-{model_params['block3']}")
388
+ print(f"Optimizer: {model_params['optimizer']}")
389
+ print(f"Batch Size: {model_params['batch_size']}")
390
+ print(f"Epochs: {model_params['epochs']}")
391
+ print(f"Iterations per epoch: {iterations_per_epoch:,}")
392
+ print(f"Total iterations: {total_iterations:,}")
393
+
394
+ try:
395
+ # Initialize datasets with model-specific batch size
396
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
397
+ test_dataset = CustomMNISTDataset(
398
+ "data/MNIST/raw/t10k-images-idx3-ubyte",
399
+ "data/MNIST/raw/t10k-labels-idx1-ubyte",
400
+ transform=transforms.Compose([transforms.Normalize((0.1307,), (0.3081,))])
401
+ )
402
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
403
+
404
+ print(f"\nDataset Information:")
405
+ print(f"Training samples: {len(train_dataset):,}")
406
+ print(f"Test samples: {len(test_dataset):,}")
407
+ print(f"Steps per epoch: {len(train_loader):,}")
408
+
409
+ # Initialize model and move to device
410
+ model = Net(kernels=[
411
+ model_params['block1'],
412
+ model_params['block2'],
413
+ model_params['block3']
414
+ ]).to(device)
415
+
416
+ # Print model parameters
417
+ total_params = sum(p.numel() for p in model.parameters())
418
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
419
+ print(f"\nModel Parameters:")
420
+ print(f"Total parameters: {total_params:,}")
421
+ print(f"Trainable parameters: {trainable_params:,}")
422
+
423
+ # Initialize optimizer
424
+ if model_params['optimizer'].lower() == 'adam':
425
+ optimizer = optim.Adam(model.parameters())
426
+ else:
427
+ optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
428
+
429
+ # Train the model
430
+ current_iteration = 0
431
+ best_acc = 0 # Track best accuracy for model saving
432
+
433
+ for epoch in range(model_params['epochs']):
434
+ model.train()
435
+ total_loss = 0
436
+ correct = 0
437
+ total = 0
438
+
439
+ # Create progress bar for each epoch
440
+ progress_bar = tqdm(
441
+ train_loader,
442
+ desc=f"Epoch {epoch+1}/{model_params['epochs']}",
443
+ unit='batch',
444
+ leave=True,
445
+ ncols=100
446
+ )
447
+
448
+ for batch_idx, (data, target) in enumerate(progress_bar):
449
+ data, target = data.to(device), target.to(device)
450
+ optimizer.zero_grad()
451
+ output = model(data)
452
+ loss = criterion(output, target)
453
+ loss.backward()
454
+ optimizer.step()
455
+
456
+ # Calculate batch accuracy
457
+ pred = output.argmax(dim=1, keepdim=True)
458
+ correct += pred.eq(target.view_as(pred)).sum().item()
459
+ total += target.size(0)
460
+ total_loss += loss.item()
461
+
462
+ # Calculate current metrics
463
+ current_loss = total_loss / (batch_idx + 1)
464
+ current_acc = 100. * correct / total
465
+
466
+ # Update progress bar description
467
+ progress_bar.set_postfix({
468
+ 'loss': f'{current_loss:.4f}',
469
+ 'acc': f'{current_acc:.2f}%'
470
+ })
471
+
472
+ # Send comparison-specific training update
473
+ current_iteration += 1
474
+ await websocket.send_json({
475
+ 'status': 'training',
476
+ 'model': model_letter,
477
+ 'metrics': {
478
+ 'iteration': current_iteration,
479
+ 'total_iterations': total_iterations,
480
+ 'loss': current_loss,
481
+ 'accuracy': current_acc
482
+ },
483
+ 'epoch': epoch,
484
+ 'batch_size': batch_size,
485
+ 'iterations_per_epoch': iterations_per_epoch
486
+ })
487
+
488
+ # Print epoch summary
489
+ print(f"\nEpoch {epoch+1} Summary:")
490
+ print(f"Average Loss: {current_loss:.4f}")
491
+ print(f"Accuracy: {current_acc:.2f}%")
492
+
493
+ # Add validation phase at the end of each epoch
494
+ model.eval()
495
+ val_loss = 0
496
+ val_correct = 0
497
+ val_total = 0
498
+
499
+ print("\nRunning validation...")
500
+ with torch.no_grad():
501
+ for data, target in test_loader:
502
+ data, target = data.to(device), target.to(device)
503
+ output = model(data)
504
+ val_loss += criterion(output, target).item()
505
+ pred = output.argmax(dim=1, keepdim=True)
506
+ val_correct += pred.eq(target.view_as(pred)).sum().item()
507
+ val_total += target.size(0)
508
+
509
+ val_loss /= len(test_loader)
510
+ val_acc = 100. * val_correct / val_total
511
+
512
+ # Save model if it's the best so far
513
+ if val_acc > best_acc:
514
+ best_acc = val_acc
515
+ # Generate filename with configuration
516
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
517
+ model_filename = f"{model_key}_arch_{model_params['block1']}_{model_params['block2']}_{model_params['block3']}_opt_{model_params['optimizer'].lower()}_batch_{model_params['batch_size']}_{timestamp}.pth"
518
+ model_path = models_dir / model_filename
519
+
520
+ print(f"\nSaving Model {model_letter} with accuracy {val_acc:.2f}% as: {model_filename}")
521
+ torch.save(model.state_dict(), model_path)
522
+
523
+ print(f"\nModel {model_letter} training completed")
524
+ print(f"Best validation accuracy: {best_acc:.2f}%")
525
+
526
+ # Save best accuracy for this model
527
+ best_accuracies[model_key] = best_acc
528
+
529
+ except Exception as e:
530
+ print(f"Error training Model {model_letter}: {str(e)}")
531
+ raise
532
+
533
+ print("\nBoth models trained successfully")
534
+ await websocket.send_json({
535
+ 'status': 'complete',
536
+ 'message': 'Training completed for both models',
537
+ 'model_a_acc': best_accuracies.get('model_a'),
538
+ 'model_b_acc': best_accuracies.get('model_b')
539
+ })
540
+
541
+ except Exception as e:
542
+ error_msg = f"Error in comparison training: {str(e)}"
543
+ print(error_msg)
544
+ await websocket.send_json({
545
+ 'status': 'error',
546
+ 'message': error_msg
547
+ })
548
+ finally:
549
+ print("=== Comparison Training Ended ===\n")
static/js/train.js CHANGED
@@ -169,24 +169,24 @@ async function compareModels() {
169
 
170
  function initializeComparisonCharts() {
171
  const lossData = [{
172
- name: 'Model 1 Loss',
173
  x: [],
174
  y: [],
175
  type: 'scatter'
176
  }, {
177
- name: 'Model 2 Loss',
178
  x: [],
179
  y: [],
180
  type: 'scatter'
181
  }];
182
 
183
  const accuracyData = [{
184
- name: 'Model 1 Accuracy',
185
  x: [],
186
  y: [],
187
  type: 'scatter'
188
  }, {
189
- name: 'Model 2 Accuracy',
190
  x: [],
191
  y: [],
192
  type: 'scatter'
@@ -209,13 +209,13 @@ function displayComparisonResults(data) {
209
  const logsDiv = document.getElementById('comparison-logs');
210
  logsDiv.innerHTML = `
211
  <div class="comparison-model">
212
- <h4>Model 1</h4>
213
  <p>Final Loss: ${data.model1_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
214
  <p>Final Accuracy: ${data.model1_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
215
  <p>Model Name: ${data.model1_results.model_name}</p>
216
  </div>
217
  <div class="comparison-model">
218
- <h4>Model 2</h4>
219
  <p>Final Loss: ${data.model2_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
220
  <p>Final Accuracy: ${data.model2_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
221
  <p>Model Name: ${data.model2_results.model_name}</p>
 
169
 
170
  function initializeComparisonCharts() {
171
  const lossData = [{
172
+ name: 'Model A Loss',
173
  x: [],
174
  y: [],
175
  type: 'scatter'
176
  }, {
177
+ name: 'Model B Loss',
178
  x: [],
179
  y: [],
180
  type: 'scatter'
181
  }];
182
 
183
  const accuracyData = [{
184
+ name: 'Model A Accuracy',
185
  x: [],
186
  y: [],
187
  type: 'scatter'
188
  }, {
189
+ name: 'Model B Accuracy',
190
  x: [],
191
  y: [],
192
  type: 'scatter'
 
209
  const logsDiv = document.getElementById('comparison-logs');
210
  logsDiv.innerHTML = `
211
  <div class="comparison-model">
212
+ <h4>Model A</h4>
213
  <p>Final Loss: ${data.model1_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
214
  <p>Final Accuracy: ${data.model1_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
215
  <p>Model Name: ${data.model1_results.model_name}</p>
216
  </div>
217
  <div class="comparison-model">
218
+ <h4>Model B</h4>
219
  <p>Final Loss: ${data.model2_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
220
  <p>Final Accuracy: ${data.model2_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
221
  <p>Model Name: ${data.model2_results.model_name}</p>
static/js/train_compare.js CHANGED
@@ -2,24 +2,24 @@ let ws;
2
 
3
  function initializeComparisonCharts() {
4
  const lossData = [{
5
- name: 'Model 1 Loss',
6
  x: [],
7
  y: [],
8
  type: 'scatter'
9
  }, {
10
- name: 'Model 2 Loss',
11
  x: [],
12
  y: [],
13
  type: 'scatter'
14
  }];
15
 
16
  const accuracyData = [{
17
- name: 'Model 1 Accuracy',
18
  x: [],
19
  y: [],
20
  type: 'scatter'
21
  }, {
22
- name: 'Model 2 Accuracy',
23
  x: [],
24
  y: [],
25
  type: 'scatter'
@@ -90,16 +90,169 @@ function displayComparisonResults(data) {
90
  const logsDiv = document.getElementById('comparison-logs');
91
  logsDiv.innerHTML = `
92
  <div class="comparison-model">
93
- <h4>Model 1</h4>
94
  <p>Final Loss: ${data.model1_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
95
  <p>Final Accuracy: ${data.model1_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
96
  <p>Model Name: ${data.model1_results.model_name}</p>
97
  </div>
98
  <div class="comparison-model">
99
- <h4>Model 2</h4>
100
  <p>Final Loss: ${data.model2_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
101
  <p>Final Accuracy: ${data.model2_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
102
  <p>Model Name: ${data.model2_results.model_name}</p>
103
  </div>
104
  `;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  }
 
2
 
3
  function initializeComparisonCharts() {
4
  const lossData = [{
5
+ name: 'Model A Loss',
6
  x: [],
7
  y: [],
8
  type: 'scatter'
9
  }, {
10
+ name: 'Model B Loss',
11
  x: [],
12
  y: [],
13
  type: 'scatter'
14
  }];
15
 
16
  const accuracyData = [{
17
+ name: 'Model A Accuracy',
18
  x: [],
19
  y: [],
20
  type: 'scatter'
21
  }, {
22
+ name: 'Model B Accuracy',
23
  x: [],
24
  y: [],
25
  type: 'scatter'
 
90
  const logsDiv = document.getElementById('comparison-logs');
91
  logsDiv.innerHTML = `
92
  <div class="comparison-model">
93
+ <h4>Model A</h4>
94
  <p>Final Loss: ${data.model1_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
95
  <p>Final Accuracy: ${data.model1_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
96
  <p>Model Name: ${data.model1_results.model_name}</p>
97
  </div>
98
  <div class="comparison-model">
99
+ <h4>Model B</h4>
100
  <p>Final Loss: ${data.model2_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
101
  <p>Final Accuracy: ${data.model2_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
102
  <p>Model Name: ${data.model2_results.model_name}</p>
103
  </div>
104
  `;
105
+ }
106
+
107
+ // Add these helper functions to get the parameters
108
+ function getModelParameters() {
109
+ try {
110
+ const params = {
111
+ model_a: {
112
+ block1: parseInt(document.getElementById('model1_kernel1').value),
113
+ block2: parseInt(document.getElementById('model1_kernel2').value),
114
+ block3: parseInt(document.getElementById('model1_kernel3').value),
115
+ optimizer: document.getElementById('model1_optimizer').value,
116
+ batch_size: parseInt(document.getElementById('model1_batch_size').value),
117
+ epochs: parseInt(document.getElementById('model1_epochs').value)
118
+ },
119
+ model_b: {
120
+ block1: parseInt(document.getElementById('model2_kernel1').value),
121
+ block2: parseInt(document.getElementById('model2_kernel2').value),
122
+ block3: parseInt(document.getElementById('model2_kernel3').value),
123
+ optimizer: document.getElementById('model2_optimizer').value,
124
+ batch_size: parseInt(document.getElementById('model2_batch_size').value),
125
+ epochs: parseInt(document.getElementById('model2_epochs').value)
126
+ }
127
+ };
128
+
129
+ // Validate that all values are present and valid
130
+ for (const model of ['model_a', 'model_b']) {
131
+ for (const [key, value] of Object.entries(params[model])) {
132
+ if (value === null || value === undefined || Number.isNaN(value)) {
133
+ throw new Error(`Invalid value for ${model} ${key}: ${value}`);
134
+ }
135
+ }
136
+ }
137
+
138
+ console.log('Collected and validated model parameters:', params);
139
+ return params;
140
+ } catch (error) {
141
+ console.error('Error in getModelParameters:', error);
142
+ throw error;
143
+ }
144
+ }
145
+
146
+ function getDatasetParameters() {
147
+ return {
148
+ batch_size: parseInt(document.getElementById('model1_batch_size').value), // Using model1's batch size for dataset
149
+ shuffle: true
150
+ };
151
+ }
152
+
153
+ // Update the WebSocket event listener
154
+ document.getElementById('startComparisonBtn').addEventListener('click', function() {
155
+ console.log('Start Comparison button clicked');
156
+
157
+ // Validate form inputs before proceeding
158
+ const formInputs = document.querySelectorAll('input[type="number"], select'); // Added select for optimizer
159
+ let isValid = true;
160
+ let formValues = {};
161
+
162
+ formInputs.forEach(input => {
163
+ console.log(`Checking input ${input.id}: ${input.value}`);
164
+ formValues[input.id] = input.value;
165
+ if (!input.value) {
166
+ console.error(`Missing value for ${input.id}`);
167
+ isValid = false;
168
+ }
169
+ });
170
+
171
+ console.log('Form values:', formValues); // Log all form values
172
+
173
+ if (!isValid) {
174
+ alert('Please fill in all required fields');
175
+ return;
176
+ }
177
+
178
+ // Show comparison progress section
179
+ document.getElementById('comparison-progress').classList.remove('hidden');
180
+ console.log('Initialized comparison charts');
181
+ initializeComparisonCharts();
182
+
183
+ console.log('Attempting WebSocket connection...');
184
+ const ws = new WebSocket(`ws://${window.location.host}/ws/compare`);
185
+
186
+ ws.onopen = function() {
187
+ console.log('WebSocket connection established');
188
+ const parameters = {
189
+ model_params: getModelParameters(),
190
+ dataset_params: getDatasetParameters()
191
+ };
192
+
193
+ const message = {
194
+ action: 'start_training',
195
+ parameters: parameters
196
+ };
197
+
198
+ console.log('Preparing to send message:', JSON.stringify(message, null, 2));
199
+
200
+ // Add a small delay to ensure WebSocket is ready
201
+ setTimeout(() => {
202
+ try {
203
+ ws.send(JSON.stringify(message));
204
+ console.log('Message sent successfully');
205
+ } catch (error) {
206
+ console.error('Error sending message:', error);
207
+ alert('Error sending training parameters. Please check console for details.');
208
+ }
209
+ }, 100);
210
+ };
211
+
212
+ ws.onmessage = function(event) {
213
+ console.log('Received WebSocket message:', event.data);
214
+ try {
215
+ const data = JSON.parse(event.data);
216
+ console.log('Parsed message data:', data);
217
+ updateTrainingProgress(data);
218
+ } catch (error) {
219
+ console.error('Error processing message:', error);
220
+ }
221
+ };
222
+
223
+ ws.onerror = function(error) {
224
+ console.error('WebSocket error:', error);
225
+ alert('Connection error occurred. Please check console for details.');
226
+ };
227
+
228
+ ws.onclose = function(event) {
229
+ console.log('WebSocket connection closed. Code:', event.code, 'Reason:', event.reason);
230
+ };
231
+ });
232
+
233
+ // Add the updateTrainingProgress function
234
+ function updateTrainingProgress(data) {
235
+ if (data.status === 'training') {
236
+ // Update loss plot
237
+ Plotly.extendTraces('comparison-loss-plot', {
238
+ y: [[data.metrics.loss]],
239
+ }, [data.model === 'A' ? 0 : 1]);
240
+
241
+ // Update accuracy plot
242
+ Plotly.extendTraces('comparison-accuracy-plot', {
243
+ y: [[data.metrics.accuracy]],
244
+ }, [data.model === 'A' ? 0 : 1]);
245
+
246
+ // Update progress text
247
+ const progressText = document.getElementById('training-progress-text');
248
+ progressText.textContent = `Training ${data.model === 'A' ? 'Model A' : 'Model B'} - Epoch ${data.epoch + 1}`;
249
+ } else if (data.status === 'complete') {
250
+ // Handle training completion
251
+ document.getElementById('training-progress-text').textContent = 'Training Complete!';
252
+ displayComparisonResults(data.metrics);
253
+ } else if (data.status === 'error') {
254
+ // Handle error
255
+ console.error('Training error:', data.message);
256
+ alert(`Training error: ${data.message}`);
257
+ }
258
  }
templates/train_compare.html CHANGED
@@ -11,9 +11,9 @@
11
  <div class="container">
12
  <h1>Compare Models</h1>
13
  <div class="models-grid">
14
- <!-- Model 1 Configuration -->
15
  <div class="model-config">
16
- <h3>Model 1</h3>
17
  <div class="network-config">
18
  <h4>Network Architecture</h4>
19
  <div class="block-config">
@@ -78,9 +78,9 @@
78
  </div>
79
  </div>
80
 
81
- <!-- Model 2 Configuration -->
82
  <div class="model-config">
83
- <h3>Model 2</h3>
84
  <div class="network-config">
85
  <h4>Network Architecture</h4>
86
  <div class="block-config">
@@ -157,6 +157,18 @@
157
  <div id="lossChart"></div>
158
  <div id="accuracyChart"></div>
159
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
160
  </div>
161
 
162
  <style>
@@ -278,6 +290,28 @@
278
  .config-item .section-title {
279
  margin-bottom: 5px;
280
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  </style>
282
 
283
  <script>
@@ -292,13 +326,13 @@
292
  {
293
  x: [],
294
  y: [],
295
- name: 'Model 1 Training Loss',
296
  type: 'scatter'
297
  },
298
  {
299
  x: [],
300
  y: [],
301
- name: 'Model 2 Training Loss',
302
  type: 'scatter'
303
  }
304
  ];
@@ -320,13 +354,13 @@
320
  {
321
  x: [],
322
  y: [],
323
- name: 'Model 1 Training Accuracy',
324
  type: 'scatter'
325
  },
326
  {
327
  x: [],
328
  y: [],
329
- name: 'Model 2 Training Accuracy',
330
  type: 'scatter'
331
  }
332
  ];
@@ -375,55 +409,91 @@
375
  // Setup WebSocket connection
376
  ws = new WebSocket(`ws://${window.location.host}/ws/compare`);
377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  ws.onmessage = function(event) {
 
379
  const data = JSON.parse(event.data);
380
 
381
- if (data.type === 'training_update') {
382
- const modelIndex = data.data.model_id - 1; // 0 for model1, 1 for model2
 
 
 
383
 
384
- // Update training metrics
385
  Plotly.extendTraces('lossChart', {
386
- x: [[data.data.step]],
387
- y: [[data.data.train_loss]]
388
  }, [modelIndex]);
389
 
 
390
  Plotly.extendTraces('accuracyChart', {
391
- x: [[data.data.step]],
392
- y: [[data.data.train_acc]]
393
  }, [modelIndex]);
 
 
 
 
 
 
 
 
 
 
 
 
394
  }
395
- else if (data.type === 'validation_update') {
396
- const modelIndex = data.data.model_id - 1;
 
397
 
398
- // Add validation points
399
- Plotly.addTraces('lossChart', {
400
- x: [data.data.step],
401
- y: [data.data.val_loss],
402
- name: `Model ${data.data.model_id} Validation Loss`,
403
- mode: 'markers',
404
- marker: { size: 8 }
405
- });
406
-
407
- Plotly.addTraces('accuracyChart', {
408
- x: [data.data.step],
409
- y: [data.data.val_acc],
410
- name: `Model ${data.data.model_id} Validation Accuracy`,
411
- mode: 'markers',
412
- marker: { size: 8 }
413
- });
414
  }
415
- else if (data.type === 'comparison_complete') {
 
 
416
  document.getElementById('startComparison').disabled = false;
417
  document.getElementById('stopComparison').disabled = true;
418
  }
419
  };
420
 
421
- // Start comparison
422
- ws.send(JSON.stringify({
423
- type: 'start_comparison',
424
- model1: model1Config,
425
- model2: model2Config
426
- }));
 
 
 
 
 
427
  }
428
 
429
  function stopComparison() {
 
11
  <div class="container">
12
  <h1>Compare Models</h1>
13
  <div class="models-grid">
14
+ <!-- Model A Configuration -->
15
  <div class="model-config">
16
+ <h3>Model A</h3>
17
  <div class="network-config">
18
  <h4>Network Architecture</h4>
19
  <div class="block-config">
 
78
  </div>
79
  </div>
80
 
81
+ <!-- Model B Configuration -->
82
  <div class="model-config">
83
+ <h3>Model B</h3>
84
  <div class="network-config">
85
  <h4>Network Architecture</h4>
86
  <div class="block-config">
 
157
  <div id="lossChart"></div>
158
  <div id="accuracyChart"></div>
159
  </div>
160
+
161
+ <!-- Add this after the charts container -->
162
+ <div class="training-status">
163
+ <p id="training-progress"></p>
164
+ </div>
165
+
166
+ <!-- Add this after the training-status div -->
167
+ <div class="inference-controls" style="display: none;">
168
+ <button id="goToInference" onclick="window.location.href='/inference'" class="inference-button">
169
+ Try Model Inference
170
+ </button>
171
+ </div>
172
  </div>
173
 
174
  <style>
 
290
  .config-item .section-title {
291
  margin-bottom: 5px;
292
  }
293
+
294
+ .training-status {
295
+ text-align: center;
296
+ margin: 20px 0;
297
+ font-weight: bold;
298
+ }
299
+
300
+ .inference-controls {
301
+ margin: 20px 0;
302
+ text-align: center;
303
+ }
304
+
305
+ .inference-button {
306
+ background-color: #28a745;
307
+ padding: 12px 24px;
308
+ font-size: 1.1em;
309
+ transition: background-color 0.3s;
310
+ }
311
+
312
+ .inference-button:hover {
313
+ background-color: #218838;
314
+ }
315
  </style>
316
 
317
  <script>
 
326
  {
327
  x: [],
328
  y: [],
329
+ name: 'Model A Training Loss',
330
  type: 'scatter'
331
  },
332
  {
333
  x: [],
334
  y: [],
335
+ name: 'Model B Training Loss',
336
  type: 'scatter'
337
  }
338
  ];
 
354
  {
355
  x: [],
356
  y: [],
357
+ name: 'Model A Training Accuracy',
358
  type: 'scatter'
359
  },
360
  {
361
  x: [],
362
  y: [],
363
+ name: 'Model B Training Accuracy',
364
  type: 'scatter'
365
  }
366
  ];
 
409
  // Setup WebSocket connection
410
  ws = new WebSocket(`ws://${window.location.host}/ws/compare`);
411
 
412
+ ws.onopen = function() {
413
+ console.log('WebSocket connection established');
414
+ // Only send the message after connection is established
415
+ const message = {
416
+ action: 'start_training',
417
+ parameters: {
418
+ model_params: {
419
+ model_a: model1Config,
420
+ model_b: model2Config
421
+ },
422
+ dataset_params: {
423
+ batch_size: model1Config.batch_size,
424
+ shuffle: true
425
+ }
426
+ }
427
+ };
428
+ console.log('Sending message:', message);
429
+ ws.send(JSON.stringify(message));
430
+ };
431
+
432
  ws.onmessage = function(event) {
433
+ console.log('Received message:', event.data);
434
  const data = JSON.parse(event.data);
435
 
436
+ if (data.status === 'training') {
437
+ const modelIndex = data.model === 'A' ? 0 : 1;
438
+ const iteration = data.metrics.iteration;
439
+
440
+ console.log(`Updating charts for model ${data.model} at iteration ${iteration}`);
441
 
442
+ // Update loss chart using iteration number
443
  Plotly.extendTraces('lossChart', {
444
+ x: [[iteration]],
445
+ y: [[data.metrics.loss]]
446
  }, [modelIndex]);
447
 
448
+ // Update accuracy chart using iteration number
449
  Plotly.extendTraces('accuracyChart', {
450
+ x: [[iteration]],
451
+ y: [[data.metrics.accuracy]]
452
  }, [modelIndex]);
453
+
454
+ // Update progress text with more detailed information
455
+ const progressText = document.getElementById('training-progress');
456
+ if (progressText) {
457
+ const progress = (data.metrics.iteration / data.metrics.total_iterations * 100).toFixed(1);
458
+ progressText.textContent =
459
+ `Training Model ${data.model} - ` +
460
+ `Epoch ${data.epoch + 1} - ` +
461
+ `Iteration ${data.metrics.iteration}/${data.metrics.total_iterations} ` +
462
+ `(${progress}%) - ` +
463
+ `Batch Size: ${data.batch_size}`;
464
+ }
465
  }
466
+ else if (data.status === 'complete') {
467
+ document.getElementById('startComparison').disabled = false;
468
+ document.getElementById('stopComparison').disabled = true;
469
 
470
+ const progressText = document.getElementById('training-progress');
471
+ if (progressText) {
472
+ progressText.textContent = 'Training Complete!';
473
+ }
474
+
475
+ // Show the inference button
476
+ document.querySelector('.inference-controls').style.display = 'block';
 
 
 
 
 
 
 
 
 
477
  }
478
+ else if (data.status === 'error') {
479
+ console.error('Training error:', data.message);
480
+ alert(`Training error: ${data.message}`);
481
  document.getElementById('startComparison').disabled = false;
482
  document.getElementById('stopComparison').disabled = true;
483
  }
484
  };
485
 
486
+ ws.onerror = function(error) {
487
+ console.error('WebSocket error:', error);
488
+ document.getElementById('startComparison').disabled = false;
489
+ document.getElementById('stopComparison').disabled = true;
490
+ };
491
+
492
+ ws.onclose = function(event) {
493
+ console.log('WebSocket connection closed:', event);
494
+ document.getElementById('startComparison').disabled = false;
495
+ document.getElementById('stopComparison').disabled = true;
496
+ };
497
  }
498
 
499
  function stopComparison() {