import gradio as gr
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
import cv2
import numpy as np
import os
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Device count: {torch.cuda.device_count()}")
    print(f"Current device: {torch.cuda.current_device()}")
    print(f"Device name: {torch.cuda.get_device_name()}")

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Set the default tensor type to cuda
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')


def load_model():
    try:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")
        
        model = Qwen2VLForConditionalGeneration.from_pretrained(
            "Qwen/Qwen2-VL-2B-Instruct",
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto",
            low_cpu_mem_usage=True
        )
        processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
        return model, processor, device
    except Exception as e:
        print(f"Error loading model: {e}")
        return None, None, None

model, processor, device = load_model()

SYSTEM_PROMPT = """You are an expert technical analyst specializing in identifying bugs, fixing errors, and explaining code functions from visual inputs. When presented with an image or video:
1. If you see code, analyze it for potential bugs or errors, and suggest fixes.
2. If you see a function or algorithm, explain its purpose and how it works.
3. If you see a technical diagram or flowchart, interpret its meaning and purpose.
4. For any technical content, provide detailed explanations and insights.
Always maintain a professional and technical tone in your responses."""

def process_content(file, user_prompt):
    if file is None:
        return "No content provided. Please upload an image or video of technical content."
    
    file_path = file.name
    file_extension = os.path.splitext(file_path)[1].lower()
    
    if file_extension in ['.jpg', '.jpeg', '.png', '.bmp']:
        image = Image.open(file_path)
        return analyze_image(image, user_prompt)
    elif file_extension in ['.mp4', '.avi', '.mov']:
        return analyze_video(file_path, user_prompt)
    else:
        return "Unsupported file type. Please provide an image (jpg, jpeg, png, bmp) or video (mp4, avi, mov) of technical content."

def analyze_image(image, prompt):
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": f"Based on the system instructions, {prompt}"},
            ],
        }
    ]
    
    return generate_response(messages)

def analyze_video(video_path, prompt, max_frames=16, frame_interval=30, max_resolution=224):
    try:
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            return "Error: Could not open video file."
        
        frames = []
        frame_count = 0

        while len(frames) < max_frames:
            ret, frame = cap.read()
            if not ret:
                break
            
            if frame_count % frame_interval == 0:
                h, w = frame.shape[:2]
                if h > w:
                    new_h, new_w = max_resolution, int(w * max_resolution / h)
                else:
                    new_h, new_w = int(h * max_resolution / w), max_resolution
                frame = cv2.resize(frame, (new_w, new_h))
                
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = Image.fromarray(frame)
                
                frames.append(frame)
            
            frame_count += 1

        return generate_response([
            {"role": "system", "content": SYSTEM_PROMPT},
            {
                "role": "user",
                "content": [
                    {"type": "video", "video": frames},
                    {"type": "text", "text": f"Based on the system instructions, {prompt}"},
                ],
            }
        ])
    except Exception as e:
        return f"Error processing video: {e}"
    finally:
        if 'cap' in locals():
            cap.release()
                

def generate_response(messages):
    try:
        text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        image_inputs, video_inputs = process_vision_info(messages)
        
        inputs = processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt"
        )
        
        # Move inputs to GPU
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        with torch.no_grad():
            generated_ids = model.generate(
                **inputs,
                max_new_tokens=512,
                do_sample=True,
                top_k=20,
                top_p=0.9,
                temperature=0.7
            )
        
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = processor.batch_decode(
            generated_ids_trimmed,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False
        )
        
        # Clear CUDA cache
        torch.cuda.empty_cache()
        
        return output_text[0]
    except Exception as e:
        return f"Error generating response: {e}"
        
# Gradio interface
iface = gr.Interface(
    fn=process_content,
    inputs=[
        gr.File(label="Upload Image or Video of Technical Content"),
        gr.Textbox(label="Enter your technical question", placeholder="e.g., Identify any bugs in this code and suggest fixes", value="Analyze this technical content and provide insights.")
    ],
    outputs="text",
    title="Technical Content Analysis",
    description="Upload an image or video of code, diagrams, or technical content. Ask questions about bugs, errors, or explanations of functions.",
)

iface.launch(share=True)