File size: 3,390 Bytes
192cb9d
0a1e0cd
fcde440
8ee6ae0
192cb9d
0a1e0cd
 
d0e9a6d
0a1e0cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ee6ae0
ba1b260
 
 
 
26db21a
192cb9d
4d5c96b
192cb9d
4d5c96b
b930686
 
 
 
8ee6ae0
192cb9d
0a1e0cd
8ee6ae0
 
b930686
 
 
 
 
 
 
 
 
8ee6ae0
 
192cb9d
1a9dd5a
0a1e0cd
 
 
 
8ee6ae0
192cb9d
 
 
 
 
0a1e0cd
 
 
 
 
 
 
 
 
192cb9d
 
8ee6ae0
0a1e0cd
26db21a
0a1e0cd
26db21a
 
0a1e0cd
 
 
 
192cb9d
 
 
 
 
0a1e0cd
 
 
 
 
 
 
 
192cb9d
 
 
8ee6ae0
 
192cb9d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import logging
import spaces
import gradio as gr
import torch
import uuid
import time
from transformers import AutoTokenizer, AutoModelForCausalLM

def capture_logs(log_body, log_file, uuid_label):
    logger = logging.getLogger('MyApp')
    logger.setLevel(logging.INFO)

    # Check if handlers are already added to avoid duplication
    if not logger.handlers:
        fh = logging.FileHandler(log_file)
        fh.setLevel(logging.INFO)

        ch = logging.StreamHandler()
        ch.setLevel(logging.INFO)

        formatter = logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        fh.setFormatter(formatter)
        ch.setFormatter(formatter)

        logger.addHandler(fh)
        logger.addHandler(ch)

    logger.info('uuid: %s - %s', log_body, uuid_label)
    return


print("CUDA available: ", torch.cuda.is_available())
print("MPS available: ", torch.backends.mps.is_available())


tokenizer = AutoTokenizer.from_pretrained(
    "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", 
    trust_remote_code=True, 
    torch_dtype=torch.float16,  # Use float16 for better GPU memory efficiency
    device_map="auto"  # Automatically handle device placement
)

# Disable tokenizers parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "True"


# Configure device
if torch.cuda.is_available():
    device = torch.device("cuda")
    # Print GPU information
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
else:
    device = torch.device("cpu")
    print("No GPU available, using CPU")


# Function to handle user input and generate a response
@spaces.GPU
def chatbot_response(query, tokens, top_k, top_p):
    uuid_label = str(uuid.uuid4())

    start_time = time.time()  # Start timer

    # Generate response using the model
    messages = [{'role': 'user', 'content': query}]
    inputs = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, return_tensors="pt").to(model.device)

    outputs = model.generate(
        inputs,
        max_new_tokens=tokens,
        do_sample=True,
        top_k=top_k,
        top_p=top_p,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id
    )
    model_response = tokenizer.decode(
        outputs[0][len(inputs[0]):], skip_special_tokens=True)

    end_time = time.time()  # End timer
    performance_time = round(end_time - start_time, 2)

    log_body = 'query: %s, pocessTime: %s,  tokens: %s, top_k: %s, top_p: %s' % (
        query, performance_time, tokens, top_k, top_p)

    capture_logs(uuid_label, 'query_logs.csv', log_body)

    return model_response


# Set up the Gradio interface
iface = gr.Interface(
    fn=chatbot_response,
    inputs=[
        gr.Textbox(label="Ask our DSChatbot Expert"),
        gr.Slider(label="Max New Tokens", minimum=128,
                  maximum=2048, step=128, value=512),
        gr.Slider(label="Top K", minimum=0, maximum=100, step=10, value=50),
        gr.Slider(label="Top P", minimum=0.0,
                  maximum=1.0, step=0.1, value=0.95),
    ],
    outputs=gr.Textbox(label="Hope it helps!"),
    title="DSChatbot"
)

if __name__ == "__main__":
    iface.launch()