from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torchvision
from matplotlib import font_manager
from matplotlib.figure import Figure
from matplotlib.gridspec import GridSpec
from PIL import Image

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
DISTANCE_THRESHOLD_NEW_INDIVIDUAL = 0.7


def get_inverse_normalize_transform(mean, std):
    return torchvision.transforms.Normalize(
        mean=[-m / s for m, s in zip(mean, std)], std=[1 / s for s in std]
    )


def get_color(
    distance: float,
    distance_threshold_new_individual: float = DISTANCE_THRESHOLD_NEW_INDIVIDUAL,
    margin: float = 0.10,
) -> str:
    threshold_unsure = distance_threshold_new_individual * (1.0 - margin)
    threshold_new_individual = distance_threshold_new_individual * (1 + margin)
    if distance < threshold_unsure:
        return "green"
    elif distance < threshold_new_individual:
        return "orange"
    else:
        return "red"


def draw_extrated_chip(ax, chip_image) -> None:
    ax.set_title("Extracted chip")
    ax.set_axis_off()
    ax.imshow(chip_image)


def draw_closest_neighbors(
    fig: Figure,
    gs: GridSpec,
    i_start: int,
    k_closest_neighbors: int,
    indexed_k_nearest_individuals: dict,
) -> None:
    inv_normalize = get_inverse_normalize_transform(
        mean=IMAGENET_MEAN,
        std=IMAGENET_STD,
    )

    neighbors = []
    for bear_id, xs in indexed_k_nearest_individuals.items():
        for x in xs:
            data = x.copy()
            data["bear_id"] = bear_id
            neighbors.append(data)

    nearest_neighbors = sorted(
        neighbors,
        key=lambda x: x["distance"],
    )[:k_closest_neighbors]
    for j, neighbor in enumerate(nearest_neighbors):
        ax = fig.add_subplot(gs[i_start, j])
        distance = neighbor["distance"]
        bear_id = neighbor["bear_id"]
        dataset_image = neighbor["dataset_image"]
        image = inv_normalize(dataset_image).numpy()
        image = np.transpose(image, (1, 2, 0))
        color = get_color(distance=distance)
        ax.set_axis_off()
        ax.set_title(label=f"{bear_id}: {distance:.2f}", color=color)
        ax.imshow(image)


def draw_top_k_individuals(
    fig: Figure,
    gs: GridSpec,
    i_start: int,
    i_end: int,
    indexed_k_nearest_individuals: dict,
    bear_ids: list[str],
    indexed_samples: dict,
):
    inv_normalize = get_inverse_normalize_transform(
        mean=IMAGENET_MEAN,
        std=IMAGENET_STD,
    )
    for i in range(i_start, i_end):
        for j in range(len(bear_ids)):
            # Draw the closest individual chips
            if i == i_start:
                ax = fig.add_subplot(gs[i, j])
                bear_id = bear_ids[j]
                nearest_individual = indexed_k_nearest_individuals[bear_id][0]
                distance = nearest_individual["distance"]
                dataset_image = nearest_individual["dataset_image"]
                image = inv_normalize(dataset_image).numpy()
                image = np.transpose(image, (1, 2, 0))
                color = get_color(distance=distance)
                ax.set_axis_off()
                ax.set_title(label=f"{bear_id}: {distance:.2f}", color=color)
                ax.imshow(image)

            # Draw random chips from the same individuals
            else:
                bear_id = bear_ids[j]
                idx = i - i_start - 1
                if idx < len(indexed_samples[bear_id]):
                    filepath = indexed_samples[bear_id][idx]
                    if filepath:
                        ax = fig.add_subplot(gs[i, j])
                        with Image.open(filepath) as image:
                            ax.set_axis_off()
                            ax.imshow(image)


def bearid_ui(
    pil_image_chip: Image.Image,
    indexed_k_nearest_individuals: dict,
    indexed_samples: dict,
    save_filepath: Path,
    k_closest_neighbors: int = 5,
) -> None:
    """Main UI for identifying bears."""
    chip_image = pil_image_chip
    # Assumption: the bear_ids are sorted by distance - if that's not something
    # we can rely on, we should just sort
    bear_ids = list(indexed_k_nearest_individuals.keys())

    # Max of the number of closest_neighbors and the number of bearids
    ncols = max(len(bear_ids), k_closest_neighbors)

    # 1 row for the closest neighbors title section
    # 1 row for the closest neighbors
    # 1 row for the individuals title section
    # rows for the indexed_samples (radom images of a given individual)
    nrows = max([len(xs) for xs in indexed_samples.values()]) + 3
    figsize = (3 * ncols, 3 * nrows)
    fig = plt.figure(constrained_layout=True, figsize=figsize)
    gs = GridSpec(nrows=nrows, ncols=ncols, figure=fig)
    font_properties_section = font_manager.FontProperties(size=35)
    font_properties_title = font_manager.FontProperties(size=40)

    # Draw closest neighbors
    i_closest_neighbors = 2
    ax = fig.add_subplot(gs[i_closest_neighbors - 1, :])
    ax.set_axis_off()
    ax.text(
        y=0.2,
        x=0,
        s="Closest faces",
        font_properties=font_properties_section,
    )
    draw_closest_neighbors(
        fig=fig,
        gs=gs,
        i_start=i_closest_neighbors,
        k_closest_neighbors=k_closest_neighbors,
        indexed_k_nearest_individuals=indexed_k_nearest_individuals,
    )
    # Filling out the grid with top k individuals and random samples
    i_top_k_individual = 4
    ax = fig.add_subplot(gs[i_top_k_individual - 1, :])
    ax.set_axis_off()
    ax.text(
        y=0.2,
        x=0,
        s=f"Closest {len(bear_ids)} individuals",
        font_properties=font_properties_section,
    )
    draw_top_k_individuals(
        fig=fig,
        gs=gs,
        i_end=nrows,
        i_start=i_top_k_individual,
        indexed_k_nearest_individuals=indexed_k_nearest_individuals,
        bear_ids=bear_ids,
        indexed_samples=indexed_samples,
    )

    plt.savefig(save_filepath, bbox_inches="tight")
    plt.close()