Spaces:
Sleeping
Sleeping
app.py works mostly
Browse files- app.py +154 -106
- drift_detector.sqlite3 +0 -0
- ourllm.py +2 -0
app.py
CHANGED
|
@@ -2,9 +2,7 @@ import os
|
|
| 2 |
import gradio as gr
|
| 3 |
import asyncio
|
| 4 |
from typing import Optional, List, Dict
|
| 5 |
-
from
|
| 6 |
-
from mcp import ClientSession, StdioServerParameters
|
| 7 |
-
from mcp.client.stdio import stdio_client
|
| 8 |
|
| 9 |
from database_module.db import SessionLocal
|
| 10 |
from database_module.models import ModelEntry
|
|
@@ -12,8 +10,8 @@ from langchain.chat_models import init_chat_model
|
|
| 12 |
# Modify imports section to include all required tools
|
| 13 |
from database_module import (
|
| 14 |
init_db,
|
| 15 |
-
|
| 16 |
-
|
| 17 |
# save_model_handler,
|
| 18 |
# get_model_details_handler,
|
| 19 |
# calculate_drift_handler,
|
|
@@ -27,66 +25,30 @@ import plotly.graph_objects as go
|
|
| 27 |
# Create tables and register MCP handlers
|
| 28 |
init_db()
|
| 29 |
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
self.stdio, self.write = stdio_transport
|
| 53 |
-
self.session = await self.exit_stack.enter_async_context(
|
| 54 |
-
ClientSession(self.stdio, self.write)
|
| 55 |
-
)
|
| 56 |
-
await self.session.initialize()
|
| 57 |
-
|
| 58 |
-
# Get available tools from server
|
| 59 |
-
tools_response = await self.session.list_tools()
|
| 60 |
-
available_tools = [t.name for t in tools_response.tools]
|
| 61 |
-
print("Connected to server with tools:", available_tools)
|
| 62 |
-
|
| 63 |
-
return True
|
| 64 |
-
except Exception as e:
|
| 65 |
-
print(f"Failed to connect to MCP server: {e}")
|
| 66 |
-
return False
|
| 67 |
-
|
| 68 |
-
async def call_tool(self, tool_name: str, arguments: dict):
|
| 69 |
-
"""Call a tool on the MCP server"""
|
| 70 |
-
if not self.session:
|
| 71 |
-
raise RuntimeError("Not connected to MCP server")
|
| 72 |
-
try:
|
| 73 |
-
response = await self.session.call_tool(tool_name, arguments)
|
| 74 |
-
return response.content
|
| 75 |
-
except Exception as e:
|
| 76 |
-
print(f"Error calling tool {tool_name}: {e}")
|
| 77 |
-
raise
|
| 78 |
-
|
| 79 |
-
async def close(self):
|
| 80 |
-
"""Close the MCP client connection"""
|
| 81 |
-
if self.session:
|
| 82 |
-
await self.exit_stack.aclose()
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
# Global MCP client instance
|
| 86 |
-
mcp_client = MCPClient()
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
# Helper to run async functions
|
| 90 |
def run_async(coro):
|
| 91 |
try:
|
| 92 |
loop = asyncio.get_running_loop()
|
|
@@ -102,10 +64,15 @@ def run_async(coro):
|
|
| 102 |
def run_initial_diagnostics(model_name: str, capabilities: str):
|
| 103 |
"""Run initial diagnostics for a new model"""
|
| 104 |
try:
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
return result
|
| 110 |
except Exception as e:
|
| 111 |
print(f"Error running diagnostics: {e}")
|
|
@@ -114,9 +81,14 @@ def run_initial_diagnostics(model_name: str, capabilities: str):
|
|
| 114 |
def check_model_drift(model_name: str):
|
| 115 |
"""Check drift for existing model"""
|
| 116 |
try:
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
return result
|
| 121 |
except Exception as e:
|
| 122 |
print(f"Error checking drift: {e}")
|
|
@@ -125,19 +97,32 @@ def check_model_drift(model_name: str):
|
|
| 125 |
# Initialize MCP connection on startup
|
| 126 |
def initialize_mcp_connection():
|
| 127 |
try:
|
| 128 |
-
run_async(
|
| 129 |
-
print("Successfully connected to MCP server")
|
| 130 |
return True
|
| 131 |
except Exception as e:
|
| 132 |
-
print(f"Failed to connect to MCP server: {e}")
|
| 133 |
return False
|
| 134 |
|
| 135 |
|
| 136 |
# Wrapper functions remain unchanged but now call real DB-backed MCP tools
|
| 137 |
def get_models_from_db():
|
|
|
|
| 138 |
try:
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
except Exception as e:
|
| 142 |
print(f"Error getting models: {e}")
|
| 143 |
return []
|
|
@@ -148,12 +133,28 @@ def get_available_model_names():
|
|
| 148 |
|
| 149 |
|
| 150 |
def search_models_in_db(search_term: str):
|
|
|
|
| 151 |
try:
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
except Exception as e:
|
| 155 |
print(f"Error searching models: {e}")
|
|
|
|
| 156 |
return [m for m in get_models_from_db() if search_term.lower() in m["name"].lower()]
|
|
|
|
| 157 |
def format_dropdown_items(models):
|
| 158 |
"""Format dropdown items to show model name, creation date, and description preview"""
|
| 159 |
formatted_items = []
|
|
@@ -172,49 +173,96 @@ def extract_model_name_from_dropdown(dropdown_value, model_mapping):
|
|
| 172 |
return model_mapping.get(dropdown_value, dropdown_value.split(" (")[0] if dropdown_value else "")
|
| 173 |
|
| 174 |
def get_model_details(model_name: str):
|
| 175 |
-
"""Get model details from database via
|
| 176 |
try:
|
| 177 |
-
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
except Exception as e:
|
| 180 |
print(f"Error getting model details: {e}")
|
| 181 |
return {"name": model_name, "system_prompt": "You are a helpful AI assistant.", "description": ""}
|
| 182 |
|
| 183 |
def enhance_prompt_via_mcp(prompt: str):
|
| 184 |
-
"""Enhance prompt
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
def save_model_to_db(model_name: str, system_prompt: str):
|
| 193 |
-
"""Save model to database
|
| 194 |
try:
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
except Exception as e:
|
| 201 |
print(f"Error saving model: {e}")
|
| 202 |
-
return f"Error saving model: {e}"
|
| 203 |
|
| 204 |
def get_drift_history_from_db(model_name: str):
|
| 205 |
-
"""Get drift history from database
|
| 206 |
try:
|
| 207 |
-
|
| 208 |
-
return result if isinstance(result, list) else []
|
| 209 |
-
except Exception as e:
|
| 210 |
-
print(f"Error getting drift history: {e}")
|
| 211 |
-
# Fallback data for demonstration
|
| 212 |
-
return [
|
| 213 |
-
{"date": "2025-06-01", "drift_score": 0.12},
|
| 214 |
-
{"date": "2025-06-05", "drift_score": 0.18},
|
| 215 |
-
{"date": "2025-06-09", "drift_score": 0.15}
|
| 216 |
-
]
|
| 217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
def create_drift_chart(drift_history):
|
| 219 |
"""Create drift chart using plotly"""
|
| 220 |
if not drift_history:
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import asyncio
|
| 4 |
from typing import Optional, List, Dict
|
| 5 |
+
from mcp_agent.core.fastagent import FastAgent
|
|
|
|
|
|
|
| 6 |
|
| 7 |
from database_module.db import SessionLocal
|
| 8 |
from database_module.models import ModelEntry
|
|
|
|
| 10 |
# Modify imports section to include all required tools
|
| 11 |
from database_module import (
|
| 12 |
init_db,
|
| 13 |
+
get_all_models_handler,
|
| 14 |
+
search_models_handler,
|
| 15 |
# save_model_handler,
|
| 16 |
# get_model_details_handler,
|
| 17 |
# calculate_drift_handler,
|
|
|
|
| 25 |
# Create tables and register MCP handlers
|
| 26 |
init_db()
|
| 27 |
|
| 28 |
+
# Fast Agent client initialization - This is the "scapegoat" client whose drift we're detecting
|
| 29 |
+
fast = FastAgent("Scapegoat Client")
|
| 30 |
|
| 31 |
+
@fast.agent(
|
| 32 |
+
name="scapegoat",
|
| 33 |
+
instruction="You are a test client whose drift will be detected and measured over time",
|
| 34 |
+
servers=["drift-server"]
|
| 35 |
+
)
|
| 36 |
+
async def setup_agent():
|
| 37 |
+
# This function defines the scapegoat agent that will be monitored for drift
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
# Global scapegoat client instance to be monitored for drift
|
| 41 |
+
scapegoat_client = None
|
| 42 |
+
|
| 43 |
+
# Initialize the scapegoat client that will be tested for drift
|
| 44 |
+
async def initialize_scapegoat_client():
|
| 45 |
+
global scapegoat_client
|
| 46 |
+
print("Initializing scapegoat client for drift monitoring...")
|
| 47 |
+
async with fast.run() as agent:
|
| 48 |
+
scapegoat_client = agent
|
| 49 |
+
return agent
|
| 50 |
+
|
| 51 |
+
# Helper to run async functions with FastAgent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def run_async(coro):
|
| 53 |
try:
|
| 54 |
loop = asyncio.get_running_loop()
|
|
|
|
| 64 |
def run_initial_diagnostics(model_name: str, capabilities: str):
|
| 65 |
"""Run initial diagnostics for a new model"""
|
| 66 |
try:
|
| 67 |
+
# Use FastAgent's send method with a formatted message to call the tool
|
| 68 |
+
message = f"""Please call the run_initial_diagnostics tool with the following parameters:
|
| 69 |
+
model: {model_name}
|
| 70 |
+
model_capabilities: {capabilities}
|
| 71 |
+
|
| 72 |
+
This tool will generate and store baseline diagnostics for the model.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
result = run_async(scapegoat_client(message))
|
| 76 |
return result
|
| 77 |
except Exception as e:
|
| 78 |
print(f"Error running diagnostics: {e}")
|
|
|
|
| 81 |
def check_model_drift(model_name: str):
|
| 82 |
"""Check drift for existing model"""
|
| 83 |
try:
|
| 84 |
+
# Use FastAgent's send method with a formatted message to call the tool
|
| 85 |
+
message = f"""Please call the check_drift tool with the following parameters:
|
| 86 |
+
model: {model_name}
|
| 87 |
+
|
| 88 |
+
This tool will re-run diagnostics and compare to baseline for drift scoring.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
result = run_async(scapegoat_client(message))
|
| 92 |
return result
|
| 93 |
except Exception as e:
|
| 94 |
print(f"Error checking drift: {e}")
|
|
|
|
| 97 |
# Initialize MCP connection on startup
|
| 98 |
def initialize_mcp_connection():
|
| 99 |
try:
|
| 100 |
+
run_async(initialize_scapegoat_client())
|
| 101 |
+
print("Successfully connected scapegoat client to MCP server")
|
| 102 |
return True
|
| 103 |
except Exception as e:
|
| 104 |
+
print(f"Failed to connect scapegoat client to MCP server: {e}")
|
| 105 |
return False
|
| 106 |
|
| 107 |
|
| 108 |
# Wrapper functions remain unchanged but now call real DB-backed MCP tools
|
| 109 |
def get_models_from_db():
|
| 110 |
+
"""Get all models from database using direct function call"""
|
| 111 |
try:
|
| 112 |
+
# Direct function call to database_module instead of using MCP
|
| 113 |
+
result = get_all_models_handler({})
|
| 114 |
+
|
| 115 |
+
if result:
|
| 116 |
+
# Format the result to match the expected structure
|
| 117 |
+
return [
|
| 118 |
+
{
|
| 119 |
+
"name": model["name"],
|
| 120 |
+
"description": model.get("description", ""),
|
| 121 |
+
"created": model.get("created", datetime.now().strftime("%Y-%m-%d"))
|
| 122 |
+
}
|
| 123 |
+
for model in result
|
| 124 |
+
]
|
| 125 |
+
return []
|
| 126 |
except Exception as e:
|
| 127 |
print(f"Error getting models: {e}")
|
| 128 |
return []
|
|
|
|
| 133 |
|
| 134 |
|
| 135 |
def search_models_in_db(search_term: str):
|
| 136 |
+
"""Search models in database using direct function call"""
|
| 137 |
try:
|
| 138 |
+
# Direct function call to database_module instead of using MCP
|
| 139 |
+
result = search_models_handler({"search_term": search_term})
|
| 140 |
+
|
| 141 |
+
if result:
|
| 142 |
+
# Format the result to match the expected structure
|
| 143 |
+
return [
|
| 144 |
+
{
|
| 145 |
+
"name": model["name"],
|
| 146 |
+
"description": model.get("description", ""),
|
| 147 |
+
"created": model.get("created", datetime.now().strftime("%Y-%m-%d"))
|
| 148 |
+
}
|
| 149 |
+
for model in result
|
| 150 |
+
]
|
| 151 |
+
# If no results, return empty list
|
| 152 |
+
return []
|
| 153 |
except Exception as e:
|
| 154 |
print(f"Error searching models: {e}")
|
| 155 |
+
# Fallback to filtering from all models if there's an error
|
| 156 |
return [m for m in get_models_from_db() if search_term.lower() in m["name"].lower()]
|
| 157 |
+
|
| 158 |
def format_dropdown_items(models):
|
| 159 |
"""Format dropdown items to show model name, creation date, and description preview"""
|
| 160 |
formatted_items = []
|
|
|
|
| 173 |
return model_mapping.get(dropdown_value, dropdown_value.split(" (")[0] if dropdown_value else "")
|
| 174 |
|
| 175 |
def get_model_details(model_name: str):
|
| 176 |
+
"""Get model details from database via direct DB access (fallback)"""
|
| 177 |
try:
|
| 178 |
+
with SessionLocal() as session:
|
| 179 |
+
model_entry = session.query(ModelEntry).filter_by(name=model_name).first()
|
| 180 |
+
if model_entry:
|
| 181 |
+
return {
|
| 182 |
+
"name": model_entry.name,
|
| 183 |
+
"description": model_entry.description or "",
|
| 184 |
+
"system_prompt": model_entry.capabilities.split("\nSystem Prompt: ")[1] if "\nSystem Prompt: " in model_entry.capabilities else "",
|
| 185 |
+
"created": model_entry.created.strftime("%Y-%m-%d %H:%M:%S") if model_entry.created else ""
|
| 186 |
+
}
|
| 187 |
+
return {"name": model_name, "system_prompt": "You are a helpful AI assistant.", "description": ""}
|
| 188 |
except Exception as e:
|
| 189 |
print(f"Error getting model details: {e}")
|
| 190 |
return {"name": model_name, "system_prompt": "You are a helpful AI assistant.", "description": ""}
|
| 191 |
|
| 192 |
def enhance_prompt_via_mcp(prompt: str):
|
| 193 |
+
"""Enhance prompt locally since enhance_prompt tool is not available in server.py"""
|
| 194 |
+
# Provide a basic prompt enhancement functionality since server doesn't have it
|
| 195 |
+
enhanced_prompts = {
|
| 196 |
+
"helpful": f"{prompt}\n\nPlease be thorough, helpful, and provide detailed responses.",
|
| 197 |
+
"concise": f"{prompt}\n\nPlease provide concise, direct answers.",
|
| 198 |
+
"technical": f"{prompt}\n\nPlease provide technically accurate and comprehensive responses.",
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
if "helpful" in prompt.lower():
|
| 202 |
+
return enhanced_prompts["helpful"]
|
| 203 |
+
elif "concise" in prompt.lower() or "brief" in prompt.lower():
|
| 204 |
+
return enhanced_prompts["concise"]
|
| 205 |
+
elif "technical" in prompt.lower() or "detailed" in prompt.lower():
|
| 206 |
+
return enhanced_prompts["technical"]
|
| 207 |
+
else:
|
| 208 |
+
return f"{prompt}\n\nAdditional context: Be specific, helpful, and provide detailed responses while maintaining a professional tone."
|
| 209 |
|
| 210 |
def save_model_to_db(model_name: str, system_prompt: str):
|
| 211 |
+
"""Save model to database directly since save_model tool is not available in server.py"""
|
| 212 |
try:
|
| 213 |
+
# Check if model already exists
|
| 214 |
+
with SessionLocal() as session:
|
| 215 |
+
existing = session.query(ModelEntry).filter_by(name=model_name).first()
|
| 216 |
+
if existing:
|
| 217 |
+
# Update capabilities to include the new system prompt
|
| 218 |
+
capabilities = existing.capabilities
|
| 219 |
+
if "\nSystem Prompt: " in capabilities:
|
| 220 |
+
# Replace the system prompt part
|
| 221 |
+
parts = capabilities.split("\nSystem Prompt: ")
|
| 222 |
+
capabilities = f"{parts[0]}\nSystem Prompt: {system_prompt}"
|
| 223 |
+
else:
|
| 224 |
+
# Add system prompt if not present
|
| 225 |
+
capabilities = f"{capabilities}\nSystem Prompt: {system_prompt}"
|
| 226 |
+
|
| 227 |
+
existing.capabilities = capabilities
|
| 228 |
+
existing.updated = datetime.now()
|
| 229 |
+
session.commit()
|
| 230 |
+
return {"message": f"Updated existing model: {model_name}"}
|
| 231 |
+
else:
|
| 232 |
+
# Should not happen as models are registered with capabilities before calling this function
|
| 233 |
+
return {"message": f"Model {model_name} not found. Please register it first."}
|
| 234 |
except Exception as e:
|
| 235 |
print(f"Error saving model: {e}")
|
| 236 |
+
return {"message": f"Error saving model: {e}"}
|
| 237 |
|
| 238 |
def get_drift_history_from_db(model_name: str):
|
| 239 |
+
"""Get drift history from database directly without any fallbacks"""
|
| 240 |
try:
|
| 241 |
+
from database_module.models import DriftEntry
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
+
with SessionLocal() as session:
|
| 244 |
+
# Query the drift_history table for this model
|
| 245 |
+
drift_entries = session.query(DriftEntry).filter(
|
| 246 |
+
DriftEntry.model_name == model_name
|
| 247 |
+
).order_by(DriftEntry.date.desc()).all()
|
| 248 |
+
|
| 249 |
+
# If no entries found, return empty list
|
| 250 |
+
if not drift_entries:
|
| 251 |
+
return []
|
| 252 |
+
|
| 253 |
+
# Convert to the expected format
|
| 254 |
+
history = []
|
| 255 |
+
for entry in drift_entries:
|
| 256 |
+
history.append({
|
| 257 |
+
"date": entry.date.strftime("%Y-%m-%d"),
|
| 258 |
+
"drift_score": float(entry.drift_score),
|
| 259 |
+
"model": entry.model_name
|
| 260 |
+
})
|
| 261 |
+
|
| 262 |
+
return history
|
| 263 |
+
except Exception as e:
|
| 264 |
+
print(f"Error getting drift history from database: {e}")
|
| 265 |
+
return [] # Return empty list on error, no fallbacks
|
| 266 |
def create_drift_chart(drift_history):
|
| 267 |
"""Create drift chart using plotly"""
|
| 268 |
if not drift_history:
|
drift_detector.sqlite3
CHANGED
|
Binary files a/drift_detector.sqlite3 and b/drift_detector.sqlite3 differ
|
|
|
ourllm.py
CHANGED
|
@@ -3,8 +3,10 @@ from typing import List
|
|
| 3 |
import mcp.types as types
|
| 4 |
from langchain.chat_models import init_chat_model
|
| 5 |
from dotenv import load_dotenv
|
|
|
|
| 6 |
# Load environment variables from .env file
|
| 7 |
load_dotenv()
|
|
|
|
| 8 |
|
| 9 |
llm = init_chat_model("llama-3.1-8b-instant",model_provider='groq')
|
| 10 |
|
|
|
|
| 3 |
import mcp.types as types
|
| 4 |
from langchain.chat_models import init_chat_model
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
+
import os
|
| 7 |
# Load environment variables from .env file
|
| 8 |
load_dotenv()
|
| 9 |
+
print("GROQ_API_KEY is set:", "GROQ_API_KEY" in os.environ)
|
| 10 |
|
| 11 |
llm = init_chat_model("llama-3.1-8b-instant",model_provider='groq')
|
| 12 |
|