import torch import whisper import torchaudio import gradio as gr import torch.nn as nn from huggingface_hub import hf_hub_download # Define the same model class used during training class DialectClassifier(nn.Module): def __init__(self, input_dim, num_classes): super(DialectClassifier, self).__init__() self.fc1 = nn.Linear(input_dim, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, num_classes) self.relu = nn.ReLU() def forward(self, x): x = x.view(x.size(0), -1) # Flatten the input tensor x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x # Function to preprocess audio and extract features def preprocess_audio(file_path, whisper_model, device): def load_audio(file_path): waveform, sample_rate = torchaudio.load(file_path) if sample_rate != 16000: waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) # Convert to single channel (mono) if necessary if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # Pad or trim audio to 30 seconds desired_length = 16000 * 30 # 30 seconds at 16 kHz current_length = waveform.shape[1] if current_length < desired_length: # Pad with zeros padding = desired_length - current_length waveform = torch.nn.functional.pad(waveform, (0, padding)) elif current_length > desired_length: # Trim to desired length waveform = waveform[:, :desired_length] return waveform audio = load_audio(file_path) audio = whisper.pad_or_trim(audio.flatten()) mel = whisper.log_mel_spectrogram(audio).to_dense() with torch.no_grad(): mel = mel.unsqueeze(0).to(device) # Add batch dimension and move to device features = whisper_model.encoder(mel) return features repo_id = "dipankar53/assamese_dialect_classifier_model" model_filename = "dialect_classifier_model.pth" model_path = hf_hub_download(repo_id=repo_id, filename=model_filename) label_to_idx = {"Darrangiya Accent": 0, "Kamrupiya Accent": 1, "Upper Assam": 2, "Nalbaria Accent": 3} # Load Whisper model whisper_model = whisper.load_model("medium") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the trained model num_classes = len(label_to_idx) sample_input = torch.randn(1, 80, 3000).to(device) with torch.no_grad(): sample_output = whisper_model.encoder(sample_input) input_dim = sample_output.view(1, -1).shape[1] # Flatten and get dimension model = DialectClassifier(input_dim, num_classes) model.load_state_dict(torch.load(model_path, map_location=device)) model.to(device) model.eval() # Function to predict the dialect of a single audio file def predict_dialect(audio_path): try: # Preprocess audio and extract features features = preprocess_audio(audio_path, whisper_model, device) features = features.view(1, -1) # Flatten features # Perform prediction with torch.no_grad(): outputs = model(features) _, predicted = torch.max(outputs, 1) # Map predicted index back to dialect label idx_to_label = {idx: label for label, idx in label_to_idx.items()} predicted_label = idx_to_label[predicted.item()] return f"Predicted Dialect: {predicted_label}" except Exception as e: return f"Error: {str(e)}" # Define Gradio interface interface = gr.Interface( fn=predict_dialect, inputs=gr.Audio(sources=["upload", "microphone"], type="filepath"), outputs="text", title="Assamese Dialect Prediction", description="Upload an Assamese audio file to predict its dialect.", ) # Launch the interface if __name__ == "__main__": interface.launch()