File size: 1,780 Bytes
73cab25 ba42b9f 73cab25 ba42b9f 1c411ce df4dfab e1cd816 ba42b9f 81c7e64 223eb95 7d73eec ba42b9f 73cab25 ba42b9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import gradio as gr
import torch
import soundfile as sf
import os
import numpy as np
import os
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
from collections import Counter
device = torch.device("cpu")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
model_path = "dysarthria_classifier12.pth"
# model_path = '/home/user/app/dysarthria_classifier12.pth'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
# if os.path.exists(model_path):
# print(f"Loading saved model {model_path}")
# model.load_state_dict(torch.load(model_path))
def predict(file_path):
max_length = 100000
model.eval()
with torch.no_grad():
wav_data, _ = sf.read(file_path.name)
inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
input_values = inputs.input_values.squeeze(0)
if max_length - input_values.shape[-1] > 0:
input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1)
else:
input_values = input_values[:max_length]
input_values = input_values.unsqueeze(0).to(device)
inputs = {"input_values": input_values}
logits = model(**inputs).logits
logits = logits.squeeze()
predicted_class_id = torch.argmax(logits, dim=-1).item()
return predicted_class_id
iface = gr.Interface(fn=predict, inputs="file", outputs="text")
iface.launch()
|