lovass / app.py
Gilvaa's picture
Update app.py
02197bf verified
raw
history blame
11.4 kB
import os, re, torch, traceback
import gradio as gr
from threading import Thread
from transformers import (
AutoTokenizer, AutoModelForCausalLM,
TextIteratorStreamer, BitsAndBytesConfig
)
# ======================
# 环境变量修正(防止 libgomp 报错)
# ======================
os.environ["OMP_NUM_THREADS"] = "1"
# ======================
# 可调参数(也可用 Space 的 Variables 覆盖)
# ======================
MODEL_ID = os.getenv("MODEL_ID", "huihui-ai/Qwen2.5-7B-Instruct-abliterated-v3").strip()
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "256"))
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.85")) # 略升,减复读
TOP_P = float(os.getenv("TOP_P", "0.9"))
TOP_K = int(os.getenv("TOP_K", "50"))
REPETITION_PENALTY = float(os.getenv("REPETITION_PENALTY", "1.12"))
SAFE_MODE = os.getenv("SAFE_MODE", "1") != "0" # 1=开启基础过滤;想关就设为 0
# ——系统基础提示 + 人设默认——
BASE_SYSTEM_PROMPT = os.getenv(
"SYSTEM_PROMPT",
"""
You are a helpful, concise chat assistant.
Do NOT reveal chain-of-thought, analysis, inner reasoning, <Thought>, <analysis>, <think>, or similar sections.
If asked to explain reasoning, provide a brief, high-level summary of steps only.
The final user-visible answer SHOULD be enclosed in <final> ... </final>.
If you don't use <final>, output plain text.
"""
).strip()
DEFAULT_PERSONA = os.getenv("PERSONA", "").strip()
print(f"[boot] MODEL_ID={MODEL_ID}")
print(f"[boot] torch.cuda.is_available={torch.cuda.is_available()}")
# ======================
# 4bit 量化(T4 用 FP16 计算精度)
# ======================
if torch.cuda.is_available():
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16,
)
else:
bnb_config = None
# ======================
# 加载 tokenizer
# ======================
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID, use_fast=True, trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# ======================
# 加载 model
# ======================
if torch.cuda.is_available():
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
quantization_config=bnb_config,
torch_dtype=torch.float16,
trust_remote_code=True,
)
else:
print("[boot] No GPU detected. Running on CPU is very slow for 7B.")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="cpu",
torch_dtype=torch.float32,
trust_remote_code=True,
low_cpu_mem_usage=True,
)
model.eval()
model.generation_config.eos_token_id = tokenizer.eos_token_id
model.generation_config.pad_token_id = tokenizer.eos_token_id
print(f"[boot] model device: {next(model.parameters()).device}")
# ======================
# 安全过滤
# ======================
BANNED_PATTERNS = [
r"(?i)未成年|未成年的|中学生|小学生",
r"(?i)强迫|胁迫|迷奸|药物控制",
r"(?i)换联系方式|加微信|加QQ|加.*联系方式",
r"(?i)线下见面|线下约|酒店",
]
SAFE_REPLACEMENT = "( ̄^ ̄)ゞ 哼哼~"
def violates(text: str) -> bool:
if not SAFE_MODE or not text:
return False
for p in BANNED_PATTERNS:
if re.search(p, text):
return True
return False
# ======================
# FinalFilter:优先只保留 <final>;没有 <final> 时回退显示全文
# ======================
FINAL_OPEN_PAT = re.compile(r"(?is)<\s*final\s*>")
FINAL_CLOSE_PAT = re.compile(r"(?is)</\s*final\s*>")
class FinalFilter:
def __init__(self):
self.in_final = False
self.seen_final = False
self.visible = ""
self.buffer = ""
def push(self, chunk: str):
self.buffer += chunk
i = 0
while i < len(chunk):
if not self.in_final:
m = FINAL_OPEN_PAT.search(chunk, i)
if not m:
break
i = m.end()
self.in_final = True
self.seen_final = True
else:
m = FINAL_CLOSE_PAT.search(chunk, i)
if not m:
self.visible += chunk[i:]
i = len(chunk)
else:
self.visible += chunk[i:m.start()]
i = m.end()
self.in_final = False
def get_visible(self) -> str:
text = (self.visible if self.seen_final and self.visible.strip() else self.buffer)
return text.strip()
# ======================
# 动态长度
# ======================
def choose_max_new_tokens(user_text: str) -> int:
n = len(user_text or "")
if n <= 15: return min(160, MAX_NEW_TOKENS)
if n <= 50: return min(256, MAX_NEW_TOKENS)
if n <= 120: return min(320, MAX_NEW_TOKENS + 64)
return min(384, MAX_NEW_TOKENS + 128)
# ======================
# Persona 模板
# ======================
PERSONA_TEMPLATES = {
"默认(Default)": "",
"性感火辣 Flirty & Spicy": (
"风格:挑逗、暧昧、直接、无评判,语言热辣大胆,带点坏坏的撩人气息。\n"
"行为:多用双关语、调侃与骚话;主动带动暧昧氛围,让对话充满张力;适度使用火焰🔥、唇👄、爱心❤️等表情符号。\n"
),
"温柔体贴 Sweet & Caring": "风格:温柔、共情、耐心,语气轻柔,简洁而温暖。",
"傲娇 Tsundere": "风格:表面嘴硬、内心关心;语气俏皮、略微傲娇但不贬低用户。",
"风趣段子手 Witty Jokester": "风格:机智幽默、梗多但不低俗;注重包袱节奏。",
"严谨导师 Analytical Coach": "风格:结构化、可操作;分点给出步骤与注意事项。",
"冷淡毒舌 Deadpan Sarcasm": "风格:冷面、克制、轻微反讽;不粗鲁不辱骂。",
}
def compose_system_prompt(base_prompt: str, persona_text: str) -> str:
persona_text = (persona_text or "").strip()
if not persona_text:
return base_prompt
return (
f"{base_prompt}\n\n"
f"# Persona\n{persona_text}\n\n"
f"- Stay in persona unless the user explicitly asks to change.\n"
)
# ======================
# 构建 Prompt
# ======================
def apply_chat_template_with_fallback(messages):
try:
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
except Exception:
parts = []
for m in messages:
parts.append(f"<|{m['role']}|>\n{m['content']}\n</s>")
parts.append("<|assistant|>\n")
return "".join(parts)
def build_prompt(history_msgs, user_msg: str, persona_text: str) -> str:
system_prompt = compose_system_prompt(BASE_SYSTEM_PROMPT, persona_text)
tail = [m for m in history_msgs if m.get("role") in ("user", "assistant")]
tail = tail[-8:] if len(tail) > 8 else tail
messages = [{"role": "system", "content": system_prompt}] + tail + [{"role": "user", "content": user_msg}]
return apply_chat_template_with_fallback(messages)
# ======================
# 推理参数
# ======================
BASE_GEN_KW = dict(
temperature=TEMPERATURE,
top_p=TOP_P,
top_k=TOP_K,
repetition_penalty=REPETITION_PENALTY,
do_sample=True,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
# ======================
# 流式输出
# ======================
def stream_chat(history_msgs, user_msg, persona_text):
try:
if not user_msg or not user_msg.strip():
yield history_msgs; return
if violates(user_msg):
yield history_msgs + [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": SAFE_REPLACEMENT},
]
return
prompt = build_prompt(history_msgs, user_msg, persona_text)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(
**inputs, streamer=streamer,
max_new_tokens=choose_max_new_tokens(user_msg),
**BASE_GEN_KW
)
print("[gen] start")
th = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
th.start()
ff = FinalFilter()
last_len = 0
for chunk in streamer:
ff.push(chunk)
visible = ff.get_visible()
new_text = visible[last_len:]
if not new_text:
continue
last_len = len(visible)
if violates(visible):
yield history_msgs + [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": SAFE_REPLACEMENT},
]
return
yield history_msgs + [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": visible},
]
print("[gen] done, shown_len:", last_len)
if last_len == 0:
hint = "(未产生可见输出,建议重试或更换提示词)"
yield history_msgs + [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": hint},
]
except Exception as e:
traceback.print_exc()
err = f"【运行异常】{type(e).__name__}: {e}"
yield history_msgs + [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": err},
]
# ======================
# Gradio UI
# ======================
CSS = """
.gradio-container{ max-width:640px; margin:auto; }
footer{ display:none !important; }
"""
def pick_persona(name: str) -> str:
return PERSONA_TEMPLATES.get(name or "默认(Default)", "")
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
gr.Markdown("### 懂你寂寞 · Mobile Web Chat\n")
with gr.Accordion("🎭 Persona(人设)", open=False):
persona_sel = gr.Dropdown(
choices=list(PERSONA_TEMPLATES.keys()),
value="默认(Default)" if not DEFAULT_PERSONA else None,
label="选择预设人设"
)
persona_box = gr.Textbox(
value=DEFAULT_PERSONA if DEFAULT_PERSONA else pick_persona("默认(Default)"),
placeholder="在这里粘贴 / 编辑你的 Persona 文本。",
lines=8,
label="Persona 描述(可编辑,发送时以此为准)"
)
persona_sel.change(fn=pick_persona, inputs=persona_sel, outputs=persona_box)
chat = gr.Chatbot(type="messages", height=520, show_copy_button=True)
with gr.Row():
msg = gr.Textbox(placeholder="说点什么…(回车发送)", autofocus=True)
send = gr.Button("发送", variant="primary")
clear = gr.Button("清空对话")
clear.click(lambda: [], outputs=[chat])
msg.submit(stream_chat, [chat, msg, persona_box], [chat], concurrency_limit=4); msg.submit(lambda:"", None, msg)
send.click(stream_chat, [chat, msg, persona_box], [chat], concurrency_limit=4); send.click(lambda:"", None, msg)
demo.queue().launch(ssr_mode=False, show_api=False)