gemma-3-270m-it / app.py
Norod78's picture
Update app.py
7bf2a69 verified
import gradio as gr
import cv2
import torch
from PIL import Image
from pathlib import Path
from threading import Thread
from transformers import AutoModelForCausalLM, AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
import spaces
import time
TITLE = " google/gemma-3-270m-it "
DESCRIPTION= """
It's so small
"""
IS_RTL = False
TEXT_ALIGN = "left"
# model config
model_name = "google/gemma-3-270m-it"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
attn_implementation="eager"
).eval()
processor = AutoProcessor.from_pretrained(model_name)
# I will add timestamp later
def extract_video_frames(video_path, num_frames=8):
cap = cv2.VideoCapture(video_path)
frames = []
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
step = max(total_frames // num_frames, 1)
for i in range(num_frames):
cap.set(cv2.CAP_PROP_POS_FRAMES, i * step)
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame))
cap.release()
return frames
def format_message(content, files):
message_content = []
if content:
parts = content.split('<image>')
for i, part in enumerate(parts):
if part.strip():
message_content.append({"type": "text", "text": part.strip()})
if i < len(parts) - 1 and files:
img = Image.open(files.pop(0))
message_content.append({"type": "image", "image": img})
for file in files:
file_path = file if isinstance(file, str) else file.name
if Path(file_path).suffix.lower() in ['.jpg', '.jpeg', '.png']:
img = Image.open(file_path)
message_content.append({"type": "image", "image": img})
elif Path(file_path).suffix.lower() in ['.mp4', '.mov']:
frames = extract_video_frames(file_path)
for frame in frames:
message_content.append({"type": "image", "image": frame})
return message_content
def format_conversation_history(chat_history):
messages = []
current_user_content = []
for item in chat_history:
role = item["role"]
content = item["content"]
if role == "user":
if isinstance(content, str):
current_user_content.append({"type": "text", "text": content})
elif isinstance(content, list):
current_user_content.extend(content)
else:
current_user_content.append({"type": "text", "text": str(content)})
elif role == "assistant":
if current_user_content:
messages.append({"role": "user", "content": current_user_content})
current_user_content = []
messages.append({"role": "assistant", "content": [{"type": "text", "text": str(content)}]})
if current_user_content:
messages.append({"role": "user", "content": current_user_content})
return messages
@spaces.GPU(duration=120)
def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
if isinstance(input_data, dict) and "text" in input_data:
text = input_data["text"]
files = input_data.get("files", [])
else:
text = str(input_data)
files = []
new_message_content = format_message(text, files)
new_message = {"role": "user", "content": new_message_content}
system_message = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] if system_prompt else []
processed_history = format_conversation_history(chat_history)
messages = system_message + processed_history
if messages and messages[-1]["role"] == "user":
messages[-1]["content"].extend(new_message["content"])
else:
messages.append(new_message)
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
).to(model.device)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate_response,
chatbot=gr.Chatbot(rtl=IS_RTL, show_copy_button=True,type="messages"),
additional_inputs=[
gr.Slider(label="Max new tokens", minimum=100, maximum=2000, step=1, value=512),
gr.Textbox(
label="System Prompt",
value="You are a very helpful multimodal assistant",
lines=4,
placeholder="Change the settings",
text_align = TEXT_ALIGN, rtl = IS_RTL
),
gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.2),
gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.4),
gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=30),
gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1),
],
examples=[
[{"text": "Write a poem which describes potatoes"}],
],
textbox=gr.MultimodalTextbox(
rtl=IS_RTL,
label="input",
file_types=["image", "video"],
file_count="multiple",
placeholder="Input text, Any image or video will be ignored",
),
cache_examples=False,
type="messages",
fill_height=True,
stop_btn="Stop",
css_paths=["style.css"],
multimodal=True,
title=TITLE,
description=DESCRIPTION,
theme=gr.themes.Soft(),
)
if __name__ == "__main__":
chat_interface.queue(max_size=20).launch()