Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer | |
| import threading | |
| import queue | |
| import time | |
| # Model configuration | |
| model_name = "HelpingAI/Dhanishtha-2.0-preview" | |
| # Global variables for model and tokenizer | |
| model = None | |
| tokenizer = None | |
| def load_model(): | |
| """Load the model and tokenizer""" | |
| global model, tokenizer | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| print("Loading model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| print("Model loaded successfully!") | |
| class GradioTextStreamer(TextStreamer): | |
| """Custom TextStreamer for Gradio integration""" | |
| def __init__(self, tokenizer, skip_prompt=True, skip_special_tokens=True): | |
| super().__init__(tokenizer, skip_prompt, skip_special_tokens) | |
| self.text_queue = queue.Queue() | |
| self.generated_text = "" | |
| def on_finalized_text(self, text: str, stream_end: bool = False): | |
| """Called when text is finalized""" | |
| self.generated_text += text | |
| self.text_queue.put(text) | |
| if stream_end: | |
| self.text_queue.put(None) | |
| def get_generated_text(self): | |
| """Get all generated text so far""" | |
| return self.generated_text | |
| def reset(self): | |
| """Reset the streamer""" | |
| self.generated_text = "" | |
| # Clear the queue | |
| while not self.text_queue.empty(): | |
| try: | |
| self.text_queue.get_nowait() | |
| except queue.Empty: | |
| break | |
| def generate_response(message, history, max_tokens, temperature, top_p): | |
| """Generate streaming response""" | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| yield "Model is still loading. Please wait..." | |
| return | |
| # Prepare conversation history | |
| messages = [] | |
| for user_msg, assistant_msg in history: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| # Add current message | |
| messages.append({"role": "user", "content": message}) | |
| # Apply chat template | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Tokenize input | |
| model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| # Create and setup streamer | |
| streamer = GradioTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| streamer.reset() | |
| # Start generation in a separate thread | |
| generation_kwargs = { | |
| **model_inputs, | |
| "max_new_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "do_sample": True, | |
| "pad_token_id": tokenizer.eos_token_id, | |
| "streamer": streamer, | |
| "return_dict_in_generate": True | |
| } | |
| # Run generation in thread | |
| def generate(): | |
| try: | |
| with torch.no_grad(): | |
| model.generate(**generation_kwargs) | |
| except Exception as e: | |
| streamer.text_queue.put(f"Error: {str(e)}") | |
| streamer.text_queue.put(None) | |
| thread = threading.Thread(target=generate) | |
| thread.start() | |
| # Stream the results | |
| generated_text = "" | |
| while True: | |
| try: | |
| new_text = streamer.text_queue.get(timeout=30) | |
| if new_text is None: | |
| break | |
| generated_text += new_text | |
| yield generated_text | |
| except queue.Empty: | |
| break | |
| thread.join(timeout=1) | |
| # Final yield with complete text | |
| if generated_text: | |
| yield generated_text | |
| else: | |
| yield "No response generated." | |
| def chat_interface(message, history, max_tokens, temperature, top_p): | |
| """Main chat interface""" | |
| if not message.strip(): | |
| return history, "" | |
| # Add user message to history | |
| history.append([message, ""]) | |
| # Generate response | |
| for partial_response in generate_response(message, history[:-1], max_tokens, temperature, top_p): | |
| history[-1][1] = partial_response | |
| yield history, "" | |
| return history, "" | |
| # Load model on startup | |
| print("Initializing model...") | |
| load_model() | |
| # Create Gradio interface | |
| with gr.Blocks(title="Dhanishtha-2.0-preview Chat", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🤖 Dhanishtha-2.0-preview Chat | |
| Chat with the **HelpingAI/Dhanishtha-2.0-preview** model! | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| chatbot = gr.Chatbot( | |
| [], | |
| elem_id="chatbot", | |
| bubble_full_width=False, | |
| height=500, | |
| show_copy_button=True | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| container=False, | |
| placeholder="Type your message here...", | |
| label="Message", | |
| autofocus=True, | |
| scale=7 | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ⚙️ Parameters") | |
| max_tokens = gr.Slider( | |
| minimum=1, | |
| maximum=4096, | |
| value=2048, | |
| step=1, | |
| label="Max Tokens", | |
| info="Maximum number of tokens to generate" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature", | |
| info="Controls randomness in generation" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05, | |
| label="Top-p", | |
| info="Controls diversity of generation" | |
| ) | |
| clear_btn = gr.Button("🗑️ Clear Chat", variant="secondary") | |
| # Event handlers | |
| msg.submit( | |
| chat_interface, | |
| inputs=[msg, chatbot, max_tokens, temperature, top_p], | |
| outputs=[chatbot, msg], | |
| concurrency_limit=1 | |
| ) | |
| send_btn.click( | |
| chat_interface, | |
| inputs=[msg, chatbot, max_tokens, temperature, top_p], | |
| outputs=[chatbot, msg], | |
| concurrency_limit=1 | |
| ) | |
| clear_btn.click( | |
| lambda: ([], ""), | |
| outputs=[chatbot, msg] | |
| ) | |
| # Example prompts | |
| gr.Examples( | |
| examples=[ | |
| ["Hello! Who are you?"], | |
| ["Explain quantum computing in simple terms"], | |
| ["Write a short story about a robot learning to paint"], | |
| ["What are the benefits of renewable energy?"], | |
| ["Help me write a Python function to sort a list"] | |
| ], | |
| inputs=msg, | |
| label="💡 Example Prompts" | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch() |