import gradio as gr
import torch
import cv2
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image, ImageDraw
from transformers import AutoProcessor
from modeling_florence2 import Florence2ForConditionalGeneration
import io
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Polygon  
import numpy as np
import random
import json


with open("config.json", "r") as f:
    config = json.load(f)

d_model = config['text_config']['d_model']
num_layers = config['text_config']['encoder_layers']
attention_heads = config['text_config']['encoder_attention_heads']
vocab_size = config['text_config']['vocab_size']
max_length = config['text_config']['max_length']
beam_size = config['text_config']['num_beams']
dropout = config['text_config']['dropout']
activation_function = config['text_config']['activation_function']
no_repeat_ngram_size = config['text_config']['no_repeat_ngram_size']
patch_size = config['vision_config']['patch_size'][0]
temporal_embeddings = config['vision_config']['visual_temporal_embedding']['max_temporal_embeddings']

title = """# 🙋🏻‍♂️Welcome to Tonic's PLeIAs/📸📈✍🏻Florence-PDF"""
description = """
This application showcases the **PLeIAs/📸📈✍🏻Florence-PDF** model, a powerful AI system designed for both **text and image generation tasks**. The model is capable of handling complex tasks such as object detection, image captioning, OCR (Optical Character Recognition), and detailed region-based image analysis.

### Model Usage and Flexibility

- **No Repeat N-Grams**: To reduce repetition in text generation, the model is configured with a **no_repeat_ngram_size** of **{no_repeat_ngram_size}**, ensuring more diverse and meaningful outputs.
- **Sampling Strategies**: 🙏🏻PLeIAs/📸📈✍🏻Florence-PDF offers flexible sampling strategies, including **top-k** and **top-p (nucleus) sampling**, allowing for both creative and constrained generation based on user needs.

📸📈✍🏻Florence-PDF is a robust model capable of handling various **text and image** tasks with high precision and flexibility, making it a valuable tool for both academic research and practical applications.

### **How to Use**:
1. **Upload an Image**: Select an image for processing.
2. **Choose a Task**: Pick a task from the dropdown menu, such as "Caption", "Object Detection", "OCR", etc.
3. **Process**: Click the "Process" button to let PLeIAs/📸📈✍🏻Florence-PDF analyze the image and generate the output.
4. **View Results**: Depending on the task, you’ll either see a processed image (e.g., with bounding boxes or labels) or a text-based result (e.g., a generated caption or extracted text).

You can reset the interface anytime by clicking the **Reset** button.

### **Available Tasks**:
- **✍🏻Caption**: Generate a concise description of the image.
- **📸Object Detection**: Identify and label objects within the image.
- **📸✍🏻OCR**: Extract text from the image.
- **📸Region Proposal**: Detect key regions in the image for detailed captioning.
"""

model_presentation = f"""
The **🙏🏻PLeIAs/📸📈✍🏻Florence-PDF** model is a state-of-the-art model for conditional generation tasks, designed to be highly effective for both **text** and **vision** tasks. It is built as an **encoder-decoder** architecture, which allows for enhanced flexibility and performance in generating outputs based on diverse inputs.

### Key Features

- **Model Architecture**: 🙏🏻PLeIAs/📸📈✍🏻Florence-PDF uses an encoder-decoder structure, which makes it effective in tasks like **text generation**, **summarization**, and **translation**. It has **{num_layers} layers** for both the encoder and decoder, with a model dimension (`d_model`) of **{d_model}**.
- **Conditional Generation**: The model can generate text conditionally, with a maximum length of **{max_length} tokens** for each generated sequence, making it ideal for tasks that require concise output.
- **Beam Search**: 🙏🏻PLeIAs/📸📈✍🏻Florence-PDFsupports **beam search** with up to **{beam_size} beams**, enabling more diverse and accurate text generation by exploring multiple potential outputs before selecting the best one.
- **Tokenization**: It includes a tokenizer with a vocabulary size of **{vocab_size} tokens**. Special tokens such as the **bos_token_id (0)** and **eos_token_id (2)** help control the generation process by marking the beginning and end of a sequence.
- **Attention Mechanism**: Both the encoder and decoder utilize **{attention_heads} attention heads** per layer, ensuring that the model can focus on relevant parts of the input when generating text.
- **Dropout and Activation**: 🙏🏻PLeIAs/📸📈✍🏻Florence-PDF employs a **{activation_function} activation function** and a **dropout rate of {dropout}**, which enhances model performance by preventing overfitting and improving generalization.
- **Training Configuration**: The model uses **float32** precision for training, and it supports fine-tuning for specific tasks by setting `finetuning_task` appropriately.

### Vision Integration

In addition to text tasks, 🙏🏻PLeIAs/📸📈✍🏻Florence-PDF also incorporates **vision capabilities**:
- **Patch-based Image Processing**: The vision component operates on image patches with a patch size of **{patch_size}x{patch_size}**.
- **Temporal Embedding**: Visual tasks benefit from temporal embeddings with up to **{temporal_embeddings} steps**, making Florence-2 well-suited for video analysis.

"""

joinus = """🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP) On 🤗Huggingface:[MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
"""
how_to_use = """The advanced settings allow you to fine-tune the text generation process. Here's what each setting does and how to use it:

### Top-k (Default: 50)
Top-k sampling limits the next token selection to the k most likely tokens.

- **Lower values** (e.g., 10) make the output more focused and deterministic.
- **Higher values** (e.g., 100) allow for more diverse outputs.

**Example:** For a creative writing task, try setting top-k to 80 for more varied language.

### Top-p (Default: 1.0)
Top-p (or nucleus) sampling selects from the smallest set of tokens whose cumulative probability exceeds p.

- **Lower values** (e.g., 0.5) make the output more focused and coherent.
- **Higher values** (e.g., 0.9) allow for more diverse and potentially creative outputs.

**Example:** For a factual caption, set top-p to 0.7 to balance accuracy and creativity.

### Repetition Penalty (Default: 1.0)
This penalizes repetition in the generated text.

- **Values closer to 1.0** have minimal effect on repetition.
- **Higher values** (e.g., 1.5) more strongly discourage repetition.

**Example:** If you notice repeated phrases, try increasing to 1.2 for more varied text.

### Number of Beams (Default: 3)
Beam search explores multiple possible sequences in parallel.

- **Higher values** (e.g., 5) can lead to better quality but slower generation.
- **Lower values** (e.g., 1) are faster but may produce lower quality results.

**Example:** For complex tasks like dense captioning, try increasing to 5 beams.

### Max Tokens (Default: 512)
This sets the maximum length of the generated text.

- **Lower values** (e.g., 100) for concise outputs.
- **Higher values** (e.g., 1000) for more detailed descriptions.

**Example:** For a detailed image description, set max tokens to 800 for a comprehensive output.

Remember, these settings interact with each other, so experimenting with different combinations can lead to interesting results!
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model = Florence2ForConditionalGeneration.from_pretrained("PleIAs/Florence-PDF", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("PleIAs/Florence-PDF", trust_remote_code=True)

TASK_PROMPTS = {
    "✍🏻Caption": "<CAPTION>",
    "✍🏻✍🏻Caption": "<DETAILED_CAPTION>",
    "✍🏻✍🏻✍🏻Caption": "<MORE_DETAILED_CAPTION>",
    "📸Object Detection": "<OD>",
    "📸Dense Region Caption": "<DENSE_REGION_CAPTION>",
    "📸✍🏻OCR": "<OCR>",
    "📸✍🏻OCR with Region": "<OCR_WITH_REGION>",
    "📸Region Proposal": "<REGION_PROPOSAL>"
}


IMAGE_TASKS = ["📸Object Detection", "📸Dense Region Caption", "📸Region Proposal", "📸✍🏻OCR with Region"]
TEXT_TASKS = ["✍🏻Caption", "✍🏻✍🏻Caption", "✍🏻✍🏻✍🏻Caption", "📸✍🏻OCR", "📸✍🏻OCR with Region"]

colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
            'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']

def fig_to_pil(fig):
    buf = io.BytesIO()
    fig.savefig(buf, format='png')
    buf.seek(0)
    return Image.open(buf)

def plot_bbox(image, data, use_quad_boxes=False):
    fig, ax = plt.subplots()
    ax.imshow(image)

    if use_quad_boxes:
        for quad_box, label in zip(data.get('quad_boxes', []), data.get('labels', [])):
            quad_box = np.array(quad_box).reshape(-1, 2)
            poly = Polygon(quad_box, linewidth=1, edgecolor='r', facecolor='none')
            ax.add_patch(poly)
            plt.text(quad_box[0][0], quad_box[0][1], label, color='white', fontsize=8,
                     bbox=dict(facecolor='red', alpha=0.5))
    else:
        bboxes = data.get('bboxes', [])
        labels = data.get('labels', [])
        for bbox, label in zip(bboxes, labels):
            x1, y1, x2, y2 = bbox
            rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=1, edgecolor='r', facecolor='none')
            ax.add_patch(rect)
            plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))

    ax.axis('off')

    return fig

def draw_ocr_bboxes(image, prediction):
    scale = 1
    draw = ImageDraw.Draw(image)
    bboxes, labels = prediction['quad_boxes'], prediction['labels']
    for box, label in zip(bboxes, labels):
        color = random.choice(colormap)
        new_box = (np.array(box) * scale).tolist()
        draw.polygon(new_box, width=3, outline=color)
        draw.text((new_box[0]+8, new_box[1]+2),
                  "{}".format(label),
                  align="right",
                  fill=color)
        
    return image

def draw_bounding_boxes(image, quad_boxes, labels, color=(0, 255, 0), thickness=2):
    """
    Draws quadrilateral bounding boxes on the image.
    """
    for i, quad in enumerate(quad_boxes):
        points = np.array(quad, dtype=np.int32).reshape((-1, 1, 2))  # Reshape the quad points for drawing
        image = cv2.polylines(image, [points], isClosed=True, color=color, thickness=thickness)
        label_pos = (int(quad[0]), int(quad[1]) - 10)  
        cv2.putText(image, labels[i], label_pos, cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, thickness)

    return image

def process_image(image, task):
    prompt = TASK_PROMPTS[task]
    inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
    generated_ids = model.generate(
        **inputs,
        max_new_tokens=1024,
        num_beams=3,
        do_sample=False
    )

    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))

    return parsed_answer


def main_process(image, task, top_k, top_p, repetition_penalty, num_beams, max_tokens):
    prompt = TASK_PROMPTS[task]
    inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
    generated_ids = model.generate(
        **inputs,
        max_new_tokens=max_tokens,
        num_beams=num_beams,
        do_sample=True,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=repetition_penalty
    )

    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
    return parsed_answer

def process_and_update(image, task, top_k, top_p, repetition_penalty, num_beams, max_tokens):
    if image is None:
        return None, gr.update(visible=False), "Please upload an image first.", gr.update(visible=True)
    result = main_process(image, task, top_k, top_p, repetition_penalty, num_beams, max_tokens)
    
    if task in IMAGE_TASKS:
        if task == "📸✍🏻OCR with Region":
            fig = plot_bbox(image, result.get('<OCR_WITH_REGION>', {}), use_quad_boxes=True)
            output_image = fig_to_pil(fig)
            text_output = result.get('<OCR_WITH_REGION>', {}).get('recognized_text', 'No text found')
            return output_image, gr.update(visible=True), text_output, gr.update(visible=False)
        else:
            fig = plot_bbox(image, result.get(TASK_PROMPTS[task], {}))
            output_image = fig_to_pil(fig)
            return output_image, gr.update(visible=True), None, gr.update(visible=False)
    else:
        return None, gr.update(visible=False), str(result), gr.update(visible=True)

def reset_outputs():
    return None, gr.update(visible=False), None, gr.update(visible=True)

with gr.Blocks(title="Tonic's 🙏🏻PLeIAs/📸📈✍🏻Florence-PDF") as iface:
    with gr.Column():
        with gr.Row():
            gr.Markdown(title)
        with gr.Row():
            with gr.Column(scale=1):
                with gr.Group():                    
                    gr.Markdown(model_presentation)
            with gr.Column(scale=1):
                with gr.Group():
                    gr.Markdown(description)
        with gr.Row():
            with gr.Accordion("🫱🏻‍🫲🏻Join Us", open=True):
                gr.Markdown(joinus)
        with gr.Row():
            with gr.Column(scale=1):
                image_input = gr.Image(type="pil", label="Input Image")
                task_dropdown = gr.Dropdown(list(TASK_PROMPTS.keys()), label="Task", value="✍🏻Caption")            
                with gr.Row():
                    submit_button = gr.Button("📸📈✍🏻Process")
                    reset_button = gr.Button("♻️Reset")
                with gr.Accordion("🧪Advanced Settings", open=False):
                    with gr.Accordion("🏗️How To Use", open=True):
                        gr.Markdown(how_to_use)                        
                    top_k = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k")
                    top_p = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.01, label="Top-p")
                    repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.0, step=0.01, label="Repetition Penalty")
                    num_beams = gr.Slider(minimum=1, maximum=6, value=3, step=1, label="Number of Beams")
                    max_tokens = gr.Slider(minimum=1, maximum=1024, value=1000, step=1, label="Max Tokens")
            with gr.Column(scale=1):    
                output_image = gr.Image(label="🙏🏻PLeIAs/📸📈✍🏻Florence-PDF", visible=False)
                output_text = gr.Textbox(label="🙏🏻PLeIAs/📸📈✍🏻Florence-PDF", visible=False)
    
    submit_button.click(
        fn=process_and_update,
        inputs=[image_input, task_dropdown, top_k, top_p, repetition_penalty, num_beams, max_tokens],
        outputs=[output_image, output_image, output_text, output_text]
    )
    
    reset_button.click(
        fn=reset_outputs,
        inputs=[],
        outputs=[output_image, output_image, output_text, output_text]
    )
    
    task_dropdown.change(
        fn=lambda task: (gr.update(visible=task in IMAGE_TASKS), gr.update(visible=task in TEXT_TASKS)),
        inputs=[task_dropdown],
        outputs=[output_image, output_text]
    )

iface.launch()