import torch
import whisper
import torchaudio
import gradio as gr
import torch.nn as nn
from huggingface_hub import hf_hub_download


# Define the same model class used during training
class DialectClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(DialectClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, num_classes)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input tensor
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Function to preprocess audio and extract features
def preprocess_audio(file_path, whisper_model, device):
    def load_audio(file_path):
        waveform, sample_rate = torchaudio.load(file_path)
        if sample_rate != 16000:
            waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)

        # Convert to single channel (mono) if necessary
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)

        # Pad or trim audio to 30 seconds
        desired_length = 16000 * 30  # 30 seconds at 16 kHz
        current_length = waveform.shape[1]

        if current_length < desired_length:
            # Pad with zeros
            padding = desired_length - current_length
            waveform = torch.nn.functional.pad(waveform, (0, padding))
        elif current_length > desired_length:
            # Trim to desired length
            waveform = waveform[:, :desired_length]

        return waveform

    audio = load_audio(file_path)
    audio = whisper.pad_or_trim(audio.flatten())
    mel = whisper.log_mel_spectrogram(audio).to_dense()

    with torch.no_grad():
        mel = mel.unsqueeze(0).to(device)  # Add batch dimension and move to device
        features = whisper_model.encoder(mel)
    return features

repo_id = "dipankar53/assamese_dialect_classifier_model"
model_filename = "dialect_classifier_model.pth"
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)

label_to_idx = {"Darrangiya Accent": 0, "Kamrupiya Accent": 1, "Upper Assam": 2, "Nalbaria Accent": 3}

# Load Whisper model
whisper_model = whisper.load_model("medium")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the trained model
num_classes = len(label_to_idx)
sample_input = torch.randn(1, 80, 3000).to(device)
with torch.no_grad():
  sample_output = whisper_model.encoder(sample_input)
input_dim = sample_output.view(1, -1).shape[1] # Flatten and get dimension

model = DialectClassifier(input_dim, num_classes)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()

# Function to predict the dialect of a single audio file
def predict_dialect(audio_path):
    try:
        # Preprocess audio and extract features
        features = preprocess_audio(audio_path, whisper_model, device)
        features = features.view(1, -1)  # Flatten features

        # Perform prediction
        with torch.no_grad():
            outputs = model(features)
            _, predicted = torch.max(outputs, 1)

        # Map predicted index back to dialect label
        idx_to_label = {idx: label for label, idx in label_to_idx.items()}
        predicted_label = idx_to_label[predicted.item()]

        return f"Predicted Dialect: {predicted_label}"
    except Exception as e:
        return f"Error: {str(e)}"

# Define Gradio interface
interface = gr.Interface(
    fn=predict_dialect,
    inputs=gr.Audio(sources=["upload", "microphone"], type="filepath"),
    outputs="text",
    title="Assamese Dialect Prediction",
    description="Upload an Assamese audio file to predict its dialect.",
)

# Launch the interface
if __name__ == "__main__":
    interface.launch()