Update app.py
Browse files
app.py
CHANGED
@@ -3,7 +3,6 @@ import torch
|
|
3 |
import torch.nn.functional as F
|
4 |
from torchvision import transforms
|
5 |
from PIL import Image
|
6 |
-
from model import FineTunedResNet
|
7 |
import os
|
8 |
import time
|
9 |
|
@@ -15,6 +14,31 @@ transform = transforms.Compose([
|
|
15 |
])
|
16 |
|
17 |
# Load the trained ResNet50 model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
model = FineTunedResNet(num_classes=4)
|
19 |
model_path = 'models/final_fine_tuned_resnet50.pth'
|
20 |
|
@@ -89,30 +113,30 @@ custom_css = """
|
|
89 |
color: #333333;
|
90 |
font-weight: bold;
|
91 |
font-size: 24px;
|
92 |
-
margin-bottom: 10px
|
93 |
}
|
94 |
.gradio-description {
|
95 |
color: #666666;
|
96 |
-
font-size: 16px
|
97 |
-
margin-bottom: 20px
|
98 |
}
|
99 |
.gradio-image {
|
100 |
-
border-radius: 10px
|
101 |
}
|
102 |
.gradio-button {
|
103 |
-
background-color: #007bff
|
104 |
-
color: #ffffff
|
105 |
-
border: none
|
106 |
-
padding: 10px 20px
|
107 |
-
border-radius: 5px
|
108 |
-
cursor: pointer
|
109 |
}
|
110 |
.gradio-button:hover {
|
111 |
-
background-color: #0056b3
|
112 |
}
|
113 |
.gradio-label {
|
114 |
-
color: #007bff
|
115 |
-
font-weight: bold
|
116 |
}
|
117 |
"""
|
118 |
|
|
|
3 |
import torch.nn.functional as F
|
4 |
from torchvision import transforms
|
5 |
from PIL import Image
|
|
|
6 |
import os
|
7 |
import time
|
8 |
|
|
|
14 |
])
|
15 |
|
16 |
# Load the trained ResNet50 model
|
17 |
+
class FineTunedResNet(nn.Module):
|
18 |
+
def __init__(self, num_classes=4):
|
19 |
+
super(FineTunedResNet, self).__init__()
|
20 |
+
self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) # Load pre-trained ResNet50
|
21 |
+
|
22 |
+
# Replace the fully connected layer with more layers and batch normalization
|
23 |
+
self.resnet.fc = nn.Sequential(
|
24 |
+
nn.Linear(self.resnet.fc.in_features, 1024), # First additional layer
|
25 |
+
nn.BatchNorm1d(1024),
|
26 |
+
nn.ReLU(),
|
27 |
+
nn.Dropout(0.5),
|
28 |
+
nn.Linear(1024, 512), # Second additional layer
|
29 |
+
nn.BatchNorm1d(512),
|
30 |
+
nn.ReLU(),
|
31 |
+
nn.Dropout(0.5),
|
32 |
+
nn.Linear(512, 256), # Third additional layer
|
33 |
+
nn.BatchNorm1d(256),
|
34 |
+
nn.ReLU(),
|
35 |
+
nn.Dropout(0.5),
|
36 |
+
nn.Linear(256, num_classes) # Output layer
|
37 |
+
)
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
return self.resnet(x)
|
41 |
+
|
42 |
model = FineTunedResNet(num_classes=4)
|
43 |
model_path = 'models/final_fine_tuned_resnet50.pth'
|
44 |
|
|
|
113 |
color: #333333;
|
114 |
font-weight: bold;
|
115 |
font-size: 24px;
|
116 |
+
margin-bottom: 10px.
|
117 |
}
|
118 |
.gradio-description {
|
119 |
color: #666666;
|
120 |
+
font-size: 16px.
|
121 |
+
margin-bottom: 20px.
|
122 |
}
|
123 |
.gradio-image {
|
124 |
+
border-radius: 10px.
|
125 |
}
|
126 |
.gradio-button {
|
127 |
+
background-color: #007bff.
|
128 |
+
color: #ffffff.
|
129 |
+
border: none.
|
130 |
+
padding: 10px 20px.
|
131 |
+
border-radius: 5px.
|
132 |
+
cursor: pointer.
|
133 |
}
|
134 |
.gradio-button:hover {
|
135 |
+
background-color: #0056b3.
|
136 |
}
|
137 |
.gradio-label {
|
138 |
+
color: #007bff.
|
139 |
+
font-weight: bold.
|
140 |
}
|
141 |
"""
|
142 |
|