Spaces:
Running
Running
File size: 6,837 Bytes
133562c d54b468 133562c d54b468 133562c d54b468 133562c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import gradio as gr
from gradio_i18n import Translate, gettext as _
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
import torch
from threading import Thread
import requests
import json
import os
import base64
from openai import OpenAI
default_img = None
default_base_url = "https://openrouter.ai/api/v1"
default_api_model = "google/gemma-3-27b-it:free"
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto"
).eval()
processor = AutoProcessor.from_pretrained(model_id)
generate_kwargs = {
'max_new_tokens': 1000,
'do_sample': True,
'temperature': 1.0
}
lang_store = {
"und": {
"confirm": "Confirm",
"default_description": "",
"additional_description": "Character description (optional)",
"title": "<h1>Chat with a character via reference sheet!</h1>",
"upload": "Upload the reference sheet of your character here",
"prompt": "You are the character in the image. Start without confirmation.",
"additional_info_prompt": "Additional info: ",
"description": "Description",
"more_options": "More Options",
"method": "Method",
"base_url": "Base URL",
"api_model": "API Model",
"api_key": "API Key",
"local": "Local",
"chatbox": "Chat Box"
},
"zh": {
"confirm": "确认",
"default_description": "",
"additional_description": "角色描述(可选)",
"title": "<h1>与设定图中的角色聊天!</h1>",
"upload": "在这里上传角色设定图",
"prompt": "你的身份是图中的角色,使用中文。无需确认。",
"additional_info_prompt": "补充信息:",
"description": "角色描述",
"more_options": "更多选项",
"method": "方法",
"base_url": "API 地址",
"api_model": "API 模型",
"api_key": "API Key",
"local": "本地",
"chatbox": "聊天窗口"
},
}
def get_init_prompt(img, description):
prompt = _("prompt")
if description != "":
prompt += _("additional_info_prompt") + description
return [
{
"role": "user",
"content": [
{"type": "image", "url": img},
{"type": "text", "text": prompt}
]
}
]
def generate(history, engine, base_url, api_model, api_key):
if engine == 'local':
inputs = processor.apply_chat_template(
history, add_generation_prompt=True, tokenize=True,
return_dict=True, return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)
streamer = TextIteratorStreamer(processor, skip_prompt=True)
with torch.inference_mode():
thread = Thread(target=model.generate, kwargs=dict(**inputs, **generate_kwargs, streamer=streamer))
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
yield generated_text
elif engine == 'api':
for item in history:
for item_i in item['content']:
if item_i['type'] == 'image':
item_i['type'] = 'image_url'
with open(item_i['url'], "rb") as image_file:
data = base64.b64encode(image_file.read()).decode("utf-8")
item_i['image_url'] = {'url': 'data:image/jpeg;base64,' + data}
del item_i['url']
if base_url == default_base_url and api_model == default_api_model and api_key == "":
api_key = os.environ['OPENROUTER_TOKEN']
client = OpenAI(base_url=base_url, api_key=api_key)
stream = client.chat.completions.create(
model=api_model,
messages=history,
stream=True,
temperature=generate_kwargs['temperature']
)
collected_text = ""
for chunk in stream:
delta = chunk.choices[0].delta
if delta.content:
collected_text += delta.content
yield collected_text
def prefill_chatbot(img, description, engine, base_url, api_model, api_key):
history = get_init_prompt(img, description)
ret = [{'role': 'assistant', 'content': ""}]
for generated_text in generate(history, engine, base_url, api_model, api_key):
ret[0]['content'] = generated_text
yield ret
def response(message, history: list, img, description, engine, base_url, api_model, api_key):
history = [{"role": item["role"], "content": [{"type": "text", "text": item["content"]}]} for item in history]
history = get_init_prompt(img, description) + history
history.append(
{"role": "user", "content": [{"type": "text", "text": message}]}
)
for generated_text in generate(history, engine, base_url, api_model, api_key):
yield generated_text
with gr.Blocks(title="Chat with a character via reference sheet!") as demo:
with Translate(lang_store) as lang:
gr.HTML(_("title"))
img = gr.Image(type="filepath", value=default_img, label=_("upload"), render=False)
description = gr.TextArea(value=_("default_description"), label=_("additional_description"), render=False)
confirm_btn = gr.Button(_("confirm"), render=False)
chatbot = gr.Chatbot(height=600, type='messages', label=_("chatbox"), render=False)
engine = gr.Radio([(_('local'), 'local'), ('API', 'api')],
value='api', label=_("method"), render=False, interactive=True)
base_url = gr.Textbox(label=_("base_url"), render=False, value=default_base_url)
api_model = gr.Textbox(label=_("api_model"), render=False, value=default_api_model)
api_key = gr.Textbox(label=_("api_key"), render=False)
with gr.Row():
with gr.Column(scale=4):
img.render()
with gr.Tab(_("description")):
description.render()
with gr.Tab(_("more_options")):
engine.render()
base_url.render()
api_model.render()
api_key.render()
confirm_btn.render()
with gr.Column(scale=6):
chat = gr.ChatInterface(
response,
chatbot=chatbot,
type="messages",
additional_inputs=[img, description, engine, base_url, api_model, api_key],
)
confirm_btn.click(prefill_chatbot, [img, description, engine, base_url, api_model, api_key], chat.chatbot)\
.then(lambda x: x, chat.chatbot, chat.chatbot_value)
if __name__ == "__main__":
demo.launch() |