#!/usr/bin/env python3
# Copyright    2023  Xiaomi Corp.        (authors: Fangjun Kuang)

from functools import lru_cache
from typing import Optional, Tuple

import ffmpeg
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from pydub import AudioSegment

from unet import UNet


def load_audio(filename):
    probe = ffmpeg.probe(filename)
    if "streams" not in probe or len(probe["streams"]) == 0:
        raise ValueError("No stream was found with ffprobe")

    metadata = next(
        stream for stream in probe["streams"] if stream["codec_type"] == "audio"
    )
    n_channels = metadata["channels"]

    sample_rate = 44100

    process = (
        ffmpeg.input(filename)
        .output("pipe:", format="f32le", ar=sample_rate)
        .run_async(pipe_stdout=True, pipe_stderr=True)
    )
    buffer, _ = process.communicate()
    waveform = np.frombuffer(buffer, dtype="<f4").reshape(-1, n_channels)

    waveform = torch.from_numpy(waveform).to(torch.float32)
    if n_channels == 1:
        waveform = waveform.tile(1, 2)

    if n_channels > 2:
        waveform = waveform[:, :2]

    return waveform


def separate(
    vocals: torch.nn.Module,
    accompaniment: torch.nn.Module,
    waveform: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    waveform = torch.nn.functional.pad(waveform, (0, 0, 0, 4096))

    # torch.stft requires a 2-D input of shape (N, T), so we transpose waveform
    stft = torch.stft(
        waveform.t(),
        n_fft=4096,
        hop_length=1024,
        window=torch.hann_window(4096, periodic=True),
        center=False,
        onesided=True,
        return_complex=True,
    )
    # stft: (2, 2049, 465)
    # stft is a complex tensor

    y = stft.permute(2, 1, 0)
    # (465, 2049, 2)

    y = y[:, :1024, :]
    # (465, 1024, 2)

    tensor_size = y.shape[0] - int(y.shape[0] / 512) * 512
    pad_size = 512 - tensor_size
    y = torch.nn.functional.pad(y, (0, 0, 0, 0, 0, pad_size))
    # (512, 1024, 2)

    num_splits = int(y.shape[0] / 512)
    y = y.reshape([num_splits, 512] + list(y.shape[1:]))
    # y: (1, 512, 1024, 2)

    y = y.abs()
    y = y.permute(0, 3, 1, 2)
    # (1, 2, 512, 1024)

    vocals_spec = vocals(y)
    accompaniment_spec = accompaniment(y)

    sum_spec = (vocals_spec**2 + accompaniment_spec**2) + 1e-10

    vocals_spec = (vocals_spec**2 + 1e-10 / 2) / sum_spec
    # (1, 2, 512, 1024)

    accompaniment_spec = (accompaniment_spec**2 + 1e-10 / 2) / sum_spec
    # (1, 2, 512, 1024)

    ans = []
    for spec in [vocals_spec, accompaniment_spec]:
        spec = torch.nn.functional.pad(spec, (0, 2049 - 1024, 0, 0, 0, 0, 0, 0))
        # (1, 2, 512, 2049)

        spec = spec.permute(0, 2, 3, 1)
        # (1, 512, 2049, 2)

        spec = spec.reshape(-1, spec.shape[2], spec.shape[3])
        # (512, 2049, 2)

        spec = spec[: stft.shape[2], :, :]
        # (465, 2049, 2)

        spec = spec.permute(2, 1, 0)
        # (2, 2049, 465)

        masked_stft = spec * stft

        wave = torch.istft(
            masked_stft,
            4096,
            1024,
            window=torch.hann_window(4096, periodic=True),
            onesided=True,
        ) * (2 / 3)

        #  sf.write(f"{name}.wav", wave.t(), 44100)

        #  wave = (wave.t() * 32768).to(torch.int16)
        #  sound = AudioSegment(
        #      data=wave.numpy().tobytes(), sample_width=2, frame_rate=44100, channels=2
        #  )
        #  sound.export(f"{name}.mp3", format="mp3", bitrate="128k")
        ans.append(wave)

    return ans[0], ans[1]


@lru_cache(maxsize=10)
def get_file(
    repo_id: str,
    filename: str,
    subfolder: str = "2stems",
) -> str:
    nn_model_filename = hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        subfolder=subfolder,
    )
    return nn_model_filename


@lru_cache(maxsize=10)
def load_model(name: str):
    net = UNet()
    net.eval()
    filename = get_file("csukuangfj/spleeter-torch", name, subfolder="2stems")

    state_dict = torch.load(filename, map_location="cpu")
    net.load_state_dict(state_dict)

    return net


@torch.no_grad()
def main():
    vocals = load_model("vocals.pt")
    accompaniment = load_model("accompaniment.pt")

    filename = "./yesterday-once-more-carpenters.mp3"

    waveform = load_audio(filename)
    assert waveform.shape[1] == 2, waveform.shape

    vocals_wave, accompaniment_wave = separate(vocals, accompaniment, waveform)
    vocals_wave = (vocals_wave.t() * 32768).to(torch.int16)
    accompaniment_wave = (accompaniment_wave.t() * 32768).to(torch.int16)

    vocals_sound = AudioSegment(
        data=vocals_wave.numpy().tobytes(), sample_width=2, frame_rate=44100, channels=2
    )
    vocals_sound.export(f"vocals.mp3", format="mp3", bitrate="128k")

    accompaniment_sound = AudioSegment(
        data=accompaniment_wave.numpy().tobytes(),
        sample_width=2,
        frame_rate=44100,
        channels=2,
    )
    accompaniment_sound.export(f"accompaniment.mp3", format="mp3", bitrate="128k")


if __name__ == "__main__":
    main()