Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| import torch | |
| from threading import Thread | |
| import spaces | |
| import time | |
| # Load the model and tokenizer | |
| model_name = "sarvamai/sarvam-m" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto") | |
| indicators = ["Thinking ⠋", "Thinking ⠙", "Thinking ⠹", "Thinking ⠸", "Thinking ⠼", "Thinking ⠴", "Thinking ⠦", "Thinking ⠧", "Thinking ⠇", "Thinking ⠏"] | |
| def generate_response(prompt, chat_history): | |
| chat_history.append({"role": "user", "content": prompt}) | |
| yield chat_history, "" | |
| print(chat_history) | |
| # Preprocess chat history to include thinking tags | |
| processed_chat_history = [] | |
| for message in chat_history: | |
| # Skipping Thought Process in history | |
| if message["role"] == "assistant": | |
| metadata = message.get("metadata", {}) | |
| if isinstance(metadata, dict) and metadata.get("title", "").startswith("Thought"): | |
| pass | |
| else: | |
| processed_chat_history.append(message) | |
| else: | |
| processed_chat_history.append(message) | |
| text = tokenizer.apply_chat_template(processed_chat_history, tokenize=False, add_generation_prompt=True) | |
| model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| # Use TextIteratorStreamer for streaming | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| # Conduct text generation with streaming | |
| generation_kwargs = dict( | |
| input_ids=model_inputs.input_ids, | |
| max_new_tokens=8192, | |
| streamer=streamer, | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Initialize variables to capture reasoning content and main content | |
| reasoning_content = "" | |
| content = "" | |
| reasoning_done = False | |
| start_time = time.time() | |
| chat_history.append({"role": "assistant", "content": reasoning_content, "metadata": {"title": "Thinking..."}}) | |
| indicator_index = 0 | |
| for new_text in streamer: | |
| if "</think>" in new_text: | |
| reasoning_done = True | |
| thought_duration = time.time() - start_time | |
| chat_history[-1]["metadata"] = {"title": f"Thought for {thought_duration:.2f} seconds"} | |
| chat_history.append({"role": "assistant", "content": content}) | |
| if not reasoning_done: | |
| # Update the thinking indicator | |
| indicator_index = (indicator_index + 1) % len(indicators) | |
| chat_history[-1]["metadata"] = {"title": indicators[indicator_index]} | |
| reasoning_content += new_text | |
| chat_history[-1]["content"] = reasoning_content | |
| else: | |
| content += new_text | |
| chat_history[-1]["content"] = content | |
| yield chat_history, "" | |
| # Create the Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Sarvam M Demo") | |
| chatbot = gr.Chatbot(height=500, type="messages") | |
| msg = gr.Textbox(label="Your Message") | |
| msg.submit(generate_response, [msg, chatbot], [chatbot, msg]) | |
| if __name__ == "__main__": | |
| demo.launch(mcp_server=True) |