Spaces:
Sleeping
Sleeping
File size: 3,125 Bytes
822b099 591c021 20041b2 822b099 591c021 49bf4ce 20041b2 822b099 20041b2 3180dde 20041b2 3180dde 20041b2 3180dde 20041b2 3180dde 20041b2 3180dde 20041b2 822b099 20041b2 475a923 20041b2 822b099 3180dde 20041b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load pre-trained model and tokenizer
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def generate_response_stream(message: str, history: list) -> str:
"""
Generate a response in real-time using streaming.
"""
# Combine the conversation history with the new message
full_input = ""
for user_msg, bot_msg in history:
full_input += f"User: {user_msg}\nBot: {bot_msg}\n"
full_input += f"User: {message}\nBot: "
# Tokenize input text
inputs = tokenizer.encode(full_input, return_tensors="pt", max_length=512, truncation=True).to(device)
# Generate response using the model with streaming
past_key_values = None
generated_tokens = []
for _ in range(100): # Max length of the response
outputs = model(
inputs,
past_key_values=past_key_values,
use_cache=True
)
next_token_logits = outputs.logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
# Decode the generated token
generated_token = tokenizer.decode(next_token[0], skip_special_tokens=True)
yield generated_token
# Stop generation if the end-of-sequence token is encountered
if next_token.item() == tokenizer.eos_token_id:
break
# Update inputs and past key values for the next step
inputs = next_token
past_key_values = outputs.past_key_values
# Append the generated token to the output
generated_tokens.append(generated_token)
# Yield the final response
final_response = "".join(generated_tokens)
yield final_response
# Create ChatInterface
with gr.Blocks() as demo:
gr.Markdown("""<center><font size=8>Chat with DeepSeek π</center>""")
chatbot = gr.Chatbot(label="Conversation")
textbox = gr.Textbox(lines=2, label="Your Message")
def add_user_message(message, history):
"""Add user message to the chat history."""
return "", history + [[message, None]]
def bot_response(history):
"""Generate bot response using streaming."""
user_message = history[-1][0]
history[-1][1] = ""
for token in generate_response_stream(user_message, history[:-1]):
history[-1][1] += token
yield history
# Event handlers
submit_event = textbox.submit(add_user_message, [textbox, chatbot], [textbox, chatbot]).then(
bot_response, chatbot, chatbot
)
# Add a "Send" button
send_button = gr.Button("π Send")
send_event = send_button.click(add_user_message, [textbox, chatbot], [textbox, chatbot]).then(
bot_response, chatbot, chatbot
)
# Launch the app
demo.queue(api_open=False)
demo.launch(max_threads=40) |