gemma-3-270m-it / app.py
gobeldan's picture
Update app.py
c2df3e1 verified
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import subprocess
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
DESCRIPTION = """\
# Gemma 3 270m IT πŸ’ŽπŸ’¬
Try this mini model by Google.
[πŸͺͺ **Model card**](https://huggingface.co/google/gemma-3-270m-it)
"""
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Pick attention backend based on device availability
if torch.cuda.is_available():
device = "cuda"
attn_impl = "flash_attention_2" # or "flash" depending on the library
torch_dtype = torch.bfloat16 # or torch.float16
else:
device = "cpu"
attn_impl = "eager"
torch_dtype = torch.bfloat16 # or float32, bfloat16 supported on CPUs with AVX512-BF16 or AMX (e.g., Intel Ice Lake / Sapphire Rapids, some newer AMD). But many ops may still fall back to float32.
# model_id = "google/gemma-3-270m-it"
model_id = "unsloth/gemma-3-270m-it"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True,)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation=attn_impl,
trust_remote_code=True,
)
model.config.sliding_window = 4096
model.eval()
@spaces.GPU(duration=90)
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_message: str = "",
max_new_tokens: int = 1024,
temperature: float = 0.001,
top_p: float = 1.0,
top_k: int = 50,
repetition_penalty: float = 1.0,
) -> Iterator[str]:
conversation = [{"role": "system", "content": system_message}]
for user, assistant in chat_history:
conversation.extend(
[
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
]
)
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
disable_compile=True, # https://ai.google.dev/gemma/docs/core/huggingface_text_full_finetune#test_model_inference
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(
value="",
label="System message",
render=False,
),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0,
maximum=4.0,
step=0.1,
value=1.0, # default from https://huggingface.co/docs/transformers/en/main_classes/text_generation
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.95, # from https://huggingface.co/google/gemma-3-270m-it/blob/main/generation_config.json
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=64, # from https://huggingface.co/google/gemma-3-270m-it/blob/main/generation_config.json
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.0, # default from https://huggingface.co/docs/transformers/en/main_classes/text_generation
),
],
stop_btn=None,
examples = [
["Hi! How are you?"],
["Pros and cons of a long-term relationship. Bullet list with max 3 pros and 3 cons, concise."],
["How many hours does it take a man to eat a helicopter?"],
["How do you open a JSON file in Python?"],
["Make a bullet list of pros and cons of living in San Francisco. Maximum 2 pros and 2 cons."],
["Invent a short story with animals about the value of friendship."],
["Can you briefly explain what the Python programming language is?"],
["Write a 100-word article on 'Benefits of Open-Source in AI Research'."],
],
cache_examples=False,
)
with gr.Blocks(css="style.css", fill_height=True, theme="soft") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()