import json
import logging
import math
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from open_clip.factory import get_model_config, load_state_dict
from open_clip.model import (CLIPTextCfg, CLIPVisionCfg, _build_text_tower,
                             _build_vision_tower,
                             convert_to_custom_text_state_dict)
from open_clip.transformer import text_global_pool
from torch import nn
from torchvision.ops import roi_align
from transformers import (CONFIG_MAPPING, AutoConfig, AutoModel,
                          AutoModelForCausalLM, GenerationConfig,
                          PretrainedConfig, PreTrainedModel, StoppingCriteria,
                          StoppingCriteriaList)
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from transformers.generation import GenerationConfig
from transformers.modeling_utils import load_state_dict
from transformers.utils import logging, strtobool

from .convnext import ConvNextVisionEncoder

logger = logging.get_logger(__name__)

XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()

IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN_INDEX = 0
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<image>"

# For Objects
DEFAULT_OBJECT_TOKEN = "<obj<i>>"
DEFAULT_OBJECT_FEATURE_TOKEN = "<objfeat>"
DEFAULT_OBJECT_INDEX = -300

# For Grounding
DEFAULT_GROUNDING_START = "<ground>"
DEFAULT_GROUNDING_END = "</ground>"
DEFAULT_GROUNDING_OBJECTS_START = "<objects>"
DEFAULT_GROUNDING_OBJECTS_END = "</objects>"

def is_fsdp_enabled():
    return (
        torch.distributed.is_available()
        and torch.distributed.is_initialized()
        and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1
        and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1
    )




def get_token_slices(input_ids: torch.Tensor):
    """
    Get slices of tokens based on special markers in the input tensor.

    Args:
        input_ids (torch.Tensor): A tensor of token IDs where IMAGE_TOKEN_INDEX represents an image token,
            DEFAULT_OBJECT_INDEX represents an object token, and all other values represent text tokens.

    Returns:
        List[Dict[str, Any]]: A list of dictionaries where each dictionary contains the type of the
            token slice ('text', 'image', 'object') and the span as a list of start and end indices.
    """
    # define type markers and corresponding types
    type_map = {IMAGE_TOKEN_INDEX: "image", DEFAULT_OBJECT_INDEX: "object"}

    # find the positions of special markers
    image_indices = torch.where(input_ids == IMAGE_TOKEN_INDEX)[0]
    object_indices = torch.where(input_ids == DEFAULT_OBJECT_INDEX)[0]
    if len(object_indices) > 0:
        has_object = True
    else:
        has_object = False

    # merge all the positions of special markers
    special_indices = torch.cat((image_indices, object_indices))
    special_indices, _ = torch.sort(special_indices)
    special_tokens = input_ids[special_indices]

    slices = []
    start_idx = 0

    for i, idx in enumerate(special_indices):
        if start_idx < idx:
            slices.append({"type": "text", "span": [start_idx, idx.item()]})
        token_type = type_map[special_tokens[i].item()]
        slices.append({"type": token_type, "span": [idx.item(), idx.item() + 1]})
        start_idx = idx.item() + 1

    if start_idx < len(input_ids):
        slices.append({"type": "text", "span": [start_idx, len(input_ids)]})

    return slices, has_object


def prepare_inputs_labels_for_multimodal(
    llm,
    input_ids: torch.LongTensor = None,
    position_ids: Optional[torch.LongTensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    past_key_values: Optional[List[torch.FloatTensor]] = None,
    labels: Optional[torch.LongTensor] = None,
    pixel_values: Optional[torch.FloatTensor] = None,
    bbox_feats=None,
    extra_llm_input_embed: nn.Embedding = None,
    **kwargs,
):
    if pixel_values is None:
        return {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "attention_mask": attention_mask,
            "past_key_values": past_key_values,
            "inputs_embeds": None,
            "labels": labels,
        }

    _labels = labels
    _position_ids = position_ids
    _attention_mask = attention_mask
    if attention_mask is None:
        attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
    else:
        attention_mask = attention_mask.bool()
    if position_ids is None:
        position_ids = torch.arange(
            0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
        )
    if labels is None:
        labels = torch.full_like(input_ids, IGNORE_INDEX)

    # remove the padding using attention_mask -- TODO: double check
    input_ids = [
        cur_input_ids[cur_attention_mask]
        for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
    ]
    labels = [
        cur_labels[cur_attention_mask]
        for cur_labels, cur_attention_mask in zip(labels, attention_mask)
    ]

    new_inputs_embeds = []
    new_labels = []
    cur_image_idx = 0
    cur_object_idx = 0
    for batch_idx, cur_input_ids in enumerate(input_ids):
        num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
        if num_images == 0:
            cur_pixel_values = pixel_values[cur_image_idx]
            cur_inputs_embeds_1 = llm.get_input_embeddings()(cur_input_ids)
            cur_inputs_embeds = torch.cat(
                [cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0
            )
            new_inputs_embeds.append(cur_inputs_embeds)
            new_labels.append(labels[batch_idx])
            cur_image_idx += 1
            cur_object_idx += 1
            continue

        cur_labels = labels[batch_idx]
        token_slices, has_object = get_token_slices(cur_input_ids)
        result_input_embeddings = []
        result_output_labels = []
        cur_gt_bnox_indice = 0
        for slice in token_slices:
            slice_type = slice["type"]
            slice_span = slice["span"]
            if slice_type == "text":
                cur_input_ids_noim = cur_input_ids[slice_span[0] : slice_span[1]]
                cur_labels_noim = cur_labels[slice_span[0] : slice_span[1]]
                cur_input_embeds = llm.get_input_embeddings()(cur_input_ids_noim)
                result_input_embeddings.append(cur_input_embeds)
                result_output_labels.append(cur_labels_noim)
            elif slice_type == "image":
                cur_input_embeds = pixel_values[cur_image_idx]
                result_input_embeddings.append(cur_input_embeds)
                result_output_labels.append(
                    torch.full(
                        (cur_input_embeds.shape[0],),
                        IGNORE_INDEX,
                        device=cur_labels.device,
                        dtype=cur_labels.dtype,
                    )
                )
                cur_image_idx += 1
            elif slice_type == "object":
                try:
                    result_input_embeddings.append(
                        bbox_feats[cur_object_idx][cur_gt_bnox_indice].unsqueeze(0)
                    )
                except:
                    raise ValueError(
                        f"current boxe_feats.shape: {bbox_feats[cur_object_idx].shape}, "
                    )
                cur_gt_bnox_indice += 1
                result_output_labels.append(
                    torch.full(
                        (1,),
                        IGNORE_INDEX,
                        device=cur_labels.device,
                        dtype=cur_labels.dtype,
                    )
                )
        cur_object_idx += 1
        result_input_embeddings = torch.cat(result_input_embeddings)
        result_output_labels = torch.cat(result_output_labels)
        assert len(result_output_labels) == len(result_input_embeddings)
        new_inputs_embeds.append(result_input_embeddings)
        new_labels.append(result_output_labels)

    # Combine them
    max_len = max(x.shape[0] for x in new_inputs_embeds)
    batch_size = len(new_inputs_embeds)

    new_inputs_embeds_padded = []
    new_labels_padded = torch.full(
        (batch_size, max_len),
        IGNORE_INDEX,
        dtype=new_labels[0].dtype,
        device=new_labels[0].device,
    )
    attention_mask = torch.zeros(
        (batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device
    )
    position_ids = torch.zeros(
        (batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device
    )

    for i, (cur_new_embed, cur_new_labels) in enumerate(
        zip(new_inputs_embeds, new_labels)
    ):
        cur_len = cur_new_embed.shape[0]
        new_inputs_embeds_padded.append(
            torch.cat(
                (
                    cur_new_embed,
                    torch.zeros(
                        (max_len - cur_len, cur_new_embed.shape[1]),
                        dtype=cur_new_embed.dtype,
                        device=cur_new_embed.device,
                    ),
                ),
                dim=0,
            )
        )
        if cur_len > 0:
            new_labels_padded[i, :cur_len] = cur_new_labels
            attention_mask[i, :cur_len] = True
            position_ids[i, :cur_len] = torch.arange(
                0, cur_len, dtype=position_ids.dtype, device=position_ids.device
            )

    new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0)

    if _labels is None:
        new_labels = None
    else:
        new_labels = new_labels_padded

    if _attention_mask is None:
        attention_mask = None
    else:
        attention_mask = attention_mask.to(dtype=_attention_mask.dtype)

    if _position_ids is None:
        position_ids = None

    return {
        "input_ids": None,
        "position_ids": position_ids,
        "attention_mask": attention_mask,
        "past_key_values": past_key_values,
        "inputs_embeds": new_inputs_embeds,
        "labels": new_labels,
    }

class StopWordStoppingCriteria(StoppingCriteria):
    """StopWord stopping criteria."""

    def __init__(self, tokenizer, stop_word):
        self.tokenizer = tokenizer
        self.stop_word = stop_word
        self.length = len(self.stop_word)

    def __call__(self, input_ids, *args, **kwargs) -> bool:
        cur_text = self.tokenizer.decode(input_ids[0])
        cur_text = cur_text.replace('\r', '').replace('\n', '')
        return cur_text[-self.length:] == self.stop_word

def get_stop_criteria(
    tokenizer,
    stop_words=[],
):
    stop_criteria = StoppingCriteriaList()
    for word in stop_words:
        stop_criteria.append(StopWordStoppingCriteria(tokenizer, word))
    return stop_criteria

class DualPathFuseModule(nn.Module):
    # change channel+gate+sum
    def __init__(self, low_res_dim, high_res_dim, zero_init=True):
        super().__init__()

        self.slow_conv = nn.Conv2d(high_res_dim, high_res_dim, 1)
        self.slow_proj = nn.Conv2d(high_res_dim, low_res_dim, 1)

        self.fast_conv = nn.Conv2d(
            low_res_dim, low_res_dim, 7, padding=3, groups=low_res_dim
        )
        self.fast_proj = nn.Conv2d(low_res_dim, low_res_dim, 1)

        self.gate = nn.Sequential(
            nn.Linear(low_res_dim * 2, low_res_dim // 2),
            nn.GELU(),
            nn.Linear(low_res_dim // 2, 1),
        )

        nn.init.xavier_uniform_(self.slow_conv.weight)
        nn.init.xavier_uniform_(self.fast_conv.weight)
        nn.init.zeros_(self.slow_conv.bias)
        nn.init.zeros_(self.fast_conv.bias)
        if zero_init:
            nn.init.zeros_(self.slow_proj.weight)
            nn.init.zeros_(self.fast_proj.weight)
        else:
            nn.init.xavier_uniform_(self.slow_proj.weight)
            nn.init.xavier_uniform_(self.fast_proj.weight)
        nn.init.zeros_(self.slow_proj.bias)
        nn.init.zeros_(self.fast_proj.bias)

    def forward(self, low_res_feat, high_res_feat, sampler=None):
        b, c, h, w = high_res_feat.shape  # (2, 1536, 24, 24)
        _, _, d = low_res_feat.shape  # (2, 576, 1024)
        high_res_feat = self.slow_proj(
            F.gelu(self.slow_conv(high_res_feat))
        )  # (2, 1024, 24, 24)
        high_res_feat = high_res_feat.view(b, d, -1).transpose(1, 2)  # (2, 576, 1024)
        dst_size = int(math.sqrt(low_res_feat.shape[1]))  # 24
        low_res_feat = low_res_feat.transpose(1, 2).view(
            b, d, dst_size, dst_size
        )  # (2, 1024, 24, 24)
        low_res_feat = low_res_feat + self.fast_proj(
            F.gelu(self.fast_conv(low_res_feat))
        )
        low_res_feat = low_res_feat.view(b, d, dst_size * dst_size).transpose(
            1, 2
        )  # (2, 576, 1024)
        gate = self.gate(
            torch.cat([low_res_feat, high_res_feat], -1).mean(1)
        ).unsqueeze(
            1
        )  # (2, 1, 1)
        low_res_feat = low_res_feat + high_res_feat * gate.tanh()
        return low_res_feat

class ProjectorConfig(PretrainedConfig):
    model_type = "projector"
    _auto_class = "AutoConfig"

    def __init__(
        self,
        visual_hidden_size=4096,
        llm_hidden_size=4096,
        depth=2,
        hidden_act="gelu",
        bias=True,
        **kwargs,
    ):
        self.visual_hidden_size = visual_hidden_size
        self.llm_hidden_size = llm_hidden_size
        self.depth = depth
        self.hidden_act = hidden_act
        self.bias = bias
        super().__init__(**kwargs)

class ProjectorModel(PreTrainedModel):
    _auto_class = "AutoModel"
    config_class = ProjectorConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = []

    def __init__(self, config: ProjectorConfig) -> None:
        super().__init__(config)
        self.gradient_checkpointing = False

        modules = [
            nn.Linear(
                config.visual_hidden_size, config.llm_hidden_size, bias=config.bias
            )
        ]
        for _ in range(1, config.depth):
            modules.append(ACT2FN[config.hidden_act])
            modules.append(
                nn.Linear(
                    config.llm_hidden_size, config.llm_hidden_size, bias=config.bias
                )
            )
        self.model = nn.Sequential(*modules)

    def enable_input_require_grads(self):

        def make_inputs_require_grad(module, input, output):
            output.requires_grad_(True)

        self.model.register_forward_hook(make_inputs_require_grad)

    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, ProjectorModel):
            module.gradient_checkpointing = value

    def forward(self, x):
        layer_outputs = self.model(x)
        return layer_outputs


def gen_sineembed_for_position(pos_tensor, dim_of_pos_feats):
    """Generate sine position embedding from a position tensor.

    Args:
        pos_tensor (torch.Tensor): shape: [batch_size, N, 4]. the last dimension is [cx, cy, w, h] in
            normalized coordinates in range [0, 1].
        out_dim (int): the output dimension of the position embedding.

    Returns:
        pos (torch.Tensor): shape: [batch_size, N, out_dim].
    """
    scale = 2 * math.pi
    dim_t = torch.arange(
        dim_of_pos_feats, dtype=torch.float32, device=pos_tensor.device
    )
    dim_t = 10000 ** (2 * (dim_t // 2) / dim_of_pos_feats)
    x_embed = pos_tensor[:, :, 0] * scale
    y_embed = pos_tensor[:, :, 1] * scale
    pos_x = x_embed[:, :, None] / dim_t
    pos_y = y_embed[:, :, None] / dim_t
    pos_x = torch.stack(
        (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3
    ).flatten(2)
    pos_y = torch.stack(
        (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3
    ).flatten(2)
    if pos_tensor.size(-1) == 2:
        pos = torch.cat((pos_y, pos_x), dim=2)
    elif pos_tensor.size(-1) == 4:
        w_embed = pos_tensor[:, :, 2] * scale
        pos_w = w_embed[:, :, None] / dim_t
        pos_w = torch.stack(
            (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3
        ).flatten(2)

        h_embed = pos_tensor[:, :, 3] * scale
        pos_h = h_embed[:, :, None] / dim_t
        pos_h = torch.stack(
            (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3
        ).flatten(2)

        pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
    else:
        raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
    return pos


class MultiLevelROIVisualPrompt(nn.Module):
    """Initialize the MultiLevelROIVisualPrompt.

    Args:
        output_size (Optional[int]): The size of the output. Default is None.
        channel_per_level (List[int]): List of channels per level. Default is [192, 384, 768, 1536].
        spatial_scale (Optional[float]): The spatial scale factor. Default is None.
        with_additional_projection (bool): Whether to use additional projection. Default is False.
        visual_prompt_hidden_size (int): The hidden size of the visual prompt. Default is 1024.
        add_pos_embedding (bool): Whether to add position embedding. Default is False.
        pos_embedding_dim (int): The dimension of the position embedding. Default is 1024.
    """

    def __init__(
        self,
        output_size: int = None,
        channel_per_level: List[int] = [192, 384, 768, 1536],
        spatail_scale: float = None,
        visual_prompt_hidden_size: bool = 1024,
        add_pos_embedding: bool = False,
        pos_embedding_dim: int = 1024,
    ):
        super(MultiLevelROIVisualPrompt, self).__init__()
        self.output_size = output_size
        self.channel_per_level = channel_per_level
        self.spatail_scale = spatail_scale
        self.add_pos_embedding = add_pos_embedding
        self.pos_embedding_dim = pos_embedding_dim

    def __call__(
        self,
        multi_level_features: List[torch.Tensor],
        boxes: Union[torch.Tensor, List[torch.Tensor]],
    ) -> torch.Tensor:
        """Performs Region of Interest (RoI) Align operator on multi-level features. The RoI
        feature on each scale will go through a different linear layer for projection. Different
        RoI features will be summed up and then average pooled.

        Args:
            multi_level_features (Listp[Tensor[N, C, H, W]]): Feature maps from different levels
            boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2)
                format where the regions will be taken from.
        Returns:
            Tensor[1, K, C]: The output tensor that has the shape KxC, where K is the number of RoIs
        """
        boxes[0] = boxes[0].float()
        concat_multi_level_feature = []
        max_height = max([feature.shape[2] for feature in multi_level_features])
        max_width = max([feature.shape[3] for feature in multi_level_features])
        # interpolate to the same size
        for level, feature in enumerate(multi_level_features):
            if level != 0:
                concat_multi_level_feature.append(
                    F.interpolate(
                        feature.float(),
                        size=(max_height, max_width),
                        mode="bilinear",
                        align_corners=False,
                    )
                )
            else:
                concat_multi_level_feature.append(feature.float())
        concat_multi_level_feature = torch.cat(concat_multi_level_feature, dim=1)

        
        out_box_feat = roi_align(
            concat_multi_level_feature,
            boxes,
            output_size=self.output_size,
            spatial_scale=self.spatail_scale,
        )

        # Average Pooling -> n,c -> 1,n,c
        out_box_feat = out_box_feat.mean(dim=(2, 3)).reshape(
            1, out_box_feat.shape[0], out_box_feat.shape[1]
        )
        if self.add_pos_embedding:
            # note that this boxes is in xyxy, unormalized format, so we need to normalize it first
            boxes = boxes[0]  # (N, 4)
            boxes = boxes.to(out_box_feat.dtype)
            original_img_width = max_width / self.spatail_scale
            original_img_height = max_height / self.spatail_scale
            boxes[:, [0, 2]] = boxes[:, [0, 2]] / original_img_width
            boxes[:, [1, 3]] = boxes[:, [1, 3]] / original_img_height
            # convert from xyxy to cx, cy, w, h
            boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
            boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
            boxes[:, 0] = boxes[:, 0] + boxes[:, 2] / 2
            boxes[:, 1] = boxes[:, 1] + boxes[:, 3] / 2
            pos_embed = gen_sineembed_for_position(
                boxes.unsqueeze(0), self.pos_embedding_dim // 4
            )
            out_box_feat = out_box_feat + pos_embed

        return out_box_feat



class ChatRexAuxConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of ChatRexAux model.


    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        vision_config (`Union[AutoConfig, dict]`,  *optional*, defaults to `CLIPVisionConfig`):
            The config object or dictionary of the vision backbone.
        vision_aux_config (`Union[AutoConfig, dict]`,  *optional*, defaults to `OpenCLIPVisionTower`):
        visual_prompt_encoder (`Union[AutoConfig, dict]`,  *optional*, defaults to `MultiLevelROIVisualPrompt`):
        text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
            The config object or dictionary of the text backbone.
        ignore_index (`int`, *optional*, defaults to -100):
            The ignore index for the loss function.
        image_token_index (`int`, *optional*, defaults to 32000):
            The image token index to encode the image prompt.
        projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
            The activation function used by the multimodal projector.
        vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
            The feature selection strategy used to select the vision feature from the vision backbone.
            Can be one of `"default"` or `"full"`.
        vision_feature_layer (`int`, *optional*, defaults to -2):
            The index of the layer to select the vision feature.

    Example:

    ```python
    >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig

    >>> # Initializing a CLIP-vision config
    >>> vision_config = CLIPVisionConfig()

    >>> # Initializing a Llama config
    >>> text_config = LlamaConfig()

    >>> # Initializing a Llava llava-1.5-7b style configuration
    >>> configuration = LlavaConfig(vision_config, text_config)

    >>> # Initializing a model from the llava-1.5-7b style configuration
    >>> model = LlavaForConditionalGeneration(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "chatrex"
    is_composition = False

    def __init__(
        self,
        vision_config=None,
        vision_aux_config=None,
        visual_prompt_encoder_config=None,
        text_config=None,
        ignore_index=-100,
        image_token_index=32000,
        projector_hidden_act="gelu",
        vision_feature_select_strategy="default",
        vision_feature_layer=-2,
        projector_depth=2,
        visual_prompt_hidden_size=2880,
        **kwargs,
    ):
        self.ignore_index = ignore_index
        self.image_token_index = image_token_index
        self.projector_hidden_act = projector_hidden_act
        self.projector_depth = projector_depth
        self.visual_prompt_hidden_size = visual_prompt_hidden_size
        self.visual_prompt_encoder_config = visual_prompt_encoder_config

        if vision_feature_select_strategy not in ["default", "full"]:
            raise ValueError(
                "vision_feature_select_strategy should be one of 'default', 'full'."
                f"Got: {vision_feature_select_strategy}"
            )

        self.vision_feature_select_strategy = vision_feature_select_strategy
        self.vision_feature_layer = vision_feature_layer

        if isinstance(vision_config, dict):
            vision_config["model_type"] = (
                vision_config["model_type"]
                if "model_type" in vision_config
                else "clip_vision_model"
            )
            vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
        elif vision_config is None:
            vision_config = CONFIG_MAPPING["clip_vision_model"](
                intermediate_size=4096,
                hidden_size=1024,
                patch_size=14,
                image_size=336,
                num_hidden_layers=24,
                num_attention_heads=16,
                vocab_size=32000,
                projection_dim=768,
            )

        self.vision_config = vision_config
        self.vision_aux_config = vision_aux_config

        if isinstance(text_config, dict):
            text_config["model_type"] = (
                text_config["model_type"] if "model_type" in text_config else "llama"
            )
            text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
        elif text_config is None:
            text_config = CONFIG_MAPPING["llama"]()

        self.text_config = text_config

        super().__init__(**kwargs)


class ChatRexAuxPreTrainedModel(PreTrainedModel):
    config_class = ChatRexAuxConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["LlavaVisionAttention"]
    _skip_keys_device_placement = "past_key_values"
    _supports_flash_attn_2 = True
    _supports_cache_class = True

    # def _init_weights(self, module):
    #     # important: this ported version of Llava isn't meant for training from scratch - only
    #     # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
    #     # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
    #     std = (
    #         self.config.initializer_range
    #         if hasattr(self.config, "initializer_range")
    #         else self.config.text_config.initializer_range
    #     )

    #     if hasattr(module, "class_embedding"):
    #         module.class_embedding.data.normal_(mean=0.0, std=std)

    #     if isinstance(module, (nn.Linear, nn.Conv2d)):
    #         module.weight.data.normal_(mean=0.0, std=std)
    #         if module.bias is not None:
    #             module.bias.data.zero_()
    #     elif isinstance(module, nn.Embedding):
    #         module.weight.data.normal_(mean=0.0, std=std)
    #         if module.padding_idx is not None:
    #             module.weight.data[module.padding_idx].zero_()

    @property
    def _supports_sdpa(self):
        """
        Retrieve language_model's attribute to check whether the model supports
        SDPA or not.
        """
        return self.language_model._supports_sdpa


class ChatRexAuxForConditionalGeneration(ChatRexAuxPreTrainedModel):

    def __init__(self, config: ChatRexAuxConfig):
        super().__init__(config)
        # low resolusion vision encoder
        self.vision_encoder = AutoModel.from_config(config.vision_config)
        # high resolusion vision encoder
        self.vision_encoder_aux = ConvNextVisionEncoder()

        # vision projector
        projector_config = ProjectorConfig(
            visual_hidden_size=config.vision_config.hidden_size,
            llm_hidden_size=config.text_config.hidden_size,
            depth=config.projector_depth,
        )
        self.projector = ProjectorModel(projector_config)

        # visual prompt encoder
        vp_projector_config = ProjectorConfig(
            visual_hidden_size=config.visual_prompt_hidden_size,
            llm_hidden_size=config.text_config.hidden_size,
            depth=config.projector_depth,
        )
        self.vp_projector = ProjectorModel(vp_projector_config)

        # fuser
        self.fuser = DualPathFuseModule(
            low_res_dim=config.vision_config.hidden_size,
            high_res_dim=1536,
        )

        # visual prompt encoder
        self.vp_encoder = MultiLevelROIVisualPrompt(
            output_size=7,
            channel_per_level=[192, 384, 768, 1536],
            spatail_scale=192 / 768,
            add_pos_embedding=True,
            pos_embedding_dim=2880,
        )

        # genconfig
        self.gen_config = None

        self.vocab_size = config.text_config.vocab_size
        self.llm = AutoModelForCausalLM.from_config(
            config.text_config, attn_implementation=config._attn_implementation
        )
        self.pad_token_id = (
            self.config.pad_token_id if self.config.pad_token_id is not None else -1
        )
        self.post_init()

        
    def _prepare_data_for_llm(self, data):
        if "pixel_values" in data:
            visual_outputs = self.vision_encoder(
                data["pixel_values"].to(self.vision_encoder.dtype),
                output_hidden_states=True,
            )
            if type(self.vision_encoder).__name__ in [
                "CLIPVisionModel",
                "CLIPVisionModelAnyRes",
            ]:
                visual_outputs = visual_outputs.hidden_states[-2][
                    :, 1:
                ]
            elif type(self.vision_encoder).__name__ == "SiglipVisionModel":
                visual_outputs = visual_outputs.hidden_states[-2]
            else:
                raise NotImplementedError

            # aux encoder
            if self.vision_encoder_aux is not None:
                pixels_aux = []
                for pixels in data["pixel_values_aux"]:
                    if pixels.dim() == 3:
                        pixels = pixels.unsqueeze(0)
                    elif pixels.dim() == 4:
                        pixels = pixels.permute(1, 0, 2, 3)
                    pixels_aux.append(pixels)
                visual_outputs_aux = torch.cat(
                    pixels_aux, dim=0
                )  # shape (2, 3, 768, 768)
                aux_output = self.vision_encoder_aux(
                    visual_outputs_aux
                )
                visual_outputs_aux = aux_output["image_features"]
                last_feat = aux_output["last_feat"]  # (B, 1536, 24, 24)
            # fuser
            fuse_features = self.fuser(
                low_res_feat=visual_outputs, high_res_feat=last_feat
            )  # (2, 576, 1024)
            pixel_values = self.projector(fuse_features)
            data["pixel_values"] = pixel_values

            # extract visual prompt features
            bbox_visual_outputs = []
            if "gt_boxes" in data:
                for batch_idx, boxes in enumerate(data["gt_boxes"]):
                    if len(boxes) == 0:
                        bbox_visual_outputs.append(None)
                        continue
                    multi_level_aux_features = [
                        visual_output_aux[batch_idx].unsqueeze(0)
                        for visual_output_aux in visual_outputs_aux
                    ]
                    boxes = boxes.to(torch.float32)
                    out_vp_feat = self.vp_encoder(
                        multi_level_aux_features,
                        [boxes],
                    ).squeeze(0)
                    out_vp_feat = out_vp_feat.to(pixel_values.dtype)
                    out_vp_feat = self.vp_projector(out_vp_feat)
                    bbox_visual_outputs.append(out_vp_feat)
                # b,n,c
                data["bbox_feats"] = bbox_visual_outputs
                
            data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
        return data

    
    def generate(self, data_dict: Dict[str, Any], gen_config=None, tokenizer=None):
        """Perform inference on the given data.

        Args:
            data_dict (Dict[str, Any]): The data to perform inference on.

        Returns:
            str: The answer to the question.
        """
        data_dict = self._prepare_data_for_llm(data_dict)
        data_dict["inputs_embeds"] = data_dict["inputs_embeds"].to(self.llm.dtype)
        stop_criteria = get_stop_criteria(
            tokenizer=tokenizer, stop_words=[]
        )
        generate_output = self.llm.generate(
            **data_dict,
            generation_config=self.gen_config if gen_config is None else gen_config,
            streamer=None,
            bos_token_id=tokenizer.bos_token_id,
            stopping_criteria=stop_criteria,
        )
        print(f'generate_output:', generate_output)
        prediction = tokenizer.decode(
            generate_output[0], skip_special_tokens=False
        ).strip()
        prediction = prediction.replace("<s>", "").replace("</s>", "").strip()
        return prediction


AutoConfig.register("chatrex", ChatRexAuxConfig)
AutoModelForCausalLM.register(ChatRexAuxConfig, ChatRexAuxForConditionalGeneration)