Shilpaj commited on
Commit
6f5f635
·
1 Parent(s): 0d84fb8

Feat: Logic for model training and inference

Browse files
scripts/__init__.py ADDED
File without changes
scripts/app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @app.post("/train")
2
+ async def train_model(config: dict):
3
+ network_config = NetworkConfig()
4
+ network_config.update(**config)
5
+
6
+ # Create model with configured architecture
7
+ model = Net(
8
+ block1=network_config.block1,
9
+ block2=network_config.block2,
10
+ block3=network_config.block3
11
+ )
12
+
13
+ # Start training with websocket updates
14
+ result = await train(model, network_config)
15
+ return result
16
+
17
+ @app.websocket("/ws/compare")
18
+ async def websocket_compare_endpoint(websocket: WebSocket):
19
+ await websocket.accept()
20
+ try:
21
+ while True:
22
+ data = await websocket.receive_json()
23
+ if data.get("type") == "start_comparison":
24
+ # Create and train both models
25
+ model1_config = NetworkConfig()
26
+ model2_config = NetworkConfig()
27
+
28
+ # Update configs with received data
29
+ model1_config.update(**data["model1"])
30
+ model2_config.update(**data["model2"])
31
+
32
+ # Create models with respective configurations
33
+ model1 = Net(
34
+ block1=model1_config.block1,
35
+ block2=model1_config.block2,
36
+ block3=model1_config.block3
37
+ )
38
+
39
+ model2 = Net(
40
+ block1=model2_config.block1,
41
+ block2=model2_config.block2,
42
+ block3=model2_config.block3
43
+ )
44
+
45
+ # Train both models concurrently
46
+ tasks = [
47
+ train(model1, model1_config, websocket),
48
+ train(model2, model2_config, websocket)
49
+ ]
50
+
51
+ results = await asyncio.gather(*tasks)
52
+
53
+ # Send completion message
54
+ await websocket.send_json({
55
+ "type": "comparison_complete",
56
+ "data": {
57
+ "model1": results[0],
58
+ "model2": results[1]
59
+ }
60
+ })
61
+
62
+ except Exception as e:
63
+ print(f"Error in websocket connection: {e}")
64
+ finally:
65
+ await websocket.close()
66
+
67
+ @app.post("/compare")
68
+ async def compare_models(request: Request):
69
+ data = await request.json()
70
+ return {"status": "started", "message": "Model comparison initiated"}
71
+
72
+ @app.get("/train/compare")
73
+ async def compare_page(request: Request):
74
+ return templates.TemplateResponse("train_compare.html", {"request": request})
scripts/inference/__init__.py ADDED
File without changes
scripts/inference/infer.py ADDED
File without changes
scripts/model.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class Net(nn.Module):
6
+ def __init__(self, kernels=[32, 64, 128]):
7
+ super(Net, self).__init__()
8
+ # First Convolutional Block
9
+ self.conv1 = nn.Conv2d(1, kernels[0], 3, padding=1)
10
+ self.bn1 = nn.BatchNorm2d(kernels[0])
11
+
12
+ # Second Convolutional Block
13
+ self.conv2 = nn.Conv2d(kernels[0], kernels[1], 3, padding=1)
14
+ self.bn2 = nn.BatchNorm2d(kernels[1])
15
+
16
+ # Third Convolutional Block
17
+ self.conv3 = nn.Conv2d(kernels[1], kernels[2], 3, padding=1)
18
+ self.bn3 = nn.BatchNorm2d(kernels[2])
19
+
20
+ self.pool = nn.MaxPool2d(2, 2)
21
+ self.dropout = nn.Dropout(0.25)
22
+
23
+ # Calculate the size after convolutions and pooling
24
+ # Input: 28x28 -> after three pooling layers: 7x7
25
+ # Final feature map size will be kernels[2] x 7 x 7
26
+ self.fc1 = nn.Linear(kernels[2] * 7 * 7, 256)
27
+ self.fc1_bn = nn.BatchNorm1d(256)
28
+ self.fc2 = nn.Linear(256, 10)
29
+
30
+ # Initialize weights
31
+ self._initialize_weights()
32
+
33
+ def forward(self, x):
34
+ # First conv block
35
+ x = self.conv1(x)
36
+ x = self.bn1(x)
37
+ x = F.relu(x)
38
+ x = self.pool(x) # 28x28 -> 14x14
39
+
40
+ # Second conv block
41
+ x = self.conv2(x)
42
+ x = self.bn2(x)
43
+ x = F.relu(x)
44
+ x = self.pool(x) # 14x14 -> 7x7
45
+
46
+ # Third conv block
47
+ x = self.conv3(x)
48
+ x = self.bn3(x)
49
+ x = F.relu(x)
50
+ # No pooling here to maintain spatial dimensions
51
+
52
+ # Flatten
53
+ x = x.view(-1, self.num_flat_features(x))
54
+ x = self.dropout(x)
55
+
56
+ # Fully connected layers
57
+ x = self.fc1(x)
58
+ x = self.fc1_bn(x)
59
+ x = F.relu(x)
60
+ x = self.dropout(x)
61
+
62
+ x = self.fc2(x)
63
+ return F.log_softmax(x, dim=1)
64
+
65
+ def num_flat_features(self, x):
66
+ size = x.size()[1:]
67
+ num_features = 1
68
+ for s in size:
69
+ num_features *= s
70
+ return num_features
71
+
72
+ def _initialize_weights(self):
73
+ for m in self.modules():
74
+ if isinstance(m, nn.Conv2d):
75
+ # Xavier initialization for CONV layers
76
+ nn.init.xavier_uniform_(m.weight)
77
+ if m.bias is not None:
78
+ nn.init.zeros_(m.bias)
79
+ elif isinstance(m, nn.BatchNorm2d):
80
+ nn.init.ones_(m.weight)
81
+ nn.init.zeros_(m.bias)
82
+ elif isinstance(m, nn.Linear):
83
+ # Xavier initialization for FC layers
84
+ nn.init.xavier_uniform_(m.weight)
85
+ nn.init.zeros_(m.bias)
scripts/training/__init__.py ADDED
File without changes
scripts/training/config.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BLOCK_OPTIONS = [8, 16, 32, 64, 128]
2
+
3
+ class NetworkConfig:
4
+ def __init__(self):
5
+ self.block1 = 32
6
+ self.block2 = 64
7
+ self.block3 = 128
8
+ self.batch_size = 64
9
+ self.optimizer = 'SGD'
10
+ self.epochs = 1
11
+
12
+ def update(self, **kwargs):
13
+ for key, value in kwargs.items():
14
+ if hasattr(self, key):
15
+ setattr(self, key, value)
scripts/training/model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import Conv2d, MaxPool2d, Linear, Sequential, ReLU, LogSoftmax, Flatten
5
+
6
+
7
+ class Net(torch.nn.Module):
8
+ def __init__(self, block1=32, block2=64, block3=128):
9
+ """
10
+ Constructor
11
+ """
12
+ super(Net, self).__init__()
13
+
14
+ # Define model architecture with configurable blocks
15
+ self.conv1 = nn.Conv2d(1, block1, kernel_size=3)
16
+ self.conv2 = nn.Conv2d(block1, block2, kernel_size=3)
17
+ self.conv3 = nn.Conv2d(block2, block3, kernel_size=3)
18
+ self.conv4 = nn.Conv2d(block3, block3*2, kernel_size=3)
19
+
20
+ # Calculate the input size for the first fully connected layer
21
+ self.fc1 = nn.Linear(block3*2*16, 50)
22
+ self.fc2 = nn.Linear(50, 10)
23
+
24
+ def forward(self, x):
25
+ """
26
+ Forward pass for model training
27
+ :param x: Input layer
28
+ :return: Output of the model
29
+ """
30
+ x = F.relu(self.conv1(x))
31
+ x = F.relu(F.max_pool2d(self.conv2(x), 2))
32
+ x = F.relu(self.conv3(x))
33
+ x = F.relu(F.max_pool2d(self.conv4(x), 2))
34
+ x = x.view(x.size(0), -1)
35
+ x = F.relu(self.fc1(x))
36
+ x = self.fc2(x)
37
+ return x
scripts/training/train.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader, Dataset
5
+ from torchvision import transforms
6
+ import numpy as np
7
+ import gzip
8
+ import os
9
+ from pathlib import Path
10
+ from datetime import datetime
11
+ import urllib.request
12
+ import shutil
13
+ from tqdm import tqdm
14
+ import asyncio
15
+
16
+ def download_and_extract_mnist_data():
17
+ """Download and extract MNIST dataset from a reliable mirror"""
18
+ base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
19
+ files = {
20
+ "train_images": "train-images-idx3-ubyte.gz",
21
+ "train_labels": "train-labels-idx1-ubyte.gz",
22
+ "test_images": "t10k-images-idx3-ubyte.gz",
23
+ "test_labels": "t10k-labels-idx1-ubyte.gz"
24
+ }
25
+
26
+ data_dir = Path("data/MNIST/raw")
27
+ data_dir.mkdir(parents=True, exist_ok=True)
28
+
29
+ for file_name in files.values():
30
+ gz_file_path = data_dir / file_name
31
+ extracted_file_path = data_dir / file_name.replace('.gz', '')
32
+
33
+ # If the extracted file exists, skip downloading
34
+ if extracted_file_path.exists():
35
+ print(f"{extracted_file_path} already exists, skipping download.")
36
+ continue
37
+
38
+ # Download the file
39
+ print(f"Downloading {file_name}...")
40
+ url = base_url + file_name
41
+ try:
42
+ urllib.request.urlretrieve(url, gz_file_path)
43
+ print(f"Successfully downloaded {file_name}")
44
+ except Exception as e:
45
+ print(f"Failed to download {file_name}: {e}")
46
+ raise Exception(f"Could not download {file_name}")
47
+
48
+ # Extract the files
49
+ try:
50
+ print(f"Extracting {file_name}...")
51
+ with gzip.open(gz_file_path, 'rb') as f_in:
52
+ with open(extracted_file_path, 'wb') as f_out:
53
+ shutil.copyfileobj(f_in, f_out)
54
+ print(f"Successfully extracted {file_name}")
55
+ except Exception as e:
56
+ print(f"Failed to extract {file_name}: {e}")
57
+ raise Exception(f"Could not extract {file_name}")
58
+
59
+ def load_mnist_images(filename):
60
+ with open(filename, 'rb') as f:
61
+ data = np.frombuffer(f.read(), np.uint8, offset=16)
62
+ return data.reshape(-1, 1, 28, 28).astype(np.float32) / 255.0
63
+
64
+ def load_mnist_labels(filename):
65
+ with open(filename, 'rb') as f:
66
+ return np.frombuffer(f.read(), np.uint8, offset=8)
67
+
68
+ class CustomMNISTDataset(Dataset):
69
+ def __init__(self, images_path, labels_path, transform=None):
70
+ self.images = load_mnist_images(images_path)
71
+ self.labels = load_mnist_labels(labels_path)
72
+ self.transform = transform
73
+
74
+ def __len__(self):
75
+ return len(self.labels)
76
+
77
+ def __getitem__(self, idx):
78
+ image = torch.FloatTensor(self.images[idx])
79
+ label = int(self.labels[idx])
80
+
81
+ if self.transform:
82
+ image = self.transform(image)
83
+
84
+ return image, label
85
+
86
+ def validate(model, test_loader, criterion, device):
87
+ """Modified validate function to handle validation properly"""
88
+ model.eval()
89
+ val_loss = 0
90
+ correct = 0
91
+ total = 0
92
+ num_batches = 0
93
+
94
+ with torch.no_grad(): # Important: no gradient computation in validation
95
+ for data, target in test_loader:
96
+ data, target = data.to(device), target.to(device)
97
+ output = model(data)
98
+ val_loss += criterion(output, target).item() # Don't scale by batch size
99
+ _, predicted = output.max(1)
100
+ total += target.size(0)
101
+ correct += predicted.eq(target).sum().item()
102
+ num_batches += 1
103
+
104
+ # Average the loss by number of batches and accuracy by total samples
105
+ val_loss = val_loss / num_batches # Average loss across batches
106
+ val_acc = 100. * correct / total
107
+
108
+ return val_loss, val_acc
109
+
110
+ async def train(model, config, websocket=None):
111
+ print("\nStarting training...")
112
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
113
+ print(f"Using device: {device}")
114
+ model = model.to(device)
115
+
116
+ # Create data directory if it doesn't exist
117
+ data_dir = Path("data")
118
+ data_dir.mkdir(exist_ok=True)
119
+
120
+ # Ensure data is downloaded and extracted
121
+ print("Preparing dataset...")
122
+ download_and_extract_mnist_data()
123
+
124
+ # Paths to the extracted files
125
+ train_images_path = "data/MNIST/raw/train-images-idx3-ubyte"
126
+ train_labels_path = "data/MNIST/raw/train-labels-idx1-ubyte"
127
+ test_images_path = "data/MNIST/raw/t10k-images-idx3-ubyte"
128
+ test_labels_path = "data/MNIST/raw/t10k-labels-idx1-ubyte"
129
+
130
+ # Data loading
131
+ transform = transforms.Compose([
132
+ transforms.Normalize((0.1307,), (0.3081,))
133
+ ])
134
+
135
+ train_dataset = CustomMNISTDataset(train_images_path, train_labels_path, transform=transform)
136
+ test_dataset = CustomMNISTDataset(test_images_path, test_labels_path, transform=transform)
137
+
138
+ train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
139
+ test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)
140
+
141
+ print(f"Dataset loaded. Training samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")
142
+
143
+ # Initialize optimizer based on config
144
+ if config.optimizer.lower() == 'adam':
145
+ optimizer = optim.Adam(model.parameters())
146
+ else:
147
+ optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
148
+
149
+ criterion = nn.CrossEntropyLoss()
150
+
151
+ print("\nTraining Configuration:")
152
+ print(f"Optimizer: {config.optimizer}")
153
+ print(f"Batch Size: {config.batch_size}")
154
+ print(f"Network Architecture: {config.block1}-{config.block2}-{config.block3}")
155
+ print("\nStarting training loop...")
156
+
157
+ best_val_acc = 0
158
+ history = {
159
+ 'train_loss': [],
160
+ 'train_acc': [],
161
+ 'val_loss': [],
162
+ 'val_acc': []
163
+ }
164
+
165
+ try:
166
+ for epoch in range(config.epochs):
167
+ model.train()
168
+ total_loss = 0
169
+ correct = 0
170
+ total = 0
171
+
172
+ # Create progress bar for each epoch
173
+ progress_bar = tqdm(
174
+ train_loader,
175
+ desc=f"Epoch {epoch+1}/{config.epochs}",
176
+ unit='batch',
177
+ leave=True
178
+ )
179
+
180
+ for batch_idx, (data, target) in enumerate(progress_bar):
181
+ data, target = data.to(device), target.to(device)
182
+ optimizer.zero_grad()
183
+ output = model(data)
184
+ loss = criterion(output, target)
185
+ loss.backward()
186
+ optimizer.step()
187
+
188
+ # Calculate batch accuracy
189
+ pred = output.argmax(dim=1, keepdim=True)
190
+ correct += pred.eq(target.view_as(pred)).sum().item()
191
+ total += target.size(0)
192
+ total_loss += loss.item()
193
+
194
+ # Calculate current metrics
195
+ current_loss = total_loss / (batch_idx + 1)
196
+ current_acc = 100. * correct / total
197
+
198
+ # Update progress bar description
199
+ progress_bar.set_postfix({
200
+ 'loss': f'{current_loss:.4f}',
201
+ 'acc': f'{current_acc:.2f}%'
202
+ })
203
+
204
+ # Send training update through websocket
205
+ if websocket:
206
+ try:
207
+ await websocket.send_json({
208
+ 'type': 'training_update',
209
+ 'data': {
210
+ 'step': batch_idx + epoch * len(train_loader),
211
+ 'train_loss': current_loss,
212
+ 'train_acc': current_acc
213
+ }
214
+ })
215
+ except Exception as e:
216
+ print(f"Error sending websocket update: {e}")
217
+
218
+ # Calculate epoch metrics
219
+ train_loss = total_loss / len(train_loader)
220
+ train_acc = 100. * correct / total
221
+
222
+ # Validation phase
223
+ model.eval()
224
+ val_loss = 0
225
+ val_correct = 0
226
+ val_total = 0
227
+
228
+ print("\nRunning validation...")
229
+ with torch.no_grad():
230
+ for data, target in test_loader:
231
+ data, target = data.to(device), target.to(device)
232
+ output = model(data)
233
+ val_loss += criterion(output, target).item()
234
+ pred = output.argmax(dim=1, keepdim=True)
235
+ val_correct += pred.eq(target.view_as(pred)).sum().item()
236
+ val_total += target.size(0)
237
+
238
+ val_loss /= len(test_loader)
239
+ val_acc = 100. * val_correct / val_total
240
+
241
+ # Print epoch results
242
+ print(f"\nEpoch {epoch+1}/{config.epochs} Results:")
243
+ print(f"Training Loss: {train_loss:.4f} | Training Accuracy: {train_acc:.2f}%")
244
+ print(f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_acc:.2f}%")
245
+
246
+ # Send validation update through websocket
247
+ if websocket:
248
+ try:
249
+ await websocket.send_json({
250
+ 'type': 'validation_update',
251
+ 'data': {
252
+ 'step': (epoch + 1) * len(train_loader),
253
+ 'val_loss': val_loss,
254
+ 'val_acc': val_acc
255
+ }
256
+ })
257
+ except Exception as e:
258
+ print(f"Error sending websocket update: {e}")
259
+
260
+ # Save best model
261
+ if val_acc > best_val_acc:
262
+ best_val_acc = val_acc
263
+ print(f"\nNew best validation accuracy: {val_acc:.2f}%")
264
+ print("Saving model...")
265
+ torch.save(model.state_dict(), 'best_model.pth')
266
+
267
+ except Exception as e:
268
+ print(f"\nError during training: {e}")
269
+ raise e
270
+
271
+ print("\nTraining completed!")
272
+ print(f"Best validation accuracy: {best_val_acc:.2f}%")
273
+ return history