Update app.py
Browse files
app.py
CHANGED
@@ -5,31 +5,68 @@ from transformers import (
|
|
5 |
AutoTokenizer, AutoModelForCausalLM,
|
6 |
TextIteratorStreamer, BitsAndBytesConfig
|
7 |
)
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
bnb_config = BitsAndBytesConfig(
|
13 |
load_in_4bit=True,
|
14 |
bnb_4bit_quant_type="nf4",
|
15 |
bnb_4bit_use_double_quant=True,
|
16 |
-
bnb_4bit_compute_dtype=torch.
|
17 |
)
|
18 |
|
|
|
|
|
|
|
19 |
tokenizer = AutoTokenizer.from_pretrained(
|
20 |
MODEL_ID, use_fast=True, trust_remote_code=True
|
21 |
)
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
model = AutoModelForCausalLM.from_pretrained(
|
24 |
MODEL_ID,
|
25 |
device_map="auto",
|
26 |
quantization_config=bnb_config,
|
27 |
-
torch_dtype=torch.
|
28 |
trust_remote_code=True,
|
29 |
)
|
|
|
|
|
|
|
|
|
|
|
30 |
model.eval()
|
31 |
|
32 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
SYSTEM_SAFETY = """
|
34 |
你必须遵守以下规则(任何情况下都不得违反):
|
35 |
- 禁止变态/暴力/强迫内容;
|
@@ -39,66 +76,87 @@ SYSTEM_SAFETY = """
|
|
39 |
- 若用户请求超出边界,礼貌拒绝并将话题引导至健康、尊重的沟通表达练习。
|
40 |
"""
|
41 |
|
42 |
-
#
|
|
|
|
|
43 |
PERSONA = """
|
44 |
-
|
45 |
-
-
|
46 |
- 风格:简洁、有节奏的短句;多使用比喻、轻度暧昧的赞美与玩笑;
|
47 |
-
|
48 |
- 输出语言:用户用中文就中文;用英文就英文。
|
49 |
- 当遇到越界或不适宜的话题时:温柔拒绝,并用更健康的话题替代(如表达技巧、相处之道、仪式感)。
|
50 |
"""
|
51 |
|
|
|
|
|
|
|
52 |
FEW_SHOTS = [
|
53 |
{"role": "user", "content": "在忙吗?想跟你聊会儿。"},
|
54 |
{"role": "assistant", "content": "忙到你为止呀~现在给你开专属通道。想听我夸你,还是来点轻微心跳的鼓励?"},
|
55 |
-
|
56 |
{"role": "user", "content": "说点更刺激的。"},
|
57 |
-
{"role": "assistant", "content": "
|
58 |
-
|
59 |
{"role": "user", "content": "今晚出来见面吗?"},
|
60 |
{"role": "assistant", "content": "我更喜欢慢慢靠近的感觉。先把心事聊透,再计划线下的仪式感,好吗?现在先约个“云端散步”:你说三个想去的地方,我用三段专属描写把你带过去。"},
|
61 |
]
|
62 |
|
63 |
-
#
|
|
|
|
|
64 |
BANNED_PATTERNS = [
|
65 |
r"(?i)未成年|未成年的|中学生|小学生",
|
66 |
r"(?i)强迫|胁迫|迷奸|药物控制",
|
67 |
r"(?i)换联系方式|加微信|加QQ|加.*联系方式",
|
68 |
r"(?i)线下见面|线下约|酒店|开房",
|
69 |
]
|
|
|
70 |
|
71 |
def violates(text: str) -> bool:
|
72 |
-
if not text:
|
|
|
73 |
for p in BANNED_PATTERNS:
|
74 |
if re.search(p, text):
|
75 |
return True
|
76 |
return False
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
#
|
81 |
-
def build_system_prompt():
|
82 |
-
# 把强约束和人设合并为 system prompt
|
83 |
return f"{SYSTEM_SAFETY.strip()}\n\n=== Persona ===\n{PERSONA.strip()}"
|
84 |
|
85 |
-
def build_prompt(history_msgs, user_msg):
|
|
|
|
|
|
|
86 |
messages = [{"role": "system", "content": build_system_prompt()}]
|
87 |
-
|
88 |
-
|
89 |
-
messages.extend(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
messages.append({"role": "user", "content": user_msg})
|
|
|
|
|
91 |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
92 |
return prompt
|
93 |
|
94 |
-
#
|
|
|
|
|
95 |
GEN_KW = dict(
|
96 |
-
max_new_tokens=
|
97 |
-
temperature=
|
98 |
-
top_p=
|
99 |
-
repetition_penalty=
|
100 |
do_sample=True,
|
101 |
-
eos_token_id=tokenizer.eos_token_id
|
|
|
102 |
)
|
103 |
|
104 |
def stream_chat(history_msgs, user_msg):
|
@@ -106,11 +164,13 @@ def stream_chat(history_msgs, user_msg):
|
|
106 |
yield history_msgs
|
107 |
return
|
108 |
|
|
|
109 |
if violates(user_msg):
|
110 |
-
|
111 |
-
{"role":"user","content": user_msg},
|
112 |
-
{"role":"assistant","content": SAFE_REPLACEMENT},
|
113 |
]
|
|
|
114 |
return
|
115 |
|
116 |
prompt = build_prompt(history_msgs, user_msg)
|
@@ -118,42 +178,56 @@ def stream_chat(history_msgs, user_msg):
|
|
118 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
119 |
|
120 |
gen_kwargs = dict(**inputs, streamer=streamer, **GEN_KW)
|
121 |
-
Thread(target=model.generate, kwargs=gen_kwargs)
|
|
|
122 |
|
123 |
reply = ""
|
124 |
for new_text in streamer:
|
125 |
reply += new_text
|
|
|
126 |
if violates(reply):
|
127 |
-
|
128 |
-
|
129 |
-
{"role":"
|
130 |
-
{"role":"assistant","content": reply},
|
131 |
]
|
|
|
132 |
return
|
|
|
133 |
yield history_msgs + [
|
134 |
-
{"role":"user","content": user_msg},
|
135 |
-
{"role":"assistant","content": reply},
|
136 |
]
|
137 |
|
138 |
-
#
|
|
|
|
|
139 |
CSS = """
|
140 |
.gradio-container{ max-width:640px; margin:auto; }
|
141 |
footer{ display:none !important; }
|
142 |
"""
|
143 |
|
144 |
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
145 |
-
gr.Markdown("### 💋 御姐聊天 · Mobile Web\n
|
146 |
-
|
147 |
chat = gr.Chatbot(type="messages", height=520, show_copy_button=True)
|
|
|
148 |
with gr.Row():
|
149 |
-
msg = gr.Textbox(placeholder="
|
150 |
send = gr.Button("发送", variant="primary")
|
151 |
clear = gr.Button("清空对话")
|
152 |
|
153 |
-
|
154 |
-
clear.click(on_clear, outputs=[chat])
|
155 |
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
158 |
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
AutoTokenizer, AutoModelForCausalLM,
|
6 |
TextIteratorStreamer, BitsAndBytesConfig
|
7 |
)
|
8 |
+
from peft import PeftModel
|
9 |
+
|
10 |
+
# =====================
|
11 |
+
# 配置:支持从环境变量切换
|
12 |
+
# =====================
|
13 |
+
MODEL_ID = os.getenv("MODEL_ID", "Qwen/Qwen2.5-7B-Instruct").strip()
|
14 |
+
ADAPTER_ID = os.getenv("ADAPTER_ID", "").strip() # 可选:LoRA 适配器仓库名,留空则不用
|
15 |
+
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "256"))
|
16 |
+
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
|
17 |
+
TOP_P = float(os.getenv("TOP_P", "0.9"))
|
18 |
+
REPETITION_PENALTY = float(os.getenv("REPETITION_PENALTY", "1.05"))
|
19 |
+
|
20 |
+
print(f"[boot] MODEL_ID={MODEL_ID} ADAPTER_ID={ADAPTER_ID or '(none)'}")
|
21 |
+
print(f"[boot] torch.cuda.is_available={torch.cuda.is_available()}")
|
22 |
+
|
23 |
+
# =====================
|
24 |
+
# 4-bit 量化 (T4 用 FP16)
|
25 |
+
# =====================
|
26 |
bnb_config = BitsAndBytesConfig(
|
27 |
load_in_4bit=True,
|
28 |
bnb_4bit_quant_type="nf4",
|
29 |
bnb_4bit_use_double_quant=True,
|
30 |
+
bnb_4bit_compute_dtype=torch.float16, # T4: FP16
|
31 |
)
|
32 |
|
33 |
+
# =====================
|
34 |
+
# 加载 tokenizer
|
35 |
+
# =====================
|
36 |
tokenizer = AutoTokenizer.from_pretrained(
|
37 |
MODEL_ID, use_fast=True, trust_remote_code=True
|
38 |
)
|
39 |
|
40 |
+
# 若无 pad_token,用 eos 兜底
|
41 |
+
if tokenizer.pad_token is None:
|
42 |
+
tokenizer.pad_token = tokenizer.eos_token
|
43 |
+
|
44 |
+
# =====================
|
45 |
+
# 加载 model
|
46 |
+
# =====================
|
47 |
model = AutoModelForCausalLM.from_pretrained(
|
48 |
MODEL_ID,
|
49 |
device_map="auto",
|
50 |
quantization_config=bnb_config,
|
51 |
+
torch_dtype=torch.float16, # T4: FP16
|
52 |
trust_remote_code=True,
|
53 |
)
|
54 |
+
# 可选:套 LoRA(社区 uncensored LoRA 时用)
|
55 |
+
if ADAPTER_ID:
|
56 |
+
print(f"[boot] loading LoRA adapter: {ADAPTER_ID}")
|
57 |
+
model = PeftModel.from_pretrained(model, ADAPTER_ID)
|
58 |
+
|
59 |
model.eval()
|
60 |
|
61 |
+
# 设定生成配置里的 pad/eos(防止生成卡住)
|
62 |
+
model.generation_config.eos_token_id = tokenizer.eos_token_id
|
63 |
+
model.generation_config.pad_token_id = tokenizer.eos_token_id
|
64 |
+
|
65 |
+
print(f"[boot] model loaded on device(s): {set(str(p.device) for p in model.parameters())}")
|
66 |
+
|
67 |
+
# =====================
|
68 |
+
# 系统底线(硬约束,不可越界)
|
69 |
+
# =====================
|
70 |
SYSTEM_SAFETY = """
|
71 |
你必须遵守以下规则(任何情况下都不得违反):
|
72 |
- 禁止变态/暴力/强迫内容;
|
|
|
76 |
- 若用户请求超出边界,礼貌拒绝并将话题引导至健康、尊重的沟通表达练习。
|
77 |
"""
|
78 |
|
79 |
+
# =====================
|
80 |
+
# Persona(成熟御姐|暧昧撩人但有分寸)
|
81 |
+
# =====================
|
82 |
PERSONA = """
|
83 |
+
你的人设:成熟御姐,优雅大方,语气撩人,偏口语化与暧昧式夸赞。
|
84 |
+
- 关键词:自信、幽默、懂分寸、情话高手、温柔而露骨。
|
85 |
- 风格:简洁、有节奏的短句;多使用比喻、轻度暧昧的赞美与玩笑;
|
86 |
+
适当使用直白词汇;适度设置边界,尊重对方。
|
87 |
- 输出语言:用户用中文就中文;用英文就英文。
|
88 |
- 当遇到越界或不适宜的话题时:温柔拒绝,并用更健康的话题替代(如表达技巧、相处之道、仪式感)。
|
89 |
"""
|
90 |
|
91 |
+
# =====================
|
92 |
+
# few-shot(固定语气与边界)
|
93 |
+
# =====================
|
94 |
FEW_SHOTS = [
|
95 |
{"role": "user", "content": "在忙吗?想跟你聊会儿。"},
|
96 |
{"role": "assistant", "content": "忙到你为止呀~现在给你开专属通道。想听我夸你,还是来点轻微心跳的鼓励?"},
|
|
|
97 |
{"role": "user", "content": "说点更刺激的。"},
|
98 |
+
{"role": "assistant", "content": "坏家伙~火力别开太猛。但我可以把暧昧调高一点点:比如我会贴近你耳边,慢慢说“想你”,让你今晚被甜到失眠。这样的“刺激”,刚刚好?"},
|
|
|
99 |
{"role": "user", "content": "今晚出来见面吗?"},
|
100 |
{"role": "assistant", "content": "我更喜欢慢慢靠近的感觉。先把心事聊透,再计划线下的仪式感,好吗?现在先约个“云端散步”:你说三个想去的地方,我用三段专属描写把你带过去。"},
|
101 |
]
|
102 |
|
103 |
+
# =====================
|
104 |
+
# 轻量过滤(示例,可自行扩展)
|
105 |
+
# =====================
|
106 |
BANNED_PATTERNS = [
|
107 |
r"(?i)未成年|未成年的|中学生|小学生",
|
108 |
r"(?i)强迫|胁迫|迷奸|药物控制",
|
109 |
r"(?i)换联系方式|加微信|加QQ|加.*联系方式",
|
110 |
r"(?i)线下见面|线下约|酒店|开房",
|
111 |
]
|
112 |
+
SAFE_REPLACEMENT = "( ̄^ ̄)ゞ 哼哼~"
|
113 |
|
114 |
def violates(text: str) -> bool:
|
115 |
+
if not text:
|
116 |
+
return False
|
117 |
for p in BANNED_PATTERNS:
|
118 |
if re.search(p, text):
|
119 |
return True
|
120 |
return False
|
121 |
|
122 |
+
# =====================
|
123 |
+
# Prompt 构建
|
124 |
+
# =====================
|
125 |
+
def build_system_prompt() -> str:
|
|
|
126 |
return f"{SYSTEM_SAFETY.strip()}\n\n=== Persona ===\n{PERSONA.strip()}"
|
127 |
|
128 |
+
def build_prompt(history_msgs, user_msg: str) -> str:
|
129 |
+
"""
|
130 |
+
history_msgs: Chatbot(type='messages') 的历史 [{role, content}, ...]
|
131 |
+
"""
|
132 |
messages = [{"role": "system", "content": build_system_prompt()}]
|
133 |
+
|
134 |
+
# few-shot 先注入(固定风格)
|
135 |
+
messages.extend(FEW_SHOTS)
|
136 |
+
|
137 |
+
# 取最近若干条历史(仅 user/assistant)
|
138 |
+
tail = [m for m in history_msgs if m.get("role") in ("user", "assistant")]
|
139 |
+
tail = tail[-8:] if len(tail) > 8 else tail
|
140 |
+
messages.extend(tail)
|
141 |
+
|
142 |
+
# 本轮用户输入
|
143 |
messages.append({"role": "user", "content": user_msg})
|
144 |
+
|
145 |
+
# Qwen 专用聊天模板
|
146 |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
147 |
return prompt
|
148 |
|
149 |
+
# =====================
|
150 |
+
# 生成参数 & 流式输出
|
151 |
+
# =====================
|
152 |
GEN_KW = dict(
|
153 |
+
max_new_tokens=MAX_NEW_TOKENS,
|
154 |
+
temperature=TEMPERATURE,
|
155 |
+
top_p=TOP_P,
|
156 |
+
repetition_penalty=REPETITION_PENALTY,
|
157 |
do_sample=True,
|
158 |
+
eos_token_id=tokenizer.eos_token_id,
|
159 |
+
pad_token_id=tokenizer.eos_token_id,
|
160 |
)
|
161 |
|
162 |
def stream_chat(history_msgs, user_msg):
|
|
|
164 |
yield history_msgs
|
165 |
return
|
166 |
|
167 |
+
# 输入侧轻过滤
|
168 |
if violates(user_msg):
|
169 |
+
return_history = history_msgs + [
|
170 |
+
{"role": "user", "content": user_msg},
|
171 |
+
{"role": "assistant", "content": SAFE_REPLACEMENT},
|
172 |
]
|
173 |
+
yield return_history
|
174 |
return
|
175 |
|
176 |
prompt = build_prompt(history_msgs, user_msg)
|
|
|
178 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
179 |
|
180 |
gen_kwargs = dict(**inputs, streamer=streamer, **GEN_KW)
|
181 |
+
th = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
|
182 |
+
th.start()
|
183 |
|
184 |
reply = ""
|
185 |
for new_text in streamer:
|
186 |
reply += new_text
|
187 |
+
# 输出侧轻过滤:命中即替换并结束本轮
|
188 |
if violates(reply):
|
189 |
+
safe_hist = history_msgs + [
|
190 |
+
{"role": "user", "content": user_msg},
|
191 |
+
{"role": "assistant", "content": SAFE_REPLACEMENT},
|
|
|
192 |
]
|
193 |
+
yield safe_hist
|
194 |
return
|
195 |
+
|
196 |
yield history_msgs + [
|
197 |
+
{"role": "user", "content": user_msg},
|
198 |
+
{"role": "assistant", "content": reply},
|
199 |
]
|
200 |
|
201 |
+
# =====================
|
202 |
+
# Gradio UI(移动端友好)
|
203 |
+
# =====================
|
204 |
CSS = """
|
205 |
.gradio-container{ max-width:640px; margin:auto; }
|
206 |
footer{ display:none !important; }
|
207 |
"""
|
208 |
|
209 |
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
210 |
+
gr.Markdown("### 💋 御姐聊天 · Mobile Web\n温柔撩人,但始终优雅有分寸。")
|
211 |
+
|
212 |
chat = gr.Chatbot(type="messages", height=520, show_copy_button=True)
|
213 |
+
|
214 |
with gr.Row():
|
215 |
+
msg = gr.Textbox(placeholder="想跟御姐聊点什么?(回车发送)", autofocus=True)
|
216 |
send = gr.Button("发送", variant="primary")
|
217 |
clear = gr.Button("清空对话")
|
218 |
|
219 |
+
clear.click(lambda: [], outputs=[chat])
|
|
|
220 |
|
221 |
+
# 绑定事件(支持回车 & 点击)
|
222 |
+
msg.submit(stream_chat, [chat, msg], [chat], concurrency_limit=4)
|
223 |
+
msg.submit(lambda: "", None, msg)
|
224 |
+
send.click(stream_chat, [chat, msg], [chat], concurrency_limit=4)
|
225 |
+
send.click(lambda: "", None, msg)
|
226 |
|
227 |
+
# 在 Spaces 上无需 share=True;如需本地外链才用 share=True
|
228 |
+
demo.queue().launch(
|
229 |
+
# server_name="0.0.0.0",
|
230 |
+
# server_port=int(os.getenv("PORT", 7860)),
|
231 |
+
ssr_mode=False, # 关掉 SSR 提示(可选)
|
232 |
+
show_api=False
|
233 |
+
)
|