import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt

import pandas as pd
import matplotlib.patches as patches
import numpy as np
from PIL import Image
from zipfile import ZipFile
import gradio as gr

class SampleClass:
    def __init__(self):
        self.test_df = pd.read_json("data/full_pred_test_w_plurals_w_iou.json")
        self.val_df = pd.read_json("data/full_pred_val_w_plurals_w_iou.json")
        self.zip_file = ZipFile("data/saiapr_tc-12.zip", 'r')
        self.filtered_df = None

    def __get(self, img_path):
        img_obj = self.zip_file.open(img_path)
        img = Image.open(img_obj)
        # img = np.array(img)
        return img


    def __loadPredictions(self, split, model):
        assert(split in ['test','val'])
        assert(model in ['baseline','extended'])

        if split == "test":
            df = self.test_df 
        elif split == "val":
            df = self.val_df 
        else:
            raise ValueError("File not available yet")

        if model == 'baseline':
            df = df.rename(columns={'baseline_hit':'hit', 'baseline_pred':'predictions', 
                                    'extended_hit':'hit_other', 'extended_pred':'predictions_other',
                                    'baseline_iou':'iou', 
                                    'extended_iou':'iou_other'}
                        )

        elif model == 'extended':
            df = df.rename(columns={'extended_hit':'hit', 'extended_pred':'predictions', 
                                    'baseline_hit':'hit_other', 'baseline_pred':'predictions_other',
                                    'extended_iou':'iou', 
                                    'baseline_iou':'iou_other'}
                        )
        return df

    def __getSample(self, id):
        sample = self.filtered_df[self.filtered_df.sample_idx == id]

        sent = sample['sent'].values[0]
        pos_tags = sample['pos_tags'].values[0]
        plural_tks = sample['plural_tks'].values[0]

        cat_intrinsic = sample['intrinsic'].values[0]
        cat_spatial = sample['spatial'].values[0]
        cat_ordinal = sample['ordinal'].values[0]
        cat_relational = sample['relational'].values[0]
        cat_plural = sample['plural'].values[0]
        categories = [('instrinsic',cat_intrinsic),
                    ('spatial',cat_spatial),
                    ('ordinal',cat_ordinal),
                    ('relational',cat_relational),
                    ('plural',cat_plural)]

        hit = sample['hit'].values[0]
        hit_o = sample['hit_other'].values[0]

        iou = sample['iou'].values[0]
        iou_o = sample['iou_other'].values[0]

        prediction = {0:' FAIL ',1:' CORRECT '}

        bbox_gt = sample['bbox'].values[0]
        x1_gt,y1_gt,x2_gt,y2_gt = bbox_gt
        # x1_gt,y1_gt,x2_gt,y2_gt = tuple(map(float,bbox_gt[1:-1].split(",")))

        bp_bbox = sample['predictions'].values[0]
        x1_pred,y1_pred,x2_pred,y2_pred = bp_bbox
        # x1_pred,y1_pred,x2_pred,y2_pred = tuple(map(float,bp_bbox[1:-1].split(",")))

        bp_o_bbox = sample['predictions_other'].values[0]
        x1_pred_o,y1_pred_o,x2_pred_o,y2_pred_o = bp_o_bbox
        # x1_pred_o,y1_pred_o,x2_pred_o,y2_pred_o = tuple(map(float,bp_o_bbox[1:-1].split(",")))

        # Create Fig with predictions
        img_path = "saiapr_tc-12"+sample['file_path'].values[0].split("saiapr_tc-12")[1]
        img_seg_path = img_path.replace("images","segmented_images")


        fig, ax = plt.subplots(1)
        ax.imshow(self.__get(img_path), interpolation='bilinear')

        # Create bbox's
        rect_gt = patches.Rectangle((x1_gt,y1_gt), (x2_gt-x1_gt),(y2_gt-y1_gt), 
                                    linewidth=2, edgecolor='blue', facecolor='None') #fill=True, alpha=.3

        rect_pred = patches.Rectangle((x1_pred,y1_pred), (x2_pred-x1_pred),(y2_pred-y1_pred), 
                                    linewidth=2, edgecolor='lightgreen', facecolor='none')

        rect_pred_o = patches.Rectangle((x1_pred_o,y1_pred_o), (x2_pred_o-x1_pred_o),(y2_pred_o-y1_pred_o), 
                                    linewidth=2, edgecolor='red', facecolor='none')
        
        ax.add_patch(rect_gt)
        ax.add_patch(rect_pred)
        ax.add_patch(rect_pred_o)
        ax.axis('off')
        
        info = {'Expresion':sent,
                'Idx Sample':str(id),
                'IoU': str(round(iou,2)) + "("+prediction[hit]+")",
                'IoU other': str(round(iou_o,2)) + "("+prediction[hit_o]+")",
                'Pos Tags':str(pos_tags),
                'PluralTks ':plural_tks,
                'Categories':",".join([c for c,b in categories if b])
                }

        plt.title(info['Expresion'], fontsize=12)
        plt.tight_layout()
        plt.close(fig)

        fig.canvas.draw()
        data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        w, h = fig.canvas.get_width_height()
        img = data.reshape((int(h), int(w), -1))


        return info, img, self.__get(img_seg_path)

    def explorateSamples(self, 
            username,
            predictions,
            category, 
            model, 
            split,
            next_idx_sample):

        next_idx_sample = int(next_idx_sample)
        hit = {'fail':0,'correct':1}
        df = self.__loadPredictions(split, model)
        self.filtered_df = df[(df[category] == 1) & (df.hit == hit[predictions])]
        

        all_idx_samples = self.filtered_df.sample_idx.to_list()
        parts = np.array_split(list(all_idx_samples), 4)
        user_ids = {
            'luciana':list(parts[0]),
            'mauri':list(parts[1]),
            'jorge':list(parts[2]),
            'nano':list(parts[3])
        }

        try:
            id_ = user_ids[username].index(next_idx_sample)
        except:
            id_ = 0

        next_idx_sample = user_ids[username][ min(id_+1, len(user_ids[username])-1) ]
        progress = {f"{id_}/{len(user_ids[username])-1}":id_/(len(user_ids[username])-1)}
        info, img, img_seg = self.__getSample(user_ids[username][id_])
        info = "".join([str(k)+":\t"+str(v)+"\n" for k,v in list(info.items())[1:]]).strip()
        
        return (gr.Number.update(value=next_idx_sample),progress,img,info,img_seg)