Spaces:
Paused
Paused
| import gradio as gr | |
| import os | |
| import re | |
| from groq import Groq | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import io | |
| import base64 | |
| from datetime import datetime, timedelta | |
| import json | |
| import numpy as np | |
| from statsmodels.tsa.arima.model import ARIMA | |
| from sklearn.linear_model import LinearRegression | |
| import calendar | |
| import matplotlib.dates as mdates | |
| # Set the style for better looking charts | |
| plt.style.use('ggplot') | |
| sns.set_palette("pastel") | |
| def validate_api_key(api_key): | |
| """Validate if the API key has the correct format.""" | |
| # Basic format check for Groq API keys (they typically start with 'gsk_') | |
| if not api_key.strip(): | |
| return False, "API key cannot be empty" | |
| if not api_key.startswith("gsk_"): | |
| return False, "Invalid API key format. Groq API keys typically start with 'gsk_'" | |
| return True, "API key looks valid" | |
| def test_api_connection(api_key): | |
| """Test the API connection with a minimal request.""" | |
| try: | |
| client = Groq(api_key=api_key) | |
| # Making a minimal API call to test the connection | |
| client.chat.completions.create( | |
| model="llama3-70b-8192", | |
| messages=[{"role": "user", "content": "test"}], | |
| max_tokens=5 | |
| ) | |
| return True, "API connection successful" | |
| except Exception as e: | |
| # Handle all exceptions since Groq might not expose specific error types | |
| if "authentication" in str(e).lower() or "api key" in str(e).lower(): | |
| return False, "Authentication failed: Invalid API key" | |
| else: | |
| return False, f"Error connecting to Groq API: {str(e)}" | |
| # Ensure analytics directory exists | |
| os.makedirs("analytics", exist_ok=True) | |
| def log_chat_interaction(model, tokens_used, response_time, user_message_length, message_type, session_id=None): | |
| """Enhanced log chat interactions for analytics""" | |
| timestamp = datetime.now().isoformat() | |
| # Generate a session ID if none is provided | |
| if session_id is None: | |
| session_id = f"session_{datetime.now().strftime('%Y%m%d%H%M%S')}_{hash(timestamp) % 1000}" | |
| log_file = "analytics/chat_log.json" | |
| # Extract message intent/category through simple keyword matching | |
| intent_categories = { | |
| "code": ["code", "programming", "function", "class", "algorithm", "debug"], | |
| "creative": ["story", "poem", "creative", "imagine", "write", "generate"], | |
| "technical": ["explain", "how does", "technical", "details", "documentation"], | |
| "data": ["data", "analysis", "statistics", "graph", "chart", "dataset"], | |
| "general": [] # Default category | |
| } | |
| message_content = user_message_length.lower() if isinstance(user_message_length, str) else "" | |
| message_intent = "general" | |
| for intent, keywords in intent_categories.items(): | |
| if any(keyword in message_content for keyword in keywords): | |
| message_intent = intent | |
| break | |
| log_entry = { | |
| "timestamp": timestamp, | |
| "model": model, | |
| "tokens_used": tokens_used, | |
| "response_time_sec": response_time, | |
| "message_length": len(message_content) if isinstance(message_content, str) else user_message_length, | |
| "message_type": message_type, | |
| "message_intent": message_intent, | |
| "session_id": session_id, | |
| "day_of_week": datetime.now().strftime("%A"), | |
| "hour_of_day": datetime.now().hour | |
| } | |
| # Append to existing log or create new file | |
| if os.path.exists(log_file): | |
| try: | |
| with open(log_file, "r") as f: | |
| logs = json.load(f) | |
| except: | |
| logs = [] | |
| else: | |
| logs = [] | |
| logs.append(log_entry) | |
| with open(log_file, "w") as f: | |
| json.dump(logs, f, indent=2) | |
| return session_id | |
| def get_template_prompt(template_name): | |
| """Get system prompt for a given template name""" | |
| templates = { | |
| "General Assistant": "You are a helpful, harmless, and honest AI assistant.", | |
| "Code Helper": "You are a programming assistant. Provide detailed code explanations and examples.", | |
| "Creative Writer": "You are a creative writing assistant. Generate engaging and imaginative content.", | |
| "Technical Expert": "You are a technical expert. Provide accurate, detailed technical information.", | |
| "Data Analyst": "You are a data analysis assistant. Help interpret and analyze data effectively." | |
| } | |
| return templates.get(template_name, "") | |
| def enhanced_chat_with_groq(api_key, model, user_message, temperature, max_tokens, top_p, chat_history, template_name="", session_id=None): | |
| """Enhanced chat function with analytics logging""" | |
| start_time = datetime.now() | |
| # Get system prompt if template is provided | |
| system_prompt = get_template_prompt(template_name) if template_name else "" | |
| # Validate and process as before | |
| is_valid, message = validate_api_key(api_key) | |
| if not is_valid: | |
| return chat_history + [[user_message, f"Error: {message}"]], session_id | |
| connection_valid, connection_message = test_api_connection(api_key) | |
| if not connection_valid: | |
| return chat_history + [[user_message, f"Error: {connection_message}"]], session_id | |
| try: | |
| # Format history | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| for human, assistant in chat_history: | |
| messages.append({"role": "user", "content": human}) | |
| messages.append({"role": "assistant", "content": assistant}) | |
| messages.append({"role": "user", "content": user_message}) | |
| # Make API call | |
| client = Groq(api_key=api_key) | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| top_p=top_p | |
| ) | |
| # Calculate metrics | |
| end_time = datetime.now() | |
| response_time = (end_time - start_time).total_seconds() | |
| tokens_used = response.usage.total_tokens | |
| # Determine message type based on template or content | |
| message_type = template_name if template_name else "general" | |
| # Log the interaction | |
| session_id = log_chat_interaction( | |
| model=model, | |
| tokens_used=tokens_used, | |
| response_time=response_time, | |
| user_message_length=user_message, | |
| message_type=message_type, | |
| session_id=session_id | |
| ) | |
| # Extract response | |
| assistant_response = response.choices[0].message.content | |
| return chat_history + [[user_message, assistant_response]], session_id | |
| except Exception as e: | |
| error_message = f"Error: {str(e)}" | |
| return chat_history + [[user_message, error_message]], session_id | |
| def clear_conversation(): | |
| """Clear the conversation history.""" | |
| return [], None # Return empty chat history and reset session ID | |
| def plt_to_html(fig): | |
| """Convert matplotlib figure to HTML img tag""" | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", bbox_inches="tight", dpi=100) | |
| buf.seek(0) | |
| img_str = base64.b64encode(buf.read()).decode("utf-8") | |
| plt.close(fig) | |
| return f'<img src="data:image/png;base64,{img_str}" alt="Chart">' | |
| def predict_future_usage(df, days_ahead=7): | |
| """Predict future token usage based on historical data""" | |
| if len(df) < 5: # Need a minimum amount of data for prediction | |
| return None, "Insufficient data for prediction" | |
| # Group by date and get total tokens per day | |
| df['date'] = pd.to_datetime(df['timestamp']).dt.date | |
| daily_data = df.groupby('date')['tokens_used'].sum().reset_index() | |
| daily_data['date'] = pd.to_datetime(daily_data['date']) | |
| # Sort by date | |
| daily_data = daily_data.sort_values('date') | |
| try: | |
| # Simple linear regression for prediction | |
| X = np.array(range(len(daily_data))).reshape(-1, 1) | |
| y = daily_data['tokens_used'].values | |
| model = LinearRegression() | |
| model.fit(X, y) | |
| # Predict future days | |
| future_days = pd.date_range( | |
| start=daily_data['date'].max() + timedelta(days=1), | |
| periods=days_ahead | |
| ) | |
| future_X = np.array(range(len(daily_data), len(daily_data) + days_ahead)).reshape(-1, 1) | |
| predictions = model.predict(future_X) | |
| # Create prediction dataframe | |
| prediction_df = pd.DataFrame({ | |
| 'date': future_days, | |
| 'predicted_tokens': np.maximum(predictions, 0) # Ensure no negative predictions | |
| }) | |
| # Create visualization | |
| fig = plt.figure(figsize=(12, 6)) | |
| plt.plot(daily_data['date'], daily_data['tokens_used'], 'o-', label='Historical Usage') | |
| plt.plot(prediction_df['date'], prediction_df['predicted_tokens'], 'o--', label='Predicted Usage') | |
| plt.title('Token Usage Forecast (Next 7 Days)') | |
| plt.xlabel('Date') | |
| plt.ylabel('Token Usage') | |
| plt.legend() | |
| plt.grid(True) | |
| plt.xticks(rotation=45) | |
| plt.tight_layout() | |
| return plt_to_html(fig), prediction_df | |
| except Exception as e: | |
| return None, f"Error in prediction: {str(e)}" | |
| def export_analytics_csv(df): | |
| """Export analytics data to CSV""" | |
| try: | |
| output_path = "analytics/export_" + datetime.now().strftime("%Y%m%d_%H%M%S") + ".csv" | |
| df.to_csv(output_path, index=False) | |
| return f"Data exported to {output_path}" | |
| except Exception as e: | |
| return f"Error exporting data: {str(e)}" | |
| def generate_enhanced_analytics(date_range=None): | |
| """Generate enhanced analytics from the chat log""" | |
| log_file = "analytics/chat_log.json" | |
| if not os.path.exists(log_file): | |
| return "No analytics data available yet.", None, None, None, None, None, None, None, None, [] | |
| try: | |
| with open(log_file, "r") as f: | |
| logs = json.load(f) | |
| if not logs: | |
| return "No analytics data available yet.", None, None, None, None, None, None, None, None, [] | |
| # Convert to DataFrame | |
| df = pd.DataFrame(logs) | |
| df["timestamp"] = pd.to_datetime(df["timestamp"]) | |
| # Apply date filter if provided | |
| if date_range and date_range != "all": | |
| end_date = datetime.now() | |
| if date_range == "last_7_days": | |
| start_date = end_date - timedelta(days=7) | |
| elif date_range == "last_30_days": | |
| start_date = end_date - timedelta(days=30) | |
| elif date_range == "last_90_days": | |
| start_date = end_date - timedelta(days=90) | |
| else: # Default to all time if unrecognized option | |
| start_date = df["timestamp"].min() | |
| df = df[(df["timestamp"] >= start_date) & (df["timestamp"] <= end_date)] | |
| # 1. Generate usage by model chart | |
| model_usage = df.groupby("model").agg({ | |
| "tokens_used": "sum", | |
| "timestamp": "count" | |
| }).reset_index() | |
| model_usage.columns = ["model", "total_tokens", "request_count"] | |
| fig1 = plt.figure(figsize=(10, 6)) | |
| ax1 = sns.barplot(x="model", y="total_tokens", data=model_usage) | |
| plt.title("Token Usage by Model", fontsize=14) | |
| plt.xlabel("Model", fontsize=12) | |
| plt.ylabel("Total Tokens Used", fontsize=12) | |
| plt.xticks(rotation=45) | |
| # Add values on top of bars | |
| for i, v in enumerate(model_usage["total_tokens"]): | |
| ax1.text(i, v + 0.1, f"{v:,}", ha='center') | |
| plt.tight_layout() | |
| model_usage_img = plt_to_html(fig1) | |
| # 2. Generate usage over time chart | |
| df["date"] = df["timestamp"].dt.date | |
| daily_usage = df.groupby("date").agg({ | |
| "tokens_used": "sum" | |
| }).reset_index() | |
| fig2 = plt.figure(figsize=(10, 6)) | |
| plt.plot(daily_usage["date"], daily_usage["tokens_used"], marker="o", linestyle="-", linewidth=2) | |
| plt.title("Daily Token Usage", fontsize=14) | |
| plt.xlabel("Date", fontsize=12) | |
| plt.ylabel("Tokens Used", fontsize=12) | |
| plt.grid(True, alpha=0.3) | |
| # Format x-axis dates | |
| plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) | |
| plt.gca().xaxis.set_major_locator(mdates.AutoDateLocator()) | |
| plt.xticks(rotation=45) | |
| plt.tight_layout() | |
| daily_usage_img = plt_to_html(fig2) | |
| # 3. Generate response time chart by model | |
| model_response_time = df.groupby("model").agg({ | |
| "response_time_sec": ["mean", "median", "std"] | |
| }).reset_index() | |
| model_response_time.columns = ["model", "mean_time", "median_time", "std_time"] | |
| fig3 = plt.figure(figsize=(10, 6)) | |
| ax3 = sns.barplot(x="model", y="mean_time", data=model_response_time) | |
| # Add error bars | |
| for i, v in enumerate(model_response_time["mean_time"]): | |
| std = model_response_time.iloc[i]["std_time"] | |
| if not np.isnan(std): | |
| plt.errorbar(i, v, yerr=std, fmt='none', color='black', capsize=5) | |
| plt.title("Response Time by Model", fontsize=14) | |
| plt.xlabel("Model", fontsize=12) | |
| plt.ylabel("Average Response Time (seconds)", fontsize=12) | |
| plt.xticks(rotation=45) | |
| # Add values on top of bars | |
| for i, v in enumerate(model_response_time["mean_time"]): | |
| ax3.text(i, v + 0.1, f"{v:.2f}s", ha='center') | |
| plt.tight_layout() | |
| response_time_img = plt_to_html(fig3) | |
| # 4. Usage by time of day and day of week | |
| if "hour_of_day" in df.columns and "day_of_week" in df.columns: | |
| # Map day of week to ensure correct order | |
| day_order = {day: i for i, day in enumerate(['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday'])} | |
| df['day_num'] = df['day_of_week'].map(day_order) | |
| hourly_usage = df.groupby("hour_of_day").agg({ | |
| "tokens_used": "sum" | |
| }).reset_index() | |
| daily_usage_by_weekday = df.groupby("day_of_week").agg({ | |
| "tokens_used": "sum" | |
| }).reset_index() | |
| # Sort by day of week | |
| daily_usage_by_weekday['day_num'] = daily_usage_by_weekday['day_of_week'].map(day_order) | |
| daily_usage_by_weekday = daily_usage_by_weekday.sort_values('day_num') | |
| fig4 = plt.figure(figsize=(18, 8)) | |
| # Hourly usage chart | |
| plt.subplot(1, 2, 1) | |
| sns.barplot(x="hour_of_day", y="tokens_used", data=hourly_usage) | |
| plt.title("Token Usage by Hour of Day", fontsize=14) | |
| plt.xlabel("Hour of Day", fontsize=12) | |
| plt.ylabel("Total Tokens Used", fontsize=12) | |
| plt.xticks(ticks=range(0, 24, 2)) | |
| # Daily usage chart | |
| plt.subplot(1, 2, 2) | |
| sns.barplot(x="day_of_week", y="tokens_used", data=daily_usage_by_weekday) | |
| plt.title("Token Usage by Day of Week", fontsize=14) | |
| plt.xlabel("Day of Week", fontsize=12) | |
| plt.ylabel("Total Tokens Used", fontsize=12) | |
| plt.xticks(rotation=45) | |
| plt.tight_layout() | |
| time_pattern_img = plt_to_html(fig4) | |
| else: | |
| time_pattern_img = None | |
| # 5. Message intent/type analysis | |
| if "message_intent" in df.columns: | |
| intent_usage = df.groupby("message_intent").agg({ | |
| "tokens_used": "sum", | |
| "timestamp": "count" | |
| }).reset_index() | |
| intent_usage.columns = ["intent", "total_tokens", "request_count"] | |
| fig5 = plt.figure(figsize=(12, 10)) | |
| # Pie chart for intent distribution | |
| plt.subplot(2, 1, 1) | |
| plt.pie(intent_usage["request_count"], labels=intent_usage["intent"], autopct='%1.1f%%', startangle=90) | |
| plt.axis('equal') | |
| plt.title("Message Intent Distribution", fontsize=14) | |
| # Bar chart for tokens by intent | |
| plt.subplot(2, 1, 2) | |
| sns.barplot(x="intent", y="total_tokens", data=intent_usage) | |
| plt.title("Token Usage by Message Intent", fontsize=14) | |
| plt.xlabel("Intent", fontsize=12) | |
| plt.ylabel("Total Tokens Used", fontsize=12) | |
| plt.tight_layout() | |
| intent_analysis_img = plt_to_html(fig5) | |
| else: | |
| intent_analysis_img = None | |
| # 6. Model comparison chart | |
| if len(model_usage) > 1: | |
| fig6 = plt.figure(figsize=(12, 8)) | |
| # Create metrics for comparison | |
| model_comparison = df.groupby("model").agg({ | |
| "tokens_used": ["mean", "median", "sum"], | |
| "response_time_sec": ["mean", "median"] | |
| }).reset_index() | |
| # Flatten column names | |
| model_comparison.columns = [ | |
| f"{col[0]}_{col[1]}" if col[1] else col[0] | |
| for col in model_comparison.columns | |
| ] | |
| # Calculate token efficiency (tokens per second) | |
| model_comparison["tokens_per_second"] = model_comparison["tokens_used_mean"] / model_comparison["response_time_sec_mean"] | |
| # Normalize for radar chart | |
| metrics = ['tokens_used_mean', 'response_time_sec_mean', 'tokens_per_second'] | |
| model_comparison_norm = model_comparison.copy() | |
| for metric in metrics: | |
| max_val = model_comparison[metric].max() | |
| if max_val > 0: # Avoid division by zero | |
| model_comparison_norm[f"{metric}_norm"] = model_comparison[metric] / max_val | |
| # Bar chart comparison | |
| plt.subplot(1, 2, 1) | |
| x = np.arange(len(model_comparison["model"])) | |
| width = 0.35 | |
| plt.bar(x - width/2, model_comparison["tokens_used_mean"], width, label="Avg Tokens") | |
| plt.bar(x + width/2, model_comparison["response_time_sec_mean"], width, label="Avg Time (s)") | |
| plt.xlabel("Model") | |
| plt.ylabel("Value") | |
| plt.title("Model Performance Comparison") | |
| plt.xticks(x, model_comparison["model"], rotation=45) | |
| plt.legend() | |
| # Scatter plot for efficiency | |
| plt.subplot(1, 2, 2) | |
| sns.scatterplot( | |
| x="response_time_sec_mean", | |
| y="tokens_used_mean", | |
| size="tokens_per_second", | |
| hue="model", | |
| data=model_comparison, | |
| sizes=(100, 500) | |
| ) | |
| plt.xlabel("Average Response Time (s)") | |
| plt.ylabel("Average Tokens Used") | |
| plt.title("Model Efficiency") | |
| plt.tight_layout() | |
| model_comparison_img = plt_to_html(fig6) | |
| else: | |
| model_comparison_img = None | |
| # 7. Usage prediction chart | |
| forecast_chart, prediction_data = predict_future_usage(df) | |
| # Summary statistics | |
| total_tokens = df["tokens_used"].sum() | |
| total_requests = len(df) | |
| avg_response_time = df["response_time_sec"].mean() | |
| # Cost estimation (assuming average pricing) | |
| # These rates are estimates and should be updated with actual rates | |
| estimated_cost_rates = { | |
| "llama3-70b-8192": 0.0001, # per token | |
| "llama3-8b-8192": 0.00005, | |
| "mistral-saba-24b": 0.00008, | |
| "gemma2-9b-it": 0.00006, | |
| "allam-2-7b": 0.00005 | |
| } | |
| total_estimated_cost = 0 | |
| model_costs = [] | |
| for model_name in df["model"].unique(): | |
| model_tokens = df[df["model"] == model_name]["tokens_used"].sum() | |
| rate = estimated_cost_rates.get(model_name, 0.00007) # Default to average rate if unknown | |
| cost = model_tokens * rate | |
| total_estimated_cost += cost | |
| model_costs.append({"model": model_name, "tokens": model_tokens, "cost": cost}) | |
| # Handling the case where there might not be enough data | |
| if not model_usage.empty: | |
| most_used_model = model_usage.iloc[model_usage["request_count"].argmax()]["model"] | |
| else: | |
| most_used_model = "N/A" | |
| # Create summary without nested f-strings to avoid the backslash issue | |
| summary = f""" | |
| ## Analytics Summary | |
| ### Overview | |
| - **Total API Requests**: {total_requests:,} | |
| - **Total Tokens Used**: {total_tokens:,} | |
| - **Estimated Cost**: ${total_estimated_cost:.2f} | |
| - **Average Response Time**: {avg_response_time:.2f} seconds | |
| - **Most Used Model**: {most_used_model} | |
| - **Date Range**: {df["timestamp"].min().date()} to {df["timestamp"].max().date()} | |
| ### Model Costs Breakdown | |
| """ | |
| # Add each model cost as a separate string concatenation | |
| for cost in model_costs: | |
| summary += f"- **{cost['model']}**: {cost['tokens']:,} tokens / ${cost['cost']:.2f}\n" | |
| # Continue with the rest of the summary | |
| summary += f""" | |
| ### Usage Patterns | |
| - **Busiest Day**: {df.groupby("date")["tokens_used"].sum().idxmax()} ({df[df["date"] == df.groupby("date")["tokens_used"].sum().idxmax()]["tokens_used"].sum():,} tokens) | |
| - **Most Efficient Model**: {df.groupby("model")["response_time_sec"].mean().idxmin()} ({df.groupby("model")["response_time_sec"].mean().min():.2f}s avg response) | |
| ### Forecast | |
| - **Projected Usage (Next 7 Days)**: {prediction_data["predicted_tokens"].sum():,.0f} tokens (estimated) | |
| """ | |
| return summary, model_usage_img, daily_usage_img, response_time_img, time_pattern_img, intent_analysis_img, model_comparison_img, forecast_chart, export_analytics_csv(df), df.to_dict("records") | |
| except Exception as e: | |
| error_message = f"Error generating analytics: {str(e)}" | |
| return error_message, None, None, None, None, None, None, None, None, [] | |
| # Define available models | |
| models = [ | |
| "llama3-70b-8192", | |
| "llama3-8b-8192", | |
| "mistral-saba-24b", | |
| "gemma2-9b-it", | |
| "allam-2-7b" | |
| ] | |
| # Define templates | |
| templates = ["General Assistant", "Code Helper", "Creative Writer", "Technical Expert", "Data Analyst"] | |
| # Define date range options for analytics filtering | |
| date_ranges = ["all", "last_7_days", "last_30_days", "last_90_days"] | |
| # Create the Gradio interface | |
| with gr.Blocks(title="Enhanced Groq AI Chat Playground") as app: | |
| # Store session ID (hidden from UI) | |
| session_id = gr.State(None) | |
| gr.Markdown("# Groq AI Chat Playground") | |
| # Create tabs for Chat, Analytics and Settings | |
| with gr.Tabs(): | |
| with gr.Tab("Chat"): | |
| # New model information accordion | |
| with gr.Accordion("ℹ️ Model Information - Learn about available models", open=False): | |
| gr.Markdown(""" | |
| ### Available Models and Use Cases | |
| **llama3-70b-8192** | |
| - Meta's most powerful language model | |
| - 70 billion parameters with 8192 token context window | |
| - Best for: Complex reasoning, sophisticated content generation, creative writing, and detailed analysis | |
| - Optimal for users needing the highest quality AI responses | |
| **llama3-8b-8192** | |
| - Lighter version of Llama 3 | |
| - 8 billion parameters with 8192 token context window | |
| - Best for: Faster responses, everyday tasks, simpler queries | |
| - Good balance between performance and speed | |
| **mistral-saba-24b** | |
| - Mistral AI's advanced model | |
| - 24 billion parameters | |
| - Best for: High-quality reasoning, code generation, and structured outputs | |
| - Excellent for technical and professional use cases | |
| **gemma2-9b-it** | |
| - Google's instruction-tuned model | |
| - 9 billion parameters | |
| - Best for: Following specific instructions, educational content, and general knowledge queries | |
| - Well-rounded performance for various tasks | |
| **allam-2-7b** | |
| - Specialized model from Aleph Alpha | |
| - 7 billion parameters | |
| - Best for: Multilingual support, concise responses, and straightforward Q&A | |
| - Good for international users and simpler applications | |
| *Note: Larger models generally provide higher quality responses but may take slightly longer to generate.* | |
| """) | |
| gr.Markdown("Enter your Groq API key to start chatting with AI models.") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| api_key_input = gr.Textbox( | |
| label="Groq API Key", | |
| placeholder="Enter your Groq API key (starts with gsk_)", | |
| type="password" | |
| ) | |
| with gr.Column(scale=1): | |
| test_button = gr.Button("Test API Connection") | |
| api_status = gr.Textbox(label="API Status", interactive=False) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| model_dropdown = gr.Dropdown( | |
| choices=models, | |
| label="Select Model", | |
| value="llama3-70b-8192" | |
| ) | |
| with gr.Column(scale=1): | |
| template_dropdown = gr.Dropdown( | |
| choices=templates, | |
| label="Select Template", | |
| value="General Assistant" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Accordion("Advanced Settings", open=False): | |
| temperature_slider = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.7, step=0.01, | |
| label="Temperature (higher = more creative, lower = more focused)" | |
| ) | |
| max_tokens_slider = gr.Slider( | |
| minimum=256, maximum=8192, value=4096, step=256, | |
| label="Max Tokens (maximum length of response)" | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.95, step=0.01, | |
| label="Top P (nucleus sampling probability threshold)" | |
| ) | |
| chatbot = gr.Chatbot(label="Conversation", height=500) | |
| with gr.Row(): | |
| message_input = gr.Textbox( | |
| label="Your Message", | |
| placeholder="Type your message here...", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| submit_button = gr.Button("Send", variant="primary") | |
| clear_button = gr.Button("Clear Conversation") | |
| # Enhanced Analytics Dashboard Tab | |
| with gr.Tab("Analytics Dashboard"): | |
| with gr.Column(): | |
| gr.Markdown("# Enhanced Usage Analytics Dashboard") | |
| with gr.Row(): | |
| refresh_analytics_button = gr.Button("Refresh Analytics", variant="primary") | |
| date_filter = gr.Dropdown( | |
| choices=date_ranges, | |
| value="all", | |
| label="Date Range Filter", | |
| info="Filter analytics by time period" | |
| ) | |
| export_button = gr.Button("Export Data to CSV") | |
| analytics_summary = gr.Markdown() | |
| with gr.Tabs(): | |
| with gr.Tab("Overview"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_usage_chart = gr.HTML(label="Token Usage by Model") | |
| with gr.Column(): | |
| daily_usage_chart = gr.HTML(label="Daily Token Usage") | |
| response_time_chart = gr.HTML(label="Response Time by Model") | |
| with gr.Tab("Usage Patterns"): | |
| time_pattern_chart = gr.HTML(label="Usage by Time and Day") | |
| intent_analysis_chart = gr.HTML(label="Message Intent Analysis") | |
| with gr.Tab("Model Comparison"): | |
| model_comparison_chart = gr.HTML(label="Model Performance Comparison") | |
| with gr.Tab("Forecast"): | |
| forecast_chart = gr.HTML(label="Token Usage Forecast") | |
| gr.Markdown("""This forecast uses linear regression on your historical data to predict token usage for the next 7 days. | |
| Note that predictions become more accurate with more usage data.""") | |
| with gr.Tab("Raw Data"): | |
| raw_data_table = gr.DataFrame(label="Raw Analytics Data") | |
| export_status = gr.Textbox(label="Export Status") | |
| # Define functions for button callbacks | |
| def test_api_connection_btn(api_key): | |
| """Callback for testing API connection""" | |
| is_valid, validation_message = validate_api_key(api_key) | |
| if not is_valid: | |
| return validation_message | |
| connection_valid, connection_message = test_api_connection(api_key) | |
| return connection_message | |
| def refresh_analytics_callback(date_range): | |
| """Callback for refreshing analytics dashboard""" | |
| return generate_enhanced_analytics(date_range) | |
| def export_data_callback(df_records): | |
| """Callback for exporting data to CSV""" | |
| try: | |
| df = pd.DataFrame(df_records) | |
| return export_analytics_csv(df) | |
| except Exception as e: | |
| return f"Error exporting data: {str(e)}" | |
| # Set up event handlers | |
| test_button.click( | |
| test_api_connection_btn, | |
| inputs=[api_key_input], | |
| outputs=[api_status] | |
| ) | |
| submit_button.click( | |
| enhanced_chat_with_groq, | |
| inputs=[ | |
| api_key_input, | |
| model_dropdown, | |
| message_input, | |
| temperature_slider, | |
| max_tokens_slider, | |
| top_p_slider, | |
| chatbot, | |
| template_dropdown, | |
| session_id | |
| ], | |
| outputs=[chatbot, session_id] | |
| ) | |
| message_input.submit( | |
| enhanced_chat_with_groq, | |
| inputs=[ | |
| api_key_input, | |
| model_dropdown, | |
| message_input, | |
| temperature_slider, | |
| max_tokens_slider, | |
| top_p_slider, | |
| chatbot, | |
| template_dropdown, | |
| session_id | |
| ], | |
| outputs=[chatbot, session_id] | |
| ) | |
| clear_button.click( | |
| clear_conversation, | |
| outputs=[chatbot, session_id] | |
| ) | |
| refresh_analytics_button.click( | |
| refresh_analytics_callback, | |
| inputs=[date_filter], | |
| outputs=[ | |
| analytics_summary, | |
| model_usage_chart, | |
| daily_usage_chart, | |
| response_time_chart, | |
| time_pattern_chart, | |
| intent_analysis_chart, | |
| model_comparison_chart, | |
| forecast_chart, | |
| export_status, | |
| raw_data_table | |
| ] | |
| ) | |
| export_button.click( | |
| export_data_callback, | |
| inputs=[raw_data_table], | |
| outputs=[export_status] | |
| ) | |
| # Launch the application | |
| if __name__ == "__main__": | |
| app.launch(share=False) # Set share=True for public URL |