File size: 3,833 Bytes
73cab25
ba42b9f
 
 
 
73cab25
ba42b9f
 
 
 
 
 
 
 
1c411ce
df4dfab
e1cd816
ba42b9f
5914cfd
2cbb9da
5914cfd
223eb95
5914cfd
 
 
aa1c032
 
b9659a7
aa1c032
b9659a7
33a5bcf
 
 
 
 
 
 
 
aa1c032
 
 
82be3cc
 
 
 
 
aa1c032
82be3cc
 
 
 
 
 
 
b1ac211
82be3cc
 
 
 
aa1c032
 
ab2603c
 
ba42b9f
ab2603c
e99bdfa
55acf81
ab2603c
 
 
 
 
 
 
 
7654888
 
 
 
73cab25
82be3cc
23804b3
 
 
 
ab2603c
5914cfd
 
0816085
acab7b5
 
0816085
5914cfd
 
 
 
5549008
5914cfd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import torch
import soundfile as sf
import os
import numpy as np

import os
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
from collections import Counter

device = torch.device("cpu")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
model_path = "dysarthria_classifier12.pth"
# model_path = 'model_weights2.pth'
# model_path = '/home/user/app/dysarthria_classifier10.pth'

if os.path.exists(model_path):
    print(f"Loading saved model {model_path}")
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


title = "Upload an mp3 file for supranuclear palsy (SP) detection! (Thai Language)"
description = """
The model was trained on Thai audio recordings with the following sentences, so submit audio recordings for one of these sentences:\n
ชาวไร่ตัดต้นสนทำท่อนซุง\n
ปูม้าวิ่งไปมาบนใบไม้ (เน้นใช้ริมฝีปาก)\n
อีกาคอยคาบงูคาบไก่ (เน้นใช้เพดานปาก)\n
เพียงแค่ฝนตกลงที่หน้าต่างในบางครา\n
“อาาาาาาาาาาา”\n
“อีีีีีีีีี”\n
“อาาาา” (ดังขึ้นเรื่อยๆ)\n
“อาา อาาา อาาาาา”\n
<img src="https://huggingface.co/spaces/course-demos/Rick_and_Morty_QA/resolve/main/rick.png" width=200px>
"""

def actualpredict(file_path):
    model.eval()
    with torch.no_grad():
        wav_data, _ = sf.read(file_path.name)
        inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)

        input_values = inputs.input_values.squeeze(0)  
        if max_length - input_values.shape[-1] > 0:
            input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1)
        else:
            input_values = input_values[:max_length]
        input_values = input_values.unsqueeze(0).to(device)
        inputs = {"input_values": input_values}

        logits = model(**inputs).logits
        logits = logits.squeeze()
        predicted_class_id = torch.argmax(logits, dim=-1).item()
    return predicted_class_id


def predict(microphone,file_upload):
    
    max_length = 100000
    warn_output = " "
    ans = " "
    file_path = microphone
    if (microphone is not None) and (file_upload is not None):
        warn_output = (
            "WARNING: You've uploaded an audio file and used the microphone. "
            "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
        )

    elif (microphone is None) and (file_upload is None):
        return "ERROR: You have to either use the microphone or upload an audio file"
    if(microphone is not None):
        file_path = microphone
    if(file_upload is not None):
        file_path = file_upload

    predicted_class_id = actualpredict(file_path)
    if(predicted_class_id==0):
        ans = "no_parkinson"
    else:
        ans = "parkinson"
    return warn_output + ans
gr.Interface(
    fn=predict,
    inputs=[
        gr.inputs.Audio(source="microphone", type="filepath", optional=True),
        gr.inputs.Audio(source="upload", type="filepath", optional=True),
    ],
    outputs="text",
    title=title,
    description=description,
).launch()

# iface = gr.Interface(fn=predict, inputs="file", outputs="text")
# iface.launch()