# This code is modified from https://github.com/ZFTurbo/
import pdb

import librosa
from tqdm import tqdm
import os
import torch
import numpy as np
import soundfile as sf
import torch.nn as nn

import warnings
warnings.filterwarnings("ignore")
from bs_roformer.bs_roformer import BSRoformer

class BsRoformer_Loader:
    def get_model_from_config(self):
        config = {
            "attn_dropout": 0.1,
            "depth": 12,
            "dim": 512,
            "dim_freqs_in": 1025,
            "dim_head": 64,
            "ff_dropout": 0.1,
            "flash_attn": True,
            "freq_transformer_depth": 1,
            "freqs_per_bands":(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 12, 12, 12, 12, 12, 12, 12, 12, 24, 24, 24, 24, 24, 24, 24, 24, 48, 48, 48, 48, 48, 48, 48, 48, 128, 129),
            "heads": 8,
            "linear_transformer_depth": 0,
            "mask_estimator_depth": 2,
            "multi_stft_hop_size": 147,
            "multi_stft_normalized": False,
            "multi_stft_resolution_loss_weight": 1.0,
            "multi_stft_resolutions_window_sizes":(4096, 2048, 1024, 512, 256),
            "num_stems": 1,
            "stereo": True,
            "stft_hop_length": 441,
            "stft_n_fft": 2048,
            "stft_normalized": False,
            "stft_win_length": 2048,
            "time_transformer_depth": 1,

        }


        model = BSRoformer(
            **dict(config)
        )

        return model
    

    def demix_track(self, model, mix, device):
        C = 352800
        # num_overlap
        N = 1
        fade_size = C // 10
        step = int(C // N)
        border = C - step
        batch_size = 4

        length_init = mix.shape[-1]

        progress_bar = tqdm(total=length_init // step + 1)
        progress_bar.set_description("Processing")

        # Do pad from the beginning and end to account floating window results better
        if length_init > 2 * border and (border > 0):
            mix = nn.functional.pad(mix, (border, border), mode='reflect')

        # Prepare windows arrays (do 1 time for speed up). This trick repairs click problems on the edges of segment
        window_size = C
        fadein = torch.linspace(0, 1, fade_size)
        fadeout = torch.linspace(1, 0, fade_size)
        window_start = torch.ones(window_size)
        window_middle = torch.ones(window_size)
        window_finish = torch.ones(window_size)
        window_start[-fade_size:] *= fadeout # First audio chunk, no fadein
        window_finish[:fade_size] *= fadein # Last audio chunk, no fadeout
        window_middle[-fade_size:] *= fadeout
        window_middle[:fade_size] *= fadein

        with torch.amp.autocast('cuda'):
            with torch.inference_mode():
                req_shape = (1, ) + tuple(mix.shape)

                result = torch.zeros(req_shape, dtype=torch.float32)
                counter = torch.zeros(req_shape, dtype=torch.float32)
                i = 0
                batch_data = []
                batch_locations = []
                while i < mix.shape[1]:
                    part = mix[:, i:i + C].to(device)
                    length = part.shape[-1]
                    if length < C:
                        if length > C // 2 + 1:
                            part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
                        else:
                            part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
                    if(self.is_half==True):
                        part=part.half()
                    batch_data.append(part)
                    batch_locations.append((i, length))
                    i += step
                    progress_bar.update(1)

                    if len(batch_data) >= batch_size or (i >= mix.shape[1]):
                        arr = torch.stack(batch_data, dim=0)
                        # print(23333333,arr.dtype)
                        x = model(arr)

                        window = window_middle
                        if i - step == 0:  # First audio chunk, no fadein
                            window = window_start
                        elif i >= mix.shape[1]:  # Last audio chunk, no fadeout
                            window = window_finish

                        for j in range(len(batch_locations)):
                            start, l = batch_locations[j]
                            result[..., start:start+l] += x[j][..., :l].cpu() * window[..., :l]
                            counter[..., start:start+l] += window[..., :l]

                        batch_data = []
                        batch_locations = []

                estimated_sources = result / counter
                estimated_sources = estimated_sources.cpu().numpy()
                np.nan_to_num(estimated_sources, copy=False, nan=0.0)

                if length_init > 2 * border and (border > 0):
                    # Remove pad
                    estimated_sources = estimated_sources[..., border:-border]

        progress_bar.close()

        return {k: v for k, v in zip(['vocals', 'other'], estimated_sources)}


    def run_folder(self,input, vocal_root, others_root, format):
        # start_time = time.time()
        self.model.eval()
        path = input

        if not os.path.isdir(vocal_root):
            os.mkdir(vocal_root)

        if not os.path.isdir(others_root):
            os.mkdir(others_root)

        try:
            mix, sr = librosa.load(path, sr=44100, mono=False)
        except Exception as e:
            print('Can read track: {}'.format(path))
            print('Error message: {}'.format(str(e)))
            return

        # Convert mono to stereo if needed
        if len(mix.shape) == 1:
            mix = np.stack([mix, mix], axis=0)

        mix_orig = mix.copy()

        mixture = torch.tensor(mix, dtype=torch.float32)
        res = self.demix_track(self.model, mixture, self.device)

        estimates = res['vocals'].T
        
        if format in ["wav", "flac"]:
            sf.write("{}/{}_{}.{}".format(vocal_root, os.path.basename(path)[:-4], 'vocals', format), estimates, sr)
            sf.write("{}/{}_{}.{}".format(others_root, os.path.basename(path)[:-4], 'instrumental', format), mix_orig.T - estimates, sr)
        else:
            path_vocal = "%s/%s_vocals.wav" % (vocal_root, os.path.basename(path)[:-4])
            path_other = "%s/%s_instrumental.wav" % (others_root, os.path.basename(path)[:-4])
            sf.write(path_vocal, estimates, sr)
            sf.write(path_other, mix_orig.T - estimates, sr)
            opt_path_vocal = path_vocal[:-4] + ".%s" % format
            opt_path_other = path_other[:-4] + ".%s" % format
            if os.path.exists(path_vocal):
                os.system(
                    "ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_vocal, opt_path_vocal)
                )
                if os.path.exists(opt_path_vocal):
                    try:
                        os.remove(path_vocal)
                    except:
                        pass
            if os.path.exists(path_other):
                os.system(
                    "ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_other, opt_path_other)
                )
                if os.path.exists(opt_path_other):
                    try:
                        os.remove(path_other)
                    except:
                        pass

        # print("Elapsed time: {:.2f} sec".format(time.time() - start_time))


    def __init__(self, model_path, device,is_half):
        self.device = device
        self.extract_instrumental=True

        model = self.get_model_from_config()
        state_dict = torch.load(model_path,map_location="cpu")
        model.load_state_dict(state_dict)
        self.is_half=is_half
        if(is_half==False):
            self.model = model.to(device)
        else:
            self.model = model.half().to(device)


    def _path_audio_(self, input, others_root, vocal_root, format, is_hp3=False):
        self.run_folder(input, vocal_root, others_root, format)