spycoder commited on
Commit
82be3cc
·
1 Parent(s): b9659a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -16
app.py CHANGED
@@ -39,8 +39,24 @@ The model was trained on Thai audio recordings with the following sentences, so
39
  <img src="https://huggingface.co/spaces/course-demos/Rick_and_Morty_QA/resolve/main/rick.png" width=200px>
40
  """
41
 
 
 
 
 
 
42
 
 
 
 
 
 
 
 
43
 
 
 
 
 
44
 
45
 
46
  def predict(microphone,file_upload):
@@ -62,22 +78,7 @@ def predict(microphone,file_upload):
62
  if(file_upload is not None):
63
  file_path = file_upload
64
 
65
- model.eval()
66
- with torch.no_grad():
67
- wav_data, _ = sf.read(file_path.name)
68
- inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
69
-
70
- input_values = inputs.input_values.squeeze(0)
71
- if max_length - input_values.shape[-1] > 0:
72
- input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1)
73
- else:
74
- input_values = input_values[:max_length]
75
- input_values = input_values.unsqueeze(0).to(device)
76
- inputs = {"input_values": input_values}
77
-
78
- logits = model(**inputs).logits
79
- logits = logits.squeeze()
80
- predicted_class_id = torch.argmax(logits, dim=-1).item()
81
  if(predicted_class_id==0):
82
  ans = "no_parkinson"
83
  else:
 
39
  <img src="https://huggingface.co/spaces/course-demos/Rick_and_Morty_QA/resolve/main/rick.png" width=200px>
40
  """
41
 
42
+ def actualpredict(file_path):
43
+ model.eval()
44
+ with torch.no_grad():
45
+ wav_data, _ = sf.read(file_path.name)
46
+ inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
47
 
48
+ input_values = inputs.input_values.squeeze(0)
49
+ if max_length - input_values.shape[-1] > 0:
50
+ input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1)
51
+ else:
52
+ input_values = input_values[:max_length]
53
+ input_values = input_values.unsqueeze(0).to(device)
54
+ inputs = {"input_values": input_values}
55
 
56
+ logits = model(**inputs).logits
57
+ logits = logits.squeeze()
58
+ predicted_class_id = torch.argmax(logits, dim=-1).item()
59
+ return predicted_class_id
60
 
61
 
62
  def predict(microphone,file_upload):
 
78
  if(file_upload is not None):
79
  file_path = file_upload
80
 
81
+ predicted_class_id = actualpredict(file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  if(predicted_class_id==0):
83
  ans = "no_parkinson"
84
  else: