from models import IntuitionKillingMachine
from transforms import undo_box_transforms_batch, ToTensor, Normalize, SquarePad, Resize, NormalizeBoxCoords
from torchvision.transforms import Compose
from encoders import get_tokenizer
from PIL import Image, ImageDraw
from zipfile import ZipFile
from copy import copy
import pandas as pd
import torch

def parse_model_args(model_path):
    _, _, dataset, max_length, input_size, backbone, num_heads, num_layers, num_conv, _, _, mu, mask_pooling = model_path.split('_')[:13]
    return {
        'dataset': dataset,
        'max_length': int(max_length),
        'input_size': int(input_size),
        'backbone': backbone,
        'num_heads': int(num_heads),
        'num_layers': int(num_layers),
        'num_conv': int(num_conv),
        'mu': float(mu),
        'mask_pooling': bool(mask_pooling == '1')
    }


class Prober:
    def __init__(self,
                 df_path=None,
                 dataset_path=None,
                 model_checkpoint=None):
        params = parse_model_args(model_checkpoint)
        mean = [0.485, 0.456, 0.406]
        sdev = [0.229, 0.224, 0.225]
        self.tokenizer = get_tokenizer()
        self.df = pd.read_json(df_path)[['sample_idx', 'bbox', 'file_path', 'sent']]
        self.df.loc[:, "image_id"] = self.df.file_path.apply(lambda x: int(x.split('/')[-1][:-4]))
        self.df.file_path = self.df.file_path.apply(lambda x: x.replace('refer/data/images/', ''))
        self.model = IntuitionKillingMachine(
            backbone=params['backbone'],
            pretrained=True,
            num_heads=params['num_heads'],
            num_layers=params['num_layers'],
            num_conv=params['num_conv'],
            segmentation_head=bool(params['mu'] > 0.0),
            mask_pooling=params['mask_pooling']
        ) 
        self.transform = Compose([
            ToTensor(),
            Normalize(mean, sdev),
            SquarePad(),
            Resize(size=(params['input_size'], params['input_size'])),
            NormalizeBoxCoords(),
        ])
        self.max_length = 30
        self.zipfile = ZipFile(dataset_path, 'r')

    @torch.no_grad()
    def probe(self, idx, re, search_by_sample_id: bool= True):
        if search_by_sample_id:
            img_path, target, = self.df.loc[idx][['file_path','bbox']].values
        else: 
            img_path, target = self.df[self.df.image_id == idx][['file_path','bbox']].values[0]
        img = Image.open(self.zipfile.open(img_path)).convert('RGB')
        W0, H0 = img.size
        sample = {
            'image': img,
            'image_size': (H0, W0),  # image original size
            'bbox': torch.tensor([copy(target)]),
            'bbox_raw': torch.tensor([copy(target)]),
            'mask': torch.ones((1, H0, W0), dtype=torch.float32),  # visibiity mask
            'mask_bbox': None,  # target bbox mask
        } 
        print('inn bbox: ', sample['bbox'])
        sample = self.transform(sample)
        tok = self.tokenizer(re,
                             max_length=30,
                             return_tensors='pt',
                             truncation=True)
        inn = {'image': torch.stack([sample['image']]),
               'mask': torch.stack([sample['mask']]),
               'bbox': torch.stack([sample['bbox']]),
               'tok': tok}
        output = undo_box_transforms_batch(self.model(inn)[0],
                                           [sample['tr_param']]).numpy().tolist()[0]
        img1 = ImageDraw.Draw(img)
        #img1.rectangle(target, outline ="#0000FF00", width=3)
        img1.rectangle(output, outline ="#00FF0000", width=3)
        return img

if __name__ == "__main__":
    prober = Prober(
        df_path = 'data/val-sim_metric.json',
        dataset_path = "data/saiapr_tc-12.zip",
        model_checkpoint= "cache/20211220_191132_refclef_32_512_resnet50_8_6_8_0.1_0.1_0.1_0_0.0001_0.0_12_4_90_1_0_0_0/best.ckpt"
    )
    prober.probe(0, "tree")
    print("Done")