import torch
import numpy as np
import gradio as gr
import yaml
import librosa
from tqdm.auto import tqdm
import spaces

import look2hear.models
from ml_collections import ConfigDict

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def load_audio(file_path):
    audio, samplerate = librosa.load(file_path, mono=False, sr=44100)
    print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}')
    #audio = dBgain(audio, -6)
    return torch.from_numpy(audio), samplerate


def get_config(config_path):
    with open(config_path) as f:
        #config = OmegaConf.load(config_path)
        config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
        return config




def _getWindowingArray(window_size, fade_size):
    # IMPORTANT NOTE :
    # no fades here in the end, only removing the failed ending of the chunk
    fadein = torch.linspace(1, 1, fade_size)
    fadeout = torch.linspace(0, 0, fade_size)
    window = torch.ones(window_size)
    window[-fade_size:] *= fadeout
    window[:fade_size] *= fadein
    return window



description = f'''
This is unofficial space for audio restoration model Apollo: https://github.com/JusperLee/Apollo
'''


apollo_config = get_config('configs/apollo.yaml')
apollo_vocal2_config = get_config('configs/config_apollo_vocal.yaml')
apollo_uni_config = get_config('configs/config_apollo_uni.yaml')
apollo_model = look2hear.models.BaseModel.from_pretrain('weights/apollo.bin', **apollo_config['model']).to(device)
apollo_vocal = look2hear.models.BaseModel.from_pretrain('weights/apollo_vocal.bin', **apollo_config['model']).to(device)
apollo_vocal2 = look2hear.models.BaseModel.from_pretrain('weights/apollo_vocal2.bin', **apollo_vocal2_config['model']).to(device)
apollo_uni = look2hear.models.BaseModel.from_pretrain('weights/apollo_model_uni.ckpt', **apollo_uni_config['model']).to(device)



models = {
   'apollo': apollo_model,
   'apollo_vocal': apollo_vocal,
   'apollo_vocal2': apollo_vocal2,
   'apollo_uni': apollo_uni
}

choices = [
    ('MP3 restore', 'apollo'),
    ('Apollo vocal', 'apollo_vocal'),
    ('Apollo vocal2', 'apollo_vocal2'),
    ('Apollo universal', 'apollo_uni')
]

@spaces.GPU
def enchance(choice, audio):
    print(choice)
    model = models[choice]
    test_data, samplerate = load_audio(audio)
    C = 10 * samplerate  # chunk_size seconds to samples
    N = 2
    step = C // N
    fade_size = 3 * 44100 # 3 seconds
    print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}")
    
    border = C - step
    
    # handle mono inputs correctly
    if len(test_data.shape) == 1:
        test_data = test_data.unsqueeze(0) 

    # Pad the input if necessary
    if test_data.shape[1] > 2 * border and (border > 0):
        test_data = torch.nn.functional.pad(test_data, (border, border), mode='reflect')

    windowingArray = _getWindowingArray(C, fade_size)

    result = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
    counter = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)

    i = 0
    progress_bar = tqdm(total=test_data.shape[1], desc="Processing audio chunks", leave=False)

    while i < test_data.shape[1]:
        part = test_data[:, i:i + C]
        length = part.shape[-1]
        if length < C:
            if length > C // 2 + 1:
                part = torch.nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
            else:
                part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)


        chunk = part.unsqueeze(0).to(device)
        with torch.no_grad():
            out = model(chunk).squeeze(0).squeeze(0).cpu()

        window = windowingArray
        if i == 0:  # First audio chunk, no fadein
            window[:fade_size] = 1
        elif i + C >= test_data.shape[1]:  # Last audio chunk, no fadeout
            window[-fade_size:] = 1

        result[..., i:i+length] += out[..., :length] * window[..., :length]
        counter[..., i:i+length] += window[..., :length]

        i += step
        progress_bar.update(step)

    progress_bar.close()

    final_output = result / counter
    final_output = final_output.squeeze(0).numpy()
    np.nan_to_num(final_output, copy=False, nan=0.0)

    # Remove padding if added earlier
    if test_data.shape[1] > 2 * border and (border > 0):
        final_output = final_output[..., border:-border]
    
    return samplerate, final_output.T


if __name__ == "__main__":
    i = gr.Interface(
        fn=enchance,
        description=description,
        inputs=[
            gr.Dropdown(label="Model", choices=choices, value=choices[0][1]),
            gr.Audio(label="Input Audio:", interactive=True, type='filepath', max_length=3000, waveform_options={'waveform_progress_color': '#3C82F6'}),
        ],
        outputs=[
            gr.Audio(
                        label="Output Audio",
                        autoplay=False,
                        streaming=False,
                        type="numpy",
                    ),
            
        ],
        allow_flagging ='never',
        cache_examples=False,
        title='Apollo audio restoration',
        
    )
    i.queue(max_size=20, default_concurrency_limit=4)
    i.launch(share=False, server_name="0.0.0.0")