File size: 3,125 Bytes
822b099
591c021
20041b2
822b099
591c021
49bf4ce
20041b2
 
822b099
20041b2
 
 
 
 
 
 
 
3180dde
20041b2
 
 
 
3180dde
 
20041b2
3180dde
20041b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3180dde
20041b2
 
3180dde
20041b2
 
 
822b099
20041b2
 
 
 
 
 
 
 
 
475a923
 
 
 
 
 
 
20041b2
 
822b099
3180dde
20041b2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load pre-trained model and tokenizer
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

def generate_response_stream(message: str, history: list) -> str:
    """
    Generate a response in real-time using streaming.
    """
    # Combine the conversation history with the new message
    full_input = ""
    for user_msg, bot_msg in history:
        full_input += f"User: {user_msg}\nBot: {bot_msg}\n"
    full_input += f"User: {message}\nBot: "
    
    # Tokenize input text
    inputs = tokenizer.encode(full_input, return_tensors="pt", max_length=512, truncation=True).to(device)
    
    # Generate response using the model with streaming
    past_key_values = None
    generated_tokens = []
    for _ in range(100):  # Max length of the response
        outputs = model(
            inputs,
            past_key_values=past_key_values,
            use_cache=True
        )
        next_token_logits = outputs.logits[:, -1, :]
        next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
        
        # Decode the generated token
        generated_token = tokenizer.decode(next_token[0], skip_special_tokens=True)
        yield generated_token
        
        # Stop generation if the end-of-sequence token is encountered
        if next_token.item() == tokenizer.eos_token_id:
            break
        
        # Update inputs and past key values for the next step
        inputs = next_token
        past_key_values = outputs.past_key_values
        
        # Append the generated token to the output
        generated_tokens.append(generated_token)

    # Yield the final response
    final_response = "".join(generated_tokens)
    yield final_response


# Create ChatInterface
with gr.Blocks() as demo:
    gr.Markdown("""<center><font size=8>Chat with DeepSeek πŸš€</center>""")
    
    chatbot = gr.Chatbot(label="Conversation")
    textbox = gr.Textbox(lines=2, label="Your Message")
    
    def add_user_message(message, history):
        """Add user message to the chat history."""
        return "", history + [[message, None]]

    def bot_response(history):
        """Generate bot response using streaming."""
        user_message = history[-1][0]
        history[-1][1] = ""
        for token in generate_response_stream(user_message, history[:-1]):
            history[-1][1] += token
            yield history

    # Event handlers
    submit_event = textbox.submit(add_user_message, [textbox, chatbot], [textbox, chatbot]).then(
        bot_response, chatbot, chatbot
    )

    # Add a "Send" button
    send_button = gr.Button("πŸš€ Send")
    send_event = send_button.click(add_user_message, [textbox, chatbot], [textbox, chatbot]).then(
        bot_response, chatbot, chatbot
    )

# Launch the app
demo.queue(api_open=False)
demo.launch(max_threads=40)