Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import spaces | |
import gradio as gr | |
import torch | |
import uuid | |
import time | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
def capture_logs(log_body, log_file, uuid_label): | |
logger = logging.getLogger('MyApp') | |
logger.setLevel(logging.INFO) | |
# Check if handlers are already added to avoid duplication | |
if not logger.handlers: | |
fh = logging.FileHandler(log_file) | |
fh.setLevel(logging.INFO) | |
ch = logging.StreamHandler() | |
ch.setLevel(logging.INFO) | |
formatter = logging.Formatter( | |
'%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
fh.setFormatter(formatter) | |
ch.setFormatter(formatter) | |
logger.addHandler(fh) | |
logger.addHandler(ch) | |
logger.info('uuid: %s - %s', log_body, uuid_label) | |
return | |
print("CUDA available: ", torch.cuda.is_available()) | |
print("MPS available: ", torch.backends.mps.is_available()) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", | |
trust_remote_code=True, | |
torch_dtype=torch.float16, # Use float16 for better GPU memory efficiency | |
device_map="auto" # Automatically handle device placement | |
) | |
# Disable tokenizers parallelism warning | |
os.environ["TOKENIZERS_PARALLELISM"] = "True" | |
# Configure device | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
# Print GPU information | |
print(f"Using GPU: {torch.cuda.get_device_name(0)}") | |
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") | |
else: | |
device = torch.device("cpu") | |
print("No GPU available, using CPU") | |
# Function to handle user input and generate a response | |
def chatbot_response(query, tokens, top_k, top_p): | |
uuid_label = str(uuid.uuid4()) | |
start_time = time.time() # Start timer | |
# Generate response using the model | |
messages = [{'role': 'user', 'content': query}] | |
inputs = tokenizer.apply_chat_template( | |
messages, add_generation_prompt=True, return_tensors="pt").to(model.device) | |
outputs = model.generate( | |
inputs, | |
max_new_tokens=tokens, | |
do_sample=True, | |
top_k=top_k, | |
top_p=top_p, | |
num_return_sequences=1, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
model_response = tokenizer.decode( | |
outputs[0][len(inputs[0]):], skip_special_tokens=True) | |
end_time = time.time() # End timer | |
performance_time = round(end_time - start_time, 2) | |
log_body = 'query: %s, pocessTime: %s, tokens: %s, top_k: %s, top_p: %s' % ( | |
query, performance_time, tokens, top_k, top_p) | |
capture_logs(uuid_label, 'query_logs.csv', log_body) | |
return model_response | |
# Set up the Gradio interface | |
iface = gr.Interface( | |
fn=chatbot_response, | |
inputs=[ | |
gr.Textbox(label="Ask our DSChatbot Expert"), | |
gr.Slider(label="Max New Tokens", minimum=128, | |
maximum=2048, step=128, value=512), | |
gr.Slider(label="Top K", minimum=0, maximum=100, step=10, value=50), | |
gr.Slider(label="Top P", minimum=0.0, | |
maximum=1.0, step=0.1, value=0.95), | |
], | |
outputs=gr.Textbox(label="Hope it helps!"), | |
title="DSChatbot" | |
) | |
if __name__ == "__main__": | |
iface.launch() | |