updated
Browse files
app.py
CHANGED
@@ -119,7 +119,62 @@ import torchvision.transforms as transforms
|
|
119 |
from PIL import Image
|
120 |
import json
|
121 |
import os
|
122 |
-
from leaf_disease_predict import ResNet9, load_model, predict_image, CLASS_NAMES
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
# Load the model
|
125 |
model_path = 'models/leaf_disease_res50_model_epoch_10.pth'
|
|
|
119 |
from PIL import Image
|
120 |
import json
|
121 |
import os
|
122 |
+
# from leaf_disease_predict import ResNet9, load_model, predict_image, CLASS_NAMES
|
123 |
+
|
124 |
+
class ResNet9(ImageClassificationBase):
|
125 |
+
def __init__(self, in_channels, num_diseases):
|
126 |
+
super().__init__()
|
127 |
+
|
128 |
+
self.conv1 = ConvBlock(in_channels, 64)
|
129 |
+
self.conv2 = ConvBlock(64, 128, pool=True)
|
130 |
+
self.res1 = torch.nn.Sequential(ConvBlock(128, 128), ConvBlock(128, 128))
|
131 |
+
|
132 |
+
self.conv3 = ConvBlock(128, 256, pool=True)
|
133 |
+
self.conv4 = ConvBlock(256, 512, pool=True)
|
134 |
+
self.res2 = torch.nn.Sequential(ConvBlock(512, 512), ConvBlock(512, 512))
|
135 |
+
|
136 |
+
self.classifier = torch.nn.Sequential(torch.nn.MaxPool2d(4),
|
137 |
+
torch.nn.Flatten(),
|
138 |
+
torch.nn.Linear(512, num_diseases))
|
139 |
+
|
140 |
+
def forward(self, xb):
|
141 |
+
out = self.conv1(xb)
|
142 |
+
out = self.conv2(out)
|
143 |
+
out = self.res1(out) + out
|
144 |
+
out = self.conv3(out)
|
145 |
+
out = self.conv4(out)
|
146 |
+
out = self.res2(out) + out
|
147 |
+
out = self.classifier(out)
|
148 |
+
return out
|
149 |
+
|
150 |
+
CLASS_NAMES = [
|
151 |
+
'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy',
|
152 |
+
'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy',
|
153 |
+
|
154 |
+
]
|
155 |
+
|
156 |
+
def predict_image(image_path, model):
|
157 |
+
transform = transforms.Compose([
|
158 |
+
transforms.Resize((256, 256)),
|
159 |
+
transforms.ToTensor(),
|
160 |
+
])
|
161 |
+
|
162 |
+
img = Image.open(image_path).convert('RGB')
|
163 |
+
img_tensor = transform(img).unsqueeze(0)
|
164 |
+
|
165 |
+
with torch.no_grad():
|
166 |
+
outputs = model(img_tensor)
|
167 |
+
_, predicted = torch.max(outputs, 1)
|
168 |
+
|
169 |
+
return CLASS_NAMES[predicted.item()]
|
170 |
+
|
171 |
+
|
172 |
+
def load_model(model_path):
|
173 |
+
model = torch.load(model_path, map_location=torch.device('cpu'))
|
174 |
+
model.eval()
|
175 |
+
return model
|
176 |
+
|
177 |
+
|
178 |
|
179 |
# Load the model
|
180 |
model_path = 'models/leaf_disease_res50_model_epoch_10.pth'
|