import gradio as gr import cv2 import torch from PIL import Image from pathlib import Path from threading import Thread from transformers import AutoModelForCausalLM, AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer import spaces import time TITLE = " google/gemma-3-270m-it " DESCRIPTION= """ It's so small """ IS_RTL = False TEXT_ALIGN = "left" # model config model_name = "google/gemma-3-270m-it" model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="auto", device_map="auto", attn_implementation="eager" ).eval() processor = AutoProcessor.from_pretrained(model_name) # I will add timestamp later def extract_video_frames(video_path, num_frames=8): cap = cv2.VideoCapture(video_path) frames = [] total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) step = max(total_frames // num_frames, 1) for i in range(num_frames): cap.set(cv2.CAP_PROP_POS_FRAMES, i * step) ret, frame = cap.read() if ret: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame)) cap.release() return frames def format_message(content, files): message_content = [] if content: parts = content.split('') for i, part in enumerate(parts): if part.strip(): message_content.append({"type": "text", "text": part.strip()}) if i < len(parts) - 1 and files: img = Image.open(files.pop(0)) message_content.append({"type": "image", "image": img}) for file in files: file_path = file if isinstance(file, str) else file.name if Path(file_path).suffix.lower() in ['.jpg', '.jpeg', '.png']: img = Image.open(file_path) message_content.append({"type": "image", "image": img}) elif Path(file_path).suffix.lower() in ['.mp4', '.mov']: frames = extract_video_frames(file_path) for frame in frames: message_content.append({"type": "image", "image": frame}) return message_content def format_conversation_history(chat_history): messages = [] current_user_content = [] for item in chat_history: role = item["role"] content = item["content"] if role == "user": if isinstance(content, str): current_user_content.append({"type": "text", "text": content}) elif isinstance(content, list): current_user_content.extend(content) else: current_user_content.append({"type": "text", "text": str(content)}) elif role == "assistant": if current_user_content: messages.append({"role": "user", "content": current_user_content}) current_user_content = [] messages.append({"role": "assistant", "content": [{"type": "text", "text": str(content)}]}) if current_user_content: messages.append({"role": "user", "content": current_user_content}) return messages @spaces.GPU(duration=120) def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty): if isinstance(input_data, dict) and "text" in input_data: text = input_data["text"] files = input_data.get("files", []) else: text = str(input_data) files = [] new_message_content = format_message(text, files) new_message = {"role": "user", "content": new_message_content} system_message = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] if system_prompt else [] processed_history = format_conversation_history(chat_history) messages = system_message + processed_history if messages and messages[-1]["role"] == "user": messages[-1]["content"].extend(new_message["content"]) else: messages.append(new_message) inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True ).to(model.device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) chat_interface = gr.ChatInterface( fn=generate_response, chatbot=gr.Chatbot(rtl=IS_RTL, show_copy_button=True,type="messages"), additional_inputs=[ gr.Slider(label="Max new tokens", minimum=100, maximum=2000, step=1, value=512), gr.Textbox( label="System Prompt", value="You are a very helpful multimodal assistant", lines=4, placeholder="Change the settings", text_align = TEXT_ALIGN, rtl = IS_RTL ), gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.2), gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.4), gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=30), gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1), ], examples=[ [{"text": "Write a poem which describes potatoes"}], ], textbox=gr.MultimodalTextbox( rtl=IS_RTL, label="input", file_types=["image", "video"], file_count="multiple", placeholder="Input text, Any image or video will be ignored", ), cache_examples=False, type="messages", fill_height=True, stop_btn="Stop", css_paths=["style.css"], multimodal=True, title=TITLE, description=DESCRIPTION, theme=gr.themes.Soft(), ) if __name__ == "__main__": chat_interface.queue(max_size=20).launch()