Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import re | |
from groq import Groq | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import io | |
import base64 | |
from datetime import datetime, timedelta | |
import json | |
def validate_api_key(api_key): | |
"""Validate if the API key has the correct format.""" | |
# Basic format check for Groq API keys (they typically start with 'gsk_') | |
if not api_key.strip(): | |
return False, "API key cannot be empty" | |
if not api_key.startswith("gsk_"): | |
return False, "Invalid API key format. Groq API keys typically start with 'gsk_'" | |
return True, "API key looks valid" | |
def test_api_connection(api_key): | |
"""Test the API connection with a minimal request.""" | |
try: | |
client = Groq(api_key=api_key) | |
# Making a minimal API call to test the connection | |
client.chat.completions.create( | |
model="llama3-70b-8192", | |
messages=[{"role": "user", "content": "test"}], | |
max_tokens=5 | |
) | |
return True, "API connection successful" | |
except Exception as e: | |
# Handle all exceptions since Groq might not expose specific error types | |
if "authentication" in str(e).lower() or "api key" in str(e).lower(): | |
return False, "Authentication failed: Invalid API key" | |
else: | |
return False, f"Error connecting to Groq API: {str(e)}" | |
# Ensure analytics directory exists | |
os.makedirs("analytics", exist_ok=True) | |
def log_chat_interaction(model, tokens_used, response_time, user_message_length): | |
"""Log chat interactions for analytics""" | |
timestamp = datetime.now().isoformat() | |
log_file = "analytics/chat_log.json" | |
log_entry = { | |
"timestamp": timestamp, | |
"model": model, | |
"tokens_used": tokens_used, | |
"response_time_sec": response_time, | |
"user_message_length": user_message_length | |
} | |
# Append to existing log or create new file | |
if os.path.exists(log_file): | |
try: | |
with open(log_file, "r") as f: | |
logs = json.load(f) | |
except: | |
logs = [] | |
else: | |
logs = [] | |
logs.append(log_entry) | |
with open(log_file, "w") as f: | |
json.dump(logs, f, indent=2) | |
def get_template_prompt(template_name): | |
"""Get system prompt for a given template name""" | |
templates = { | |
"General Assistant": "You are a helpful, harmless, and honest AI assistant.", | |
"Code Helper": "You are a programming assistant. Provide detailed code explanations and examples.", | |
"Creative Writer": "You are a creative writing assistant. Generate engaging and imaginative content.", | |
"Technical Expert": "You are a technical expert. Provide accurate, detailed technical information.", | |
"Data Analyst": "You are a data analysis assistant. Help interpret and analyze data effectively." | |
} | |
return templates.get(template_name, "") | |
def enhanced_chat_with_groq(api_key, model, user_message, temperature, max_tokens, top_p, chat_history, template_name=""): | |
"""Enhanced chat function with analytics logging""" | |
start_time = datetime.now() | |
# Get system prompt if template is provided | |
system_prompt = get_template_prompt(template_name) if template_name else "" | |
# Validate and process as before | |
is_valid, message = validate_api_key(api_key) | |
if not is_valid: | |
return chat_history + [[user_message, f"Error: {message}"]] | |
connection_valid, connection_message = test_api_connection(api_key) | |
if not connection_valid: | |
return chat_history + [[user_message, f"Error: {connection_message}"]] | |
try: | |
# Format history | |
messages = [] | |
if system_prompt: | |
messages.append({"role": "system", "content": system_prompt}) | |
for human, assistant in chat_history: | |
messages.append({"role": "user", "content": human}) | |
messages.append({"role": "assistant", "content": assistant}) | |
messages.append({"role": "user", "content": user_message}) | |
# Make API call | |
client = Groq(api_key=api_key) | |
response = client.chat.completions.create( | |
model=model, | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
top_p=top_p | |
) | |
# Calculate metrics | |
end_time = datetime.now() | |
response_time = (end_time - start_time).total_seconds() | |
tokens_used = response.usage.total_tokens | |
# Log the interaction | |
log_chat_interaction( | |
model=model, | |
tokens_used=tokens_used, | |
response_time=response_time, | |
user_message_length=len(user_message) | |
) | |
# Extract response | |
assistant_response = response.choices[0].message.content | |
return chat_history + [[user_message, assistant_response]] | |
except Exception as e: | |
error_message = f"Error: {str(e)}" | |
return chat_history + [[user_message, error_message]] | |
def clear_conversation(): | |
"""Clear the conversation history.""" | |
return [] | |
def plt_to_html(fig): | |
"""Convert matplotlib figure to HTML img tag""" | |
buf = io.BytesIO() | |
fig.savefig(buf, format="png", bbox_inches="tight") | |
buf.seek(0) | |
img_str = base64.b64encode(buf.read()).decode("utf-8") | |
plt.close(fig) | |
return f'<img src="data:image/png;base64,{img_str}" alt="Chart">' | |
def clear_analytics(): | |
"""Clear all analytics data by removing the log file""" | |
log_file = "analytics/chat_log.json" | |
if os.path.exists(log_file): | |
try: | |
os.remove(log_file) | |
return "Analytics data cleared successfully." | |
except Exception as e: | |
return f"Error clearing analytics: {str(e)}" | |
else: | |
return "No analytics data to clear." | |
def generate_analytics(): | |
"""Generate analytics from the chat log""" | |
log_file = "analytics/chat_log.json" | |
if not os.path.exists(log_file): | |
return "No analytics data available yet.", None, None | |
try: | |
with open(log_file, "r") as f: | |
logs = json.load(f) | |
if not logs: | |
return "No analytics data available yet.", None, None | |
# Convert to DataFrame | |
df = pd.DataFrame(logs) | |
df["timestamp"] = pd.to_datetime(df["timestamp"]) | |
# Generate usage by model chart | |
model_usage = df.groupby("model").agg({ | |
"tokens_used": "sum", | |
"timestamp": "count" | |
}).reset_index() | |
model_usage.columns = ["model", "total_tokens", "request_count"] | |
fig1 = plt.figure(figsize=(10, 6)) | |
plt.bar(model_usage["model"], model_usage["total_tokens"]) | |
plt.title("Token Usage by Model") | |
plt.xlabel("Model") | |
plt.ylabel("Total Tokens Used") | |
plt.xticks(rotation=45) | |
plt.tight_layout() | |
model_usage_img = plt_to_html(fig1) | |
# Generate response time chart | |
model_response_time = df.groupby("model").agg({ | |
"response_time_sec": "mean" | |
}).reset_index() | |
fig3 = plt.figure(figsize=(10, 6)) | |
plt.bar(model_response_time["model"], model_response_time["response_time_sec"]) | |
plt.title("Average Response Time by Model") | |
plt.xlabel("Model") | |
plt.ylabel("Response Time (seconds)") | |
plt.xticks(rotation=45) | |
plt.tight_layout() | |
response_time_img = plt_to_html(fig3) | |
# Summary statistics | |
total_tokens = df["tokens_used"].sum() | |
total_requests = len(df) | |
avg_response_time = df["response_time_sec"].mean() | |
# Handling the case where there might not be enough data | |
if not model_usage.empty: | |
most_used_model = model_usage.iloc[model_usage["request_count"].argmax()]["model"] | |
else: | |
most_used_model = "N/A" | |
summary = f""" | |
## Analytics Summary | |
- **Total API Requests**: {total_requests} | |
- **Total Tokens Used**: {total_tokens:,} | |
- **Average Response Time**: {avg_response_time:.2f} seconds | |
- **Most Used Model**: {most_used_model} | |
- **Date Range**: {df["timestamp"].min().date()} to {df["timestamp"].max().date()} | |
""" | |
return summary, model_usage_img, response_time_img | |
except Exception as e: | |
error_message = f"Error generating analytics: {str(e)}" | |
return error_message, None, None | |
# Define available models | |
models = [ | |
"llama3-70b-8192", | |
"llama3-8b-8192", | |
"mistral-saba-24b", | |
"gemma2-9b-it", | |
"allam-2-7b" | |
] | |
# Define templates | |
templates = ["General Assistant", "Code Helper", "Creative Writer", "Technical Expert", "Data Analyst"] | |
# Create the Gradio interface | |
with gr.Blocks(title="Groq AI Chat Playground") as app: | |
gr.Markdown("# Groq AI Chat Playground") | |
# Create tabs for Chat and Analytics | |
with gr.Tabs(): | |
with gr.Tab("Chat"): | |
# New model information accordion | |
with gr.Accordion("ℹ️ Model Information - Learn about available models", open=False): | |
gr.Markdown(""" | |
### Available Models and Use Cases | |
**llama3-70b-8192** | |
- Meta's most powerful language model | |
- 70 billion parameters with 8192 token context window | |
- Best for: Complex reasoning, sophisticated content generation, creative writing, and detailed analysis | |
- Optimal for users needing the highest quality AI responses | |
**llama3-8b-8192** | |
- Lighter version of Llama 3 | |
- 8 billion parameters with 8192 token context window | |
- Best for: Faster responses, everyday tasks, simpler queries | |
- Good balance between performance and speed | |
**mistral-saba-24b** | |
- Mistral AI's advanced model | |
- 24 billion parameters | |
- Best for: High-quality reasoning, code generation, and structured outputs | |
- Excellent for technical and professional use cases | |
**gemma2-9b-it** | |
- Google's instruction-tuned model | |
- 9 billion parameters | |
- Best for: Following specific instructions, educational content, and general knowledge queries | |
- Well-rounded performance for various tasks | |
**allam-2-7b** | |
- Specialized model from Aleph Alpha | |
- 7 billion parameters | |
- Best for: Multilingual support, concise responses, and straightforward Q&A | |
- Good for international users and simpler applications | |
*Note: Larger models generally provide higher quality responses but may take slightly longer to generate.* | |
""") | |
gr.Markdown("Enter your Groq API key to start chatting with AI models.") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
api_key_input = gr.Textbox( | |
label="Groq API Key", | |
placeholder="Enter your Groq API key (starts with gsk_)", | |
type="password" | |
) | |
with gr.Column(scale=1): | |
test_button = gr.Button("Test API Connection") | |
api_status = gr.Textbox(label="API Status", interactive=False) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
model_dropdown = gr.Dropdown( | |
choices=models, | |
label="Select Model", | |
value="llama3-70b-8192" | |
) | |
with gr.Column(scale=1): | |
template_dropdown = gr.Dropdown( | |
choices=templates, | |
label="Select Template", | |
value="General Assistant" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Accordion("Advanced Settings", open=False): | |
temperature_slider = gr.Slider( | |
minimum=0.0, maximum=1.0, value=0.7, step=0.01, | |
label="Temperature (higher = more creative, lower = more focused)" | |
) | |
max_tokens_slider = gr.Slider( | |
minimum=256, maximum=8192, value=4096, step=256, | |
label="Max Tokens (maximum length of response)" | |
) | |
top_p_slider = gr.Slider( | |
minimum=0.0, maximum=1.0, value=0.95, step=0.01, | |
label="Top P (nucleus sampling probability threshold)" | |
) | |
chatbot = gr.Chatbot(label="Conversation", height=500) | |
with gr.Row(): | |
message_input = gr.Textbox( | |
label="Your Message", | |
placeholder="Type your message here...", | |
lines=3 | |
) | |
with gr.Row(): | |
submit_button = gr.Button("Send", variant="primary") | |
clear_button = gr.Button("Clear Conversation") | |
# Analytics Dashboard Tab | |
with gr.Tab("Analytics Dashboard"): | |
with gr.Column(): | |
gr.Markdown("# Usage Analytics Dashboard") | |
with gr.Row(): | |
refresh_analytics_button = gr.Button("Refresh Analytics") | |
clear_analytics_button = gr.Button("Clear Analytics", variant="secondary") | |
analytics_status = gr.Markdown() | |
analytics_summary = gr.Markdown() | |
with gr.Row(): | |
with gr.Column(): | |
model_usage_chart = gr.HTML(label="Token Usage by Model") | |
response_time_chart = gr.HTML(label="Response Time by Model") | |
# Connect components with functions | |
submit_button.click( | |
fn=enhanced_chat_with_groq, | |
inputs=[api_key_input, model_dropdown, message_input, temperature_slider, max_tokens_slider, top_p_slider, chatbot, template_dropdown], | |
outputs=chatbot | |
).then( | |
fn=lambda: "", | |
inputs=None, | |
outputs=message_input | |
) | |
message_input.submit( | |
fn=enhanced_chat_with_groq, | |
inputs=[api_key_input, model_dropdown, message_input, temperature_slider, max_tokens_slider, top_p_slider, chatbot, template_dropdown], | |
outputs=chatbot | |
).then( | |
fn=lambda: "", | |
inputs=None, | |
outputs=message_input | |
) | |
clear_button.click( | |
fn=clear_conversation, | |
inputs=None, | |
outputs=chatbot | |
) | |
test_button.click( | |
fn=test_api_connection, | |
inputs=[api_key_input], | |
outputs=[api_status] | |
) | |
refresh_analytics_button.click( | |
fn=generate_analytics, | |
inputs=[], | |
outputs=[analytics_summary, model_usage_chart, response_time_chart] | |
) | |
clear_analytics_button.click( | |
fn=clear_analytics, | |
inputs=[], | |
outputs=[analytics_status] | |
).then( | |
fn=generate_analytics, | |
inputs=[], | |
outputs=[analytics_summary, model_usage_chart, response_time_chart] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
app.launch(share=False) |