"""
Processor class for Molmo.
"""

from typing import Optional

import PIL
from PIL import Image

try:
    from typing import Unpack
except ImportError:
    from typing_extensions import Unpack

import re
from typing import List, Optional, Union

import numpy as np
import torch
import torchvision.transforms.functional as F
from transformers import AutoTokenizer
from transformers.image_utils import ImageInput
from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin,
                                           TextKwargs)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import logging

logger = logging.get_logger(__name__)



IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN_INDEX = 0
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<image>"

# For Objects
DEFAULT_OBJECT_TOKEN = "<obj<i>>"
DEFAULT_OBJECT_FEATURE_TOKEN = "<objfeat>"
DEFAULT_OBJECT_INDEX = -300

# For Grounding
DEFAULT_GROUNDING_START = "<ground>"
DEFAULT_GROUNDING_END = "</ground>"
DEFAULT_GROUNDING_OBJECTS_START = "<objects>"
DEFAULT_GROUNDING_OBJECTS_END = "</objects>"

def xyxy_to_xywh(boxes):
    """
    Convert boxes from xywh to xyxy format.

    Parameters:
    boxes (numpy.ndarray): An array of shape (N, 4) where N is the number of boxes.
                           Each box is represented as [x, y, x, y].

    Returns:
    numpy.ndarray: An array of shape (N, 4) where each box is represented as [x_min, y_min, w, h].
    """
    boxes = np.array(boxes)
    x_min, y_min, x_max, y_max = (
        boxes[:, 0],
        boxes[:, 1],
        boxes[:, 2],
        boxes[:, 3],
    )
    w = x_max - x_min
    h = y_max - y_min
    return np.stack([x_min, y_min, w, h], axis=1)


def xywh_to_xyxy(boxes):
    """
    Convert boxes from xywh to xyxy format.

    Parameters:
    boxes (numpy.ndarray): An array of shape (N, 4) where N is the number of boxes.
                           Each box is represented as [x, y, width, height].

    Returns:
    numpy.ndarray: An array of shape (N, 4) where each box is represented as [x_min, y_min, x_max, y_max].
    """
    boxes = np.array(boxes)
    x, y, width, height = (
        boxes[:, 0],
        boxes[:, 1],
        boxes[:, 2],
        boxes[:, 3],
    )
    x_max = x + width
    y_max = y + height
    return np.stack([x, y, x_max, y_max], axis=1)

def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result

def pad_boxes(gt_boxes, old_size):
    old_w, old_h = old_size
    gt_boxes = np.array(gt_boxes).astype(np.float32)
    # Calculate the padding added
    if old_w > old_h:
        pad_top = (old_w - old_h) // 2
        pad_bottom = old_w - old_h - pad_top
        pad_left, pad_right = 0, 0
    else:
        pad_left = (old_h - old_w) // 2
        pad_right = old_h - old_w - pad_left
        pad_top, pad_bottom = 0, 0

    # Adjust the boxes for padding
    gt_boxes[:, 0] += pad_left  # x
    gt_boxes[:, 1] += pad_top  # y
    return gt_boxes


def resize_boxes(gt_boxes, old_size, new_size):
    old_w, old_h = old_size
    new_h, new_w = new_size
    gt_boxes = np.array(gt_boxes).astype(np.float32)
    # Calculate scale factors
    scale_x = new_w / max(old_w, old_h)
    scale_y = new_h / max(old_w, old_h)

    # Resize the boxes
    gt_boxes[:, 0] *= scale_x  # x
    gt_boxes[:, 1] *= scale_y  # y
    gt_boxes[:, 2] *= scale_x  # w
    gt_boxes[:, 3] *= scale_y  # h

    return gt_boxes

def split_special_strings(input_string: str, special_strings: list[str] = None):
    """Split the input string into a list of strings, keeping the special strings.

    Args:
        input_string (str): The input string to split.

        Example:

            input_string = "<image>\n<obj0><objfeat><obj1><objfeat>\n I am happy today."
            output = ['<image>', '\n<obj0>', '<objfeat>', '<obj1>', '<objfeat>', '\n I am happy today.']

    Returns:
        list: A list of strings, with the special strings separated from the rest of the input string.
    """
    # Create a regex pattern to match the special strings
    pattern = "|".join(map(re.escape, special_strings))

    # Split the input string using the pattern, keeping the special strings in the result
    split_list = re.split(f"({pattern})", input_string)

    # Remove empty strings from the list
    split_list = [s for s in split_list if s]

    return split_list

def tokenizer_image_object_token(prompt, tokenizer):
    bos_token_id = tokenizer.bos_token_id
    split_tokens = [DEFAULT_IMAGE_TOKEN, DEFAULT_OBJECT_FEATURE_TOKEN]
    chunks = split_special_strings(prompt, split_tokens)
    input_encode = [bos_token_id]
    for chunk in chunks:
        if chunk == DEFAULT_IMAGE_TOKEN:
            input_encode.append(IMAGE_TOKEN_INDEX)
        elif chunk == DEFAULT_OBJECT_FEATURE_TOKEN:
            input_encode.append(DEFAULT_OBJECT_INDEX)
        else:
            input_encode.extend(tokenizer.encode(chunk, add_special_tokens=False))
    return input_encode

class ChatRexProcessor(ProcessorMixin):
    attributes = ["image_processor", "tokenizer"]
    image_processor_class = "AutoImageProcessor"
    tokenizer_class = "AutoTokenizer"

    def __init__(self, image_processor = None, tokenizer : AutoTokenizer = None, **kwargs):
        # self.image_processor = image_processor
        # self.tokenizer = tokenizer
        super().__init__(image_processor, tokenizer)
        self._special_tokens = None
        self.template = dict(
            SYSTEM=('A chat between a curious user and an artificial '
                    'intelligence assistant. The assistant gives '
                    'helpful, detailed, and polite answers to the '
                    'user\'s questions. {system}\n '),
            INSTRUCTION=('USER: {input} ASSISTANT:'),
            SEP='\n')

    def process(
        self,
        image: Union[str, Image.Image],
        bbox: List[List[int]],
        question: str,
    ):
        """Prepare input data for inference.

        Args:
            image (Union[str, Image.Image]): The image to process.
            bbox (List[List[int]]): A list of bounding boxes for the image. Each bounding box should
                be in order of [x, y, x , y].
            question (str): The question to ask about the image.
        """
        data_dict = {}
        # step1 load image
        if type(image) == str:
            image = Image.open(image).convert("RGB")
        ori_w, ori_h = F.get_image_size(image)
        image = expand2square(
            image,
            tuple(int(x * 255) for x in self.image_processor.image_mean),
        )
        pad_w, pad_h = F.get_image_size(image)
        image_aux = self.image_processor.preprocess(image, return_tensors="pt")[
            "pixel_values"
        ][0]
        resize_h, resize_w = image_aux.shape[-2:]
        data_dict["pixel_values_aux"] = image_aux.unsqueeze(0)
        image = image_aux.clone()
        image = torch.nn.functional.interpolate(
            image[None],
            size=[336, 336],
            mode="bilinear",
            align_corners=False,
        )[0]
        data_dict["pixel_values"] = image.unsqueeze(0)

        # step2 load boxes
        bbox= xyxy_to_xywh(bbox)
        bbox = pad_boxes(bbox, (ori_w, ori_h))
        bbox = resize_boxes(bbox, (pad_w, pad_h), (resize_h, resize_w))
        data_dict["gt_boxes"] = torch.tensor(xywh_to_xyxy(bbox)).unsqueeze(0)

        # step3 prepare question
        total_num_boxes = len(bbox)
        obj_tokens = [
            DEFAULT_OBJECT_TOKEN.replace("<i>", str(i)) for i in range(total_num_boxes)
        ]
        obj_tokens = (
            DEFAULT_OBJECT_FEATURE_TOKEN.join(obj_tokens) + DEFAULT_OBJECT_FEATURE_TOKEN
        )
        question = question.replace(DEFAULT_IMAGE_TOKEN, "")
        question = DEFAULT_IMAGE_TOKEN + "\n" + obj_tokens + "\n" + question


        inputs = ""
        inputs += self.template["INSTRUCTION"].format(input=question, round=1)

        # step4 tokenize question
        input_ids = tokenizer_image_object_token(inputs, self.tokenizer)
        data_dict["input_ids"] = torch.tensor(input_ids).unsqueeze(0)

        return data_dict

ChatRexProcessor.register_for_auto_class()