MnistStudio / templates /train_single.html
Shilpaj's picture
Refactor: css file address
244431c
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Train Single Model - MNIST</title>
<link rel="stylesheet" href="{{ url_for('static', path='/css/style.css') }}">
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=Roboto+Mono&display=swap" rel="stylesheet">
</head>
<body>
<div class="container">
<h1>Train Single Model</h1>
<!-- Network Architecture Section -->
<div class="model-config">
<h3>Model Configuration</h3>
<div class="network-config">
<h4>Network Architecture</h4>
<div class="block-config">
<div class="block">
<label for="block1">Block-1:</label>
<select id="block1" name="block1" class="form-select">
<option value="8">8</option>
<option value="16">16</option>
<option value="32" selected>32</option>
<option value="64">64</option>
<option value="128">128</option>
</select>
</div>
<div class="block">
<label for="block2">Block-2:</label>
<select id="block2" name="block2" class="form-select">
<option value="8">8</option>
<option value="16">16</option>
<option value="32">32</option>
<option value="64" selected>64</option>
<option value="128">128</option>
</select>
</div>
<div class="block">
<label for="block3">Block-3:</label>
<select id="block3" name="block3" class="form-select">
<option value="8">8</option>
<option value="16">16</option>
<option value="32">32</option>
<option value="64">64</option>
<option value="128" selected>128</option>
</select>
</div>
</div>
</div>
<div class="training-config">
<div class="config-item">
<label for="optimizer">Optimizer:</label>
<select id="optimizer" name="optimizer">
<option value="SGD" selected>SGD</option>
<option value="Adam">Adam</option>
</select>
</div>
<div class="config-item">
<label for="batch_size">Batch Size:</label>
<select id="batch_size" name="batch_size">
<option value="32">32</option>
<option value="64" selected>64</option>
<option value="128">128</option>
</select>
</div>
<div class="config-item">
<label for="epochs">Epochs:</label>
<select id="epochs" name="epochs">
<option value="1">1</option>
<option value="2">2</option>
<option value="3">3</option>
</select>
</div>
</div>
</div>
<!-- Training Controls -->
<div class="controls">
<button id="startTraining" onclick="startTraining()">Start Training</button>
<button id="stopTraining" onclick="stopTraining()" disabled>Stop Training</button>
</div>
<!-- Training Progress -->
<div class="charts-container">
<div id="lossChart"></div>
<div id="accuracyChart"></div>
</div>
<!-- Inference Controls -->
<div class="inference-controls" style="display: none;">
<button id="goToInference" onclick="window.location.href='/inference'" class="inference-button">
Try Model Inference
</button>
</div>
</div>
<script>
let ws;
let lossChart;
let accuracyChart;
// Initialize charts
document.addEventListener('DOMContentLoaded', function() {
// Loss chart configuration
const lossData = [
{
x: [],
y: [],
name: 'Training Loss',
type: 'scatter'
},
{
x: [],
y: [],
name: 'Validation Loss',
type: 'scatter'
}
];
const lossLayout = {
title: 'Loss',
xaxis: {
title: 'Iterations',
rangemode: 'nonnegative'
},
yaxis: {
title: 'Loss',
rangemode: 'nonnegative'
}
};
// Accuracy chart configuration
const accuracyData = [
{
x: [],
y: [],
name: 'Training Accuracy',
type: 'scatter'
},
{
x: [],
y: [],
name: 'Validation Accuracy',
type: 'scatter'
}
];
const accuracyLayout = {
title: 'Accuracy',
xaxis: {
title: 'Iterations',
rangemode: 'nonnegative'
},
yaxis: {
title: 'Accuracy (%)',
range: [0, 100]
}
};
// Create charts
Plotly.newPlot('lossChart', lossData, lossLayout);
Plotly.newPlot('accuracyChart', accuracyData, accuracyLayout);
});
function startTraining() {
// Disable start button and enable stop button
document.getElementById('startTraining').disabled = true;
document.getElementById('stopTraining').disabled = false;
// Clear previous charts
Plotly.purge('lossChart');
Plotly.purge('accuracyChart');
// Initialize new charts
const lossData = [
{
x: [],
y: [],
name: 'Training Loss',
type: 'scatter'
},
{
x: [],
y: [],
name: 'Validation Loss',
type: 'scatter'
}
];
const accuracyData = [
{
x: [],
y: [],
name: 'Training Accuracy',
type: 'scatter'
},
{
x: [],
y: [],
name: 'Validation Accuracy',
type: 'scatter'
}
];
Plotly.newPlot('lossChart', lossData, {
title: 'Loss',
xaxis: { title: 'Iterations', rangemode: 'nonnegative' },
yaxis: { title: 'Loss', rangemode: 'nonnegative' }
});
Plotly.newPlot('accuracyChart', accuracyData, {
title: 'Accuracy',
xaxis: { title: 'Iterations', rangemode: 'nonnegative' },
yaxis: { title: 'Accuracy (%)', range: [0, 100] }
});
// Setup WebSocket connection
ws = new WebSocket(`ws://${window.location.host}/ws/train`);
ws.onopen = function() {
console.log("WebSocket connection established");
// Send configuration through WebSocket
const config = {
block1: parseInt(document.getElementById('block1').value),
block2: parseInt(document.getElementById('block2').value),
block3: parseInt(document.getElementById('block3').value),
optimizer: document.getElementById('optimizer').value,
batch_size: parseInt(document.getElementById('batch_size').value),
epochs: parseInt(document.getElementById('epochs').value)
};
ws.send(JSON.stringify(config));
};
ws.onerror = function(error) {
console.error("WebSocket error:", error);
stopTraining();
alert("Error connecting to training server");
};
ws.onclose = function() {
console.log("WebSocket connection closed");
stopTraining();
};
ws.onmessage = function(event) {
const data = JSON.parse(event.data);
if (data.type === 'training_update') {
// Update training metrics (trace index 0)
Plotly.extendTraces('lossChart', {
x: [[data.data.step]],
y: [[data.data.train_loss]]
}, [0]);
Plotly.extendTraces('accuracyChart', {
x: [[data.data.step]],
y: [[data.data.train_acc]]
}, [0]);
}
else if (data.type === 'validation_update') {
// Update validation metrics (trace index 1)
Plotly.extendTraces('lossChart', {
x: [[data.data.step]],
y: [[data.data.val_loss]]
}, [1]);
Plotly.extendTraces('accuracyChart', {
x: [[data.data.step]],
y: [[data.data.val_acc]]
}, [1]);
}
else if (data.type === 'training_complete') {
alert(data.data.message);
stopTraining();
// Show the inference button
document.querySelector('.inference-controls').style.display = 'block';
}
else if (data.type === 'training_error') {
alert(data.data.message);
stopTraining();
}
};
}
function stopTraining() {
if (ws) {
ws.close();
}
document.getElementById('startTraining').disabled = false;
document.getElementById('stopTraining').disabled = true;
}
</script>
<style>
.container {
max-width: 1200px;
margin: 0 auto;
padding: 20px;
}
.model-config {
padding: 20px;
border: 1px solid #ddd;
border-radius: 5px;
margin-bottom: 20px;
}
.network-config {
margin-bottom: 20px;
}
.network-config h4 {
margin: 0 0 15px 0;
font-size: 1.1em;
}
.block-config {
display: flex;
justify-content: space-between;
gap: 20px;
}
.block {
flex: 1;
}
.block label {
display: block;
margin-bottom: 5px;
font-weight: bold;
}
.training-config {
display: flex;
gap: 20px;
}
.config-item {
flex: 1;
}
.config-item label {
display: block;
margin-bottom: 5px;
font-weight: bold;
}
select {
width: 100%;
padding: 8px;
border: 1px solid #ddd;
border-radius: 4px;
}
.controls {
margin: 20px 0;
}
button {
padding: 10px 20px;
margin-right: 10px;
border: none;
border-radius: 4px;
background-color: #007bff;
color: white;
cursor: pointer;
}
button:disabled {
background-color: #ccc;
cursor: not-allowed;
}
.charts-container {
display: flex;
flex-direction: column;
gap: 20px;
margin-top: 20px;
}
#lossChart, #accuracyChart {
height: 400px;
width: 100%;
}
.inference-controls {
margin: 20px 0;
text-align: center;
}
.inference-button {
background-color: #28a745;
padding: 12px 24px;
font-size: 1.1em;
transition: background-color 0.3s;
}
.inference-button:hover {
background-color: #218838;
}
</style>
</body>
</html>