Shining-Data commited on
Commit
e017f5c
·
verified ·
1 Parent(s): fd5e60d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +428 -0
app.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import gc
4
+ import threading
5
+ from itertools import islice
6
+ from datetime import datetime
7
+ import re # for parsing <think> blocks
8
+ import gradio as gr
9
+ import torch
10
+ from transformers import pipeline, TextIteratorStreamer
11
+ from transformers import AutoTokenizer
12
+ from duckduckgo_search import DDGS
13
+ import spaces # Import spaces early to enable ZeroGPU support
14
+
15
+ # Optional: Disable GPU visibility if you wish to force CPU usage
16
+ # os.environ["CUDA_VISIBLE_DEVICES"] = ""
17
+
18
+ # ------------------------------
19
+ # Global Cancellation Event
20
+ # ------------------------------
21
+ cancel_event = threading.Event()
22
+
23
+ # ------------------------------
24
+ # Torch-Compatible Model Definitions with Adjusted Descriptions
25
+ # ------------------------------
26
+ MODELS = {
27
+ "Yee-R1-mini": {"repo_id":"sds-ai/Yee-R1-mini","description":"小熠(Yee)AI 数据安全专家"},
28
+ "secgpt-mini": {"repo_id":"clouditera/secgpt-mini","description":"SecGPT 是由 云起无垠 于 2023 年正式推出的开源大模型,专为网络安全场景打造,旨在以人工智能技术全面提升安全防护效率与效果。"},
29
+ "Qwen3-0.6B": {"repo_id":"Qwen/Qwen3-0.6B","description":"Dense causal language model with 0.6 B total parameters (0.44 B non-embedding), 28 transformer layers, 16 query heads & 8 KV heads, native 32 768-token context window, dual-mode generation, full multilingual & agentic capabilities."},
30
+ "Qwen3-1.7B": {"repo_id":"Qwen/Qwen3-1.7B","description":"Dense causal language model with 1.7 B total parameters (1.4 B non-embedding), 28 layers, 16 query heads & 8 KV heads, 32 768-token context, stronger reasoning vs. 0.6 B variant, dual-mode inference, instruction following across 100+ languages."},
31
+ }
32
+
33
+ # Global cache for pipelines to avoid re-loading.
34
+ PIPELINES = {}
35
+
36
+ def load_pipeline(model_name):
37
+ """
38
+ Load and cache a transformers pipeline for text generation.
39
+ Tries bfloat16, falls back to float16 or float32 if unsupported.
40
+ """
41
+ global PIPELINES
42
+ if model_name in PIPELINES:
43
+ return PIPELINES[model_name]
44
+ repo = MODELS[model_name]["repo_id"]
45
+ tokenizer = AutoTokenizer.from_pretrained(repo)
46
+ for dtype in (torch.bfloat16, torch.float16, torch.float32):
47
+ try:
48
+ pipe = pipeline(
49
+ task="text-generation",
50
+ model=repo,
51
+ tokenizer=tokenizer,
52
+ trust_remote_code=True,
53
+ torch_dtype=dtype,
54
+ device_map="auto"
55
+ )
56
+ PIPELINES[model_name] = pipe
57
+ return pipe
58
+ except Exception:
59
+ continue
60
+ # Final fallback
61
+ pipe = pipeline(
62
+ task="text-generation",
63
+ model=repo,
64
+ tokenizer=tokenizer,
65
+ trust_remote_code=True,
66
+ device_map="auto"
67
+ )
68
+ PIPELINES[model_name] = pipe
69
+ return pipe
70
+
71
+
72
+ def retrieve_context(query, max_results=6, max_chars=1000):
73
+ """
74
+ Retrieve search snippets from DuckDuckGo (runs in background).
75
+ Returns a list of result strings.
76
+ """
77
+ try:
78
+ with DDGS() as ddgs:
79
+ return [f"{i+1}. {r.get('title','No Title')} - {r.get('body','')[:max_chars]}"
80
+ for i, r in enumerate(islice(ddgs.text(query, region="wt-wt", safesearch="off", timelimit="y"), max_results))]
81
+ except Exception:
82
+ return []
83
+
84
+ def format_conversation(history, system_prompt, tokenizer):
85
+ if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
86
+ messages = [{"role": "system", "content": system_prompt.strip()}] + history
87
+ return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=True)
88
+ else:
89
+ # Fallback for base LMs without chat template
90
+ prompt = system_prompt.strip() + "\n"
91
+ for msg in history:
92
+ if msg['role'] == 'user':
93
+ prompt += "User: " + msg['content'].strip() + "\n"
94
+ elif msg['role'] == 'assistant':
95
+ prompt += "Assistant: " + msg['content'].strip() + "\n"
96
+ if not prompt.strip().endswith("Assistant:"):
97
+ prompt += "Assistant: "
98
+ return prompt
99
+
100
+ @spaces.GPU(duration=60)
101
+ def chat_response(user_msg, chat_history, system_prompt,
102
+ enable_search, max_results, max_chars,
103
+ model_name, max_tokens, temperature,
104
+ top_k, top_p, repeat_penalty, search_timeout):
105
+ """
106
+ Generates streaming chat responses, optionally with background web search.
107
+ """
108
+ cancel_event.clear()
109
+ history = list(chat_history or [])
110
+ history.append({'role': 'user', 'content': user_msg})
111
+
112
+ # Launch web search if enabled
113
+ debug = ''
114
+ search_results = []
115
+ if enable_search:
116
+ debug = 'Search task started.'
117
+ thread_search = threading.Thread(
118
+ target=lambda: search_results.extend(
119
+ retrieve_context(user_msg, int(max_results), int(max_chars))
120
+ )
121
+ )
122
+ thread_search.daemon = True
123
+ thread_search.start()
124
+ else:
125
+ debug = 'Web search disabled.'
126
+
127
+ try:
128
+
129
+ # merge any fetched search results into the system prompt
130
+ if search_results:
131
+ enriched = system_prompt.strip() + "\n\nRelevant context:\n" + "\n".join(search_results)
132
+ else:
133
+ enriched = system_prompt
134
+
135
+ # wait up to 1s for snippets, then replace debug with them
136
+ if enable_search:
137
+ thread_search.join(timeout=float(search_timeout))
138
+ if search_results:
139
+ debug = "### Search results merged into prompt\n\n" + "\n".join(
140
+ f"- {r}" for r in search_results
141
+ )
142
+ else:
143
+ debug = "*No web search results found.*"
144
+
145
+ # merge fetched snippets into the system prompt
146
+ if search_results:
147
+ enriched = system_prompt.strip() + "\n\nRelevant context:\n" + "\n".join(search_results)
148
+ else:
149
+ enriched = system_prompt
150
+
151
+ pipe = load_pipeline(model_name)
152
+ prompt = format_conversation(history, enriched, pipe.tokenizer)
153
+ prompt_debug = f"\n\n--- Prompt Preview ---\n```\n{prompt}\n```"
154
+ streamer = TextIteratorStreamer(pipe.tokenizer,
155
+ skip_prompt=True,
156
+ skip_special_tokens=True)
157
+ gen_thread = threading.Thread(
158
+ target=pipe,
159
+ args=(prompt,),
160
+ kwargs={
161
+ 'max_new_tokens': max_tokens,
162
+ 'temperature': temperature,
163
+ 'top_k': top_k,
164
+ 'top_p': top_p,
165
+ 'repetition_penalty': repeat_penalty,
166
+ 'streamer': streamer,
167
+ 'return_full_text': False,
168
+ }
169
+ )
170
+ gen_thread.start()
171
+
172
+ # Buffers for thought vs answer
173
+ thought_buf = ''
174
+ answer_buf = ''
175
+ in_thought = False
176
+
177
+ # Stream tokens
178
+ for chunk in streamer:
179
+ if cancel_event.is_set():
180
+ break
181
+ text = chunk
182
+
183
+ # Detect start of thinking
184
+ if not in_thought and '<think>' in text:
185
+ in_thought = True
186
+ # Insert thought placeholder
187
+ history.append({
188
+ 'role': 'assistant',
189
+ 'content': '',
190
+ 'metadata': {'title': '💭 Thought'}
191
+ })
192
+ # Capture after opening tag
193
+ after = text.split('<think>', 1)[1]
194
+ thought_buf += after
195
+ # If closing tag in same chunk
196
+ if '</think>' in thought_buf:
197
+ before, after2 = thought_buf.split('</think>', 1)
198
+ history[-1]['content'] = before.strip()
199
+ in_thought = False
200
+ # Start answer buffer
201
+ answer_buf = after2
202
+ history.append({'role': 'assistant', 'content': answer_buf})
203
+ else:
204
+ history[-1]['content'] = thought_buf
205
+ yield history, debug
206
+ continue
207
+
208
+ # Continue thought streaming
209
+ if in_thought:
210
+ thought_buf += text
211
+ if '</think>' in thought_buf:
212
+ before, after2 = thought_buf.split('</think>', 1)
213
+ history[-1]['content'] = before.strip()
214
+ in_thought = False
215
+ # Start answer buffer
216
+ answer_buf = after2
217
+ history.append({'role': 'assistant', 'content': answer_buf})
218
+ else:
219
+ history[-1]['content'] = thought_buf
220
+ yield history, debug
221
+ continue
222
+
223
+ # Stream answer
224
+ if not answer_buf:
225
+ history.append({'role': 'assistant', 'content': ''})
226
+ answer_buf += text
227
+ history[-1]['content'] = answer_buf
228
+ yield history, debug
229
+
230
+ gen_thread.join()
231
+ yield history, debug + prompt_debug
232
+ except Exception as e:
233
+ history.append({'role': 'assistant', 'content': f"Error: {e}"})
234
+ yield history, debug
235
+ finally:
236
+ gc.collect()
237
+
238
+
239
+ def cancel_generation():
240
+ cancel_event.set()
241
+ return 'Generation cancelled.'
242
+
243
+
244
+ def update_default_prompt(enable_search):
245
+ today = datetime.now().strftime('%Y-%m-%d')
246
+ return f"You are a helpful assistant. Today is {today}."
247
+
248
+
249
+ def toggle_theme(current_theme):
250
+ """Toggle between light and dark themes"""
251
+ if current_theme == "light":
252
+ return "dark", "☀️ Light Mode"
253
+ else:
254
+ return "light", "🌙 Dark Mode"
255
+
256
+
257
+ def toggle_language(current_lang):
258
+ """Toggle between Chinese and English"""
259
+ if current_lang == "zh":
260
+ return "en"
261
+ else:
262
+ return "zh"
263
+
264
+
265
+ def get_ui_text(lang):
266
+ """Get UI text based on language"""
267
+ texts = {
268
+ "zh": {
269
+ "title": "## Yee-R1 Demo",
270
+ "subtitle": "小熠(Yee)AI 数据安全专家",
271
+ "dark_mode": "🌙 暗黑模式",
272
+ "light_mode": "☀️ 明亮模式",
273
+ "lang_btn": "🌐 English",
274
+ "select_model": "选择模型",
275
+ "enable_search": "启用网络搜索",
276
+ "system_prompt": "系统提示词",
277
+ "gen_params": "### 生成参数",
278
+ "max_tokens": "最大令牌数",
279
+ "temperature": "温度",
280
+ "top_k": "Top-K",
281
+ "top_p": "Top-P",
282
+ "repeat_penalty": "重复惩罚",
283
+ "search_settings": "### 网络搜索设置",
284
+ "max_results": "最大结果数",
285
+ "max_chars": "每个结果最大字符数",
286
+ "search_timeout": "搜索超时时间 (秒)",
287
+ "clear_chat": "清空对话",
288
+ "cancel_gen": "取消生成",
289
+ "placeholder": "输入您的消息并按回车..."
290
+ },
291
+ "en": {
292
+ "title": "## Yee-R1 Demo",
293
+ "subtitle": "Yee AI Data Security Expert",
294
+ "dark_mode": "🌙 Dark Mode",
295
+ "light_mode": "☀️ Light Mode",
296
+ "lang_btn": "🌐 中文",
297
+ "select_model": "Select Model",
298
+ "enable_search": "Enable Web Search",
299
+ "system_prompt": "System Prompt",
300
+ "gen_params": "### Generation Parameters",
301
+ "max_tokens": "Max Tokens",
302
+ "temperature": "Temperature",
303
+ "top_k": "Top-K",
304
+ "top_p": "Top-P",
305
+ "repeat_penalty": "Repetition Penalty",
306
+ "search_settings": "### Web Search Settings",
307
+ "max_results": "Max Results",
308
+ "max_chars": "Max Chars/Result",
309
+ "search_timeout": "Search Timeout (s)",
310
+ "clear_chat": "Clear Chat",
311
+ "cancel_gen": "Cancel Generation",
312
+ "placeholder": "Type your message and press Enter..."
313
+ }
314
+ }
315
+ return texts[lang]
316
+
317
+
318
+ # ------------------------------
319
+ # Gradio UI
320
+ # ------------------------------
321
+ with gr.Blocks(title="Yee-R1-Demo", theme=gr.themes.Default()) as demo:
322
+ # States
323
+ theme_state = gr.State("light")
324
+ lang_state = gr.State("zh")
325
+
326
+ # Header with controls
327
+ with gr.Row():
328
+ title_md = gr.Markdown("## Yee-R1 Demo")
329
+ with gr.Row(scale=0):
330
+ lang_btn = gr.Button("🌐 English", size="sm")
331
+ theme_btn = gr.Button("🌙 暗黑模式", size="sm")
332
+
333
+ subtitle_md = gr.Markdown("小熠(Yee)AI 数据安全专家")
334
+
335
+ with gr.Row():
336
+ with gr.Column(scale=3):
337
+ model_dd = gr.Dropdown(label="选择模型", choices=list(MODELS.keys()), value=list(MODELS.keys())[0])
338
+ search_chk = gr.Checkbox(label="启用网络搜索", value=True)
339
+ sys_prompt = gr.Textbox(label="系统提示词", lines=3, value=update_default_prompt(search_chk.value))
340
+ gen_params_md = gr.Markdown("### 生成参数")
341
+ max_tok = gr.Slider(64, 16384, value=4096, step=32, label="最大令牌数")
342
+ temp = gr.Slider(0.1, 2.0, value=0.6, step=0.1, label="温度")
343
+ k = gr.Slider(1, 100, value=40, step=1, label="Top-K")
344
+ p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
345
+ rp = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="重复惩罚")
346
+ search_settings_md = gr.Markdown("### 网络搜索设置")
347
+ mr = gr.Number(value=6, precision=0, label="最大结果数")
348
+ mc = gr.Number(value=600, precision=0, label="每个结果最大字符数")
349
+ st = gr.Slider(minimum=0.0, maximum=30.0, step=0.5, value=5.0, label="搜索超时时间 (秒)")
350
+ clr = gr.Button("清空对话")
351
+ cnl = gr.Button("取消生成")
352
+ with gr.Column(scale=7):
353
+ chat = gr.Chatbot(type="messages", show_copy_all_button=True, height="50vh")
354
+ txt = gr.Textbox(placeholder="输入您的消息并按回车...")
355
+ dbg = gr.Markdown()
356
+
357
+ # Event handlers
358
+ search_chk.change(fn=update_default_prompt, inputs=search_chk, outputs=sys_prompt)
359
+ clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg])
360
+ cnl.click(fn=cancel_generation, outputs=dbg)
361
+
362
+ # Theme toggle functionality
363
+ def handle_theme_toggle(current_theme, current_lang):
364
+ new_theme, _ = toggle_theme(current_theme)
365
+ ui_text = get_ui_text(current_lang)
366
+ new_btn_text = ui_text["light_mode"] if new_theme == "dark" else ui_text["dark_mode"]
367
+
368
+ if new_theme == "dark":
369
+ demo._theme = gr.themes.Monochrome()
370
+ else:
371
+ demo._theme = gr.themes.Default()
372
+ return new_theme, new_btn_text
373
+
374
+ # Language toggle functionality
375
+ def handle_language_toggle(current_lang, current_theme):
376
+ new_lang = toggle_language(current_lang)
377
+ ui_text = get_ui_text(new_lang)
378
+
379
+ # Update all UI text
380
+ updates = [
381
+ new_lang, # lang_state
382
+ ui_text["lang_btn"], # lang_btn
383
+ ui_text["light_mode"] if current_theme == "dark" else ui_text["dark_mode"], # theme_btn
384
+ ui_text["title"], # title_md
385
+ ui_text["subtitle"], # subtitle_md
386
+ ui_text["select_model"], # model_dd label
387
+ ui_text["enable_search"], # search_chk label
388
+ ui_text["system_prompt"], # sys_prompt label
389
+ ui_text["gen_params"], # gen_params_md
390
+ ui_text["max_tokens"], # max_tok label
391
+ ui_text["temperature"], # temp label
392
+ ui_text["top_k"], # k label
393
+ ui_text["top_p"], # p label
394
+ ui_text["repeat_penalty"], # rp label
395
+ ui_text["search_settings"], # search_settings_md
396
+ ui_text["max_results"], # mr label
397
+ ui_text["max_chars"], # mc label
398
+ ui_text["search_timeout"], # st label
399
+ ui_text["clear_chat"], # clr
400
+ ui_text["cancel_gen"], # cnl
401
+ ui_text["placeholder"] # txt placeholder
402
+ ]
403
+
404
+ return updates
405
+
406
+ theme_btn.click(
407
+ fn=handle_theme_toggle,
408
+ inputs=[theme_state, lang_state],
409
+ outputs=[theme_state, theme_btn]
410
+ )
411
+
412
+ lang_btn.click(
413
+ fn=handle_language_toggle,
414
+ inputs=[lang_state, theme_state],
415
+ outputs=[
416
+ lang_state, lang_btn, theme_btn, title_md, subtitle_md,
417
+ model_dd, search_chk, sys_prompt, gen_params_md,
418
+ max_tok, temp, k, p, rp, search_settings_md,
419
+ mr, mc, st, clr, cnl, txt
420
+ ]
421
+ )
422
+
423
+ txt.submit(fn=chat_response,
424
+ inputs=[txt, chat, sys_prompt, search_chk, mr, mc,
425
+ model_dd, max_tok, temp, k, p, rp, st],
426
+ outputs=[chat, dbg])
427
+
428
+ demo.launch()