import matplotlib.pyplot as plt

from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F

from constants import COLORS
from utils import fig2img


def visualize_prediction(
    pil_img, output_dict, threshold=0.7, id2label=None, display_mask=False, mask=None
):
    keep = output_dict["scores"] > threshold
    boxes = output_dict["boxes"][keep].tolist()
    scores = output_dict["scores"][keep].tolist()
    labels = output_dict["labels"][keep].tolist()
    if id2label is not None:
        labels = [id2label[x] for x in labels]

    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(pil_img)
    if display_mask and mask is not None:
        mask_arr = np.asarray(mask)
        new_mask = np.zeros_like(mask_arr)
        new_mask[mask_arr > 0] = 255
        new_mask = Image.fromarray(new_mask)
        ax.imshow(new_mask, alpha=0.5, cmap='viridis')
       
    colors = COLORS * 100
    counter = 0
    for score, (xmin, ymin, xmax, ymax), label, color in zip(
        scores, boxes, labels, colors
    ):
        counter += 1
        ax.add_patch(
            plt.Rectangle(
                (xmin, ymin),
                xmax - xmin,
                ymax - ymin,
                fill=False,
                color=color,
                linewidth=2,
            )
        )
        ax.text(
            xmin,
            ymin,
            f"[{counter}] {score:0.2f}",
            fontsize=8,
            bbox=dict(facecolor="yellow", alpha=0.5),
        )
    ax.axis("off")
    return fig2img(fig)


def visualize_attention_map(pil_img, attention_map):
    attention_map = attention_map[-1].detach().cpu()
    
    n_heads = attention_map.shape[1]
    
    avg_attention_weight = torch.mean(attention_map, dim=1).squeeze()
    
    resized_attention_weight = F.interpolate(
        avg_attention_weight.unsqueeze(0).unsqueeze(0),
        size=pil_img.size[::-1],
        mode="bicubic",
    ).squeeze().numpy()
    
    fig, axes = plt.subplots(nrows=1, ncols=n_heads, figsize=(n_heads*4, 4))
    
    for i, ax in enumerate(axes.flat):
        ax.imshow(pil_img)
        ax.imshow(attention_map[0,i,:,:].squeeze(), alpha=0.7, cmap="viridis")
        ax.set_title(f"Head {i+1}")
        ax.axis("off")
    
    plt.tight_layout()
    
    return fig2img(fig)