import gradio as gr
import torch
import numpy as np
from transformers import OwlViTProcessor, OwlViTForObjectDetection, ResNetModel
from torchvision import transforms
from PIL import Image
import cv2
import torch.nn.functional as F
import tempfile
import os

# Load models
resnet = ResNetModel.from_pretrained("microsoft/resnet-50")
resnet.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet = resnet.to(device)

mixin = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
model = mixin.to(device)

# Preprocess the image
def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return transform(image).unsqueeze(0)

def extract_embedding(image):
    image_tensor = preprocess_image(image).to(device)
    with torch.no_grad():
        output = resnet(image_tensor)
        embedding = output.pooler_output
    return embedding

def cosine_similarity(embedding1, embedding2):
    return F.cosine_similarity(embedding1, embedding2)

def l2_distance(embedding1, embedding2):
    return torch.norm(embedding1 - embedding2, p=2)

def save_array_to_temp_image(arr):
    rgb_arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
    img = Image.fromarray(rgb_arr)
    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
    temp_file_name = temp_file.name
    temp_file.close()
    img.save(temp_file_name)
    return temp_file_name

def detect_and_crop(target_image, query_image, threshold=0.6, nms_threshold=0.3):
    target_sizes = torch.Tensor([target_image.size[::-1]])
    inputs = processor(images=target_image, query_images=query_image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model.image_guided_detection(**inputs)
    
    img = cv2.cvtColor(np.array(target_image), cv2.COLOR_BGR2RGB)
    outputs.logits = outputs.logits.cpu()
    outputs.target_pred_boxes = outputs.target_pred_boxes.cpu()
    
    results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold, nms_threshold=nms_threshold, target_sizes=target_sizes)
    boxes, scores = results[0]["boxes"], results[0]["scores"]

    if len(boxes) == 0:
        return []

    filtered_boxes = []
    for box in boxes:
        x1, y1, x2, y2 = [int(i) for i in box.tolist()]
        cropped_img = img[y1:y2, x1:x2]
        if cropped_img.size != 0:
            filtered_boxes.append(cropped_img)

    return filtered_boxes

def process_video(video_path, query_image, skipframes=0):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        return

    frame_count = 0
    all_results = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        if frame_count % (skipframes + 1) == 0:
            frame_file = save_array_to_temp_image(frame)
            result_frames = detect_and_crop(Image.open(frame_file), query_image)
            for res in result_frames:
                saved_res = save_array_to_temp_image(res)
                embedding1 = extract_embedding(query_image)
                embedding2 = extract_embedding(Image.open(saved_res))
                dist = l2_distance(embedding1, embedding2).item()
                cos = cosine_similarity(embedding1, embedding2).item()
                all_results.append({'l2_dist': dist, 'cos': cos})
        frame_count += 1
    cap.release()
    return all_results

def process_videos_and_compare(image, video, skipframes=5, threshold=0.47):
    def median(values):
        n = len(values)
        return (values[n // 2 - 1] + values[n // 2]) / 2 if n % 2 == 0 else values[n // 2]

    results = process_video(video, image, skipframes)
    if results:
        l2_dists = [item['l2_dist'] for item in results]
        cosines = [item['cos'] for item in results]
        avg_l2_dist = sum(l2_dists) / len(l2_dists)
        avg_cos = sum(cosines) / len(cosines)
        median_l2_dist = median(sorted(l2_dists))
        median_cos = median(sorted(cosines))
        result = {
            "avg_l2_dist": avg_l2_dist,
            "avg_cos": avg_cos,
            "median_l2_dist": median_l2_dist,
            "median_cos": median_cos,
            "avg_cos_dist": 1 - avg_cos,
            "median_cos_dist": 1 - median_cos,
            "is_present": avg_cos >= threshold
        }
    else:
        result = {
            "avg_l2_dist": float('inf'),
            "avg_cos": 0,
            "median_l2_dist": float('inf'),
            "median_cos": 0,
            "avg_cos_dist": float('inf'),
            "median_cos_dist": float('inf'),
            "is_present": False
        }
    return result

def interface(video, image, skipframes, threshold):
    result = process_videos_and_compare(image, video, skipframes, threshold)
    return result

iface = gr.Interface(
    fn=interface,
    inputs=[
        gr.Video(label="Upload a Video"),
        gr.Image(type="pil", label="Upload a Query Image"),
        gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Skip Frames"),
        gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.47, label="Threshold")
    ],
    outputs=[
        gr.JSON(label="Result")
    ],
    title="Object Detection in Video",
    description="""
    **Instructions:**

    1. **Upload a Video**: Select a video file to upload. 
    2. **Upload a Query Image**: Select an image file that contains the object you want to detect in the video.
    3. **Set Skip Frames**: Adjust the slider to set the number of frames to skip between each processing.
    4. **Set Threshold**: Adjust the slider to set the threshold for cosine similarity to determine if the object is present in the video.
    5. **View Results**: The result will show the average and median distances and similarities, and whether the object is present in the video based on the threshold.
    """
)

if __name__ == "__main__":
    iface.launch()