import argparse
import os
import time

import cv2
import numpy as np
import requests
import torch
import wget
import yolov7
from mobile_sam import SamPredictor, sam_model_registry
from PIL import Image
from tqdm import tqdm
from transformers import YolosForObjectDetection, YolosImageProcessor

from images_to_video import VideoCreator
from video_to_images import ImageCreator


def download_mobile_sam_weight(path):
    if not os.path.exists(path):
        sam_weights = "https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/master/weights/mobile_sam.pt"
        for i in range(2, len(path.split("/"))):
            temp = path.split("/")[:i]
            cur_path = "/".join(temp)
            if not os.path.isdir(cur_path):
                os.mkdir(cur_path)
        model_name = path.split("/")[-1]
        if model_name in sam_weights:
            wget.download(sam_weights, path)
        else:
            raise NameError(
                "There is no pretrained weight to download for %s, you need to provide a path to segformer weights."
                % model_name
            )


def get_closest_bbox(bbox_list, bbox_target):
    """
    Given a list of bounding boxes, find the one that is closest to the target bounding box.
    Args:
        bbox_list: list of bounding boxes
        bbox_target: target bounding box
    Returns:
        closest bounding box

    """
    min_dist = 100000000
    min_idx = 0
    for idx, bbox in enumerate(bbox_list):
        dist = np.linalg.norm(bbox - bbox_target)
        if dist < min_dist:
            min_dist = dist
            min_idx = idx
    return bbox_list[min_idx]


def get_bboxes(image_file, image, model, image_processor, threshold=0.9):
    if image_processor is None:
        results = model(image_file)
        predictions = results.pred[0]
        boxes = predictions[:, :4].detach().numpy()
        return boxes
    else:
        inputs = image_processor(images=image, return_tensors="pt")
        outputs = model(**inputs)

        target_sizes = torch.tensor([image.size[::-1]])
        results = image_processor.post_process_object_detection(
            outputs, threshold=threshold, target_sizes=target_sizes
        )[0]

        return results["boxes"].detach().numpy()


def segment_video(
    video_filename,
    dir_frames,
    image_start,
    image_end,
    bbox_file,
    skip_vid2im,
    mobile_sam_weights,
    auto_detect=False,
    tracker_name="yolov7",
    background_color="#009000",
    output_dir="output_frames",
    output_video="output.mp4",
    pbar=False,
    reverse_mask=False,
):
    if not skip_vid2im:
        vid_to_im = ImageCreator(
            video_filename,
            dir_frames,
            image_start=image_start,
            image_end=image_end,
            pbar=pbar,
        )
        vid_to_im.get_images()
    # Get fps of video
    vid = cv2.VideoCapture(video_filename)
    fps = vid.get(cv2.CAP_PROP_FPS)
    vid.release()
    background_color = background_color.lstrip("#")
    background_color = (
        np.array([int(background_color[i : i + 2], 16) for i in (0, 2, 4)]) / 255.0
    )

    with open(bbox_file, "r") as f:
        bbox_orig = [int(coord) for coord in f.read().split(" ")]
    download_mobile_sam_weight(mobile_sam_weights)
    if image_end == 0:
        frames = sorted(os.listdir(dir_frames))[image_start:]
    else:
        frames = sorted(os.listdir(dir_frames))[image_start:image_end]

    model_type = "vit_t"

    if torch.backends.mps.is_available():
        device = "mps"
    elif torch.cuda.is_available():

        device = "cuda"
    else:
        device = "cpu"
    sam = sam_model_registry[model_type](checkpoint=mobile_sam_weights)
    sam.to(device=device)
    sam.eval()

    predictor = SamPredictor(sam)

    if not auto_detect:
        if tracker_name == "yolov7":
            model = yolov7.load("kadirnar/yolov7-tiny-v0.1", hf_model=True)
            model.conf = 0.25  # NMS confidence threshold
            model.iou = 0.45  # NMS IoU threshold
            model.classes = None
            image_processor = None
        else:
            model = YolosForObjectDetection.from_pretrained("hustvl/yolos-tiny")
            image_processor = YolosImageProcessor.from_pretrained("hustvl/yolos-tiny")

    output_frames = []

    if pbar:
        pb = tqdm(frames)
    else:
        pb = frames

    processed_frames = 0
    init_time = time.time()
    for frame in pb:
        processed_frames += 1
        image_file = dir_frames + "/" + frame
        image_pil = Image.open(image_file)
        image_np = np.array(image_pil)
        if not auto_detect:
            bboxes = get_bboxes(image_file, image_pil, model, image_processor)
            closest_bbox = get_closest_bbox(bboxes, bbox_orig)
            input_box = np.array(closest_bbox)
        else:
            input_box = np.array([0, 0, image_np.shape[1], image_np.shape[0]])
        predictor.set_image(image_np)
        masks, _, _ = predictor.predict(
            point_coords=None,
            point_labels=None,
            box=input_box[None, :],
            multimask_output=True,
        )
        if reverse_mask:
            mask = masks[0]
            h, w = mask.shape[-2:]
            mask_image = (
                (mask).reshape(h, w, 1) * background_color.reshape(1, 1, -1)
            ) * 255
            masked_image = image_np * (1 - mask).reshape(h, w, 1)
            masked_image = masked_image + mask_image
            output_frames.append(masked_image)
        else:
            mask = masks[0]
            h, w = mask.shape[-2:]
            mask_image = (
                (1 - mask).reshape(h, w, 1) * background_color.reshape(1, 1, -1)
            ) * 255
            masked_image = image_np * mask.reshape(h, w, 1)
            masked_image = masked_image + mask_image
            output_frames.append(masked_image)

        if not pbar and processed_frames % 10 == 0:
            remaining_time = (
                (time.time() - init_time)
                / processed_frames
                * (len(frames) - processed_frames)
            )
            remaining_time = int(remaining_time)
            remaining_time_str = f"{remaining_time//60}m {remaining_time%60}s"
            print(
                f"Processed frame {processed_frames}/{len(frames)} - Remaining time: {remaining_time_str}"
            )
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    zfill_max = len(str(len(output_frames)))
    for idx, frame in enumerate(output_frames):
        cv2.imwrite(
            f"{output_dir}/frame_{str(idx).zfill(zfill_max)}.png",
            frame,
        )
    vid_creator = VideoCreator(output_dir, output_video, pbar=pbar)
    vid_creator.create_video(fps=int(fps))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--video_filename",
        default="assets/example.mp4",
        type=str,
        help="path to the video",
    )
    parser.add_argument(
        "--dir_frames",
        type=str,
        default="frames",
        help="path to the directory in which all input frames will be stored",
    )
    parser.add_argument(
        "--image_start", type=int, default=0, help="first image to be stored"
    )
    parser.add_argument(
        "--image_end",
        type=int,
        default=0,
        help="last image to be stored, last one if 0",
    )
    parser.add_argument(
        "--bbox_file",
        type=str,
        default="bbox.txt",
        help="path to the bounding box text file",
    )
    parser.add_argument(
        "--skip_vid2im",
        action="store_true",
        help="whether to write the video frames as images",
    )
    parser.add_argument(
        "--mobile_sam_weights",
        type=str,
        default="./models/mobile_sam.pt",
        help="path to MobileSAM weights",
    )

    parser.add_argument(
        "--tracker_name",
        type=str,
        default="yolov7",
        help="tracker name",
        choices=["yolov7", "yoloS"],
    )

    parser.add_argument(
        "--output_dir",
        type=str,
        default="output_frames",
        help="directory to store the output frames",
    )

    parser.add_argument(
        "--output_video",
        type=str,
        default="output.mp4",
        help="path to store the output video",
    )
    parser.add_argument(
        "--auto_detect",
        action="store_true",
        help="whether to use a bounding box to force the model to segment the object",
    )
    parser.add_argument(
        "--background_color",
        type=str,
        default="#009000",
        help="background color for the output (hex)",
    )
    args = parser.parse_args()

    segment_video(
        args.video_filename,
        args.dir_frames,
        args.image_start,
        args.image_end,
        args.bbox_file,
        args.skip_vid2im,
        args.mobile_sam_weights,
        args.auto_detect,
        args.output_dir,
        args.output_video,
        args.tracker_name,
        args.background_color,
    )