Gilvaa commited on
Commit
85700de
·
verified ·
1 Parent(s): fdeb105

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -151
app.py CHANGED
@@ -1,108 +1,78 @@
1
- import os, re, torch
2
  import gradio as gr
3
  from threading import Thread
4
  from transformers import (
5
  AutoTokenizer, AutoModelForCausalLM,
6
  TextIteratorStreamer, BitsAndBytesConfig
7
  )
8
- from peft import PeftModel
9
 
10
- # =====================
11
- # 配置:支持从环境变量切换
12
- # =====================
13
- MODEL_ID = os.getenv("MODEL_ID", "Qwen/Qwen2.5-7B-Instruct").strip()
14
- ADAPTER_ID = os.getenv("ADAPTER_ID", "").strip() # 可选:LoRA 适配器仓库名,留空则不用
15
  MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "256"))
16
- TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
17
  TOP_P = float(os.getenv("TOP_P", "0.9"))
18
- REPETITION_PENALTY = float(os.getenv("REPETITION_PENALTY", "1.05"))
 
19
 
20
- print(f"[boot] MODEL_ID={MODEL_ID} ADAPTER_ID={ADAPTER_ID or '(none)'}")
21
  print(f"[boot] torch.cuda.is_available={torch.cuda.is_available()}")
22
 
23
- # =====================
24
- # 4-bit 量化 (T4 用 FP16)
25
- # =====================
26
- bnb_config = BitsAndBytesConfig(
27
- load_in_4bit=True,
28
- bnb_4bit_quant_type="nf4",
29
- bnb_4bit_use_double_quant=True,
30
- bnb_4bit_compute_dtype=torch.float16, # T4: FP16
31
- )
32
-
33
- # =====================
 
 
 
34
  # 加载 tokenizer
35
- # =====================
36
  tokenizer = AutoTokenizer.from_pretrained(
37
  MODEL_ID, use_fast=True, trust_remote_code=True
38
  )
39
-
40
- # 若无 pad_token,用 eos 兜底
41
  if tokenizer.pad_token is None:
42
  tokenizer.pad_token = tokenizer.eos_token
43
 
44
- # =====================
45
  # 加载 model
46
- # =====================
47
- model = AutoModelForCausalLM.from_pretrained(
48
- MODEL_ID,
49
- device_map="auto",
50
- quantization_config=bnb_config,
51
- torch_dtype=torch.float16, # T4: FP16
52
- trust_remote_code=True,
53
- )
54
- # 可选:套 LoRA(社区 uncensored LoRA 时用)
55
- if ADAPTER_ID:
56
- print(f"[boot] loading LoRA adapter: {ADAPTER_ID}")
57
- model = PeftModel.from_pretrained(model, ADAPTER_ID)
 
 
 
 
 
 
 
 
58
 
59
  model.eval()
60
-
61
- # 设定生成配置里的 pad/eos(防止生成卡住)
62
  model.generation_config.eos_token_id = tokenizer.eos_token_id
63
  model.generation_config.pad_token_id = tokenizer.eos_token_id
64
 
65
- print(f"[boot] model loaded on device(s): {set(str(p.device) for p in model.parameters())}")
66
-
67
- # =====================
68
- # 系统底线(硬约束,不可越界)
69
- # =====================
70
- SYSTEM_SAFETY = """
71
- 你必须遵守以下规则(任何情况下都不得违反):
72
- - 禁止变态/暴力/强迫内容;
73
- - 严禁涉及未成年人或暗示未成年人;
74
- - 不提供线下约会、联系方式索取/交换的引导;
75
- - 避免仇恨、歧视、违法、医疗/心理诊断与处方;
76
- - 若用户请求超出边界,礼貌拒绝并将话题引导至健康、尊重的沟通表达练习。
77
- """
78
-
79
- # =====================
80
- # Persona(成熟御姐|暧昧撩人但有分寸)
81
- # =====================
82
- PERSONA = """
83
- 你的人设:成熟御姐,优雅大方,语气撩人,偏口语化与暧昧式夸赞。
84
- - 关键词:自信、幽默、懂分寸、情话高手、温柔而露骨。
85
- - 风格:简洁、有节奏的短句;多使用比喻、轻度暧昧的赞美与玩笑;
86
- 适当使用直白词汇;适度设置边界,尊重对方。
87
- - 输出语言:用户用中文就中文;用英文就英文。
88
- - 当遇到越界或不适宜的话题时:温柔拒绝,并用更健康的话题替代(如表达技巧、相处之道、仪式感)。
89
- """
90
-
91
- # =====================
92
- # few-shot(固定语气与边界)
93
- # =====================
94
- FEW_SHOTS = [
95
- {"role": "user", "content": "在忙吗?想跟你聊会儿。"},
96
- {"role": "assistant", "content": "忙到你为止呀~现在给你开专属通道。想听我夸你,还是来点轻微心跳的鼓励?"},
97
- {"role": "user", "content": "说点更刺激的。"},
98
- {"role": "assistant", "content": "坏家伙~火力别开太猛。但我可以把暧昧调高一点点:比如我会贴近你耳边,慢慢说“想你”,让你今晚被甜到失眠。这样的“刺激”,刚刚好?"},
99
- {"role": "user", "content": "今晚出来见面吗?"},
100
- {"role": "assistant", "content": "我更喜欢慢慢靠近的感觉。先把心事聊透,再计划线下的仪式感,好吗?现在先约个“云端散步”:你说三个想去的地方,我用三段专属描写把你带过去。"},
101
- ]
102
 
103
- # =====================
104
- # 轻量过滤(示例,可自行扩展)
105
- # =====================
106
  BANNED_PATTERNS = [
107
  r"(?i)未成年|未成年的|中学生|小学生",
108
  r"(?i)强迫|胁迫|迷奸|药物控制",
@@ -112,45 +82,47 @@ BANNED_PATTERNS = [
112
  SAFE_REPLACEMENT = "( ̄^ ̄)ゞ 哼哼~"
113
 
114
  def violates(text: str) -> bool:
115
- if not text:
116
  return False
117
  for p in BANNED_PATTERNS:
118
  if re.search(p, text):
119
  return True
120
  return False
121
 
122
- # =====================
123
- # Prompt 构建
124
- # =====================
125
- def build_system_prompt() -> str:
126
- return f"{SYSTEM_SAFETY.strip()}\n\n=== Persona ===\n{PERSONA.strip()}"
 
 
 
 
 
 
 
 
 
127
 
128
  def build_prompt(history_msgs, user_msg: str) -> str:
129
  """
130
  history_msgs: Chatbot(type='messages') 的历史 [{role, content}, ...]
131
  """
132
- messages = [{"role": "system", "content": build_system_prompt()}]
133
-
134
- # few-shot 先注入(固定风格)
135
- messages.extend(FEW_SHOTS)
136
-
137
- # 取最近若干条历史(仅 user/assistant)
138
  tail = [m for m in history_msgs if m.get("role") in ("user", "assistant")]
139
  tail = tail[-8:] if len(tail) > 8 else tail
140
  messages.extend(tail)
141
-
142
- # 本轮用户输入
143
  messages.append({"role": "user", "content": user_msg})
144
 
145
- # Qwen 专用聊天模板
146
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
147
  return prompt
148
 
149
- # =====================
150
- # 生成参数 & 流式输出
151
- # =====================
152
- GEN_KW = dict(
153
- max_new_tokens=MAX_NEW_TOKENS,
154
  temperature=TEMPERATURE,
155
  top_p=TOP_P,
156
  repetition_penalty=REPETITION_PENALTY,
@@ -159,75 +131,77 @@ GEN_KW = dict(
159
  pad_token_id=tokenizer.eos_token_id,
160
  )
161
 
 
 
 
162
  def stream_chat(history_msgs, user_msg):
163
- if not user_msg or not user_msg.strip():
164
- yield history_msgs
165
- return
166
-
167
- # 输入侧轻过滤
168
- if violates(user_msg):
169
- return_history = history_msgs + [
170
- {"role": "user", "content": user_msg},
171
- {"role": "assistant", "content": SAFE_REPLACEMENT},
172
- ]
173
- yield return_history
174
- return
175
-
176
- prompt = build_prompt(history_msgs, user_msg)
177
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
178
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
179
-
180
- gen_kwargs = dict(**inputs, streamer=streamer, **GEN_KW)
181
- th = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
182
- th.start()
183
-
184
- reply = ""
185
- for new_text in streamer:
186
- reply += new_text
187
- # 输出侧轻过滤:命中即替换并结束本轮
188
- if violates(reply):
189
- safe_hist = history_msgs + [
190
- {"role": "user", "content": user_msg},
191
- {"role": "assistant", "content": SAFE_REPLACEMENT},
192
  ]
193
- yield safe_hist
194
  return
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  yield history_msgs + [
197
- {"role": "user", "content": user_msg},
198
- {"role": "assistant", "content": reply},
199
  ]
200
 
201
- # =====================
202
  # Gradio UI(移动端友好)
203
- # =====================
204
  CSS = """
205
  .gradio-container{ max-width:640px; margin:auto; }
206
  footer{ display:none !important; }
207
  """
208
 
209
  with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
210
- gr.Markdown("### 💋 御姐聊天 · Mobile Web\n温柔撩人,但始终优雅有分寸。")
211
-
212
  chat = gr.Chatbot(type="messages", height=520, show_copy_button=True)
213
-
214
  with gr.Row():
215
- msg = gr.Textbox(placeholder="想跟御姐聊点什么?(回车发送)", autofocus=True)
216
  send = gr.Button("发送", variant="primary")
217
  clear = gr.Button("清空对话")
218
 
219
  clear.click(lambda: [], outputs=[chat])
 
 
220
 
221
- # 绑定事件(支持回车 & 点击)
222
- msg.submit(stream_chat, [chat, msg], [chat], concurrency_limit=4)
223
- msg.submit(lambda: "", None, msg)
224
- send.click(stream_chat, [chat, msg], [chat], concurrency_limit=4)
225
- send.click(lambda: "", None, msg)
226
-
227
- # 在 Spaces 上无需 share=True;如需本地外链才用 share=True
228
- demo.queue().launch(
229
- # server_name="0.0.0.0",
230
- # server_port=int(os.getenv("PORT", 7860)),
231
- ssr_mode=False, # 关掉 SSR 提示(可选)
232
- show_api=False
233
- )
 
1
+ import os, re, torch, traceback
2
  import gradio as gr
3
  from threading import Thread
4
  from transformers import (
5
  AutoTokenizer, AutoModelForCausalLM,
6
  TextIteratorStreamer, BitsAndBytesConfig
7
  )
 
8
 
9
+ # ======================
10
+ # 可调参数(也可用 Space 的 Variables 覆盖)
11
+ # ======================
12
+ MODEL_ID = os.getenv("MODEL_ID", "happzy2633/qwen2.5-7b-ins-v3").strip()
 
13
  MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "256"))
14
+ TEMPERATURE = float(os.getenv("TEMPERATURE", "0.75"))
15
  TOP_P = float(os.getenv("TOP_P", "0.9"))
16
+ REPETITION_PENALTY = float(os.getenv("REPETITION_PENALTY", "1.08"))
17
+ SAFE_MODE = os.getenv("SAFE_MODE", "1") != "0" # 1=开启基础过滤;想关就设为 0
18
 
19
+ print(f"[boot] MODEL_ID={MODEL_ID}")
20
  print(f"[boot] torch.cuda.is_available={torch.cuda.is_available()}")
21
 
22
+ # ======================
23
+ # 4bit 量化(T4 用 FP16 计算精度)
24
+ # ======================
25
+ if torch.cuda.is_available():
26
+ bnb_config = BitsAndBytesConfig(
27
+ load_in_4bit=True,
28
+ bnb_4bit_quant_type="nf4",
29
+ bnb_4bit_use_double_quant=True,
30
+ bnb_4bit_compute_dtype=torch.float16, # T4: FP16
31
+ )
32
+ else:
33
+ bnb_config = None # CPU 情况下不做 4bit(仅烟测时用小模型更合适)
34
+
35
+ # ======================
36
  # 加载 tokenizer
37
+ # ======================
38
  tokenizer = AutoTokenizer.from_pretrained(
39
  MODEL_ID, use_fast=True, trust_remote_code=True
40
  )
 
 
41
  if tokenizer.pad_token is None:
42
  tokenizer.pad_token = tokenizer.eos_token
43
 
44
+ # ======================
45
  # 加载 model
46
+ # ======================
47
+ if torch.cuda.is_available():
48
+ model = AutoModelForCausalLM.from_pretrained(
49
+ MODEL_ID,
50
+ device_map="auto",
51
+ quantization_config=bnb_config,
52
+ torch_dtype=torch.float16, # T4: FP16
53
+ trust_remote_code=True,
54
+ )
55
+ else:
56
+ # GPU 时仅用于链路自测:建议把 MODEL_ID 换成 1.5B 基座以免过慢
57
+ print("[boot] No GPU detected. Running on CPU is very slow for 7B. "
58
+ "Consider setting MODEL_ID=Qwen/Qwen2.5-1.5B-Instruct for smoke test.")
59
+ model = AutoModelForCausalLM.from_pretrained(
60
+ MODEL_ID,
61
+ device_map="cpu",
62
+ torch_dtype=torch.float32,
63
+ trust_remote_code=True,
64
+ low_cpu_mem_usage=True,
65
+ )
66
 
67
  model.eval()
 
 
68
  model.generation_config.eos_token_id = tokenizer.eos_token_id
69
  model.generation_config.pad_token_id = tokenizer.eos_token_id
70
 
71
+ print(f"[boot] model device: {next(model.parameters()).device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ # ======================
74
+ # (可选)基础安全过滤:想关就设 SAFE_MODE=0
75
+ # ======================
76
  BANNED_PATTERNS = [
77
  r"(?i)未成年|未成年的|中学生|小学生",
78
  r"(?i)强迫|胁迫|迷奸|药物控制",
 
82
  SAFE_REPLACEMENT = "( ̄^ ̄)ゞ 哼哼~"
83
 
84
  def violates(text: str) -> bool:
85
+ if not SAFE_MODE or not text:
86
  return False
87
  for p in BANNED_PATTERNS:
88
  if re.search(p, text):
89
  return True
90
  return False
91
 
92
+ # ======================
93
+ # 动态长度:根据输入长短调 max_new_tokens
94
+ # ======================
95
+ def choose_max_new_tokens(user_text: str) -> int:
96
+ n = len(user_text or "")
97
+ if n <= 15: return min(160, MAX_NEW_TOKENS)
98
+ if n <= 50: return min(256, MAX_NEW_TOKENS)
99
+ if n <= 120: return min(320, MAX_NEW_TOKENS + 64)
100
+ return min(384, MAX_NEW_TOKENS + 128)
101
+
102
+ # ======================
103
+ # 构建 Qwen 模板 Prompt(messages 形式 → chat_template)
104
+ # ======================
105
+ SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "You are a helpful, concise chat assistant. Avoid unsafe content.")
106
 
107
  def build_prompt(history_msgs, user_msg: str) -> str:
108
  """
109
  history_msgs: Chatbot(type='messages') 的历史 [{role, content}, ...]
110
  """
111
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
 
 
 
 
 
112
  tail = [m for m in history_msgs if m.get("role") in ("user", "assistant")]
113
  tail = tail[-8:] if len(tail) > 8 else tail
114
  messages.extend(tail)
 
 
115
  messages.append({"role": "user", "content": user_msg})
116
 
117
+ prompt = tokenizer.apply_chat_template(
118
+ messages, tokenize=False, add_generation_prompt=True
119
+ )
120
  return prompt
121
 
122
+ # ======================
123
+ # 生成参数(默认档)
124
+ # ======================
125
+ BASE_GEN_KW = dict(
 
126
  temperature=TEMPERATURE,
127
  top_p=TOP_P,
128
  repetition_penalty=REPETITION_PENALTY,
 
131
  pad_token_id=tokenizer.eos_token_id,
132
  )
133
 
134
+ # ======================
135
+ # 主推理:流式输出
136
+ # ======================
137
  def stream_chat(history_msgs, user_msg):
138
+ try:
139
+ if not user_msg or not user_msg.strip():
140
+ yield history_msgs; return
141
+
142
+ if violates(user_msg):
143
+ yield history_msgs + [
144
+ {"role":"user","content": user_msg},
145
+ {"role":"assistant","content": SAFE_REPLACEMENT},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  ]
 
147
  return
148
 
149
+ prompt = build_prompt(history_msgs, user_msg)
150
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
151
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
152
+
153
+ gen_kwargs = dict(
154
+ **inputs, streamer=streamer,
155
+ max_new_tokens=choose_max_new_tokens(user_msg),
156
+ **BASE_GEN_KW
157
+ )
158
+
159
+ print("[gen] start")
160
+ th = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
161
+ th.start()
162
+
163
+ reply = ""
164
+ for chunk in streamer:
165
+ reply += chunk
166
+ if violates(reply):
167
+ yield history_msgs + [
168
+ {"role":"user","content": user_msg},
169
+ {"role":"assistant","content": SAFE_REPLACEMENT},
170
+ ]
171
+ return
172
+ yield history_msgs + [
173
+ {"role":"user","content": user_msg},
174
+ {"role":"assistant","content": reply},
175
+ ]
176
+ print("[gen] done, len:", len(reply))
177
+
178
+ except Exception as e:
179
+ traceback.print_exc()
180
+ err = f"【运行异常】{type(e).__name__}: {e}"
181
  yield history_msgs + [
182
+ {"role":"user","content": user_msg},
183
+ {"role":"assistant","content": err},
184
  ]
185
 
186
+ # ======================
187
  # Gradio UI(移动端友好)
188
+ # ======================
189
  CSS = """
190
  .gradio-container{ max-width:640px; margin:auto; }
191
  footer{ display:none !important; }
192
  """
193
 
194
  with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
195
+ gr.Markdown("### 🤖 Ins-v3 · Mobile Web Chat\n(happzy2633 / qwen2.5-7b-ins-v3 · 4bit 流式)")
 
196
  chat = gr.Chatbot(type="messages", height=520, show_copy_button=True)
 
197
  with gr.Row():
198
+ msg = gr.Textbox(placeholder="说点什么…(回车发送)", autofocus=True)
199
  send = gr.Button("发送", variant="primary")
200
  clear = gr.Button("清空对话")
201
 
202
  clear.click(lambda: [], outputs=[chat])
203
+ msg.submit(stream_chat, [chat, msg], [chat], concurrency_limit=4); msg.submit(lambda:"", None, msg)
204
+ send.click(stream_chat, [chat, msg], [chat], concurrency_limit=4); send.click(lambda:"", None, msg)
205
 
206
+ # Spaces 上无需 share=True
207
+ demo.queue().launch(ssr_mode=False, show_api=False)