Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import time | |
# Dictionary of available models with their configurations | |
AVAILABLE_MODELS = { | |
"GPT-2 (Small)": "gpt2", | |
"DialoGPT (Small)": "microsoft/DialoGPT-small", | |
"BLOOM (560M)": "bigscience/bloom-560m", | |
"OPT (125M)": "facebook/opt-125m", | |
"T5 (Small)": "t5-small" | |
} | |
class ChatAgent: | |
def __init__(self): | |
self.current_model_name = "gpt2" | |
self.model = None | |
self.load_model(self.current_model_name) | |
print("ChatAgent initialized with model:", self.current_model_name) | |
def load_model(self, model_name): | |
"""Load a new model and tokenizer""" | |
try: | |
print(f"Loading model: {model_name}") | |
self.current_model_name = model_name | |
if "t5" in model_name.lower(): | |
self.model = pipeline("text2text-generation", model=model_name) | |
else: | |
self.model = pipeline("text-generation", model=model_name) | |
print(f"Model {model_name} loaded successfully") | |
except Exception as e: | |
print(f"Error loading model {model_name}: {str(e)}") | |
raise e | |
def chat(self, message: str, history: list, model_choice: str) -> tuple[str, dict]: | |
"""Process a single chat message and return the response""" | |
try: | |
start_time = time.time() | |
# Check if we need to switch models | |
model_name = AVAILABLE_MODELS[model_choice] | |
if model_name != self.current_model_name: | |
self.load_model(model_name) | |
# Generate response based on model type | |
if "t5" in model_name.lower(): | |
generated = self.model(message, max_length=100, truncation=True) | |
response = generated[0]['generated_text'] | |
else: | |
generated = self.model( | |
message, | |
max_length=100, | |
pad_token_id=self.model.tokenizer.eos_token_id, | |
truncation=True | |
) | |
response = generated[0]['generated_text'] | |
# Clean up the response by removing the input message if it's included | |
if response.startswith(message): | |
response = response[len(message):].strip() | |
# Calculate metadata | |
end_time = time.time() | |
response_time = round(end_time - start_time, 2) | |
input_tokens = len(self.model.tokenizer.encode(message)) | |
output_tokens = len(self.model.tokenizer.encode(response)) | |
metadata = { | |
"response_time": f"{response_time}s", | |
"input_tokens": input_tokens, | |
"output_tokens": output_tokens, | |
"model": model_choice | |
} | |
return response, metadata | |
except Exception as e: | |
return f"Sorry, I encountered an error: {str(e)}", {} | |
def chat_response(message, history, model_choice): | |
"""Handler for chat messages""" | |
agent = ChatAgent() if not hasattr(chat_response, 'agent') else chat_response.agent | |
response, metadata = agent.chat(message, history, model_choice) | |
metadata_str = f"\n\n*[π Model: {metadata['model']} | β±οΈ Response Time: {metadata['response_time']} | π Tokens: {metadata['input_tokens']}β{metadata['output_tokens']}]*" | |
# Format messages as tuples for Gradio chat | |
if not history: | |
history = [] | |
history.append((message, response + metadata_str)) | |
return history | |
# Create custom theme | |
theme = gr.themes.Soft().set( | |
body_background_fill="white", | |
block_background_fill="white", | |
block_border_width="0", | |
button_primary_background_fill="*primary_500", | |
button_primary_text_color="white", | |
button_secondary_background_fill="*neutral_200", | |
button_secondary_text_color="*neutral_800" | |
) | |
# Add custom CSS for better styling | |
css = """ | |
body { | |
background-color: white !important; | |
} | |
#title { | |
text-align: center; | |
margin-bottom: 0.5rem; | |
padding: 0.5rem; | |
border-radius: 8px; | |
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.05); | |
} | |
.control-section { | |
background: white; | |
padding: 0.8rem; | |
margin-bottom: 0.5rem; | |
border-radius: 8px; | |
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.05); | |
} | |
.section-title { | |
padding: 2px; | |
border-radius: 5px; | |
margin-bottom: 0.3rem; | |
} | |
.quick-actions { | |
display: flex !important; | |
justify-content: space-between !important; | |
gap: 8px !important; | |
padding: 4px !important; | |
margin-bottom: 8px !important; | |
} | |
.quick-action-chip { | |
flex: 1 !important; | |
background: #f0f2f5 !important; | |
border: none !important; | |
border-radius: 20px !important; | |
padding: 8px 16px !important; | |
font-size: 0.9em !important; | |
color: #1a1a1a !important; | |
transition: all 0.2s ease-in-out !important; | |
box-shadow: none !important; | |
margin-bottom: 2px !important; | |
height: auto !important; | |
line-height: 1.2 !important; | |
min-width: 120px !important; | |
} | |
.quick-action-chip:hover { | |
background: #e4e6e9 !important; | |
transform: translateY(-1px); | |
} | |
/* Enhanced Chat Styling */ | |
.message { | |
position: relative !important; | |
padding: 12px 16px !important; | |
margin: 8px 0 !important; | |
border-radius: 12px !important; | |
max-width: 85% !important; | |
font-size: 0.95em !important; | |
line-height: 1.4 !important; | |
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1) !important; | |
} | |
/* User message styling */ | |
.message.user { | |
background: #007AFF !important; | |
color: white !important; | |
margin-left: auto !important; | |
border-bottom-right-radius: 4px !important; | |
} | |
/* Assistant message styling */ | |
.message.assistant { | |
background: #f0f2f5 !important; | |
color: #1a1a1a !important; | |
margin-right: auto !important; | |
border-bottom-left-radius: 4px !important; | |
} | |
/* Metadata styling in assistant messages */ | |
.message.assistant em { | |
display: block !important; | |
margin-top: 8px !important; | |
padding-top: 8px !important; | |
border-top: 1px solid rgba(0, 0, 0, 0.1) !important; | |
color: #666 !important; | |
font-size: 0.85em !important; | |
font-style: normal !important; | |
} | |
/* Chat container styling */ | |
.chat-container { | |
border: 1px solid #e5e7eb !important; | |
border-radius: 12px !important; | |
background: white !important; | |
padding: 16px !important; | |
height: 460px !important; | |
overflow-y: auto !important; | |
scrollbar-width: thin !important; | |
scrollbar-color: #cbd5e1 transparent !important; | |
} | |
.chat-container::-webkit-scrollbar { | |
width: 6px !important; | |
} | |
.chat-container::-webkit-scrollbar-track { | |
background: transparent !important; | |
} | |
.chat-container::-webkit-scrollbar-thumb { | |
background-color: #cbd5e1 !important; | |
border-radius: 3px !important; | |
} | |
/* Input area styling */ | |
.input-area { | |
border: 1px solid #e5e7eb !important; | |
border-radius: 12px !important; | |
background: white !important; | |
margin-top: 16px !important; | |
} | |
textarea { | |
min-height: 40px !important; | |
padding: 12px !important; | |
border-radius: 12px !important; | |
border: none !important; | |
background: transparent !important; | |
resize: none !important; | |
font-size: 0.95em !important; | |
} | |
textarea:focus { | |
outline: none !important; | |
box-shadow: 0 0 0 2px rgba(0, 122, 255, 0.1) !important; | |
} | |
.gradio-container { | |
background: white !important; | |
} | |
.contain { | |
background: white !important; | |
} | |
#clear-btn { | |
background: #f0f2f5 !important; | |
border: none !important; | |
border-radius: 8px !important; | |
color: #1a1a1a !important; | |
transition: all 0.2s ease-in-out !important; | |
width: 100% !important; | |
} | |
#clear-btn:hover { | |
background: #e4e6e9 !important; | |
} | |
""" | |
# Create Gradio interface | |
with gr.Blocks(theme=theme, css=css) as demo: | |
with gr.Row(equal_height=True): | |
gr.Markdown( | |
""" | |
# π€ Multi-Model AI Chat Assistant | |
Chat with different AI models and explore their capabilities! | |
""", | |
elem_id="title" | |
) | |
with gr.Row(equal_height=True): | |
# Left column for model selection and controls | |
with gr.Column(scale=1, min_width=300): | |
# Model Selection Section | |
with gr.Group(elem_classes="control-section"): | |
gr.Markdown( | |
""" | |
## π― Model Selection | |
""", | |
elem_classes="section-title" | |
) | |
model_dropdown = gr.Dropdown( | |
choices=list(AVAILABLE_MODELS.keys()), | |
value="GPT-2 (Small)", | |
label="Select AI Model", | |
info="Choose which AI model to chat with", | |
container=True | |
) | |
# Quick Actions Section | |
with gr.Group(elem_classes="control-section"): | |
gr.Markdown( | |
""" | |
## β‘ Quick Prompts | |
""", | |
elem_classes="section-title" | |
) | |
with gr.Group(elem_classes="quick-actions"): | |
example_btn1 = gr.Button("π Creative Story", elem_classes="quick-action-chip") | |
example_btn2 = gr.Button("π€ Explain AI", elem_classes="quick-action-chip") | |
example_btn3 = gr.Button("πΊ Write Poetry", elem_classes="quick-action-chip") | |
example_btn4 = gr.Button("βοΈ Science Facts", elem_classes="quick-action-chip") | |
# Clear Chat Section | |
with gr.Group(elem_classes="control-section"): | |
clear = gr.Button("ποΈ Clear Chat", elem_id="clear-btn", size="sm", variant="secondary", elem_classes="full-width") | |
# Right column for chat interface | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot( | |
value=[], | |
label="Chat History", | |
height=460, | |
show_copy_button=True, | |
container=False, | |
elem_classes="chat-container" | |
) | |
with gr.Row(elem_classes="input-area"): | |
msg = gr.Textbox( | |
label="Type your message", | |
placeholder="Type your message here...", | |
lines=2, | |
scale=8, | |
show_label=False, | |
container=False | |
) | |
submit = gr.Button( | |
"βοΈ Send", | |
scale=1, | |
variant="primary", | |
size="sm", | |
elem_classes="send-button" | |
) | |
# Example button click handlers | |
example_btn1.click(lambda: "Tell me a story about a brave knight", None, msg) | |
example_btn2.click(lambda: "What is artificial intelligence?", None, msg) | |
example_btn3.click(lambda: "Write a poem about nature", None, msg) | |
example_btn4.click(lambda: "Explain quantum computing in simple terms", None, msg) | |
# Set up event handlers | |
submit.click( | |
chat_response, | |
inputs=[msg, chatbot, model_dropdown], | |
outputs=[chatbot] | |
).then( | |
lambda: "", | |
None, | |
msg | |
) | |
msg.submit( | |
chat_response, | |
inputs=[msg, chatbot, model_dropdown], | |
outputs=[chatbot] | |
).then( | |
lambda: "", | |
None, | |
msg | |
) | |
clear.click(lambda: None, None, chatbot) | |
if __name__ == "__main__": | |
print("Starting Multi-Model AI Chat Interface...") | |
demo.queue().launch(debug=True, share=False) |