|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import re |
|
import asyncio |
|
import httpx |
|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
|
MODEL_ID = "huihui-ai/Qwen2.5-7B-Instruct-abliterated-v3" |
|
|
|
SIMSIMI_ENDPOINT = "https://wsapi.simsimi.com/190410/talk" |
|
SIMSIMI_API_KEY = os.getenv("SIMSIMI_API_KEY", "").strip() |
|
SIMSIMI_LANG = os.getenv("SIMSIMI_LANG", "ch").strip() |
|
SIMSIMI_BAD_MAX = float(os.getenv("SIMSIMI_BAD_MAX", "0.30")) |
|
|
|
|
|
|
|
|
|
dtype = ( |
|
torch.float16 if torch.cuda.is_available() |
|
else torch.bfloat16 if torch.backends.mps.is_available() |
|
else torch.float32 |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
torch_dtype=dtype, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
SYSTEM_PROMPT = "You are a helpful, concise, and friendly AI assistant. Keep answers direct and useful." |
|
|
|
def qwen_generate(messages, max_new_tokens=512, temperature=0.7, top_p=0.9): |
|
""" |
|
messages: list[{"role": "system"|"user"|"assistant", "content": str}] |
|
""" |
|
try: |
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
inputs = tokenizer([prompt], return_tensors="pt").to(model.device) |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.pad_token_id, |
|
) |
|
|
|
gen_ids = outputs[0][inputs["input_ids"].shape[1]:] |
|
text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() |
|
return text |
|
except Exception as e: |
|
return f"[Qwen 生成异常] {e}" |
|
|
|
|
|
|
|
|
|
async def simsimi_smalltalk(user_text: str, lang: str = None, bad_max: float = None, timeout: float = 10.0): |
|
""" |
|
调用 SimSimi SmallTalk: |
|
- Endpoint: https://wsapi.simsimi.com/190410/talk |
|
- Header: x-api-key: <Project Key> |
|
- Body: {"utext": "...", "lang":"ch", "atext_bad_prob_max": 0.3} |
|
""" |
|
if not SIMSIMI_API_KEY: |
|
return None, "未配置 SIMSIMI_API_KEY" |
|
|
|
lang = (lang or SIMSIMI_LANG or "ch").strip() |
|
bad = SIMSIMI_BAD_MAX if bad_max is None else float(bad_max) |
|
|
|
headers = { |
|
"Content-Type": "application/json", |
|
"x-api-key": SIMSIMI_API_KEY |
|
} |
|
payload = { |
|
"utext": user_text, |
|
"lang": lang, |
|
"atext_bad_prob_max": bad |
|
} |
|
|
|
try: |
|
async with httpx.AsyncClient(timeout=timeout) as client: |
|
r = await client.post(SIMSIMI_ENDPOINT, headers=headers, json=payload) |
|
r.raise_for_status() |
|
data = r.json() |
|
|
|
reply = data.get("atext") |
|
if not reply: |
|
|
|
reply = data.get("response") or data.get("msg") |
|
return reply, None |
|
except Exception as e: |
|
return None, f"SimSimi 调用失败: {e}" |
|
|
|
|
|
|
|
|
|
CHATY_HINTS = [ |
|
r"讲个(笑话|段子)", r"无聊", r"随便聊", |
|
r"你(会|能)吐槽", r"来点梗", r"夸我", r"损我一下", |
|
r"夸夸我", r"给我一句毒舌" |
|
] |
|
TASK_HINTS = [ |
|
r"(怎么|如何|为何|为什么|为啥)", |
|
r"(写|生成|改|优化).{0,12}(代码|脚本|文案|提示词|SQL|正则)", |
|
r"(安装|配置|部署|报错|调试|报错)", |
|
r"(引用|数据|来源|对比|表格)" |
|
] |
|
|
|
def is_chitchat(text: str) -> bool: |
|
if re.search("|".join(TASK_HINTS), text, flags=re.I): |
|
return False |
|
if re.search("|".join(CHATY_HINTS), text, flags=re.I): |
|
return True |
|
|
|
return (len(text) <= 22 and not re.search(r"[,。!?.!??]", text)) |
|
|
|
|
|
|
|
|
|
async def hybrid_reply(history_messages, user_text, mode: str, lang: str, bad_max: float): |
|
""" |
|
mode: |
|
- "Auto 混合": 闲聊→SimSimi;任务→Qwen;二者都合适时先 Qwen 再 SimSimi 补一句 |
|
- "只用 Qwen" |
|
- "只用 SimSimi" |
|
""" |
|
lang = (lang or SIMSIMI_LANG or "ch").strip() |
|
bad_max = SIMSIMI_BAD_MAX if bad_max is None else float(bad_max) |
|
|
|
if mode == "只用 SimSimi": |
|
sim, err = await simsimi_smalltalk(user_text, lang=lang, bad_max=bad_max) |
|
return sim or (f"[SimSimi 无回复] {err or '未知错误'}") |
|
|
|
if mode == "只用 Qwen": |
|
base = qwen_generate(history_messages + [{"role": "user", "content": user_text}]) |
|
return base |
|
|
|
|
|
if is_chitchat(user_text): |
|
sim, err = await simsimi_smalltalk(user_text, lang=lang, bad_max=bad_max) |
|
if sim: |
|
return sim |
|
|
|
return qwen_generate(history_messages + [{"role": "user", "content": user_text}]) |
|
|
|
|
|
base = qwen_generate(history_messages + [{"role": "user", "content": user_text}]) |
|
sim_tail, _ = await simsimi_smalltalk(f"用一句简短幽默的方式做个收尾:{user_text}", lang=lang, bad_max=bad_max) |
|
if sim_tail: |
|
return f"{base}\n\n—— {sim_tail}" |
|
return base |
|
|
|
|
|
|
|
|
|
with gr.Blocks(css=""" |
|
#chatbot {height: 560px} |
|
""") as demo: |
|
gr.Markdown("## Qwen × SimSimi Hybrid Chat\n") |
|
|
|
|
|
mode_dd = gr.Dropdown( |
|
choices=["Auto 混合", "只用 Qwen", "只用 SimSimi"], |
|
value="Auto 混合", |
|
label="对话模式" |
|
) |
|
|
|
chatbox = gr.Chatbot(elem_id="chatbot") |
|
user_in = gr.Textbox(placeholder="输入内容,然后点击【提交】发送…", lines=2) |
|
submit_btn = gr.Button("提交", variant="primary") |
|
clear_btn = gr.Button("清空对话") |
|
|
|
|
|
state_msgs = gr.State([{"role": "system", "content": SYSTEM_PROMPT}]) |
|
|
|
async def respond(user_text, history, messages, mode): |
|
user_text = (user_text or "").strip() |
|
if not user_text: |
|
return gr.update(), messages, "" |
|
|
|
lang = SIMSIMI_LANG |
|
bad_max = SIMSIMI_BAD_MAX |
|
|
|
messages = list(messages) if messages else [{"role": "system", "content": SYSTEM_PROMPT}] |
|
messages.append({"role": "user", "content": user_text}) |
|
|
|
reply = await hybrid_reply(messages, user_text, mode=mode, lang=lang, bad_max=bad_max) |
|
|
|
messages.append({"role": "assistant", "content": reply}) |
|
history = (history or []) + [[user_text, reply]] |
|
return history, messages, "" |
|
|
|
def clear_all(): |
|
return [], [{"role": "system", "content": SYSTEM_PROMPT}] |
|
|
|
|
|
submit_btn.click( |
|
respond, |
|
inputs=[user_in, chatbox, state_msgs, mode_dd], |
|
outputs=[chatbox, state_msgs, user_in] |
|
) |
|
clear_btn.click( |
|
clear_all, |
|
inputs=None, |
|
outputs=[chatbox, state_msgs] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|