"""
Util functions to process any incoming audio data to be processable by the model 
"""
import os
import torch
import torchaudio
# import wget
import requests

DEFAULT_SAMPLE_RATE=48000
DEFAULT_WAVE_LENGTH=3

def process_from_url(url):
    # download UI audio
    req_url = requests.get(url)

    with open('temp.wav', 'wb') as file:
        file.write(req_url.content)

    
    # filename = 'temp.wav'
    # audio = torchaudio.load(filename)

    # # remove wget file
    # os.remove(filename)

    # spec
    spec = process_from_filename('temp.wav')

    os.remove('temp.wav')
    return spec


def process_from_filename(filename, target_sample_rate=DEFAULT_SAMPLE_RATE, wav_length=DEFAULT_WAVE_LENGTH):
    wav, sample_rate = torchaudio.load(filename)

    wav = process_raw_wav(wav, sample_rate, target_sample_rate, wav_length)

    spec = _wav_to_spec(wav, target_sample_rate)

    return spec

def process_raw_wav(wav, sample_rate, target_sample_rate, wav_length):
    num_samples = wav_length * target_sample_rate

    wav = _resample(wav, sample_rate, target_sample_rate)
    wav = _mix_down(wav)
    wav = _cut(wav, num_samples)
    wav = _pad(wav, num_samples)

    return wav

def _wav_to_spec(wav, target_sample_rate):
    mel_spectrogram = torchaudio.transforms.MelSpectrogram(
        sample_rate=target_sample_rate,
        n_fft=2048,
        hop_length=512,
        n_mels=128,
    )

    return mel_spectrogram(wav)

def _resample(wav, sample_rate, target_sample_rate):
    if sample_rate != target_sample_rate:
        resampler = torchaudio.transforms.Resample(sample_rate, target_sample_rate)
        wav = resampler(wav)
    
    return wav

def _mix_down(wav):
    if wav.shape[0] > 1:
        wav = torch.mean(wav, dim=0, keepdim=True)
    
    return wav

def _cut(wav, num_samples):
    if wav.shape[1] > num_samples:
        wav = wav[:, :num_samples]
    
    return wav

def _pad(wav, num_samples):
    if wav.shape[1] < num_samples:
        missing_samples = num_samples - wav.shape[1]
        pad = (0, missing_samples)
        wav = torch.nn.functional.pad(wav, pad)
    
    return wav