Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
# Load pre-trained model and tokenizer | |
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Move model to GPU if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
def generate_response_stream(message: str, history: list) -> str: | |
""" | |
Generate a response in real-time using streaming. | |
""" | |
# Combine the conversation history with the new message | |
full_input = "" | |
for user_msg, bot_msg in history: | |
full_input += f"User: {user_msg}\nBot: {bot_msg}\n" | |
full_input += f"User: {message}\nBot: " | |
# Tokenize input text | |
inputs = tokenizer.encode(full_input, return_tensors="pt", max_length=512, truncation=True).to(device) | |
# Generate response using the model with streaming | |
past_key_values = None | |
generated_tokens = [] | |
for _ in range(100): # Max length of the response | |
outputs = model( | |
inputs, | |
past_key_values=past_key_values, | |
use_cache=True | |
) | |
next_token_logits = outputs.logits[:, -1, :] | |
next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) | |
# Decode the generated token | |
generated_token = tokenizer.decode(next_token[0], skip_special_tokens=True) | |
yield generated_token | |
# Stop generation if the end-of-sequence token is encountered | |
if next_token.item() == tokenizer.eos_token_id: | |
break | |
# Update inputs and past key values for the next step | |
inputs = next_token | |
past_key_values = outputs.past_key_values | |
# Append the generated token to the output | |
generated_tokens.append(generated_token) | |
# Yield the final response | |
final_response = "".join(generated_tokens) | |
yield final_response | |
# Create ChatInterface | |
with gr.Blocks() as demo: | |
gr.Markdown("""<center><font size=8>Chat with DeepSeek π</center>""") | |
chatbot = gr.Chatbot(label="Conversation") | |
textbox = gr.Textbox(lines=2, label="Your Message") | |
def add_user_message(message, history): | |
"""Add user message to the chat history.""" | |
return "", history + [[message, None]] | |
def bot_response(history): | |
"""Generate bot response using streaming.""" | |
user_message = history[-1][0] | |
history[-1][1] = "" | |
for token in generate_response_stream(user_message, history[:-1]): | |
history[-1][1] += token | |
yield history | |
# Event handlers | |
submit_event = textbox.submit(add_user_message, [textbox, chatbot], [textbox, chatbot]).then( | |
bot_response, chatbot, chatbot | |
) | |
# Add a "Send" button | |
send_button = gr.Button("π Send") | |
send_event = send_button.click(add_user_message, [textbox, chatbot], [textbox, chatbot]).then( | |
bot_response, chatbot, chatbot | |
) | |
# Launch the app | |
demo.queue(api_open=False) | |
demo.launch(max_threads=40) |