import os

import cv2 as cv
import moviepy.editor as mpe
import numpy as np
import supervision as sv
import torch
from hyper import hp
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
from PIL import Image
from tqdm import tqdm


def detect(frame, model, processor, confidence_threshold):
    """
    args:
        image: PIL image
        model: PreTrainedModel
        processor: PreTrainedProcessor
        confidence_threshold: float
    returns:
        results: dict with keys "boxes", "labels", "scores"


    examples:
    [
        {
            "scores": tensor([0.9980, 0.9039, 0.7575, 0.9033]),
            "labels": tensor([86, 64, 67, 67]),
            "boxes": tensor(
                [
                    [1.1582e03, 1.1893e03, 1.9373e03, 1.9681e03],
                    [2.4274e02, 1.3234e02, 2.5919e03, 1.9628e03],
                    [1.1107e-01, 1.5105e03, 3.1980e03, 2.1076e03],
                    [7.1036e-01, 1.7360e03, 3.1970e03, 2.1100e03],
                ]
            ),
        }
    ]
    """
    inputs = processor(images=frame, return_tensors="pt").to(hp.device)
    with torch.no_grad():
        outputs = model(**inputs)
    target_sizes = torch.tensor([frame.size[::-1]])

    results = processor.post_process_object_detection(
        outputs=outputs, threshold=confidence_threshold, target_sizes=target_sizes
    )
    return results


def get_len_frames(viedo_path):
    """
    args:
        viedo_path: str
    returns:
        int: the number of frames in the video
    examples:
        get_len_frames("../demo_video/aerial.mp4") # 1478
    """
    video_info = sv.VideoInfo.from_video_path(viedo_path)
    return video_info.total_frames


def track(detected_result, tracker: sv.ByteTrack):
    """
    args:
        detected_result: dict with keys "boxes", "labels", "scores"
        tracker: sv.ByteTrack
    returns:
        tracked_result: dict with keys "boxes", "labels", "scores"
    examples:
        from transformers import DetrImageProcessor, DetrForObjectDetection

        processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
        model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")

        tracker = sv.ByteTrack()

        image = Image.open("ZJF990.jpg")

        detected_result = detect(image, model, processor, hp.confidence_threshold)
        tracked_result = track(detected_result, tracker)

        print(detected_result)
        print(tracked_result)


        [
            {
                "scores": tensor([0.9980, 0.9039, 0.7575, 0.9033]),
                "labels": tensor([86, 64, 67, 67]),
                "boxes": tensor(
                    [
                        [1.1582e03, 1.1893e03, 1.9373e03, 1.9681e03],
                        [2.4274e02, 1.3234e02, 2.5919e03, 1.9628e03],
                        [1.1107e-01, 1.5105e03, 3.1980e03, 2.1076e03],
                        [7.1036e-01, 1.7360e03, 3.1970e03, 2.1100e03],
                    ]
                ),
            }
        ]



        Detections(
            xyxy=array(
                [
                    [1.1581914e03, 1.1892766e03, 1.9372931e03, 1.9680990e03],
                    [2.4273552e02, 1.3233553e02, 2.5918860e03, 1.9628494e03],
                    [1.1106834e-01, 1.5105106e03, 3.1980032e03, 2.1075664e03],
                    [7.1036065e-01, 1.7359819e03, 3.1970449e03, 2.1100107e03],
                ],
                dtype=float32,
            ),
            mask=None,
            confidence=array([0.9980374, 0.9038882, 0.7575455, 0.9032779], dtype=float32),
            class_id=array([86, 64, 67, 67]),
            tracker_id=array([1, 2, 3, 4]),
            data={},
        )


    """

    detections = sv.Detections.from_transformers(detected_result[0])
    detections = tracker.update_with_detections(detections)
    return detections


def annotate_image(
    frame,
    detections,
    labels,
    mask_annotator: sv.MaskAnnotator,
    bbox_annotator: sv.BoxAnnotator,
    label_annotator: sv.LabelAnnotator,
) -> np.ndarray:
    out_frame = mask_annotator.annotate(frame, detections)
    out_frame = bbox_annotator.annotate(out_frame, detections)
    out_frame = label_annotator.annotate(out_frame, detections, labels=labels)
    return out_frame


def detect_and_track(
    video_path,
    model,
    processor,
    tracker,
    confidence_threshold,
    mask_annotator: sv.MaskAnnotator,
    bbox_annotator: sv.BoxAnnotator,
    label_annotator: sv.LabelAnnotator,
):
    video_info = sv.VideoInfo.from_video_path(video_path)
    fps = video_info.fps
    len_frames = video_info.total_frames

    frames_loader = sv.get_video_frames_generator(video_path, end=len_frames)

    result_file_name = "output.mp4"
    original_file_name = "original.mp4"
    combined_file_name = "combined.mp4"
    result_file_path = os.path.join("./output/", result_file_name)
    original_file_path = os.path.join("./output/", original_file_name)
    combined_file_name = os.path.join("./output/", combined_file_name)

    concated_frames = []
    original_frames = []
    for frame in tqdm(frames_loader, total=len_frames):
        results = detect(Image.fromarray(frame), model, processor, confidence_threshold)
        tracked_results = track(results, tracker)
        frame = cv.cvtColor(frame, cv.COLOR_RGB2BGR)
        original_frames.append(frame.copy())
        scores = tracked_results.confidence.tolist()
        labels = tracked_results.class_id.tolist()

        frame = annotate_image(
            frame,
            tracked_results,
            labels=[
                str(f"{model.config.id2label[label]}:{score:.2f}")
                for label, score in zip(labels, scores)
            ],
            mask_annotator=mask_annotator,
            bbox_annotator=bbox_annotator,
            label_annotator=label_annotator,
        )
        concated_frames.append(frame)  # Add the processed frame to the list

    # Create a MoviePy video clip from the list of frames
    original_video = mpe.ImageSequenceClip(original_frames, fps=fps)
    original_video.write_videofile(original_file_path, codec="libx264", fps=fps)
    concated_video = mpe.ImageSequenceClip(concated_frames, fps=fps)
    concated_video.write_videofile(result_file_path, codec="libx264", fps=fps)

    combined_video = combine_frames(original_frames, concated_frames, fps)
    combined_video.write_videofile(combined_file_name, codec="libx264", fps=fps)
    return result_file_path, combined_file_name


def combine_frames(frames_list1, frames_list2, fps):
    """
    args:
        frames_list1: list of PIL images
        frames_list2: list of PIL images
    returns:
        final_clip: moviepy video clip
    """
    clip1 = ImageSequenceClip(frames_list1, fps=fps)
    clip2 = ImageSequenceClip(frames_list2, fps=fps)

    final_clip = mpe.clips_array([[clip1, clip2]])

    return final_clip