Spaces:
Sleeping
Sleeping
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>Train Single Model - MNIST</title> | |
<link rel="stylesheet" href="{{ url_for('static', path='/css/style.css') }}"> | |
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=Roboto+Mono&display=swap" rel="stylesheet"> | |
</head> | |
<body> | |
<div class="container"> | |
<h1>Train Single Model</h1> | |
<!-- Network Architecture Section --> | |
<div class="model-config"> | |
<h3>Model Configuration</h3> | |
<div class="network-config"> | |
<h4>Network Architecture</h4> | |
<div class="block-config"> | |
<div class="block"> | |
<label for="block1">Block-1:</label> | |
<select id="block1" name="block1" class="form-select"> | |
<option value="8">8</option> | |
<option value="16">16</option> | |
<option value="32" selected>32</option> | |
<option value="64">64</option> | |
<option value="128">128</option> | |
</select> | |
</div> | |
<div class="block"> | |
<label for="block2">Block-2:</label> | |
<select id="block2" name="block2" class="form-select"> | |
<option value="8">8</option> | |
<option value="16">16</option> | |
<option value="32">32</option> | |
<option value="64" selected>64</option> | |
<option value="128">128</option> | |
</select> | |
</div> | |
<div class="block"> | |
<label for="block3">Block-3:</label> | |
<select id="block3" name="block3" class="form-select"> | |
<option value="8">8</option> | |
<option value="16">16</option> | |
<option value="32">32</option> | |
<option value="64">64</option> | |
<option value="128" selected>128</option> | |
</select> | |
</div> | |
</div> | |
</div> | |
<div class="training-config"> | |
<div class="config-item"> | |
<label for="optimizer">Optimizer:</label> | |
<select id="optimizer" name="optimizer"> | |
<option value="SGD" selected>SGD</option> | |
<option value="Adam">Adam</option> | |
</select> | |
</div> | |
<div class="config-item"> | |
<label for="batch_size">Batch Size:</label> | |
<select id="batch_size" name="batch_size"> | |
<option value="32">32</option> | |
<option value="64" selected>64</option> | |
<option value="128">128</option> | |
</select> | |
</div> | |
<div class="config-item"> | |
<label for="epochs">Epochs:</label> | |
<select id="epochs" name="epochs"> | |
<option value="1">1</option> | |
<option value="2">2</option> | |
<option value="3">3</option> | |
</select> | |
</div> | |
</div> | |
</div> | |
<!-- Training Controls --> | |
<div class="controls"> | |
<button id="startTraining" onclick="startTraining()">Start Training</button> | |
<button id="stopTraining" onclick="stopTraining()" disabled>Stop Training</button> | |
</div> | |
<!-- Training Progress --> | |
<div class="charts-container"> | |
<div id="lossChart"></div> | |
<div id="accuracyChart"></div> | |
</div> | |
<!-- Inference Controls --> | |
<div class="inference-controls" style="display: none;"> | |
<button id="goToInference" onclick="window.location.href='/inference'" class="inference-button"> | |
Try Model Inference | |
</button> | |
</div> | |
</div> | |
<script> | |
let ws; | |
let lossChart; | |
let accuracyChart; | |
// Initialize charts | |
document.addEventListener('DOMContentLoaded', function() { | |
// Loss chart configuration | |
const lossData = [ | |
{ | |
x: [], | |
y: [], | |
name: 'Training Loss', | |
type: 'scatter' | |
}, | |
{ | |
x: [], | |
y: [], | |
name: 'Validation Loss', | |
type: 'scatter' | |
} | |
]; | |
const lossLayout = { | |
title: 'Loss', | |
xaxis: { | |
title: 'Iterations', | |
rangemode: 'nonnegative' | |
}, | |
yaxis: { | |
title: 'Loss', | |
rangemode: 'nonnegative' | |
} | |
}; | |
// Accuracy chart configuration | |
const accuracyData = [ | |
{ | |
x: [], | |
y: [], | |
name: 'Training Accuracy', | |
type: 'scatter' | |
}, | |
{ | |
x: [], | |
y: [], | |
name: 'Validation Accuracy', | |
type: 'scatter' | |
} | |
]; | |
const accuracyLayout = { | |
title: 'Accuracy', | |
xaxis: { | |
title: 'Iterations', | |
rangemode: 'nonnegative' | |
}, | |
yaxis: { | |
title: 'Accuracy (%)', | |
range: [0, 100] | |
} | |
}; | |
// Create charts | |
Plotly.newPlot('lossChart', lossData, lossLayout); | |
Plotly.newPlot('accuracyChart', accuracyData, accuracyLayout); | |
}); | |
function startTraining() { | |
// Disable start button and enable stop button | |
document.getElementById('startTraining').disabled = true; | |
document.getElementById('stopTraining').disabled = false; | |
// Clear previous charts | |
Plotly.purge('lossChart'); | |
Plotly.purge('accuracyChart'); | |
// Initialize new charts | |
const lossData = [ | |
{ | |
x: [], | |
y: [], | |
name: 'Training Loss', | |
type: 'scatter' | |
}, | |
{ | |
x: [], | |
y: [], | |
name: 'Validation Loss', | |
type: 'scatter' | |
} | |
]; | |
const accuracyData = [ | |
{ | |
x: [], | |
y: [], | |
name: 'Training Accuracy', | |
type: 'scatter' | |
}, | |
{ | |
x: [], | |
y: [], | |
name: 'Validation Accuracy', | |
type: 'scatter' | |
} | |
]; | |
Plotly.newPlot('lossChart', lossData, { | |
title: 'Loss', | |
xaxis: { title: 'Iterations', rangemode: 'nonnegative' }, | |
yaxis: { title: 'Loss', rangemode: 'nonnegative' } | |
}); | |
Plotly.newPlot('accuracyChart', accuracyData, { | |
title: 'Accuracy', | |
xaxis: { title: 'Iterations', rangemode: 'nonnegative' }, | |
yaxis: { title: 'Accuracy (%)', range: [0, 100] } | |
}); | |
// Setup WebSocket connection | |
ws = new WebSocket(`ws://${window.location.host}/ws/train`); | |
ws.onopen = function() { | |
console.log("WebSocket connection established"); | |
// Send configuration through WebSocket | |
const config = { | |
block1: parseInt(document.getElementById('block1').value), | |
block2: parseInt(document.getElementById('block2').value), | |
block3: parseInt(document.getElementById('block3').value), | |
optimizer: document.getElementById('optimizer').value, | |
batch_size: parseInt(document.getElementById('batch_size').value), | |
epochs: parseInt(document.getElementById('epochs').value) | |
}; | |
ws.send(JSON.stringify(config)); | |
}; | |
ws.onerror = function(error) { | |
console.error("WebSocket error:", error); | |
stopTraining(); | |
alert("Error connecting to training server"); | |
}; | |
ws.onclose = function() { | |
console.log("WebSocket connection closed"); | |
stopTraining(); | |
}; | |
ws.onmessage = function(event) { | |
const data = JSON.parse(event.data); | |
if (data.type === 'training_update') { | |
// Update training metrics (trace index 0) | |
Plotly.extendTraces('lossChart', { | |
x: [[data.data.step]], | |
y: [[data.data.train_loss]] | |
}, [0]); | |
Plotly.extendTraces('accuracyChart', { | |
x: [[data.data.step]], | |
y: [[data.data.train_acc]] | |
}, [0]); | |
} | |
else if (data.type === 'validation_update') { | |
// Update validation metrics (trace index 1) | |
Plotly.extendTraces('lossChart', { | |
x: [[data.data.step]], | |
y: [[data.data.val_loss]] | |
}, [1]); | |
Plotly.extendTraces('accuracyChart', { | |
x: [[data.data.step]], | |
y: [[data.data.val_acc]] | |
}, [1]); | |
} | |
else if (data.type === 'training_complete') { | |
alert(data.data.message); | |
stopTraining(); | |
// Show the inference button | |
document.querySelector('.inference-controls').style.display = 'block'; | |
} | |
else if (data.type === 'training_error') { | |
alert(data.data.message); | |
stopTraining(); | |
} | |
}; | |
} | |
function stopTraining() { | |
if (ws) { | |
ws.close(); | |
} | |
document.getElementById('startTraining').disabled = false; | |
document.getElementById('stopTraining').disabled = true; | |
} | |
</script> | |
<style> | |
.container { | |
max-width: 1200px; | |
margin: 0 auto; | |
padding: 20px; | |
} | |
.model-config { | |
padding: 20px; | |
border: 1px solid #ddd; | |
border-radius: 5px; | |
margin-bottom: 20px; | |
} | |
.network-config { | |
margin-bottom: 20px; | |
} | |
.network-config h4 { | |
margin: 0 0 15px 0; | |
font-size: 1.1em; | |
} | |
.block-config { | |
display: flex; | |
justify-content: space-between; | |
gap: 20px; | |
} | |
.block { | |
flex: 1; | |
} | |
.block label { | |
display: block; | |
margin-bottom: 5px; | |
font-weight: bold; | |
} | |
.training-config { | |
display: flex; | |
gap: 20px; | |
} | |
.config-item { | |
flex: 1; | |
} | |
.config-item label { | |
display: block; | |
margin-bottom: 5px; | |
font-weight: bold; | |
} | |
select { | |
width: 100%; | |
padding: 8px; | |
border: 1px solid #ddd; | |
border-radius: 4px; | |
} | |
.controls { | |
margin: 20px 0; | |
} | |
button { | |
padding: 10px 20px; | |
margin-right: 10px; | |
border: none; | |
border-radius: 4px; | |
background-color: #007bff; | |
color: white; | |
cursor: pointer; | |
} | |
button:disabled { | |
background-color: #ccc; | |
cursor: not-allowed; | |
} | |
.charts-container { | |
display: flex; | |
flex-direction: column; | |
gap: 20px; | |
margin-top: 20px; | |
} | |
#lossChart, #accuracyChart { | |
height: 400px; | |
width: 100%; | |
} | |
.inference-controls { | |
margin: 20px 0; | |
text-align: center; | |
} | |
.inference-button { | |
background-color: #28a745; | |
padding: 12px 24px; | |
font-size: 1.1em; | |
transition: background-color 0.3s; | |
} | |
.inference-button:hover { | |
background-color: #218838; | |
} | |
</style> | |
</body> | |
</html> |