import os import time import gc import threading from itertools import islice from datetime import datetime import re # for parsing blocks import gradio as gr from typing import Dict, Union import torch from transformers import pipeline, TextIteratorStreamer from transformers import AutoTokenizer from duckduckgo_search import DDGS import spaces # Import spaces early to enable ZeroGPU support # Optional: Disable GPU visibility if you wish to force CPU usage # os.environ["CUDA_VISIBLE_DEVICES"] = "" # ------------------------------ # Global Cancellation Event # ------------------------------ cancel_event = threading.Event() # ------------------------------ # Torch-Compatible Model Definitions with Adjusted Descriptions # ------------------------------ MODELS = { "Yee-R1-mini": {"repo_id":"sds-ai/Yee-R1-mini","description":"小熠(Yee)AI 数据安全专家"}, "secgpt-mini": {"repo_id":"clouditera/secgpt-mini","description":"SecGPT 是由 云起无垠 于 2023 年正式推出的开源大模型,专为网络安全场景打造,旨在以人工智能技术全面提升安全防护效率与效果。"}, "Qwen3-0.6B": {"repo_id":"Qwen/Qwen3-0.6B","description":"Dense causal language model with 0.6 B total parameters (0.44 B non-embedding), 28 transformer layers, 16 query heads & 8 KV heads, native 32 768-token context window, dual-mode generation, full multilingual & agentic capabilities."}, "Qwen3-1.7B": {"repo_id":"Qwen/Qwen3-1.7B","description":"Dense causal language model with 1.7 B total parameters (1.4 B non-embedding), 28 layers, 16 query heads & 8 KV heads, 32 768-token context, stronger reasoning vs. 0.6 B variant, dual-mode inference, instruction following across 100+ languages."}, } # Global cache for pipelines to avoid re-loading. PIPELINES = {} def load_pipeline(model_name): """ Load and cache a transformers pipeline for text generation. Tries bfloat16, falls back to float16 or float32 if unsupported. """ global PIPELINES if model_name in PIPELINES: return PIPELINES[model_name] repo = MODELS[model_name]["repo_id"] tokenizer = AutoTokenizer.from_pretrained(repo) for dtype in (torch.bfloat16, torch.float16, torch.float32): try: pipe = pipeline( task="text-generation", model=repo, tokenizer=tokenizer, trust_remote_code=True, torch_dtype=dtype, device_map="auto" ) PIPELINES[model_name] = pipe return pipe except Exception: continue # Final fallback pipe = pipeline( task="text-generation", model=repo, tokenizer=tokenizer, trust_remote_code=True, device_map="auto" ) PIPELINES[model_name] = pipe return pipe def retrieve_context(query, max_results=6, max_chars=1000): """ Retrieve search snippets from DuckDuckGo (runs in background). Returns a list of result strings. """ try: with DDGS() as ddgs: return [f"{i+1}. {r.get('title','No Title')} - {r.get('body','')[:max_chars]}" for i, r in enumerate(islice(ddgs.text(query, region="wt-wt", safesearch="off", timelimit="y"), max_results))] except Exception: return [] def format_conversation(history, system_prompt, tokenizer): if hasattr(tokenizer, "chat_template") and tokenizer.chat_template: messages = [{"role": "system", "content": system_prompt.strip()}] + history return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=True) else: # Fallback for base LMs without chat template prompt = system_prompt.strip() + "\n" for msg in history: if msg['role'] == 'user': prompt += "User: " + msg['content'].strip() + "\n" elif msg['role'] == 'assistant': prompt += "Assistant: " + msg['content'].strip() + "\n" if not prompt.strip().endswith("Assistant:"): prompt += "Assistant: " return prompt @spaces.GPU(duration=60) def chat_response(user_msg, chat_history, system_prompt, enable_search, max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty, search_timeout): """ Generates streaming chat responses, optionally with background web search. """ cancel_event.clear() history = list(chat_history or []) history.append({'role': 'user', 'content': user_msg}) # Launch web search if enabled debug = '' search_results = [] if enable_search: debug = 'Search task started.' thread_search = threading.Thread( target=lambda: search_results.extend( retrieve_context(user_msg, int(max_results), int(max_chars)) ) ) thread_search.daemon = True thread_search.start() else: debug = 'Web search disabled.' try: # merge any fetched search results into the system prompt if search_results: enriched = system_prompt.strip() + "\n\nRelevant context:\n" + "\n".join(search_results) else: enriched = system_prompt # wait up to 1s for snippets, then replace debug with them if enable_search: thread_search.join(timeout=float(search_timeout)) if search_results: debug = "### Search results merged into prompt\n\n" + "\n".join( f"- {r}" for r in search_results ) else: debug = "*No web search results found.*" # merge fetched snippets into the system prompt if search_results: enriched = system_prompt.strip() + "\n\nRelevant context:\n" + "\n".join(search_results) else: enriched = system_prompt pipe = load_pipeline(model_name) prompt = format_conversation(history, enriched, pipe.tokenizer) prompt_debug = f"\n\n--- Prompt Preview ---\n```\n{prompt}\n```" streamer = TextIteratorStreamer(pipe.tokenizer, skip_prompt=True, skip_special_tokens=True) gen_thread = threading.Thread( target=pipe, args=(prompt,), kwargs={ 'max_new_tokens': max_tokens, 'temperature': temperature, 'top_k': top_k, 'top_p': top_p, 'repetition_penalty': repeat_penalty, 'streamer': streamer, 'return_full_text': False, } ) gen_thread.start() # Buffers for thought vs answer thought_buf = '' answer_buf = '' in_thought = False # Stream tokens for chunk in streamer: if cancel_event.is_set(): break text = chunk # Detect start of thinking if not in_thought and '' in text: in_thought = True # Insert thought placeholder history.append({ 'role': 'assistant', 'content': '', 'metadata': {'title': '💭 Thought'} }) # Capture after opening tag after = text.split('', 1)[1] thought_buf += after # If closing tag in same chunk if '' in thought_buf: before, after2 = thought_buf.split('', 1) history[-1]['content'] = before.strip() in_thought = False # Start answer buffer answer_buf = after2 history.append({'role': 'assistant', 'content': answer_buf}) else: history[-1]['content'] = thought_buf yield history, debug continue # Continue thought streaming if in_thought: thought_buf += text if '' in thought_buf: before, after2 = thought_buf.split('', 1) history[-1]['content'] = before.strip() in_thought = False # Start answer buffer answer_buf = after2 history.append({'role': 'assistant', 'content': answer_buf}) else: history[-1]['content'] = thought_buf yield history, debug continue # Stream answer if not answer_buf: history.append({'role': 'assistant', 'content': ''}) answer_buf += text history[-1]['content'] = answer_buf yield history, debug gen_thread.join() yield history, debug + prompt_debug except Exception as e: history.append({'role': 'assistant', 'content': f"Error: {e}"}) yield history, debug finally: gc.collect() def cancel_generation(): cancel_event.set() return 'Generation cancelled.' def update_default_prompt(enable_search): today = datetime.now().strftime('%Y-%m-%d') return f"You are a helpful assistant. Today is {today}." # ------------------------------ # Gradio UI # ------------------------------ # UI strings translations UI_TEXTS: Dict[str, Dict[str, Union[str, Dict[str, str]]]] = { "en": { "title": "Yee-R1 Demo", "description": "Yee AI Data Security Expert", "select_model": "Select Model", "enable_search": "Enable Web Search", "system_prompt": "System Prompt", "generation_parameters": "Generation Parameters", "max_tokens": "Max Tokens", "temperature": "Temperature", "top_k": "Top-K", "top_p": "Top-P", "repeat_penalty": "Repetition Penalty", "web_search_settings": "Web Search Settings", "max_results": "Max Results", "max_chars_result": "Max Chars/Result", "search_timeout": "Search Timeout (s)", "clear_chat": "Clear Chat", "cancel_generation": "Cancel Generation", "chat_placeholder": "Type your message and press Enter...", "theme_label": "Select Theme", "language_label": "Select Language", "theme_light": "Light", "theme_dark": "Dark", "language_en": "English", "language_zh": "Chinese", }, "zh": { "title": "小熠演示", "description": "小熠AI数据安全专家", "select_model": "选择模型", "enable_search": "启用网络搜索", "system_prompt": "系统提示词", "generation_parameters": "生成参数", "max_tokens": "最大生成长度", "temperature": "温度", "top_k": "Top-K", "top_p": "Top-P", "repeat_penalty": "重复惩罚", "web_search_settings": "网络搜索设置", "max_results": "最大结果数", "max_chars_result": "每个结果最大字符数", "search_timeout": "搜索超时 (秒)", "clear_chat": "清空聊天", "cancel_generation": "取消生成", "chat_placeholder": "请输入消息,按回车发送...", "theme_label": "选择主题", "language_label": "选择语言", "theme_light": "明亮模式", "theme_dark": "暗黑模式", "language_en": "English", "language_zh": "中文", } } def get_ui_text(language, key): return UI_TEXTS.get(language, UI_TEXTS["en"]).get(key, "") with gr.Blocks(title=get_ui_text("en", "title")) as demo: theme_dropdown = gr.Dropdown(label=get_ui_text("en", "theme_label"), choices=[get_ui_text("en", "theme_light"), get_ui_text("en", "theme_dark")], value=get_ui_text("en", "theme_light")) language_dropdown = gr.Dropdown(label=get_ui_text("en", "language_label"), choices=[get_ui_text("en", "language_en"), get_ui_text("en", "language_zh")], value=get_ui_text("en", "language_en")) title_md = gr.Markdown(value=get_ui_text("en", "title")) description_md = gr.Markdown(value=get_ui_text("en", "description")) with gr.Row(): with gr.Column(scale=3): model_dd = gr.Dropdown(label=get_ui_text("en", "select_model"), choices=list(MODELS.keys()), value=list(MODELS.keys())[0]) search_chk = gr.Checkbox(label=get_ui_text("en", "enable_search"), value=True) sys_prompt = gr.Textbox(label=get_ui_text("en", "system_prompt"), lines=3, value=update_default_prompt(search_chk.value)) gr.Markdown(f"### {get_ui_text('en', 'generation_parameters')}") max_tok = gr.Slider(64, 16384, value=4096, step=32, label=get_ui_text("en", "max_tokens")) temp = gr.Slider(0.1, 2.0, value=0.6, step=0.1, label=get_ui_text("en", "temperature")) k = gr.Slider(1, 100, value=40, step=1, label="Top-K") p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P") rp = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label=get_ui_text("en", "repeat_penalty")) gr.Markdown(f"### {get_ui_text('en', 'web_search_settings')}") mr = gr.Number(value=6, precision=0, label=get_ui_text("en", "max_results")) mc = gr.Number(value=600, precision=0, label=get_ui_text("en", "max_chars_result")) st = gr.Slider(minimum=0.0, maximum=30.0, step=0.5, value=5.0, label=get_ui_text("en", "search_timeout")) clr = gr.Button(get_ui_text("en", "clear_chat")) cnl = gr.Button(get_ui_text("en", "cancel_generation")) with gr.Column(scale=7): chat = gr.Chatbot(type="messages", show_copy_all_button=True) txt = gr.Textbox(placeholder=get_ui_text("en", "chat_placeholder")) dbg = gr.Markdown() # Function to update UI texts when language or theme changes def update_ui(language, theme): texts = UI_TEXTS.get("en") if language == UI_TEXTS["en"]["language_en"] else UI_TEXTS.get("zh") # Map passed language (English/Chinese) to keys lang_key = "en" if language == UI_TEXTS["en"]["language_en"] else "zh" # Update labels return (texts["title"], texts["description"], texts["select_model"], texts["enable_search"], texts["system_prompt"], texts["generation_parameters"], texts["max_tokens"], texts["temperature"], texts["repeat_penalty"], texts["web_search_settings"], texts["max_results"], texts["max_chars_result"], texts["search_timeout"], texts["clear_chat"], texts["cancel_generation"], texts["chat_placeholder"], gr.themes.Dark() if theme == texts["theme_dark"] else gr.themes.Default()) def toggle_language_and_theme(language, theme): # Return updated texts and theme to update all UI elements lang_key = "en" if language == UI_TEXTS["en"]["language_en"] else "zh" texts = UI_TEXTS[lang_key] return { "title_md": texts["title"], "description_md": texts["description"], "model_dd_label": texts["select_model"], "search_chk_label": texts["enable_search"], "sys_prompt_label": texts["system_prompt"], "max_tok_label": texts["max_tokens"], "temp_label": texts["temperature"], "rp_label": texts["repeat_penalty"], "mr_label": texts["max_results"], "mc_label": texts["max_chars_result"], "st_label": texts["search_timeout"], "clr_label": texts["clear_chat"], "cnl_label": texts["cancel_generation"], "txt_placeholder": texts["chat_placeholder"], "theme_obj": gr.themes.Dark() if theme == texts["theme_dark"] else gr.themes.Default() } # Update UI text labels on language or theme change language_dropdown.change( fn=toggle_language_and_theme, inputs=[language_dropdown, theme_dropdown], outputs=[ title_md, description_md, model_dd, search_chk, sys_prompt, max_tok, temp, rp, mr, mc, st, clr, cnl, txt ], _js=""" function updateLabels(resp) { title_md.textContent = resp.title_md; description_md.textContent = resp.description_md; model_dd.label = resp.model_dd_label; search_chk.label = resp.search_chk_label; sys_prompt.label = resp.sys_prompt_label; max_tok.label = resp.max_tok_label; temp.label = resp.temp_label; rp.label = resp.rp_label; mr.label = resp.mr_label; mc.label = resp.mc_label; st.label = resp.st_label; clr.textContent = resp.clr_label; cnl.textContent = resp.cnl_label; txt.placeholder = resp.txt_placeholder; return; } """ ) theme_dropdown.change( fn=toggle_language_and_theme, inputs=[language_dropdown, theme_dropdown], outputs=[ title_md, description_md, model_dd, search_chk, sys_prompt, max_tok, temp, rp, mr, mc, st, clr, cnl, txt ] ) search_chk.change(fn=update_default_prompt, inputs=search_chk, outputs=sys_prompt) clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg]) cnl.click(fn=cancel_generation, outputs=dbg) txt.submit(fn=chat_response, inputs=[txt, chat, sys_prompt, search_chk, mr, mc, model_dd, max_tok, temp, k, p, rp, st], outputs=[chat, dbg]) demo.launch()