import os
import glob
import time
import threading
import requests
import wikipedia
import torch
import cv2
import numpy as np
from io import BytesIO
from PIL import Image
import base64  # Added import

import gradio as gr
from ultralytics import YOLO
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from diffusers import MarigoldDepthPipeline  # Updated import for depth model
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet

# Set environment variable for PyTorch MPS fallback before importing torch
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

# Initialize Models
def initialize_models():
    models = {}
    
    # Device detection
    if torch.cuda.is_available():
        device = 'cuda'
    elif torch.backends.mps.is_available():
        device = 'mps'
    else:
        device = 'cpu'
    models['device'] = device

    print(f"Using device: {device}")

    # Initialize the RoBERTa model for question answering
    try:
        models['qa_pipeline'] = pipeline(
            "question-answering", model="deepset/roberta-base-squad2", device=0 if device == 'cuda' else -1)
        print("RoBERTa QA pipeline initialized.")
    except Exception as e:
        print(f"Error initializing the RoBERTa model: {e}")
        models['qa_pipeline'] = None

    # Initialize the Gemma model
    try:
        models['gemma_tokenizer'] = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
        models['gemma_model'] = AutoModelForCausalLM.from_pretrained(
            "google/gemma-2-2b-it",
            device_map="auto",
            torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32
        )
        print("Gemma model initialized.")
    except Exception as e:
        print(f"Error initializing the Gemma model: {e}")
        models['gemma_model'] = None

    # Initialize the depth estimation model using MarigoldDepthPipeline exactly as per your sample
    try:
        if device == 'cuda':
            models['depth_pipe'] = MarigoldDepthPipeline.from_pretrained(
                "prs-eth/marigold-depth-lcm-v1-0",
                variant="fp16",
                torch_dtype=torch.float16
            ).to('cuda')
        else:
            # For CPU or MPS devices, keep on 'cpu' to avoid unsupported operators
            models['depth_pipe'] = MarigoldDepthPipeline.from_pretrained(
                "prs-eth/marigold-depth-lcm-v1-0",
                torch_dtype=torch.float32
            ).to('cpu')
        print("Depth estimation model initialized.")
    except Exception as e:
        error_message = f"Error initializing the depth estimation model: {e}"
        print(error_message)
        models['depth_pipe'] = None
        models['depth_init_error'] = error_message  # Store the error message

    # Initialize the upscaling model
    try:
        upscaler_model_path = 'weights/RealESRGAN_x4plus.pth'  # Ensure this path is correct
        if not os.path.exists(upscaler_model_path):
            print(f"Upscaling model weights not found at {upscaler_model_path}. Please download them.")
            models['upscaler'] = None
        else:
            # Define the model architecture
            model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
                            num_block=23, num_grow_ch=32, scale=4)

            # Initialize RealESRGANer
            models['upscaler'] = RealESRGANer(
                scale=4,
                model_path=upscaler_model_path,
                model=model,
                pre_pad=0,
                half=(device == 'cuda'),
                device=device
            )
            print("Real-ESRGAN upscaler initialized.")
    except Exception as e:
        print(f"Error initializing the upscaling model: {e}")
        models['upscaler'] = None

    # Initialize YOLO model
    try:
        source_weights_path = "/Users/David/Downloads/WheelOfFortuneLab-DavidDriscoll/Eurybia1.3/mbari_315k_yolov8.pt"
        if not os.path.exists(source_weights_path):
            print(f"YOLO weights not found at {source_weights_path}. Please download them.")
            models['yolo_model'] = None
        else:
            models['yolo_model'] = YOLO(source_weights_path)
            print("YOLO model initialized.")
    except Exception as e:
        print(f"Error initializing YOLO model: {e}")
        models['yolo_model'] = None

    return models

models = initialize_models()

# Utility Functions
def search_class_description(class_name):
    wikipedia.set_lang("en")
    wikipedia.set_rate_limiting(True)
    description = ""

    try:
        page = wikipedia.page(class_name)
        if page:
            description = page.content[:5000]  # Get more content
    except Exception as e:
        print(f"Error fetching description for {class_name}: {e}")

    return description

def search_class_image(class_name):
    wikipedia.set_lang("en")
    wikipedia.set_rate_limiting(True)
    img_url = ""

    try:
        page = wikipedia.page(class_name)
        if page:
            for img in page.images:
                if img.lower().endswith(('.jpg', '.jpeg', '.png', '.gif')):
                    img_url = img
                    break
    except Exception as e:
        print(f"Error fetching image for {class_name}: {e}")

    return img_url

def process_image(image):
    if models['yolo_model'] is None:
        return None, "YOLO model is not initialized.", "YOLO model is not initialized.", [], None
    
    try:
        if image is None:
            return None, "No image uploaded.", "No image uploaded.", [], None

        # Convert Gradio Image to OpenCV format
        image_np = np.array(image)
        if image_np.dtype != np.uint8:
            image_np = image_np.astype(np.uint8)

        if len(image_np.shape) != 3 or image_np.shape[2] != 3:
            return None, "Invalid image format. Please upload a RGB image.", "Invalid image format. Please upload a RGB image.", [], None

        image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)

        # Store the original image before drawing bounding boxes
        original_image_cv = image_cv.copy()
        original_image_pil = Image.fromarray(cv2.cvtColor(original_image_cv, cv2.COLOR_BGR2RGB))

        # Perform YOLO prediction
        results = models['yolo_model'].predict(
            source=image_cv, conf=0.075)[0]  # Lowered the threshold

        bounding_boxes = []
        image_processed = image_cv.copy()

        if results.boxes is not None:
            for box in results.boxes:
                x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
                class_name = models['yolo_model'].names[int(box.cls)]
                confidence = box.conf.item() * 100  # Convert to percentage

                bounding_boxes.append({
                    "coords": (x1, y1, x2, y2),
                    "class_name": class_name,
                    "confidence": confidence
                })

                cv2.rectangle(image_processed, (x1, y1), (x2, y2), (0, 0, 255), 2)
                cv2.putText(image_processed, f'{class_name} {confidence:.2f}%',
                            (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX,
                            0.9, (0, 0, 255), 2)

        # Convert back to PIL Image
        processed_image = Image.fromarray(cv2.cvtColor(image_processed, cv2.COLOR_BGR2RGB))

        # Prepare detection info
        if bounding_boxes:
            detection_info = "\n".join(
                [f'{box["class_name"]}: {box["confidence"]:.2f}%' for box in bounding_boxes]
            )
        else:
            detection_info = "No detections found."

        # Prepare detection details as Markdown
        if bounding_boxes:
            details = ""
            for idx, box in enumerate(bounding_boxes):
                class_name = box['class_name']
                confidence = box['confidence']
                description = search_class_description(class_name)
                img_url = search_class_image(class_name)
                img_md = ""
                if img_url:
                    try:
                        headers = {
                            'User-Agent': 'MyApp/1.0 (https://example.com/contact; myemail@example.com)'
                        }
                        response = requests.get(img_url, headers=headers, timeout=10)
                        img_data = response.content
                        img = Image.open(BytesIO(img_data)).convert("RGB")
                        img.thumbnail((400, 400))  # Resize for faster loading
                        buffered = BytesIO()
                        img.save(buffered, format="PNG")
                        img_str = base64.b64encode(buffered.getvalue()).decode()
                        img_md = f"![{class_name}](data:image/png;base64,{img_str})\n\n"
                    except Exception as e:
                        print(f"Error fetching image for {class_name}: {e}")
                details += f"### {idx+1}. {class_name} ({confidence:.2f}%)\n\n"
                if description:
                    details += f"{description}\n\n"
                if img_md:
                    details += f"{img_md}\n\n"
            detection_details_md = details
        else:
            detection_details_md = "No detections to show."

        return processed_image, detection_info, detection_details_md, bounding_boxes, original_image_pil
    except Exception as e:
        print(f"Error processing image: {e}")
        return None, f"Error processing image: {e}", f"Error processing image: {e}", [], None

def ask_eurybia(question, state):
    if not question.strip():
        return "Please enter a valid question.", state

    if not state['bounding_boxes']:
        return "No detected objects to ask about.", state

    # Combine descriptions of all detected objects as context
    context = ""
    for box in state['bounding_boxes']:
        description = search_class_description(box['class_name'])
        if description:
            context += description + "\n"

    if not context.strip():
        return "No sufficient context available to answer the question.", state

    try:
        if models['qa_pipeline'] is None:
            return "QA pipeline is not initialized.", state

        answer = models['qa_pipeline'](question=question, context=context)
        answer_text = answer['answer'].strip()
        if not answer_text:
            return "I couldn't find an answer to that question based on the detected objects.", state
        return answer_text, state
    except Exception as e:
        print(f"Error during question answering: {e}")
        return f"Error during question answering: {e}", state

def enhance_image(cropped_image_pil):
    if models['upscaler'] is None:
        return None, "Upscaling model is not initialized."
    
    try:
        input_image = cropped_image_pil.convert("RGB")
        img = np.array(input_image)

        # Run the model to enhance the image
        output, _ = models['upscaler'].enhance(img, outscale=4)

        enhanced_image = Image.fromarray(output)

        return enhanced_image, "Image enhanced successfully."
    except Exception as e:
        print(f"Error during image enhancement: {e}")
        return None, f"Error during image enhancement: {e}"

def run_depth_prediction(original_image):
    if models['depth_pipe'] is None:
        error_msg = models.get('depth_init_error', "Depth estimation model is not initialized.")
        return None, error_msg

    try:
        if original_image is None:
            return None, "No image uploaded for depth prediction."

        # Prepare the image
        input_image = original_image.convert("RGB")

        # Run the depth pipeline
        result = models['depth_pipe'](input_image)

        # Access the depth prediction
        depth_prediction = result.prediction  # Adjust based on sample code

        # Visualize the depth map
        vis_depth = models['depth_pipe'].image_processor.visualize_depth(depth_prediction)
        
        # Ensure vis_depth is a list and extract the first image
        if isinstance(vis_depth, list) and len(vis_depth) > 0:
            vis_depth_image = vis_depth[0]
        else:
            vis_depth_image = vis_depth  # Fallback if not a list

        return vis_depth_image, "Depth prediction completed."
    except Exception as e:
        print(f"Error during depth prediction: {e}")
        return None, f"Error during depth prediction: {e}"

# Gradio Interface Components
with gr.Blocks() as demo:
    gr.Markdown("# Eurybia Mini - Object Detection and Analysis Tool")

    with gr.Tab("Upload & Process"):
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(type="pil", label="Upload Image")
                process_button = gr.Button("Process Image")
                clear_button = gr.Button("Clear")
            with gr.Column():
                processed_image = gr.Image(type="pil", label="Processed Image")
                detection_info = gr.Textbox(label="Detection Information", lines=10)

    with gr.Tab("Detection Details"):
        with gr.Accordion("Click to see detection details", open=False):
            detection_details_md = gr.Markdown("No detections to show.")

    with gr.Tab("Ask Eurybia"):
        with gr.Row():
            with gr.Column():
                question_input = gr.Textbox(label="Ask a question about the detected objects")
                ask_button = gr.Button("Ask Eurybia")
            with gr.Column():
                answer_output = gr.Markdown(label="Eurybia's Answer")

    with gr.Tab("Depth Estimation"):
        with gr.Row():
            with gr.Column():
                depth_button = gr.Button("Run Depth Prediction")
            with gr.Column():
                depth_output = gr.Image(type="pil", label="Depth Map")
                depth_status = gr.Textbox(label="Status", lines=2)
        
        # Display error message if depth estimation model failed to initialize
        if models.get('depth_init_error'):
            gr.Markdown(f"**Depth Estimation Initialization Error:** {models['depth_init_error']}")

    with gr.Tab("Enhance Detected Objects"):
        if models['yolo_model'] is not None and models['upscaler'] is not None:
            with gr.Row():
                detected_objects = gr.Dropdown(choices=[], label="Select Detected Object", interactive=True)
                enhance_btn = gr.Button("Enhance Image")
            with gr.Column():
                enhanced_image = gr.Image(type="pil", label="Enhanced Image")
                enhance_status = gr.Textbox(label="Status", lines=2)
        else:
            gr.Markdown("**Warning:** YOLO model or Upscaling model is not initialized. Image enhancement functionality will be unavailable.")

    with gr.Tab("Credits"):
        gr.Markdown("""
# Credits and Licensing Information

This project utilizes various open-source libraries, tools, pretrained models, and datasets. Below is the list of components used and their respective credits/licenses:

## Libraries
- **Python** - Python Software Foundation License (PSF License)
- **Gradio** - Licensed under the Apache License 2.0
- **Torch (PyTorch)** - Licensed under the BSD 3-Clause License
- **OpenCV (cv2)** - Licensed under the Apache License 2.0
- **NumPy** - Licensed under the BSD License
- **Pillow (PIL)** - Licensed under the HPND License
- **Requests** - Licensed under the Apache License 2.0
- **Wikipedia API** - Licensed under the MIT License
- **Transformers** - Licensed under the Apache License 2.0
- **Diffusers** - Licensed under the Apache License 2.0
- **Real-ESRGAN** - Licensed under the MIT License
- **BasicSR** - Licensed under the Apache License 2.0
- **Ultralytics YOLO** - Licensed under the GPL-3.0 License

## Pretrained Models
- **deepset/roberta-base-squad2 (RoBERTa)** - Model provided by Hugging Face under the Apache License 2.0.
- **google/gemma-2-2b-it** - Model provided by Hugging Face under the Apache License 2.0.
- **prs-eth/marigold-depth-lcm-v1-0** - Licensed under the Apache License 2.0.
- **Real-ESRGAN model weights (RealESRGAN_x4plus.pth)** - Distributed under the MIT License.
- **FathomNet MBARI 315K YOLOv8 Model**:
  - **Dataset**: Sourced from [FathomNet](https://fathomnet.org).
  - **Model**: Derived from MBARI’s curated dataset of 315,000 marine annotations.
  - **License**: Dataset and models adhere to MBARI’s Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0).

## Datasets
- **FathomNet MBARI Dataset**:
  - A large-scale dataset for marine biodiversity image annotations.
  - All content adheres to the [CC BY-NC 4.0 License](https://creativecommons.org/licenses/by-nc/4.0/).

## Acknowledgments
- **Ultralytics YOLO**: For the YOLOv8 architecture used for object detection.
- **FathomNet and MBARI**: For providing the marine dataset and annotations that support object detection in underwater imagery.
- **Gradio**: For providing an intuitive interface for machine learning applications.
- **Hugging Face**: For pretrained models and pipelines (e.g., Transformers, Diffusers).
- **Real-ESRGAN**: For image enhancement and upscaling models.
- **Wikipedia API**: For fetching object descriptions and images.
""")

    # Hidden state to store bounding boxes, original and processed images
    state = gr.State({"bounding_boxes": [], "last_image": None, "original_image": None})

    # Event Handlers
    def on_process_image(image, state):
        processed_img, info, details, bounding_boxes, original_image_pil = process_image(image)
        if processed_img is not None:
            # Update the state with new bounding boxes and images
            state['bounding_boxes'] = bounding_boxes
            state['last_image'] = processed_img
            state['original_image'] = original_image_pil
            # Update the dropdown choices for detected objects
            choices = [f"{idx+1}. {box['class_name']} ({box['confidence']:.2f}%)" for idx, box in enumerate(bounding_boxes)]
        else:
            choices = []
        return processed_img, info, details, gr.update(choices=choices), state

    process_button.click(
        on_process_image,
        inputs=[image_input, state],
        outputs=[processed_image, detection_info, detection_details_md, detected_objects, state]
    )

    def on_clear(state):
        state = {"bounding_boxes": [], "last_image": None, "original_image": None}
        return None, "No detections found.", "No detections to show.", gr.update(choices=[]), state

    clear_button.click(
        on_clear,
        inputs=state,
        outputs=[processed_image, detection_info, detection_details_md, detected_objects, state]
    )

    def on_ask_eurybia(question, state):
        answer, state = ask_eurybia(question, state)
        return answer, state

    ask_button.click(
        on_ask_eurybia,
        inputs=[question_input, state],
        outputs=[answer_output, state]
    )

    def on_depth_prediction(state):
        original_image = state.get('original_image')
        depth_img, status = run_depth_prediction(original_image)
        return depth_img, status

    depth_button.click(
        on_depth_prediction,
        inputs=state,
        outputs=[depth_output, depth_status]
    )

    def on_enhance_image(selected_object, state):
        if not selected_object:
            return None, "No object selected.", state

        try:
            idx = int(selected_object.split('.')[0]) - 1
            box = state['bounding_boxes'][idx]
            class_name = box['class_name']
            x1, y1, x2, y2 = box['coords']

            if not state.get('last_image'):
                return None, "Processed image is not available.", state

            # Ensure processed_image is stored in state
            processed_img_pil = state['last_image']
            if not isinstance(processed_img_pil, Image.Image):
                return None, "Processed image is in an unsupported format.", state

            # Convert processed_image to OpenCV format with checks
            processed_img_cv = np.array(processed_img_pil)
            if processed_img_cv.dtype != np.uint8:
                processed_img_cv = processed_img_cv.astype(np.uint8)

            if len(processed_img_cv.shape) != 3 or processed_img_cv.shape[2] != 3:
                return None, "Invalid processed image format.", state

            processed_img_cv = cv2.cvtColor(processed_img_cv, cv2.COLOR_RGB2BGR)

            # Crop the detected object from the processed image
            cropped_img_cv = processed_img_cv[y1:y2, x1:x2]
            if cropped_img_cv.size == 0:
                return None, "Cropped image is empty.", state

            cropped_img_pil = Image.fromarray(cv2.cvtColor(cropped_img_cv, cv2.COLOR_BGR2RGB))

            # Enhance the cropped image
            enhanced_img, status = enhance_image(cropped_img_pil)
            return enhanced_img, status, state
        except Exception as e:
            return None, f"Error: {e}", state

    if models['yolo_model'] is not None and models['upscaler'] is not None:
        enhance_btn.click(
            on_enhance_image,
            inputs=[detected_objects, state],
            outputs=[enhanced_image, enhance_status, state]
        )

    # Optional: Add a note if the depth model isn't initialized
    if models['depth_pipe'] is None and not models.get('depth_init_error'):
        gr.Markdown("**Warning:** Depth estimation model is not initialized. Depth prediction functionality will be unavailable.")

    # Optional: Add a note if the upscaler isn't initialized
    if models['upscaler'] is None:
        gr.Markdown("**Warning:** Upscaling model is not initialized. Image enhancement functionality will be unavailable.")

# Launch the Gradio app
demo.launch()