dipankar53's picture
Add requirements.txt, Modifiy app.py
cfa5958
raw
history blame
3.9 kB
import torch
import whisper
import torchaudio
import gradio as gr
import torch.nn as nn
from huggingface_hub import hf_hub_download
# Define the same model class used during training
class DialectClassifier(nn.Module):
def __init__(self, input_dim, num_classes):
super(DialectClassifier, self).__init__()
self.fc1 = nn.Linear(input_dim, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, num_classes)
self.relu = nn.ReLU()
def forward(self, x):
x = x.view(x.size(0), -1) # Flatten the input tensor
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
# Function to preprocess audio and extract features
def preprocess_audio(file_path, whisper_model, device):
def load_audio(file_path):
waveform, sample_rate = torchaudio.load(file_path)
if sample_rate != 16000:
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
# Convert to single channel (mono) if necessary
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Pad or trim audio to 30 seconds
desired_length = 16000 * 30 # 30 seconds at 16 kHz
current_length = waveform.shape[1]
if current_length < desired_length:
# Pad with zeros
padding = desired_length - current_length
waveform = torch.nn.functional.pad(waveform, (0, padding))
elif current_length > desired_length:
# Trim to desired length
waveform = waveform[:, :desired_length]
return waveform
audio = load_audio(file_path)
audio = whisper.pad_or_trim(audio.flatten())
mel = whisper.log_mel_spectrogram(audio).to_dense()
with torch.no_grad():
mel = mel.unsqueeze(0).to(device) # Add batch dimension and move to device
features = whisper_model.encoder(mel)
return features
repo_id = "dipankar53/assamese_dialect_classifier_model"
model_filename = "dialect_classifier_model.pth"
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
label_to_idx = {"Darrangiya Accent": 0, "Kamrupiya Accent": 1, "Upper Assam": 2, "Nalbaria Accent": 3}
# Load Whisper model
whisper_model = whisper.load_model("medium")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the trained model
num_classes = len(label_to_idx)
sample_input = torch.randn(1, 80, 3000).to(device)
with torch.no_grad():
sample_output = whisper_model.encoder(sample_input)
input_dim = sample_output.view(1, -1).shape[1] # Flatten and get dimension
model = DialectClassifier(input_dim, num_classes)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
# Function to predict the dialect of a single audio file
def predict_dialect(audio_path):
try:
# Preprocess audio and extract features
features = preprocess_audio(audio_path, whisper_model, device)
features = features.view(1, -1) # Flatten features
# Perform prediction
with torch.no_grad():
outputs = model(features)
_, predicted = torch.max(outputs, 1)
# Map predicted index back to dialect label
idx_to_label = {idx: label for label, idx in label_to_idx.items()}
predicted_label = idx_to_label[predicted.item()]
return f"Predicted Dialect: {predicted_label}"
except Exception as e:
return f"Error: {str(e)}"
# Define Gradio interface
interface = gr.Interface(
fn=predict_dialect,
inputs=gr.Audio(sources=["upload", "microphone"], type="filepath"),
outputs="text",
title="Assamese Dialect Prediction",
description="Upload an Assamese audio file to predict its dialect.",
)
# Launch the interface
if __name__ == "__main__":
interface.launch()