import argparse
import json
from math import ceil
import os
import random
import uuid
from collections import defaultdict
from typing import Callable
import time
import cv2

import more_itertools
import numpy as np
import torch
from coco_metric import compute_cider, postprocess_captioning_generation
from eval_datasets import VQADataset, GQADataset
from tqdm import tqdm
from collections import Counter

from vqa_metric import compute_vqa_accuracy, compute_gqa_accuracy
from open_flamingo.eval.classification import (
    compute_per_sample_probs,
    compute_per_sample_loss,
)
from open_flamingo.eval.imagenet_utils import (
    openai_imagenet_classnames,
    IMAGENET_1K_CLASS_ID_TO_LABEL,
)

from open_flamingo.src.factory import create_model_and_transforms
from PIL import Image
from io import BytesIO
import base64
from open_flamingo.train.distributed import init_distributed_device, world_info_from_env
import string
from lavis.datasets.builders import load_dataset


def get_iou(box1, box2):
    # box1 and box2 should be in the format [x1, y1, x2, y2]
    intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
                   max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
    area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
    area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
    union = area_box1 + area_box2 - intersection
    iou = intersection / union if union > 0 else 0
    return iou

def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result

parser = argparse.ArgumentParser()
parser.add_argument("--lm_path", type=str, default="facebook/opt-1.3b")
parser.add_argument("--lm_tokenizer_path", type=str, default="facebook/opt-30b")
parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
parser.add_argument("--checkpoint_path", type=str, required=True)
parser.add_argument(
    "--results_file", type=str, default=None, help="JSON file to save results"
)

# Trial arguments
parser.add_argument("--shots", nargs="+", default=[0, 4, 8, 16, 32], type=int)
parser.add_argument(
    "--num_trials",
    type=int,
    default=1,
    help="Number of trials to run for each shot using different demonstrations",
)
parser.add_argument(
    "--trial_seeds",
    nargs="+",
    default=[0],
    help="Seeds to use for each trial for picking demonstrations and eval sets",
)
parser.add_argument(
    "--num_samples", type=int, default=5000, help="Number of samples to evaluate on"
)

parser.add_argument("--batch_size", type=int, default=8)

# Per-dataset evaluation flags
parser.add_argument(
    "--eval_coco",
    action="store_true",
    default=False,
    help="Whether to evaluate on COCO.",
)
parser.add_argument(
    "--eval_vqav2",
    action="store_true",
    default=False,
    help="Whether to evaluate on VQAV2.",
)
parser.add_argument(
    "--eval_ok_vqa",
    action="store_true",
    default=False,
    help="Whether to evaluate on OK-VQA.",
)
parser.add_argument(
    "--eval_imagenet",
    action="store_true",
    default=False,
    help="Whether to evaluate on ImageNet.",
)

parser.add_argument(
    "--eval_flickr30",
    action="store_true",
    default=False,
    help="Whether to evaluate on Flickr30.",
)

parser.add_argument(
    "--eval_refcoco",
    action="store_true",
    default=False,
    help="Whether to evaluate on RefCOCO.",
)

# Dataset arguments

## Flickr30 Dataset
parser.add_argument(
    "--flickr_image_dir_path",
    type=str,
    help="Path to the flickr30/flickr30k_images directory.",
    default=None,
)
parser.add_argument(
    "--flickr_annotations_json_path",
    type=str,
    help="Path to the dataset_flickr30k_coco_style.json file.",
    default=None,
)

## COCO Dataset
parser.add_argument(
    "--coco_image_dir_path",
    type=str,
    help="Path to the flickr30/flickr30k_images directory.",
    default=None,
)
parser.add_argument(
    "--coco_annotations_json_path",
    type=str,
    default=None,
)

## VQAV2 Dataset
parser.add_argument(
    "--vqav2_image_dir_path",
    type=str,
    default=None,
)
parser.add_argument(
    "--vqav2_questions_json_path",
    type=str,
    default=None,
)
parser.add_argument(
    "--vqav2_annotations_json_path",
    type=str,
    default=None,
)

## OK-VQA Dataset
parser.add_argument(
    "--ok_vqa_image_dir_path",
    type=str,
    help="Path to the vqav2/train2014 directory.",
    default=None,
)
parser.add_argument(
    "--ok_vqa_questions_json_path",
    type=str,
    help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.",
    default=None,
)
parser.add_argument(
    "--ok_vqa_annotations_json_path",
    type=str,
    help="Path to the v2_mscoco_train2014_annotations.json file.",
    default=None,
)

## Imagenet dataset
parser.add_argument("--imagenet_root", type=str, default="/tmp")

## RefCOCO dataset
parser.add_argument("--refcoco_tsvfile", type=str, default=None)

parser.add_argument(
    "--location_token_num",
    default=1000,
    type=int,
)
# distributed training
parser.add_argument(
    "--dist-url",
    default="env://",
    type=str,
    help="url used to set up distributed training",
)
parser.add_argument(
    "--dist-backend", default="nccl", type=str, help="distributed backend"
)
parser.add_argument(
    "--horovod",
    default=False,
    action="store_true",
    help="Use horovod for distributed training.",
)
parser.add_argument(
    "--no-set-device-rank",
    default=False,
    action="store_true",
    help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
)
parser.add_argument(
    "--dist",
    default=False,
    action="store_true",
)
parser.add_argument(
    "--lora",
    default=False,
    action="store_true",
)
parser.add_argument(
    "--lora_r",
    default=16,
    type=int,
    required=False,
)
parser.add_argument(
    "--legacy",
    default=False,
    action="store_true",
)
parser.add_argument(
    "--special",
    default=False,
    action="store_true",
)
parser.add_argument(
    "--id",
    default=0,
    type=int,
    required=False,
)

parser.add_argument(
    "--eval_gqa",
    default=False,
    action="store_true",
)
parser.add_argument(
    "--use_sam",
    default=None,
    type=str,
    required=False,
)
parser.add_argument(
    "--add_visual_token",
    default=False,
    action="store_true",
)
parser.add_argument(
    "--use_format_v2",
    default=False,
    action="store_true",
)


class OKVQAPostProcess():
    def __init__(self):
        self._lemmatizer = None

    def _lemmatize(self, answers):
        def apply(answer):
            doc = self.lemmatizer(answer)

            words = []
            for token in doc:
                if token.pos_ in ["NOUN", "VERB"]:
                    words.append(token.lemma_)
                else:
                    words.append(token.text)
            answer = " ".join(words)

            return answer

        return [apply(answer) for answer in answers]

    @property
    def lemmatizer(self):
        if self._lemmatizer is None:
            try:
                import spacy

                self._lemmatizer = spacy.load("en_core_web_sm")
            except ImportError:
                logging.error(
                    """
                    Please install spacy and en_core_web_sm model to apply lemmatization.
                    python -m spacy download en_core_web_sm
                    OR
                    import spacy.cli
                    spacy.cli.download("en_core_web_sm")
                    """
                )
                exit(1)

        return self._lemmatizer
        

def main():
    args = parser.parse_args()
    if args.dist:
        args.local_rank, args.rank, args.world_size = world_info_from_env()
        print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
        device_id = init_distributed_device(args)
    else:
        args.rank = 0
        args.world_size = 1
        print(f"rank: {args.rank} world_size: {args.world_size}")
    
    if "sam" in args.checkpoint_path:
        args.use_sam = "vit_l"

    args.add_visual_token = True
    if "lora" in args.checkpoint_path:
        args.lora = True


    args.add_pe = False
    args.add_box = False
    args.relation = False
    if "debug" in args.checkpoint_path:
        # args.add_pe = True
        args.add_box = True
    if "box" in args.checkpoint_path:
        args.add_box = True
    if "pe" in args.checkpoint_path:
        args.add_pe = True
    if "rel" in args.checkpoint_path:
        args.relation = True
        args.add_pe = False
    if "previsual" in args.checkpoint_path:
        args.use_format_v2 = True
        args.relation = False



    # load model
    flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms(
        args.vision_encoder_path,
        args.vision_encoder_pretrained,
        args.lm_path,
        args.lm_tokenizer_path,
        location_token_num=args.location_token_num,
        lora=args.lora,
        lora_r=16,
        use_sam=args.use_sam,
        add_visual_token=args.add_visual_token,
        use_format_v2=args.use_format_v2,
        add_box=args.add_box,
        add_pe=args.add_pe,
        add_relation=args.relation,
    )
    flamingo.use_format_v2 = args.use_format_v2
    if args.special:
        flamingo.special = True
    else:
        flamingo.special = False
    if args.legacy:
        flamingo.legacy = True
        print("use legacy evaluation")
    flamingo.step_num = int(args.checkpoint_path.split("/")[-1].split(".")[0].split("_")[-1])
    flamingo.expr_name = args.checkpoint_path.split("/")[-2]
    if args.rank == 0:
        print("legacy", True if hasattr(flamingo, "legacy") else False)
        print("step:", flamingo.step_num)
        print("expr:", flamingo.expr_name)
        print("use format v2:", flamingo.use_format_v2)
        print(args)
    checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
    model_state_dict = {}
    for key in checkpoint["model_state_dict"].keys():
        model_state_dict[key.replace("module.", "")] = checkpoint["model_state_dict"][key]
    if "vision_encoder.logit_scale"in model_state_dict:
        # previous checkpoint has some unnecessary weights
        del model_state_dict["vision_encoder.logit_scale"]
        del model_state_dict["vision_encoder.visual.proj"]
        del model_state_dict["vision_encoder.visual.ln_post.weight"]
        del model_state_dict["vision_encoder.visual.ln_post.bias"]
    flamingo.load_state_dict(model_state_dict, strict=True)
    results = defaultdict(list)
    if args.eval_coco:
        print("Evaluating on COCO...")
        for shot in args.shots:
            scores = []
            for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
                cider_score = evaluate_coco_flickr(
                    model=flamingo,
                    tokenizer=tokenizer,
                    image_processor=image_processor,
                    batch_size=args.batch_size,
                    image_dir_path=args.coco_image_dir_path,
                    annotations_json_path=args.coco_annotations_json_path,
                    device=args.device,
                    seed=seed,
                    vis_embed_size=vis_embed_size,
                    rank=args.rank,
                    world_size=args.world_size,
                    id=args.id,
                )
                print(f"Shots {shot} Trial {trial} CIDEr score: {cider_score}")
                scores.append(cider_score)
            print(f"Shots {shot} Mean CIDEr score: {np.mean(scores)}")
            results["coco"].append(
                {"shots": shot, "trials": scores, "mean": np.mean(scores)}
            )

    if args.eval_ok_vqa:
        print("Evaluating on OK-VQA...")
        for shot in args.shots:
            scores = []
            for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
                ok_vqa_score = evaluate_vqa(
                    model=flamingo,
                    tokenizer=tokenizer,
                    image_processor=image_processor,
                    batch_size=args.batch_size,
                    image_dir_path=args.ok_vqa_image_dir_path,
                    questions_json_path=args.ok_vqa_questions_json_path,
                    annotations_json_path=args.ok_vqa_annotations_json_path,
                    vqa_dataset="ok_vqa",
                    vis_embed_size=vis_embed_size,
                    rank=args.rank,
                    world_size=args.world_size,
                    id=args.id,
                )
            results["ok_vqa"].append(
                {"shots": shot, "score": ok_vqa_score}
            )

    if args.eval_vqav2:
        print("Evaluating on VQAv2...")
        for shot in args.shots:
            scores = []
            for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
                vqa_score = evaluate_vqa(
                    model=flamingo,
                    tokenizer=tokenizer,
                    image_processor=image_processor,
                    batch_size=args.batch_size,
                    image_dir_path=args.vqav2_image_dir_path,
                    questions_json_path=args.vqav2_questions_json_path,
                    annotations_json_path=args.vqav2_annotations_json_path,
                    vqa_dataset="vqa",
                    vis_embed_size=vis_embed_size,
                    rank=args.rank,
                    world_size=args.world_size,
                    id=args.id,
                )
            results["vqav2"].append(
                {"shots": shot, "score": vqa_score}
            )

    if args.eval_gqa:
        print("Evaluating on GQA...")
        for shot in args.shots:
            scores = []
            for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
                vqa_score = evaluate_vqa(
                    model=flamingo,
                    tokenizer=tokenizer,
                    image_processor=image_processor,
                    batch_size=args.batch_size,
                    vqa_dataset="gqa",
                    vis_embed_size=vis_embed_size,
                    rank=args.rank,
                    world_size=args.world_size,
                    id=args.id,
                )
            results["gqa"].append(
                {"shots": shot, "score": vqa_score}
            )

    if args.eval_imagenet:
        print("Evaluating on ImageNet...")
        for shot in args.shots:
            scores = []
            for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
                imagenet_score = evaluate_imagenet(
                    model=flamingo,
                    tokenizer=tokenizer,
                    image_processor=image_processor,
                    batch_size=args.batch_size,
                    num_samples=args.num_samples,
                    num_shots=shot,
                    device=args.device,
                    seed=seed,
                    imagenet_root=args.imagenet_root,
                )
                print(
                    f"Shots {shot} Trial {trial} " f"ImageNet score: {imagenet_score}"
                )
                scores.append(imagenet_score)
            print(f"Shots {shot} Mean ImageNet score: {np.mean(scores)}")
            results["imagenet"].append(
                {"shots": shot, "trials": scores, "mean": np.mean(scores)}
            )

    if args.eval_refcoco:
        print("Evaluating on RefCOCO...")
        refcoco_score = evaluate_refcoco(
            model=flamingo,
            tokenizer=tokenizer,
            image_processor=image_processor,
            batch_size=args.batch_size,
            device=args.device,
            tsvfile=args.refcoco_tsvfile,
            vis_embed_size=vis_embed_size,
            rank=args.rank,
            world_size=args.world_size,
            id=args.id,
        )
        results["refcoco"].append(
            {"score": refcoco_score}
        )

def prepare_batch_images(batch, image_processor):
    batch_images = None
    for b in batch:
        b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
        if batch_images is None:
            batch_images = b_image
        else:
            batch_images = torch.cat([batch_images, b_image], dim=0)
    return batch_images

def get_outputs(
    model,
    batch_images,
    attention_mask,
    max_generation_length,
    min_generation_length,
    num_beams,
    length_penalty,
    input_ids,
    image_start_index_list=None,
    image_nums=None,
    bad_words_ids=None,
):
    with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
        outputs = model.generate(
            batch_images,
            input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_generation_length,
            min_length=min_generation_length,
            num_beams=num_beams,
            length_penalty=length_penalty,
            image_start_index_list=image_start_index_list,
            image_nums=image_nums,
            bad_words_ids=bad_words_ids,
        )

    outputs = outputs[:, len(input_ids[0]) :]
    return outputs


def evaluate_coco_flickr(
    model,
    tokenizer,
    image_processor,
    batch_size,
    image_dir_path,
    annotations_json_path,
    seed=42,
    max_generation_length=20,
    num_beams=1,
    length_penalty=-2.0,
    device=-1,
    is_flickr=False,
    vis_embed_size=None,
    rank=0,
    world_size=1,
    id=0,
):
    """Evaluate a model on COCO dataset.

    Args:
        model (nn.Module): model to evaluate
        tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
        image_processor : image processor for the model
        batch_size (int): batch size
        image_dir_path (str, optional): path to the directory containing the images.
        annotations_json_path (str, optional): path to the json file containing the annotations.
        seed (int, optional): seed for random number generator. Defaults to 42.
        max_generation_length (int, optional): maximum length of the generated caption. Defaults to 10.
        num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
        length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
        num_samples (int, optional): number of samples to evaluate on. Defaults to 5000.
        query_set_size (int, optional): number of samples to use for query set. Defaults to 2048.
        num_shots (int, optional): number of in-context samples to use. Defaults to 8.
        device (int, optional): device to use. Defaults to -1.
        num_workers (int, optional): number of workers to use for dataloader. Defaults to 4.
        is_flickr (bool): defines if that data is COCO or Flickr. Defaults to False (COCO).

    Returns:
        float: CIDEr score

    """
    # eval_dataset = COCOFlickrDataset(
    #     image_dir_path=image_dir_path,
    #     annotations_path=annotations_json_path,
    #     is_flickr=is_flickr,
    # )
    coco_dataset = load_dataset("coco_caption")
    eval_dataset = coco_dataset["test"]


    model.eval().cuda()
    predictions = defaultdict()
    lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
    # if "peft" in lang_encoder_name:
        # lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
    try:
        media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
        endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
        pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
        bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
    except:
        pass

    def get_prompt(sample):
        return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"

    tokenizer.padding_side = "left"
    cnt = 0
    if world_size > 1:
        torch.distributed.barrier()
    desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"
    for ii, batch in enumerate(more_itertools.chunked(
        tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size
    )):
        if ii % world_size != rank:
            continue
        cnt += len(batch)
        batch_images = prepare_batch_images(
            batch=batch,
            image_processor=image_processor,
        ).cuda()
        batch_text = [get_prompt(s) for s in batch]
        encodings = tokenizer(
            batch_text,
            padding="longest",
            truncation=True,
            return_tensors="pt",
            max_length=2000,
        )
        input_ids = encodings["input_ids"].cuda()
        attention_mask = encodings["attention_mask"].cuda()
        skip_special_tokens = False
        if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
            if rank == 0:
                tqdm.write("use legacy model")
            skip_special_tokens = True
            for i in range(len(input_ids)):
                media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
                endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
                input_ids[i, media_token_index - 1] = media_token_id
                input_ids[i, media_token_index] = pad_token_id
                input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
                input_ids[i, endofmedia_token_index] = bos_token_id
        image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
        image_start_index_list = [[x] for x in image_start_index_list]
        image_nums = [1] * len(input_ids)
        if "llama" in lang_encoder_name:
            attention_mask[input_ids == 0] = 0
        outputs = get_outputs(
            model=model,
            batch_images=batch_images,
            attention_mask=attention_mask,
            max_generation_length=30,
            min_generation_length=8,
            num_beams=5,
            length_penalty=0,
            input_ids=input_ids,
            image_start_index_list=image_start_index_list,
            image_nums=image_nums,
        )
        new_predictions = [
            postprocess_captioning_generation(out).replace('"', "")
            for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ]
        # if rank == 0:
        #     tqdm.write(f"{batch_images.shape} {batch[0]} pred: {new_predictions[0]}")

        for i, sample in enumerate(batch):
            predictions[int(sample["image_id"])] = {
                "caption": new_predictions[i],
            }
    results_path = (
        f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
        if is_flickr
        else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json"
    )
    with open(results_path, "w") as f:
        f.write(
            json.dumps(
                [
                    {"image_id": k, "caption": predictions[k]["caption"]}
                    for k in predictions
                ],
                indent=2,
            )
        )
    print("save to", results_path)
    del predictions
    time.sleep(10)
    if world_size > 1:
        torch.distributed.barrier()
    if rank == 0:
        print(f"evaluate on rank {rank}. world size is {world_size}")
        predictions = []
        for rank_i in range(world_size):
            part_results_path = (
                f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json"
                if is_flickr
                else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json"
            )
            print("load", part_results_path)
            predictions.extend(json.load(open(part_results_path)))
            os.remove(part_results_path)
        print("num:", len(predictions))
        results_path = (
            f"flickrresults_{lang_encoder_name}.json"
            if is_flickr
            else f"cocoresults_{lang_encoder_name}.json"
        )
        json.dump(predictions, open(results_path, "w"), indent=2)

        metrics = compute_cider(
            result_path=results_path,
            annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json",
        )
        os.makedirs("eval_results", exist_ok=True)
        acc = metrics["CIDEr"]
        with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
            f.write(json.dumps(predictions, indent=2))

        # delete the temporary file
        os.remove(results_path)
    else:
        metrics = {}
        metrics["CIDEr"] = 0.0

    return metrics["CIDEr"]


def evaluate_vqa(
    model,
    tokenizer,
    image_processor,
    batch_size,
    image_dir_path=None,
    questions_json_path=None,
    annotations_json_path=None,
    vqa_dataset="vqa",
    vis_embed_size=None,
    rank=0,
    world_size=1,
    id=0,
):
    """
    Evaluate a model on VQA datasets. Currently supports VQA v2.0.

    Args:
        model (nn.Module): model to evaluate
        tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
        image_processor : image processor for the model
        batch_size (int): batch size
        image_dir_path (str): path to image directory
        questions_json_path (str): path to questions json file
        annotations_json_path (str): path to annotations json file
        seed (int, optional): random seed. Defaults to 42.
        max_generation_length (int, optional): max generation length. Defaults to 5.
        num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
        length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
        num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples.
        query_set_size (int, optional): size of the query set. Defaults to 2048.
        num_shots (int, optional): number of shots to use. Defaults to 8.
        device (int, optional): device to use. Defaults to -1 (cpu).
        num_workers (int, optional): number of workers to use. Defaults to 4.
        vqa_dataset (string): type of vqa dataset: currently supports vqa, ok_vqa. Defaults to vqa.
    Returns:
        float: accuracy score
    """
    if world_size > 1:
        torch.distributed.barrier()
    if vqa_dataset == "gqa":
        eval_dataset = GQADataset()
    else:
        eval_dataset = VQADataset(
            image_dir_path=image_dir_path,
            question_path=questions_json_path,
            annotations_path=annotations_json_path,
            vqa_dataset=vqa_dataset,
        )
    postprocessor = OKVQAPostProcess()
    try:
        media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
        endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
        pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
        bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
    except:
        pass
    def get_prompt(sample):
        return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:"
        # return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"

    model.eval().cuda()
    lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
    if "peft" in lang_encoder_name:
        lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
    predictions = []
    tokenizer.padding_side = "left"
    if world_size > 1:
        torch.distributed.barrier()
    for ii, batch in enumerate(more_itertools.chunked(
        tqdm(eval_dataset, desc="Running inference", disable=(rank != 0)), batch_size
    )):
        if ii % world_size != rank:
            continue
        batch_images = prepare_batch_images(
            batch=batch,
            image_processor=image_processor,
        ).cuda()
        batch_text = [get_prompt(s) for s in batch]
        encodings = tokenizer(
            batch_text,
            return_tensors="pt",
            padding="longest",
            truncation=True,
            max_length=2000,
        )
        input_ids = encodings["input_ids"].cuda()
        attention_mask = encodings["attention_mask"].cuda()
        skip_special_tokens = True
        if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
            if rank == 0:
                tqdm.write("use legacy model")
            for i in range(len(input_ids)):
                media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
                endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
                input_ids[i, media_token_index - 1] = media_token_id
                input_ids[i, media_token_index] = pad_token_id
                input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
                input_ids[i, endofmedia_token_index] = bos_token_id
        image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
        image_start_index_list = [[x] for x in image_start_index_list]
        image_nums = [1] * len(input_ids)
        if "llama" in lang_encoder_name:
            attention_mask[input_ids == 0] = 0
        outputs = get_outputs(
            model=model,
            batch_images=batch_images,
            attention_mask=attention_mask,
            max_generation_length=10,
            min_generation_length=1,
            num_beams=5,
            length_penalty=0,
            input_ids=input_ids,
            image_start_index_list=image_start_index_list,
            image_nums=image_nums,
        )
        # postprocess begin
        new_predictions = [
            out.strip().lower().strip(string.punctuation+" ") for out in tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens)
        ]
        if vqa_dataset == "ok_vqa":
            new_predictions = postprocessor._lemmatize(new_predictions)
        if model.special:
            for i in range(len(new_predictions)):
                for answer, _ in Counter(batch[i]['answers']).most_common():
                    if answer in new_predictions[i]:
                        new_predictions[i] = answer
                        break
                    if "cant" in new_predictions[i] and "no" == answer:
                        new_predictions[i] = answer
                        break
                    if "can" in new_predictions[i] and "not" not in new_predictions[i] and "cant" not in new_predictions[i] and "yes" == answer:
                        new_predictions[i] = answer
                        break

        # if rank == 0:
        #     tqdm.write(f"{image_nums} {image_start_index_list}")
        #     for i in range(1):
        #         tqdm.write(f"ID: {batch[i]['question_id']} | gt QA: {batch[i]['question']} {Counter(batch[i]['answers']).most_common()}")
        #         tqdm.write("prompt: " + tokenizer.decode(input_ids[i]))
        #         tqdm.write("model output: " + new_predictions[i])

        predictions.extend(
            [
                {"answer": p, "question_id": sample["question_id"], "_question": sample["question"], "answers": sample["answers"]}
                for p, sample in zip(new_predictions, batch)
            ]
        )
    with open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json", "w") as f:
        f.write(json.dumps(predictions))
    print("save to", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json")

    time.sleep(10)
    if world_size > 1:
        torch.distributed.barrier()
    if rank == 0:
        print(f"evaluate on rank {rank}. world size is {world_size}")
        predictions = []
        for rank_i in range(world_size):
            print("load", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
            predictions.extend(json.load(open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")))
            os.remove(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
        print("num:", len(predictions))
        # save the predictions to a temporary file
        random_uuid = str(uuid.uuid4())
        with open(f"{vqa_dataset}results_{random_uuid}.json", "w") as f:
            f.write(json.dumps(predictions, indent=4))

        if vqa_dataset == "gqa":
            acc = compute_gqa_accuracy(predictions)
        else:
            acc = compute_vqa_accuracy(
                f"{vqa_dataset}results_{random_uuid}.json",
                questions_json_path,
                annotations_json_path,
                vqa_dataset=vqa_dataset,
            )
        print(vqa_dataset, "score:", acc, "| save to", f"{vqa_dataset}results_{random_uuid}.json")
        os.makedirs("eval_results", exist_ok=True)
        with open(os.path.join("eval_results", f"{vqa_dataset}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
            f.write(json.dumps(predictions, indent=2))

        # delete the temporary file
        os.remove(f"{vqa_dataset}results_{random_uuid}.json")
    else:
        time.sleep(5)
        acc = 0.0
    if world_size > 1:
        torch.distributed.barrier()
    return acc


def evaluate_refcoco(
    model,
    tokenizer,
    image_processor,
    batch_size,
    tsvfile,
    max_generation_length=20,
    num_beams=3,
    length_penalty=-2.0,
    device=-1,
    vis_embed_size=None,
    rank=0,
    world_size=1,
    id=0,
):
    model.eval().cuda()
    loc_token_ids = []
    for i in range(1000):
        loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
    total = 0
    correct = 0
    ious = []
    if "refcocog" in tsvfile:
        dataset_name = "refcocog"
    elif "refcocoplus" in tsvfile:
        dataset_name = "refcocoplus"
    else:
        dataset_name = "refcoco"
    with open(tsvfile, "r") as f:
        lines = f.readlines()
        pbar = tqdm(lines, disable=(rank != 0))
        for ii, line in enumerate(pbar):
            if ii % world_size != rank:
                continue
            total += 1
            line = line.rstrip()
            uniq_id, image_id, text, region_coord, image = line.split("\t")

            # image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
            # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/cat.png").convert("RGB")
            # image2 = Image.open("yolo.png").convert("RGB")
            # image1 = image1.resize((224, 224))
            # image2 = image2.resize((224, 224))
            # images = [image1, image2]

            # gt_box = np.array(list(map(float, region_coord.split(","))))
            # width = image.width
            # height = image.height
            # gt_box /= np.array([width, height, width, height])
            # batch_images = [image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0) for image in images]
            # batch_images = torch.cat(batch_images, dim=0)
            # image = Image.open("yolo_test.png").convert("RGB")
            image = Image.open("example.png").convert("RGB")
            image = image.resize((224, 224))
            batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
            # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text.rstrip('.')}<|#visual#|>"]
            prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|><|#previsual#|><|#prebox#|><|#endofattr#|>man<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|> is sitting on<|#object#|><|#previsual#|>"]
            # prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|><|#previsual#|>man<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|> is sitting on<|#object#|><|#previsual#|>"]
            # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"]
            # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>a man<|#visual#|> is doing a trick on a skateboard<|#visual#|>"]


            encodings = tokenizer(
                prompt,
                padding="longest",
                truncation=True,
                return_tensors="pt",
                max_length=2000,
            )
            input_ids = encodings["input_ids"]
            attention_mask = encodings["attention_mask"]
            image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
            image_start_index_list = [image_start_index_list]
            image_nums = [1]
            vision_x = batch_images.cuda()
            lang_x = input_ids.cuda()
            attention_mask = attention_mask.cuda()
            print(image_start_index_list, image_nums)

            model.debug_id = 0
            # outputs = get_outputs(
            #     model=model,
            #     batch_images=vision_x,
            #     attention_mask=attention_mask,
            #     max_generation_length=20,
            #     min_generation_length=8,
            #     num_beams=5,
            #     length_penalty=0,
            #     input_ids=lang_x,
            #     image_start_index_list=image_start_index_list,
            #     image_nums=image_nums,
            # )
            # print(tokenizer.decode(outputs[0]))
            # exit()

            prebox = [93, 20, 155, 172] # man
            # prebox = [32, 82, 89, 213] # dog
            # prebox = [34, 49, 166, 164] # bike
            with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
                outputs = model(
                    vision_x=vision_x,
                    lang_x=lang_x,
                    attention_mask=attention_mask,
                    labels=None,
                    image_nums=image_nums,
                    image_start_index_list=image_start_index_list,
                    added_bbox_list=[torch.tensor(prebox).cuda().unsqueeze(0) / 224],
                    add_box=True,
                    debug_mode=True,
                )
            
            boxes = outputs["boxes"]
            scores = outputs["scores"]
            box = boxes[scores.argmax()]
            open_cv_image = np.array(image)
            # Convert RGB to BGR 
            open_cv_image = open_cv_image[:, :, ::-1].copy() 
            open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
            open_cv_image = cv2.rectangle(open_cv_image, prebox[:2], prebox[2:], (0, 0, 255), 2)
            cv2.imwrite(f"output2.jpg", open_cv_image)
            print(box)
            print(prebox)
            exit()

            # force_words = ["man", "table"]
            # force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids


            # sequences, hidden_states_for_each_step = get_outputs(
            #     model=model,
            #     batch_images=vision_x,
            #     attention_mask=attention_mask,
            #     max_generation_length=20,
            #     min_generation_length=8,
            #     num_beams=5,
            #     length_penalty=0,
            #     input_ids=lang_x,
            #     image_start_index_list=image_start_index_list,
            #     image_nums=image_nums,
            #     force_words_ids=force_words_ids,
            # )
            # sequence = sequences[0]
            # print(tokenizer.decode(sequence))
            # for i, token in enumerate(sequence):
            #     if token == model.visual_token_id:
            #         print(tokenizer.decode(sequence[:i+1]))
            #         if hasattr(model, "debug_id"):
            #             model.debug_id += 1
            #         else:
            #             model.debug_id = 0
            #         this_lang_x = torch.hstack([lang_x[0], sequence[:i+1]]).unsqueeze(0)
            #         this_attention_mask = torch.ones_like(this_lang_x).cuda()
            #         with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
            #             _ = model(
            #                 vision_x=vision_x,
            #                 lang_x=this_lang_x,
            #                 attention_mask=this_attention_mask,
            #                 labels=None,
            #                 image_nums=image_nums,
            #                 image_start_index_list=image_start_index_list,
            #                 added_bbox_list=None,
            #             )
            # exit()

    with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
        f.write(json.dumps([total, correct]))
    if world_size > 1:
        torch.distributed.barrier()
    if rank == 0:
        total = 0
        correct = 0
        print(f"evaluate on rank {rank}. world size is {world_size}")
        for rank_i in range(world_size):
            [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
            os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
            total += total_part
            correct += correct_part
        score = correct / total
        print("score:", score)
        with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
            pass
    else:
        score = 0.0
    if world_size > 1:
        torch.distributed.barrier()
    return score


if __name__ == "__main__":
    main()