Spaces:
Running
Running
| import os | |
| import time | |
| import gc | |
| import threading | |
| from datetime import datetime | |
| import gradio as gr | |
| import torch | |
| from transformers import pipeline, TextIteratorStreamer | |
| import spaces # Import spaces early to enable ZeroGPU support | |
| # ------------------------------ | |
| # Global Cancellation Event | |
| # ------------------------------ | |
| cancel_event = threading.Event() | |
| # ------------------------------ | |
| # Qwen3 Model Definitions | |
| # ------------------------------ | |
| MODELS = { | |
| "Qwen3-8B": {"repo_id": "Qwen/Qwen3-8B", "description": "Qwen3-8B - Largest model with highest capabilities"}, | |
| "Qwen3-4B": {"repo_id": "Qwen/Qwen3-4B", "description": "Qwen3-4B - Good balance of performance and efficiency"}, | |
| "Qwen3-1.7B": {"repo_id": "Qwen/Qwen3-1.7B", "description": "Qwen3-1.7B - Smaller model for faster responses"}, | |
| "Qwen3-0.6B": {"repo_id": "Qwen/Qwen3-0.6B", "description": "Qwen3-0.6B - Ultra-lightweight model"} | |
| } | |
| # Global cache for pipelines to avoid re-loading. | |
| PIPELINES = {} | |
| def load_pipeline(model_name): | |
| """ | |
| Load and cache a transformers pipeline for text generation. | |
| Tries bfloat16, falls back to float16 or float32 if unsupported. | |
| """ | |
| global PIPELINES | |
| if model_name in PIPELINES: | |
| return PIPELINES[model_name] | |
| repo = MODELS[model_name]["repo_id"] | |
| for dtype in (torch.bfloat16, torch.float16, torch.float32): | |
| try: | |
| pipe = pipeline( | |
| task="text-generation", | |
| model=repo, | |
| tokenizer=repo, | |
| trust_remote_code=True, | |
| torch_dtype=dtype, | |
| device_map="auto" | |
| ) | |
| PIPELINES[model_name] = pipe | |
| return pipe | |
| except Exception: | |
| continue | |
| # Final fallback | |
| pipe = pipeline( | |
| task="text-generation", | |
| model=repo, | |
| tokenizer=repo, | |
| trust_remote_code=True, | |
| device_map="auto" | |
| ) | |
| PIPELINES[model_name] = pipe | |
| return pipe | |
| def format_conversation(history, system_prompt): | |
| """ | |
| Flatten chat history and system prompt into a single string. | |
| """ | |
| prompt = system_prompt.strip() + "\n" | |
| for turn in history: | |
| user_msg, assistant_msg = turn | |
| prompt += "User: " + user_msg.strip() + "\n" | |
| if assistant_msg: # might be None or empty | |
| prompt += "Assistant: " + assistant_msg.strip() + "\n" | |
| if not prompt.strip().endswith("Assistant:"): | |
| prompt += "Assistant: " | |
| return prompt | |
| def chat_response(user_msg, history, system_prompt, | |
| model_name, max_tokens, temperature, | |
| top_k, top_p, repeat_penalty): | |
| """ | |
| Generates streaming chat responses using the standard (user, assistant) format. | |
| """ | |
| cancel_event.clear() | |
| # Add the user message to history | |
| history = history + [[user_msg, None]] | |
| # Format the conversation for the model | |
| prompt = format_conversation(history, system_prompt) | |
| try: | |
| pipe = load_pipeline(model_name) | |
| streamer = TextIteratorStreamer(pipe.tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True) | |
| gen_thread = threading.Thread( | |
| target=pipe, | |
| args=(prompt,), | |
| kwargs={ | |
| 'max_new_tokens': max_tokens, | |
| 'temperature': temperature, | |
| 'top_k': top_k, | |
| 'top_p': top_p, | |
| 'repetition_penalty': repeat_penalty, | |
| 'streamer': streamer, | |
| 'return_full_text': False | |
| } | |
| ) | |
| gen_thread.start() | |
| # Stream the response | |
| assistant_text = '' | |
| for chunk in streamer: | |
| if cancel_event.is_set(): | |
| break | |
| assistant_text += chunk | |
| history[-1][1] = assistant_text | |
| yield history | |
| gen_thread.join() | |
| except Exception as e: | |
| history[-1][1] = f"Error: {e}" | |
| yield history | |
| finally: | |
| gc.collect() | |
| def cancel_generation(): | |
| cancel_event.set() | |
| return 'Generation cancelled.' | |
| def get_default_system_prompt(): | |
| today = datetime.now().strftime('%Y-%m-%d') | |
| return f"""You are Qwen3, a helpful and friendly AI assistant created by Alibaba Cloud. | |
| Today is {today}. | |
| Be concise, accurate, and helpful in your responses.""" | |
| # CSS for improved visual style | |
| css = """ | |
| .gradio-container { | |
| background-color: #f5f7fb !important; | |
| } | |
| .qwen-header { | |
| background: linear-gradient(90deg, #0099FF, #0066CC); | |
| padding: 20px; | |
| border-radius: 10px; | |
| margin-bottom: 20px; | |
| text-align: center; | |
| color: white; | |
| box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
| } | |
| .qwen-container { | |
| border-radius: 10px; | |
| box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); | |
| background: white; | |
| padding: 20px; | |
| margin-bottom: 20px; | |
| } | |
| .controls-container { | |
| background: #f0f4fa; | |
| border-radius: 10px; | |
| padding: 15px; | |
| margin-bottom: 15px; | |
| } | |
| .model-select { | |
| border: 2px solid #0099FF !important; | |
| border-radius: 8px !important; | |
| } | |
| .button-primary { | |
| background-color: #0099FF !important; | |
| color: white !important; | |
| } | |
| .button-secondary { | |
| background-color: #6c757d !important; | |
| color: white !important; | |
| } | |
| .footer { | |
| text-align: center; | |
| margin-top: 20px; | |
| font-size: 0.8em; | |
| color: #666; | |
| } | |
| """ | |
| # Function to get just the model name from the dropdown selection | |
| def get_model_name(full_selection): | |
| return full_selection.split(" - ")[0] | |
| # Function to clear chat | |
| def clear_chat(): | |
| return [], "" | |
| # Function to handle message submission and clear input | |
| def submit_message(user_input, history, system_prompt, model_name, max_tokens, temp, k, p, rp): | |
| return "", history + [[user_input, None]] | |
| # ------------------------------ | |
| # Gradio UI | |
| # ------------------------------ | |
| with gr.Blocks(title="Qwen3 Chat", css=css) as demo: | |
| gr.HTML(""" | |
| <div class="qwen-header"> | |
| <h1>🤖 Qwen3 Chat</h1> | |
| <p>Interact with Alibaba Cloud's Qwen3 language models</p> | |
| </div> | |
| """) | |
| chatbot = gr.Chatbot(height=500) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| with gr.Group(elem_classes="qwen-container"): | |
| model_dd = gr.Dropdown( | |
| label="Select Qwen3 Model", | |
| choices=[f"{k} - {v['description']}" for k, v in MODELS.items()], | |
| value=f"{list(MODELS.keys())[0]} - {MODELS[list(MODELS.keys())[0]]['description']}", | |
| elem_classes="model-select" | |
| ) | |
| with gr.Group(elem_classes="controls-container"): | |
| gr.Markdown("### ⚙️ Generation Parameters") | |
| sys_prompt = gr.Textbox(label="System Prompt", lines=5, value=get_default_system_prompt()) | |
| with gr.Row(): | |
| max_tok = gr.Slider(64, 1024, value=512, step=32, label="Max Tokens") | |
| with gr.Row(): | |
| temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") | |
| p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P") | |
| with gr.Row(): | |
| k = gr.Slider(1, 100, value=40, step=1, label="Top-K") | |
| rp = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty") | |
| with gr.Row(): | |
| clr = gr.Button("Clear Chat", elem_classes="button-secondary") | |
| cnl = gr.Button("Cancel Generation", elem_classes="button-secondary") | |
| with gr.Column(scale=7): | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| placeholder="Type your message and press Enter...", | |
| lines=2, | |
| show_label=False | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", elem_classes="button-primary") | |
| gr.HTML(""" | |
| <div class="footer"> | |
| <p>Qwen3 models developed by Alibaba Cloud. Interface powered by Gradio and ZeroGPU.</p> | |
| </div> | |
| """) | |
| # Event handlers | |
| clr.click(fn=clear_chat, outputs=[chatbot, msg]) | |
| cnl.click(fn=cancel_generation) | |
| # Handle sending messages and generating responses | |
| msg.submit( | |
| fn=submit_message, | |
| inputs=[msg, chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp], | |
| outputs=[msg, chatbot] | |
| ).then( | |
| fn=lambda history, prompt, model, tok, temp, k, p, rp: | |
| chat_response( | |
| history[-1][0], history[:-1], prompt, | |
| get_model_name(model), tok, temp, k, p, rp | |
| ), | |
| inputs=[chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp], | |
| outputs=chatbot | |
| ) | |
| send_btn.click( | |
| fn=submit_message, | |
| inputs=[msg, chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp], | |
| outputs=[msg, chatbot] | |
| ).then( | |
| fn=lambda history, prompt, model, tok, temp, k, p, rp: | |
| chat_response( | |
| history[-1][0], history[:-1], prompt, | |
| get_model_name(model), tok, temp, k, p, rp | |
| ), | |
| inputs=[chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp], | |
| outputs=chatbot | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |