Spaces:
Sleeping
Sleeping
import asyncio | |
import os | |
from typing import Any, List, Dict | |
import mcp.types as types | |
from mcp import CreateMessageResult | |
from mcp.server import Server | |
from mcp.server.stdio import stdio_server | |
from ourllm import genratequestionnaire, gradeanswers | |
from database_module import init_db | |
from database_module import ( | |
get_all_models_handler, | |
search_models_handler, | |
save_diagnostic_data, | |
get_baseline_diagnostics, | |
save_drift_score, | |
register_model_with_capabilities | |
) | |
# Initialize data directory and database | |
DATA_DIR = "data" | |
os.makedirs(DATA_DIR, exist_ok=True) | |
init_db() | |
app = Server("mcp-drift-server") | |
# === Tool Manifest === | |
async def list_tools() -> List[types.Tool]: | |
return [ | |
types.Tool( | |
name="run_initial_diagnostics", | |
description="Generate and store baseline diagnostics for a connected LLM.", | |
inputSchema={ | |
"type": "object", | |
"properties": { | |
"model": {"type": "string", "description": "The name of the model to run diagnostics on"}, | |
"model_capabilities": {"type": "string", | |
"description": "Full description of the model's capabilities, along with the system prompt."} | |
}, | |
"required": ["model", "model_capabilities"] | |
}, | |
), | |
types.Tool( | |
name="check_drift", | |
description="Re-run diagnostics and compare to baseline for drift scoring.", | |
inputSchema={ | |
"type": "object", | |
"properties": { | |
"model": {"type": "string", "description": "The name of the model to run diagnostics on"}}, | |
"required": ["model"] | |
}, | |
), | |
types.Tool( | |
name="get_all_models", | |
description="Retrieve all registered models from the database.", | |
inputSchema={"type": "object", "properties": {}, "required": []} | |
), | |
types.Tool( | |
name="search_models", | |
description="Search registered models by name.", | |
inputSchema={ | |
"type": "object", | |
"properties": {"query": {"type": "string", "description": "Substring to match model names against"}}, | |
"required": ["query"] | |
} | |
), | |
] | |
# === Sampling Wrapper === | |
async def sample(messages: list[types.SamplingMessage], max_tokens=600) -> CreateMessageResult: | |
try: | |
return await app.request_context.session.create_message( | |
messages=messages, | |
max_tokens=max_tokens, | |
temperature=0.7 | |
) | |
except Exception as e: | |
print(f"Error in sampling: {e}") | |
# Return a fallback response | |
return CreateMessageResult( | |
content=types.TextContent(type="text", text="Error generating response"), | |
model="unknown", | |
role="assistant" | |
) | |
# === Core Logic === | |
async def run_initial_diagnostics(arguments: Dict[str, Any]) -> List[types.TextContent]: | |
model = arguments["model"] | |
capabilities = arguments["model_capabilities"] | |
try: | |
# 1. Generate questionnaire using ourllm (returns list of strings) | |
questions = genratequestionnaire(model, capabilities) | |
# 2. Convert questions to sampling messages and get answers | |
answers = [] | |
for question_text in questions: | |
try: | |
sampling_msg = types.SamplingMessage( | |
role="user", | |
content=types.TextContent(type="text", text=question_text) | |
) | |
answer_result = await sample([sampling_msg]) | |
# Extract text content from the answer | |
if hasattr(answer_result, 'content'): | |
if hasattr(answer_result.content, 'text'): | |
answers.append(answer_result.content.text) | |
else: | |
answers.append(str(answer_result.content)) | |
else: | |
answers.append("No response generated") | |
except Exception as e: | |
print(f"Error getting answer for question '{question_text}': {e}") | |
answers.append(f"Error: {str(e)}") | |
# 3. Save the model capabilities and questions/answers to database | |
try: | |
register_model_with_capabilities(model, capabilities) | |
save_diagnostic_data( | |
model_name=model, | |
questions=questions, | |
answers=answers, | |
is_baseline=True | |
) | |
except Exception as e: | |
print(f"Error saving diagnostic data: {e}") | |
return [types.TextContent(type="text", text=f"β Error saving baseline for {model}: {str(e)}")] | |
return [ | |
types.TextContent(type="text", text=f"β Baseline stored for model: {model} ({len(questions)} questions)")] | |
except Exception as e: | |
print(f"Error in run_initial_diagnostics: {e}") | |
return [types.TextContent(type="text", text=f"β Error running diagnostics for {model}: {str(e)}")] | |
async def check_drift(arguments: Dict[str, Any]) -> List[types.TextContent]: | |
model = arguments["model"] | |
try: | |
# Get baseline from database | |
baseline = get_baseline_diagnostics(model) | |
# Ensure baseline exists | |
if not baseline: | |
return [types.TextContent(type="text", text=f"β No baseline for model: {model}")] | |
# Get old answers from baseline | |
old_answers = baseline["answers"] | |
questions = baseline["questions"] | |
# Ask the model the same questions again | |
new_answers = [] | |
for question_text in questions: | |
try: | |
sampling_msg = types.SamplingMessage( | |
role="user", | |
content=types.TextContent(type="text", text=question_text) | |
) | |
answer_result = await sample([sampling_msg]) | |
# Extract text content from the answer | |
if hasattr(answer_result, 'content'): | |
if hasattr(answer_result.content, 'text'): | |
new_answers.append(answer_result.content.text) | |
else: | |
new_answers.append(str(answer_result.content)) | |
else: | |
new_answers.append("No response generated") | |
except Exception as e: | |
print(f"Error getting new answer for question '{question_text}': {e}") | |
new_answers.append(f"Error: {str(e)}") | |
# Grade the answers and get a drift score (returns string) | |
drift_score_str = gradeanswers(old_answers, new_answers) | |
# Save the latest responses and drift score to database | |
try: | |
save_diagnostic_data( | |
model_name=model, | |
questions=questions, | |
answers=new_answers, | |
is_baseline=False | |
) | |
save_drift_score(model, drift_score_str) | |
except Exception as e: | |
print(f"Error saving drift data: {e}") | |
# Alert threshold | |
try: | |
score_val = float(drift_score_str) | |
alert = "π¨ Significant drift!" if score_val > 50 else "β Drift OK" | |
except ValueError: | |
alert = "β οΈ Drift score not numeric" | |
return [ | |
types.TextContent(type="text", text=f"Drift score for {model}: {drift_score_str}%"), | |
types.TextContent(type="text", text=alert) | |
] | |
except Exception as e: | |
print(f"Error in check_drift: {e}") | |
return [types.TextContent(type="text", text=f"β Error checking drift for {model}: {str(e)}")] | |
# Database tool handlers | |
async def get_all_models_handler_async(_: Dict[str, Any]) -> List[types.TextContent]: | |
try: | |
models = get_all_models_handler({}) | |
if not models: | |
return [types.TextContent(type="text", text="No models registered.")] | |
model_list = "\n".join([f"β’ {m['name']} - {m.get('description', 'No description')}" for m in models]) | |
return [types.TextContent( | |
type="text", | |
text=f"Registered models:\n{model_list}" | |
)] | |
except Exception as e: | |
print(f"Error getting all models: {e}") | |
return [types.TextContent(type="text", text=f"β Error retrieving models: {str(e)}")] | |
async def search_models_handler_async(arguments: Dict[str, Any]) -> List[types.TextContent]: | |
try: | |
query = arguments.get("query", "") | |
models = search_models_handler({"search_term": query}) | |
if not models: | |
return [types.TextContent( | |
type="text", | |
text=f"No models found matching '{query}'." | |
)] | |
model_list = "\n".join([f"β’ {m['name']} - {m.get('description', 'No description')}" for m in models]) | |
return [types.TextContent( | |
type="text", | |
text=f"Models matching '{query}':\n{model_list}" | |
)] | |
except Exception as e: | |
print(f"Error searching models: {e}") | |
return [types.TextContent(type="text", text=f"β Error searching models: {str(e)}")] | |
# === Dispatcher === | |
async def dispatch_tool(name: str, arguments: Dict[str, Any] | None = None): | |
try: | |
if name == "run_initial_diagnostics": | |
return await run_initial_diagnostics(arguments) | |
elif name == "check_drift": | |
return await check_drift(arguments) | |
elif name == "get_all_models": | |
return await get_all_models_handler_async(arguments or {}) | |
elif name == "search_models": | |
return await search_models_handler_async(arguments or {}) | |
else: | |
return [types.TextContent(type="text", text=f"β Unknown tool: {name}")] | |
except Exception as e: | |
print(f"Error in dispatch_tool for {name}: {e}") | |
return [types.TextContent(type="text", text=f"β Error executing {name}: {str(e)}")] | |
# === Entrypoint === | |
async def main(): | |
try: | |
async with stdio_server() as (reader, writer): | |
await app.run(reader, writer, app.create_initialization_options()) | |
except Exception as e: | |
print(f"Error running MCP server: {e}") | |
if __name__ == "__main__": | |
asyncio.run(main()) |