4darsh-Dev commited on
Commit
c4b9624
·
verified ·
1 Parent(s): 2f686da
Files changed (1) hide show
  1. app.py +56 -1
app.py CHANGED
@@ -119,7 +119,62 @@ import torchvision.transforms as transforms
119
  from PIL import Image
120
  import json
121
  import os
122
- from leaf_disease_predict import ResNet9, load_model, predict_image, CLASS_NAMES
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  # Load the model
125
  model_path = 'models/leaf_disease_res50_model_epoch_10.pth'
 
119
  from PIL import Image
120
  import json
121
  import os
122
+ # from leaf_disease_predict import ResNet9, load_model, predict_image, CLASS_NAMES
123
+
124
+ class ResNet9(ImageClassificationBase):
125
+ def __init__(self, in_channels, num_diseases):
126
+ super().__init__()
127
+
128
+ self.conv1 = ConvBlock(in_channels, 64)
129
+ self.conv2 = ConvBlock(64, 128, pool=True)
130
+ self.res1 = torch.nn.Sequential(ConvBlock(128, 128), ConvBlock(128, 128))
131
+
132
+ self.conv3 = ConvBlock(128, 256, pool=True)
133
+ self.conv4 = ConvBlock(256, 512, pool=True)
134
+ self.res2 = torch.nn.Sequential(ConvBlock(512, 512), ConvBlock(512, 512))
135
+
136
+ self.classifier = torch.nn.Sequential(torch.nn.MaxPool2d(4),
137
+ torch.nn.Flatten(),
138
+ torch.nn.Linear(512, num_diseases))
139
+
140
+ def forward(self, xb):
141
+ out = self.conv1(xb)
142
+ out = self.conv2(out)
143
+ out = self.res1(out) + out
144
+ out = self.conv3(out)
145
+ out = self.conv4(out)
146
+ out = self.res2(out) + out
147
+ out = self.classifier(out)
148
+ return out
149
+
150
+ CLASS_NAMES = [
151
+ 'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy',
152
+ 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy',
153
+
154
+ ]
155
+
156
+ def predict_image(image_path, model):
157
+ transform = transforms.Compose([
158
+ transforms.Resize((256, 256)),
159
+ transforms.ToTensor(),
160
+ ])
161
+
162
+ img = Image.open(image_path).convert('RGB')
163
+ img_tensor = transform(img).unsqueeze(0)
164
+
165
+ with torch.no_grad():
166
+ outputs = model(img_tensor)
167
+ _, predicted = torch.max(outputs, 1)
168
+
169
+ return CLASS_NAMES[predicted.item()]
170
+
171
+
172
+ def load_model(model_path):
173
+ model = torch.load(model_path, map_location=torch.device('cpu'))
174
+ model.eval()
175
+ return model
176
+
177
+
178
 
179
  # Load the model
180
  model_path = 'models/leaf_disease_res50_model_epoch_10.pth'