import os
import cv2
import torch
import numpy as np
import gradio as gr
from PIL import Image, ImageDraw
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from transformers import OwlViTProcessor, OwlViTForObjectDetection
import gc

models = {
	'vit_b': './checkpoints/sam_vit_b_01ec64.pth',
	'vit_l': './checkpoints/sam_vit_l_0b3195.pth',
	'vit_h': './checkpoints/sam_vit_h_4b8939.pth'
}

image_examples = [
    [os.path.join(os.path.dirname(__file__), "./images/53960-scaled.jpg"), 0, []],
    [os.path.join(os.path.dirname(__file__), "./images/2388455-scaled.jpg"), 1, []],
    [os.path.join(os.path.dirname(__file__), "./images/1.jpg"),2,[]],
    [os.path.join(os.path.dirname(__file__), "./images/2.jpg"),3,[]],
    [os.path.join(os.path.dirname(__file__), "./images/3.jpg"),4,[]],
    [os.path.join(os.path.dirname(__file__), "./images/4.jpg"),5,[]],
    [os.path.join(os.path.dirname(__file__), "./images/5.jpg"),6,[]],
    [os.path.join(os.path.dirname(__file__), "./images/6.jpg"),7,[]],
    [os.path.join(os.path.dirname(__file__), "./images/7.jpg"),8,[]],
    [os.path.join(os.path.dirname(__file__), "./images/8.jpg"),9,[]]
]


def plot_boxes(img, boxes):
	img_pil = Image.fromarray(np.uint8(img * 255)).convert('RGB')
	draw = ImageDraw.Draw(img_pil)
	for box in boxes:
		color = tuple(np.random.randint(0, 255, size=3).tolist())
		x0, y0, x1, y1 = box
		x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
		draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
	return img_pil


def segment_one(img, mask_generator, seed=None):
	if seed is not None:
		np.random.seed(seed)
	masks = mask_generator.generate(img)
	sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
	mask_all = np.ones((img.shape[0], img.shape[1], 3))
	for ann in sorted_anns:
		m = ann['segmentation']
		color_mask = np.random.random((1, 3)).tolist()[0]
		for i in range(3):
			mask_all[m == True, i] = color_mask[i]
	result = img / 255 * 0.3 + mask_all * 0.7
	return result, mask_all


def generator_inference(device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh,
                        min_mask_region_area, stability_score_offset, box_nms_thresh, crop_n_layers, crop_nms_thresh,
                        input_x, progress=gr.Progress()):
	# sam model
	sam = sam_model_registry[model_type](checkpoint=models[model_type]).to(device)
	mask_generator = SamAutomaticMaskGenerator(
		sam,
		points_per_side=points_per_side,
		pred_iou_thresh=pred_iou_thresh,
		stability_score_thresh=stability_score_thresh,
		stability_score_offset=stability_score_offset,
		box_nms_thresh=box_nms_thresh,
		crop_n_layers=crop_n_layers,
		crop_nms_thresh=crop_nms_thresh,
		crop_overlap_ratio=512 / 1500,
		crop_n_points_downscale_factor=1,
		point_grids=None,
		min_mask_region_area=min_mask_region_area,
		output_mode='binary_mask'
	)

	# input is image, type: numpy
	if type(input_x) == np.ndarray:
		result, mask_all = segment_one(input_x, mask_generator)
		return result, mask_all
	elif isinstance(input_x, str):  # input is video, type: path (str)
		cap = cv2.VideoCapture(input_x)  # read video
		frames_num = cap.get(cv2.CAP_PROP_FRAME_COUNT)
		W, H = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
		fps = int(cap.get(cv2.CAP_PROP_FPS))
		out = cv2.VideoWriter("output.mp4", cv2.VideoWriter_fourcc('x', '2', '6', '4'), fps, (W, H), isColor=True)
		for _ in progress.tqdm(range(int(frames_num)),
		                       desc='Processing video ({} frames, size {}x{})'.format(int(frames_num), W, H)):
			ret, frame = cap.read()  # read a frame
			result, mask_all = segment_one(frame, mask_generator, seed=2023)
			result = (result * 255).astype(np.uint8)
			out.write(result)
		out.release()
		cap.release()
		return 'output.mp4'


def predictor_inference(device, model_type, input_x, input_text, selected_points, owl_vit_threshold=0.1):
	# sam model
	sam = sam_model_registry[model_type](checkpoint=models[model_type]).to(device)
	predictor = SamPredictor(sam)
	predictor.set_image(input_x)  # Process the image to produce an image embedding

	if input_text != '':
		# split input text
		input_text = [input_text.split(',')]
		print(input_text)
		# OWL-ViT model
		processor = OwlViTProcessor.from_pretrained('./checkpoints/models--google--owlvit-base-patch32')
		owlvit_model = OwlViTForObjectDetection.from_pretrained("./checkpoints/models--google--owlvit-base-patch32").to(device)
		# get outputs
		input_text = processor(text=input_text, images=input_x, return_tensors="pt").to(device)
		outputs = owlvit_model(**input_text)
		target_size = torch.Tensor([input_x.shape[:2]]).to(device)
		results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_size,
		                                                  threshold=owl_vit_threshold)

		# get the box with best score
		scores = torch.sigmoid(outputs.logits)
		# best_scores, best_idxs = torch.topk(scores, k=1, dim=1)
		# best_idxs = best_idxs.squeeze(1).tolist()

		i = 0  # Retrieve predictions for the first image for the corresponding text queries
		boxes_tensor = results[i]["boxes"]  # [best_idxs]
		boxes = boxes_tensor.cpu().detach().numpy()
		# boxes = boxes[np.newaxis, :, :]
		transformed_boxes = predictor.transform.apply_boxes_torch(torch.Tensor(boxes).to(device),
		                                                          input_x.shape[:2])  # apply transform to original boxes
		# transformed_boxes = transformed_boxes.unsqueeze(0)
		print(transformed_boxes.size(), boxes.shape)
	else:
		transformed_boxes = None

	# points
	if len(selected_points) != 0:
		points = torch.Tensor([p for p, _ in selected_points]).to(device).unsqueeze(1)
		labels = torch.Tensor([int(l) for _, l in selected_points]).to(device).unsqueeze(1)
		transformed_points = predictor.transform.apply_coords_torch(points, input_x.shape[:2])
		print(points.size(), transformed_points.size(), labels.size(), input_x.shape, points)
	else:
		transformed_points, labels = None, None

	# predict segmentation according to the boxes
	masks, scores, logits = predictor.predict_torch(
		point_coords=transformed_points,
		point_labels=labels,
		boxes=transformed_boxes,  # only one box
		multimask_output=False,
	)
	masks = masks.cpu().detach().numpy()
	mask_all = np.ones((input_x.shape[0], input_x.shape[1], 3))
	for ann in masks:
		color_mask = np.random.random((1, 3)).tolist()[0]
		for i in range(3):
			mask_all[ann[0] == True, i] = color_mask[i]
	img = input_x / 255 * 0.3 + mask_all * 0.7
	if input_text != '':
		img = plot_boxes(img, boxes_tensor)  # image + mask + boxes

	# free the memory
	if input_text != '':
		owlvit_model.cpu()
		del owlvit_model
	del input_text
	gc.collect()
	torch.cuda.empty_cache()

	return img, mask_all


def run_inference(device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh, min_mask_region_area,
                  stability_score_offset, box_nms_thresh, crop_n_layers, crop_nms_thresh, owl_vit_threshold, input_x,
                  input_text, selected_points):
	# if input_x is int, the image is selected from examples
	if isinstance(input_x, int):
		input_x = cv2.imread(image_examples[input_x][0])
		input_x = cv2.cvtColor(input_x, cv2.COLOR_BGR2RGB)
	if (input_text != '' and not isinstance(input_x, str)) or len(selected_points) != 0:  # user input text or points
		print('use predictor_inference')
		print('prompt text: ', input_text)
		print('prompt points length: ', len(selected_points))
		return predictor_inference(device, model_type, input_x, input_text, selected_points, owl_vit_threshold)
	else:
		print('use generator_inference')
		return generator_inference(device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh,
		                           min_mask_region_area, stability_score_offset, box_nms_thresh, crop_n_layers,
		                           crop_nms_thresh, input_x)