Rakesh2205's picture
Update app.py
de6daef verified
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
import time
# Dictionary of available models with their configurations
AVAILABLE_MODELS = {
"GPT-2 (Small)": "gpt2",
"DialoGPT (Small)": "microsoft/DialoGPT-small",
"BLOOM (560M)": "bigscience/bloom-560m",
"OPT (125M)": "facebook/opt-125m",
"T5 (Small)": "t5-small"
}
class ChatAgent:
def __init__(self):
self.current_model_name = "gpt2"
self.model = None
self.load_model(self.current_model_name)
print("ChatAgent initialized with model:", self.current_model_name)
def load_model(self, model_name):
"""Load a new model and tokenizer"""
try:
print(f"Loading model: {model_name}")
self.current_model_name = model_name
if "t5" in model_name.lower():
self.model = pipeline("text2text-generation", model=model_name)
else:
self.model = pipeline("text-generation", model=model_name)
print(f"Model {model_name} loaded successfully")
except Exception as e:
print(f"Error loading model {model_name}: {str(e)}")
raise e
def chat(self, message: str, history: list, model_choice: str) -> tuple[str, dict]:
"""Process a single chat message and return the response"""
try:
start_time = time.time()
# Check if we need to switch models
model_name = AVAILABLE_MODELS[model_choice]
if model_name != self.current_model_name:
self.load_model(model_name)
# Generate response based on model type
if "t5" in model_name.lower():
generated = self.model(message, max_length=100, truncation=True)
response = generated[0]['generated_text']
else:
generated = self.model(
message,
max_length=100,
pad_token_id=self.model.tokenizer.eos_token_id,
truncation=True
)
response = generated[0]['generated_text']
# Clean up the response by removing the input message if it's included
if response.startswith(message):
response = response[len(message):].strip()
# Calculate metadata
end_time = time.time()
response_time = round(end_time - start_time, 2)
input_tokens = len(self.model.tokenizer.encode(message))
output_tokens = len(self.model.tokenizer.encode(response))
metadata = {
"response_time": f"{response_time}s",
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"model": model_choice
}
return response, metadata
except Exception as e:
return f"Sorry, I encountered an error: {str(e)}", {}
def chat_response(message, history, model_choice):
"""Handler for chat messages"""
agent = ChatAgent() if not hasattr(chat_response, 'agent') else chat_response.agent
response, metadata = agent.chat(message, history, model_choice)
metadata_str = f"\n\n*[πŸ“Š Model: {metadata['model']} | ⏱️ Response Time: {metadata['response_time']} | πŸ“ Tokens: {metadata['input_tokens']}β†’{metadata['output_tokens']}]*"
# Format messages as tuples for Gradio chat
if not history:
history = []
history.append((message, response + metadata_str))
return history
# Create custom theme
theme = gr.themes.Soft().set(
body_background_fill="white",
block_background_fill="white",
block_border_width="0",
button_primary_background_fill="*primary_500",
button_primary_text_color="white",
button_secondary_background_fill="*neutral_200",
button_secondary_text_color="*neutral_800"
)
# Add custom CSS for better styling
css = """
body {
background-color: white !important;
}
#title {
text-align: center;
margin-bottom: 0.5rem;
padding: 0.5rem;
border-radius: 8px;
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.05);
}
.control-section {
background: white;
padding: 0.8rem;
margin-bottom: 0.5rem;
border-radius: 8px;
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.05);
}
.section-title {
padding: 2px;
border-radius: 5px;
margin-bottom: 0.3rem;
}
.quick-actions {
display: flex !important;
justify-content: space-between !important;
gap: 8px !important;
padding: 4px !important;
margin-bottom: 8px !important;
}
.quick-action-chip {
flex: 1 !important;
background: #f0f2f5 !important;
border: none !important;
border-radius: 20px !important;
padding: 8px 16px !important;
font-size: 0.9em !important;
color: #1a1a1a !important;
transition: all 0.2s ease-in-out !important;
box-shadow: none !important;
margin-bottom: 2px !important;
height: auto !important;
line-height: 1.2 !important;
min-width: 120px !important;
}
.quick-action-chip:hover {
background: #e4e6e9 !important;
transform: translateY(-1px);
}
/* Enhanced Chat Styling */
.message {
position: relative !important;
padding: 12px 16px !important;
margin: 8px 0 !important;
border-radius: 12px !important;
max-width: 85% !important;
font-size: 0.95em !important;
line-height: 1.4 !important;
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1) !important;
}
/* User message styling */
.message.user {
background: #007AFF !important;
color: white !important;
margin-left: auto !important;
border-bottom-right-radius: 4px !important;
}
/* Assistant message styling */
.message.assistant {
background: #f0f2f5 !important;
color: #1a1a1a !important;
margin-right: auto !important;
border-bottom-left-radius: 4px !important;
}
/* Metadata styling in assistant messages */
.message.assistant em {
display: block !important;
margin-top: 8px !important;
padding-top: 8px !important;
border-top: 1px solid rgba(0, 0, 0, 0.1) !important;
color: #666 !important;
font-size: 0.85em !important;
font-style: normal !important;
}
/* Chat container styling */
.chat-container {
border: 1px solid #e5e7eb !important;
border-radius: 12px !important;
background: white !important;
padding: 16px !important;
height: 460px !important;
overflow-y: auto !important;
scrollbar-width: thin !important;
scrollbar-color: #cbd5e1 transparent !important;
}
.chat-container::-webkit-scrollbar {
width: 6px !important;
}
.chat-container::-webkit-scrollbar-track {
background: transparent !important;
}
.chat-container::-webkit-scrollbar-thumb {
background-color: #cbd5e1 !important;
border-radius: 3px !important;
}
/* Input area styling */
.input-area {
border: 1px solid #e5e7eb !important;
border-radius: 12px !important;
background: white !important;
margin-top: 16px !important;
}
textarea {
min-height: 40px !important;
padding: 12px !important;
border-radius: 12px !important;
border: none !important;
background: transparent !important;
resize: none !important;
font-size: 0.95em !important;
}
textarea:focus {
outline: none !important;
box-shadow: 0 0 0 2px rgba(0, 122, 255, 0.1) !important;
}
.gradio-container {
background: white !important;
}
.contain {
background: white !important;
}
#clear-btn {
background: #f0f2f5 !important;
border: none !important;
border-radius: 8px !important;
color: #1a1a1a !important;
transition: all 0.2s ease-in-out !important;
width: 100% !important;
}
#clear-btn:hover {
background: #e4e6e9 !important;
}
"""
# Create Gradio interface
with gr.Blocks(theme=theme, css=css) as demo:
with gr.Row(equal_height=True):
gr.Markdown(
"""
# πŸ€– Multi-Model AI Chat Assistant
Chat with different AI models and explore their capabilities!
""",
elem_id="title"
)
with gr.Row(equal_height=True):
# Left column for model selection and controls
with gr.Column(scale=1, min_width=300):
# Model Selection Section
with gr.Group(elem_classes="control-section"):
gr.Markdown(
"""
## 🎯 Model Selection
""",
elem_classes="section-title"
)
model_dropdown = gr.Dropdown(
choices=list(AVAILABLE_MODELS.keys()),
value="GPT-2 (Small)",
label="Select AI Model",
info="Choose which AI model to chat with",
container=True
)
# Quick Actions Section
with gr.Group(elem_classes="control-section"):
gr.Markdown(
"""
## ⚑ Quick Prompts
""",
elem_classes="section-title"
)
with gr.Group(elem_classes="quick-actions"):
example_btn1 = gr.Button("🎭 Creative Story", elem_classes="quick-action-chip")
example_btn2 = gr.Button("πŸ€– Explain AI", elem_classes="quick-action-chip")
example_btn3 = gr.Button("🌺 Write Poetry", elem_classes="quick-action-chip")
example_btn4 = gr.Button("βš›οΈ Science Facts", elem_classes="quick-action-chip")
# Clear Chat Section
with gr.Group(elem_classes="control-section"):
clear = gr.Button("πŸ—‘οΈ Clear Chat", elem_id="clear-btn", size="sm", variant="secondary", elem_classes="full-width")
# Right column for chat interface
with gr.Column(scale=2):
chatbot = gr.Chatbot(
value=[],
label="Chat History",
height=460,
show_copy_button=True,
container=False,
elem_classes="chat-container"
)
with gr.Row(elem_classes="input-area"):
msg = gr.Textbox(
label="Type your message",
placeholder="Type your message here...",
lines=2,
scale=8,
show_label=False,
container=False
)
submit = gr.Button(
"βœ‰οΈ Send",
scale=1,
variant="primary",
size="sm",
elem_classes="send-button"
)
# Example button click handlers
example_btn1.click(lambda: "Tell me a story about a brave knight", None, msg)
example_btn2.click(lambda: "What is artificial intelligence?", None, msg)
example_btn3.click(lambda: "Write a poem about nature", None, msg)
example_btn4.click(lambda: "Explain quantum computing in simple terms", None, msg)
# Set up event handlers
submit.click(
chat_response,
inputs=[msg, chatbot, model_dropdown],
outputs=[chatbot]
).then(
lambda: "",
None,
msg
)
msg.submit(
chat_response,
inputs=[msg, chatbot, model_dropdown],
outputs=[chatbot]
).then(
lambda: "",
None,
msg
)
clear.click(lambda: None, None, chatbot)
if __name__ == "__main__":
print("Starting Multi-Model AI Chat Interface...")
demo.queue().launch(debug=True, share=False)