Spaces:
Sleeping
Sleeping
Sreekanth Tangirala
commited on
Commit
·
c773c40
1
Parent(s):
de2aabe
change to progress and epochs to 20
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ import torch
|
|
3 |
import torchvision.transforms as transforms
|
4 |
from PIL import Image
|
5 |
from torchvision.models import resnet50
|
|
|
6 |
|
7 |
# Load model
|
8 |
model = resnet50(pretrained=False)
|
|
|
3 |
import torchvision.transforms as transforms
|
4 |
from PIL import Image
|
5 |
from torchvision.models import resnet50
|
6 |
+
import torch.nn as nn
|
7 |
|
8 |
# Load model
|
9 |
model = resnet50(pretrained=False)
|
train.py
CHANGED
@@ -65,8 +65,8 @@ def train_model(model, trainloader, epochs=100, device='cuda'):
|
|
65 |
|
66 |
best_acc = 0.0
|
67 |
|
68 |
-
# Create epoch progress bar
|
69 |
-
epoch_pbar = tqdm(range(epochs), desc='Training')
|
70 |
|
71 |
for epoch in epoch_pbar:
|
72 |
model.train()
|
@@ -74,8 +74,11 @@ def train_model(model, trainloader, epochs=100, device='cuda'):
|
|
74 |
correct = 0
|
75 |
total = 0
|
76 |
|
77 |
-
# Create batch progress bar
|
78 |
-
batch_pbar = tqdm(trainloader,
|
|
|
|
|
|
|
79 |
|
80 |
for inputs, labels in batch_pbar:
|
81 |
inputs, labels = inputs.to(device), labels.to(device)
|
@@ -97,20 +100,18 @@ def train_model(model, trainloader, epochs=100, device='cuda'):
|
|
97 |
epoch_acc = 100. * correct / total
|
98 |
avg_loss = running_loss/len(trainloader)
|
99 |
|
100 |
-
# Update epoch
|
101 |
-
epoch_pbar.
|
102 |
-
'loss': f'{avg_loss:.3f}',
|
103 |
-
'accuracy': f'{epoch_acc:.2f}%'
|
104 |
-
})
|
105 |
|
106 |
scheduler.step(epoch_acc)
|
107 |
|
108 |
if epoch_acc > best_acc:
|
109 |
best_acc = epoch_acc
|
110 |
save_model(model, 'best_model.pth')
|
|
|
111 |
|
112 |
if epoch_acc > 70:
|
113 |
-
|
114 |
break
|
115 |
|
116 |
if __name__ == "__main__":
|
@@ -125,4 +126,4 @@ if __name__ == "__main__":
|
|
125 |
model = get_model(num_classes=10)
|
126 |
|
127 |
# Train model
|
128 |
-
train_model(model, trainloader, epochs=
|
|
|
65 |
|
66 |
best_acc = 0.0
|
67 |
|
68 |
+
# Create epoch progress bar without a description (we'll use it for stats only)
|
69 |
+
epoch_pbar = tqdm(range(epochs), desc='Training Progress', position=0)
|
70 |
|
71 |
for epoch in epoch_pbar:
|
72 |
model.train()
|
|
|
74 |
correct = 0
|
75 |
total = 0
|
76 |
|
77 |
+
# Create batch progress bar with position below epoch bar
|
78 |
+
batch_pbar = tqdm(trainloader,
|
79 |
+
desc=f'Epoch {epoch+1}',
|
80 |
+
position=1,
|
81 |
+
leave=True)
|
82 |
|
83 |
for inputs, labels in batch_pbar:
|
84 |
inputs, labels = inputs.to(device), labels.to(device)
|
|
|
100 |
epoch_acc = 100. * correct / total
|
101 |
avg_loss = running_loss/len(trainloader)
|
102 |
|
103 |
+
# Update epoch status with more detailed format
|
104 |
+
epoch_pbar.write(f'Epoch {epoch+1}: Loss: {avg_loss:.3f} | Accuracy: {epoch_acc:.2f}%')
|
|
|
|
|
|
|
105 |
|
106 |
scheduler.step(epoch_acc)
|
107 |
|
108 |
if epoch_acc > best_acc:
|
109 |
best_acc = epoch_acc
|
110 |
save_model(model, 'best_model.pth')
|
111 |
+
epoch_pbar.write(f'New best accuracy: {epoch_acc:.2f}%')
|
112 |
|
113 |
if epoch_acc > 70:
|
114 |
+
epoch_pbar.write(f"\nReached target accuracy of 70%!")
|
115 |
break
|
116 |
|
117 |
if __name__ == "__main__":
|
|
|
126 |
model = get_model(num_classes=10)
|
127 |
|
128 |
# Train model
|
129 |
+
train_model(model, trainloader, epochs=20, device=device)
|