import json
import os.path
from functools import lru_cache
from typing import Union, List

import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download, HfFileSystem

try:
    from typing import Literal
except (ModuleNotFoundError, ImportError):
    from typing_extensions import Literal

from imgutils.data import MultiImagesTyping, load_images, ImageTyping
from imgutils.utils import open_onnx_model

hf_fs = HfFileSystem()


def _normalize(data, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)):
    mean, std = np.asarray(mean), np.asarray(std)
    return (data - mean[:, None, None]) / std[:, None, None]


def _preprocess_image(image: Image.Image, size: int = 384):
    image = image.resize((size, size), resample=Image.BILINEAR)
    # noinspection PyTypeChecker
    data = np.array(image).transpose(2, 0, 1).astype(np.float32) / 255.0
    data = _normalize(data)

    return data


@lru_cache()
def _open_feat_model(model):
    return open_onnx_model(hf_hub_download(
        f'deepghs/ccip_onnx',
        f'{model}/model_feat.onnx',
    ))


@lru_cache()
def _open_metric_model(model):
    return open_onnx_model(hf_hub_download(
        f'deepghs/ccip_onnx',
        f'{model}/model_metrics.onnx',
    ))


@lru_cache()
def _open_metrics(model):
    with open(hf_hub_download(f'deepghs/ccip_onnx', f'{model}/metrics.json'), 'r') as f:
        return json.load(f)


@lru_cache()
def _open_cluster_metrics(model):
    with open(hf_hub_download(f'deepghs/ccip_onnx', f'{model}/cluster.json'), 'r') as f:
        return json.load(f)


_VALID_MODEL_NAMES = [
    os.path.basename(os.path.dirname(file)) for file in
    hf_fs.glob('deepghs/ccip_onnx/*/model.ckpt')
]
_DEFAULT_MODEL_NAMES = 'ccip-caformer-24-randaug-pruned'


def ccip_extract_feature(image: ImageTyping, size: int = 384, model: str = _DEFAULT_MODEL_NAMES):
    """
    Extracts the feature vector of the character from the given anime image.

    :param image: The anime image containing a single character.
    :type image: ImageTyping

    :param size: The size of the input image to be used for feature extraction. (default: ``384``)
    :type size: int

    :param model: The name of the model to use for feature extraction. (default: ``ccip-caformer-24-randaug-pruned``)
                  The available model names are: ``ccip-caformer-24-randaug-pruned``,
                  ``ccip-caformer-6-randaug-pruned_fp32``, ``ccip-caformer-5_fp32``.
    :type model: str

    :return: The feature vector of the character.
    :rtype: numpy.ndarray

    Examples::
        >>> from imgutils.metrics import ccip_extract_feature
        >>>
        >>> feat = ccip_extract_feature('ccip/1.jpg')
        >>> feat.shape, feat.dtype
        ((768,), dtype('float32'))
    """
    return ccip_batch_extract_features([image], size, model)[0]


def ccip_batch_extract_features(images: MultiImagesTyping, size: int = 384, model: str = _DEFAULT_MODEL_NAMES):
    """
    Extracts the feature vectors of multiple images using the specified model.

    :param images: The input images from which to extract the feature vectors.
    :type images: MultiImagesTyping

    :param size: The size of the input image to be used for feature extraction. (default: ``384``)
    :type size: int

    :param model: The name of the model to use for feature extraction. (default: ``ccip-caformer-24-randaug-pruned``)
                  The available model names are: ``ccip-caformer-24-randaug-pruned``,
                  ``ccip-caformer-6-randaug-pruned_fp32``, ``ccip-caformer-5_fp32``.
    :type model: str

    :return: The feature vectors of the input images.
    :rtype: numpy.ndarray

    Examples::
        >>> from imgutils.metrics import ccip_batch_extract_features
        >>>
        >>> feat = ccip_batch_extract_features(['ccip/1.jpg', 'ccip/2.jpg', 'ccip/6.jpg'])
        >>> feat.shape, feat.dtype
        ((3, 768), dtype('float32'))
    """
    images = load_images(images, mode='RGB')
    data = np.stack([_preprocess_image(item, size=size) for item in images]).astype(np.float32)
    output, = _open_feat_model(model).run(['output'], {'input': data})
    return output


_FeatureOrImage = Union[ImageTyping, np.ndarray]


def _p_feature(x: _FeatureOrImage, size: int = 384, model: str = _DEFAULT_MODEL_NAMES):
    if isinstance(x, np.ndarray):  # if feature
        return x
    else:  # is image or path
        return ccip_extract_feature(x, size, model)


def ccip_default_threshold(model: str = _DEFAULT_MODEL_NAMES) -> float:
    """
    Retrieves the default threshold value obtained from model metrics in the Hugging Face model repository.

    :param model: The name of the model to use for feature extraction. (default: ``ccip-caformer-24-randaug-pruned``)
                  The available model names are: ``ccip-caformer-24-randaug-pruned``,
                  ``ccip-caformer-6-randaug-pruned_fp32``, ``ccip-caformer-5_fp32``.
    :type model: str

    :return: The default threshold value obtained from model metrics.
    :rtype: float

    Examples::
        >>> from imgutils.metrics import ccip_default_threshold
        >>>
        >>> ccip_default_threshold()
        0.17847511429108218
        >>> ccip_default_threshold('ccip-caformer-6-randaug-pruned_fp32')
        0.1951224011983088
        >>> ccip_default_threshold('ccip-caformer-5_fp32')
        0.18397327797685215
    """
    return _open_metrics(model)['threshold']


def ccip_difference(x: _FeatureOrImage, y: _FeatureOrImage,
                    size: int = 384, model: str = _DEFAULT_MODEL_NAMES) -> float:
    """
    Calculates the difference value between two anime characters based on their images or feature vectors.

    :param x: The image or feature vector of the first anime character.
    :type x: Union[ImageTyping, np.ndarray]

    :param y: The image or feature vector of the second anime character.
    :type y: Union[ImageTyping, np.ndarray]

    :param size: The size of the input image to be used for feature extraction. (default: ``384``)
    :type size: int

    :param model: The name of the model to use for feature extraction. (default: ``ccip-caformer-24-randaug-pruned``)
                  The available model names are: ``ccip-caformer-24-randaug-pruned``,
                  ``ccip-caformer-6-randaug-pruned_fp32``, ``ccip-caformer-5_fp32``.
    :type model: str

    :return: The difference value between the two anime characters.
    :rtype: float

    Examples::
        >>> from imgutils.metrics import ccip_difference
        >>>
        >>> ccip_difference('ccip/1.jpg', 'ccip/2.jpg')  # same character
        0.16583099961280823
        >>>
        >>> # different characters
        >>> ccip_difference('ccip/1.jpg', 'ccip/6.jpg')
        0.42947039008140564
        >>> ccip_difference('ccip/1.jpg', 'ccip/7.jpg')
        0.4037521779537201
        >>> ccip_difference('ccip/2.jpg', 'ccip/6.jpg')
        0.4371533691883087
        >>> ccip_difference('ccip/2.jpg', 'ccip/7.jpg')
        0.40748104453086853
        >>> ccip_difference('ccip/6.jpg', 'ccip/7.jpg')
        0.392294704914093
    """
    return ccip_batch_differences([x, y], size, model)[0, 1].item()


def ccip_batch_differences(images: List[_FeatureOrImage],
                           size: int = 384, model: str = _DEFAULT_MODEL_NAMES) -> np.ndarray:
    """
    Calculates the pairwise differences between a given list of images or feature vectors representing anime characters.

    :param images: The list of images or feature vectors representing anime characters.
    :type images: List[Union[ImageTyping, np.ndarray]]

    :param size: The size of the input image to be used for feature extraction. (default: ``384``)
    :type size: int

    :param model: The name of the model to use for feature extraction. (default: ``ccip-caformer-24-randaug-pruned``)
                  The available model names are: ``ccip-caformer-24-randaug-pruned``,
                  ``ccip-caformer-6-randaug-pruned_fp32``, ``ccip-caformer-5_fp32``.
    :type model: str

    :return: The matrix of pairwise differences between the given images or feature vectors.
    :rtype: np.ndarray

    Examples::
        >>> from imgutils.metrics import ccip_batch_differences
        >>>
        >>> ccip_batch_differences(['ccip/1.jpg', 'ccip/2.jpg', 'ccip/6.jpg', 'ccip/7.jpg'])
        array([[6.5350548e-08, 1.6583106e-01, 4.2947042e-01, 4.0375218e-01],
               [1.6583106e-01, 9.8025822e-08, 4.3715334e-01, 4.0748104e-01],
               [4.2947042e-01, 4.3715334e-01, 3.2675274e-08, 3.9229470e-01],
               [4.0375218e-01, 4.0748104e-01, 3.9229470e-01, 6.5350548e-08]],
              dtype=float32)
    """
    input_ = np.stack([_p_feature(img, size, model) for img in images]).astype(np.float32)
    output, = _open_metric_model(model).run(['output'], {'input': input_})
    return output