spycoder commited on
Commit
23804b3
·
1 Parent(s): 0816085

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -62,13 +62,16 @@ def predict(file_path):
62
  logits = model(**inputs).logits
63
  logits = logits.squeeze()
64
  predicted_class_id = torch.argmax(logits, dim=-1).item()
65
-
66
- return predicted_class_id
 
 
 
67
  gr.Interface(
68
  fn=predict,
69
  inputs=[
70
- gr.inputs.Audio(source="microphone", type="filepath", optional=True),
71
- gr.inputs.Audio(source="upload", type="filepath", optional=True),
72
  ],
73
  outputs="text",
74
  title=title,
 
62
  logits = model(**inputs).logits
63
  logits = logits.squeeze()
64
  predicted_class_id = torch.argmax(logits, dim=-1).item()
65
+ if(predicted_class_id==0):
66
+ ans = "no_parkinson"
67
+ else:
68
+ ans = "parkinson"
69
+ return ans
70
  gr.Interface(
71
  fn=predict,
72
  inputs=[
73
+ gr.inputs.Audio(source="microphone", type="file", optional=True),
74
+ gr.inputs.Audio(source="upload", type="file", optional=True),
75
  ],
76
  outputs="text",
77
  title=title,