import argparse import hashlib import json import os import time from threading import Thread import logging import gradio as gr import torch from huggingface_hub import hf_hub_download from pathlib import Path from tinyllava.model.builder import load_pretrained_model from tinyllava.mm_utils import ( KeywordsStoppingCriteria, load_image_from_base64, process_images, tokenizer_image_token, get_model_name_from_path, ) from PIL import Image from io import BytesIO import base64 import torch from transformers import StoppingCriteria from tinyllava.constants import ( DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, ) from tinyllava.conversation import SeparatorStyle, conv_templates, default_conversation from transformers import TextIteratorStreamer from pathlib import Path DEFAULT_MODEL_PATH = "bczhou/TinyLLaVA-3.1B" DEFAULT_MODEL_NAME = "TinyLLaVA-3.1B" block_css = """ #buttons button { min-width: min(120px,100%); } """ title_markdown = """ # Privacy Aware Visual Language Models [[Code](https://github.com/laurenssam/Privacy-Aware-Visual-Language-Models)] | 📚 [[Paper](https://arxiv.org/abs/2405.17423)] """ def regenerate(state, image_process_mode): state.messages[-1][-1] = None prev_human_msg = state.messages[-2] if type(prev_human_msg[1]) in (tuple, list): prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) state.skip_next = False return (state, state.to_gradio_chatbot(), "", None) def clear_history(): state = default_conversation.copy() return (state, state.to_gradio_chatbot(), "", None) def add_text(state, text, image, image_process_mode): if len(text) <= 0 and image is None: state.skip_next = True return (state, state.to_gradio_chatbot(), "", None) text = text[:1536] # Hard cut-off if image is not None: text = text[:1200] # Hard cut-off for images if "" not in text: # text = '' + text text = text + "\n" text = (text, image, image_process_mode) if len(state.get_images(return_pil=True)) > 0: state = default_conversation.copy() state.append_message(state.roles[0], text) state.append_message(state.roles[1], None) state.skip_next = False return (state, state.to_gradio_chatbot(), "", None) def load_demo(): state = default_conversation.copy() return state @torch.inference_mode() def get_response(params): prompt = params["prompt"] ori_prompt = prompt images = params.get("images", None) num_image_tokens = 0 if images is not None and len(images) > 0: if len(images) > 0: if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): raise ValueError( "Number of images does not match number of tokens in prompt" ) images = [load_image_from_base64(image) for image in images] # images = process_images(images, image_processor, model.config) images = process_images(images, image_processor, model.config).to('cpu', dtype=torch.float) if type(images) is list: images = [ image.to(model.device) for image in images ] else: images = images.to(model.device) replace_token = DEFAULT_IMAGE_TOKEN if getattr(model.config, "mm_use_im_start_end", False): replace_token = ( DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN ) prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) num_image_tokens = ( prompt.count(replace_token) * model.get_vision_tower().num_patches ) else: images = None image_args = {"images": images} else: images = None image_args = {} temperature = 0.0 top_p = 1.0 max_context_length = getattr(model.config, "max_position_embeddings", 2048) max_new_tokens = 512 stop_str = params.get("stop", None) do_sample = False logger.info(prompt) input_ids = ( tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") .unsqueeze(0) .to(model.device).long() ) keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=5000 ) max_new_tokens = min( max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens ) images = images.to(dtype=torch.float) if max_new_tokens < 1: yield json.dumps( { "text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0, } ).encode() + b"\0" return # local inference # BUG: If stopping_criteria is set, an error occur: # RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0 generate_kwargs = dict( inputs=input_ids, do_sample=False, top_p=1.0, max_new_tokens=512, pad_token_id=tokenizer.pad_token_id, streamer=streamer, # stopping_criteria=[stopping_criteria], use_cache=True, **image_args, ) thread = Thread(target=model.generate, kwargs=generate_kwargs) thread.start() logger.debug(ori_prompt) logger.debug(generate_kwargs) generated_text = ori_prompt for new_text in streamer: generated_text += new_text if generated_text.endswith(stop_str): generated_text = generated_text[: -len(stop_str)] yield json.dumps({"text": generated_text, "error_code": 0}).encode() def http_bot(state): if state.skip_next: # This generate call is skipped due to invalid inputs yield (state, state.to_gradio_chatbot()) return if len(state.messages) == state.offset + 2: # First round of conversation template_name = 'phi' new_state = conv_templates[template_name].copy() new_state.append_message(new_state.roles[0], state.messages[-2][1]) new_state.append_message(new_state.roles[1], None) state = new_state # if "tinyllava" in model_name.lower(): # if "3.1b" in model_name.lower() or "phi" in model_name.lower(): # template_name = "phi" # elif "2.0b" in model_name.lower() or "stablelm" in model_name.lower(): # template_name = "phi" # elif "qwen" in model_name.lower(): # template_name = "qwen" # else: # template_name = "v1" # elif "llava" in model_name.lower(): # if "llama-2" in model_name.lower(): # template_name = "llava_llama_2" # elif "v1" in model_name.lower(): # if "mmtag" in model_name.lower(): # template_name = "v1_mmtag" # elif ( # "plain" in model_name.lower() # and "finetune" not in model_name.lower() # ): # template_name = "v1_mmtag" # else: # template_name = "llava_v1" # elif "mpt" in model_name.lower(): # template_name = "mpt" # else: # if "mmtag" in model_name.lower(): # template_name = "v0_mmtag" # elif ( # "plain" in model_name.lower() # and "finetune" not in model_name.lower() # ): # template_name = "v0_mmtag" # else: # template_name = "llava_v0" # elif "mpt" in model_name: # template_name = "mpt_text" # elif "llama-2" in model_name: # template_name = "llama_2" # else: # template_name = "vicuna_v1" # Construct prompt prompt = state.get_prompt() all_images = state.get_images(return_pil=True) all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images] # Make requests # pload = {"model": model_name, "prompt": prompt, "temperature": float(temperature), "top_p": float(top_p), # "max_new_tokens": min(int(max_new_tokens), 1536), "stop": ( # state.sep # if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] # else state.sep2 # ), "images": state.get_images()} pload = { "model": model_name, "prompt": prompt, "temperature": 0, "top_p": 1.0, "max_new_tokens": 512, "stop": ( state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2 ), "images": state.get_images()} state.messages[-1][-1] = "▌" yield (state, state.to_gradio_chatbot()) # for stream output = get_response(pload) print(output) for chunk in output: if chunk: data = json.loads(chunk.decode()) if data["error_code"] == 0: output = data["text"][len(prompt) :].strip() state.messages[-1][-1] = output + "▌" yield (state, state.to_gradio_chatbot()) else: output = data["text"] + f" (error_code: {data['error_code']})" state.messages[-1][-1] = output yield (state, state.to_gradio_chatbot()) return time.sleep(0.03) state.messages[-1][-1] = state.messages[-1][-1][:-1] yield (state, state.to_gradio_chatbot()) def build_demo(): textbox = gr.Textbox( show_label=False, placeholder="Enter text and press ENTER", container=False ) with gr.Blocks(title="TinyLLaVA", theme=gr.themes.Default(), css=block_css) as demo: state = gr.State() gr.Markdown(title_markdown) with gr.Row(): with gr.Column(scale=5): with gr.Row(elem_id="Model ID"): gr.Dropdown( choices=[DEFAULT_MODEL_NAME], value=DEFAULT_MODEL_NAME, interactive=True, label="Model ID", container=False, ) imagebox = gr.Image(type="pil") image_process_mode = gr.Radio( ["Crop", "Resize", "Pad", "Default"], value="Default", label="Preprocess for non-square image", visible=False, ) # # cur_dir = os.path.dirname(os.path.abspath(__file__)) # cur_dir = Path(__file__).parent # gr.Examples( # examples=[ # [ # f"{cur_dir}/examples/extreme_ironing.jpg", # "What is unusual about this image?", # ], # [ # f"{cur_dir}/examples/waterview.jpg", # "What are the things I should be cautious about when I visit here?", # ], # ], # inputs=[imagebox, textbox], # ) # with gr.Accordion("Parameters", open=False) as _: # temperature = gr.Slider( # minimum=0.0, # maximum=1.0, # value=0.2, # step=0.1, # interactive=True, # label="Temperature", # ) # top_p = gr.Slider( # minimum=0.0, # maximum=1.0, # value=0.7, # step=0.1, # interactive=True, # label="Top P", # ) # max_output_tokens = gr.Slider( # minimum=0, # maximum=1024, # value=512, # step=64, # interactive=True, # label="Max output tokens", # ) with gr.Column(scale=8): chatbot = gr.Chatbot(elem_id="chatbot", label="Chatbot", height=550) with gr.Row(): with gr.Column(scale=8): textbox.render() with gr.Column(scale=1, min_width=50): submit_btn = gr.Button(value="Send", variant="primary") with gr.Row(elem_id="buttons") as _: regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True) clear_btn = gr.Button(value="🗑️ Clear", interactive=True) # gr.Markdown(tos_markdown) # gr.Markdown(learn_more_markdown) # gr.Markdown(ack_markdown) regenerate_btn.click( regenerate, [state, image_process_mode], [state, chatbot, textbox, imagebox], queue=False, ).then( http_bot, [state], [state, chatbot] ) clear_btn.click( clear_history, None, [state, chatbot, textbox, imagebox], queue=False ) textbox.submit( add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox], queue=False, ).then( http_bot, [state], [state, chatbot] ) submit_btn.click( add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox], queue=False, ).then( http_bot, [state], [state, chatbot] ) demo.load(load_demo, None, [state], queue=False) return demo logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) logger.info(gr.__version__) from huggingface_hub import snapshot_download huggingface_path = Path(snapshot_download(repo_id="laurenssam/privacy-aware-visual-language-models")) model_name = str(huggingface_path / "tinyllava_demo") tokenizer, model, image_processor, context_len = load_pretrained_model( model_path=str(model_name), model_base="bczhou/TinyLLaVA-3.1B", model_name=get_model_name_from_path(model_name + "_lora_TinyLLaVA-3.1B"), device="cpu", load_4bit=False, load_8bit=False ) model = model.cpu().float() demo = build_demo() demo.queue() if __name__ == "__main__": demo.launch()