File size: 3,833 Bytes
73cab25 ba42b9f 73cab25 ba42b9f 1c411ce df4dfab e1cd816 ba42b9f 5914cfd 2cbb9da 5914cfd 223eb95 5914cfd aa1c032 b9659a7 aa1c032 b9659a7 33a5bcf aa1c032 82be3cc aa1c032 82be3cc b1ac211 82be3cc aa1c032 ab2603c ba42b9f ab2603c e99bdfa 55acf81 ab2603c 7654888 73cab25 82be3cc 23804b3 ab2603c 5914cfd 0816085 acab7b5 0816085 5914cfd 5549008 5914cfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import gradio as gr
import torch
import soundfile as sf
import os
import numpy as np
import os
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
from collections import Counter
device = torch.device("cpu")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
model_path = "dysarthria_classifier12.pth"
# model_path = 'model_weights2.pth'
# model_path = '/home/user/app/dysarthria_classifier10.pth'
if os.path.exists(model_path):
print(f"Loading saved model {model_path}")
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
title = "Upload an mp3 file for supranuclear palsy (SP) detection! (Thai Language)"
description = """
The model was trained on Thai audio recordings with the following sentences, so submit audio recordings for one of these sentences:\n
ชาวไร่ตัดต้นสนทำท่อนซุง\n
ปูม้าวิ่งไปมาบนใบไม้ (เน้นใช้ริมฝีปาก)\n
อีกาคอยคาบงูคาบไก่ (เน้นใช้เพดานปาก)\n
เพียงแค่ฝนตกลงที่หน้าต่างในบางครา\n
“อาาาาาาาาาาา”\n
“อีีีีีีีีี”\n
“อาาาา” (ดังขึ้นเรื่อยๆ)\n
“อาา อาาา อาาาาา”\n
<img src="https://huggingface.co/spaces/course-demos/Rick_and_Morty_QA/resolve/main/rick.png" width=200px>
"""
def actualpredict(file_path):
model.eval()
with torch.no_grad():
wav_data, _ = sf.read(file_path.name)
inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
input_values = inputs.input_values.squeeze(0)
if max_length - input_values.shape[-1] > 0:
input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1)
else:
input_values = input_values[:max_length]
input_values = input_values.unsqueeze(0).to(device)
inputs = {"input_values": input_values}
logits = model(**inputs).logits
logits = logits.squeeze()
predicted_class_id = torch.argmax(logits, dim=-1).item()
return predicted_class_id
def predict(microphone,file_upload):
max_length = 100000
warn_output = " "
ans = " "
file_path = microphone
if (microphone is not None) and (file_upload is not None):
warn_output = (
"WARNING: You've uploaded an audio file and used the microphone. "
"The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
)
elif (microphone is None) and (file_upload is None):
return "ERROR: You have to either use the microphone or upload an audio file"
if(microphone is not None):
file_path = microphone
if(file_upload is not None):
file_path = file_upload
predicted_class_id = actualpredict(file_path)
if(predicted_class_id==0):
ans = "no_parkinson"
else:
ans = "parkinson"
return warn_output + ans
gr.Interface(
fn=predict,
inputs=[
gr.inputs.Audio(source="microphone", type="filepath", optional=True),
gr.inputs.Audio(source="upload", type="filepath", optional=True),
],
outputs="text",
title=title,
description=description,
).launch()
# iface = gr.Interface(fn=predict, inputs="file", outputs="text")
# iface.launch() |