Sars6 commited on
Commit
7209b84
·
1 Parent(s): f869a71

app.py works mostly

Browse files
Files changed (3) hide show
  1. app.py +154 -106
  2. drift_detector.sqlite3 +0 -0
  3. ourllm.py +2 -0
app.py CHANGED
@@ -2,9 +2,7 @@ import os
2
  import gradio as gr
3
  import asyncio
4
  from typing import Optional, List, Dict
5
- from contextlib import AsyncExitStack
6
- from mcp import ClientSession, StdioServerParameters
7
- from mcp.client.stdio import stdio_client
8
 
9
  from database_module.db import SessionLocal
10
  from database_module.models import ModelEntry
@@ -12,8 +10,8 @@ from langchain.chat_models import init_chat_model
12
  # Modify imports section to include all required tools
13
  from database_module import (
14
  init_db,
15
- # get_all_models_handler,
16
- # search_models_handler,
17
  # save_model_handler,
18
  # get_model_details_handler,
19
  # calculate_drift_handler,
@@ -27,66 +25,30 @@ import plotly.graph_objects as go
27
  # Create tables and register MCP handlers
28
  init_db()
29
 
 
 
30
 
31
- # Ensure server.py imports and registers these tools:
32
- # app.register_tool("get_all_models", get_all_models_handler)
33
- # app.register_tool("search_models", search_models_handler)
34
-
35
- # Replace the existing MCP client class with this updated version
36
- class MCPClient:
37
- def __init__(self):
38
- self.session: Optional[ClientSession] = None
39
- self.exit_stack = AsyncExitStack()
40
-
41
- async def connect_to_server(self, server_script_path: str = "server.py"):
42
- """Connect to MCP server"""
43
- try:
44
- server_params = StdioServerParameters(
45
- command="python",
46
- args=[server_script_path],
47
- env=None
48
- )
49
- stdio_transport = await self.exit_stack.enter_async_context(
50
- stdio_client(server_params)
51
- )
52
- self.stdio, self.write = stdio_transport
53
- self.session = await self.exit_stack.enter_async_context(
54
- ClientSession(self.stdio, self.write)
55
- )
56
- await self.session.initialize()
57
-
58
- # Get available tools from server
59
- tools_response = await self.session.list_tools()
60
- available_tools = [t.name for t in tools_response.tools]
61
- print("Connected to server with tools:", available_tools)
62
-
63
- return True
64
- except Exception as e:
65
- print(f"Failed to connect to MCP server: {e}")
66
- return False
67
-
68
- async def call_tool(self, tool_name: str, arguments: dict):
69
- """Call a tool on the MCP server"""
70
- if not self.session:
71
- raise RuntimeError("Not connected to MCP server")
72
- try:
73
- response = await self.session.call_tool(tool_name, arguments)
74
- return response.content
75
- except Exception as e:
76
- print(f"Error calling tool {tool_name}: {e}")
77
- raise
78
-
79
- async def close(self):
80
- """Close the MCP client connection"""
81
- if self.session:
82
- await self.exit_stack.aclose()
83
-
84
-
85
- # Global MCP client instance
86
- mcp_client = MCPClient()
87
-
88
-
89
- # Helper to run async functions
90
  def run_async(coro):
91
  try:
92
  loop = asyncio.get_running_loop()
@@ -102,10 +64,15 @@ def run_async(coro):
102
  def run_initial_diagnostics(model_name: str, capabilities: str):
103
  """Run initial diagnostics for a new model"""
104
  try:
105
- result = run_async(mcp_client.call_tool("run_initial_diagnostics", {
106
- "model": model_name,
107
- "model_capabilities": capabilities
108
- }))
 
 
 
 
 
109
  return result
110
  except Exception as e:
111
  print(f"Error running diagnostics: {e}")
@@ -114,9 +81,14 @@ def run_initial_diagnostics(model_name: str, capabilities: str):
114
  def check_model_drift(model_name: str):
115
  """Check drift for existing model"""
116
  try:
117
- result = run_async(mcp_client.call_tool("check_drift", {
118
- "model": model_name
119
- }))
 
 
 
 
 
120
  return result
121
  except Exception as e:
122
  print(f"Error checking drift: {e}")
@@ -125,19 +97,32 @@ def check_model_drift(model_name: str):
125
  # Initialize MCP connection on startup
126
  def initialize_mcp_connection():
127
  try:
128
- run_async(mcp_client.connect_to_server())
129
- print("Successfully connected to MCP server")
130
  return True
131
  except Exception as e:
132
- print(f"Failed to connect to MCP server: {e}")
133
  return False
134
 
135
 
136
  # Wrapper functions remain unchanged but now call real DB-backed MCP tools
137
  def get_models_from_db():
 
138
  try:
139
- result = run_async(mcp_client.call_tool("get_all_models", {}))
140
- return result if isinstance(result, list) else []
 
 
 
 
 
 
 
 
 
 
 
 
141
  except Exception as e:
142
  print(f"Error getting models: {e}")
143
  return []
@@ -148,12 +133,28 @@ def get_available_model_names():
148
 
149
 
150
  def search_models_in_db(search_term: str):
 
151
  try:
152
- result = run_async(mcp_client.call_tool("search_models", {"search_term": search_term}))
153
- return result if isinstance(result, list) else []
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  except Exception as e:
155
  print(f"Error searching models: {e}")
 
156
  return [m for m in get_models_from_db() if search_term.lower() in m["name"].lower()]
 
157
  def format_dropdown_items(models):
158
  """Format dropdown items to show model name, creation date, and description preview"""
159
  formatted_items = []
@@ -172,49 +173,96 @@ def extract_model_name_from_dropdown(dropdown_value, model_mapping):
172
  return model_mapping.get(dropdown_value, dropdown_value.split(" (")[0] if dropdown_value else "")
173
 
174
  def get_model_details(model_name: str):
175
- """Get model details from database via MCP"""
176
  try:
177
- result = run_async(mcp_client.call_tool("get_model_details", {"model_name": model_name}))
178
- return result
 
 
 
 
 
 
 
 
179
  except Exception as e:
180
  print(f"Error getting model details: {e}")
181
  return {"name": model_name, "system_prompt": "You are a helpful AI assistant.", "description": ""}
182
 
183
  def enhance_prompt_via_mcp(prompt: str):
184
- """Enhance prompt using MCP server"""
185
- try:
186
- result = run_async(mcp_client.call_tool("enhance_prompt", {"prompt": prompt}))
187
- return result.get("enhanced_prompt", prompt)
188
- except Exception as e:
189
- print(f"Error enhancing prompt: {e}")
190
- return f"Enhanced: {prompt}\n\nAdditional context: Be more specific, helpful, and provide detailed responses while maintaining a professional tone."
 
 
 
 
 
 
 
 
 
191
 
192
  def save_model_to_db(model_name: str, system_prompt: str):
193
- """Save model to database via MCP"""
194
  try:
195
- result = run_async(mcp_client.call_tool("save_model", {
196
- "model_name": model_name,
197
- "system_prompt": system_prompt
198
- }))
199
- return result.get("message", "Model saved successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  except Exception as e:
201
  print(f"Error saving model: {e}")
202
- return f"Error saving model: {e}"
203
 
204
  def get_drift_history_from_db(model_name: str):
205
- """Get drift history from database via MCP"""
206
  try:
207
- result = run_async(mcp_client.call_tool("get_drift_history", {"model_name": model_name}))
208
- return result if isinstance(result, list) else []
209
- except Exception as e:
210
- print(f"Error getting drift history: {e}")
211
- # Fallback data for demonstration
212
- return [
213
- {"date": "2025-06-01", "drift_score": 0.12},
214
- {"date": "2025-06-05", "drift_score": 0.18},
215
- {"date": "2025-06-09", "drift_score": 0.15}
216
- ]
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  def create_drift_chart(drift_history):
219
  """Create drift chart using plotly"""
220
  if not drift_history:
 
2
  import gradio as gr
3
  import asyncio
4
  from typing import Optional, List, Dict
5
+ from mcp_agent.core.fastagent import FastAgent
 
 
6
 
7
  from database_module.db import SessionLocal
8
  from database_module.models import ModelEntry
 
10
  # Modify imports section to include all required tools
11
  from database_module import (
12
  init_db,
13
+ get_all_models_handler,
14
+ search_models_handler,
15
  # save_model_handler,
16
  # get_model_details_handler,
17
  # calculate_drift_handler,
 
25
  # Create tables and register MCP handlers
26
  init_db()
27
 
28
+ # Fast Agent client initialization - This is the "scapegoat" client whose drift we're detecting
29
+ fast = FastAgent("Scapegoat Client")
30
 
31
+ @fast.agent(
32
+ name="scapegoat",
33
+ instruction="You are a test client whose drift will be detected and measured over time",
34
+ servers=["drift-server"]
35
+ )
36
+ async def setup_agent():
37
+ # This function defines the scapegoat agent that will be monitored for drift
38
+ pass
39
+
40
+ # Global scapegoat client instance to be monitored for drift
41
+ scapegoat_client = None
42
+
43
+ # Initialize the scapegoat client that will be tested for drift
44
+ async def initialize_scapegoat_client():
45
+ global scapegoat_client
46
+ print("Initializing scapegoat client for drift monitoring...")
47
+ async with fast.run() as agent:
48
+ scapegoat_client = agent
49
+ return agent
50
+
51
+ # Helper to run async functions with FastAgent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def run_async(coro):
53
  try:
54
  loop = asyncio.get_running_loop()
 
64
  def run_initial_diagnostics(model_name: str, capabilities: str):
65
  """Run initial diagnostics for a new model"""
66
  try:
67
+ # Use FastAgent's send method with a formatted message to call the tool
68
+ message = f"""Please call the run_initial_diagnostics tool with the following parameters:
69
+ model: {model_name}
70
+ model_capabilities: {capabilities}
71
+
72
+ This tool will generate and store baseline diagnostics for the model.
73
+ """
74
+
75
+ result = run_async(scapegoat_client(message))
76
  return result
77
  except Exception as e:
78
  print(f"Error running diagnostics: {e}")
 
81
  def check_model_drift(model_name: str):
82
  """Check drift for existing model"""
83
  try:
84
+ # Use FastAgent's send method with a formatted message to call the tool
85
+ message = f"""Please call the check_drift tool with the following parameters:
86
+ model: {model_name}
87
+
88
+ This tool will re-run diagnostics and compare to baseline for drift scoring.
89
+ """
90
+
91
+ result = run_async(scapegoat_client(message))
92
  return result
93
  except Exception as e:
94
  print(f"Error checking drift: {e}")
 
97
  # Initialize MCP connection on startup
98
  def initialize_mcp_connection():
99
  try:
100
+ run_async(initialize_scapegoat_client())
101
+ print("Successfully connected scapegoat client to MCP server")
102
  return True
103
  except Exception as e:
104
+ print(f"Failed to connect scapegoat client to MCP server: {e}")
105
  return False
106
 
107
 
108
  # Wrapper functions remain unchanged but now call real DB-backed MCP tools
109
  def get_models_from_db():
110
+ """Get all models from database using direct function call"""
111
  try:
112
+ # Direct function call to database_module instead of using MCP
113
+ result = get_all_models_handler({})
114
+
115
+ if result:
116
+ # Format the result to match the expected structure
117
+ return [
118
+ {
119
+ "name": model["name"],
120
+ "description": model.get("description", ""),
121
+ "created": model.get("created", datetime.now().strftime("%Y-%m-%d"))
122
+ }
123
+ for model in result
124
+ ]
125
+ return []
126
  except Exception as e:
127
  print(f"Error getting models: {e}")
128
  return []
 
133
 
134
 
135
  def search_models_in_db(search_term: str):
136
+ """Search models in database using direct function call"""
137
  try:
138
+ # Direct function call to database_module instead of using MCP
139
+ result = search_models_handler({"search_term": search_term})
140
+
141
+ if result:
142
+ # Format the result to match the expected structure
143
+ return [
144
+ {
145
+ "name": model["name"],
146
+ "description": model.get("description", ""),
147
+ "created": model.get("created", datetime.now().strftime("%Y-%m-%d"))
148
+ }
149
+ for model in result
150
+ ]
151
+ # If no results, return empty list
152
+ return []
153
  except Exception as e:
154
  print(f"Error searching models: {e}")
155
+ # Fallback to filtering from all models if there's an error
156
  return [m for m in get_models_from_db() if search_term.lower() in m["name"].lower()]
157
+
158
  def format_dropdown_items(models):
159
  """Format dropdown items to show model name, creation date, and description preview"""
160
  formatted_items = []
 
173
  return model_mapping.get(dropdown_value, dropdown_value.split(" (")[0] if dropdown_value else "")
174
 
175
  def get_model_details(model_name: str):
176
+ """Get model details from database via direct DB access (fallback)"""
177
  try:
178
+ with SessionLocal() as session:
179
+ model_entry = session.query(ModelEntry).filter_by(name=model_name).first()
180
+ if model_entry:
181
+ return {
182
+ "name": model_entry.name,
183
+ "description": model_entry.description or "",
184
+ "system_prompt": model_entry.capabilities.split("\nSystem Prompt: ")[1] if "\nSystem Prompt: " in model_entry.capabilities else "",
185
+ "created": model_entry.created.strftime("%Y-%m-%d %H:%M:%S") if model_entry.created else ""
186
+ }
187
+ return {"name": model_name, "system_prompt": "You are a helpful AI assistant.", "description": ""}
188
  except Exception as e:
189
  print(f"Error getting model details: {e}")
190
  return {"name": model_name, "system_prompt": "You are a helpful AI assistant.", "description": ""}
191
 
192
  def enhance_prompt_via_mcp(prompt: str):
193
+ """Enhance prompt locally since enhance_prompt tool is not available in server.py"""
194
+ # Provide a basic prompt enhancement functionality since server doesn't have it
195
+ enhanced_prompts = {
196
+ "helpful": f"{prompt}\n\nPlease be thorough, helpful, and provide detailed responses.",
197
+ "concise": f"{prompt}\n\nPlease provide concise, direct answers.",
198
+ "technical": f"{prompt}\n\nPlease provide technically accurate and comprehensive responses.",
199
+ }
200
+
201
+ if "helpful" in prompt.lower():
202
+ return enhanced_prompts["helpful"]
203
+ elif "concise" in prompt.lower() or "brief" in prompt.lower():
204
+ return enhanced_prompts["concise"]
205
+ elif "technical" in prompt.lower() or "detailed" in prompt.lower():
206
+ return enhanced_prompts["technical"]
207
+ else:
208
+ return f"{prompt}\n\nAdditional context: Be specific, helpful, and provide detailed responses while maintaining a professional tone."
209
 
210
  def save_model_to_db(model_name: str, system_prompt: str):
211
+ """Save model to database directly since save_model tool is not available in server.py"""
212
  try:
213
+ # Check if model already exists
214
+ with SessionLocal() as session:
215
+ existing = session.query(ModelEntry).filter_by(name=model_name).first()
216
+ if existing:
217
+ # Update capabilities to include the new system prompt
218
+ capabilities = existing.capabilities
219
+ if "\nSystem Prompt: " in capabilities:
220
+ # Replace the system prompt part
221
+ parts = capabilities.split("\nSystem Prompt: ")
222
+ capabilities = f"{parts[0]}\nSystem Prompt: {system_prompt}"
223
+ else:
224
+ # Add system prompt if not present
225
+ capabilities = f"{capabilities}\nSystem Prompt: {system_prompt}"
226
+
227
+ existing.capabilities = capabilities
228
+ existing.updated = datetime.now()
229
+ session.commit()
230
+ return {"message": f"Updated existing model: {model_name}"}
231
+ else:
232
+ # Should not happen as models are registered with capabilities before calling this function
233
+ return {"message": f"Model {model_name} not found. Please register it first."}
234
  except Exception as e:
235
  print(f"Error saving model: {e}")
236
+ return {"message": f"Error saving model: {e}"}
237
 
238
  def get_drift_history_from_db(model_name: str):
239
+ """Get drift history from database directly without any fallbacks"""
240
  try:
241
+ from database_module.models import DriftEntry
 
 
 
 
 
 
 
 
 
242
 
243
+ with SessionLocal() as session:
244
+ # Query the drift_history table for this model
245
+ drift_entries = session.query(DriftEntry).filter(
246
+ DriftEntry.model_name == model_name
247
+ ).order_by(DriftEntry.date.desc()).all()
248
+
249
+ # If no entries found, return empty list
250
+ if not drift_entries:
251
+ return []
252
+
253
+ # Convert to the expected format
254
+ history = []
255
+ for entry in drift_entries:
256
+ history.append({
257
+ "date": entry.date.strftime("%Y-%m-%d"),
258
+ "drift_score": float(entry.drift_score),
259
+ "model": entry.model_name
260
+ })
261
+
262
+ return history
263
+ except Exception as e:
264
+ print(f"Error getting drift history from database: {e}")
265
+ return [] # Return empty list on error, no fallbacks
266
  def create_drift_chart(drift_history):
267
  """Create drift chart using plotly"""
268
  if not drift_history:
drift_detector.sqlite3 CHANGED
Binary files a/drift_detector.sqlite3 and b/drift_detector.sqlite3 differ
 
ourllm.py CHANGED
@@ -3,8 +3,10 @@ from typing import List
3
  import mcp.types as types
4
  from langchain.chat_models import init_chat_model
5
  from dotenv import load_dotenv
 
6
  # Load environment variables from .env file
7
  load_dotenv()
 
8
 
9
  llm = init_chat_model("llama-3.1-8b-instant",model_provider='groq')
10
 
 
3
  import mcp.types as types
4
  from langchain.chat_models import init_chat_model
5
  from dotenv import load_dotenv
6
+ import os
7
  # Load environment variables from .env file
8
  load_dotenv()
9
+ print("GROQ_API_KEY is set:", "GROQ_API_KEY" in os.environ)
10
 
11
  llm = init_chat_model("llama-3.1-8b-instant",model_provider='groq')
12