import os
from typing import Any, Dict, List

import fsspec
import numpy as np
import torch
from coqpit import Coqpit

from TTS.config import check_config_and_model_args
from TTS.tts.utils.managers import BaseIDManager


class LanguageManager(BaseIDManager):
    """Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
    in a way that can be queried by language.

    Args:
        language_ids_file_path (str, optional): Path to the metafile that maps language names to ids used by
        TTS models. Defaults to "".
        config (Coqpit, optional): Coqpit config that contains the language information in the datasets filed.
        Defaults to None.

    Examples:
        >>> manager = LanguageManager(language_ids_file_path=language_ids_file_path)
        >>> language_id_mapper = manager.language_ids
    """

    def __init__(
        self,
        language_ids_file_path: str = "",
        config: Coqpit = None,
    ):
        super().__init__(id_file_path=language_ids_file_path)

        if config:
            self.set_language_ids_from_config(config)

    @property
    def num_languages(self) -> int:
        return len(list(self.name_to_id.keys()))

    @property
    def language_names(self) -> List:
        return list(self.name_to_id.keys())

    @staticmethod
    def parse_language_ids_from_config(c: Coqpit) -> Dict:
        """Set language id from config.

        Args:
            c (Coqpit): Config

        Returns:
            Tuple[Dict, int]: Language ID mapping and the number of languages.
        """
        languages = set({})
        for dataset in c.datasets:
            if "language" in dataset:
                languages.add(dataset["language"])
            else:
                raise ValueError(f"Dataset {dataset['name']} has no language specified.")
        return {name: i for i, name in enumerate(sorted(list(languages)))}

    def set_language_ids_from_config(self, c: Coqpit) -> None:
        """Set language IDs from config samples.

        Args:
            c (Coqpit): Config.
        """
        self.name_to_id = self.parse_language_ids_from_config(c)

    @staticmethod
    def parse_ids_from_data(items: List, parse_key: str) -> Any:
        raise NotImplementedError

    def set_ids_from_data(self, items: List, parse_key: str) -> Any:
        raise NotImplementedError

    def save_ids_to_file(self, file_path: str) -> None:
        """Save language IDs to a json file.

        Args:
            file_path (str): Path to the output file.
        """
        self._save_json(file_path, self.name_to_id)

    @staticmethod
    def init_from_config(config: Coqpit) -> "LanguageManager":
        """Initialize the language manager from a Coqpit config.

        Args:
            config (Coqpit): Coqpit config.
        """
        language_manager = None
        if check_config_and_model_args(config, "use_language_embedding", True):
            if config.get("language_ids_file", None):
                language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
            language_manager = LanguageManager(config=config)
        return language_manager


def _set_file_path(path):
    """Find the language_ids.json under the given path or the above it.
    Intended to band aid the different paths returned in restored and continued training."""
    path_restore = os.path.join(os.path.dirname(path), "language_ids.json")
    path_continue = os.path.join(path, "language_ids.json")
    fs = fsspec.get_mapper(path).fs
    if fs.exists(path_restore):
        return path_restore
    if fs.exists(path_continue):
        return path_continue
    return None


def get_language_balancer_weights(items: list):
    language_names = np.array([item["language"] for item in items])
    unique_language_names = np.unique(language_names).tolist()
    language_ids = [unique_language_names.index(l) for l in language_names]
    language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
    weight_language = 1.0 / language_count
    # get weight for each sample
    dataset_samples_weight = np.array([weight_language[l] for l in language_ids])
    # normalize
    dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
    return torch.from_numpy(dataset_samples_weight).float()