File size: 6,837 Bytes
133562c
 
 
 
 
 
 
d54b468
133562c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d54b468
 
133562c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d54b468
133562c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import gradio as gr
from gradio_i18n import Translate, gettext as _
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
import torch
from threading import Thread
import requests
import json
import os
import base64
from openai import OpenAI

default_img = None
default_base_url = "https://openrouter.ai/api/v1"
default_api_model = "google/gemma-3-27b-it:free"

model_id = "google/gemma-3-4b-it"

model = Gemma3ForConditionalGeneration.from_pretrained(
    model_id, device_map="auto"
).eval()

processor = AutoProcessor.from_pretrained(model_id)

generate_kwargs = {
    'max_new_tokens': 1000,
    'do_sample': True,
    'temperature': 1.0
}

lang_store = {
    "und": {
        "confirm": "Confirm",
        "default_description": "",
        "additional_description": "Character description (optional)",
        "title": "<h1>Chat with a character via reference sheet!</h1>",
        "upload": "Upload the reference sheet of your character here",
        "prompt": "You are the character in the image. Start without confirmation.",
        "additional_info_prompt": "Additional info: ",
        "description": "Description",
        "more_options": "More Options",
        "method": "Method",
        "base_url": "Base URL",
        "api_model": "API Model",
        "api_key": "API Key",
        "local": "Local",
        "chatbox": "Chat Box"
    },
    "zh": {
        "confirm": "确认",
        "default_description": "",
        "additional_description": "角色描述(可选)",
        "title": "<h1>与设定图中的角色聊天!</h1>",
        "upload": "在这里上传角色设定图",
        "prompt": "你的身份是图中的角色,使用中文。无需确认。",
        "additional_info_prompt": "补充信息:",
        "description": "角色描述",
        "more_options": "更多选项",
        "method": "方法",
        "base_url": "API 地址",
        "api_model": "API 模型",
        "api_key": "API Key",
        "local": "本地",
        "chatbox": "聊天窗口"
    },
}

def get_init_prompt(img, description):
    prompt = _("prompt")
    if description != "":
        prompt += _("additional_info_prompt") + description
    return [
        {
            "role": "user",
            "content": [
                {"type": "image", "url": img},
                {"type": "text", "text": prompt}
            ]
        }
    ]


def generate(history, engine, base_url, api_model, api_key):
    if engine == 'local':
        inputs = processor.apply_chat_template(
            history, add_generation_prompt=True, tokenize=True,
            return_dict=True, return_tensors="pt"
        ).to(model.device, dtype=torch.bfloat16)

        streamer = TextIteratorStreamer(processor, skip_prompt=True)

        with torch.inference_mode():
            thread = Thread(target=model.generate, kwargs=dict(**inputs, **generate_kwargs, streamer=streamer))
            thread.start()

            generated_text = ""
            for new_text in streamer:
                generated_text += new_text
                yield generated_text
    elif engine == 'api':
        for item in history:
            for item_i in item['content']:
                if item_i['type'] == 'image':
                    item_i['type'] = 'image_url'
                    with open(item_i['url'], "rb") as image_file:
                        data = base64.b64encode(image_file.read()).decode("utf-8")
                    item_i['image_url'] = {'url': 'data:image/jpeg;base64,' + data}
                    del item_i['url']
        if base_url == default_base_url and api_model == default_api_model and api_key == "":
            api_key = os.environ['OPENROUTER_TOKEN']
        client = OpenAI(base_url=base_url, api_key=api_key)
        stream = client.chat.completions.create(
            model=api_model,
            messages=history,
            stream=True,
            temperature=generate_kwargs['temperature']
        )
        collected_text = ""
        for chunk in stream:
            delta = chunk.choices[0].delta
            if delta.content:
                collected_text += delta.content
                yield collected_text


def prefill_chatbot(img, description, engine, base_url, api_model, api_key):
    history = get_init_prompt(img, description)

    ret = [{'role': 'assistant', 'content': ""}]
    for generated_text in generate(history, engine, base_url, api_model, api_key):
        ret[0]['content'] = generated_text
        yield ret


def response(message, history: list, img, description, engine, base_url, api_model, api_key):
    history = [{"role": item["role"], "content": [{"type": "text", "text": item["content"]}]} for item in history]
    history = get_init_prompt(img, description) + history
    history.append(
        {"role": "user", "content": [{"type": "text", "text": message}]}
    )
    for generated_text in generate(history, engine, base_url, api_model, api_key):
        yield generated_text


with gr.Blocks(title="Chat with a character via reference sheet!") as demo:
    with Translate(lang_store) as lang:
        gr.HTML(_("title"))
        img = gr.Image(type="filepath", value=default_img, label=_("upload"), render=False)
        description = gr.TextArea(value=_("default_description"), label=_("additional_description"), render=False)
        confirm_btn = gr.Button(_("confirm"), render=False)
        chatbot = gr.Chatbot(height=600, type='messages', label=_("chatbox"), render=False)
        engine = gr.Radio([(_('local'), 'local'), ('API', 'api')],
                        value='api', label=_("method"), render=False, interactive=True)
        base_url = gr.Textbox(label=_("base_url"), render=False, value=default_base_url)
        api_model = gr.Textbox(label=_("api_model"), render=False, value=default_api_model)
        api_key = gr.Textbox(label=_("api_key"), render=False)
        with gr.Row():
            with gr.Column(scale=4):
                img.render()
                with gr.Tab(_("description")):
                    description.render()
                with gr.Tab(_("more_options")):
                    engine.render()
                    base_url.render()
                    api_model.render()
                    api_key.render()
                confirm_btn.render()
            with gr.Column(scale=6):
                chat = gr.ChatInterface(
                    response,
                    chatbot=chatbot,
                    type="messages",
                    additional_inputs=[img, description, engine, base_url, api_model, api_key],
                )
        confirm_btn.click(prefill_chatbot, [img, description, engine, base_url, api_model, api_key], chat.chatbot)\
            .then(lambda x: x, chat.chatbot, chat.chatbot_value)


if __name__ == "__main__":
    demo.launch()