resberry commited on
Commit
5761260
·
verified ·
1 Parent(s): bc62586

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -14
app.py CHANGED
@@ -3,7 +3,6 @@ import torch
3
  import torch.nn.functional as F
4
  from torchvision import transforms
5
  from PIL import Image
6
- from model import FineTunedResNet
7
  import os
8
  import time
9
 
@@ -15,6 +14,31 @@ transform = transforms.Compose([
15
  ])
16
 
17
  # Load the trained ResNet50 model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  model = FineTunedResNet(num_classes=4)
19
  model_path = 'models/final_fine_tuned_resnet50.pth'
20
 
@@ -89,30 +113,30 @@ custom_css = """
89
  color: #333333;
90
  font-weight: bold;
91
  font-size: 24px;
92
- margin-bottom: 10px;
93
  }
94
  .gradio-description {
95
  color: #666666;
96
- font-size: 16px;
97
- margin-bottom: 20px;
98
  }
99
  .gradio-image {
100
- border-radius: 10px;
101
  }
102
  .gradio-button {
103
- background-color: #007bff;
104
- color: #ffffff;
105
- border: none;
106
- padding: 10px 20px;
107
- border-radius: 5px;
108
- cursor: pointer;
109
  }
110
  .gradio-button:hover {
111
- background-color: #0056b3;
112
  }
113
  .gradio-label {
114
- color: #007bff;
115
- font-weight: bold;
116
  }
117
  """
118
 
 
3
  import torch.nn.functional as F
4
  from torchvision import transforms
5
  from PIL import Image
 
6
  import os
7
  import time
8
 
 
14
  ])
15
 
16
  # Load the trained ResNet50 model
17
+ class FineTunedResNet(nn.Module):
18
+ def __init__(self, num_classes=4):
19
+ super(FineTunedResNet, self).__init__()
20
+ self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) # Load pre-trained ResNet50
21
+
22
+ # Replace the fully connected layer with more layers and batch normalization
23
+ self.resnet.fc = nn.Sequential(
24
+ nn.Linear(self.resnet.fc.in_features, 1024), # First additional layer
25
+ nn.BatchNorm1d(1024),
26
+ nn.ReLU(),
27
+ nn.Dropout(0.5),
28
+ nn.Linear(1024, 512), # Second additional layer
29
+ nn.BatchNorm1d(512),
30
+ nn.ReLU(),
31
+ nn.Dropout(0.5),
32
+ nn.Linear(512, 256), # Third additional layer
33
+ nn.BatchNorm1d(256),
34
+ nn.ReLU(),
35
+ nn.Dropout(0.5),
36
+ nn.Linear(256, num_classes) # Output layer
37
+ )
38
+
39
+ def forward(self, x):
40
+ return self.resnet(x)
41
+
42
  model = FineTunedResNet(num_classes=4)
43
  model_path = 'models/final_fine_tuned_resnet50.pth'
44
 
 
113
  color: #333333;
114
  font-weight: bold;
115
  font-size: 24px;
116
+ margin-bottom: 10px.
117
  }
118
  .gradio-description {
119
  color: #666666;
120
+ font-size: 16px.
121
+ margin-bottom: 20px.
122
  }
123
  .gradio-image {
124
+ border-radius: 10px.
125
  }
126
  .gradio-button {
127
+ background-color: #007bff.
128
+ color: #ffffff.
129
+ border: none.
130
+ padding: 10px 20px.
131
+ border-radius: 5px.
132
+ cursor: pointer.
133
  }
134
  .gradio-button:hover {
135
+ background-color: #0056b3.
136
  }
137
  .gradio-label {
138
+ color: #007bff.
139
+ font-weight: bold.
140
  }
141
  """
142