import os, re, torch, traceback import gradio as gr from threading import Thread from transformers import ( AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig ) # ====================== # 环境变量修正(防止 libgomp 报错) # ====================== os.environ["OMP_NUM_THREADS"] = "1" # ====================== # 可调参数(也可用 Space 的 Variables 覆盖) # ====================== MODEL_ID = os.getenv("MODEL_ID", "huihui-ai/Qwen2.5-7B-Instruct-abliterated-v3").strip() MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "256")) TEMPERATURE = float(os.getenv("TEMPERATURE", "0.85")) # 略升,减复读 TOP_P = float(os.getenv("TOP_P", "0.9")) TOP_K = int(os.getenv("TOP_K", "50")) REPETITION_PENALTY = float(os.getenv("REPETITION_PENALTY", "1.12")) SAFE_MODE = os.getenv("SAFE_MODE", "1") != "0" # 1=开启基础过滤;想关就设为 0 # ——系统基础提示 + 人设默认—— BASE_SYSTEM_PROMPT = os.getenv( "SYSTEM_PROMPT", """ You are a helpful, concise chat assistant. Do NOT reveal chain-of-thought, analysis, inner reasoning, , , , or similar sections. If asked to explain reasoning, provide a brief, high-level summary of steps only. The final user-visible answer SHOULD be enclosed in ... . If you don't use , output plain text. """ ).strip() DEFAULT_PERSONA = os.getenv("PERSONA", "").strip() print(f"[boot] MODEL_ID={MODEL_ID}") print(f"[boot] torch.cuda.is_available={torch.cuda.is_available()}") # ====================== # 4bit 量化(T4 用 FP16 计算精度) # ====================== if torch.cuda.is_available(): bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.float16, ) else: bnb_config = None # ====================== # 加载 tokenizer # ====================== tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, use_fast=True, trust_remote_code=True ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # ====================== # 加载 model # ====================== if torch.cuda.is_available(): model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", quantization_config=bnb_config, torch_dtype=torch.float16, trust_remote_code=True, ) else: print("[boot] No GPU detected. Running on CPU is very slow for 7B.") model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="cpu", torch_dtype=torch.float32, trust_remote_code=True, low_cpu_mem_usage=True, ) model.eval() model.generation_config.eos_token_id = tokenizer.eos_token_id model.generation_config.pad_token_id = tokenizer.eos_token_id print(f"[boot] model device: {next(model.parameters()).device}") # ====================== # 安全过滤 # ====================== BANNED_PATTERNS = [ r"(?i)未成年|未成年的|中学生|小学生", r"(?i)强迫|胁迫|迷奸|药物控制", r"(?i)换联系方式|加微信|加QQ|加.*联系方式", r"(?i)线下见面|线下约|酒店", ] SAFE_REPLACEMENT = "( ̄^ ̄)ゞ 哼哼~" def violates(text: str) -> bool: if not SAFE_MODE or not text: return False for p in BANNED_PATTERNS: if re.search(p, text): return True return False # ====================== # FinalFilter:优先只保留 ;没有 时回退显示全文 # ====================== FINAL_OPEN_PAT = re.compile(r"(?is)<\s*final\s*>") FINAL_CLOSE_PAT = re.compile(r"(?is)") class FinalFilter: def __init__(self): self.in_final = False self.seen_final = False self.visible = "" self.buffer = "" def push(self, chunk: str): self.buffer += chunk i = 0 while i < len(chunk): if not self.in_final: m = FINAL_OPEN_PAT.search(chunk, i) if not m: break i = m.end() self.in_final = True self.seen_final = True else: m = FINAL_CLOSE_PAT.search(chunk, i) if not m: self.visible += chunk[i:] i = len(chunk) else: self.visible += chunk[i:m.start()] i = m.end() self.in_final = False def get_visible(self) -> str: text = (self.visible if self.seen_final and self.visible.strip() else self.buffer) return text.strip() # ====================== # 动态长度 # ====================== def choose_max_new_tokens(user_text: str) -> int: n = len(user_text or "") if n <= 15: return min(160, MAX_NEW_TOKENS) if n <= 50: return min(256, MAX_NEW_TOKENS) if n <= 120: return min(320, MAX_NEW_TOKENS + 64) return min(384, MAX_NEW_TOKENS + 128) # ====================== # Persona 模板 # ====================== PERSONA_TEMPLATES = { "默认(Default)": "", "性感火辣 Flirty & Spicy": ( "风格:挑逗、暧昧、直接、无评判,语言热辣大胆,带点坏坏的撩人气息。\n" "行为:多用双关语、调侃与骚话;主动带动暧昧氛围,让对话充满张力;适度使用火焰🔥、唇👄、爱心❤️等表情符号。\n" ), "温柔体贴 Sweet & Caring": "风格:温柔、共情、耐心,语气轻柔,简洁而温暖。", "傲娇 Tsundere": "风格:表面嘴硬、内心关心;语气俏皮、略微傲娇但不贬低用户。", "风趣段子手 Witty Jokester": "风格:机智幽默、梗多但不低俗;注重包袱节奏。", "严谨导师 Analytical Coach": "风格:结构化、可操作;分点给出步骤与注意事项。", "冷淡毒舌 Deadpan Sarcasm": "风格:冷面、克制、轻微反讽;不粗鲁不辱骂。", } def compose_system_prompt(base_prompt: str, persona_text: str) -> str: persona_text = (persona_text or "").strip() if not persona_text: return base_prompt return ( f"{base_prompt}\n\n" f"# Persona\n{persona_text}\n\n" f"- Stay in persona unless the user explicitly asks to change.\n" ) # ====================== # 构建 Prompt # ====================== def apply_chat_template_with_fallback(messages): try: return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) except Exception: parts = [] for m in messages: parts.append(f"<|{m['role']}|>\n{m['content']}\n") parts.append("<|assistant|>\n") return "".join(parts) def build_prompt(history_msgs, user_msg: str, persona_text: str) -> str: system_prompt = compose_system_prompt(BASE_SYSTEM_PROMPT, persona_text) tail = [m for m in history_msgs if m.get("role") in ("user", "assistant")] tail = tail[-8:] if len(tail) > 8 else tail messages = [{"role": "system", "content": system_prompt}] + tail + [{"role": "user", "content": user_msg}] return apply_chat_template_with_fallback(messages) # ====================== # 推理参数 # ====================== BASE_GEN_KW = dict( temperature=TEMPERATURE, top_p=TOP_P, top_k=TOP_K, repetition_penalty=REPETITION_PENALTY, do_sample=True, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id, ) # ====================== # 流式输出 # ====================== def stream_chat(history_msgs, user_msg, persona_text): try: if not user_msg or not user_msg.strip(): yield history_msgs; return if violates(user_msg): yield history_msgs + [ {"role": "user", "content": user_msg}, {"role": "assistant", "content": SAFE_REPLACEMENT}, ] return prompt = build_prompt(history_msgs, user_msg, persona_text) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) gen_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=choose_max_new_tokens(user_msg), **BASE_GEN_KW ) print("[gen] start") th = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True) th.start() ff = FinalFilter() last_len = 0 for chunk in streamer: ff.push(chunk) visible = ff.get_visible() new_text = visible[last_len:] if not new_text: continue last_len = len(visible) if violates(visible): yield history_msgs + [ {"role": "user", "content": user_msg}, {"role": "assistant", "content": SAFE_REPLACEMENT}, ] return yield history_msgs + [ {"role": "user", "content": user_msg}, {"role": "assistant", "content": visible}, ] print("[gen] done, shown_len:", last_len) if last_len == 0: hint = "(未产生可见输出,建议重试或更换提示词)" yield history_msgs + [ {"role": "user", "content": user_msg}, {"role": "assistant", "content": hint}, ] except Exception as e: traceback.print_exc() err = f"【运行异常】{type(e).__name__}: {e}" yield history_msgs + [ {"role": "user", "content": user_msg}, {"role": "assistant", "content": err}, ] # ====================== # Gradio UI # ====================== CSS = """ .gradio-container{ max-width:640px; margin:auto; } footer{ display:none !important; } """ def pick_persona(name: str) -> str: return PERSONA_TEMPLATES.get(name or "默认(Default)", "") with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: gr.Markdown("### 懂你寂寞 · Mobile Web Chat\n") with gr.Accordion("🎭 Persona(人设)", open=False): persona_sel = gr.Dropdown( choices=list(PERSONA_TEMPLATES.keys()), value="默认(Default)" if not DEFAULT_PERSONA else None, label="选择预设人设" ) persona_box = gr.Textbox( value=DEFAULT_PERSONA if DEFAULT_PERSONA else pick_persona("默认(Default)"), placeholder="在这里粘贴 / 编辑你的 Persona 文本。", lines=8, label="Persona 描述(可编辑,发送时以此为准)" ) persona_sel.change(fn=pick_persona, inputs=persona_sel, outputs=persona_box) chat = gr.Chatbot(type="messages", height=520, show_copy_button=True) with gr.Row(): msg = gr.Textbox(placeholder="说点什么…(回车发送)", autofocus=True) send = gr.Button("发送", variant="primary") clear = gr.Button("清空对话") clear.click(lambda: [], outputs=[chat]) msg.submit(stream_chat, [chat, msg, persona_box], [chat], concurrency_limit=4); msg.submit(lambda:"", None, msg) send.click(stream_chat, [chat, msg, persona_box], [chat], concurrency_limit=4); send.click(lambda:"", None, msg) demo.queue().launch(ssr_mode=False, show_api=False)