import argparse
import logging

import copy
import codecs
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import torch
import decord
import os
import json
import random
import requests
from tqdm import tqdm
import numpy as np

from llava.constants import MM_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images_v2
from llava.model import *
from llava.model.builder import load_pretrained_model
from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor

from transformers import CLIPImageProcessor
from PIL import Image
from decord import VideoReader, cpu

decord.bridge.set_bridge("torch")



def get_image(image_path):
    image = Image.open(image_path).convert('RGB')
    return image


# def load_frames(frames_dir, frame_names):
#     results = []
#     for frame_name in frame_names:
#         image_path = f"{frames_dir}/{frame_name}"
#         image = get_image(image_path)
#         results.append(image)
#     return results

def load_frames(frames_dir):
    results = []
    image_files = [(int(os.path.splitext(img)[0]), img) for img in os.listdir(frames_dir) if not img.startswith('cuttime')]
    image_files = sorted(image_files, key=lambda img: img[0])
    for frame_name in image_files:
        image_path = f"{frames_dir}/{frame_name[1]}"
        image = get_image(image_path)
        results.append(image)
    return results




def uniform_sample(frames, num_segments):
    indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int)
    frames = [frames[ind] for ind in indices]
    return frames

    


def run_inference(args, frame_folders):
    disable_torch_init()
    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, _, context_len = load_pretrained_model(model_path, args.model_base, model_name, device_map={"":0})
    image_processor = Blip2ImageTrainProcessor(
        image_size=model.config.img_size,
        is_training=False)
    model_cfgs = model.config

    
    for frame_folder in frame_folders:
        question = "Describe the video in detail."

        # Question input here
        qs = question
        # qs = DEFAULT_VIDEO_TOKEN + '\n' + qs
        if model.config.mm_use_start_end:
            qs = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_TOKEN + DEFAULT_VIDEO_END_TOKEN + '\n' + qs
        else:
            qs = DEFAULT_VIDEO_TOKEN + '\n' + qs

        conv = conv_templates[args.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        # inputs = tokenizer([prompt])
        input_ids = tokenizer_image_token(prompt, tokenizer, MM_TOKEN_INDEX, return_tensors='pt').unsqueeze(
            0).cuda()
    

        # try:
        images = load_frames(frame_folder)
        # images = images[:15:2]
        if len(images) > args.num_segments:
            images = uniform_sample(images, args.num_segments)
        elif len(images) < args.num_segments:
            # frame_indices = [i for i in range(len(images))]
            images = uniform_sample(images, args.num_segments)
        else:
            pass
        
        if model_cfgs.image_aspect_ratio == 'pad':
            model_cfgs.image_aspect_ratio = 'no_padding' 
        images_tensor = process_images_v2(images, image_processor, model_cfgs).half().cuda()
        # print(images_tensor.shape)

        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
        images_tensors = [images_tensor.clone() for _ in range(args.num_beams)]
        with torch.inference_mode():
            output_ids = model.generate(
                input_ids,
                images= images_tensors,
                do_sample=True,
                temperature=args.temperature,
                top_p=args.top_p,
                num_beams=args.num_beams,
                no_repeat_ngram_size=args.no_repeat_ngram_size,
                pad_token_id=tokenizer.eos_token_id, 
                max_new_tokens=1024,
                use_cache=True,
                stopping_criteria=[stopping_criteria])


        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]

        outputs = outputs.strip()
        if outputs.endswith(conv.sep):
            outputs = outputs[:-len(stop_str)]
        outputs = outputs.strip()
        print(outputs)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--video_dir', help='Directory containing video files.', type=str, default="")
    parser.add_argument('--validation_data', type=str,
                        default="/mnt/bn/yukunfeng-nasdrive/xiangchen/repo/benchmark_data/refine_chair_eval_gt_neg_1k.json")
    parser.add_argument('--num_samples', help='Number of samples to predict', type=int, default=-1)
    parser.add_argument("--model_path", type=str,
                        default="/mnt/bn/algo-masp-nas-2/xiangchen/model/masp_models/checkpoints/llava-mistral_gpt4v_adso185k_unfreeze_qformer_data_sampler/")
    parser.add_argument("--model_base", type=str, default=None)
    parser.add_argument("--conv_mode", type=str, default="v1")
    parser.add_argument("--output_file", type=str, default="vid_top1k_res.json")
    parser.add_argument("--num_segments", type=int, default=10)
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--top_p", type=float, default=None)
    parser.add_argument("--num_beams", type=int, default=1)
    parser.add_argument("--no_repeat_ngram_size", type=int, default=3)

    args = parser.parse_args()
    frame_folders = ['/mnt/bn/algo-masp-nas-2/xiangchen/repo/LLaVA/tmp/cases/yj']
    run_inference(args, frame_folders)