model1 / llava /eval /masp_eval /eval_case.py
multitensor's picture
Upload folder using huggingface_hub
bbfa6f6 verified
import argparse
import logging
import copy
import codecs
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import torch
import decord
import os
import json
import random
import requests
from tqdm import tqdm
import numpy as np
from llava.constants import MM_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images_v2
from llava.model import *
from llava.model.builder import load_pretrained_model
from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor
from transformers import CLIPImageProcessor
from PIL import Image
from decord import VideoReader, cpu
decord.bridge.set_bridge("torch")
def get_image(image_path):
image = Image.open(image_path).convert('RGB')
return image
# def load_frames(frames_dir, frame_names):
# results = []
# for frame_name in frame_names:
# image_path = f"{frames_dir}/{frame_name}"
# image = get_image(image_path)
# results.append(image)
# return results
def load_frames(frames_dir):
results = []
image_files = [(int(os.path.splitext(img)[0]), img) for img in os.listdir(frames_dir) if not img.startswith('cuttime')]
image_files = sorted(image_files, key=lambda img: img[0])
for frame_name in image_files:
image_path = f"{frames_dir}/{frame_name[1]}"
image = get_image(image_path)
results.append(image)
return results
def uniform_sample(frames, num_segments):
indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int)
frames = [frames[ind] for ind in indices]
return frames
def run_inference(args, frame_folders):
disable_torch_init()
model_path = os.path.expanduser(args.model_path)
model_name = get_model_name_from_path(model_path)
tokenizer, model, _, context_len = load_pretrained_model(model_path, args.model_base, model_name, device_map={"":0})
image_processor = Blip2ImageTrainProcessor(
image_size=model.config.img_size,
is_training=False)
model_cfgs = model.config
for frame_folder in frame_folders:
question = "Describe the video in detail."
# Question input here
qs = question
# qs = DEFAULT_VIDEO_TOKEN + '\n' + qs
if model.config.mm_use_start_end:
qs = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_TOKEN + DEFAULT_VIDEO_END_TOKEN + '\n' + qs
else:
qs = DEFAULT_VIDEO_TOKEN + '\n' + qs
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
# inputs = tokenizer([prompt])
input_ids = tokenizer_image_token(prompt, tokenizer, MM_TOKEN_INDEX, return_tensors='pt').unsqueeze(
0).cuda()
# try:
images = load_frames(frame_folder)
# images = images[:15:2]
if len(images) > args.num_segments:
images = uniform_sample(images, args.num_segments)
elif len(images) < args.num_segments:
# frame_indices = [i for i in range(len(images))]
images = uniform_sample(images, args.num_segments)
else:
pass
if model_cfgs.image_aspect_ratio == 'pad':
model_cfgs.image_aspect_ratio = 'no_padding'
images_tensor = process_images_v2(images, image_processor, model_cfgs).half().cuda()
# print(images_tensor.shape)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
images_tensors = [images_tensor.clone() for _ in range(args.num_beams)]
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images= images_tensors,
do_sample=True,
temperature=args.temperature,
top_p=args.top_p,
num_beams=args.num_beams,
no_repeat_ngram_size=args.no_repeat_ngram_size,
pad_token_id=tokenizer.eos_token_id,
max_new_tokens=1024,
use_cache=True,
stopping_criteria=[stopping_criteria])
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(conv.sep):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
print(outputs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--video_dir', help='Directory containing video files.', type=str, default="")
parser.add_argument('--validation_data', type=str,
default="/mnt/bn/yukunfeng-nasdrive/xiangchen/repo/benchmark_data/refine_chair_eval_gt_neg_1k.json")
parser.add_argument('--num_samples', help='Number of samples to predict', type=int, default=-1)
parser.add_argument("--model_path", type=str,
default="/mnt/bn/algo-masp-nas-2/xiangchen/model/masp_models/checkpoints/llava-mistral_gpt4v_adso185k_unfreeze_qformer_data_sampler/")
parser.add_argument("--model_base", type=str, default=None)
parser.add_argument("--conv_mode", type=str, default="v1")
parser.add_argument("--output_file", type=str, default="vid_top1k_res.json")
parser.add_argument("--num_segments", type=int, default=10)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--num_beams", type=int, default=1)
parser.add_argument("--no_repeat_ngram_size", type=int, default=3)
args = parser.parse_args()
frame_folders = ['/mnt/bn/algo-masp-nas-2/xiangchen/repo/LLaVA/tmp/cases/yj']
run_inference(args, frame_folders)