import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
import cv2
import numpy as np
import gradio as gr
import spaces

# Load the model and processor
def load_model():
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2-VL-2B-Instruct",
        torch_dtype=torch.float16
    )
    processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
    return model, processor

model, processor = load_model()

SYSTEM_PROMPT = """You are an AI assistant specialized in analyzing images and videos of code editors like VSCode. Your primary task is to identify and extract code snippets visible in these visual inputs. Follow these guidelines:

1. Focus on recognizing and extracting code, ignoring non-code elements like toolbars, file explorers, or terminals.
2. If multiple code snippets or files are visible, extract and separate them clearly.
3. Preserve the syntax, indentation, and structure of the code as seen in the image/video.
4. If you detect syntax errors or warning highlights in the code, mention them but do not attempt to correct the code.
5. If the code is partially visible or cut off, extract what you can see and indicate where the code might be incomplete.
6. Identify the programming language if possible, based on syntax or file extensions visible.
7. If asked about specific elements (e.g., "What's on line 20?"), focus on that particular part of the code.
8. Do not invent or assume code that isn't visible in the image/video.
9. If the image/video doesn't contain code or an IDE, politely inform the user.

Always strive for accuracy in code extraction, as the user may need to directly use or analyze the extracted code."""

@spaces.GPU
def process_content(content, predefined_prompt, custom_prompt):
    if content is None:
        return "Please upload an image or video file of a code editor."

    # Combine predefined and custom prompts
    user_prompt = predefined_prompt if predefined_prompt else custom_prompt
    full_prompt = f"{SYSTEM_PROMPT}\n\nUser request: {user_prompt}"

    if content.name.lower().endswith(('.png', '.jpg', '.jpeg')):
        return process_image(Image.open(content.name), full_prompt)
    elif content.name.lower().endswith(('.mp4', '.avi', '.mov')):
        return process_video(content, full_prompt)
    else:
        return "Unsupported file type. Please provide an image or video file of a code editor."

@spaces.GPU
def process_image(image, prompt):
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt},
            ],
        }
    ]

    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)

    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    ).to("cuda")

    model.to("cuda")
    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=256)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    return output_text[0]

@spaces.GPU
def process_video(video, prompt, max_frames=16, frame_interval=30, max_resolution=224):
    cap = cv2.VideoCapture(video.name)
    frames = []
    frame_count = 0

    while len(frames) < max_frames:
        ret, frame = cap.read()
        if not ret:
            break

        if frame_count % frame_interval == 0:
            h, w = frame.shape[:2]
            if h > w:
                new_h, new_w = max_resolution, int(w * max_resolution / h)
            else:
                new_h, new_w = int(h * max_resolution / w), max_resolution
            frame = cv2.resize(frame, (new_w, new_h))
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = Image.fromarray(frame)
            frames.append(frame)

        frame_count += 1

    cap.release()

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "video", "video": frames},
                {"type": "text", "text": prompt},
            ],
        }
    ]

    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)

    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    ).to("cuda")

    model.to("cuda")
    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=256)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    return output_text[0]

# Predefined prompts
PREDEFINED_PROMPTS = [
    "Extract all visible code from this IDE screenshot.",
    "What programming language is used in this code editor image?",
    "Are there any syntax errors highlighted in this VSCode recording?",
    "Extract the function definition starting at line 15 in this image.",
    "What are the variable names defined in the visible code?",
    "Extract and separate all different code snippets or files visible in this IDE recording.",
    "Is there any commented code in this screenshot? If so, extract it.",
    "What coding style or convention is being used in this code (e.g., camelCase, snake_case)?",
    "Are there any import statements or library inclusions visible in this IDE image?",
    "Extract only the CSS code visible in this multi-file editor screenshot."
]

# Gradio interface
iface = gr.Interface(
    fn=process_content,
    inputs=[
        gr.File(label="Upload Image or Video of Code Editor"),
        gr.Dropdown(choices=PREDEFINED_PROMPTS, label="Select a predefined prompt", type="value"),
        gr.Textbox(label="Or enter a custom prompt")
    ],
    outputs="text",
    title="Code Extraction from IDE Screenshots/Recordings",
    description="Upload an image or video of a code editor and select a predefined prompt or enter a custom one to extract and analyze the code.",
    allow_flagging="never"
)

if __name__ == "__main__":
    iface.launch()