mmek commited on
Commit
77c1868
·
1 Parent(s): 953be88

add model transforms

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -14,8 +14,11 @@ categories = ("Aculus Olearius", "Healthy", "Peacock Spot")
14
  def classify_health(input_img):
15
  input_img = transforms.ToTensor()(input_img)
16
  with torch.no_grad():
17
- image = data_transforms(input_img).unsqueeze(0)
18
- probs = torch.nn.functional.softmax(model(image), dim=0)
 
 
 
19
  idx = probs.argmax(dim=1)
20
  return dict(zip(categories, map(float, probs[0])))
21
 
 
14
  def classify_health(input_img):
15
  input_img = transforms.ToTensor()(input_img)
16
  with torch.no_grad():
17
+ image = data_transforms(input_img).unsqueeze(0)
18
+ output = model(image)
19
+ print(output)
20
+ print(output.shape)
21
+ probs = torch.nn.functional.softmax(model(image)[0], dim=0)
22
  idx = probs.argmax(dim=1)
23
  return dict(zip(categories, map(float, probs[0])))
24