Spaces:
Running
on
Zero
Running
on
Zero
"""Template Demo for IBM Granite Hugging Face spaces.""" | |
from collections.abc import Iterator | |
from datetime import datetime | |
from pathlib import Path | |
from threading import Thread | |
import gradio as gr | |
import PIL | |
import spaces | |
import torch | |
from PIL.Image import Image as PILImage | |
from PIL.Image import Resampling | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoModelForVision2Seq, | |
AutoProcessor, | |
AutoTokenizer, | |
LlavaNextForConditionalGeneration, | |
LlavaNextProcessor, | |
TextIteratorStreamer, | |
) | |
from themes.research_monochrome import theme | |
dir_ = Path(__file__).parent.parent | |
today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002 | |
MODEL_ID = "ibm-granite/granite-vision-3.2-2b" | |
MODEL_ID_PREVIEW = "ibm-granite/granite-vision-3.1-2b-preview" | |
# SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024. | |
# Today's Date: {today_date}. | |
# You are Granite, developed by IBM. You are a helpful AI assistant""" | |
TITLE = "IBM Granite VISION 3.1 2b preview" | |
DESCRIPTION = "Try one of the sample prompts below or write your own. Remember, \ | |
AI models can make mistakes." | |
MAX_INPUT_TOKEN_LENGTH = 4096 | |
MAX_NEW_TOKENS = 1024 | |
TEMPERATURE = 0.7 | |
TOP_P = 0.85 | |
TOP_K = 50 | |
REPETITION_PENALTY = 1.05 | |
sample_data = [ | |
[ | |
"https://www.ibm.com/design/language/static/159e89b3d8d6efcb5db43f543df36b23/a5df1/rebusgallery_tshirt.png", | |
["What are the three symbols on the tshirt?"], | |
], | |
[ | |
str(dir_ / "data" / "p2-report.png"), | |
[ | |
"What's the difference in rental income between 2020 and 2019?", | |
"Which table entries are less in 2020 than 2019?", | |
], | |
], | |
[ | |
"https://www.ibm.com/design/language/static/159e89b3d8d6efcb5db43f543df36b23/a5df1/rebusgallery_tshirt.png", | |
["What's this?"], | |
], | |
] | |
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") | |
processor: LlavaNextProcessor = None | |
model: LlavaNextForConditionalGeneration = None | |
selected_image: PILImage = None | |
def image_changed(im: PILImage): | |
global selected_image | |
if im is None: | |
selected_image = None | |
else: | |
selected_image = im.copy() | |
selected_image.thumbnail((800, 800)) | |
# return selected_image | |
def create_single_turn(image: PILImage, text: str) -> dict: | |
if image is None: | |
return { | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": text}, | |
], | |
} | |
else: | |
return { | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": image}, | |
{"type": "text", "text": text}, | |
], | |
} | |
def generate( | |
image: PILImage, | |
message: str, | |
chat_history: list[dict], | |
temperature: float = TEMPERATURE, | |
repetition_penalty: float = REPETITION_PENALTY, | |
top_p: float = TOP_P, | |
top_k: float = TOP_K, | |
max_new_tokens: int = MAX_NEW_TOKENS, | |
): | |
"""Generate function for chat demo. | |
Args: | |
max_new_tokens: | |
top_k: | |
top_p: | |
repetition_penalty: | |
temperature: | |
image: the image to be talked about... | |
message (str): The latest input message from the user. | |
chat_history (list[dict]): A list of dictionaries representing previous chat history, where each dictionary | |
contains 'role' and 'content'. | |
Yields: | |
str: The generated response, broken down into smaller chunks. | |
""" | |
print(top_p) | |
# Build messages | |
conversation = [] | |
# TODO: maybe add back custom sys prompt | |
# conversation.append({"role": "system", "content": SYS_PROMPT}) | |
conversation += chat_history | |
conversation.append(create_single_turn(image, message)) | |
# Convert messages to prompt format | |
inputs = processor.apply_chat_template( | |
conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" | |
).to(device) | |
# TODO: This might cut out the image tokens -- find better strategy | |
# if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
# input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
# gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
generate_kwargs = dict( | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
num_beams=1, | |
repetition_penalty=repetition_penalty, | |
) | |
output = model.generate(**inputs, **generate_kwargs) | |
out = processor.decode(output[0], skip_special_tokens=True) | |
out_s = out.strip().split("<|assistant|>") | |
return [gr.ChatMessage(role="user", content=message), gr.ChatMessage(role="assistant", content=out_s[-1])] | |
def multimodal_generate_v2( | |
msg: str, | |
temperature: float = TEMPERATURE, | |
repetition_penalty: float = REPETITION_PENALTY, | |
top_p: float = TOP_P, | |
top_k: float = TOP_K, | |
max_new_tokens: int = MAX_NEW_TOKENS, | |
): | |
global model, processor | |
# lazy loading and adding image | |
if model is None: | |
processor = AutoProcessor.from_pretrained(MODEL_ID) | |
model = AutoModelForVision2Seq.from_pretrained(MODEL_ID, device_map="auto").to(device) | |
return generate( | |
selected_image, | |
msg, | |
[], | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
top_p=top_p, | |
top_k=top_k, | |
max_new_tokens=max_new_tokens, | |
) | |
tb = gr.Textbox(submit_btn=True) | |
# advanced settings (displayed in Accordion) | |
temperature_slider = gr.Slider( | |
minimum=0, | |
maximum=1.0, | |
value=TEMPERATURE, | |
step=0.1, | |
label="Temperature", | |
elem_classes=["gr_accordion_element"], | |
interactive=True, | |
) | |
top_p_slider = gr.Slider( | |
minimum=0, | |
maximum=1.0, | |
value=TOP_P, | |
step=0.05, | |
label="Top P", | |
elem_classes=["gr_accordion_element"], | |
interactive=True, | |
) | |
top_k_slider = gr.Slider( | |
minimum=0, maximum=100, value=TOP_K, step=1, label="Top K", elem_classes=["gr_accordion_element"], interactive=True | |
) | |
repetition_penalty_slider = gr.Slider( | |
minimum=0, | |
maximum=2.0, | |
value=REPETITION_PENALTY, | |
step=0.05, | |
label="Repetition Penalty", | |
elem_classes=["gr_accordion_element"], | |
interactive=True, | |
) | |
max_new_tokens_slider = gr.Slider( | |
minimum=1, | |
maximum=2000, | |
value=MAX_NEW_TOKENS, | |
step=1, | |
label="Max New Tokens", | |
elem_classes=["gr_accordion_element"], | |
interactive=True, | |
) | |
chatbot = gr.Chatbot(examples=[{"text": "Hello World!"}], type="messages", label="Q&A about selected document") | |
css_file_path = Path(Path(__file__).parent / "app.css") | |
head_file_path = Path(Path(__file__).parent / "app_head.html") | |
with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=theme, title=TITLE) as demo: | |
is_in_edit_mode = gr.State(True) # in block to be reactive | |
gr.Markdown(f"# {TITLE}") | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(): | |
# create sample image object for reference, render later | |
image_x = gr.Image( | |
type="pil", | |
label="Example image", | |
render=False, | |
interactive=False, | |
show_label=False, | |
show_fullscreen_button=False, | |
height=800, | |
) | |
image_x.change(fn=image_changed, inputs=image_x) | |
# Create Dataset object and render it | |
ds = gr.Dataset(label="Select one document", samples=sample_data, components=[gr.Image(render=False)]) | |
def sample_image_selected(d: gr.SelectData, dx): | |
return gr.Image(dx[0]), gr.update(examples=[{"text": x} for x in dx[1]]) | |
ds.select(lambda: [], outputs=[chatbot]) | |
ds.select(sample_image_selected, inputs=[ds], outputs=[image_x, chatbot]) | |
# Render image object after DS | |
image_x.render() | |
with gr.Column(): | |
# Render ChatBot | |
chatbot.render() | |
# Define behavior for example selection | |
def update_user_chat_x(x: gr.SelectData): | |
return [gr.ChatMessage(role="user", content=x.value["text"])] | |
def send_generate_x(x: gr.SelectData, temperature, repetition_penalty, top_p, top_k, max_new_tokens): | |
txt = x.value["text"] | |
return multimodal_generate_v2(txt, temperature, repetition_penalty, top_p, top_k, max_new_tokens) | |
chatbot.example_select(lambda: False, outputs=is_in_edit_mode) | |
chatbot.example_select(update_user_chat_x, outputs=[chatbot]) | |
chatbot.example_select( | |
send_generate_x, | |
inputs=[ | |
temperature_slider, | |
repetition_penalty_slider, | |
top_p_slider, | |
top_k_slider, | |
max_new_tokens_slider, | |
], | |
outputs=[chatbot], | |
) | |
# Create User Chat Textbox and Reset Button | |
tbb = gr.Textbox(submit_btn=True, show_label=False) | |
fb = gr.Button("Reset Chat", visible=False) | |
fb.click(lambda: [], outputs=[chatbot]) | |
# Handle toggling betwwen edit and non-edit mode | |
def textbox_switch(emode): | |
# if t.visible: | |
if not emode: | |
return [gr.update(visible=False), gr.update(visible=True)] | |
else: | |
return [gr.update(visible=True), gr.update(visible=False)] | |
tbb.submit(lambda: False, outputs=[is_in_edit_mode]) | |
fb.click(lambda: True, outputs=[is_in_edit_mode]) | |
is_in_edit_mode.change(textbox_switch, inputs=[is_in_edit_mode], outputs=[tbb, fb]) | |
# submit user question | |
tbb.submit(lambda x: [gr.ChatMessage(role="user", content=x)], inputs=tbb, outputs=chatbot) | |
tbb.submit( | |
multimodal_generate_v2, | |
inputs=[ | |
tbb, | |
temperature_slider, | |
repetition_penalty_slider, | |
top_p_slider, | |
top_k_slider, | |
max_new_tokens_slider, | |
], | |
outputs=[chatbot], | |
) | |
# extra model parameters | |
with gr.Accordion("Advanced Settings", open=False): | |
max_new_tokens_slider.render() | |
temperature_slider.render() | |
top_k_slider.render() | |
top_p_slider.render() | |
repetition_penalty_slider.render() | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() | |