import os
import uuid
from typing import Union

import torch
from box import Box

from modules import models
from modules.utils.SeedContext import SeedContext


def create_speaker_from_seed(seed):
    chat_tts = models.load_chat_tts()
    with SeedContext(seed, True):
        emb = chat_tts.sample_random_speaker()
    return emb


class Speaker:
    @staticmethod
    def from_file(file_like):
        speaker = torch.load(file_like, map_location=torch.device("cpu"))
        speaker.fix()
        return speaker

    @staticmethod
    def from_tensor(tensor):
        speaker = Speaker(seed_or_tensor=-2)
        speaker.emb = tensor
        return speaker

    @staticmethod
    def from_seed(seed: int):
        speaker = Speaker(seed_or_tensor=seed)
        speaker.emb = create_speaker_from_seed(seed)
        return speaker

    def __init__(
        self, seed_or_tensor: Union[int, torch.Tensor], name="", gender="", describe=""
    ):
        self.id = uuid.uuid4()
        self.seed = -2 if isinstance(seed_or_tensor, torch.Tensor) else seed_or_tensor
        self.name = name
        self.gender = gender
        self.describe = describe
        self.emb = None if isinstance(seed_or_tensor, int) else seed_or_tensor

        # TODO replace emb => tokens
        self.tokens = []

    def to_json(self, with_emb=False):
        return Box(
            **{
                "id": str(self.id),
                "seed": self.seed,
                "name": self.name,
                "gender": self.gender,
                "describe": self.describe,
                "emb": self.emb.tolist() if with_emb else None,
            }
        )

    def fix(self):
        is_update = False
        if "id" not in self.__dict__:
            setattr(self, "id", uuid.uuid4())
            is_update = True
        if "seed" not in self.__dict__:
            setattr(self, "seed", -2)
            is_update = True
        if "name" not in self.__dict__:
            setattr(self, "name", "")
            is_update = True
        if "gender" not in self.__dict__:
            setattr(self, "gender", "*")
            is_update = True
        if "describe" not in self.__dict__:
            setattr(self, "describe", "")
            is_update = True

        return is_update

    def __hash__(self):
        return hash(str(self.id))

    def __eq__(self, other):
        if not isinstance(other, Speaker):
            return False
        return str(self.id) == str(other.id)


# 每个speaker就是一个 emb 文件 .pt
# 管理 speaker 就是管理 ./data/speaker/ 下的所有 speaker
# 可以 用 seed 创建一个 speaker
# 可以 刷新列表 读取所有 speaker
# 可以列出所有 speaker
class SpeakerManager:
    def __init__(self):
        self.speakers = {}
        self.speaker_dir = "./data/speakers/"
        self.refresh_speakers()

    def refresh_speakers(self):
        self.speakers = {}
        for speaker_file in os.listdir(self.speaker_dir):
            if speaker_file.endswith(".pt"):
                self.speakers[speaker_file] = Speaker.from_file(
                    self.speaker_dir + speaker_file
                )
        # 检查是否有被删除的,同步到 speakers
        for fname, spk in self.speakers.items():
            if not os.path.exists(self.speaker_dir + fname):
                del self.speakers[fname]

    def list_speakers(self) -> list[Speaker]:
        return list(self.speakers.values())

    def create_speaker_from_seed(self, seed, name="", gender="", describe=""):
        if name == "":
            name = seed
        filename = name + ".pt"
        speaker = Speaker(seed, name=name, gender=gender, describe=describe)
        speaker.emb = create_speaker_from_seed(seed)
        torch.save(speaker, self.speaker_dir + filename)
        self.refresh_speakers()
        return speaker

    def create_speaker_from_tensor(
        self, tensor, filename="", name="", gender="", describe=""
    ):
        if filename == "":
            filename = name
        speaker = Speaker(
            seed_or_tensor=-2, name=name, gender=gender, describe=describe
        )
        if isinstance(tensor, torch.Tensor):
            speaker.emb = tensor
        if isinstance(tensor, list):
            speaker.emb = torch.tensor(tensor)
        torch.save(speaker, self.speaker_dir + filename + ".pt")
        self.refresh_speakers()
        return speaker

    def get_speaker(self, name) -> Union[Speaker, None]:
        for speaker in self.speakers.values():
            if speaker.name == name:
                return speaker
        return None

    def get_speaker_by_id(self, id) -> Union[Speaker, None]:
        for speaker in self.speakers.values():
            if str(speaker.id) == str(id):
                return speaker
        return None

    def get_speaker_filename(self, id: str):
        filename = None
        for fname, spk in self.speakers.items():
            if str(spk.id) == str(id):
                filename = fname
                break
        return filename

    def update_speaker(self, speaker: Speaker):
        filename = None
        for fname, spk in self.speakers.items():
            if str(spk.id) == str(speaker.id):
                filename = fname
                break

        if filename:
            torch.save(speaker, self.speaker_dir + filename)
            self.refresh_speakers()
            return speaker
        else:
            raise ValueError("Speaker not found for update")

    def save_all(self):
        for speaker in self.speakers.values():
            filename = self.get_speaker_filename(speaker.id)
            torch.save(speaker, self.speaker_dir + filename)
        # self.refresh_speakers()

    def __len__(self):
        return len(self.speakers)


speaker_mgr = SpeakerManager()