File size: 1,780 Bytes
73cab25
ba42b9f
 
 
 
73cab25
ba42b9f
 
 
 
 
 
 
 
1c411ce
df4dfab
e1cd816
ba42b9f
 
81c7e64
223eb95
 
7d73eec
 
 
ba42b9f
 
73cab25
ba42b9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gradio as gr
import torch
import soundfile as sf
import os
import numpy as np

import os
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
from collections import Counter

device = torch.device("cpu")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
model_path = "dysarthria_classifier12.pth"
# model_path = '/home/user/app/dysarthria_classifier12.pth'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

# if os.path.exists(model_path):
#     print(f"Loading saved model {model_path}")
#     model.load_state_dict(torch.load(model_path))
def predict(file_path):
    max_length = 100000

    model.eval()
    with torch.no_grad():
        wav_data, _ = sf.read(file_path.name)
        inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)

        input_values = inputs.input_values.squeeze(0)  
        if max_length - input_values.shape[-1] > 0:
            input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1)
        else:
            input_values = input_values[:max_length]
        input_values = input_values.unsqueeze(0).to(device)
        inputs = {"input_values": input_values}

        logits = model(**inputs).logits
        logits = logits.squeeze()
        predicted_class_id = torch.argmax(logits, dim=-1).item()

    return predicted_class_id

iface = gr.Interface(fn=predict, inputs="file", outputs="text")
iface.launch()