# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import torch
from tqdm import tqdm
from collections import OrderedDict

from models.tts.base.tts_inferece import TTSInference
from models.tts.fastspeech2.fs2_dataset import FS2TestDataset, FS2TestCollator
from utils.util import load_config
from utils.io import save_audio
from models.tts.fastspeech2.fs2 import FastSpeech2
from models.vocoders.vocoder_inference import synthesis
from pathlib import Path
from processors.phone_extractor import phoneExtractor
from text.text_token_collation import phoneIDCollation
import numpy as np
import json


class FastSpeech2Inference(TTSInference):
    def __init__(self, args, cfg):
        TTSInference.__init__(self, args, cfg)
        self.args = args
        self.cfg = cfg
        self.infer_type = args.mode

    def _build_model(self):
        self.model = FastSpeech2(self.cfg)
        return self.model

    def load_model(self, state_dict):
        raw_dict = state_dict["model"]
        clean_dict = OrderedDict()
        for k, v in raw_dict.items():
            if k.startswith("module."):
                clean_dict[k[7:]] = v
            else:
                clean_dict[k] = v

        self.model.load_state_dict(clean_dict)

    def _build_test_dataset(self):
        return FS2TestDataset, FS2TestCollator

    @staticmethod
    def _parse_vocoder(vocoder_dir):
        r"""Parse vocoder config"""
        vocoder_dir = os.path.abspath(vocoder_dir)
        ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
        # last step (different from the base *int(x.stem)*)
        ckpt_list.sort(
            key=lambda x: int(x.stem.split("_")[-2].split("-")[-1]), reverse=True
        )
        ckpt_path = str(ckpt_list[0])
        vocoder_cfg = load_config(
            os.path.join(vocoder_dir, "args.json"), lowercase=True
        )
        return vocoder_cfg, ckpt_path

    @torch.inference_mode()
    def inference_for_batches(self):
        y_pred = []
        for i, batch in tqdm(enumerate(self.test_dataloader)):
            y_pred, mel_lens, _ = self._inference_each_batch(batch)
            y_ls = y_pred.chunk(self.test_batch_size)
            tgt_ls = mel_lens.chunk(self.test_batch_size)
            j = 0
            for it, l in zip(y_ls, tgt_ls):
                l = l.item()
                it = it.squeeze(0)[:l].detach().cpu()

                uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
                torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt"))
                j += 1

        vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
        res = synthesis(
            cfg=vocoder_cfg,
            vocoder_weight_file=vocoder_ckpt,
            n_samples=None,
            pred=[
                torch.load(
                    os.path.join(self.args.output_dir, "{}.pt".format(item["Uid"]))
                ).numpy()
                for item in self.test_dataset.metadata
            ],
        )
        for it, wav in zip(self.test_dataset.metadata, res):
            uid = it["Uid"]
            save_audio(
                os.path.join(self.args.output_dir, f"{uid}.wav"),
                wav.numpy(),
                self.cfg.preprocess.sample_rate,
                add_silence=True,
                turn_up=True,
            )
            os.remove(os.path.join(self.args.output_dir, f"{uid}.pt"))

    @torch.inference_mode()
    def _inference_each_batch(self, batch_data):
        device = self.accelerator.device
        control_values = (
            self.args.pitch_control,
            self.args.energy_control,
            self.args.duration_control,
        )
        for k, v in batch_data.items():
            batch_data[k] = v.to(device)

        pitch_control, energy_control, duration_control = control_values

        output = self.model(
            batch_data,
            p_control=pitch_control,
            e_control=energy_control,
            d_control=duration_control,
        )
        pred_res = output["postnet_output"]
        mel_lens = output["mel_lens"].cpu()
        return pred_res, mel_lens, 0

    def inference_for_single_utterance(self):
        text = self.args.text
        control_values = (
            self.args.pitch_control,
            self.args.energy_control,
            self.args.duration_control,
        )
        pitch_control, energy_control, duration_control = control_values

        # get phone symbol file
        phone_symbol_file = None
        if self.cfg.preprocess.phone_extractor != "lexicon":
            phone_symbol_file = os.path.join(
                self.exp_dir, self.cfg.preprocess.symbols_dict
            )
            assert os.path.exists(phone_symbol_file)
        # convert text to phone sequence
        phone_extractor = phoneExtractor(self.cfg)

        phone_seq = phone_extractor.extract_phone(text)  # phone_seq: list
        # convert phone sequence to phone id sequence
        phon_id_collator = phoneIDCollation(
            self.cfg, symbols_dict_file=phone_symbol_file
        )
        phone_seq = ["{"] + phone_seq + ["}"]
        phone_id_seq = phon_id_collator.get_phone_id_sequence(self.cfg, phone_seq)

        # convert phone sequence to phone id sequence
        phone_id_seq = np.array(phone_id_seq)
        phone_id_seq = torch.from_numpy(phone_id_seq)

        # get speaker id if multi-speaker training and use speaker id
        speaker_id = None
        if self.cfg.preprocess.use_spkid and self.cfg.train.multi_speaker_training:
            spk2id_file = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
            with open(spk2id_file, "r") as f:
                spk2id = json.load(f)
                speaker_id = spk2id[self.args.speaker_name]
                speaker_id = torch.from_numpy(np.array([speaker_id], dtype=np.int32))
        else:
            speaker_id = torch.Tensor(0).view(-1)

        with torch.no_grad():
            x_tst = phone_id_seq.to(self.device).unsqueeze(0)
            x_tst_lengths = torch.LongTensor([phone_id_seq.size(0)]).to(self.device)
            if speaker_id is not None:
                speaker_id = speaker_id.to(self.device)

            data = {}
            data["texts"] = x_tst
            data["text_len"] = x_tst_lengths
            data["spk_id"] = speaker_id

            output = self.model(
                data,
                p_control=pitch_control,
                e_control=energy_control,
                d_control=duration_control,
            )
            pred_res = output["postnet_output"]
            vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
            audio = synthesis(
                cfg=vocoder_cfg,
                vocoder_weight_file=vocoder_ckpt,
                n_samples=None,
                pred=pred_res,
            )
        return audio[0]