|
import os, re, torch, traceback |
|
import gradio as gr |
|
from threading import Thread |
|
from transformers import ( |
|
AutoTokenizer, AutoModelForCausalLM, |
|
TextIteratorStreamer, BitsAndBytesConfig |
|
) |
|
|
|
|
|
|
|
|
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
|
|
|
|
|
|
|
|
MODEL_ID = os.getenv("MODEL_ID", "huihui-ai/Qwen2.5-7B-Instruct-abliterated-v3").strip() |
|
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "256")) |
|
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.85")) |
|
TOP_P = float(os.getenv("TOP_P", "0.9")) |
|
TOP_K = int(os.getenv("TOP_K", "50")) |
|
REPETITION_PENALTY = float(os.getenv("REPETITION_PENALTY", "1.12")) |
|
SAFE_MODE = os.getenv("SAFE_MODE", "1") != "0" |
|
|
|
|
|
BASE_SYSTEM_PROMPT = os.getenv( |
|
"SYSTEM_PROMPT", |
|
""" |
|
You are a helpful, concise chat assistant. |
|
Do NOT reveal chain-of-thought, analysis, inner reasoning, <Thought>, <analysis>, <think>, or similar sections. |
|
If asked to explain reasoning, provide a brief, high-level summary of steps only. |
|
The final user-visible answer SHOULD be enclosed in <final> ... </final>. |
|
If you don't use <final>, output plain text. |
|
""" |
|
).strip() |
|
DEFAULT_PERSONA = os.getenv("PERSONA", "").strip() |
|
|
|
print(f"[boot] MODEL_ID={MODEL_ID}") |
|
print(f"[boot] torch.cuda.is_available={torch.cuda.is_available()}") |
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_compute_dtype=torch.float16, |
|
) |
|
else: |
|
bnb_config = None |
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
MODEL_ID, use_fast=True, trust_remote_code=True |
|
) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
device_map="auto", |
|
quantization_config=bnb_config, |
|
torch_dtype=torch.float16, |
|
trust_remote_code=True, |
|
) |
|
else: |
|
print("[boot] No GPU detected. Running on CPU is very slow for 7B.") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
device_map="cpu", |
|
torch_dtype=torch.float32, |
|
trust_remote_code=True, |
|
low_cpu_mem_usage=True, |
|
) |
|
|
|
model.eval() |
|
model.generation_config.eos_token_id = tokenizer.eos_token_id |
|
model.generation_config.pad_token_id = tokenizer.eos_token_id |
|
|
|
print(f"[boot] model device: {next(model.parameters()).device}") |
|
|
|
|
|
|
|
|
|
BANNED_PATTERNS = [ |
|
r"(?i)未成年|未成年的|中学生|小学生", |
|
r"(?i)强迫|胁迫|迷奸|药物控制", |
|
r"(?i)换联系方式|加微信|加QQ|加.*联系方式", |
|
r"(?i)线下见面|线下约|酒店", |
|
] |
|
SAFE_REPLACEMENT = "( ̄^ ̄)ゞ 哼哼~" |
|
|
|
def violates(text: str) -> bool: |
|
if not SAFE_MODE or not text: |
|
return False |
|
for p in BANNED_PATTERNS: |
|
if re.search(p, text): |
|
return True |
|
return False |
|
|
|
|
|
|
|
|
|
FINAL_OPEN_PAT = re.compile(r"(?is)<\s*final\s*>") |
|
FINAL_CLOSE_PAT = re.compile(r"(?is)</\s*final\s*>") |
|
|
|
class FinalFilter: |
|
def __init__(self): |
|
self.in_final = False |
|
self.seen_final = False |
|
self.visible = "" |
|
self.buffer = "" |
|
|
|
def push(self, chunk: str): |
|
self.buffer += chunk |
|
i = 0 |
|
while i < len(chunk): |
|
if not self.in_final: |
|
m = FINAL_OPEN_PAT.search(chunk, i) |
|
if not m: |
|
break |
|
i = m.end() |
|
self.in_final = True |
|
self.seen_final = True |
|
else: |
|
m = FINAL_CLOSE_PAT.search(chunk, i) |
|
if not m: |
|
self.visible += chunk[i:] |
|
i = len(chunk) |
|
else: |
|
self.visible += chunk[i:m.start()] |
|
i = m.end() |
|
self.in_final = False |
|
|
|
def get_visible(self) -> str: |
|
text = (self.visible if self.seen_final and self.visible.strip() else self.buffer) |
|
return text.strip() |
|
|
|
|
|
|
|
|
|
def choose_max_new_tokens(user_text: str) -> int: |
|
n = len(user_text or "") |
|
if n <= 15: return min(160, MAX_NEW_TOKENS) |
|
if n <= 50: return min(256, MAX_NEW_TOKENS) |
|
if n <= 120: return min(320, MAX_NEW_TOKENS + 64) |
|
return min(384, MAX_NEW_TOKENS + 128) |
|
|
|
|
|
|
|
|
|
PERSONA_TEMPLATES = { |
|
"默认(Default)": "", |
|
"性感火辣 Flirty & Spicy": ( |
|
"风格:挑逗、暧昧、直接、无评判,语言热辣大胆,带点坏坏的撩人气息。\n" |
|
"行为:多用双关语、调侃与骚话;主动带动暧昧氛围,让对话充满张力;适度使用火焰🔥、唇👄、爱心❤️等表情符号。\n" |
|
), |
|
"温柔体贴 Sweet & Caring": "风格:温柔、共情、耐心,语气轻柔,简洁而温暖。", |
|
"傲娇 Tsundere": "风格:表面嘴硬、内心关心;语气俏皮、略微傲娇但不贬低用户。", |
|
"风趣段子手 Witty Jokester": "风格:机智幽默、梗多但不低俗;注重包袱节奏。", |
|
"严谨导师 Analytical Coach": "风格:结构化、可操作;分点给出步骤与注意事项。", |
|
"冷淡毒舌 Deadpan Sarcasm": "风格:冷面、克制、轻微反讽;不粗鲁不辱骂。", |
|
} |
|
|
|
def compose_system_prompt(base_prompt: str, persona_text: str) -> str: |
|
persona_text = (persona_text or "").strip() |
|
if not persona_text: |
|
return base_prompt |
|
return ( |
|
f"{base_prompt}\n\n" |
|
f"# Persona\n{persona_text}\n\n" |
|
f"- Stay in persona unless the user explicitly asks to change.\n" |
|
) |
|
|
|
|
|
|
|
|
|
def apply_chat_template_with_fallback(messages): |
|
try: |
|
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
except Exception: |
|
parts = [] |
|
for m in messages: |
|
parts.append(f"<|{m['role']}|>\n{m['content']}\n</s>") |
|
parts.append("<|assistant|>\n") |
|
return "".join(parts) |
|
|
|
def build_prompt(history_msgs, user_msg: str, persona_text: str) -> str: |
|
system_prompt = compose_system_prompt(BASE_SYSTEM_PROMPT, persona_text) |
|
tail = [m for m in history_msgs if m.get("role") in ("user", "assistant")] |
|
tail = tail[-8:] if len(tail) > 8 else tail |
|
messages = [{"role": "system", "content": system_prompt}] + tail + [{"role": "user", "content": user_msg}] |
|
return apply_chat_template_with_fallback(messages) |
|
|
|
|
|
|
|
|
|
BASE_GEN_KW = dict( |
|
temperature=TEMPERATURE, |
|
top_p=TOP_P, |
|
top_k=TOP_K, |
|
repetition_penalty=REPETITION_PENALTY, |
|
do_sample=True, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.eos_token_id, |
|
) |
|
|
|
|
|
|
|
|
|
def stream_chat(history_msgs, user_msg, persona_text): |
|
try: |
|
if not user_msg or not user_msg.strip(): |
|
yield history_msgs; return |
|
|
|
if violates(user_msg): |
|
yield history_msgs + [ |
|
{"role": "user", "content": user_msg}, |
|
{"role": "assistant", "content": SAFE_REPLACEMENT}, |
|
] |
|
return |
|
|
|
prompt = build_prompt(history_msgs, user_msg, persona_text) |
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
gen_kwargs = dict( |
|
**inputs, streamer=streamer, |
|
max_new_tokens=choose_max_new_tokens(user_msg), |
|
**BASE_GEN_KW |
|
) |
|
|
|
print("[gen] start") |
|
th = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True) |
|
th.start() |
|
|
|
ff = FinalFilter() |
|
last_len = 0 |
|
|
|
for chunk in streamer: |
|
ff.push(chunk) |
|
visible = ff.get_visible() |
|
|
|
new_text = visible[last_len:] |
|
if not new_text: |
|
continue |
|
last_len = len(visible) |
|
|
|
if violates(visible): |
|
yield history_msgs + [ |
|
{"role": "user", "content": user_msg}, |
|
{"role": "assistant", "content": SAFE_REPLACEMENT}, |
|
] |
|
return |
|
|
|
yield history_msgs + [ |
|
{"role": "user", "content": user_msg}, |
|
{"role": "assistant", "content": visible}, |
|
] |
|
|
|
print("[gen] done, shown_len:", last_len) |
|
|
|
if last_len == 0: |
|
hint = "(未产生可见输出,建议重试或更换提示词)" |
|
yield history_msgs + [ |
|
{"role": "user", "content": user_msg}, |
|
{"role": "assistant", "content": hint}, |
|
] |
|
|
|
except Exception as e: |
|
traceback.print_exc() |
|
err = f"【运行异常】{type(e).__name__}: {e}" |
|
yield history_msgs + [ |
|
{"role": "user", "content": user_msg}, |
|
{"role": "assistant", "content": err}, |
|
] |
|
|
|
|
|
|
|
|
|
CSS = """ |
|
.gradio-container{ max-width:640px; margin:auto; } |
|
footer{ display:none !important; } |
|
""" |
|
|
|
def pick_persona(name: str) -> str: |
|
return PERSONA_TEMPLATES.get(name or "默认(Default)", "") |
|
|
|
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("### 懂你寂寞 · Mobile Web Chat\n") |
|
|
|
with gr.Accordion("🎭 Persona(人设)", open=False): |
|
persona_sel = gr.Dropdown( |
|
choices=list(PERSONA_TEMPLATES.keys()), |
|
value="默认(Default)" if not DEFAULT_PERSONA else None, |
|
label="选择预设人设" |
|
) |
|
persona_box = gr.Textbox( |
|
value=DEFAULT_PERSONA if DEFAULT_PERSONA else pick_persona("默认(Default)"), |
|
placeholder="在这里粘贴 / 编辑你的 Persona 文本。", |
|
lines=8, |
|
label="Persona 描述(可编辑,发送时以此为准)" |
|
) |
|
persona_sel.change(fn=pick_persona, inputs=persona_sel, outputs=persona_box) |
|
|
|
chat = gr.Chatbot(type="messages", height=520, show_copy_button=True) |
|
with gr.Row(): |
|
msg = gr.Textbox(placeholder="说点什么…(回车发送)", autofocus=True) |
|
send = gr.Button("发送", variant="primary") |
|
clear = gr.Button("清空对话") |
|
|
|
clear.click(lambda: [], outputs=[chat]) |
|
msg.submit(stream_chat, [chat, msg, persona_box], [chat], concurrency_limit=4); msg.submit(lambda:"", None, msg) |
|
send.click(stream_chat, [chat, msg, persona_box], [chat], concurrency_limit=4); send.click(lambda:"", None, msg) |
|
|
|
demo.queue().launch(ssr_mode=False, show_api=False) |
|
|