RTE Build
Deployment
a099612
"""Template Demo for IBM Granite Hugging Face spaces."""
from collections.abc import Iterator
from datetime import datetime
from pathlib import Path
from threading import Thread
import gradio as gr
import PIL
import spaces
import torch
from PIL.Image import Image as PILImage
from PIL.Image import Resampling
from transformers import (
AutoModelForCausalLM,
AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer,
LlavaNextForConditionalGeneration,
LlavaNextProcessor,
TextIteratorStreamer,
)
from themes.research_monochrome import theme
dir_ = Path(__file__).parent.parent
today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002
MODEL_ID = "ibm-granite/granite-vision-3.2-2b"
MODEL_ID_PREVIEW = "ibm-granite/granite-vision-3.1-2b-preview"
# SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024.
# Today's Date: {today_date}.
# You are Granite, developed by IBM. You are a helpful AI assistant"""
TITLE = "IBM Granite VISION 3.1 2b preview"
DESCRIPTION = "Try one of the sample prompts below or write your own. Remember, \
AI models can make mistakes."
MAX_INPUT_TOKEN_LENGTH = 4096
MAX_NEW_TOKENS = 1024
TEMPERATURE = 0.7
TOP_P = 0.85
TOP_K = 50
REPETITION_PENALTY = 1.05
sample_data = [
[
"https://www.ibm.com/design/language/static/159e89b3d8d6efcb5db43f543df36b23/a5df1/rebusgallery_tshirt.png",
["What are the three symbols on the tshirt?"],
],
[
str(dir_ / "data" / "p2-report.png"),
[
"What's the difference in rental income between 2020 and 2019?",
"Which table entries are less in 2020 than 2019?",
],
],
[
"https://www.ibm.com/design/language/static/159e89b3d8d6efcb5db43f543df36b23/a5df1/rebusgallery_tshirt.png",
["What's this?"],
],
]
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
processor: LlavaNextProcessor = None
model: LlavaNextForConditionalGeneration = None
selected_image: PILImage = None
def image_changed(im: PILImage):
global selected_image
if im is None:
selected_image = None
else:
selected_image = im.copy()
selected_image.thumbnail((800, 800))
# return selected_image
def create_single_turn(image: PILImage, text: str) -> dict:
if image is None:
return {
"role": "user",
"content": [
{"type": "text", "text": text},
],
}
else:
return {
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": text},
],
}
@spaces.GPU
def generate(
image: PILImage,
message: str,
chat_history: list[dict],
temperature: float = TEMPERATURE,
repetition_penalty: float = REPETITION_PENALTY,
top_p: float = TOP_P,
top_k: float = TOP_K,
max_new_tokens: int = MAX_NEW_TOKENS,
):
"""Generate function for chat demo.
Args:
max_new_tokens:
top_k:
top_p:
repetition_penalty:
temperature:
image: the image to be talked about...
message (str): The latest input message from the user.
chat_history (list[dict]): A list of dictionaries representing previous chat history, where each dictionary
contains 'role' and 'content'.
Yields:
str: The generated response, broken down into smaller chunks.
"""
print(top_p)
# Build messages
conversation = []
# TODO: maybe add back custom sys prompt
# conversation.append({"role": "system", "content": SYS_PROMPT})
conversation += chat_history
conversation.append(create_single_turn(image, message))
# Convert messages to prompt format
inputs = processor.apply_chat_template(
conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
).to(device)
# TODO: This might cut out the image tokens -- find better strategy
# if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
# input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
# gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
generate_kwargs = dict(
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
output = model.generate(**inputs, **generate_kwargs)
out = processor.decode(output[0], skip_special_tokens=True)
out_s = out.strip().split("<|assistant|>")
return [gr.ChatMessage(role="user", content=message), gr.ChatMessage(role="assistant", content=out_s[-1])]
def multimodal_generate_v2(
msg: str,
temperature: float = TEMPERATURE,
repetition_penalty: float = REPETITION_PENALTY,
top_p: float = TOP_P,
top_k: float = TOP_K,
max_new_tokens: int = MAX_NEW_TOKENS,
):
global model, processor
# lazy loading and adding image
if model is None:
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForVision2Seq.from_pretrained(MODEL_ID, device_map="auto").to(device)
return generate(
selected_image,
msg,
[],
temperature=temperature,
repetition_penalty=repetition_penalty,
top_p=top_p,
top_k=top_k,
max_new_tokens=max_new_tokens,
)
tb = gr.Textbox(submit_btn=True)
# advanced settings (displayed in Accordion)
temperature_slider = gr.Slider(
minimum=0,
maximum=1.0,
value=TEMPERATURE,
step=0.1,
label="Temperature",
elem_classes=["gr_accordion_element"],
interactive=True,
)
top_p_slider = gr.Slider(
minimum=0,
maximum=1.0,
value=TOP_P,
step=0.05,
label="Top P",
elem_classes=["gr_accordion_element"],
interactive=True,
)
top_k_slider = gr.Slider(
minimum=0, maximum=100, value=TOP_K, step=1, label="Top K", elem_classes=["gr_accordion_element"], interactive=True
)
repetition_penalty_slider = gr.Slider(
minimum=0,
maximum=2.0,
value=REPETITION_PENALTY,
step=0.05,
label="Repetition Penalty",
elem_classes=["gr_accordion_element"],
interactive=True,
)
max_new_tokens_slider = gr.Slider(
minimum=1,
maximum=2000,
value=MAX_NEW_TOKENS,
step=1,
label="Max New Tokens",
elem_classes=["gr_accordion_element"],
interactive=True,
)
chatbot = gr.Chatbot(examples=[{"text": "Hello World!"}], type="messages", label="Q&A about selected document")
css_file_path = Path(Path(__file__).parent / "app.css")
head_file_path = Path(Path(__file__).parent / "app_head.html")
with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=theme, title=TITLE) as demo:
is_in_edit_mode = gr.State(True) # in block to be reactive
gr.Markdown(f"# {TITLE}")
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
# create sample image object for reference, render later
image_x = gr.Image(
type="pil",
label="Example image",
render=False,
interactive=False,
show_label=False,
show_fullscreen_button=False,
height=800,
)
image_x.change(fn=image_changed, inputs=image_x)
# Create Dataset object and render it
ds = gr.Dataset(label="Select one document", samples=sample_data, components=[gr.Image(render=False)])
def sample_image_selected(d: gr.SelectData, dx):
return gr.Image(dx[0]), gr.update(examples=[{"text": x} for x in dx[1]])
ds.select(lambda: [], outputs=[chatbot])
ds.select(sample_image_selected, inputs=[ds], outputs=[image_x, chatbot])
# Render image object after DS
image_x.render()
with gr.Column():
# Render ChatBot
chatbot.render()
# Define behavior for example selection
def update_user_chat_x(x: gr.SelectData):
return [gr.ChatMessage(role="user", content=x.value["text"])]
def send_generate_x(x: gr.SelectData, temperature, repetition_penalty, top_p, top_k, max_new_tokens):
txt = x.value["text"]
return multimodal_generate_v2(txt, temperature, repetition_penalty, top_p, top_k, max_new_tokens)
chatbot.example_select(lambda: False, outputs=is_in_edit_mode)
chatbot.example_select(update_user_chat_x, outputs=[chatbot])
chatbot.example_select(
send_generate_x,
inputs=[
temperature_slider,
repetition_penalty_slider,
top_p_slider,
top_k_slider,
max_new_tokens_slider,
],
outputs=[chatbot],
)
# Create User Chat Textbox and Reset Button
tbb = gr.Textbox(submit_btn=True, show_label=False)
fb = gr.Button("Reset Chat", visible=False)
fb.click(lambda: [], outputs=[chatbot])
# Handle toggling betwwen edit and non-edit mode
def textbox_switch(emode):
# if t.visible:
if not emode:
return [gr.update(visible=False), gr.update(visible=True)]
else:
return [gr.update(visible=True), gr.update(visible=False)]
tbb.submit(lambda: False, outputs=[is_in_edit_mode])
fb.click(lambda: True, outputs=[is_in_edit_mode])
is_in_edit_mode.change(textbox_switch, inputs=[is_in_edit_mode], outputs=[tbb, fb])
# submit user question
tbb.submit(lambda x: [gr.ChatMessage(role="user", content=x)], inputs=tbb, outputs=chatbot)
tbb.submit(
multimodal_generate_v2,
inputs=[
tbb,
temperature_slider,
repetition_penalty_slider,
top_p_slider,
top_k_slider,
max_new_tokens_slider,
],
outputs=[chatbot],
)
# extra model parameters
with gr.Accordion("Advanced Settings", open=False):
max_new_tokens_slider.render()
temperature_slider.render()
top_k_slider.render()
top_p_slider.render()
repetition_penalty_slider.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()