DeepThink / app.py
d-delaurier's picture
Update app.py
716a943 verified
import gradio as gr
import json
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Default system prompt for the chat interface
DEFAULT_SYSTEM_PROMPT = """You are DeepThink, a helpful and knowledgeable AI assistant. You aim to provide accurate,
informative, and engaging responses while maintaining a professional and friendly demeanor."""
class ChatInterface:
"""Main chat interface handler with memory and parameter management"""
def __init__(self):
"""Initialize the chat interface with default settings"""
self.model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
self.chat_history = []
self.system_prompt = DEFAULT_SYSTEM_PROMPT
def load_context_from_json(self, file_obj):
"""Load additional context from a JSON file"""
if file_obj is None:
return "No file uploaded", self.system_prompt
try:
content = json.load(file_obj)
if "system_prompt" in content:
self.system_prompt = content["system_prompt"]
return "Context loaded successfully!", self.system_prompt
except Exception as e:
return f"Error loading context: {str(e)}", self.system_prompt
def generate_response(self, message, temperature, max_length, top_p, presence_penalty, frequency_penalty):
"""Generate AI response with given parameters"""
# Format the input with system prompt and chat history
conversation = f"System: {self.system_prompt}\n\n"
for msg in self.chat_history:
conversation += f"Human: {msg[0]}\nAssistant: {msg[1]}\n\n"
conversation += f"Human: {message}\nAssistant:"
# Generate response with specified parameters
inputs = self.tokenizer(conversation, return_tensors="pt")
outputs = self.model.generate(
inputs["input_ids"],
max_length=max_length,
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract assistant's response and update chat history
response = response.split("Assistant:")[-1].strip()
self.chat_history.append((message, response))
return response, self.format_chat_history()
def format_chat_history(self):
"""Format chat history for display"""
return [(f"User: {msg[0]}", f"Assistant: {msg[1]}") for msg in self.chat_history]
def clear_history(self):
"""Clear the chat history"""
self.chat_history = []
return self.format_chat_history()
# Initialize the chat interface
chat_interface = ChatInterface()
# Create the Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Row():
with gr.Column(scale=2):
# Main chat interface
chatbot = gr.Chatbot(
label="Chat History",
height=600,
show_label=True,
)
with gr.Row():
message = gr.Textbox(
label="Your message",
placeholder="Type your message here...",
lines=2
)
submit_btn = gr.Button("Send", variant="primary")
with gr.Column(scale=1):
# System settings and parameters
with gr.Group(label="System Configuration"):
system_prompt = gr.Textbox(
label="System Prompt",
value=DEFAULT_SYSTEM_PROMPT,
lines=4
)
context_file = gr.File(
label="Upload Context JSON",
file_types=[".json"]
)
upload_button = gr.Button("Load Context")
context_status = gr.Textbox(label="Context Status", interactive=False)
with gr.Group(label="Generation Parameters"):
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature"
)
max_length = gr.Slider(
minimum=50,
maximum=2000,
value=500,
step=50,
label="Max Length"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.1,
label="Top P"
)
presence_penalty = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.0,
step=0.1,
label="Presence Penalty"
)
frequency_penalty = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.0,
step=0.1,
label="Frequency Penalty"
)
clear_btn = gr.Button("Clear Chat History")
# Event handlers
def submit_message(message, temperature, max_length, top_p, presence_penalty, frequency_penalty):
response, history = chat_interface.generate_response(
message, temperature, max_length, top_p, presence_penalty, frequency_penalty
)
return "", history
submit_btn.click(
submit_message,
inputs=[message, temperature, max_length, top_p, presence_penalty, frequency_penalty],
outputs=[message, chatbot]
)
message.submit(
submit_message,
inputs=[message, temperature, max_length, top_p, presence_penalty, frequency_penalty],
outputs=[message, chatbot]
)
clear_btn.click(
lambda: (chat_interface.clear_history(), ""),
outputs=[chatbot, message]
)
upload_button.click(
chat_interface.load_context_from_json,
inputs=[context_file],
outputs=[context_status, system_prompt]
)
# Launch the interface
demo.launch()