crystal-technologies's picture
Upload 1653 files
714d948
from CircumSpect.vqa.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN
from CircumSpect.vqa.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from CircumSpect.vqa.conversation_obj import conv_templates_obj, SeparatorStyle_obj
from CircumSpect.vqa.conversation_vqa import conv_templates, SeparatorStyle
from transformers import AutoTokenizer, BitsAndBytesConfig
from CircumSpect.vqa.utils import disable_torch_init
from Perceptrix.streamer import TextStreamer
from CircumSpect.vqa.model import *
from utils import setup_device
from io import BytesIO
from PIL import Image
import requests
import torch
import os
device = setup_device()
def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
disable_torch_init()
model_name = os.environ.get('VLM_MODEL')
model_path = "models/CRYSTAL-vision" if model_name == None else model_name
model_base = None
conv_mode = None
temperature = 0.2
max_new_tokens = 512
model_name = get_model_name_from_path(model_path)
image_processor = None
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
model = LlavaMPTForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
device_map="auto",
torch_dtype=torch.float32 if str(device) == "cpu" else torch.float16,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
low_cpu_mem_usage=True,
bnb_4bit_compute_dtype=torch.bfloat16
) if str(device) == "cuda" else None,
offload_folder="offloads",
)
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model()
vision_tower.to(device=device, dtype=torch.float32)
image_processor = vision_tower.image_processor
if hasattr(model.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 2048
if 'llama-2' in model_name.lower():
conv_mode = "llava_llama_2"
elif "v1" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
conv_mode = "mpt"
else:
conv_mode = "llava_v0"
if conv_mode is not None and conv_mode != conv_mode:
print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, conv_mode, conv_mode))
else:
conv_mode = conv_mode
conv = conv_templates[conv_mode].copy()
if "mpt" in model_name.lower():
roles = ('User', 'Assistant')
else:
roles = conv.roles
streamer = TextStreamer(tokenizer, skip_prompt=True,
skip_special_tokens=True, save_file="vlm-reply.txt")
def answer_question(question, image_file):
conv = conv_templates[conv_mode].copy()
inp = question
image = load_image(image_file)
if str(device) == "cpu":
image_tensor = image_processor.preprocess(image, return_tensors='pt')[
'pixel_values'].to(device)
else:
image_tensor = image_processor.preprocess(image, return_tensors='pt')[
'pixel_values'].half().to(device)
print(f"{roles[1]}: ", end="")
if image is not None:
# first message
if model.config.mm_use_im_start_end:
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + \
DEFAULT_IM_END_TOKEN + '\n' + inp
else:
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
conv.append_message(conv.roles[0], inp)
image = None
else:
# later messages
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(
prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(
keywords, tokenizer, input_ids)
with open("./database/vlm-reply.txt", 'w') as clear_file:
clear_file.write("")
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=True,
temperature=0.2,
max_new_tokens=1024,
streamer=streamer,
use_cache=True,
stopping_criteria=[stopping_criteria])
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
conv.messages[-1][-1] = outputs
return outputs
conv_obj = conv_templates_obj[conv_mode].copy()
if "mpt" in model_name.lower():
roles = ('User', 'Assistant')
else:
roles = conv_obj.roles
def find_object_description(question, image_file):
conv_obj = conv_templates_obj[conv_mode].copy()
inp = question
image = load_image(image_file)
if str(device) == "cpu":
image_tensor = image_processor.preprocess(image, return_tensors='pt')[
'pixel_values'].to(device)
else:
image_tensor = image_processor.preprocess(image, return_tensors='pt')[
'pixel_values'].half().to(device)
print(f"{roles[1]}: ", end="")
if image is not None:
# first message
if model.config.mm_use_im_start_end:
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + \
DEFAULT_IM_END_TOKEN + '\n' + inp
else:
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
conv_obj.append_message(conv_obj.roles[0], inp)
image = None
else:
# later messages
conv_obj.append_message(conv_obj.roles[0], inp)
conv_obj.append_message(conv_obj.roles[1], None)
prompt = conv_obj.get_prompt()
input_ids = tokenizer_image_token(
prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device)
stop_str = conv_obj.sep if conv_obj.sep_style != SeparatorStyle_obj.TWO else conv_obj.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(
keywords, tokenizer, input_ids)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=True,
temperature=0.2,
max_new_tokens=1024,
streamer=streamer,
use_cache=True,
stopping_criteria=[stopping_criteria])
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
conv_obj.messages[-1][-1] = outputs
return outputs
if __name__ == "__main__":
print("RUNNING TEST\n\tTest Image: https://llava-vl.github.io/static/images/view.jpg\n\tPrompt: What is this image about?")
answer_question("What is this image about?",
"https://llava-vl.github.io/static/images/view.jpg")