|
import torch |
|
import whisper |
|
import torchaudio |
|
import gradio as gr |
|
import torch.nn as nn |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
class DialectClassifier(nn.Module): |
|
def __init__(self, input_dim, num_classes): |
|
super(DialectClassifier, self).__init__() |
|
self.fc1 = nn.Linear(input_dim, 128) |
|
self.fc2 = nn.Linear(128, 64) |
|
self.fc3 = nn.Linear(64, num_classes) |
|
self.relu = nn.ReLU() |
|
|
|
def forward(self, x): |
|
x = x.view(x.size(0), -1) |
|
x = self.relu(self.fc1(x)) |
|
x = self.relu(self.fc2(x)) |
|
x = self.fc3(x) |
|
return x |
|
|
|
|
|
def preprocess_audio(file_path, whisper_model, device): |
|
def load_audio(file_path): |
|
waveform, sample_rate = torchaudio.load(file_path) |
|
if sample_rate != 16000: |
|
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) |
|
|
|
|
|
if waveform.shape[0] > 1: |
|
waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
|
|
|
|
desired_length = 16000 * 30 |
|
current_length = waveform.shape[1] |
|
|
|
if current_length < desired_length: |
|
|
|
padding = desired_length - current_length |
|
waveform = torch.nn.functional.pad(waveform, (0, padding)) |
|
elif current_length > desired_length: |
|
|
|
waveform = waveform[:, :desired_length] |
|
|
|
return waveform |
|
|
|
audio = load_audio(file_path) |
|
audio = whisper.pad_or_trim(audio.flatten()) |
|
mel = whisper.log_mel_spectrogram(audio).to_dense() |
|
|
|
with torch.no_grad(): |
|
mel = mel.unsqueeze(0).to(device) |
|
features = whisper_model.encoder(mel) |
|
return features |
|
|
|
repo_id = "dipankar53/assamese_dialect_classifier_model" |
|
model_filename = "dialect_classifier_model.pth" |
|
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename) |
|
|
|
label_to_idx = {"Darrangiya Accent": 0, "Kamrupiya Accent": 1, "Upper Assam": 2, "Nalbaria Accent": 3} |
|
|
|
|
|
whisper_model = whisper.load_model("medium") |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
num_classes = len(label_to_idx) |
|
sample_input = torch.randn(1, 80, 3000).to(device) |
|
with torch.no_grad(): |
|
sample_output = whisper_model.encoder(sample_input) |
|
input_dim = sample_output.view(1, -1).shape[1] |
|
|
|
model = DialectClassifier(input_dim, num_classes) |
|
model.load_state_dict(torch.load(model_path, map_location=device)) |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
def predict_dialect(audio_path): |
|
try: |
|
|
|
features = preprocess_audio(audio_path, whisper_model, device) |
|
features = features.view(1, -1) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(features) |
|
_, predicted = torch.max(outputs, 1) |
|
|
|
|
|
idx_to_label = {idx: label for label, idx in label_to_idx.items()} |
|
predicted_label = idx_to_label[predicted.item()] |
|
|
|
return f"Predicted Dialect: {predicted_label}" |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict_dialect, |
|
inputs=gr.Audio(sources=["upload", "microphone"], type="filepath"), |
|
outputs="text", |
|
title="Assamese Dialect Prediction", |
|
description="Upload an Assamese audio file to predict its dialect.", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch() |