import matplotlib.pyplot as plt
import requests, validators
import torch
import pathlib
import numpy as np
from PIL import Image
import cv2 as cv

from transformers import DetrImageProcessor, DetrForSegmentation, MaskFormerImageProcessor, MaskFormerForInstanceSegmentation
# from transformers.models.detr.feature_extraction_detr import rgb_to_id
from transformers.image_transforms import rgb_to_id, id_to_rgb

TEST_IMAGE = Image.open(r"images/9999999_00783_d_0000358.jpg")
MODEL_NAME_DETR = "facebook/detr-resnet-50-panoptic"
MODEL_NAME_MASKFORMER = "facebook/maskformer-swin-large-coco"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#######
# Parameters
#######
image = TEST_IMAGE
model_name = MODEL_NAME_MASKFORMER

# Starting with MaskFormer

processor = MaskFormerImageProcessor.from_pretrained(model_name) # <class 'transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor'>
# DIR() --> ['__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', 
#           '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', 
#           '__weakref__', '_auto_class', '_create_repo', '_get_files_timestamps', '_max_size', '_pad_image', '_preprocess', '_preprocess_image', '_preprocess_mask', '_processor_class', 
#           '_set_processor_class', '_upload_modified_files', 'center_crop', 'convert_segmentation_map_to_binary_masks', 'do_normalize', 'do_reduce_labels', 'do_rescale', 'do_resize', 
#           'encode_inputs', 'fetch_images', 'from_dict', 'from_json_file', 'from_pretrained', 'get_image_processor_dict', 'ignore_index', 'image_mean', 'image_std', 'model_input_names', 
#           'normalize', 'pad', 'post_process_instance_segmentation', 'post_process_panoptic_segmentation', 'post_process_segmentation', 'post_process_semantic_segmentation', 'preprocess', 
#           'push_to_hub', 'register_for_auto_class', 'resample', 'rescale', 'rescale_factor', 'resize', 'save_pretrained', 'size', 'size_divisor', 'to_dict', 'to_json_file', 'to_json_string']

model = MaskFormerForInstanceSegmentation.from_pretrained(model_name) # <class 'transformers.models.maskformer.modeling_maskformer.MaskFormerForInstanceSegmentation'>
# DIR for model was too big
model.to(DEVICE)

# img = np.array(TEST_IMAGE)

inputs = processor(images=image, return_tensors="pt") # <class 'transformers.image_processing_utils.BatchFeature'>
# DIR() --> ['_MutableMapping__marker', '__abstractmethods__', '__class__', '__contains__', '__copy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', 
#           '__ge__', '__getattr__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', 
#           '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__slots__', '__str__', 
#           '__subclasshook__', '__weakref__', '_abc_impl', '_get_is_as_tensor_fns', 'clear', 'convert_to_tensors', 'copy', 'data', 'fromkeys', 'get', 'items', 'keys', 'pop', 'popitem', 
#           'setdefault', 'to', 'update', 'values']
inputs.to(DEVICE)


outputs = model(**inputs) # <class 'transformers.models.maskformer.modeling_maskformer.MaskFormerForInstanceSegmentationOutput'>
# Each element of this class is a <class 'torch.Tensor'>
# DIR() --> ['__annotations__', '__class__', '__contains__', '__dataclass_fields__', '__dataclass_params__', '__delattr__', '__delitem__', '__dict__', '__dir__', 
#           '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', 
#           '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__post_init__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__setattr__', 
#           '__setitem__', '__sizeof__', '__str__', '__subclasshook__', 'attentions', 'auxiliary_logits', 'class_queries_logits', 'clear', 'copy', 'encoder_hidden_states', 
#           'encoder_last_hidden_state', 'fromkeys', 'get', 'hidden_states', 'items', 'keys', 'loss', 'masks_queries_logits', 'move_to_end', 'pixel_decoder_hidden_states', 
#           'pixel_decoder_last_hidden_state', 'pop', 'popitem', 'setdefault', 'to_tuple', 'transformer_decoder_hidden_states', 'transformer_decoder_last_hidden_state', 
#           'update', 'values']

results = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
# <class 'dict'>
# Keys: dict_keys(['segmentation', 'segments_info'])
# type(results["segments_info"]) --> list
# type(results["segmentation"]) --> <class 'torch.Tensor'>


def show_mask_for_number(map_to_use, label_id):
    """
    map_to_use: You have to pass in `results["segmentation"]`
    """
    if torch.cuda.is_available():
        mask = (map_to_use.cpu().numpy() == label_id)
    else:
        mask = (map_to_use.numpy() == label_id)
    
    visual_mask = (mask* 255).astype(np.uint8)
    visual_mask = Image.fromarray(visual_mask)
    plt.imshow(visual_mask)
    plt.show()

def show_mask_for_number_over_image(map_to_use, label_id, image_object):
    """
    map_to_use: You have to pass in `results["segmentation"]`
    """
    if torch.cuda.is_available():
        mask = (map_to_use.cpu().numpy() == label_id)
    else:
        mask = (map_to_use.numpy() == label_id)
    
    visual_mask = (mask* 255).astype(np.uint8)
    visual_mask = Image.fromarray(visual_mask)
    plt.imshow(image_object)
    plt.imshow(visual_mask, alpha=0.25)
    plt.show()


def get_coordinates_for_bb_simple(map_to_use, label_id):
    """
    map_to_use: You have to pass in `results["segmentation"]`
    """
    if torch.cuda.is_available():
        mask = (map_to_use.cpu().numpy() == label_id)
    else:
        mask = (map_to_use.numpy() == label_id)
    
    y_vals, x_vals = np.where(mask==True)
    x_max, x_min = max(x_vals), min(x_vals)
    y_max, y_min = max(y_vals), min(y_vals)
    return (x_min, y_min), (x_max, y_max)

def make_simple_box(left_top, right_bottom, map_size):
    full_mask = np.full(map_size, False)
    left_x, top_y = left_top
    right_x, bottom_y = right_bottom
    full_mask[left_x:right_x, top_y] = True
    full_mask[left_x:right_x, bottom_y] = True
    full_mask[left_x, top_y:bottom_y] = True
    full_mask[right_x, top_y:bottom_y] = True

    visual_mask = (full_mask* 255).astype(np.uint8)
    visual_mask = Image.fromarray(visual_mask)
    plt.imshow(visual_mask)
    plt.show()


def map_bounding_box_draw(map_to_use, label_id, img_obj=TEST_IMAGE):
    """
    map_to_use: You have to pass in `results["segmentation"]`
    """
    if torch.cuda.is_available():
        mask = (map_to_use.cpu().numpy() == label_id)
    else:
        mask = (map_to_use.numpy() == label_id)
    
    
    lt, rb = get_coordinates_for_bb_simple(map_to_use, label_id)
    left_x, top_y = lt
    right_x, bottom_y = rb
    
    mask[top_y, left_x:right_x] = .5
    mask[bottom_y, left_x:right_x] = .5
    mask[ top_y:bottom_y, left_x] = .5
    mask[ top_y:bottom_y, right_x] = .5

    visual_mask = (mask* 255).astype(np.uint8)
    visual_mask = Image.fromarray(visual_mask)
    plt.imshow(img_obj)
    plt.imshow(visual_mask, alpha=0.25)
    plt.show()

def contour_map(map_to_use, label_id):
    """
    map_to_use: You have to pass in `results["segmentation"]`
    """
    if torch.cuda.is_available():
        mask = (map_to_use.cpu().numpy() == label_id)
    else:
        mask = (map_to_use.numpy() == label_id)
    
    visual_mask = (mask* 255).astype(np.uint8)
    contours, hierarchy = cv.findContours(visual_mask, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)
    return contours, hierarchy


# https://docs.opencv.org/4.9.0/dd/d49/tutorial_py_contour_features.html


# Idea for determining if close
#   https://dsp.stackexchange.com/questions/2564/opencv-c-connect-nearby-contours-based-on-distance-between-them
#   Bing Search: cv determine if 2 contours belong together

def find_if_close(contour1, contour2, c_dist=50):
    """
    Source: https://dsp.stackexchange.com/questions/2564/opencv-c-connect-nearby-contours-based-on-distance-between-them

    """
    row1, row2 = contour1.shape[0], contour2.shape[0]
    for i in range(row1):
        for j in range(row2):
            dist = np.linalg.norm(contour1[i]-contour2[j])
            if abs(dist) < c_dist:
                return True
            elif i == (row1-1) and j == (row2-1):
                return False


def make_new_bounding_box(bb1, bb2):
    x1, y1, w1, h1 = bb1
    x2, y2, w2, h2 = bb2
    new_x = min(x1, x2)
    new_y = min(y1, y2)
    new_w = abs(max(x1+w1, x2+w2) - new_x)
    new_h = abs(max(y1+h1, y2+h2) - new_y)

    return (new_x, new_y, new_w, new_h)

def map_bounding_box_draw(map_to_use, label_id, img_obj=TEST_IMAGE, v="cv"):
    """
    map_to_use: You have to pass in `results["segmentation"]`
    v: version of bounding box
        cv, coord
    """
    if torch.cuda.is_available():
        mask = (map_to_use.cpu().numpy() == label_id)
    else:
        mask = (map_to_use.numpy() == label_id)
    
    if v == "cv":
        c, v = contour_map(map_to_use, label_id)
        x, y, w, h = make_new_bounding_box(cv.boundingRect(c[0]), cv.boundingRect(c[1]))
        lt = (x, y)
        rb = (x + w, y + h)
        left_x, top_y = lt
        right_x, bottom_y = rb
    elif v == "coord":
        lt, rb = get_coordinates_for_bb_simple(map_to_use, label_id)
        left_x, top_y = lt
        right_x, bottom_y = rb
    else:
        print(f"Not available `v` command {v}")
        return
    
    mask[top_y, left_x:right_x] = .5
    mask[bottom_y, left_x:right_x] = .5
    mask[ top_y:bottom_y, left_x] = .5
    mask[ top_y:bottom_y, right_x] = .5

    visual_mask = (mask* 255).astype(np.uint8)
    visual_mask = Image.fromarray(visual_mask)
    plt.imshow(img_obj)
    plt.imshow(visual_mask, alpha=0.25)
    plt.show(block=False)


def manual_looking(id_number):
    c, _ = contour_map(results["segmentation"], id_number)
    print(f'{model.config.id2label[results["segments_info"][id_number-1]["label_id"]]}, {results["segments_info"][id_number -1]["score"]}, Contour Count: {len(c)}')
    show_mask_for_number_over_image(results["segmentation"],id_number, TEST_IMAGE)