from pydantic import BaseModel, Field
import os
from pathlib import Path
from enum import Enum
from typing import Any, Tuple
import numpy as np
from utils.load_yaml import HpsYaml
from utils.util import AttrDict
import torch

# Constants
EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models"
CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models"
ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"


if os.path.isdir(EXT_MODELS_DIRT):    
    extractors =  Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
    print("Loaded extractor models: " + str(len(extractors)))
else:
    raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")

if os.path.isdir(CONV_MODELS_DIRT):    
    convertors =  Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
    print("Loaded convertor models: " + str(len(convertors)))
else:
    raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")

if os.path.isdir(ENC_MODELS_DIRT):    
    encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
    print("Loaded encoders models: " + str(len(encoders)))
else:
    raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")

class Model(str, Enum):
    VC_PPG2MEL = "ppg2mel"

class Dataset(str, Enum):
    AIDATATANG_200ZH = "aidatatang_200zh"
    AIDATATANG_200ZH_S = "aidatatang_200zh_s"

class Input(BaseModel):
    # def render_input_ui(st, input) -> Dict: 
    #     input["selected_dataset"] = st.selectbox(
    #         '选择数据集', 
    #         ("aidatatang_200zh", "aidatatang_200zh_s")
    #     )
    # return input
    model: Model = Field(
        Model.VC_PPG2MEL, title="模型类型",
    )
    # datasets_root: str = Field(
    #     ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
    #     format=True,
    #     example="..\\trainning_data\\"
    # )
    output_root: str = Field(
        ..., alias="输出目录(可选)", description="建议不填,保持默认",
        format=True,
        example=""
    )
    continue_mode: bool = Field(
        True, alias="继续训练模式", description="选择“是”,则从下面选择的模型中继续训练",
    )
    gpu: bool = Field(
        True, alias="GPU训练", description="选择“是”,则使用GPU训练",
    )
    verbose: bool = Field(
        True, alias="打印详情", description="选择“是”,输出更多详情",
    )
    # TODO: Move to hiden fields by default
    convertor: convertors = Field(
        ..., alias="转换模型", 
        description="选择语音转换模型文件."
    )
    extractor: extractors = Field(
        ..., alias="特征提取模型", 
        description="选择PPG特征提取模型文件."
    )
    encoder: encoders = Field(
        ..., alias="语音编码模型", 
        description="选择语音编码模型文件."
    )
    njobs: int = Field(
        8, alias="进程数", description="适用于ppg2mel",
    )
    seed: int = Field(
        default=0, alias="初始随机数", description="适用于ppg2mel",
    )
    model_name: str = Field(
        ..., alias="新模型名", description="仅在重新训练时生效,选中继续训练时无效",
        example="test"
    )
    model_config: str = Field(
        ..., alias="新模型配置", description="仅在重新训练时生效,选中继续训练时无效",
        example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2"
    )

class AudioEntity(BaseModel):
    content: bytes
    mel: Any

class Output(BaseModel):
    __root__: Tuple[str, int]

    def render_output_ui(self, streamlit_app, input) -> None:  # type: ignore
        """Custom output UI.
        If this method is implmeneted, it will be used instead of the default Output UI renderer.
        """
        sr, count = self.__root__
        streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")

def train_vc(input: Input) -> Output:
    """Train VC(训练 VC)"""

    print(">>> OneShot VC training ...")
    params = AttrDict()
    params.update({
        "gpu": input.gpu,
        "cpu": not input.gpu,
        "njobs": input.njobs,
        "seed": input.seed,
        "verbose": input.verbose,
        "load": input.convertor.value,
        "warm_start": False,
    })
    if input.continue_mode: 
        # trace old model and config
        p = Path(input.convertor.value)
        params.name = p.parent.name
        # search a config file
        model_config_fpaths = list(p.parent.rglob("*.yaml"))
        if len(model_config_fpaths) == 0:
            raise "No model yaml config found for convertor"
        config = HpsYaml(model_config_fpaths[0])
        params.ckpdir = p.parent.parent
        params.config = model_config_fpaths[0]
        params.logdir = os.path.join(p.parent, "log")
    else:
        # Make the config dict dot visitable
        config = HpsYaml(input.config)    
    np.random.seed(input.seed)
    torch.manual_seed(input.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(input.seed)
    mode = "train"
    from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
    solver = Solver(config, params, mode)
    solver.load_data()
    solver.set_model()
    solver.exec()
    print(">>> Oneshot VC train finished!")

    # TODO: pass useful return code
    return Output(__root__=(input.dataset, 0))