Gilvaa commited on
Commit
d0e2818
·
verified ·
1 Parent(s): 83d2d2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -28
app.py CHANGED
@@ -82,18 +82,11 @@ def build_system_prompt():
82
  # 把强约束和人设合并为 system prompt
83
  return f"{SYSTEM_SAFETY.strip()}\n\n=== Persona ===\n{PERSONA.strip()}"
84
 
85
- def build_prompt(history, user_msg):
86
  messages = [{"role": "system", "content": build_system_prompt()}]
87
-
88
- # 先注入 few-shot(固定风格与边界)
89
- messages.extend(FEW_SHOTS)
90
-
91
- # 再接最近对话,避免超过上下文
92
- recent = history[-4:] if len(history) > 4 else history
93
- for u, a in recent:
94
- if u: messages.append({"role": "user", "content": u})
95
- if a: messages.append({"role": "assistant", "content": a})
96
-
97
  messages.append({"role": "user", "content": user_msg})
98
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
99
  return prompt
@@ -108,37 +101,39 @@ GEN_KW = dict(
108
  eos_token_id=tokenizer.eos_token_id
109
  )
110
 
111
- def stream_chat(history, user_msg):
112
  if not user_msg or not user_msg.strip():
113
- yield history
114
  return
115
 
116
- # 输入侧轻过滤
117
  if violates(user_msg):
118
- reply = SAFE_REPLACEMENT
119
- yield history + [[user_msg, reply]]
 
 
120
  return
121
 
122
- prompt = build_prompt(history, user_msg)
123
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
124
-
125
- streamer = TextIteratorStreamer(
126
- tokenizer, skip_prompt=True, skip_special_tokens=True
127
- )
128
 
129
  gen_kwargs = dict(**inputs, streamer=streamer, **GEN_KW)
130
- thread = Thread(target=model.generate, kwargs=gen_kwargs)
131
- thread.start()
132
 
133
  reply = ""
134
  for new_text in streamer:
135
  reply += new_text
136
- # 输出侧轻过滤:一旦命中,立即替换为安全话术
137
  if violates(reply):
138
  reply = SAFE_REPLACEMENT
139
- yield history + [[user_msg, reply]]
 
 
 
140
  return
141
- yield history + [[user_msg, reply]]
 
 
 
142
 
143
  # ======== Gradio UI(移动端友好) ========
144
  CSS = """
@@ -148,7 +143,8 @@ footer{ display:none !important; }
148
 
149
  with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
150
  gr.Markdown("### 💋 御姐聊天 · Mobile Web\n温柔撩人,但始终优雅有分寸。")
151
- chat = gr.Chatbot(height=520, bubble_full_width=False, show_copy_button=True)
 
152
  with gr.Row():
153
  msg = gr.Textbox(placeholder="想跟姐姐聊点什么?(回车发送)", autofocus=True)
154
  send = gr.Button("发送", variant="primary")
@@ -160,4 +156,4 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
160
  msg.submit(stream_chat, [chat, msg], [chat]); msg.submit(lambda:"", None, msg)
161
  send.click(stream_chat, [chat, msg], [chat]); send.click(lambda:"", None, msg)
162
 
163
- demo.queue().launch()
 
82
  # 把强约束和人设合并为 system prompt
83
  return f"{SYSTEM_SAFETY.strip()}\n\n=== Persona ===\n{PERSONA.strip()}"
84
 
85
+ def build_prompt(history_msgs, user_msg):
86
  messages = [{"role": "system", "content": build_system_prompt()}]
87
+ messages.extend(FEW_SHOTS) # 先注入 few-shot
88
+ tail = history_msgs[-8:] if len(history_msgs) > 8 else history_msgs
89
+ messages.extend(tail) # 最近几条历史(messages 形式)
 
 
 
 
 
 
 
90
  messages.append({"role": "user", "content": user_msg})
91
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
92
  return prompt
 
101
  eos_token_id=tokenizer.eos_token_id
102
  )
103
 
104
+ def stream_chat(history_msgs, user_msg):
105
  if not user_msg or not user_msg.strip():
106
+ yield history_msgs
107
  return
108
 
 
109
  if violates(user_msg):
110
+ yield history_msgs + [
111
+ {"role":"user","content": user_msg},
112
+ {"role":"assistant","content": SAFE_REPLACEMENT},
113
+ ]
114
  return
115
 
116
+ prompt = build_prompt(history_msgs, user_msg)
117
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
118
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
 
 
119
 
120
  gen_kwargs = dict(**inputs, streamer=streamer, **GEN_KW)
121
+ Thread(target=model.generate, kwargs=gen_kwargs).start()
 
122
 
123
  reply = ""
124
  for new_text in streamer:
125
  reply += new_text
 
126
  if violates(reply):
127
  reply = SAFE_REPLACEMENT
128
+ yield history_msgs + [
129
+ {"role":"user","content": user_msg},
130
+ {"role":"assistant","content": reply},
131
+ ]
132
  return
133
+ yield history_msgs + [
134
+ {"role":"user","content": user_msg},
135
+ {"role":"assistant","content": reply},
136
+ ]
137
 
138
  # ======== Gradio UI(移动端友好) ========
139
  CSS = """
 
143
 
144
  with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
145
  gr.Markdown("### 💋 御姐聊天 · Mobile Web\n温柔撩人,但始终优雅有分寸。")
146
+ #chat = gr.Chatbot(height=520, bubble_full_width=False, show_copy_button=True)
147
+ chat = gr.Chatbot(type="messages", height=520, show_copy_button=True)
148
  with gr.Row():
149
  msg = gr.Textbox(placeholder="想跟姐姐聊点什么?(回车发送)", autofocus=True)
150
  send = gr.Button("发送", variant="primary")
 
156
  msg.submit(stream_chat, [chat, msg], [chat]); msg.submit(lambda:"", None, msg)
157
  send.click(stream_chat, [chat, msg], [chat]); send.click(lambda:"", None, msg)
158
 
159
+ demo.queue().launch(share=True)