import spaces
import os
import torch
from transformers import AutoFeatureExtractor, WhisperModel, AutoModelForSpeechSeq2Seq
import numpy as np
import torchaudio
import librosa

import gradio as gr
from modules import load_audio, MosPredictor, denorm


mos_checkpoint = "ckpt_mosa_net_plus"

print('Loading MOSANET+ checkpoint...')
device = torch.device("cpu")
#torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = MosPredictor().to(device)
model.eval()
model.load_state_dict(torch.load(mos_checkpoint, map_location=device))

print('Loading Whisper checkpoint...')
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large-v3")
#model_asli = WhisperModel.from_pretrained("openai/whisper-large-v3") 
model_asli = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-large-v3", low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="sdpa")  
#model_asli = model_asli.to(device)

@spaces.GPU
def predict_mos(wavefile:str):
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    model.to(device)
    if device != model_asli.device:
        model_asli.to(device)
    
    
    print('Starting prediction...')
    # STFT
    wav = torchaudio.load(wavefile)[0] 
    lps = torch.from_numpy(np.expand_dims(np.abs(librosa.stft(wav[0].detach().numpy(), n_fft = 512, hop_length=256,win_length=512)).T, axis=0))
    lps = lps.unsqueeze(1)

    # Whisper Feature
    audio = load_audio(wavefile)
    inputs = feature_extractor(audio, return_tensors="pt")
    input_features = inputs.input_features
    input_features = input_features.to(device)  

    with torch.no_grad():
        decoder_input_ids = torch.tensor([[1, 1]]) * model_asli.config.decoder_start_token_id
        decoder_input_ids =  decoder_input_ids.to(device)
        last_hidden_state = model_asli(input_features, decoder_input_ids=decoder_input_ids).encoder_last_hidden_state 
        whisper_feat = last_hidden_state

    print('Model features shapes...')
    print(whisper_feat.shape)
    print(wav.shape)
    print(lps.shape)

    # prediction
    wav = wav.to(device)
    lps = lps.to(device)
    Quality_1, Intell_1, frame1, frame2 = model(wav ,lps, whisper_feat)
    quality_pred = Quality_1.cpu().detach().numpy()[0] 
    intell_pred = Intell_1.cpu().detach().numpy()[0]         

    print("predictions")
    qa_text = f"Quality: {denorm(quality_pred)[0]:.2f}  Inteligibility: {intell_pred[0]:.2f}" 
    print(qa_text)
    return qa_text


title =  """
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
    <div
        style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;"
    > <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
        MOSA-Net Whisper features
    </h1> </div>
</div>
""" 

description = """
This is a demo of [MOSA-Net+](https://github.com/dhimasryan/MOSA-Net-Cross-Domain/tree/main/MOSA_Net%2B), an improved version of MOSA-
NET that predicts human-based speech quality and intelligibility. MOSA-Net+ uses Whisper to generate cross-domain features. The model employs a CNN-
BLSTM architecture with an attention mechanism and is trained using a multi-task learning approach to predict subjective listening test
scores.  
MOSA-Net+ was tested in the noisy-and-enhanced track of the VoiceMOS Challenge 2023, where it obtained the top-ranked performance among nine systems [full paper](https://arxiv.org/abs/2309.12766)
"""

article = """
If the model contributes to your research please cite the following work: 

R. E. Zezario, S. -W. Fu, F. Chen, C. -S. Fuh, H. -M. Wang and Y. Tsao, "Deep Learning-Based Non-Intrusive Multi-Objective Speech Assessment Model With Cross-Domain Features," in IEEE/ACM Transactions on Audio, Speech, and Language Processing, vol. 31, pp. 54-70, 2023, doi: 10.1109/TASLP.2022.3205757.

R. E. Zezario, Y.-W. Chen, S.-W. Fu, Y. Tsao, H.-M. Wang, C.-S. Fuh, "A Study on Incorporating Whisper for Robust Speech Assessment," IEEE ICME 2024, July 2024, (Top Performance on the Track 3 - VoiceMOS Challenge 2023)"

demo contributed by [@wetdog](https://github.com/wetdog)
"""
demo = gr.Blocks()
with demo:
    gr.Markdown(title)
    gr.Markdown(description)
    gr.Interface(
    fn=predict_mos,
    inputs=gr.Audio(type='filepath'),
    outputs="text",
    allow_flagging=False,)
    gr.Markdown(article)

demo.queue(max_size=10)
demo.launch(show_api=False, server_name="0.0.0.0", server_port=7860)