laurenssam's picture
longer timeout
82d2a8d
import argparse
import hashlib
import json
import os
import time
from threading import Thread
import logging
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from pathlib import Path
from tinyllava.model.builder import load_pretrained_model
from tinyllava.mm_utils import (
KeywordsStoppingCriteria,
load_image_from_base64,
process_images,
tokenizer_image_token,
get_model_name_from_path,
)
from PIL import Image
from io import BytesIO
import base64
import torch
from transformers import StoppingCriteria
from tinyllava.constants import (
DEFAULT_IM_END_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IMAGE_TOKEN,
IMAGE_TOKEN_INDEX,
)
from tinyllava.conversation import SeparatorStyle, conv_templates, default_conversation
from transformers import TextIteratorStreamer
from pathlib import Path
DEFAULT_MODEL_PATH = "bczhou/TinyLLaVA-3.1B"
DEFAULT_MODEL_NAME = "TinyLLaVA-3.1B"
block_css = """
#buttons button {
min-width: min(120px,100%);
}
"""
title_markdown = """
# Privacy Aware Visual Language Models
[[Code](https://github.com/laurenssam/Privacy-Aware-Visual-Language-Models)] | πŸ“š [[Paper](https://arxiv.org/abs/2405.17423)]
"""
def regenerate(state, image_process_mode):
state.messages[-1][-1] = None
prev_human_msg = state.messages[-2]
if type(prev_human_msg[1]) in (tuple, list):
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "", None)
def clear_history():
state = default_conversation.copy()
return (state, state.to_gradio_chatbot(), "", None)
def add_text(state, text, image, image_process_mode):
if len(text) <= 0 and image is None:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "", None)
text = text[:1536] # Hard cut-off
if image is not None:
text = text[:1200] # Hard cut-off for images
if "<image>" not in text:
# text = '<Image><image></Image>' + text
text = text + "\n<image>"
text = (text, image, image_process_mode)
if len(state.get_images(return_pil=True)) > 0:
state = default_conversation.copy()
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "", None)
def load_demo():
state = default_conversation.copy()
return state
@torch.inference_mode()
def get_response(params):
prompt = params["prompt"]
ori_prompt = prompt
images = params.get("images", None)
num_image_tokens = 0
if images is not None and len(images) > 0:
if len(images) > 0:
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
raise ValueError(
"Number of images does not match number of <image> tokens in prompt"
)
images = [load_image_from_base64(image) for image in images]
# images = process_images(images, image_processor, model.config)
images = process_images(images, image_processor, model.config).to('cpu', dtype=torch.float)
if type(images) is list:
images = [
image.to(model.device) for image in images
]
else:
images = images.to(model.device)
replace_token = DEFAULT_IMAGE_TOKEN
if getattr(model.config, "mm_use_im_start_end", False):
replace_token = (
DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
)
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
num_image_tokens = (
prompt.count(replace_token) * model.get_vision_tower().num_patches
)
else:
images = None
image_args = {"images": images}
else:
images = None
image_args = {}
temperature = 0.0
top_p = 1.0
max_context_length = getattr(model.config, "max_position_embeddings", 2048)
max_new_tokens = 512
stop_str = params.get("stop", None)
do_sample = False
logger.info(prompt)
input_ids = (
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
.unsqueeze(0)
.to(model.device).long()
)
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=5000
)
max_new_tokens = min(
max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens
)
images = images.to(dtype=torch.float)
if max_new_tokens < 1:
yield json.dumps(
{
"text": ori_prompt
+ "Exceeds max token length. Please start a new conversation, thanks.",
"error_code": 0,
}
).encode() + b"\0"
return
# local inference
# BUG: If stopping_criteria is set, an error occur:
# RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0
generate_kwargs = dict(
inputs=input_ids,
do_sample=False,
top_p=1.0,
max_new_tokens=512,
pad_token_id=tokenizer.pad_token_id,
streamer=streamer,
# stopping_criteria=[stopping_criteria],
use_cache=True,
**image_args,
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
logger.debug(ori_prompt)
logger.debug(generate_kwargs)
generated_text = ori_prompt
for new_text in streamer:
generated_text += new_text
if generated_text.endswith(stop_str):
generated_text = generated_text[: -len(stop_str)]
yield json.dumps({"text": generated_text, "error_code": 0}).encode()
def http_bot(state):
if state.skip_next:
# This generate call is skipped due to invalid inputs
yield (state, state.to_gradio_chatbot())
return
if len(state.messages) == state.offset + 2:
# First round of conversation
template_name = 'phi'
new_state = conv_templates[template_name].copy()
new_state.append_message(new_state.roles[0], state.messages[-2][1])
new_state.append_message(new_state.roles[1], None)
state = new_state
# if "tinyllava" in model_name.lower():
# if "3.1b" in model_name.lower() or "phi" in model_name.lower():
# template_name = "phi"
# elif "2.0b" in model_name.lower() or "stablelm" in model_name.lower():
# template_name = "phi"
# elif "qwen" in model_name.lower():
# template_name = "qwen"
# else:
# template_name = "v1"
# elif "llava" in model_name.lower():
# if "llama-2" in model_name.lower():
# template_name = "llava_llama_2"
# elif "v1" in model_name.lower():
# if "mmtag" in model_name.lower():
# template_name = "v1_mmtag"
# elif (
# "plain" in model_name.lower()
# and "finetune" not in model_name.lower()
# ):
# template_name = "v1_mmtag"
# else:
# template_name = "llava_v1"
# elif "mpt" in model_name.lower():
# template_name = "mpt"
# else:
# if "mmtag" in model_name.lower():
# template_name = "v0_mmtag"
# elif (
# "plain" in model_name.lower()
# and "finetune" not in model_name.lower()
# ):
# template_name = "v0_mmtag"
# else:
# template_name = "llava_v0"
# elif "mpt" in model_name:
# template_name = "mpt_text"
# elif "llama-2" in model_name:
# template_name = "llama_2"
# else:
# template_name = "vicuna_v1"
# Construct prompt
prompt = state.get_prompt()
all_images = state.get_images(return_pil=True)
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
# Make requests
# pload = {"model": model_name, "prompt": prompt, "temperature": float(temperature), "top_p": float(top_p),
# "max_new_tokens": min(int(max_new_tokens), 1536), "stop": (
# state.sep
# if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT]
# else state.sep2
# ), "images": state.get_images()}
pload = {
"model": model_name,
"prompt": prompt,
"temperature": 0,
"top_p": 1.0,
"max_new_tokens": 512,
"stop": (
state.sep
if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT]
else state.sep2
), "images": state.get_images()}
state.messages[-1][-1] = "β–Œ"
yield (state, state.to_gradio_chatbot())
# for stream
output = get_response(pload)
print(output)
for chunk in output:
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"][len(prompt) :].strip()
state.messages[-1][-1] = output + "β–Œ"
yield (state, state.to_gradio_chatbot())
else:
output = data["text"] + f" (error_code: {data['error_code']})"
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot())
return
time.sleep(0.03)
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot())
def build_demo():
textbox = gr.Textbox(
show_label=False, placeholder="Enter text and press ENTER", container=False
)
with gr.Blocks(title="TinyLLaVA", theme=gr.themes.Default(), css=block_css) as demo:
state = gr.State()
gr.Markdown(title_markdown)
with gr.Row():
with gr.Column(scale=5):
with gr.Row(elem_id="Model ID"):
gr.Dropdown(
choices=[DEFAULT_MODEL_NAME],
value=DEFAULT_MODEL_NAME,
interactive=True,
label="Model ID",
container=False,
)
imagebox = gr.Image(type="pil")
image_process_mode = gr.Radio(
["Crop", "Resize", "Pad", "Default"],
value="Default",
label="Preprocess for non-square image",
visible=False,
)
# # cur_dir = os.path.dirname(os.path.abspath(__file__))
# cur_dir = Path(__file__).parent
# gr.Examples(
# examples=[
# [
# f"{cur_dir}/examples/extreme_ironing.jpg",
# "What is unusual about this image?",
# ],
# [
# f"{cur_dir}/examples/waterview.jpg",
# "What are the things I should be cautious about when I visit here?",
# ],
# ],
# inputs=[imagebox, textbox],
# )
# with gr.Accordion("Parameters", open=False) as _:
# temperature = gr.Slider(
# minimum=0.0,
# maximum=1.0,
# value=0.2,
# step=0.1,
# interactive=True,
# label="Temperature",
# )
# top_p = gr.Slider(
# minimum=0.0,
# maximum=1.0,
# value=0.7,
# step=0.1,
# interactive=True,
# label="Top P",
# )
# max_output_tokens = gr.Slider(
# minimum=0,
# maximum=1024,
# value=512,
# step=64,
# interactive=True,
# label="Max output tokens",
# )
with gr.Column(scale=8):
chatbot = gr.Chatbot(elem_id="chatbot", label="Chatbot", height=550)
with gr.Row():
with gr.Column(scale=8):
textbox.render()
with gr.Column(scale=1, min_width=50):
submit_btn = gr.Button(value="Send", variant="primary")
with gr.Row(elem_id="buttons") as _:
regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=True)
clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=True)
# gr.Markdown(tos_markdown)
# gr.Markdown(learn_more_markdown)
# gr.Markdown(ack_markdown)
regenerate_btn.click(
regenerate,
[state, image_process_mode],
[state, chatbot, textbox, imagebox],
queue=False,
).then(
http_bot, [state], [state, chatbot]
)
clear_btn.click(
clear_history, None, [state, chatbot, textbox, imagebox], queue=False
)
textbox.submit(
add_text,
[state, textbox, imagebox, image_process_mode],
[state, chatbot, textbox, imagebox],
queue=False,
).then(
http_bot, [state], [state, chatbot]
)
submit_btn.click(
add_text,
[state, textbox, imagebox, image_process_mode],
[state, chatbot, textbox, imagebox],
queue=False,
).then(
http_bot, [state], [state, chatbot]
)
demo.load(load_demo, None, [state], queue=False)
return demo
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
logger.info(gr.__version__)
from huggingface_hub import snapshot_download
huggingface_path = Path(snapshot_download(repo_id="laurenssam/privacy-aware-visual-language-models"))
model_name = str(huggingface_path / "tinyllava_demo")
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path=str(model_name),
model_base="bczhou/TinyLLaVA-3.1B",
model_name=get_model_name_from_path(model_name + "_lora_TinyLLaVA-3.1B"),
device="cpu",
load_4bit=False,
load_8bit=False
)
model = model.cpu().float()
demo = build_demo()
demo.queue()
if __name__ == "__main__":
demo.launch()