Spaces:
Running
on
Zero
Running
on
Zero
"""Developed by Ruslan Magana Vsevolodovna""" | |
from collections.abc import Iterator | |
from datetime import datetime | |
from pathlib import Path | |
from threading import Thread | |
import io | |
import base64 | |
import random | |
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration | |
from themes.research_monochrome import theme | |
# ============================================================================= | |
# Constants & Prompts | |
# ============================================================================= | |
today_date = datetime.today().strftime("%B %-d, %Y") | |
SYS_PROMPT = """ | |
Respond in the following format: | |
<reasoning> | |
... | |
</reasoning> | |
<answer> | |
... | |
</answer> | |
""" | |
TITLE = "IBM Granite 3.1 8b Reasoning & Vision Preview" | |
DESCRIPTION = """ | |
<p>Granite 3.1 8b Reasoning is an open‐source LLM supporting a 128k context window and Granite Vision 3.1 2B Preview for vision‐language capabilities. Start with one of the sample prompts | |
or enter your own. Keep in mind that AI can occasionally make mistakes. | |
<span class="gr_docs_link"> | |
<a href="https://www.ibm.com/granite/docs/">View Documentation <i class="fa fa-external-link"></i></a> | |
</span> | |
</p> | |
""" | |
MAX_INPUT_TOKEN_LENGTH = 128_000 | |
MAX_NEW_TOKENS = 1024 | |
TEMPERATURE = 0.5 | |
TOP_P = 0.85 | |
TOP_K = 50 | |
REPETITION_PENALTY = 1.05 | |
# Vision defaults (advanced settings) | |
VISION_TEMPERATURE = 0.2 | |
VISION_TOP_P = 0.95 | |
VISION_TOP_K = 50 | |
VISION_MAX_TOKENS = 128 | |
if not torch.cuda.is_available(): | |
print("This demo may not work on CPU.") | |
# ============================================================================= | |
# Text Model Loading | |
# ============================================================================= | |
granite_text_model = "ruslanmv/granite-3.1-8b-Reasoning" | |
text_model = AutoModelForCausalLM.from_pretrained( | |
granite_text_model, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
tokenizer = AutoTokenizer.from_pretrained(granite_text_model) | |
tokenizer.use_default_system_prompt = False | |
# ============================================================================= | |
# Vision Model Loading | |
# ============================================================================= | |
vision_model_path = "ibm-granite/granite-vision-3.1-2b-preview" | |
vision_processor = LlavaNextProcessor.from_pretrained(vision_model_path, use_fast=True) | |
vision_model = LlavaNextForConditionalGeneration.from_pretrained( | |
vision_model_path, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True # Ensure the custom code is used so that weight shapes match. | |
) | |
# ============================================================================= | |
# Unified Display Function | |
# ============================================================================= | |
def get_text_from_content(content): | |
"""Helper to extract text from a list of content items.""" | |
texts = [] | |
for item in content: | |
if isinstance(item, dict): | |
if item.get("type") == "text": | |
texts.append(item.get("text", "")) | |
elif item.get("type") == "image": | |
image = item.get("image") | |
if image is not None: | |
buffered = io.BytesIO() | |
image.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
texts.append(f'<img src="data:image/jpeg;base64,{img_str}" style="max-width: 200px; max-height: 200px;">') | |
else: | |
texts.append("<image>") | |
else: | |
texts.append(str(item)) | |
return " ".join(texts) | |
def display_unified_conversation(conversation): | |
""" | |
Combine both text-only and vision messages. | |
Each conversation entry is expected to be a dict with keys: | |
- role: "user" or "assistant" | |
- content: either a string (for text) or a list of content items (for vision) | |
""" | |
chat_history = [] | |
i = 0 | |
while i < len(conversation): | |
if conversation[i]["role"] == "user": | |
user_content = conversation[i]["content"] | |
if isinstance(user_content, list): | |
user_msg = get_text_from_content(user_content) | |
else: | |
user_msg = user_content | |
assistant_msg = "" | |
if i + 1 < len(conversation) and conversation[i+1]["role"] == "assistant": | |
asst_content = conversation[i+1]["content"] | |
if isinstance(asst_content, list): | |
assistant_msg = get_text_from_content(asst_content) | |
else: | |
assistant_msg = asst_content | |
i += 2 | |
else: | |
i += 1 | |
chat_history.append((user_msg, assistant_msg)) | |
else: | |
i += 1 | |
return chat_history | |
# ============================================================================= | |
# Text Generation Function (for text-only chat) | |
# ============================================================================= | |
def generate( | |
message: str, | |
chat_history: list[dict], | |
temperature: float = TEMPERATURE, | |
repetition_penalty: float = REPETITION_PENALTY, | |
top_p: float = TOP_P, | |
top_k: float = TOP_K, | |
max_new_tokens: int = MAX_NEW_TOKENS, | |
) -> Iterator[str]: | |
""" | |
Generate function for text chat. It streams tokens and stops once the generated answer | |
contains the closing </answer> tag. | |
""" | |
conversation = [] | |
conversation.append({"role": "system", "content": SYS_PROMPT}) | |
conversation.extend(chat_history) | |
conversation.append({"role": "user", "content": message}) | |
input_ids = tokenizer.apply_chat_template( | |
conversation, | |
return_tensors="pt", | |
add_generation_prompt=True, | |
truncation=True, | |
max_length=MAX_INPUT_TOKEN_LENGTH - max_new_tokens, | |
) | |
input_ids = input_ids.to(text_model.device) | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = { | |
"input_ids": input_ids, | |
"streamer": streamer, | |
"max_new_tokens": max_new_tokens, | |
"do_sample": True, | |
"top_p": top_p, | |
"top_k": top_k, | |
"temperature": temperature, | |
"num_beams": 1, | |
"repetition_penalty": repetition_penalty, | |
} | |
t = Thread(target=text_model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
reasoning_started = False | |
answer_started = False | |
collected_reasoning = "" | |
collected_answer = "" | |
for text in streamer: | |
outputs.append(text) | |
current_output = "".join(outputs) | |
if "<reasoning>" in current_output and not reasoning_started: | |
reasoning_started = True | |
reasoning_start_index = current_output.find("<reasoning>") + len("<reasoning>") | |
collected_reasoning = current_output[reasoning_start_index:] | |
yield "[Reasoning]: " | |
outputs = [collected_reasoning] | |
elif reasoning_started and "<answer>" in current_output and not answer_started: | |
answer_started = True | |
reasoning_end_index = current_output.find("<answer>") | |
collected_reasoning = current_output[len("<reasoning>"):reasoning_end_index] | |
answer_start_index = current_output.find("<answer>") + len("<answer>") | |
collected_answer = current_output[answer_start_index:] | |
yield "\n[Answer]: " | |
outputs = [collected_answer] | |
yield collected_answer | |
elif reasoning_started and not answer_started: | |
collected_reasoning += text | |
yield text | |
elif answer_started: | |
collected_answer += text | |
yield text | |
if "</answer>" in collected_answer: | |
break | |
else: | |
yield text | |
# ============================================================================= | |
# Vision Chat Inference Function (for image+text chat) | |
# ============================================================================= | |
def chat_inference(image, text, conversation, temperature=VISION_TEMPERATURE, top_p=VISION_TOP_P, top_k=VISION_TOP_K, max_tokens=VISION_MAX_TOKENS): | |
if conversation is None: | |
conversation = [] | |
user_content = [] | |
if image is not None: | |
user_content.append({"type": "image", "image": image}) | |
if text and text.strip(): | |
user_content.append({"type": "text", "text": text.strip()}) | |
if not user_content: | |
return display_unified_conversation(conversation), conversation | |
conversation.append({"role": "user", "content": user_content}) | |
inputs = vision_processor.apply_chat_template( | |
conversation, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt" | |
).to("cuda") | |
torch.manual_seed(random.randint(0, 10000)) | |
generation_kwargs = { | |
"max_new_tokens": max_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
"top_k": top_k, | |
"do_sample": True, | |
} | |
output = vision_model.generate(**inputs, **generation_kwargs) | |
assistant_response = vision_processor.decode(output[0], skip_special_tokens=True) | |
if "<|assistant|>" in assistant_response: | |
assistant_response_parts = assistant_response.split("<|assistant|>") | |
assistant_response_text = assistant_response_parts[-1].strip() | |
else: | |
assistant_response_text = assistant_response.strip() | |
conversation.append({"role": "assistant", "content": [{"type": "text", "text": assistant_response_text.strip()}]}) | |
return display_unified_conversation(conversation), conversation | |
# ============================================================================= | |
# Unified Send-Message Function | |
# | |
# We now maintain two histories: | |
# - unified_state: complete conversation (for display) | |
# - internal_text_state: only text turns (for text generation) | |
# Vision turns update only unified_state. | |
# ============================================================================= | |
def send_message(image, text, | |
text_temperature, text_repetition_penalty, text_top_p, text_top_k, text_max_new_tokens, | |
vision_temperature, vision_top_p, vision_top_k, vision_max_tokens, | |
unified_state, vision_state, internal_text_state): | |
# Initialize states if empty | |
if unified_state is None: | |
unified_state = [] | |
if internal_text_state is None: | |
internal_text_state = [] | |
if image is not None: | |
# Use vision inference. | |
user_msg = [] | |
user_msg.append({"type": "image", "image": image}) | |
if text and text.strip(): | |
user_msg.append({"type": "text", "text": text.strip()}) | |
unified_state.append({"role": "user", "content": user_msg}) | |
chat_history, updated_vision_conv = chat_inference(image, text, vision_state, | |
temperature=vision_temperature, | |
top_p=vision_top_p, | |
top_k=vision_top_k, | |
max_tokens=vision_max_tokens) | |
vision_state = updated_vision_conv | |
if updated_vision_conv and updated_vision_conv[-1]["role"] == "assistant": | |
unified_state.append(updated_vision_conv[-1]) | |
yield display_unified_conversation(unified_state), unified_state, vision_state, internal_text_state | |
else: | |
# Text-only mode: update both unified and internal text states. | |
unified_state.append({"role": "user", "content": text}) | |
internal_text_state.append({"role": "user", "content": text}) | |
unified_state.append({"role": "assistant", "content": ""}) | |
internal_text_state.append({"role": "assistant", "content": ""}) | |
yield display_unified_conversation(unified_state), unified_state, vision_state, internal_text_state | |
base_conv = internal_text_state[:-1] | |
assistant_text = "" | |
for chunk in generate( | |
text, base_conv, | |
temperature=text_temperature, | |
repetition_penalty=text_repetition_penalty, | |
top_p=text_top_p, | |
top_k=text_top_k, | |
max_new_tokens=text_max_new_tokens | |
): | |
assistant_text += chunk | |
unified_state[-1]["content"] = assistant_text | |
internal_text_state[-1]["content"] = assistant_text | |
yield display_unified_conversation(unified_state), unified_state, vision_state, internal_text_state | |
yield display_unified_conversation(unified_state), unified_state, vision_state, internal_text_state | |
# ============================================================================= | |
# Clear Chat Function | |
# ============================================================================= | |
def clear_chat(): | |
# Clear unified conversation, vision state, and internal text state. | |
return [], [], [], [], "", None | |
# ============================================================================= | |
# UI Layout with Gradio | |
# ============================================================================= | |
css_file_path = Path(Path(__file__).parent / "app.css") | |
head_file_path = Path(Path(__file__).parent / "app_head.html") | |
with gr.Blocks(fill_height=True, css_paths=[str(css_file_path)], head_paths=[str(head_file_path)], theme=theme, title=TITLE) as demo: | |
gr.HTML(f"<h1>{TITLE}</h1>", elem_classes=["gr_title"]) | |
gr.HTML(DESCRIPTION) | |
chatbot = gr.Chatbot(label="Chat History", height=500) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
image_input = gr.Image(type="pil", label="Upload Image (optional)") | |
text_input = gr.Textbox(lines=2, placeholder="Enter your message here", label="Message") | |
with gr.Column(scale=1): | |
with gr.Accordion("Text Advanced Settings", open=False): | |
text_temperature_slider = gr.Slider(minimum=0, maximum=1.0, value=TEMPERATURE, step=0.1, label="Temperature", elem_classes=["gr_accordion_element"]) | |
repetition_penalty_slider = gr.Slider(minimum=0, maximum=2.0, value=REPETITION_PENALTY, step=0.05, label="Repetition Penalty", elem_classes=["gr_accordion_element"]) | |
top_p_slider = gr.Slider(minimum=0, maximum=1.0, value=TOP_P, step=0.05, label="Top P", elem_classes=["gr_accordion_element"]) | |
top_k_slider = gr.Slider(minimum=0, maximum=100, value=TOP_K, step=1, label="Top K", elem_classes=["gr_accordion_element"]) | |
max_new_tokens_slider = gr.Slider(minimum=1, maximum=2000, value=MAX_NEW_TOKENS, step=1, label="Max New Tokens", elem_classes=["gr_accordion_element"]) | |
with gr.Accordion("Vision Advanced Settings", open=False): | |
vision_temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=VISION_TEMPERATURE, step=0.01, label="Vision Temperature", elem_classes=["gr_accordion_element"]) | |
vision_top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=VISION_TOP_P, step=0.01, label="Vision Top p", elem_classes=["gr_accordion_element"]) | |
vision_top_k_slider = gr.Slider(minimum=0, maximum=100, value=VISION_TOP_K, step=1, label="Vision Top k", elem_classes=["gr_accordion_element"]) | |
vision_max_tokens_slider = gr.Slider(minimum=10, maximum=300, value=VISION_MAX_TOKENS, step=1, label="Vision Max Tokens", elem_classes=["gr_accordion_element"]) | |
send_button = gr.Button("Send Message") | |
clear_button = gr.Button("Clear Chat") | |
# Conversation state variables: | |
# - unified_state: complete conversation for display (text and vision) | |
# - vision_state: state for vision turns | |
# - internal_text_state: only text turns (for text-generation) | |
unified_state = gr.State([]) | |
vision_state = gr.State([]) | |
internal_text_state = gr.State([]) | |
send_button.click( | |
send_message, | |
inputs=[ | |
image_input, text_input, | |
text_temperature_slider, repetition_penalty_slider, top_p_slider, top_k_slider, max_new_tokens_slider, | |
vision_temperature_slider, vision_top_p_slider, vision_top_k_slider, vision_max_tokens_slider, | |
unified_state, vision_state, internal_text_state | |
], | |
outputs=[chatbot, unified_state, vision_state, internal_text_state], | |
) | |
clear_button.click( | |
clear_chat, | |
inputs=None, | |
outputs=[chatbot, unified_state, vision_state, internal_text_state, text_input, image_input] | |
) | |
gr.Examples( | |
examples=[ | |
["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/cheetah1.jpg", "What is in this image?"], | |
[None, "Compute Pi."], | |
[None, "Explain quantum computing to a beginner."], | |
[None, "What is OpenShift?"], | |
[None, "Importance of low latency inference"], | |
[None, "Boosting productivity habits"], | |
[None, "Explain and document your code"], | |
[None, "Generate Java Code"] | |
], | |
inputs=[image_input, text_input], | |
example_labels=[ | |
"Vision Example: What is in this image?", | |
"Compute Pi.", | |
"Explain quantum computing", | |
"What is OpenShift?", | |
"Importance of low latency inference", | |
"Boosting productivity habits", | |
"Explain and document your code", | |
"Generate Java Code" | |
], | |
cache_examples=False, | |
) | |
if __name__ == "__main__": | |
demo.queue().launch(debug=True, share=False) | |