import gradio as gr from gradio_i18n import Translate, gettext as _ import io from PIL import Image import os import base64 from openai import OpenAI huggingface_spaces = "HUGGINGFACE_SPACES" in os.environ and os.environ['HUGGINGFACE_SPACES'] == "1" local = "LOCAL" in os.environ and os.environ['LOCAL'] == "1" pyinstaller = "PYINSTALLER" in os.environ and os.environ['PYINSTALLER'] == "1" default_img = None default_engine = "local" if pyinstaller else "api" default_base_url = "https://openrouter.ai/api/v1" default_api_model = "google/gemma-3-27b-it" model_id = "google/gemma-3-4b-it" if huggingface_spaces or local or pyinstaller: from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer import torch from threading import Thread model = Gemma3ForConditionalGeneration.from_pretrained( model_id, device_map="auto" ).eval() processor = AutoProcessor.from_pretrained(model_id) generate_kwargs = { 'max_new_tokens': 1000, 'do_sample': True, 'temperature': 1.0 } analytics_code = """ """ lang_store = { "und": { "confirm": "Confirm", "default_description": "", "additional_description": "Character description (optional)", "description_placeholder": "Information that is not shown in the reference sheet, such as the character's name, personality, past stories and habit of saying.", "more_imgs_tab": "More reference images", "more_imgs": "More reference images of the character (optional)", "title": """
RefSheet Chat is open-sourced, developed by snowkylin, and powered by Gemma 3
" }, "zh": { "confirm": "确认", "default_description": "", "additional_description": "角色文字描述(可选)", "description_placeholder": "未在设定图中包含的角色信息,可以包括角色姓名、性格、言语习惯、过往经历等。", "more_imgs_tab": "额外角色参考图", "more_imgs": "额外角色参考图(可选,可上传多张)", "title": """RefSheet Chat 是开源的,由 snowkylin 开发,由开源的 Gemma 3 驱动
""" }, } def encode_img(filepath, thumbnail=(896, 896)): more_img = Image.open(filepath) more_img = more_img.convert('RGB') more_img.thumbnail(thumbnail) buffer = io.BytesIO() more_img.save(buffer, "JPEG", quality=60) encoded_img = "data:image/jpeg;base64," + base64.b64encode(buffer.getvalue()).decode("utf-8") return encoded_img def get_init_prompt(img, description, more_imgs, character_language): prompt = _("prompt") % _(character_language) if description != "": prompt += "\n" + _("additional_info_prompt") + description if more_imgs is None: more_imgs = [] if len(more_imgs) > 0: prompt += "\n" + _("additional_reference_images_prompt") content = [ {"type": "image", "url": encode_img(img)}, {"type": "text", "text": prompt} ] + [{"type": "image", "url": encode_img(filepath)} for filepath in more_imgs] return [ { "role": "user", "content": content } ] def generate(history, engine, base_url, api_model, api_key): if engine == 'local': inputs = processor.apply_chat_template( history, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(model.device, dtype=torch.bfloat16) streamer = TextIteratorStreamer(processor, skip_prompt=True) with torch.inference_mode(): thread = Thread(target=model.generate, kwargs=dict(**inputs, **generate_kwargs, streamer=streamer)) thread.start() generated_text = "" for new_text in streamer: generated_text += new_text yield generated_text elif engine == 'api': for item in history: for item_i in item['content']: if item_i['type'] == 'image': item_i['type'] = 'image_url' item_i['image_url'] = {'url': item_i['url']} del item_i['url'] if base_url == default_base_url and api_model == default_api_model and api_key == "": api_key = os.environ['OPENROUTER_TOKEN'] client = OpenAI(base_url=base_url, api_key=api_key) stream = client.chat.completions.create( model=api_model, messages=history, stream=True, temperature=generate_kwargs['temperature'] ) collected_text = "" for chunk in stream: delta = chunk.choices[0].delta if delta.content: collected_text += delta.content yield collected_text def prefill_chatbot(img, description, more_imgs, character_language, engine, base_url, api_model, api_key): history = get_init_prompt(img, description, more_imgs, character_language) ret = [{'role': 'assistant', 'content': ""}] for generated_text in generate(history, engine, base_url, api_model, api_key): ret[0]['content'] = generated_text yield ret def response(message, history: list, img, description, more_imgs, character_language, engine, base_url, api_model, api_key): history = [{"role": item["role"], "content": [{"type": "text", "text": item["content"]}]} for item in history] history = get_init_prompt(img, description, more_imgs, character_language) + history history.append( {"role": "user", "content": [{"type": "text", "text": message}]} ) for generated_text in generate(history, engine, base_url, api_model, api_key): yield generated_text def set_default_character_language(request: gr.Request): if request.headers["Accept-Language"].split(",")[0].lower().startswith("zh"): default_language = lang_store['zh']['default_language'] else: default_language = lang_store['und']['default_language'] return gr.update(value=default_language) with gr.Blocks(title="Chat with a character via reference sheet!") as demo: with Translate(lang_store) as lang: gr.Markdown(_("title_pyinstaller" if pyinstaller else "title"), sanitize_html=False) img = gr.Image(type="filepath", value=default_img, label=_("upload"), render=False) description = gr.TextArea( value=_("default_description"), label=_("additional_description"), placeholder=_("description_placeholder"), render=False ) character_language = gr.Dropdown( choices=[ (_("en"), "en"), (_("zh"), "zh"), (_("zh-Hant"), "zh-Hant"), (_("ja"), "ja"), (_("ko"), "ko"), (_("fr"), "fr"), (_("de"), "de"), (_("es"), "es"), (_("ru"), "ru"), (_("ar"), "ar"), ], label=_("character_language"), render=False, interactive = True ) more_imgs = gr.Files( label=_("more_imgs"), file_types=["image"], render=False ) confirm_btn = gr.Button(_("confirm"), render=False, variant='primary') chatbot = gr.Chatbot(height=600, type='messages', label=_("chatbox"), render=False) engine = gr.Radio( choices=[ (_("local"), "local"), (_("API"), "api") ], value=default_engine, label=_("method"), render=False, interactive=True ) base_url = gr.Textbox(label=_("base_url"), render=False, value=default_base_url) api_model = gr.Textbox(label=_("api_model"), render=False, value=default_api_model) api_key = gr.Textbox(label=_("api_key"), render=False) with gr.Row(): with gr.Column(scale=4): img.render() with gr.Tab(_("description")): description.render() character_language.render() with gr.Tab(_("more_imgs_tab")): more_imgs.render() if local or huggingface_spaces: with gr.Tab(_("more_options")): engine.render() base_url.render() api_model.render() api_key.render() else: engine.visible = False base_url.visible = False api_model.visible = False api_key.visible = False engine.render() base_url.render() api_model.render() api_key.render() confirm_btn.render() with gr.Column(scale=6): chat = gr.ChatInterface( response, chatbot=chatbot, type="messages", additional_inputs=[img, description, more_imgs, character_language, engine, base_url, api_model, api_key], ) confirm_btn.click(prefill_chatbot, [img, description, more_imgs, character_language, engine, base_url, api_model, api_key], chat.chatbot)\ .then(lambda x: x, chat.chatbot, chat.chatbot_value) gr.HTML(analytics_code) gr.Markdown(_("author")) demo.load(set_default_character_language, None, character_language) if __name__ == "__main__": demo.launch(prevent_thread_lock=True if pyinstaller else False) if pyinstaller: import webview window = webview.create_window("RefSheet Chat", demo.local_url, maximized=True) webview.start()