import os
import logging
import sys
from config import WEAVE_PROJECT, WANDB_API_KEY
import weave
from model_utils import get_model_summary, install_flash_attn

# Install required package
install_flash_attn()

weave.init(WEAVE_PROJECT)

# Function to get logging level from environment variable
def get_logging_level(default_level=logging.INFO):  # Default to DEBUG for detailed logs
    log_level_str = os.getenv('VISION_AGENT_LOG_LEVEL', '').upper()
    if log_level_str == 'DEBUG':
        return logging.DEBUG
    elif log_level_str == 'INFO':
        return logging.INFO
    elif log_level_str == 'WARNING':
        return logging.WARNING
    elif log_level_str == 'ERROR':
        return logging.ERROR
    elif log_level_str == 'CRITICAL':
        return logging.CRITICAL
    else:
        return default_level

# Initialize logger
logging.basicConfig(level=get_logging_level(), format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('vision_agent')

from huggingface_hub import login
import time
import gradio as gr
from typing import *
from pillow_heif import register_heif_opener
register_heif_opener()
import vision_agent as va
from vision_agent.tools import register_tool, load_image, owl_v2, grounding_dino, florencev2_object_detection, overlay_bounding_boxes, save_image

# Perform login using the token
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token, add_to_git_credential=True)

import numpy as np
from PIL import Image

@weave.op()
def detect_object_owlv2(image, seg_input, debug: bool = True):
    """
    Detects a brain tumor in the given image and returns the annotated image.

    Parameters:
        image: The input image (as numpy array provided by Gradio).
        seg_input: The segmentation input (not used in this function, but required for Gradio).
        debug (bool): Flag to enable logging for debugging purposes.

    Returns:
        tuple: (numpy array of image, list of (label, (x1, y1, x2, y2)) tuples)
    """

    # Step 2: Detect brain tumor using owl_v2
    prompt = seg_input
    detections = owl_v2(prompt, image)

    # Step 3: Overlay bounding boxes on the image
    image_with_bboxes = overlay_bounding_boxes(image, detections)

    # Prepare annotations for AnnotatedImage output
    annotations = []
    for detection in detections:
        label = detection['label']
        score = detection['score']
        bbox = detection['bbox']
        x1, y1, x2, y2 = bbox
        # Convert normalized coordinates to pixel coordinates
        height, width = image.shape[:2]
        x1, y1, x2, y2 = int(x1*width), int(y1*height), int(x2*width), int(y2*height)
        annotations.append(((x1, y1, x2, y2), f"{label} {score:.2f}"))

    # Convert image to numpy array if it's not already
    if isinstance(image_with_bboxes, Image.Image):
        image_with_bboxes = np.array(image_with_bboxes)

    return (image_with_bboxes, annotations)

@weave.op()
def detect_object_dino(image, seg_input, debug: bool = True):
    """
    Detects a brain tumor in the given image and returns the annotated image.

    Parameters:
        image: The input image (as numpy array provided by Gradio).
        seg_input: The segmentation input (not used in this function, but required for Gradio).
        debug (bool): Flag to enable logging for debugging purposes.

    Returns:
        tuple: (numpy array of image, list of (label, (x1, y1, x2, y2)) tuples)
    """

    # Step 2: Detect brain tumor using grounding_dino
    prompt = seg_input
    detections = grounding_dino(prompt, image)

    # Step 3: Overlay bounding boxes on the image
    image_with_bboxes = overlay_bounding_boxes(image, detections)

    # Prepare annotations for AnnotatedImage output
    annotations = []
    for detection in detections:
        label = detection['label']
        score = detection['score']
        bbox = detection['bbox']
        x1, y1, x2, y2 = bbox
        # Convert normalized coordinates to pixel coordinates
        height, width = image.shape[:2]
        x1, y1, x2, y2 = int(x1*width), int(y1*height), int(x2*width), int(y2*height)
        annotations.append(((x1, y1, x2, y2), f"{label} {score:.2f}"))

    # Convert image to numpy array if it's not already
    if isinstance(image_with_bboxes, Image.Image):
        image_with_bboxes = np.array(image_with_bboxes)

    return (image_with_bboxes, annotations)

@weave.op()
def detect_object_florence2(image, seg_input, debug: bool = True):
    """
    Detects a brain tumor in the given image and returns the annotated image.

    Parameters:
        image: The input image (as numpy array provided by Gradio).
        seg_input: The segmentation input (not used in this function, but required for Gradio).
        debug (bool): Flag to enable logging for debugging purposes.

    Returns:
        tuple: (numpy array of image, list of (label, (x1, y1, x2, y2)) tuples)
    """

    # Step 2: Detect brain tumor using florencev2 - NO PROMPT
    detections = florencev2_object_detection(image)

    # Step 3: Overlay bounding boxes on the image
    image_with_bboxes = overlay_bounding_boxes(image, detections)

    # Prepare annotations for AnnotatedImage output
    annotations = []
    for detection in detections:
        label = detection['label']
        score = detection['score']
        bbox = detection['bbox']
        x1, y1, x2, y2 = bbox
        # Convert normalized coordinates to pixel coordinates
        height, width = image.shape[:2]
        x1, y1, x2, y2 = int(x1*width), int(y1*height), int(x2*width), int(y2*height)
        annotations.append(((x1, y1, x2, y2), f"{label} {score:.2f}"))

    # Convert image to numpy array if it's not already
    if isinstance(image_with_bboxes, Image.Image):
        image_with_bboxes = np.array(image_with_bboxes)

    return (image_with_bboxes, annotations)

def handle_model_summary(model_name):
    model_summary, error_message = get_model_summary(model_name)
    if error_message:
        return error_message, ""
    return model_summary, ""

INTRO_TEXT="""# 🔬🧠 OmniScience -- Agentic Imaging Analysis 🤖🧫

- these are the results from the base non-finetuned models
"""

with gr.Blocks(theme="sudeepshouche/minimalist") as demo:
    gr.Markdown(INTRO_TEXT)
    with gr.Tab("Object Detection - Owl V2"):
        with gr.Row():
            with gr.Column():
                image = gr.Image(type="numpy")
                seg_input = gr.Text(label="Entities to Segment/Detect")
        
            with gr.Column():
                annotated_image = gr.AnnotatedImage(label="Output")

        seg_btn = gr.Button("Submit")    
        examples = [
                        ["./examples/BloodImage_00099_jpg.rf.0a65e56401cdd71253e7bc04917c3558.jpg", "detect blood cell"],
                        ["./examples/15_242_212_25_25_jpg.rf.f6bbadf4260dd2c1f5b4ace1b09b0a1b.jpg", "detect liver disease"],
                        ["./examples/194_jpg.rf.3e3dd592d034bb5ee27a978553819f42.jpg", "detect brain tumor"],
                        ["./examples/239_jpg.rf.3dcc0799277fb78a2ab21db7761ccaeb.jpg", "detect brain tumor"],
                        ["./examples/2871_jpg.rf.3b6eadfbb369abc2b3bcb52b406b74f2.jpg", "detect brain tumor"],
                        ["./examples/2921_jpg.rf.3b952f91f27a6248091e7601c22323ad.jpg", "detect brain tumor"],
                    ]
        gr.Examples(
            examples=examples,
            inputs=[image, seg_input],
        )
        seg_inputs = [
            image,
            seg_input
            ]
        seg_outputs = [
            annotated_image
        ]
        seg_btn.click(
            fn=detect_object_owlv2,
            inputs=seg_inputs,
            outputs=seg_outputs,
        )

    with gr.Tab("Object Detection - DINO"):
        with gr.Row():
            with gr.Column():
                image = gr.Image(type="numpy")
                seg_input = gr.Text(label="Entities to Segment/Detect")
        
            with gr.Column():
                annotated_image = gr.AnnotatedImage(label="Output")

        seg_btn = gr.Button("Submit")    
        examples = [
                        ["./examples/BloodImage_00099_jpg.rf.0a65e56401cdd71253e7bc04917c3558.jpg", "detect blood cell"],
                        ["./examples/15_242_212_25_25_jpg.rf.f6bbadf4260dd2c1f5b4ace1b09b0a1b.jpg", "detect liver disease"],
                        ["./examples/194_jpg.rf.3e3dd592d034bb5ee27a978553819f42.jpg", "detect brain tumor"],
                        ["./examples/239_jpg.rf.3dcc0799277fb78a2ab21db7761ccaeb.jpg", "detect brain tumor"],
                        ["./examples/2871_jpg.rf.3b6eadfbb369abc2b3bcb52b406b74f2.jpg", "detect brain tumor"],
                        ["./examples/2921_jpg.rf.3b952f91f27a6248091e7601c22323ad.jpg", "detect brain tumor"],
                    ]
        gr.Examples(
            examples=examples,
            inputs=[image, seg_input],
        )
        seg_inputs = [
            image,
            seg_input
            ]
        seg_outputs = [
            annotated_image
        ]
        seg_btn.click(
            fn=detect_object_dino,
            inputs=seg_inputs,
            outputs=seg_outputs,
        )

    with gr.Tab("Object Detection - Florence2"):
        with gr.Row():
            with gr.Column():
                image = gr.Image(type="numpy")
                seg_input = gr.Text(label="Entities to Segment/Detect")
        
            with gr.Column():
                annotated_image = gr.AnnotatedImage(label="Output")

        seg_btn = gr.Button("Submit")    
        examples = [
                        ["./examples/BloodImage_00099_jpg.rf.0a65e56401cdd71253e7bc04917c3558.jpg", "<OD>"],
                        ["./examples/15_242_212_25_25_jpg.rf.f6bbadf4260dd2c1f5b4ace1b09b0a1b.jpg", "<OD>"],
                        ["./examples/194_jpg.rf.3e3dd592d034bb5ee27a978553819f42.jpg", "<OD>"],
                        ["./examples/239_jpg.rf.3dcc0799277fb78a2ab21db7761ccaeb.jpg", "<OD>"],
                        ["./examples/2871_jpg.rf.3b6eadfbb369abc2b3bcb52b406b74f2.jpg", "<OD>"],
                        ["./examples/2921_jpg.rf.3b952f91f27a6248091e7601c22323ad.jpg", "<OD>"],
                    ]
        gr.Examples(
            examples=examples,
            inputs=[image, seg_input],
        )
        seg_inputs = [
            image,
            seg_input
            ]
        seg_outputs = [
            annotated_image
        ]
        seg_btn.click(
            fn=detect_object_florence2,
            inputs=seg_inputs,
            outputs=seg_outputs,
        )

    with gr.Tab("Model Explorer"):
        gr.Markdown("## Retrieve and Display Model Architecture")
        model_name_input = gr.Textbox(label="Model Name", placeholder="Enter the model name to retrieve its architecture...")
        vision_examples = gr.Examples(
            examples=[
                ["facebook/sam-vit-huge"],
                ["google/owlv2-base-patch16-ensemble"],
                ["IDEA-Research/grounding-dino-base"],
                ["microsoft/Florence-2-large-ft"],
                ["google/paligemma-3b-mix-224"],
                ["llava-hf/llava-v1.6-mistral-7b-hf"],
                ["vikhyatk/moondream2"],
                ["microsoft/Phi-3-vision-128k-instruct"],
                ["HuggingFaceM4/idefics2-8b-chatty"]
            ],
            inputs=model_name_input
        )
        model_submit_button = gr.Button("Submit")
        model_output = gr.Textbox(label="Model Architecture", lines=20, placeholder="Model architecture will appear here...", show_copy_button=True)
        error_output = gr.Textbox(label="Error", lines=10, placeholder="Exceptions will appear here...", show_copy_button=True)
        model_submit_button.click(fn=handle_model_summary, inputs=model_name_input, outputs=[model_output, error_output])

if __name__ == "__main__":
    demo.queue(max_size=10).launch(debug=True)