import gradio as gr from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM import torch import time # Dictionary of available models with their configurations AVAILABLE_MODELS = { "GPT-2 (Small)": "gpt2", "DialoGPT (Small)": "microsoft/DialoGPT-small", "BLOOM (560M)": "bigscience/bloom-560m", "OPT (125M)": "facebook/opt-125m", "T5 (Small)": "t5-small" } class ChatAgent: def __init__(self): self.current_model_name = "gpt2" self.model = None self.load_model(self.current_model_name) print("ChatAgent initialized with model:", self.current_model_name) def load_model(self, model_name): """Load a new model and tokenizer""" try: print(f"Loading model: {model_name}") self.current_model_name = model_name if "t5" in model_name.lower(): self.model = pipeline("text2text-generation", model=model_name) else: self.model = pipeline("text-generation", model=model_name) print(f"Model {model_name} loaded successfully") except Exception as e: print(f"Error loading model {model_name}: {str(e)}") raise e def chat(self, message: str, history: list, model_choice: str) -> tuple[str, dict]: """Process a single chat message and return the response""" try: start_time = time.time() # Check if we need to switch models model_name = AVAILABLE_MODELS[model_choice] if model_name != self.current_model_name: self.load_model(model_name) # Generate response based on model type if "t5" in model_name.lower(): generated = self.model(message, max_length=100, truncation=True) response = generated[0]['generated_text'] else: generated = self.model( message, max_length=100, pad_token_id=self.model.tokenizer.eos_token_id, truncation=True ) response = generated[0]['generated_text'] # Clean up the response by removing the input message if it's included if response.startswith(message): response = response[len(message):].strip() # Calculate metadata end_time = time.time() response_time = round(end_time - start_time, 2) input_tokens = len(self.model.tokenizer.encode(message)) output_tokens = len(self.model.tokenizer.encode(response)) metadata = { "response_time": f"{response_time}s", "input_tokens": input_tokens, "output_tokens": output_tokens, "model": model_choice } return response, metadata except Exception as e: return f"Sorry, I encountered an error: {str(e)}", {} def chat_response(message, history, model_choice): """Handler for chat messages""" agent = ChatAgent() if not hasattr(chat_response, 'agent') else chat_response.agent response, metadata = agent.chat(message, history, model_choice) metadata_str = f"\n\n*[📊 Model: {metadata['model']} | ⏱️ Response Time: {metadata['response_time']} | 📝 Tokens: {metadata['input_tokens']}→{metadata['output_tokens']}]*" # Format messages as tuples for Gradio chat if not history: history = [] history.append((message, response + metadata_str)) return history # Create custom theme theme = gr.themes.Soft().set( body_background_fill="white", block_background_fill="white", block_border_width="0", button_primary_background_fill="*primary_500", button_primary_text_color="white", button_secondary_background_fill="*neutral_200", button_secondary_text_color="*neutral_800" ) # Add custom CSS for better styling css = """ body { background-color: white !important; } #title { text-align: center; margin-bottom: 0.5rem; padding: 0.5rem; border-radius: 8px; box-shadow: 0 2px 6px rgba(0, 0, 0, 0.05); } .control-section { background: white; padding: 0.8rem; margin-bottom: 0.5rem; border-radius: 8px; box-shadow: 0 2px 6px rgba(0, 0, 0, 0.05); } .section-title { padding: 2px; border-radius: 5px; margin-bottom: 0.3rem; } .quick-actions { display: flex !important; justify-content: space-between !important; gap: 8px !important; padding: 4px !important; margin-bottom: 8px !important; } .quick-action-chip { flex: 1 !important; background: #f0f2f5 !important; border: none !important; border-radius: 20px !important; padding: 8px 16px !important; font-size: 0.9em !important; color: #1a1a1a !important; transition: all 0.2s ease-in-out !important; box-shadow: none !important; margin-bottom: 2px !important; height: auto !important; line-height: 1.2 !important; min-width: 120px !important; } .quick-action-chip:hover { background: #e4e6e9 !important; transform: translateY(-1px); } /* Enhanced Chat Styling */ .message { position: relative !important; padding: 12px 16px !important; margin: 8px 0 !important; border-radius: 12px !important; max-width: 85% !important; font-size: 0.95em !important; line-height: 1.4 !important; box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1) !important; } /* User message styling */ .message.user { background: #007AFF !important; color: white !important; margin-left: auto !important; border-bottom-right-radius: 4px !important; } /* Assistant message styling */ .message.assistant { background: #f0f2f5 !important; color: #1a1a1a !important; margin-right: auto !important; border-bottom-left-radius: 4px !important; } /* Metadata styling in assistant messages */ .message.assistant em { display: block !important; margin-top: 8px !important; padding-top: 8px !important; border-top: 1px solid rgba(0, 0, 0, 0.1) !important; color: #666 !important; font-size: 0.85em !important; font-style: normal !important; } /* Chat container styling */ .chat-container { border: 1px solid #e5e7eb !important; border-radius: 12px !important; background: white !important; padding: 16px !important; height: 460px !important; overflow-y: auto !important; scrollbar-width: thin !important; scrollbar-color: #cbd5e1 transparent !important; } .chat-container::-webkit-scrollbar { width: 6px !important; } .chat-container::-webkit-scrollbar-track { background: transparent !important; } .chat-container::-webkit-scrollbar-thumb { background-color: #cbd5e1 !important; border-radius: 3px !important; } /* Input area styling */ .input-area { border: 1px solid #e5e7eb !important; border-radius: 12px !important; background: white !important; margin-top: 16px !important; } textarea { min-height: 40px !important; padding: 12px !important; border-radius: 12px !important; border: none !important; background: transparent !important; resize: none !important; font-size: 0.95em !important; } textarea:focus { outline: none !important; box-shadow: 0 0 0 2px rgba(0, 122, 255, 0.1) !important; } .gradio-container { background: white !important; } .contain { background: white !important; } #clear-btn { background: #f0f2f5 !important; border: none !important; border-radius: 8px !important; color: #1a1a1a !important; transition: all 0.2s ease-in-out !important; width: 100% !important; } #clear-btn:hover { background: #e4e6e9 !important; } """ # Create Gradio interface with gr.Blocks(theme=theme, css=css) as demo: with gr.Row(equal_height=True): gr.Markdown( """ # 🤖 Multi-Model AI Chat Assistant Chat with different AI models and explore their capabilities! """, elem_id="title" ) with gr.Row(equal_height=True): # Left column for model selection and controls with gr.Column(scale=1, min_width=300): # Model Selection Section with gr.Group(elem_classes="control-section"): gr.Markdown( """ ## 🎯 Model Selection """, elem_classes="section-title" ) model_dropdown = gr.Dropdown( choices=list(AVAILABLE_MODELS.keys()), value="GPT-2 (Small)", label="Select AI Model", info="Choose which AI model to chat with", container=True ) # Quick Actions Section with gr.Group(elem_classes="control-section"): gr.Markdown( """ ## ⚡ Quick Prompts """, elem_classes="section-title" ) with gr.Group(elem_classes="quick-actions"): example_btn1 = gr.Button("🎭 Creative Story", elem_classes="quick-action-chip") example_btn2 = gr.Button("🤖 Explain AI", elem_classes="quick-action-chip") example_btn3 = gr.Button("🌺 Write Poetry", elem_classes="quick-action-chip") example_btn4 = gr.Button("⚛️ Science Facts", elem_classes="quick-action-chip") # Clear Chat Section with gr.Group(elem_classes="control-section"): clear = gr.Button("🗑️ Clear Chat", elem_id="clear-btn", size="sm", variant="secondary", elem_classes="full-width") # Right column for chat interface with gr.Column(scale=2): chatbot = gr.Chatbot( value=[], label="Chat History", height=460, show_copy_button=True, container=False, elem_classes="chat-container" ) with gr.Row(elem_classes="input-area"): msg = gr.Textbox( label="Type your message", placeholder="Type your message here...", lines=2, scale=8, show_label=False, container=False ) submit = gr.Button( "✉️ Send", scale=1, variant="primary", size="sm", elem_classes="send-button" ) # Example button click handlers example_btn1.click(lambda: "Tell me a story about a brave knight", None, msg) example_btn2.click(lambda: "What is artificial intelligence?", None, msg) example_btn3.click(lambda: "Write a poem about nature", None, msg) example_btn4.click(lambda: "Explain quantum computing in simple terms", None, msg) # Set up event handlers submit.click( chat_response, inputs=[msg, chatbot, model_dropdown], outputs=[chatbot] ).then( lambda: "", None, msg ) msg.submit( chat_response, inputs=[msg, chatbot, model_dropdown], outputs=[chatbot] ).then( lambda: "", None, msg ) clear.click(lambda: None, None, chatbot) if __name__ == "__main__": print("Starting Multi-Model AI Chat Interface...") demo.queue().launch(debug=True, share=False)