import copy
import numpy as np
import torch

import sys
sys.path.append("./")
from models import sam_model_registry
from models.grasp_mods import modify_forward
from models.utils.transforms import ResizeLongestSide

from gradio_image_prompter import ImagePrompter
from structures.grasp_box import GraspCoder
img_resize = ResizeLongestSide(1024)
import cv2

import gradio as gr

from models.grasp_mods import add_inference_method

device = "cuda" if torch.cuda.is_available() else "cpu"
model_type = "vit_b"

mean = np.array([103.53, 116.28, 123.675])[:, np.newaxis, np.newaxis]
std = np.array([57.375, 57.12, 58.395])[:, np.newaxis, np.newaxis]

sam = sam_model_registry[model_type]()
sam.to(device=device)

sam.forward = modify_forward(sam)
sam.infer = add_inference_method(sam)

pretrained_model_path = "./epoch_39_step_415131.pth"

if pretrained_model_path != "":
    sd = torch.load(pretrained_model_path, map_location='cpu')
    # strip prefix "module." from keys
    new_sd = {}
    for k, v in sd.items():
        if k.startswith("module."):
            k = k[7:]
        new_sd[k] = v
    sam.load_state_dict(new_sd)

sam.eval()

def predict(input, topk):
    np_image = input["image"]
    points = input["points"]
    orig_size = np_image.shape[:2]
    # normalize image
    np_image = np_image.transpose(2, 0, 1)

    image = (np_image - mean) / std
    image = torch.tensor(image).float().to(device)
    image = image.unsqueeze(0)
    t_image = img_resize.apply_image_torch(image)
    t_orig_size = t_image.shape[-2:]
    # pad to 1024x1024
    pixel_mask = torch.ones(1, t_orig_size[0], t_orig_size[1], device=device)
    t_image = torch.nn.functional.pad(t_image, (0, 1024 - t_image.shape[-1], 0, 1024 - t_image.shape[-2]))
    pixel_mask = torch.nn.functional.pad(pixel_mask, (0, 1024 - t_orig_size[1], 0, 1024 - t_orig_size[0]))

    # get box prompt
    valid_boxes = []
    for point in points:
        x1, y1, type1, x2, y2, type2 = point
        if type1 == 2 and type2 == 3:
            valid_boxes.append([x1, y1, x2, y2])
    if len(valid_boxes) == 0:
        return np_image
    t_boxes = np.array(valid_boxes)
    t_boxes = img_resize.apply_boxes(t_boxes, orig_size)
    box_torch = torch.as_tensor(t_boxes, dtype=torch.float, device=device)
    batched_inputs = [{"image": t_image[0], "boxes": box_torch, "pixel_mask": pixel_mask}]
    with torch.no_grad():
        outputs = sam.infer(batched_inputs, multimask_output=False)
    # visualize and post on tensorboard
    # recover image
    recovered_img = batched_inputs[0]['image'].cpu().numpy()
    recovered_img = recovered_img * std + mean
    recovered_img = recovered_img.transpose(1, 2, 0).astype(np.uint8).clip(0, 255)

    for i in range(len(outputs.pred_masks)):
        # get predicted mask
        pred_mask = outputs.pred_masks[i].detach().sigmoid().cpu().numpy() > 0.5
        pred_mask = pred_mask.transpose(1, 2, 0).repeat(3, axis=2)

        # get predicted grasp
        pred_logits = outputs.logits[i].detach().cpu().numpy()
        top_ind = pred_logits[:, 0].argsort()[-topk:][::-1]
        pred_grasp = outputs.pred_boxes[i].detach().cpu().numpy()[top_ind]
        coded_grasp = GraspCoder(t_orig_size[0], t_orig_size[1], None, grasp_annos_reformat=pred_grasp)
        _ = coded_grasp.decode()
        decoded_grasp = copy.deepcopy(coded_grasp.grasp_annos)

        # draw mask
        mask_color = np.array([0, 255, 0])[None, None, :]
        recovered_img[pred_mask] = recovered_img[pred_mask] * 0.5 + (pred_mask * mask_color)[pred_mask] * 0.5

        # draw grasp
        recovered_img = np.ascontiguousarray(recovered_img)
        for grasp in decoded_grasp:
            grasp = grasp.astype(int)
            cv2.line(recovered_img, tuple(grasp[0:2]), tuple(grasp[2:4]), (255, 0, 0), 1)
            cv2.line(recovered_img, tuple(grasp[4:6]), tuple(grasp[6:8]), (255, 0, 0), 1)
            cv2.line(recovered_img, tuple(grasp[2:4]), tuple(grasp[4:6]), (0, 0, 255), 2)
            cv2.line(recovered_img, tuple(grasp[6:8]), tuple(grasp[0:2]), (0, 0, 255), 2)

    recovered_img = recovered_img[:t_orig_size[0], :t_orig_size[1]]
    # resize to original size
    recovered_img = cv2.resize(recovered_img, (orig_size[1], orig_size[0]))
    return recovered_img

if __name__ == "__main__":
    app = gr.Blocks(title="GraspAnything")
    with app:
        gr.Markdown("""
        # GraspAnything <br>
        Upload an image and draw a box around the object you want to grasp. Set top k to be the number of grasps you want to predict for each object.
        """)
        with gr.Column():
            prompter = ImagePrompter(show_label=False)
            top_k = gr.Slider(minimum=1, maximum=20, step=1, value=3, label="Top K Grasps")
        with gr.Column():
            image_output = gr.Image()
        btn = gr.Button("Generate!")
        btn.click(predict,
                  inputs=[prompter, top_k],
                  outputs=[image_output])
    app.launch()