import gradio as gr
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
from transformers.image_utils import load_image
from threading import Thread
import time
import torch
import spaces
import cv2
import numpy as np
from PIL import Image
def progress_bar_html(label: str) -> str:
    """
    Returns an HTML snippet for a thin progress bar with a label.
    The progress bar is styled as a dark animated bar.
    """
    return f'''
    '''
def downsample_video(video_path):
    """
    Downsamples the video to 10 evenly spaced frames.
    Each frame is converted to a PIL Image along with its timestamp.
    """
    vidcap = cv2.VideoCapture(video_path)
    total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = vidcap.get(cv2.CAP_PROP_FPS)
    frames = []
    if total_frames <= 0 or fps <= 0:
        vidcap.release()
        return frames
    # Sample 10 evenly spaced frames.
    frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
    for i in frame_indices:
        vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
        success, image = vidcap.read()
        if success:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            pil_image = Image.fromarray(image)
            timestamp = round(i / fps, 2)
            frames.append((pil_image, timestamp))
    vidcap.release()
    return frames
MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"  # Alternatively: "Qwen/Qwen2.5-VL-3B-Instruct" 
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
).to("cuda").eval()
@spaces.GPU
def model_inference(input_dict, history):
    text = input_dict["text"]
    files = input_dict["files"]
    if text.strip().lower().startswith("@video-infer"):
        # Remove the tag from the query.
        text = text[len("@video-infer"):].strip()
        if not files:
            gr.Error("Please upload a video file along with your @video-infer query.")
            return
        # Assume the first file is a video.
        video_path = files[0]
        frames = downsample_video(video_path)
        if not frames:
            gr.Error("Could not process video.")
            return
        # Build messages: start with the text prompt.
        messages = [
            {
                "role": "user",
                "content": [{"type": "text", "text": text}]
            }
        ]
        # Append each frame with a timestamp label.
        for image, timestamp in frames:
            messages[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
            messages[0]["content"].append({"type": "image", "image": image})
        # Collect only the images from the frames.
        video_images = [image for image, _ in frames]
        # Prepare the prompt.
        prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = processor(
            text=[prompt],
            images=video_images,
            return_tensors="pt",
            padding=True,
        ).to("cuda")
        # Set up streaming generation.
        streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
        generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()
        buffer = ""
        yield progress_bar_html("Processing video with Qwen2.5VL Model")
        for new_text in streamer:
            buffer += new_text
            time.sleep(0.01)
            yield buffer
        return
    if len(files) > 1:
        images = [load_image(image) for image in files]
    elif len(files) == 1:
        images = [load_image(files[0])]
    else:
        images = []
    if text == "" and not images:
        gr.Error("Please input a query and optionally image(s).")
        return
    if text == "" and images:
        gr.Error("Please input a text query along with the image(s).")
        return
    messages = [
        {
            "role": "user",
            "content": [
                *[{"type": "image", "image": image} for image in images],
                {"type": "text", "text": text},
            ],
        }
    ]
    prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(
        text=[prompt],
        images=images if images else None,
        return_tensors="pt",
        padding=True,
    ).to("cuda")
    streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    buffer = ""
    yield progress_bar_html("Processing with Qwen2.5VL Model")
    for new_text in streamer:
        buffer += new_text
        time.sleep(0.01)
        yield buffer
examples = [
    [{"text": "Describe the Image?", "files": ["example_images/document.jpg"]}],
    [{"text": "@video-infer Explain the content of the Advertisement", "files": ["example_images/videoplayback.mp4"]}],
    [{"text": "@video-infer Explain the content of the video in detail", "files": ["example_images/breakfast.mp4"]}],
    [{"text": "@video-infer Explain the content of the video.", "files": ["example_images/sky.mp4"]}],
]
demo = gr.ChatInterface(
    fn=model_inference,
    description="# **Qwen2.5-VL-7B-Instruct `@video-infer for video understanding`**",
    examples=examples,
    fill_height=True,
    textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"),
    stop_btn="Stop Generation",
    multimodal=True,
    cache_examples=False,
)
demo.launch(debug=True)