Anurag Prasad commited on
Commit
d65ad43
·
1 Parent(s): 39a354b

made some changes in app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -28
app.py CHANGED
@@ -5,7 +5,16 @@ from typing import Optional, List, Dict
5
  from contextlib import AsyncExitStack
6
  from mcp import ClientSession, StdioServerParameters
7
  from mcp.client.stdio import stdio_client
8
- from database_module import init_db, get_all_models_handler, search_models_handler
 
 
 
 
 
 
 
 
 
9
  import json
10
  from datetime import datetime
11
  import plotly.graph_objects as go
@@ -19,6 +28,7 @@ init_db()
19
  # app.register_tool("get_all_models", get_all_models_handler)
20
  # app.register_tool("search_models", search_models_handler)
21
 
 
22
  class MCPClient:
23
  def __init__(self):
24
  self.session: Optional[ClientSession] = None
@@ -26,33 +36,46 @@ class MCPClient:
26
 
27
  async def connect_to_server(self, server_script_path: str = "server.py"):
28
  """Connect to MCP server"""
29
- is_python = server_script_path.endswith('.py')
30
- command = "python" if is_python else "node"
31
- server_params = StdioServerParameters(
32
- command=command,
33
- args=[server_script_path],
34
- env=None
35
- )
36
- stdio_transport = await self.exit_stack.enter_async_context(
37
- stdio_client(server_params)
38
- )
39
- self.stdio, self.write = stdio_transport
40
- self.session = await self.exit_stack.enter_async_context(
41
- ClientSession(self.stdio, self.write)
42
- )
43
- await self.session.initialize()
44
- tools = (await self.session.list_tools()).tools
45
- print("Connected to server with tools:", [t.name for t in tools])
 
 
 
 
 
 
 
46
 
47
  async def call_tool(self, tool_name: str, arguments: dict):
48
  """Call a tool on the MCP server"""
49
  if not self.session:
50
- raise RuntimeError("Not connected to server")
51
- return (await self.session.call_tool(tool_name, arguments)).content
 
 
 
 
 
52
 
53
  async def close(self):
54
  """Close the MCP client connection"""
55
- await self.exit_stack.aclose()
 
56
 
57
 
58
  # Global MCP client instance
@@ -72,6 +95,28 @@ def run_async(coro):
72
  task = loop.create_task(coro)
73
  return loop.run_until_complete(task) if not task.done() else task
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  # Initialize MCP connection on startup
77
  def initialize_mcp_connection():
@@ -274,7 +319,21 @@ def save_new_model(selected_model_name, original_prompt, enhanced_prompt, choice
274
  ]
275
 
276
  final_prompt = enhanced_prompt if choice == "Keep Enhanced" else original_prompt
277
- status = save_model_to_db(selected_model_name, final_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  # Update dropdown choices
280
  updated_models = get_models_from_db()
@@ -297,7 +356,7 @@ def chatbot_response(message, history, dropdown_value):
297
  model_details = get_model_details(model_name)
298
  system_prompt = model_details.get("system_prompt", "")
299
 
300
- # Simulate response (replace with actual LLM call)
301
  response = f"[{model_name}] Response to: {message}\n(Using system prompt: {system_prompt[:50]}...)"
302
  history.append([message, response])
303
  return history, ""
@@ -308,11 +367,18 @@ def calculate_drift(dropdown_value):
308
  return "Please select a model first"
309
 
310
  model_name = extract_model_name_from_dropdown(dropdown_value, current_model_mapping)
311
- result = calculate_drift_via_mcp(model_name)
312
- drift_score = result.get("drift_score", 0.0)
313
- message = result.get("message", "")
314
 
315
- return f"Drift Score: {drift_score:.3f}\n{message}"
 
 
 
 
 
 
 
 
 
 
316
 
317
  def refresh_drift_history(dropdown_value):
318
  """Refresh drift history for selected model"""
@@ -372,6 +438,7 @@ with gr.Blocks(title="AI Model Management & Interaction Platform") as demo:
372
  # Create New Model Section (Initially Hidden)
373
  with gr.Group(visible=False) as create_new_section:
374
  gr.Markdown("#### Create New Model")
 
375
  new_model_name = gr.Dropdown(
376
  choices=[],
377
  label="Select Model Name",
@@ -526,4 +593,4 @@ with gr.Blocks(title="AI Model Management & Interaction Platform") as demo:
526
  )
527
 
528
  if __name__ == "__main__":
529
- demo.launch(share=True)
 
5
  from contextlib import AsyncExitStack
6
  from mcp import ClientSession, StdioServerParameters
7
  from mcp.client.stdio import stdio_client
8
+ # Modify imports section to include all required tools
9
+ from database_module import (
10
+ init_db,
11
+ # get_all_models_handler,
12
+ # search_models_handler,
13
+ # save_model_handler,
14
+ # get_model_details_handler,
15
+ # calculate_drift_handler,
16
+ # get_drift_history_handler
17
+ )
18
  import json
19
  from datetime import datetime
20
  import plotly.graph_objects as go
 
28
  # app.register_tool("get_all_models", get_all_models_handler)
29
  # app.register_tool("search_models", search_models_handler)
30
 
31
+ # Replace the existing MCP client class with this updated version
32
  class MCPClient:
33
  def __init__(self):
34
  self.session: Optional[ClientSession] = None
 
36
 
37
  async def connect_to_server(self, server_script_path: str = "server.py"):
38
  """Connect to MCP server"""
39
+ try:
40
+ server_params = StdioServerParameters(
41
+ command="python",
42
+ args=[server_script_path],
43
+ env=None
44
+ )
45
+ stdio_transport = await self.exit_stack.enter_async_context(
46
+ stdio_client(server_params)
47
+ )
48
+ self.stdio, self.write = stdio_transport
49
+ self.session = await self.exit_stack.enter_async_context(
50
+ ClientSession(self.stdio, self.write)
51
+ )
52
+ await self.session.initialize()
53
+
54
+ # Get available tools from server
55
+ tools_response = await self.session.list_tools()
56
+ available_tools = [t.name for t in tools_response.tools]
57
+ print("Connected to server with tools:", available_tools)
58
+
59
+ return True
60
+ except Exception as e:
61
+ print(f"Failed to connect to MCP server: {e}")
62
+ return False
63
 
64
  async def call_tool(self, tool_name: str, arguments: dict):
65
  """Call a tool on the MCP server"""
66
  if not self.session:
67
+ raise RuntimeError("Not connected to MCP server")
68
+ try:
69
+ response = await self.session.call_tool(tool_name, arguments)
70
+ return response.content
71
+ except Exception as e:
72
+ print(f"Error calling tool {tool_name}: {e}")
73
+ raise
74
 
75
  async def close(self):
76
  """Close the MCP client connection"""
77
+ if self.session:
78
+ await self.exit_stack.aclose()
79
 
80
 
81
  # Global MCP client instance
 
95
  task = loop.create_task(coro)
96
  return loop.run_until_complete(task) if not task.done() else task
97
 
98
+ def run_initial_diagnostics(model_name: str, capabilities: str):
99
+ """Run initial diagnostics for a new model"""
100
+ try:
101
+ result = run_async(mcp_client.call_tool("run_initial_diagnostics", {
102
+ "model": model_name,
103
+ "model_capabilities": capabilities
104
+ }))
105
+ return result
106
+ except Exception as e:
107
+ print(f"Error running diagnostics: {e}")
108
+ return None
109
+
110
+ def check_model_drift(model_name: str):
111
+ """Check drift for existing model"""
112
+ try:
113
+ result = run_async(mcp_client.call_tool("check_drift", {
114
+ "model": model_name
115
+ }))
116
+ return result
117
+ except Exception as e:
118
+ print(f"Error checking drift: {e}")
119
+ return None
120
 
121
  # Initialize MCP connection on startup
122
  def initialize_mcp_connection():
 
319
  ]
320
 
321
  final_prompt = enhanced_prompt if choice == "Keep Enhanced" else original_prompt
322
+
323
+ try:
324
+ # Save the model first
325
+ status = save_model_to_db(selected_model_name, final_prompt)
326
+
327
+ # Run initial diagnostics
328
+ diagnostic_result = run_initial_diagnostics(
329
+ selected_model_name,
330
+ f"System Prompt: {final_prompt}\nCapabilities: General language model capabilities"
331
+ )
332
+
333
+ if diagnostic_result:
334
+ status = f"{status}\n{diagnostic_result[0].text if isinstance(diagnostic_result, list) else diagnostic_result}"
335
+ except Exception as e:
336
+ status = f"Error saving model: {e}"
337
 
338
  # Update dropdown choices
339
  updated_models = get_models_from_db()
 
356
  model_details = get_model_details(model_name)
357
  system_prompt = model_details.get("system_prompt", "")
358
 
359
+ # Simulate response (replace with actual LLM call) //Work here
360
  response = f"[{model_name}] Response to: {message}\n(Using system prompt: {system_prompt[:50]}...)"
361
  history.append([message, response])
362
  return history, ""
 
367
  return "Please select a model first"
368
 
369
  model_name = extract_model_name_from_dropdown(dropdown_value, current_model_mapping)
 
 
 
370
 
371
+ # First try the drift calculation tool
372
+ try:
373
+ result = check_model_drift(model_name)
374
+ if result and isinstance(result, list):
375
+ return "\n".join(msg.text for msg in result)
376
+ except Exception as e:
377
+ print(f"Error calculating drift: {e}")
378
+
379
+ # Fallback to the simpler drift calculation if needed
380
+ result = calculate_drift_handler({"model_name": model_name})
381
+ return f"Drift Score: {result.get('drift_score', 0.0):.3f}\n{result.get('message', '')}"
382
 
383
  def refresh_drift_history(dropdown_value):
384
  """Refresh drift history for selected model"""
 
438
  # Create New Model Section (Initially Hidden)
439
  with gr.Group(visible=False) as create_new_section:
440
  gr.Markdown("#### Create New Model")
441
+ #work here to show options to select model
442
  new_model_name = gr.Dropdown(
443
  choices=[],
444
  label="Select Model Name",
 
593
  )
594
 
595
  if __name__ == "__main__":
596
+ demo.launch()