Spaces:
Sleeping
Sleeping
app.py works mostly
Browse files- app.py +154 -106
- drift_detector.sqlite3 +0 -0
- ourllm.py +2 -0
app.py
CHANGED
@@ -2,9 +2,7 @@ import os
|
|
2 |
import gradio as gr
|
3 |
import asyncio
|
4 |
from typing import Optional, List, Dict
|
5 |
-
from
|
6 |
-
from mcp import ClientSession, StdioServerParameters
|
7 |
-
from mcp.client.stdio import stdio_client
|
8 |
|
9 |
from database_module.db import SessionLocal
|
10 |
from database_module.models import ModelEntry
|
@@ -12,8 +10,8 @@ from langchain.chat_models import init_chat_model
|
|
12 |
# Modify imports section to include all required tools
|
13 |
from database_module import (
|
14 |
init_db,
|
15 |
-
|
16 |
-
|
17 |
# save_model_handler,
|
18 |
# get_model_details_handler,
|
19 |
# calculate_drift_handler,
|
@@ -27,66 +25,30 @@ import plotly.graph_objects as go
|
|
27 |
# Create tables and register MCP handlers
|
28 |
init_db()
|
29 |
|
|
|
|
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
self.stdio, self.write = stdio_transport
|
53 |
-
self.session = await self.exit_stack.enter_async_context(
|
54 |
-
ClientSession(self.stdio, self.write)
|
55 |
-
)
|
56 |
-
await self.session.initialize()
|
57 |
-
|
58 |
-
# Get available tools from server
|
59 |
-
tools_response = await self.session.list_tools()
|
60 |
-
available_tools = [t.name for t in tools_response.tools]
|
61 |
-
print("Connected to server with tools:", available_tools)
|
62 |
-
|
63 |
-
return True
|
64 |
-
except Exception as e:
|
65 |
-
print(f"Failed to connect to MCP server: {e}")
|
66 |
-
return False
|
67 |
-
|
68 |
-
async def call_tool(self, tool_name: str, arguments: dict):
|
69 |
-
"""Call a tool on the MCP server"""
|
70 |
-
if not self.session:
|
71 |
-
raise RuntimeError("Not connected to MCP server")
|
72 |
-
try:
|
73 |
-
response = await self.session.call_tool(tool_name, arguments)
|
74 |
-
return response.content
|
75 |
-
except Exception as e:
|
76 |
-
print(f"Error calling tool {tool_name}: {e}")
|
77 |
-
raise
|
78 |
-
|
79 |
-
async def close(self):
|
80 |
-
"""Close the MCP client connection"""
|
81 |
-
if self.session:
|
82 |
-
await self.exit_stack.aclose()
|
83 |
-
|
84 |
-
|
85 |
-
# Global MCP client instance
|
86 |
-
mcp_client = MCPClient()
|
87 |
-
|
88 |
-
|
89 |
-
# Helper to run async functions
|
90 |
def run_async(coro):
|
91 |
try:
|
92 |
loop = asyncio.get_running_loop()
|
@@ -102,10 +64,15 @@ def run_async(coro):
|
|
102 |
def run_initial_diagnostics(model_name: str, capabilities: str):
|
103 |
"""Run initial diagnostics for a new model"""
|
104 |
try:
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
}
|
|
|
|
|
|
|
|
|
|
|
109 |
return result
|
110 |
except Exception as e:
|
111 |
print(f"Error running diagnostics: {e}")
|
@@ -114,9 +81,14 @@ def run_initial_diagnostics(model_name: str, capabilities: str):
|
|
114 |
def check_model_drift(model_name: str):
|
115 |
"""Check drift for existing model"""
|
116 |
try:
|
117 |
-
|
118 |
-
|
119 |
-
}
|
|
|
|
|
|
|
|
|
|
|
120 |
return result
|
121 |
except Exception as e:
|
122 |
print(f"Error checking drift: {e}")
|
@@ -125,19 +97,32 @@ def check_model_drift(model_name: str):
|
|
125 |
# Initialize MCP connection on startup
|
126 |
def initialize_mcp_connection():
|
127 |
try:
|
128 |
-
run_async(
|
129 |
-
print("Successfully connected to MCP server")
|
130 |
return True
|
131 |
except Exception as e:
|
132 |
-
print(f"Failed to connect to MCP server: {e}")
|
133 |
return False
|
134 |
|
135 |
|
136 |
# Wrapper functions remain unchanged but now call real DB-backed MCP tools
|
137 |
def get_models_from_db():
|
|
|
138 |
try:
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
except Exception as e:
|
142 |
print(f"Error getting models: {e}")
|
143 |
return []
|
@@ -148,12 +133,28 @@ def get_available_model_names():
|
|
148 |
|
149 |
|
150 |
def search_models_in_db(search_term: str):
|
|
|
151 |
try:
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
except Exception as e:
|
155 |
print(f"Error searching models: {e}")
|
|
|
156 |
return [m for m in get_models_from_db() if search_term.lower() in m["name"].lower()]
|
|
|
157 |
def format_dropdown_items(models):
|
158 |
"""Format dropdown items to show model name, creation date, and description preview"""
|
159 |
formatted_items = []
|
@@ -172,49 +173,96 @@ def extract_model_name_from_dropdown(dropdown_value, model_mapping):
|
|
172 |
return model_mapping.get(dropdown_value, dropdown_value.split(" (")[0] if dropdown_value else "")
|
173 |
|
174 |
def get_model_details(model_name: str):
|
175 |
-
"""Get model details from database via
|
176 |
try:
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
except Exception as e:
|
180 |
print(f"Error getting model details: {e}")
|
181 |
return {"name": model_name, "system_prompt": "You are a helpful AI assistant.", "description": ""}
|
182 |
|
183 |
def enhance_prompt_via_mcp(prompt: str):
|
184 |
-
"""Enhance prompt
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
def save_model_to_db(model_name: str, system_prompt: str):
|
193 |
-
"""Save model to database
|
194 |
try:
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
except Exception as e:
|
201 |
print(f"Error saving model: {e}")
|
202 |
-
return f"Error saving model: {e}"
|
203 |
|
204 |
def get_drift_history_from_db(model_name: str):
|
205 |
-
"""Get drift history from database
|
206 |
try:
|
207 |
-
|
208 |
-
return result if isinstance(result, list) else []
|
209 |
-
except Exception as e:
|
210 |
-
print(f"Error getting drift history: {e}")
|
211 |
-
# Fallback data for demonstration
|
212 |
-
return [
|
213 |
-
{"date": "2025-06-01", "drift_score": 0.12},
|
214 |
-
{"date": "2025-06-05", "drift_score": 0.18},
|
215 |
-
{"date": "2025-06-09", "drift_score": 0.15}
|
216 |
-
]
|
217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
def create_drift_chart(drift_history):
|
219 |
"""Create drift chart using plotly"""
|
220 |
if not drift_history:
|
|
|
2 |
import gradio as gr
|
3 |
import asyncio
|
4 |
from typing import Optional, List, Dict
|
5 |
+
from mcp_agent.core.fastagent import FastAgent
|
|
|
|
|
6 |
|
7 |
from database_module.db import SessionLocal
|
8 |
from database_module.models import ModelEntry
|
|
|
10 |
# Modify imports section to include all required tools
|
11 |
from database_module import (
|
12 |
init_db,
|
13 |
+
get_all_models_handler,
|
14 |
+
search_models_handler,
|
15 |
# save_model_handler,
|
16 |
# get_model_details_handler,
|
17 |
# calculate_drift_handler,
|
|
|
25 |
# Create tables and register MCP handlers
|
26 |
init_db()
|
27 |
|
28 |
+
# Fast Agent client initialization - This is the "scapegoat" client whose drift we're detecting
|
29 |
+
fast = FastAgent("Scapegoat Client")
|
30 |
|
31 |
+
@fast.agent(
|
32 |
+
name="scapegoat",
|
33 |
+
instruction="You are a test client whose drift will be detected and measured over time",
|
34 |
+
servers=["drift-server"]
|
35 |
+
)
|
36 |
+
async def setup_agent():
|
37 |
+
# This function defines the scapegoat agent that will be monitored for drift
|
38 |
+
pass
|
39 |
+
|
40 |
+
# Global scapegoat client instance to be monitored for drift
|
41 |
+
scapegoat_client = None
|
42 |
+
|
43 |
+
# Initialize the scapegoat client that will be tested for drift
|
44 |
+
async def initialize_scapegoat_client():
|
45 |
+
global scapegoat_client
|
46 |
+
print("Initializing scapegoat client for drift monitoring...")
|
47 |
+
async with fast.run() as agent:
|
48 |
+
scapegoat_client = agent
|
49 |
+
return agent
|
50 |
+
|
51 |
+
# Helper to run async functions with FastAgent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
def run_async(coro):
|
53 |
try:
|
54 |
loop = asyncio.get_running_loop()
|
|
|
64 |
def run_initial_diagnostics(model_name: str, capabilities: str):
|
65 |
"""Run initial diagnostics for a new model"""
|
66 |
try:
|
67 |
+
# Use FastAgent's send method with a formatted message to call the tool
|
68 |
+
message = f"""Please call the run_initial_diagnostics tool with the following parameters:
|
69 |
+
model: {model_name}
|
70 |
+
model_capabilities: {capabilities}
|
71 |
+
|
72 |
+
This tool will generate and store baseline diagnostics for the model.
|
73 |
+
"""
|
74 |
+
|
75 |
+
result = run_async(scapegoat_client(message))
|
76 |
return result
|
77 |
except Exception as e:
|
78 |
print(f"Error running diagnostics: {e}")
|
|
|
81 |
def check_model_drift(model_name: str):
|
82 |
"""Check drift for existing model"""
|
83 |
try:
|
84 |
+
# Use FastAgent's send method with a formatted message to call the tool
|
85 |
+
message = f"""Please call the check_drift tool with the following parameters:
|
86 |
+
model: {model_name}
|
87 |
+
|
88 |
+
This tool will re-run diagnostics and compare to baseline for drift scoring.
|
89 |
+
"""
|
90 |
+
|
91 |
+
result = run_async(scapegoat_client(message))
|
92 |
return result
|
93 |
except Exception as e:
|
94 |
print(f"Error checking drift: {e}")
|
|
|
97 |
# Initialize MCP connection on startup
|
98 |
def initialize_mcp_connection():
|
99 |
try:
|
100 |
+
run_async(initialize_scapegoat_client())
|
101 |
+
print("Successfully connected scapegoat client to MCP server")
|
102 |
return True
|
103 |
except Exception as e:
|
104 |
+
print(f"Failed to connect scapegoat client to MCP server: {e}")
|
105 |
return False
|
106 |
|
107 |
|
108 |
# Wrapper functions remain unchanged but now call real DB-backed MCP tools
|
109 |
def get_models_from_db():
|
110 |
+
"""Get all models from database using direct function call"""
|
111 |
try:
|
112 |
+
# Direct function call to database_module instead of using MCP
|
113 |
+
result = get_all_models_handler({})
|
114 |
+
|
115 |
+
if result:
|
116 |
+
# Format the result to match the expected structure
|
117 |
+
return [
|
118 |
+
{
|
119 |
+
"name": model["name"],
|
120 |
+
"description": model.get("description", ""),
|
121 |
+
"created": model.get("created", datetime.now().strftime("%Y-%m-%d"))
|
122 |
+
}
|
123 |
+
for model in result
|
124 |
+
]
|
125 |
+
return []
|
126 |
except Exception as e:
|
127 |
print(f"Error getting models: {e}")
|
128 |
return []
|
|
|
133 |
|
134 |
|
135 |
def search_models_in_db(search_term: str):
|
136 |
+
"""Search models in database using direct function call"""
|
137 |
try:
|
138 |
+
# Direct function call to database_module instead of using MCP
|
139 |
+
result = search_models_handler({"search_term": search_term})
|
140 |
+
|
141 |
+
if result:
|
142 |
+
# Format the result to match the expected structure
|
143 |
+
return [
|
144 |
+
{
|
145 |
+
"name": model["name"],
|
146 |
+
"description": model.get("description", ""),
|
147 |
+
"created": model.get("created", datetime.now().strftime("%Y-%m-%d"))
|
148 |
+
}
|
149 |
+
for model in result
|
150 |
+
]
|
151 |
+
# If no results, return empty list
|
152 |
+
return []
|
153 |
except Exception as e:
|
154 |
print(f"Error searching models: {e}")
|
155 |
+
# Fallback to filtering from all models if there's an error
|
156 |
return [m for m in get_models_from_db() if search_term.lower() in m["name"].lower()]
|
157 |
+
|
158 |
def format_dropdown_items(models):
|
159 |
"""Format dropdown items to show model name, creation date, and description preview"""
|
160 |
formatted_items = []
|
|
|
173 |
return model_mapping.get(dropdown_value, dropdown_value.split(" (")[0] if dropdown_value else "")
|
174 |
|
175 |
def get_model_details(model_name: str):
|
176 |
+
"""Get model details from database via direct DB access (fallback)"""
|
177 |
try:
|
178 |
+
with SessionLocal() as session:
|
179 |
+
model_entry = session.query(ModelEntry).filter_by(name=model_name).first()
|
180 |
+
if model_entry:
|
181 |
+
return {
|
182 |
+
"name": model_entry.name,
|
183 |
+
"description": model_entry.description or "",
|
184 |
+
"system_prompt": model_entry.capabilities.split("\nSystem Prompt: ")[1] if "\nSystem Prompt: " in model_entry.capabilities else "",
|
185 |
+
"created": model_entry.created.strftime("%Y-%m-%d %H:%M:%S") if model_entry.created else ""
|
186 |
+
}
|
187 |
+
return {"name": model_name, "system_prompt": "You are a helpful AI assistant.", "description": ""}
|
188 |
except Exception as e:
|
189 |
print(f"Error getting model details: {e}")
|
190 |
return {"name": model_name, "system_prompt": "You are a helpful AI assistant.", "description": ""}
|
191 |
|
192 |
def enhance_prompt_via_mcp(prompt: str):
|
193 |
+
"""Enhance prompt locally since enhance_prompt tool is not available in server.py"""
|
194 |
+
# Provide a basic prompt enhancement functionality since server doesn't have it
|
195 |
+
enhanced_prompts = {
|
196 |
+
"helpful": f"{prompt}\n\nPlease be thorough, helpful, and provide detailed responses.",
|
197 |
+
"concise": f"{prompt}\n\nPlease provide concise, direct answers.",
|
198 |
+
"technical": f"{prompt}\n\nPlease provide technically accurate and comprehensive responses.",
|
199 |
+
}
|
200 |
+
|
201 |
+
if "helpful" in prompt.lower():
|
202 |
+
return enhanced_prompts["helpful"]
|
203 |
+
elif "concise" in prompt.lower() or "brief" in prompt.lower():
|
204 |
+
return enhanced_prompts["concise"]
|
205 |
+
elif "technical" in prompt.lower() or "detailed" in prompt.lower():
|
206 |
+
return enhanced_prompts["technical"]
|
207 |
+
else:
|
208 |
+
return f"{prompt}\n\nAdditional context: Be specific, helpful, and provide detailed responses while maintaining a professional tone."
|
209 |
|
210 |
def save_model_to_db(model_name: str, system_prompt: str):
|
211 |
+
"""Save model to database directly since save_model tool is not available in server.py"""
|
212 |
try:
|
213 |
+
# Check if model already exists
|
214 |
+
with SessionLocal() as session:
|
215 |
+
existing = session.query(ModelEntry).filter_by(name=model_name).first()
|
216 |
+
if existing:
|
217 |
+
# Update capabilities to include the new system prompt
|
218 |
+
capabilities = existing.capabilities
|
219 |
+
if "\nSystem Prompt: " in capabilities:
|
220 |
+
# Replace the system prompt part
|
221 |
+
parts = capabilities.split("\nSystem Prompt: ")
|
222 |
+
capabilities = f"{parts[0]}\nSystem Prompt: {system_prompt}"
|
223 |
+
else:
|
224 |
+
# Add system prompt if not present
|
225 |
+
capabilities = f"{capabilities}\nSystem Prompt: {system_prompt}"
|
226 |
+
|
227 |
+
existing.capabilities = capabilities
|
228 |
+
existing.updated = datetime.now()
|
229 |
+
session.commit()
|
230 |
+
return {"message": f"Updated existing model: {model_name}"}
|
231 |
+
else:
|
232 |
+
# Should not happen as models are registered with capabilities before calling this function
|
233 |
+
return {"message": f"Model {model_name} not found. Please register it first."}
|
234 |
except Exception as e:
|
235 |
print(f"Error saving model: {e}")
|
236 |
+
return {"message": f"Error saving model: {e}"}
|
237 |
|
238 |
def get_drift_history_from_db(model_name: str):
|
239 |
+
"""Get drift history from database directly without any fallbacks"""
|
240 |
try:
|
241 |
+
from database_module.models import DriftEntry
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
|
243 |
+
with SessionLocal() as session:
|
244 |
+
# Query the drift_history table for this model
|
245 |
+
drift_entries = session.query(DriftEntry).filter(
|
246 |
+
DriftEntry.model_name == model_name
|
247 |
+
).order_by(DriftEntry.date.desc()).all()
|
248 |
+
|
249 |
+
# If no entries found, return empty list
|
250 |
+
if not drift_entries:
|
251 |
+
return []
|
252 |
+
|
253 |
+
# Convert to the expected format
|
254 |
+
history = []
|
255 |
+
for entry in drift_entries:
|
256 |
+
history.append({
|
257 |
+
"date": entry.date.strftime("%Y-%m-%d"),
|
258 |
+
"drift_score": float(entry.drift_score),
|
259 |
+
"model": entry.model_name
|
260 |
+
})
|
261 |
+
|
262 |
+
return history
|
263 |
+
except Exception as e:
|
264 |
+
print(f"Error getting drift history from database: {e}")
|
265 |
+
return [] # Return empty list on error, no fallbacks
|
266 |
def create_drift_chart(drift_history):
|
267 |
"""Create drift chart using plotly"""
|
268 |
if not drift_history:
|
drift_detector.sqlite3
CHANGED
Binary files a/drift_detector.sqlite3 and b/drift_detector.sqlite3 differ
|
|
ourllm.py
CHANGED
@@ -3,8 +3,10 @@ from typing import List
|
|
3 |
import mcp.types as types
|
4 |
from langchain.chat_models import init_chat_model
|
5 |
from dotenv import load_dotenv
|
|
|
6 |
# Load environment variables from .env file
|
7 |
load_dotenv()
|
|
|
8 |
|
9 |
llm = init_chat_model("llama-3.1-8b-instant",model_provider='groq')
|
10 |
|
|
|
3 |
import mcp.types as types
|
4 |
from langchain.chat_models import init_chat_model
|
5 |
from dotenv import load_dotenv
|
6 |
+
import os
|
7 |
# Load environment variables from .env file
|
8 |
load_dotenv()
|
9 |
+
print("GROQ_API_KEY is set:", "GROQ_API_KEY" in os.environ)
|
10 |
|
11 |
llm = init_chat_model("llama-3.1-8b-instant",model_provider='groq')
|
12 |
|