Spaces:
Runtime error
Runtime error
import gradio as gr | |
import json | |
from pathlib import Path | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Default system prompt for the chat interface | |
DEFAULT_SYSTEM_PROMPT = """You are DeepThink, a helpful and knowledgeable AI assistant. You aim to provide accurate, | |
informative, and engaging responses while maintaining a professional and friendly demeanor.""" | |
class ChatInterface: | |
"""Main chat interface handler with memory and parameter management""" | |
def __init__(self): | |
"""Initialize the chat interface with default settings""" | |
self.model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self.model = AutoModelForCausalLM.from_pretrained(self.model_name) | |
self.chat_history = [] | |
self.system_prompt = DEFAULT_SYSTEM_PROMPT | |
def load_context_from_json(self, file_obj): | |
"""Load additional context from a JSON file""" | |
if file_obj is None: | |
return "No file uploaded", self.system_prompt | |
try: | |
content = json.load(file_obj) | |
if "system_prompt" in content: | |
self.system_prompt = content["system_prompt"] | |
return "Context loaded successfully!", self.system_prompt | |
except Exception as e: | |
return f"Error loading context: {str(e)}", self.system_prompt | |
def generate_response(self, message, temperature, max_length, top_p, presence_penalty, frequency_penalty): | |
"""Generate AI response with given parameters""" | |
# Format the input with system prompt and chat history | |
conversation = f"System: {self.system_prompt}\n\n" | |
for msg in self.chat_history: | |
conversation += f"Human: {msg[0]}\nAssistant: {msg[1]}\n\n" | |
conversation += f"Human: {message}\nAssistant:" | |
# Generate response with specified parameters | |
inputs = self.tokenizer(conversation, return_tensors="pt") | |
outputs = self.model.generate( | |
inputs["input_ids"], | |
max_length=max_length, | |
temperature=temperature, | |
top_p=top_p, | |
presence_penalty=presence_penalty, | |
frequency_penalty=frequency_penalty, | |
) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract assistant's response and update chat history | |
response = response.split("Assistant:")[-1].strip() | |
self.chat_history.append((message, response)) | |
return response, self.format_chat_history() | |
def format_chat_history(self): | |
"""Format chat history for display""" | |
return [(f"User: {msg[0]}", f"Assistant: {msg[1]}") for msg in self.chat_history] | |
def clear_history(self): | |
"""Clear the chat history""" | |
self.chat_history = [] | |
return self.format_chat_history() | |
# Initialize the chat interface | |
chat_interface = ChatInterface() | |
# Create the Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Main chat interface | |
chatbot = gr.Chatbot( | |
label="Chat History", | |
height=600, | |
show_label=True, | |
) | |
with gr.Row(): | |
message = gr.Textbox( | |
label="Your message", | |
placeholder="Type your message here...", | |
lines=2 | |
) | |
submit_btn = gr.Button("Send", variant="primary") | |
with gr.Column(scale=1): | |
# System settings and parameters | |
with gr.Group(label="System Configuration"): | |
system_prompt = gr.Textbox( | |
label="System Prompt", | |
value=DEFAULT_SYSTEM_PROMPT, | |
lines=4 | |
) | |
context_file = gr.File( | |
label="Upload Context JSON", | |
file_types=[".json"] | |
) | |
upload_button = gr.Button("Load Context") | |
context_status = gr.Textbox(label="Context Status", interactive=False) | |
with gr.Group(label="Generation Parameters"): | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature" | |
) | |
max_length = gr.Slider( | |
minimum=50, | |
maximum=2000, | |
value=500, | |
step=50, | |
label="Max Length" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
step=0.1, | |
label="Top P" | |
) | |
presence_penalty = gr.Slider( | |
minimum=0.0, | |
maximum=2.0, | |
value=0.0, | |
step=0.1, | |
label="Presence Penalty" | |
) | |
frequency_penalty = gr.Slider( | |
minimum=0.0, | |
maximum=2.0, | |
value=0.0, | |
step=0.1, | |
label="Frequency Penalty" | |
) | |
clear_btn = gr.Button("Clear Chat History") | |
# Event handlers | |
def submit_message(message, temperature, max_length, top_p, presence_penalty, frequency_penalty): | |
response, history = chat_interface.generate_response( | |
message, temperature, max_length, top_p, presence_penalty, frequency_penalty | |
) | |
return "", history | |
submit_btn.click( | |
submit_message, | |
inputs=[message, temperature, max_length, top_p, presence_penalty, frequency_penalty], | |
outputs=[message, chatbot] | |
) | |
message.submit( | |
submit_message, | |
inputs=[message, temperature, max_length, top_p, presence_penalty, frequency_penalty], | |
outputs=[message, chatbot] | |
) | |
clear_btn.click( | |
lambda: (chat_interface.clear_history(), ""), | |
outputs=[chatbot, message] | |
) | |
upload_button.click( | |
chat_interface.load_context_from_json, | |
inputs=[context_file], | |
outputs=[context_status, system_prompt] | |
) | |
# Launch the interface | |
demo.launch() |