|
from CircumSpect.vqa.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN |
|
from CircumSpect.vqa.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria |
|
from CircumSpect.vqa.conversation_obj import conv_templates_obj, SeparatorStyle_obj |
|
from CircumSpect.vqa.conversation_vqa import conv_templates, SeparatorStyle |
|
from transformers import AutoTokenizer, BitsAndBytesConfig |
|
from CircumSpect.vqa.utils import disable_torch_init |
|
from Perceptrix.streamer import TextStreamer |
|
from CircumSpect.vqa.model import * |
|
from utils import setup_device |
|
from io import BytesIO |
|
from PIL import Image |
|
import requests |
|
import torch |
|
import os |
|
|
|
device = setup_device() |
|
|
|
|
|
def load_image(image_file): |
|
if image_file.startswith('http') or image_file.startswith('https'): |
|
response = requests.get(image_file) |
|
image = Image.open(BytesIO(response.content)).convert('RGB') |
|
else: |
|
image = Image.open(image_file).convert('RGB') |
|
return image |
|
|
|
|
|
disable_torch_init() |
|
|
|
model_name = os.environ.get('VLM_MODEL') |
|
|
|
model_path = "models/CRYSTAL-vision" if model_name == None else model_name |
|
model_base = None |
|
conv_mode = None |
|
temperature = 0.2 |
|
max_new_tokens = 512 |
|
|
|
model_name = get_model_name_from_path(model_path) |
|
image_processor = None |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) |
|
model = LlavaMPTForCausalLM.from_pretrained( |
|
model_path, |
|
low_cpu_mem_usage=True, |
|
device_map="auto", |
|
torch_dtype=torch.float32 if str(device) == "cpu" else torch.float16, |
|
quantization_config=BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
low_cpu_mem_usage=True, |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) if str(device) == "cuda" else None, |
|
offload_folder="offloads", |
|
) |
|
|
|
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) |
|
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) |
|
if mm_use_im_patch_token: |
|
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) |
|
if mm_use_im_start_end: |
|
tokenizer.add_tokens( |
|
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
vision_tower = model.get_vision_tower() |
|
if not vision_tower.is_loaded: |
|
vision_tower.load_model() |
|
vision_tower.to(device=device, dtype=torch.float32) |
|
image_processor = vision_tower.image_processor |
|
|
|
if hasattr(model.config, "max_sequence_length"): |
|
context_len = model.config.max_sequence_length |
|
else: |
|
context_len = 2048 |
|
|
|
|
|
if 'llama-2' in model_name.lower(): |
|
conv_mode = "llava_llama_2" |
|
elif "v1" in model_name.lower(): |
|
conv_mode = "llava_v1" |
|
elif "mpt" in model_name.lower(): |
|
conv_mode = "mpt" |
|
else: |
|
conv_mode = "llava_v0" |
|
|
|
if conv_mode is not None and conv_mode != conv_mode: |
|
print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, conv_mode, conv_mode)) |
|
else: |
|
conv_mode = conv_mode |
|
|
|
conv = conv_templates[conv_mode].copy() |
|
if "mpt" in model_name.lower(): |
|
roles = ('User', 'Assistant') |
|
else: |
|
roles = conv.roles |
|
|
|
streamer = TextStreamer(tokenizer, skip_prompt=True, |
|
skip_special_tokens=True, save_file="vlm-reply.txt") |
|
|
|
|
|
def answer_question(question, image_file): |
|
conv = conv_templates[conv_mode].copy() |
|
inp = question |
|
image = load_image(image_file) |
|
if str(device) == "cpu": |
|
image_tensor = image_processor.preprocess(image, return_tensors='pt')[ |
|
'pixel_values'].to(device) |
|
else: |
|
image_tensor = image_processor.preprocess(image, return_tensors='pt')[ |
|
'pixel_values'].half().to(device) |
|
|
|
print(f"{roles[1]}: ", end="") |
|
|
|
if image is not None: |
|
|
|
if model.config.mm_use_im_start_end: |
|
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + \ |
|
DEFAULT_IM_END_TOKEN + '\n' + inp |
|
else: |
|
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp |
|
conv.append_message(conv.roles[0], inp) |
|
image = None |
|
else: |
|
|
|
conv.append_message(conv.roles[0], inp) |
|
conv.append_message(conv.roles[1], None) |
|
prompt = conv.get_prompt() |
|
input_ids = tokenizer_image_token( |
|
prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device) |
|
|
|
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 |
|
keywords = [stop_str] |
|
stopping_criteria = KeywordsStoppingCriteria( |
|
keywords, tokenizer, input_ids) |
|
|
|
with open("./database/vlm-reply.txt", 'w') as clear_file: |
|
clear_file.write("") |
|
|
|
with torch.inference_mode(): |
|
output_ids = model.generate( |
|
input_ids, |
|
images=image_tensor, |
|
do_sample=True, |
|
temperature=0.2, |
|
max_new_tokens=1024, |
|
streamer=streamer, |
|
use_cache=True, |
|
stopping_criteria=[stopping_criteria]) |
|
|
|
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() |
|
conv.messages[-1][-1] = outputs |
|
return outputs |
|
|
|
|
|
conv_obj = conv_templates_obj[conv_mode].copy() |
|
if "mpt" in model_name.lower(): |
|
roles = ('User', 'Assistant') |
|
else: |
|
roles = conv_obj.roles |
|
|
|
|
|
def find_object_description(question, image_file): |
|
conv_obj = conv_templates_obj[conv_mode].copy() |
|
inp = question |
|
image = load_image(image_file) |
|
if str(device) == "cpu": |
|
image_tensor = image_processor.preprocess(image, return_tensors='pt')[ |
|
'pixel_values'].to(device) |
|
else: |
|
image_tensor = image_processor.preprocess(image, return_tensors='pt')[ |
|
'pixel_values'].half().to(device) |
|
|
|
print(f"{roles[1]}: ", end="") |
|
|
|
if image is not None: |
|
|
|
if model.config.mm_use_im_start_end: |
|
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + \ |
|
DEFAULT_IM_END_TOKEN + '\n' + inp |
|
else: |
|
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp |
|
conv_obj.append_message(conv_obj.roles[0], inp) |
|
image = None |
|
else: |
|
|
|
conv_obj.append_message(conv_obj.roles[0], inp) |
|
conv_obj.append_message(conv_obj.roles[1], None) |
|
prompt = conv_obj.get_prompt() |
|
input_ids = tokenizer_image_token( |
|
prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device) |
|
|
|
stop_str = conv_obj.sep if conv_obj.sep_style != SeparatorStyle_obj.TWO else conv_obj.sep2 |
|
keywords = [stop_str] |
|
stopping_criteria = KeywordsStoppingCriteria( |
|
keywords, tokenizer, input_ids) |
|
|
|
with torch.inference_mode(): |
|
output_ids = model.generate( |
|
input_ids, |
|
images=image_tensor, |
|
do_sample=True, |
|
temperature=0.2, |
|
max_new_tokens=1024, |
|
streamer=streamer, |
|
use_cache=True, |
|
stopping_criteria=[stopping_criteria]) |
|
|
|
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() |
|
conv_obj.messages[-1][-1] = outputs |
|
return outputs |
|
|
|
|
|
if __name__ == "__main__": |
|
print("RUNNING TEST\n\tTest Image: https://llava-vl.github.io/static/images/view.jpg\n\tPrompt: What is this image about?") |
|
answer_question("What is this image about?", |
|
"https://llava-vl.github.io/static/images/view.jpg") |
|
|