Sars6 commited on
Commit
dc9d63b
Β·
1 Parent(s): 9fa019c

Idk this works. The llm added it's own stuff.

Browse files
Files changed (2) hide show
  1. ourllm.py +7 -0
  2. server.py +152 -72
ourllm.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+
2
+ def genratequestionnaire(model, capabilities):
3
+ return None
4
+
5
+
6
+ def gradeanswers(old_answers, new_answers):
7
+ return None
server.py CHANGED
@@ -1,92 +1,172 @@
1
- # server.py
2
  import asyncio
 
 
 
 
 
 
3
  from mcp.server import Server
4
  from mcp.server.stdio import stdio_server
5
- import mcp.types as types
6
 
7
- # Define diagnostic prompts statically for now
8
- PROMPTS = {
9
- "drift-diagnostics": types.Prompt(
10
- name="drift-diagnostics",
11
- description="Run a diagnostic questionnaire to test LLM consistency.",
12
- arguments=[],
13
- )
14
- }
15
-
16
- # Setup server
17
- app = Server("mcp-drift-server", version="0.1.0")
18
-
19
-
20
- @app.list_prompts()
21
- async def list_prompts() -> list[types.Prompt]:
22
- return list(PROMPTS.values())
23
-
24
-
25
- @app.get_prompt()
26
- async def get_prompt(name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
27
- if name not in PROMPTS:
28
- raise ValueError(f"Prompt not found: {name}")
29
-
30
- # Static message for MVP – replace with dynamic question set later
31
- return types.GetPromptResult(
32
- messages=[
33
- types.PromptMessage(
34
- role="user",
35
- content=types.TextContent(
36
- type="text",
37
- text="Answer the following: What's the capital of France?"
38
- )
39
- ),
40
- types.PromptMessage(
41
- role="user",
42
- content=types.TextContent(
43
- type="text",
44
- text="Explain why the sky is blue."
45
- )
46
- ),
47
- ]
48
- )
49
 
50
- from mcp.server import Server
51
- import mcp.types as types
52
 
53
- # Assuming 'app' is your MCP Server instance
54
 
55
- async def sample(app: Server, messages: list[types.SamplingMessage]):
56
- result = await app.request_context.session.create_message(
57
- messages=messages,
58
- max_tokens=300,
59
- temperature=0.7
60
- )
61
- return result
62
 
63
  @app.list_tools()
64
  async def list_tools() -> list[types.Tool]:
65
  return [
66
  types.Tool(
67
- name="init_diagnostics",
68
- description="Run diagnostic questionnaire on the connected LLM.",
69
- inputSchema={"model_name": "Name of the LLM model"},
70
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  ]
72
 
73
- @app.call_tool()
74
- async def call_tool(name: str, arguments: dict[str, str] | None = None) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
75
- """
76
- Initializes diagnostics by running the questionnaire on the connected LLM.
77
- """
78
- # You could fetch dynamic questions here if needed
79
- questions = [
80
- types.SamplingMessage(role="user", content=types.TextContent(type="text", text="What is the capital of France?")),
81
- types.SamplingMessage(role="user", content=types.TextContent(type="text", text="Why is the sky blue?")),
82
- ]
83
 
84
- response = await sample(app, questions)
 
 
 
 
 
 
 
85
 
86
- # Return the assistant’s message(s) back to the caller
87
- return [types.TextContent(type="text", text=str(response.content))]
 
88
 
89
- # Main entrypoint
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  async def main():
91
  async with stdio_server() as streams:
92
  await app.run(streams[0], streams[1], app.create_initialization_options())
 
 
1
  import asyncio
2
+ import json
3
+ import os
4
+ from typing import Any
5
+
6
+ import mcp.types as types
7
+ from mcp import CreateMessageResult
8
  from mcp.server import Server
9
  from mcp.server.stdio import stdio_server
 
10
 
11
+ from ourllm import genratequestionnaire, gradeanswers
12
+
13
+ DATA_DIR = "data"
14
+ os.makedirs(DATA_DIR, exist_ok=True)
15
+
16
+ app = Server("mcp-drift-server")
17
+
18
+ registered_models = {}
19
+
20
+ def get_all_models():
21
+ """Retrieve all registered models."""
22
+ return list(registered_models.keys())
23
+
24
+ def search_models(query: str):
25
+ """Search registered models by name."""
26
+ return [model for model in registered_models if query.lower() in model.lower()]
27
+
28
+ def get_model_details(model_name: str):
29
+ """Get details of a specific model."""
30
+ return registered_models.get(model_name, None)
31
+
32
+ def save_model(model_name: str, model_details: dict):
33
+ """Save a new model or update an existing one."""
34
+ registered_models[model_name] = model_details
35
+ with open(os.path.join(DATA_DIR, "models.json"), "w") as f:
36
+ json.dump(registered_models, f, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
 
 
38
 
 
39
 
 
 
 
 
 
 
 
40
 
41
  @app.list_tools()
42
  async def list_tools() -> list[types.Tool]:
43
  return [
44
  types.Tool(
45
+ name="run_initial_diagnostics",
46
+ description="Generate and store baseline diagnostics for a connected LLM.",
47
+ inputSchema={"type":"object",
48
+ "properties": {
49
+ "model": {
50
+ "type": "string",
51
+ "description": "The name of the model to run diagnostics on"
52
+ },
53
+ "model_capabilities": {
54
+ "type": "string",
55
+ "description": "Full description of the model's capabilities, including any special features"
56
+ }
57
+ },
58
+
59
+ "required": ["model", "model_capabilities"]},
60
+ ),
61
+ types.Tool(
62
+ name="check_drift",
63
+ description="Re-run diagnostics and compare to baseline for drift scoring.",
64
+ inputSchema={"type":"object",
65
+ "properties": {
66
+ "model": {
67
+ "type": "string",
68
+ "description": "The name of the model to run diagnostics on"
69
+ },
70
+ },
71
+
72
+ "required": ["model"]},
73
+ ),
74
  ]
75
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ # === Sampling Wrapper ===
78
+ async def sample(messages: list[types.SamplingMessage], max_tokens=300) -> CreateMessageResult:
79
+ return await app.request_context.session.create_message(
80
+ messages=messages,
81
+ max_tokens=max_tokens,
82
+ temperature=0.7
83
+ )
84
+
85
 
86
+ # === Baseline File Paths ===
87
+ def get_baseline_path(model_name):
88
+ return os.path.join(DATA_DIR, f"{model_name}_baseline.json")
89
 
90
+
91
+ def get_response_path(model_name):
92
+ return os.path.join(DATA_DIR, f"{model_name}_latest.json")
93
+
94
+
95
+ # === Core Logic ===
96
+
97
+
98
+ async def run_initial_diagnostics(arguments: dict[str, Any]) -> list[types.TextContent]:
99
+ if arguments and "model" in arguments:
100
+ model = arguments["model"]
101
+ else:
102
+ raise(ValueError("Model details is required"))
103
+
104
+ # 1. Ask the server's internal LLM to generate a questionnaire
105
+
106
+ questions = await genratequestionnaire(model, arguments["model_capabilities"]) # Server-side trusted LLM
107
+
108
+ # 2. Send questionnaire to target LLM (i.e., the client)
109
+ answers = await sample(questions) # Client model's answers
110
+
111
+ # 3. Save Q/A pair
112
+ with open(get_baseline_path(model), "w") as f:
113
+ json.dump({
114
+ "questions": [m.content.text for m in questions],
115
+ "answers": [m.content.text for m in answers]
116
+ }, f, indent=2)
117
+
118
+ return [types.TextContent(type="text", text="Baseline stored for model: " + model)]
119
+
120
+
121
+
122
+ async def check_drift(arguments: dict[str, str]) -> list[types.TextContent]:
123
+ if arguments and "model" in arguments:
124
+ model = arguments["model"]
125
+ else:
126
+ raise (ValueError("Model details is required"))
127
+
128
+ baseline_path = get_baseline_path(model)
129
+ if not os.path.exists(baseline_path):
130
+ return [types.TextContent(type="text", text="No baseline exists for model: " + model)]
131
+
132
+ with open(baseline_path) as f:
133
+ data = json.load(f)
134
+ questions = [types.SamplingMessage(role="user", content=types.TextContent(type="text", text=q)) for q in
135
+ data["questions"]]
136
+ old_answers = data["answers"]
137
+
138
+ # 1. Ask the model again
139
+ new_answers_msgs = await sample(questions)
140
+ new_answers = [m.content.text for m in new_answers_msgs]
141
+
142
+
143
+ grading_response = await gradeanswers(old_answers, new_answers)
144
+ drift_score = grading_response[0].content.text.strip()
145
+
146
+ # 3. Save the response
147
+ with open(get_response_path(model), "w") as f:
148
+ json.dump({
149
+ "new_answers": new_answers,
150
+ "drift_score": drift_score
151
+ }, f, indent=2)
152
+
153
+ # 4. Optionally alert if high drift
154
+ alert = "🚨 Significant drift detected!" if float(drift_score) > 50 else "βœ… Drift within acceptable limits."
155
+
156
+ return [
157
+ types.TextContent(type="text", text=f"Drift score for {model}: {drift_score}"),
158
+ types.TextContent(type="text", text=alert)
159
+ ]
160
+ @app.call_tool()
161
+ async def call_tool(name: str, arguments: dict[str, Any] | None = None):
162
+ if name == "run_initial_diagnostics":
163
+ return await run_initial_diagnostics(arguments)
164
+ elif name == "check_drift":
165
+ return await check_drift(arguments)
166
+ else:
167
+ raise ValueError(f"Unknown tool: {name}")
168
+
169
+ # === Entrypoint ===
170
  async def main():
171
  async with stdio_server() as streams:
172
  await app.run(streams[0], streams[1], app.create_initialization_options())