HarshBhati commited on
Commit
0733fd6
Β·
1 Parent(s): 7209b84

registered agents are showing but groq is not working fine

Browse files
app.py CHANGED
@@ -2,118 +2,142 @@ import os
2
  import gradio as gr
3
  import asyncio
4
  from typing import Optional, List, Dict
5
- from mcp_agent.core.fastagent import FastAgent
6
-
7
- from database_module.db import SessionLocal
8
- from database_module.models import ModelEntry
9
- from langchain.chat_models import init_chat_model
10
- # Modify imports section to include all required tools
11
- from database_module import (
12
- init_db,
13
- get_all_models_handler,
14
- search_models_handler,
15
- # save_model_handler,
16
- # get_model_details_handler,
17
- # calculate_drift_handler,
18
- # get_drift_history_handler
19
- )
 
 
 
 
 
 
 
 
 
 
 
20
  import json
21
  from datetime import datetime
22
  import plotly.graph_objects as go
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # --- Initialize database and MCP tool registration ---
25
- # Create tables and register MCP handlers
26
- init_db()
27
 
28
- # Fast Agent client initialization - This is the "scapegoat" client whose drift we're detecting
29
- fast = FastAgent("Scapegoat Client")
30
 
31
- @fast.agent(
32
- name="scapegoat",
33
- instruction="You are a test client whose drift will be detected and measured over time",
34
- servers=["drift-server"]
35
- )
36
- async def setup_agent():
37
- # This function defines the scapegoat agent that will be monitored for drift
38
- pass
39
 
40
- # Global scapegoat client instance to be monitored for drift
 
 
 
 
 
 
 
 
41
  scapegoat_client = None
 
 
42
 
43
- # Initialize the scapegoat client that will be tested for drift
44
- async def initialize_scapegoat_client():
45
- global scapegoat_client
46
- print("Initializing scapegoat client for drift monitoring...")
47
- async with fast.run() as agent:
48
- scapegoat_client = agent
49
- return agent
50
-
51
- # Helper to run async functions with FastAgent
52
- def run_async(coro):
53
- try:
54
- loop = asyncio.get_running_loop()
55
- except RuntimeError:
56
- loop = asyncio.new_event_loop()
57
- asyncio.set_event_loop(loop)
58
- return loop.run_until_complete(coro)
59
- else:
60
- # return result if coroutine returns value, else schedule
61
- task = loop.create_task(coro)
62
- return loop.run_until_complete(task) if not task.done() else task
63
-
64
- def run_initial_diagnostics(model_name: str, capabilities: str):
65
- """Run initial diagnostics for a new model"""
66
- try:
67
- # Use FastAgent's send method with a formatted message to call the tool
68
- message = f"""Please call the run_initial_diagnostics tool with the following parameters:
69
- model: {model_name}
70
- model_capabilities: {capabilities}
71
-
72
- This tool will generate and store baseline diagnostics for the model.
73
- """
74
-
75
- result = run_async(scapegoat_client(message))
76
- return result
77
- except Exception as e:
78
- print(f"Error running diagnostics: {e}")
79
- return None
80
 
81
- def check_model_drift(model_name: str):
82
- """Check drift for existing model"""
 
 
 
 
 
83
  try:
84
- # Use FastAgent's send method with a formatted message to call the tool
85
- message = f"""Please call the check_drift tool with the following parameters:
86
- model: {model_name}
87
-
88
- This tool will re-run diagnostics and compare to baseline for drift scoring.
89
- """
90
-
91
- result = run_async(scapegoat_client(message))
92
- return result
93
  except Exception as e:
94
- print(f"Error checking drift: {e}")
95
- return None
 
 
 
 
 
 
96
 
97
- # Initialize MCP connection on startup
98
- def initialize_mcp_connection():
99
  try:
100
- run_async(initialize_scapegoat_client())
101
- print("Successfully connected scapegoat client to MCP server")
102
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  except Exception as e:
104
- print(f"Failed to connect scapegoat client to MCP server: {e}")
105
  return False
106
 
107
 
108
- # Wrapper functions remain unchanged but now call real DB-backed MCP tools
109
  def get_models_from_db():
110
- """Get all models from database using direct function call"""
 
 
 
111
  try:
112
- # Direct function call to database_module instead of using MCP
113
  result = get_all_models_handler({})
114
-
115
  if result:
116
- # Format the result to match the expected structure
117
  return [
118
  {
119
  "name": model["name"],
@@ -124,22 +148,77 @@ def get_models_from_db():
124
  ]
125
  return []
126
  except Exception as e:
127
- print(f"Error getting models: {e}")
128
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
 
 
130
 
131
- def get_available_model_names():
132
- return [m["name"] for m in get_models_from_db()]
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  def search_models_in_db(search_term: str):
136
- """Search models in database using direct function call"""
 
 
 
137
  try:
138
- # Direct function call to database_module instead of using MCP
139
  result = search_models_handler({"search_term": search_term})
140
-
141
  if result:
142
- # Format the result to match the expected structure
143
  return [
144
  {
145
  "name": model["name"],
@@ -148,546 +227,496 @@ def search_models_in_db(search_term: str):
148
  }
149
  for model in result
150
  ]
151
- # If no results, return empty list
152
  return []
153
  except Exception as e:
154
- print(f"Error searching models: {e}")
155
- # Fallback to filtering from all models if there's an error
156
  return [m for m in get_models_from_db() if search_term.lower() in m["name"].lower()]
157
 
 
158
  def format_dropdown_items(models):
159
- """Format dropdown items to show model name, creation date, and description preview"""
 
 
 
160
  formatted_items = []
161
  model_mapping = {}
162
-
163
  for model in models:
164
  desc_preview = model["description"][:40] + ("..." if len(model["description"]) > 40 else "")
165
  item_label = f"{model['name']} (Created: {model['created']}) - {desc_preview}"
166
  formatted_items.append(item_label)
167
  model_mapping[item_label] = model["name"]
168
-
169
  return formatted_items, model_mapping
170
 
 
171
  def extract_model_name_from_dropdown(dropdown_value, model_mapping):
172
- """Extract actual model name from formatted dropdown value"""
 
 
173
  return model_mapping.get(dropdown_value, dropdown_value.split(" (")[0] if dropdown_value else "")
174
 
 
175
  def get_model_details(model_name: str):
176
- """Get model details from database via direct DB access (fallback)"""
177
  try:
178
- with SessionLocal() as session:
179
- model_entry = session.query(ModelEntry).filter_by(name=model_name).first()
180
- if model_entry:
181
- return {
182
- "name": model_entry.name,
183
- "description": model_entry.description or "",
184
- "system_prompt": model_entry.capabilities.split("\nSystem Prompt: ")[1] if "\nSystem Prompt: " in model_entry.capabilities else "",
185
- "created": model_entry.created.strftime("%Y-%m-%d %H:%M:%S") if model_entry.created else ""
186
- }
187
- return {"name": model_name, "system_prompt": "You are a helpful AI assistant.", "description": ""}
 
 
188
  except Exception as e:
189
- print(f"Error getting model details: {e}")
190
- return {"name": model_name, "system_prompt": "You are a helpful AI assistant.", "description": ""}
191
-
192
- def enhance_prompt_via_mcp(prompt: str):
193
- """Enhance prompt locally since enhance_prompt tool is not available in server.py"""
194
- # Provide a basic prompt enhancement functionality since server doesn't have it
195
- enhanced_prompts = {
196
- "helpful": f"{prompt}\n\nPlease be thorough, helpful, and provide detailed responses.",
197
- "concise": f"{prompt}\n\nPlease provide concise, direct answers.",
198
- "technical": f"{prompt}\n\nPlease provide technically accurate and comprehensive responses.",
199
- }
200
 
201
- if "helpful" in prompt.lower():
202
- return enhanced_prompts["helpful"]
203
- elif "concise" in prompt.lower() or "brief" in prompt.lower():
204
- return enhanced_prompts["concise"]
205
- elif "technical" in prompt.lower() or "detailed" in prompt.lower():
206
- return enhanced_prompts["technical"]
207
- else:
208
- return f"{prompt}\n\nAdditional context: Be specific, helpful, and provide detailed responses while maintaining a professional tone."
209
-
210
- def save_model_to_db(model_name: str, system_prompt: str):
211
- """Save model to database directly since save_model tool is not available in server.py"""
212
- try:
213
- # Check if model already exists
214
- with SessionLocal() as session:
215
- existing = session.query(ModelEntry).filter_by(name=model_name).first()
216
- if existing:
217
- # Update capabilities to include the new system prompt
218
- capabilities = existing.capabilities
219
- if "\nSystem Prompt: " in capabilities:
220
- # Replace the system prompt part
221
- parts = capabilities.split("\nSystem Prompt: ")
222
- capabilities = f"{parts[0]}\nSystem Prompt: {system_prompt}"
223
- else:
224
- # Add system prompt if not present
225
- capabilities = f"{capabilities}\nSystem Prompt: {system_prompt}"
226
 
227
- existing.capabilities = capabilities
228
- existing.updated = datetime.now()
229
- session.commit()
230
- return {"message": f"Updated existing model: {model_name}"}
231
- else:
232
- # Should not happen as models are registered with capabilities before calling this function
233
- return {"message": f"Model {model_name} not found. Please register it first."}
234
- except Exception as e:
235
- print(f"Error saving model: {e}")
236
- return {"message": f"Error saving model: {e}"}
237
 
238
- def get_drift_history_from_db(model_name: str):
239
- """Get drift history from database directly without any fallbacks"""
240
  try:
241
- from database_module.models import DriftEntry
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
- with SessionLocal() as session:
244
- # Query the drift_history table for this model
245
- drift_entries = session.query(DriftEntry).filter(
246
- DriftEntry.model_name == model_name
247
- ).order_by(DriftEntry.date.desc()).all()
248
 
249
- # If no entries found, return empty list
250
- if not drift_entries:
251
- return []
 
252
 
253
- # Convert to the expected format
254
- history = []
255
- for entry in drift_entries:
256
- history.append({
257
- "date": entry.date.strftime("%Y-%m-%d"),
258
- "drift_score": float(entry.drift_score),
259
- "model": entry.model_name
260
- })
261
-
262
- return history
263
  except Exception as e:
264
- print(f"Error getting drift history from database: {e}")
265
- return [] # Return empty list on error, no fallbacks
266
- def create_drift_chart(drift_history):
267
- """Create drift chart using plotly"""
268
- if not drift_history:
269
- return gr.update(value=None)
270
-
271
- dates = [entry["date"] for entry in drift_history]
272
- scores = [entry["drift_score"] for entry in drift_history]
273
-
274
- fig = go.Figure()
275
- fig.add_trace(go.Scatter(
276
- x=dates,
277
- y=scores,
278
- mode='lines+markers',
279
- name='Drift Score',
280
- line=dict(color='#ff6b6b', width=3),
281
- marker=dict(size=8, color='#ff6b6b')
282
- ))
283
-
284
- fig.update_layout(
285
- title='Model Drift Over Time',
286
- xaxis_title='Date',
287
- yaxis_title='Drift Score',
288
- template='plotly_white',
289
- height=400,
290
- showlegend=True
291
- )
292
-
293
- return fig
294
 
295
- # Global variable to store model mapping
296
- current_model_mapping = {}
297
 
298
- # Gradio interface functions
299
- def update_model_dropdown(search_term):
300
- """Update dropdown choices based on search term"""
301
- global current_model_mapping
302
-
303
- if search_term.strip():
304
- models = search_models_in_db(search_term.strip())
305
- else:
306
- models = get_models_from_db()
307
-
308
- formatted_items, model_mapping = format_dropdown_items(models)
309
- current_model_mapping = model_mapping
310
-
311
- return gr.update(choices=formatted_items, value=formatted_items[0] if formatted_items else None)
312
 
313
- def on_model_select(dropdown_value):
314
- """Handle model selection"""
315
- if not dropdown_value:
316
- return "", ""
317
-
318
- actual_model_name = extract_model_name_from_dropdown(dropdown_value, current_model_mapping)
319
- return actual_model_name, actual_model_name
320
 
321
  def cancel_create_new():
322
  """Cancel create new model"""
323
  return [
324
  gr.update(visible=False), # create_new_section
325
- None, # new_model_name (dropdown)
326
  "", # new_system_prompt
327
  gr.update(visible=False), # enhanced_prompt_display
328
  gr.update(visible=False), # prompt_choice
329
  gr.update(visible=False), # save_model_button
330
- gr.update(visible=False) # save_status
331
  ]
332
 
 
333
  def enhance_prompt(original_prompt):
334
- """Enhance prompt and show options"""
335
- if not original_prompt.strip():
336
  return [
337
  gr.update(visible=False),
338
  gr.update(visible=False),
339
  gr.update(visible=False)
340
  ]
341
-
342
- enhanced = enhance_prompt_via_mcp(original_prompt.strip())
343
  return [
344
  gr.update(value=enhanced, visible=True),
345
  gr.update(visible=True),
346
  gr.update(visible=True)
347
  ]
348
 
349
- def register_model_with_capabilities(model_name: str, capabilities: str):
350
- """Register a new model with its capabilities in the database"""
351
- try:
352
- with SessionLocal() as session:
353
- model_entry = ModelEntry(
354
- name=model_name,
355
- capabilities=capabilities,
356
- created=datetime.now()
357
- )
358
- session.add(model_entry)
359
- session.commit()
360
- return True
361
- except Exception as e:
362
- print(f"Error registering model: {e}")
363
- return False
364
 
 
 
 
365
 
366
- def save_new_model(selected_model_name, selected_llm, original_prompt, enhanced_prompt, choice):
367
- """Save new model to database"""
368
- if not selected_model_name or not original_prompt.strip() or not selected_llm:
369
  return [
370
- "Please provide model name, LLM selection, and system prompt",
371
  gr.update(visible=True),
372
  gr.update()
373
  ]
374
-
375
- final_prompt = enhanced_prompt if choice == "Keep Enhanced" else original_prompt
376
-
377
  try:
378
- # Save the model with LLM capabilities
379
  capabilities = f"{selected_llm}\nSystem Prompt: {final_prompt}"
380
- register_model_with_capabilities(selected_model_name, capabilities)
381
-
382
- status = save_model_to_db(selected_model_name, final_prompt)
383
-
384
- # Run initial diagnostics
385
- diagnostic_result = run_initial_diagnostics(
386
- selected_model_name,
387
- capabilities
388
- )
389
-
390
- if diagnostic_result:
391
- status = f"{status}\n{diagnostic_result[0].text if isinstance(diagnostic_result, list) else diagnostic_result}"
 
 
392
  except Exception as e:
393
- status = f"Error saving model: {e}"
394
-
395
- # Update dropdown choices
396
- updated_models = get_models_from_db()
397
- formatted_items, model_mapping = format_dropdown_items(updated_models)
398
- global current_model_mapping
399
- current_model_mapping = model_mapping
400
-
401
  return [
402
  status,
403
  gr.update(visible=True),
404
- gr.update(choices=formatted_items)
405
  ]
406
 
 
407
  def chatbot_response(message, history, dropdown_value):
408
- """Generate chatbot response using selected model"""
409
- if not message.strip() or not dropdown_value:
410
  return history, ""
411
-
412
- model_name = extract_model_name_from_dropdown(dropdown_value, current_model_mapping)
413
- model_details = get_model_details(model_name)
414
- system_prompt = model_details.get("system_prompt", "")
415
-
416
  try:
417
- # Initialize LLM based on model details
418
- # Get model configuration from database
419
- with SessionLocal() as session:
420
- model_entry = session.query(ModelEntry).filter_by(name=model_name).first()
421
- if not model_entry:
422
- return history + [[message, "Error: Model not found"]], ""
423
-
424
- llm_name = model_entry.capabilities.split("\n")[0] if model_entry.capabilities else "groq-llama-3.1-8b-instant"
425
-
426
- # Initialize the LLM using langchain
427
- llm = init_chat_model(
428
- llm_name,
429
- model_provider='groq' if llm_name.startswith('groq') else 'google'
430
- )
431
-
432
- # Format the conversation with system prompt
433
- formatted_prompt = f"System: {system_prompt}\nUser: {message}"
434
-
435
- # Get response from LLM
436
- response = llm.invoke(formatted_prompt)
437
- response_text = response.content
438
-
439
- history.append([message, response_text])
440
  return history, ""
441
-
442
  except Exception as e:
443
- error_message = f"Error generating response: {str(e)}"
444
- history.append([message, error_message])
445
  return history, ""
446
 
 
447
  def calculate_drift(dropdown_value):
448
- """Calculate drift for selected model"""
449
  if not dropdown_value:
450
- return "Please select a model first"
451
-
452
- model_name = extract_model_name_from_dropdown(dropdown_value, current_model_mapping)
453
-
454
- # First try the drift calculation tool
455
  try:
456
- result = check_model_drift(model_name)
457
- if result and isinstance(result, list):
458
- return "\n".join(msg.text for msg in result)
 
 
 
 
 
459
  except Exception as e:
460
- print(f"Error calculating drift: {e}")
461
- return f"Error calculating drift from server side: {e}"
462
-
463
- # Fallback to the simpler drift calculation if needed
464
- # result = calculate_drift_handler({"model_name": model_name})
465
- return f"Drift Score: {result.get('drift_score', 0.0):.3f}\n{result.get('message', '')}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
 
467
  def refresh_drift_history(dropdown_value):
468
- """Refresh drift history for selected model"""
469
  if not dropdown_value:
470
  return [], gr.update(value=None)
471
-
472
- model_name = extract_model_name_from_dropdown(dropdown_value, current_model_mapping)
473
- history = get_drift_history_from_db(model_name)
474
- chart = create_drift_chart(history)
475
-
476
- return history, chart
 
 
 
 
 
 
 
 
 
 
 
477
 
478
  def initialize_interface():
479
- """Initialize interface with MCP connection and default data"""
480
- # Connect to MCP server
481
- mcp_connected = initialize_mcp_connection()
482
-
483
- # Get initial model data
484
- models = get_models_from_db()
485
- formatted_items, model_mapping = format_dropdown_items(models)
486
  global current_model_mapping
487
- current_model_mapping = model_mapping
488
-
489
- return (
490
- formatted_items, # model_dropdown choices
491
- formatted_items[0] if formatted_items else None, # model_dropdown value
492
- "", # new_model_name - should be empty string, not choices
493
- formatted_items[0].split(" (")[0] if formatted_items else "", # selected_model_display
494
- formatted_items[0].split(" (")[0] if formatted_items else "" # drift_model_display
495
- )
496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
 
498
- # Create Gradio interface
499
- with gr.Blocks(title="AI Model Management & Interaction Platform") as demo:
500
- gr.Markdown("# AI Model Management & Interaction Platform")
501
-
502
- with gr.Row():
503
- # Left Column - Model Selection
504
- with gr.Column(scale=1):
505
- gr.Markdown("### Model Selection")
506
-
507
- model_dropdown = gr.Dropdown(
508
- choices=[], #work here Here show the already created models (fetched from database using mcp functions defined above)
509
- label="Select Model",
510
- interactive=True
511
- )
512
-
513
- search_box = gr.Textbox(
514
- placeholder="Search by model name or description...",
515
- label="Search Models"
516
- )
517
-
518
- create_new_button = gr.Button("Create New Model", variant="secondary")
519
-
520
- # Create New Model Section (Initially Hidden)
521
- with gr.Group(visible=False) as create_new_section:
522
- gr.Markdown("#### Create New Model")
523
- new_model_name = gr.Textbox(
524
- label="Model name",
525
- placeholder="Model name"
526
- )
527
- new_llm = gr.Dropdown(
528
- choices=[
529
- "gemini-1.0-pro",
530
- "gemini-1.5-pro",
531
- "groq-llama-3.1-8b-instant",
532
- "groq-mixtral-8x7b",
533
- "groq-gpt4"
534
- ], #work here to show options to select llms(available to use) like gemini-1.5-pro, etc google models, groq models (atleast 5 in total)
535
- label="Select LLM Name",
536
- interactive=True
537
- )
538
- new_system_prompt = gr.Textbox(
539
- label="System Prompt",
540
- placeholder="Enter system prompt",
541
- lines=3
542
- )
543
-
544
- with gr.Row():
545
- enhance_button = gr.Button("Enhance Prompt", variant="primary")
546
- cancel_button = gr.Button("Cancel", variant="secondary")
547
-
548
- enhanced_prompt_display = gr.Textbox(
549
- label="Enhanced Prompt",
550
- interactive=False,
551
- lines=4,
552
- visible=False
553
  )
554
-
555
- prompt_choice = gr.Radio(
556
- choices=["Keep Enhanced", "Keep Original"],
557
- label="Choose Prompt to Use",
558
- visible=False
559
  )
560
-
561
- save_model_button = gr.Button("Save Model", variant="primary", visible=False)
562
- save_status = gr.Textbox(label="Status", interactive=False, visible=False)
563
-
564
- # Right Column - Model Operations
565
- with gr.Column(scale=2):
566
- gr.Markdown("### Model Operations")
567
-
568
- with gr.Tabs():
569
- # Chatbot Tab
570
- with gr.TabItem("Chatbot"):
571
- selected_model_display = gr.Textbox(
572
- label="Currently Selected Model",
573
- interactive=False
574
  )
575
-
576
- chatbot_interface = gr.Chatbot(height=400)
577
-
578
- with gr.Row():
579
- msg_input = gr.Textbox(
580
- placeholder="Enter your message...",
581
- label="Message",
582
- scale=4
583
- )
584
- send_button = gr.Button("Send", variant="primary", scale=1)
585
-
586
- clear_chat = gr.Button("Clear Chat", variant="secondary")
587
-
588
- # Drift Analysis Tab
589
- with gr.TabItem("Drift Analysis"):
590
- drift_model_display = gr.Textbox(
591
- label="Model for Drift Analysis",
592
- interactive=False
593
  )
594
-
 
 
 
 
 
595
  with gr.Row():
596
- calculate_drift_button = gr.Button("Calculate New Drift", variant="primary")
597
- refresh_history_button = gr.Button("Refresh History", variant="secondary")
598
-
599
- drift_result = gr.Textbox(label="Latest Drift Calculation", interactive=False)
600
-
601
- gr.Markdown("#### Drift History")
602
- drift_history_display = gr.JSON(label="Drift History Data")
603
-
604
- gr.Markdown("#### Drift Chart")
605
- drift_chart = gr.Plot(label="Drift Over Time")
606
-
607
- # Event Handlers
608
-
609
- # Search functionality - Dynamic update
610
- search_box.change(
611
- update_model_dropdown,
612
- inputs=[search_box],
613
- outputs=[model_dropdown]
614
- )
615
-
616
- # Model selection updates
617
- model_dropdown.change(
618
- on_model_select,
619
- inputs=[model_dropdown],
620
- outputs=[selected_model_display, drift_model_display]
621
- )
622
-
623
- # Create new model functionality
624
- def show_create_new():
625
- """Show the create new model section"""
626
- return gr.update(visible=True), gr.update(value="")
627
-
628
- create_new_button.click(
629
- show_create_new,
630
- outputs=[create_new_section, new_model_name]
631
- )
632
-
633
- cancel_button.click(cancel_create_new, outputs=[
634
- create_new_section, new_model_name, new_system_prompt,
635
- enhanced_prompt_display, prompt_choice, save_model_button, save_status
636
- ])
637
-
638
- # Enhance prompt
639
- enhance_button.click(
640
- enhance_prompt,
641
- inputs=[new_system_prompt],
642
- outputs=[enhanced_prompt_display, prompt_choice, save_model_button]
643
- )
644
-
645
- # Save model
646
- save_model_button.click(
647
- save_new_model,
648
- inputs=[new_model_name, new_llm, new_system_prompt, enhanced_prompt_display, prompt_choice],
649
- outputs=[save_status, save_status, model_dropdown]
650
- )
651
-
652
- # Chatbot functionality
653
- send_button.click(
654
- chatbot_response,
655
- inputs=[msg_input, chatbot_interface, model_dropdown],
656
- outputs=[chatbot_interface, msg_input]
657
- )
658
-
659
- msg_input.submit(
660
- chatbot_response,
661
- inputs=[msg_input, chatbot_interface, model_dropdown],
662
- outputs=[chatbot_interface, msg_input]
663
- )
664
-
665
- clear_chat.click(lambda: [], outputs=[chatbot_interface])
666
-
667
- # Drift analysis functionality
668
- calculate_drift_button.click(
669
- calculate_drift,
670
- inputs=[model_dropdown],
671
- outputs=[drift_result]
672
- )
673
-
674
- refresh_history_button.click(
675
- refresh_drift_history,
676
- inputs=[model_dropdown],
677
- outputs=[drift_history_display, drift_chart]
678
- )
679
-
680
- # Initialize interface on load
681
- demo.load(
682
- initialize_interface,
683
- outputs=[
684
- model_dropdown, # dropdown choices
685
- model_dropdown, # dropdown value
686
- new_model_name, # textbox for new model name (empty string)
687
- selected_model_display,
688
- drift_model_display
689
- ]
690
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
691
 
692
  if __name__ == "__main__":
693
- demo.launch()
 
 
 
 
 
 
 
2
  import gradio as gr
3
  import asyncio
4
  from typing import Optional, List, Dict
5
+ import subprocess
6
+ import time
7
+ import signal
8
+ import sys
9
+ import threading
10
+ import concurrent.futures
11
+ # Add these imports at the top of your Gradio file
12
+ from ourllm import llm # Import the actual LLM instance
13
+ from dotenv import load_dotenv
14
+ # Add error handling for imports
15
+ try:
16
+ from database_module.db import SessionLocal
17
+ from database_module.models import ModelEntry
18
+ from langchain.chat_models import init_chat_model
19
+ from database_module import (
20
+ init_db,
21
+ get_all_models_handler,
22
+ search_models_handler,
23
+ )
24
+
25
+ DATABASE_AVAILABLE = True
26
+ except ImportError as e:
27
+ print(f"⚠️ Database modules not available: {e}")
28
+ print("⚠️ Running in demo mode without database functionality")
29
+ DATABASE_AVAILABLE = False
30
+
31
  import json
32
  from datetime import datetime
33
  import plotly.graph_objects as go
34
+ try:
35
+ from ourllm import llm
36
+ print("βœ… Successfully imported LLM from ourllm.py")
37
+ LLM_AVAILABLE = True
38
+ except ImportError as e:
39
+ print(f"❌ Failed to import LLM: {e}")
40
+ LLM_AVAILABLE = False
41
+
42
+ # Mock database functions for when database is not available
43
+ def mock_init_db():
44
+ print("πŸ“ Mock database initialized")
45
+ return True
46
+
47
+
48
+ def mock_get_all_models():
49
+ return [
50
+ {"name": "demo-model-1", "description": "Demo model for testing", "created": "2024-01-01"},
51
+ {"name": "demo-model-2", "description": "Another demo model", "created": "2024-01-02"}
52
+ ]
53
+
54
+
55
+ def mock_search_models(search_term):
56
+ all_models = mock_get_all_models()
57
+ return [m for m in all_models if search_term.lower() in m["name"].lower()]
58
+
59
 
60
+ def mock_register_model(model_name, capabilities):
61
+ print(f"πŸ“ Mock: Registered model {model_name}")
62
+ return True
63
 
 
 
64
 
65
+ # Use mock functions if database is not available
66
+ if not DATABASE_AVAILABLE:
67
+ init_db = mock_init_db
68
+ get_all_models_handler = lambda x: mock_get_all_models()
69
+ search_models_handler = lambda x: mock_search_models(x.get("search_term", ""))
 
 
 
70
 
71
+ # Initialize database (or mock)
72
+ try:
73
+ init_db()
74
+ print("βœ… Database initialization successful")
75
+ except Exception as e:
76
+ print(f"⚠️ Database initialization failed: {e}")
77
+ DATABASE_AVAILABLE = False
78
+
79
+ # Global variables
80
  scapegoat_client = None
81
+ server_manager = None
82
+ current_model_mapping = {}
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ # --- Simplified Database Functions ---
86
+ def ensure_database_setup():
87
+ """Ensure database is properly set up"""
88
+ if not DATABASE_AVAILABLE:
89
+ print("βœ… Running in demo mode - no database required")
90
+ return True
91
+
92
  try:
93
+ # Test database connection
94
+ with SessionLocal() as session:
95
+ session.execute("SELECT 1")
96
+ session.commit()
97
+ print("βœ… Database connection successful")
98
+ return True
 
 
 
99
  except Exception as e:
100
+ print(f"❌ Database setup failed: {e}")
101
+ return False
102
+
103
+
104
+ def register_model_with_capabilities(model_name: str, capabilities: str):
105
+ """Register a new model with its capabilities"""
106
+ if not DATABASE_AVAILABLE:
107
+ return mock_register_model(model_name, capabilities)
108
 
 
 
109
  try:
110
+ with SessionLocal() as session:
111
+ existing = session.query(ModelEntry).filter_by(name=model_name).first()
112
+ if existing:
113
+ existing.capabilities = capabilities
114
+ existing.updated = datetime.now()
115
+ session.commit()
116
+ print(f"βœ… Updated existing model: {model_name}")
117
+ else:
118
+ model_entry = ModelEntry(
119
+ name=model_name,
120
+ capabilities=capabilities,
121
+ created=datetime.now()
122
+ )
123
+ session.add(model_entry)
124
+ session.commit()
125
+ print(f"βœ… Registered new model: {model_name}")
126
+ return True
127
  except Exception as e:
128
+ print(f"❌ Error registering model: {e}")
129
  return False
130
 
131
 
132
+ # --- Simplified Model Management Functions ---
133
  def get_models_from_db():
134
+ """Get all models from database"""
135
+ if not DATABASE_AVAILABLE:
136
+ return mock_get_all_models()
137
+
138
  try:
 
139
  result = get_all_models_handler({})
 
140
  if result:
 
141
  return [
142
  {
143
  "name": model["name"],
 
148
  ]
149
  return []
150
  except Exception as e:
151
+ print(f"❌ Error getting models: {e}")
152
+ return mock_get_all_models()
153
+
154
+
155
+ load_dotenv()
156
+
157
+
158
+ # Replace your current chatbot_response function with this:
159
+ def chatbot_response(message, history, dropdown_value):
160
+ """Generate chatbot response using actual LLM with debug info"""
161
+ print(f"πŸ” DEBUG: Function called with message: '{message}'")
162
+ print(f"πŸ” DEBUG: LLM_AVAILABLE: {LLM_AVAILABLE}")
163
+ print(f"πŸ” DEBUG: GROQ_API_KEY exists: {'GROQ_API_KEY' in os.environ}")
164
+
165
+ if not message or not message.strip() or not dropdown_value:
166
+ print("πŸ” DEBUG: Empty message or dropdown")
167
+ return history, ""
168
+
169
+ try:
170
+ model_name = extract_model_name_from_dropdown(dropdown_value, current_model_mapping)
171
+ print(f"πŸ” DEBUG: Model name: {model_name}")
172
+
173
+ # Initialize history if needed
174
+ if history is None:
175
+ history = []
176
+
177
+ # Check if LLM is available and API key is set
178
+ if not LLM_AVAILABLE:
179
+ response_text = "❌ LLM not available - check ourllm.py import"
180
+ elif not os.getenv("GROQ_API_KEY"):
181
+ response_text = "❌ GROQ_API_KEY not found in environment variables"
182
+ else:
183
+ try:
184
+ print("πŸ” DEBUG: Attempting to call LLM...")
185
+
186
+ # Simple direct call to LLM
187
+ response = llm.invoke(message)
188
+ response_text = str(response.content).strip()
189
+
190
+ print(f"πŸ” DEBUG: LLM response received: {response_text[:100]}...")
191
 
192
+ if not response_text:
193
+ response_text = "❌ LLM returned empty response"
194
 
195
+ except Exception as e:
196
+ print(f"πŸ” DEBUG: LLM call failed: {e}")
197
+ response_text = f"❌ LLM Error: {str(e)}"
198
 
199
+ # Add to history
200
+ history.append({"role": "user", "content": message})
201
+ history.append({"role": "assistant", "content": response_text})
202
+
203
+ print(f"πŸ” DEBUG: Final response: {response_text}")
204
+ return history, ""
205
+
206
+ except Exception as e:
207
+ print(f"πŸ” DEBUG: General error in chatbot_response: {e}")
208
+ if history is None:
209
+ history = []
210
+ history.append({"role": "user", "content": message})
211
+ history.append({"role": "assistant", "content": f"❌ Error: {str(e)}"})
212
+ return history, ""
213
 
214
  def search_models_in_db(search_term: str):
215
+ """Search models in database"""
216
+ if not DATABASE_AVAILABLE:
217
+ return mock_search_models(search_term)
218
+
219
  try:
 
220
  result = search_models_handler({"search_term": search_term})
 
221
  if result:
 
222
  return [
223
  {
224
  "name": model["name"],
 
227
  }
228
  for model in result
229
  ]
 
230
  return []
231
  except Exception as e:
232
+ print(f"❌ Error searching models: {e}")
 
233
  return [m for m in get_models_from_db() if search_term.lower() in m["name"].lower()]
234
 
235
+
236
  def format_dropdown_items(models):
237
+ """Format dropdown items"""
238
+ if not models:
239
+ return [], {}
240
+
241
  formatted_items = []
242
  model_mapping = {}
243
+
244
  for model in models:
245
  desc_preview = model["description"][:40] + ("..." if len(model["description"]) > 40 else "")
246
  item_label = f"{model['name']} (Created: {model['created']}) - {desc_preview}"
247
  formatted_items.append(item_label)
248
  model_mapping[item_label] = model["name"]
249
+
250
  return formatted_items, model_mapping
251
 
252
+
253
  def extract_model_name_from_dropdown(dropdown_value, model_mapping):
254
+ """Extract model name from dropdown"""
255
+ if not dropdown_value:
256
+ return ""
257
  return model_mapping.get(dropdown_value, dropdown_value.split(" (")[0] if dropdown_value else "")
258
 
259
+
260
  def get_model_details(model_name: str):
261
+ """Get model details from database"""
262
  try:
263
+ if DATABASE_AVAILABLE:
264
+ with SessionLocal() as session:
265
+ model_entry = session.query(ModelEntry).filter_by(name=model_name).first()
266
+ if model_entry:
267
+ return {
268
+ "name": model_entry.name,
269
+ "description": model_entry.description or "",
270
+ "system_prompt": model_entry.capabilities.split("System Prompt: ")[
271
+ 1] if model_entry.capabilities and "System Prompt: " in model_entry.capabilities else "You are a helpful AI assistant.",
272
+ "created": model_entry.created.strftime("%Y-%m-%d %H:%M:%S") if model_entry.created else ""
273
+ }
274
+ return {"name": model_name, "system_prompt": "You are a helpful AI assistant.", "description": "Demo model"}
275
  except Exception as e:
276
+ print(f"❌ Error getting model details: {e}")
277
+ return {"name": model_name, "system_prompt": "You are a helpful AI assistant.", "description": "Demo model"}
 
 
 
 
 
 
 
 
 
278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
+ # --- Gradio Interface Functions ---
281
+ def update_model_dropdown(search_term):
282
+ """Update dropdown based on search"""
283
+ global current_model_mapping
 
 
 
 
 
 
284
 
 
 
285
  try:
286
+ if search_term and search_term.strip():
287
+ models = search_models_in_db(search_term.strip())
288
+ else:
289
+ models = get_models_from_db()
290
+
291
+ formatted_items, model_mapping = format_dropdown_items(models)
292
+ current_model_mapping = model_mapping
293
+
294
+ # Return dropdown with proper value handling
295
+ if formatted_items:
296
+ return gr.update(choices=formatted_items, value=formatted_items[0])
297
+ else:
298
+ return gr.update(choices=[], value=None)
299
+ except Exception as e:
300
+ print(f"❌ Error updating dropdown: {e}")
301
+ return gr.update(choices=[], value=None)
302
 
 
 
 
 
 
303
 
304
+ def on_model_select(dropdown_value):
305
+ """Handle model selection"""
306
+ if not dropdown_value or not current_model_mapping:
307
+ return "", ""
308
 
309
+ try:
310
+ actual_model_name = extract_model_name_from_dropdown(dropdown_value, current_model_mapping)
311
+ return actual_model_name, actual_model_name
 
 
 
 
 
 
 
312
  except Exception as e:
313
+ print(f"❌ Error in model selection: {e}")
314
+ return "", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
 
 
316
 
317
+ def show_create_new():
318
+ """Show create new model section"""
319
+ return gr.update(visible=True), gr.update(value="")
 
 
 
 
 
 
 
 
 
 
 
320
 
 
 
 
 
 
 
 
321
 
322
  def cancel_create_new():
323
  """Cancel create new model"""
324
  return [
325
  gr.update(visible=False), # create_new_section
326
+ "", # new_model_name
327
  "", # new_system_prompt
328
  gr.update(visible=False), # enhanced_prompt_display
329
  gr.update(visible=False), # prompt_choice
330
  gr.update(visible=False), # save_model_button
331
+ gr.update(visible=False) # save_status
332
  ]
333
 
334
+
335
  def enhance_prompt(original_prompt):
336
+ """Enhance prompt locally"""
337
+ if not original_prompt or not original_prompt.strip():
338
  return [
339
  gr.update(visible=False),
340
  gr.update(visible=False),
341
  gr.update(visible=False)
342
  ]
343
+
344
+ enhanced = f"{original_prompt}\n\nAdditional context: Be specific, helpful, and provide detailed responses while maintaining a professional tone."
345
  return [
346
  gr.update(value=enhanced, visible=True),
347
  gr.update(visible=True),
348
  gr.update(visible=True)
349
  ]
350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
+ def save_new_model(model_name, selected_llm, original_prompt, enhanced_prompt, choice):
353
+ """Save new model"""
354
+ global current_model_mapping
355
 
356
+ if not model_name or not original_prompt or not original_prompt.strip() or not selected_llm:
 
 
357
  return [
358
+ "❌ Please provide model name, LLM selection, and system prompt",
359
  gr.update(visible=True),
360
  gr.update()
361
  ]
362
+
 
 
363
  try:
364
+ final_prompt = enhanced_prompt if choice == "Keep Enhanced" else original_prompt
365
  capabilities = f"{selected_llm}\nSystem Prompt: {final_prompt}"
366
+
367
+ if register_model_with_capabilities(model_name, capabilities):
368
+ status = f"βœ… Model '{model_name}' saved successfully!"
369
+
370
+ # Update dropdown with new models
371
+ updated_models = get_models_from_db()
372
+ formatted_items, model_mapping = format_dropdown_items(updated_models)
373
+ current_model_mapping = model_mapping
374
+
375
+ dropdown_update = gr.update(choices=formatted_items, value=formatted_items[0] if formatted_items else None)
376
+ else:
377
+ status = "❌ Error saving model to database"
378
+ dropdown_update = gr.update()
379
+
380
  except Exception as e:
381
+ status = f"❌ Error saving model: {e}"
382
+ dropdown_update = gr.update()
383
+
 
 
 
 
 
384
  return [
385
  status,
386
  gr.update(visible=True),
387
+ dropdown_update
388
  ]
389
 
390
+
391
  def chatbot_response(message, history, dropdown_value):
392
+ """Generate chatbot response - simplified version"""
393
+ if not message or not message.strip() or not dropdown_value:
394
  return history, ""
395
+
 
 
 
 
396
  try:
397
+ model_name = extract_model_name_from_dropdown(dropdown_value, current_model_mapping)
398
+
399
+ # Simple mock response for demo
400
+ response_text = f"Hello! I'm {model_name}. You said: '{message}'. This is a demo response since the full LLM integration requires API keys."
401
+
402
+ # Append in messages format
403
+ if history is None:
404
+ history = []
405
+
406
+ history.append({"role": "user", "content": message})
407
+ history.append({"role": "assistant", "content": response_text})
 
 
 
 
 
 
 
 
 
 
 
 
408
  return history, ""
 
409
  except Exception as e:
410
+ print(f"❌ Error in chatbot response: {e}")
 
411
  return history, ""
412
 
413
+
414
  def calculate_drift(dropdown_value):
415
+ """Calculate drift for model - simplified version"""
416
  if not dropdown_value:
417
+ return "❌ Please select a model first"
418
+
 
 
 
419
  try:
420
+ model_name = extract_model_name_from_dropdown(dropdown_value, current_model_mapping)
421
+
422
+ # Simple mock drift calculation
423
+ import random
424
+ drift_score = random.randint(10, 80)
425
+ alert = "🚨 Significant drift detected!" if drift_score > 50 else "βœ… Drift within acceptable range"
426
+
427
+ return f"Drift analysis for {model_name}:\nDrift Score: {drift_score}/100\n{alert}"
428
  except Exception as e:
429
+ print(f"❌ Error calculating drift: {e}")
430
+ return "❌ Error calculating drift"
431
+
432
+
433
+ def create_drift_chart(drift_history):
434
+ """Create drift chart"""
435
+ try:
436
+ if not drift_history:
437
+ # Create sample data for demo
438
+ dates = ['2024-01-01', '2024-01-02', '2024-01-03', '2024-01-04', '2024-01-05']
439
+ scores = [25, 30, 45, 35, 40]
440
+ else:
441
+ dates = [entry["date"] for entry in drift_history]
442
+ scores = [entry["drift_score"] for entry in drift_history]
443
+
444
+ fig = go.Figure()
445
+ fig.add_trace(go.Scatter(
446
+ x=dates,
447
+ y=scores,
448
+ mode='lines+markers',
449
+ name='Drift Score',
450
+ line=dict(color='#ff6b6b', width=3),
451
+ marker=dict(size=8, color='#ff6b6b')
452
+ ))
453
+
454
+ fig.update_layout(
455
+ title='Model Drift Over Time',
456
+ xaxis_title='Date',
457
+ yaxis_title='Drift Score',
458
+ template='plotly_white',
459
+ height=400,
460
+ showlegend=True
461
+ )
462
+
463
+ return fig
464
+ except Exception as e:
465
+ print(f"❌ Error creating drift chart: {e}")
466
+ return go.Figure()
467
+
468
 
469
  def refresh_drift_history(dropdown_value):
470
+ """Refresh drift history"""
471
  if not dropdown_value:
472
  return [], gr.update(value=None)
473
+
474
+ try:
475
+ # Mock data for demo
476
+ history = [
477
+ {"date": "2024-01-01", "drift_score": 25},
478
+ {"date": "2024-01-02", "drift_score": 30},
479
+ {"date": "2024-01-03", "drift_score": 45},
480
+ {"date": "2024-01-04", "drift_score": 35},
481
+ {"date": "2024-01-05", "drift_score": 40}
482
+ ]
483
+
484
+ chart = create_drift_chart(history)
485
+ return history, chart
486
+ except Exception as e:
487
+ print(f"❌ Error refreshing drift history: {e}")
488
+ return [], gr.update(value=None)
489
+
490
 
491
  def initialize_interface():
492
+ """Initialize interface"""
 
 
 
 
 
 
493
  global current_model_mapping
 
 
 
 
 
 
 
 
 
494
 
495
+ try:
496
+ models = get_models_from_db()
497
+ formatted_items, model_mapping = format_dropdown_items(models)
498
+ current_model_mapping = model_mapping
499
+
500
+ # Safe initialization
501
+ if formatted_items:
502
+ dropdown_value = formatted_items[0]
503
+ first_model_name = extract_model_name_from_dropdown(dropdown_value, model_mapping)
504
+ dropdown_update = gr.update(choices=formatted_items, value=dropdown_value)
505
+ else:
506
+ dropdown_value = None
507
+ first_model_name = ""
508
+ dropdown_update = gr.update(choices=[], value=None)
509
+
510
+ return (
511
+ dropdown_update, # dropdown update
512
+ "", # new_model_name
513
+ first_model_name, # selected_model_display
514
+ first_model_name # drift_model_display
515
+ )
516
+ except Exception as e:
517
+ print(f"❌ Error initializing interface: {e}")
518
+ return (
519
+ gr.update(choices=[], value=None),
520
+ "",
521
+ "",
522
+ ""
523
+ )
524
 
525
+
526
+ # --- Gradio Interface ---
527
+ def create_interface():
528
+ """Create the Gradio interface"""
529
+ with gr.Blocks(title="AI Model Management & Interaction Platform", theme=gr.themes.Soft()) as demo:
530
+ gr.Markdown("# πŸ€– AI Model Management & Interaction Platform")
531
+
532
+ if not DATABASE_AVAILABLE:
533
+ gr.Markdown("⚠️ **Demo Mode**: Running without database connectivity. Some features are simulated.")
534
+
535
+ with gr.Row():
536
+ # Left Column - Model Selection
537
+ with gr.Column(scale=1):
538
+ gr.Markdown("### πŸ“‹ Model Selection")
539
+
540
+ model_dropdown = gr.Dropdown(
541
+ choices=[],
542
+ label="Select Model",
543
+ interactive=True,
544
+ allow_custom_value=False,
545
+ value=None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
  )
547
+
548
+ search_box = gr.Textbox(
549
+ placeholder="Search by model name or description...",
550
+ label="πŸ” Search Models"
 
551
  )
552
+
553
+ create_new_button = gr.Button("βž• Create New Model", variant="secondary")
554
+
555
+ # Create New Model Section
556
+ with gr.Group(visible=False) as create_new_section:
557
+ gr.Markdown("#### πŸ†• Create New Model")
558
+ new_model_name = gr.Textbox(
559
+ label="Model Name",
560
+ placeholder="Enter model name"
 
 
 
 
 
561
  )
562
+ new_llm = gr.Dropdown(
563
+ choices=[
564
+ "gemini-1.0-pro",
565
+ "gemini-1.5-pro",
566
+ "groq-llama-3.1-8b-instant",
567
+ "groq-mixtral-8x7b-32768",
568
+ "claude-3-sonnet-20240229"
569
+ ],
570
+ label="Select LLM",
571
+ interactive=True
 
 
 
 
 
 
 
 
572
  )
573
+ new_system_prompt = gr.Textbox(
574
+ label="System Prompt",
575
+ placeholder="Enter system prompt",
576
+ lines=3
577
+ )
578
+
579
  with gr.Row():
580
+ enhance_button = gr.Button("✨ Enhance Prompt", variant="primary")
581
+ cancel_button = gr.Button("❌ Cancel", variant="secondary")
582
+
583
+ enhanced_prompt_display = gr.Textbox(
584
+ label="Enhanced Prompt",
585
+ interactive=False,
586
+ lines=4,
587
+ visible=False
588
+ )
589
+
590
+ prompt_choice = gr.Radio(
591
+ choices=["Keep Enhanced", "Keep Original"],
592
+ label="Choose Prompt",
593
+ visible=False
594
+ )
595
+
596
+ save_model_button = gr.Button("πŸ’Ύ Save Model", variant="primary", visible=False)
597
+ save_status = gr.Textbox(label="Status", interactive=False, visible=False)
598
+
599
+ # Right Column - Model Operations
600
+ with gr.Column(scale=2):
601
+ gr.Markdown("### πŸ› οΈ Model Operations")
602
+
603
+ with gr.Tabs():
604
+ # Chatbot Tab
605
+ with gr.TabItem("πŸ’¬ Chatbot"):
606
+ selected_model_display = gr.Textbox(
607
+ label="Currently Selected Model",
608
+ interactive=False
609
+ )
610
+
611
+ chatbot_interface = gr.Chatbot(
612
+ type="messages",
613
+ height=400,
614
+ show_label=False
615
+ )
616
+
617
+ with gr.Row():
618
+ msg_input = gr.Textbox(
619
+ placeholder="Enter your message...",
620
+ label="Message",
621
+ scale=4
622
+ )
623
+ send_button = gr.Button("πŸ“€ Send", variant="primary", scale=1)
624
+
625
+ clear_chat = gr.Button("πŸ—‘οΈ Clear Chat", variant="secondary")
626
+
627
+ # Drift Analysis Tab
628
+ with gr.TabItem("πŸ“Š Drift Analysis"):
629
+ drift_model_display = gr.Textbox(
630
+ label="Model for Drift Analysis",
631
+ interactive=False
632
+ )
633
+
634
+ with gr.Row():
635
+ calculate_drift_button = gr.Button("πŸ” Calculate New Drift", variant="primary")
636
+ refresh_history_button = gr.Button("πŸ”„ Refresh History", variant="secondary")
637
+
638
+ drift_result = gr.Textbox(label="Latest Drift Calculation", interactive=False)
639
+
640
+ gr.Markdown("#### πŸ“ˆ Drift History")
641
+ drift_history_display = gr.JSON(label="Drift History Data")
642
+
643
+ gr.Markdown("#### πŸ“Š Drift Chart")
644
+ drift_chart = gr.Plot(label="Drift Over Time")
645
+
646
+ # Event Handlers with better error handling
647
+ search_box.change(update_model_dropdown, inputs=[search_box], outputs=[model_dropdown])
648
+ model_dropdown.change(on_model_select, inputs=[model_dropdown],
649
+ outputs=[selected_model_display, drift_model_display])
650
+
651
+ create_new_button.click(show_create_new, outputs=[create_new_section, new_model_name])
652
+ cancel_button.click(cancel_create_new,
653
+ outputs=[create_new_section, new_model_name, new_system_prompt, enhanced_prompt_display,
654
+ prompt_choice, save_model_button, save_status])
655
+
656
+ enhance_button.click(enhance_prompt, inputs=[new_system_prompt],
657
+ outputs=[enhanced_prompt_display, prompt_choice, save_model_button])
658
+ save_model_button.click(save_new_model,
659
+ inputs=[new_model_name, new_llm, new_system_prompt, enhanced_prompt_display,
660
+ prompt_choice],
661
+ outputs=[save_status, save_status, model_dropdown])
662
+
663
+ send_button.click(chatbot_response, inputs=[msg_input, chatbot_interface, model_dropdown],
664
+ outputs=[chatbot_interface, msg_input])
665
+ msg_input.submit(chatbot_response, inputs=[msg_input, chatbot_interface, model_dropdown],
666
+ outputs=[chatbot_interface, msg_input])
667
+ clear_chat.click(lambda: [], outputs=[chatbot_interface])
668
+
669
+ calculate_drift_button.click(calculate_drift, inputs=[model_dropdown], outputs=[drift_result])
670
+ refresh_history_button.click(refresh_drift_history, inputs=[model_dropdown],
671
+ outputs=[drift_history_display, drift_chart])
672
+
673
+ demo.load(initialize_interface,
674
+ outputs=[model_dropdown, new_model_name, selected_model_display, drift_model_display])
675
+
676
+ return demo
677
+
678
+
679
+ def main():
680
+ """Main function to launch the application"""
681
+ print("πŸš€ Starting AI Model Management Platform...")
682
+
683
+ # Create the interface
684
+ demo = create_interface()
685
+
686
+ # Launch configuration
687
+ launch_config = {
688
+ "server_name": "0.0.0.0", # Listen on all interfaces
689
+ "server_port": 7860, # Default Gradio port
690
+ "share": False, # Set to True if you want a public link
691
+ "show_error": True, # Show detailed errors
692
+ "quiet": False, # Set to True to reduce output
693
+ "show_api": True, # Show API docs
694
+ }
695
+
696
+ print("πŸ“‘ Launching Gradio interface...")
697
+ print(f"🌐 Server will be available at:")
698
+ print(f" - Local: http://localhost:{launch_config['server_port']}")
699
+ print(f" - Network: http://0.0.0.0:{launch_config['server_port']}")
700
+
701
+ try:
702
+ demo.launch(**launch_config)
703
+ except Exception as e:
704
+ print(f"❌ Failed to launch Gradio interface: {e}")
705
+ print("πŸ”§ Troubleshooting suggestions:")
706
+ print(" 1. Check if port 7860 is already in use")
707
+ print(" 2. Try a different port: demo.launch(server_port=7861)")
708
+ print(" 3. Check firewall settings")
709
+ print(" 4. Ensure Gradio is properly installed: pip install gradio")
710
+ return False
711
+
712
+ return True
713
+
714
 
715
  if __name__ == "__main__":
716
+ try:
717
+ main()
718
+ except KeyboardInterrupt:
719
+ print("\nπŸ‘‹ Shutting down gracefully...")
720
+ except Exception as e:
721
+ print(f"❌ Application error: {e}")
722
+ sys.exit(1)
database_module/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- #database_module/__init__.py
2
  from .db import init_db
3
  from .mcp_tools import (
4
  get_all_models_handler,
@@ -7,4 +7,4 @@ from .mcp_tools import (
7
  get_baseline_diagnostics,
8
  save_drift_score,
9
  register_model_with_capabilities
10
- )
 
1
+ # database_module/__init__.py
2
  from .db import init_db
3
  from .mcp_tools import (
4
  get_all_models_handler,
 
7
  get_baseline_diagnostics,
8
  save_drift_score,
9
  register_model_with_capabilities
10
+ )
database_module/db.py CHANGED
@@ -1,4 +1,4 @@
1
- #database_module/db.py
2
  import os
3
  from sqlalchemy import create_engine, inspect, text
4
  from sqlalchemy.ext.declarative import declarative_base
@@ -26,25 +26,34 @@ SessionLocal = sessionmaker(
26
  )
27
  Base = declarative_base()
28
 
 
29
  def apply_migrations():
30
  """
31
  Apply any necessary migrations to existing tables.
32
  """
33
  with engine.connect() as conn:
34
- # Check if the models table exists and has the capabilities column
35
  inspector = inspect(engine)
36
  if "models" in inspector.get_table_names():
37
  columns = [col['name'] for col in inspector.get_columns('models')]
 
 
38
  if "capabilities" not in columns:
39
- # Add capabilities column to models table
40
  conn.execute(text("ALTER TABLE models ADD COLUMN capabilities TEXT"))
41
  conn.commit()
42
  print("Migration: Added capabilities column to models table")
43
 
 
 
 
 
 
 
 
44
  def init_db():
45
  """
46
  Create tables if they don't exist.
47
  Call this once at application startup.
48
  """
49
  Base.metadata.create_all(bind=engine)
50
- apply_migrations()
 
1
+ # database_module/db.py
2
  import os
3
  from sqlalchemy import create_engine, inspect, text
4
  from sqlalchemy.ext.declarative import declarative_base
 
26
  )
27
  Base = declarative_base()
28
 
29
+
30
  def apply_migrations():
31
  """
32
  Apply any necessary migrations to existing tables.
33
  """
34
  with engine.connect() as conn:
35
+ # Check if the models table exists and has the required columns
36
  inspector = inspect(engine)
37
  if "models" in inspector.get_table_names():
38
  columns = [col['name'] for col in inspector.get_columns('models')]
39
+
40
+ # Add capabilities column if missing
41
  if "capabilities" not in columns:
 
42
  conn.execute(text("ALTER TABLE models ADD COLUMN capabilities TEXT"))
43
  conn.commit()
44
  print("Migration: Added capabilities column to models table")
45
 
46
+ # Add updated column if missing
47
+ if "updated" not in columns:
48
+ conn.execute(text("ALTER TABLE models ADD COLUMN updated DATETIME"))
49
+ conn.commit()
50
+ print("Migration: Added updated column to models table")
51
+
52
+
53
  def init_db():
54
  """
55
  Create tables if they don't exist.
56
  Call this once at application startup.
57
  """
58
  Base.metadata.create_all(bind=engine)
59
+ apply_migrations()
database_module/mcp_tools.py CHANGED
@@ -16,7 +16,11 @@ def get_all_models_handler(_: Dict[str, Any]) -> List[Dict[str, Any]]:
16
  with SessionLocal() as session:
17
  entries = session.query(ModelEntry).all()
18
  return [
19
- {"name": e.name, "created": e.created.isoformat(), "description": e.description or ""}
 
 
 
 
20
  for e in entries
21
  ]
22
 
@@ -39,7 +43,11 @@ def search_models_handler(params: Dict[str, Any]) -> List[Dict[str, Any]]:
39
  )
40
  entries = query.all()
41
  return [
42
- {"name": e.name, "created": e.created.isoformat(), "description": e.description or ""}
 
 
 
 
43
  for e in entries
44
  ]
45
 
@@ -53,9 +61,22 @@ def get_model_details_handler(params: Dict[str, Any]) -> Dict[str, Any]:
53
  with SessionLocal() as session:
54
  e = session.query(ModelEntry).filter_by(name=model_name).first()
55
  if not e:
56
- return {"name": model_name, "system_prompt": "You are a helpful AI assistant.", "description": ""}
57
- # You can store system_prompt as a column if desired; here placeholder
58
- return {"name": e.name, "system_prompt": "You are a helpful AI assistant.", "description": e.description or ""}
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
 
61
  def save_model_handler(params: Dict[str, Any]) -> Dict[str, Any]:
@@ -71,11 +92,16 @@ def save_model_handler(params: Dict[str, Any]) -> Dict[str, Any]:
71
  # New model; created today
72
  entry = ModelEntry(
73
  name=name,
74
- created=datetime.utcnow().date(),
75
- description=""
 
76
  )
77
  session.add(entry)
78
- # Optionally store prompt in another table or JSON field
 
 
 
 
79
  session.commit()
80
  return {"message": f"Model '{name}' saved."}
81
 
@@ -88,7 +114,7 @@ def calculate_drift_handler(params: Dict[str, Any]) -> Dict[str, Any]:
88
  import random
89
  name = params.get("model_name")
90
  score = round(random.uniform(0, 1), 3)
91
- today = datetime.utcnow().date()
92
  with SessionLocal() as session:
93
  entry = DriftEntry(
94
  model_name=name,
@@ -117,10 +143,10 @@ def get_drift_history_handler(params: Dict[str, Any]) -> List[Dict[str, Any]]:
117
  # === New functions for drift detection database operations ===
118
 
119
  def save_diagnostic_data(
120
- model_name: str,
121
- questions: list,
122
- answers: list,
123
- is_baseline: bool = False
124
  ) -> None:
125
  """
126
  Save diagnostic questions and answers to the database
@@ -131,7 +157,7 @@ def save_diagnostic_data(
131
  if not model:
132
  model = ModelEntry(
133
  name=model_name,
134
- created=datetime.utcnow().date(),
135
  description=""
136
  )
137
  session.add(model)
@@ -142,7 +168,7 @@ def save_diagnostic_data(
142
  is_baseline=1 if is_baseline else 0,
143
  questions=questions,
144
  answers=answers,
145
- created=datetime.utcnow()
146
  )
147
  session.add(diagnostic)
148
  session.commit()
@@ -153,9 +179,9 @@ def get_baseline_diagnostics(model_name: str) -> Optional[Dict[str, Any]]:
153
  Retrieve baseline diagnostics for a model
154
  """
155
  with SessionLocal() as session:
156
- baseline = session.query(DiagnosticData)\
157
- .filter_by(model_name=model_name, is_baseline=1)\
158
- .order_by(DiagnosticData.created.desc())\
159
  .first()
160
 
161
  if not baseline:
@@ -181,7 +207,7 @@ def save_drift_score(model_name: str, drift_score: str) -> None:
181
  with SessionLocal() as session:
182
  entry = DriftEntry(
183
  model_name=model_name,
184
- date=datetime.utcnow(),
185
  drift_score=score_float
186
  )
187
  session.add(entry)
@@ -197,13 +223,14 @@ def register_model_with_capabilities(model_name: str, capabilities: str) -> None
197
 
198
  if model:
199
  model.capabilities = capabilities
 
200
  else:
201
  model = ModelEntry(
202
  name=model_name,
203
- created=datetime.utcnow().date(),
204
  capabilities=capabilities,
205
  description=""
206
  )
207
  session.add(model)
208
 
209
- session.commit()
 
16
  with SessionLocal() as session:
17
  entries = session.query(ModelEntry).all()
18
  return [
19
+ {
20
+ "name": e.name,
21
+ "created": e.created.isoformat() if e.created else datetime.now().isoformat(),
22
+ "description": e.description or ""
23
+ }
24
  for e in entries
25
  ]
26
 
 
43
  )
44
  entries = query.all()
45
  return [
46
+ {
47
+ "name": e.name,
48
+ "created": e.created.isoformat() if e.created else datetime.now().isoformat(),
49
+ "description": e.description or ""
50
+ }
51
  for e in entries
52
  ]
53
 
 
61
  with SessionLocal() as session:
62
  e = session.query(ModelEntry).filter_by(name=model_name).first()
63
  if not e:
64
+ return {
65
+ "name": model_name,
66
+ "system_prompt": "You are a helpful AI assistant.",
67
+ "description": ""
68
+ }
69
+
70
+ # Extract system prompt from capabilities if available
71
+ system_prompt = "You are a helpful AI assistant."
72
+ if e.capabilities and "System Prompt: " in e.capabilities:
73
+ system_prompt = e.capabilities.split("System Prompt: ")[1]
74
+
75
+ return {
76
+ "name": e.name,
77
+ "system_prompt": system_prompt,
78
+ "description": e.description or ""
79
+ }
80
 
81
 
82
  def save_model_handler(params: Dict[str, Any]) -> Dict[str, Any]:
 
92
  # New model; created today
93
  entry = ModelEntry(
94
  name=name,
95
+ created=datetime.now(),
96
+ description="",
97
+ capabilities=f"System Prompt: {prompt}"
98
  )
99
  session.add(entry)
100
+ else:
101
+ # Update existing model
102
+ entry.capabilities = f"System Prompt: {prompt}"
103
+ entry.updated = datetime.now()
104
+
105
  session.commit()
106
  return {"message": f"Model '{name}' saved."}
107
 
 
114
  import random
115
  name = params.get("model_name")
116
  score = round(random.uniform(0, 1), 3)
117
+ today = datetime.now()
118
  with SessionLocal() as session:
119
  entry = DriftEntry(
120
  model_name=name,
 
143
  # === New functions for drift detection database operations ===
144
 
145
  def save_diagnostic_data(
146
+ model_name: str,
147
+ questions: list,
148
+ answers: list,
149
+ is_baseline: bool = False
150
  ) -> None:
151
  """
152
  Save diagnostic questions and answers to the database
 
157
  if not model:
158
  model = ModelEntry(
159
  name=model_name,
160
+ created=datetime.now(),
161
  description=""
162
  )
163
  session.add(model)
 
168
  is_baseline=1 if is_baseline else 0,
169
  questions=questions,
170
  answers=answers,
171
+ created=datetime.now()
172
  )
173
  session.add(diagnostic)
174
  session.commit()
 
179
  Retrieve baseline diagnostics for a model
180
  """
181
  with SessionLocal() as session:
182
+ baseline = session.query(DiagnosticData) \
183
+ .filter_by(model_name=model_name, is_baseline=1) \
184
+ .order_by(DiagnosticData.created.desc()) \
185
  .first()
186
 
187
  if not baseline:
 
207
  with SessionLocal() as session:
208
  entry = DriftEntry(
209
  model_name=model_name,
210
+ date=datetime.now(),
211
  drift_score=score_float
212
  )
213
  session.add(entry)
 
223
 
224
  if model:
225
  model.capabilities = capabilities
226
+ model.updated = datetime.now()
227
  else:
228
  model = ModelEntry(
229
  name=model_name,
230
+ created=datetime.now(),
231
  capabilities=capabilities,
232
  description=""
233
  )
234
  session.add(model)
235
 
236
+ session.commit()
database_module/models.py CHANGED
@@ -8,16 +8,17 @@ class ModelEntry(Base):
8
 
9
  id = Column(Integer, primary_key=True, index=True)
10
  name = Column(String, unique=True, nullable=False, index=True)
11
- created = Column(Date, nullable=False)
 
12
  description = Column(Text, nullable=True)
13
- capabilities = Column(Text, nullable=True) # Added to store model_capabilities
14
 
15
  class DriftEntry(Base):
16
  __tablename__ = "drift_history"
17
 
18
  id = Column(Integer, primary_key=True, index=True)
19
  model_name = Column(String, nullable=False, index=True)
20
- date = Column(DateTime, nullable=False, default=datetime.utcnow)
21
  drift_score = Column(Float, nullable=True)
22
 
23
  class DiagnosticData(Base):
@@ -25,7 +26,7 @@ class DiagnosticData(Base):
25
 
26
  id = Column(Integer, primary_key=True, index=True)
27
  model_name = Column(String, nullable=False, index=True)
28
- created = Column(DateTime, nullable=False, default=datetime.utcnow)
29
  is_baseline = Column(Integer, nullable=False, default=0) # 0=latest, 1=baseline
30
  questions = Column(JSON, nullable=True)
31
- answers = Column(JSON, nullable=True)
 
8
 
9
  id = Column(Integer, primary_key=True, index=True)
10
  name = Column(String, unique=True, nullable=False, index=True)
11
+ created = Column(DateTime, nullable=False, default=datetime.now)
12
+ updated = Column(DateTime, nullable=True) # Added updated field
13
  description = Column(Text, nullable=True)
14
+ capabilities = Column(Text, nullable=True) # Store model_capabilities
15
 
16
  class DriftEntry(Base):
17
  __tablename__ = "drift_history"
18
 
19
  id = Column(Integer, primary_key=True, index=True)
20
  model_name = Column(String, nullable=False, index=True)
21
+ date = Column(DateTime, nullable=False, default=datetime.now)
22
  drift_score = Column(Float, nullable=True)
23
 
24
  class DiagnosticData(Base):
 
26
 
27
  id = Column(Integer, primary_key=True, index=True)
28
  model_name = Column(String, nullable=False, index=True)
29
+ created = Column(DateTime, nullable=False, default=datetime.now)
30
  is_baseline = Column(Integer, nullable=False, default=0) # 0=latest, 1=baseline
31
  questions = Column(JSON, nullable=True)
32
+ answers = Column(JSON, nullable=True)
drift_detector.sqlite3 CHANGED
Binary files a/drift_detector.sqlite3 and b/drift_detector.sqlite3 differ
 
fastagent.config.yaml CHANGED
@@ -1,12 +1,16 @@
1
- mcp:
2
- servers:
3
- drift-server:
4
- transport: stdio
5
- command: "python"
6
- args: ["server.py"]
7
 
8
- default_model: "generic.llama3.1"
9
- generic:
10
- api_key: "ollama" # doesn't matter, just a placeholder
11
- base_url: "http://localhost:11434/v1"
 
 
 
 
12
 
 
 
 
 
 
 
1
+ # fastagent.config.yaml
 
 
 
 
 
2
 
3
+ servers:
4
+ drift-server:
5
+ transport: stdio
6
+ # Launch drift-server via Python in unbuffered mode so JSON‐RPC messages flow correctly
7
+ command: python
8
+ args:
9
+ - -u
10
+ - server.py
11
 
12
+ # Your model defaults (unchanged)
13
+ default_model: generic.llama3.1
14
+ generic:
15
+ api_key: ollama # placeholder
16
+ base_url: http://localhost:11434/v1
ourllm.py CHANGED
@@ -4,87 +4,107 @@ import mcp.types as types
4
  from langchain.chat_models import init_chat_model
5
  from dotenv import load_dotenv
6
  import os
 
 
7
  # Load environment variables from .env file
8
  load_dotenv()
9
  print("GROQ_API_KEY is set:", "GROQ_API_KEY" in os.environ)
10
 
11
- llm = init_chat_model("llama-3.1-8b-instant",model_provider='groq')
12
 
13
 
14
- def genratequestionnaire(model: str, capabilities: str) -> List[types.SamplingMessage]:
15
  """
16
  Generate a baseline questionnaire for the given model.
17
- Returns a list of SamplingMessage instances (role="user") with diagnostic questions.
18
  """
19
  global llm
20
  questions = []
21
  previously_generated = ""
22
 
23
- for i in range(0,5):
24
- response = llm.invoke("Generate a questionnaire for a model with the following capabilities:\n"
25
- "Model Name: " + model + "\n"
26
- "Capabilities Overview:\n" + capabilities + "\n"
27
- "Please provide one more question that cover the model's capabilities and typical use-cases.\n"
28
- "Previously generated questions:\n" + previously_generated +
29
- "\nQuestion " + str(i+1) + ":")
30
- new_question = str(response.content)
31
- questions.append(new_question)
32
- # Update previously_generated to include the new question
33
- if previously_generated:
34
- previously_generated += "\n"
35
- previously_generated += f"Question {i+1}: {new_question}"
36
-
37
- return [
38
- types.SamplingMessage(
39
- role="user",
40
- content=types.TextContent(type="text", text=q)
41
- )
42
- for q in questions
43
- ]
44
-
45
-
46
- def gradeanswers(old_answers: List[str], new_answers: List[str]) -> List[types.SamplingMessage]:
 
 
 
47
  """
48
  Use the LLM to compare the old and new answers to compute a drift score.
49
- Returns a list with a single SamplingMessage (role="assistant") whose content.text is the drift percentage.
50
  """
51
  global llm
52
 
53
  if not old_answers or not new_answers:
54
- drift_pct = 0.0
55
- else:
 
 
 
 
56
  # Prepare a prompt with old and new answers for the LLM to analyze
57
  prompt = "You're tasked with detecting semantic drift between two sets of model responses.\n\n"
58
  prompt += "Original responses:\n"
59
  for i, ans in enumerate(old_answers):
60
- prompt += f"Response {i+1}: {ans}\n\n"
61
 
62
  prompt += "New responses:\n"
63
  for i, ans in enumerate(new_answers):
64
- prompt += f"Response {i+1}: {ans}\n\n"
65
 
66
- prompt += "Analyze the semantic differences between the original and new responses. "
67
- prompt += "Provide a drift percentage score (0-100%) that represents how much the meaning, "
68
- prompt += "intent, or capabilities have changed between the two sets of responses. "
69
- prompt += "Only return the numerical percentage value without any explanation or additional text."
70
 
71
  # Get the drift assessment from the LLM
72
  response = llm.invoke(prompt)
73
  drift_text = str(response.content).strip()
74
 
75
  # Extract just the numerical value if there's extra text
76
- import re
77
  drift_match = re.search(r'(\d+\.?\d*)', drift_text)
78
  if drift_match:
79
  drift_pct = float(drift_match.group(1))
 
80
  else:
81
- # Fallback if no number found
82
- drift_pct = 0.0
83
-
84
- drift_text = f"{drift_pct}"
85
- return [
86
- types.SamplingMessage(
87
- role="assistant",
88
- content=types.TextContent(type="text", text=drift_text)
89
- )
90
- ]
 
 
 
 
 
 
 
 
 
 
 
 
4
  from langchain.chat_models import init_chat_model
5
  from dotenv import load_dotenv
6
  import os
7
+ import re
8
+
9
  # Load environment variables from .env file
10
  load_dotenv()
11
  print("GROQ_API_KEY is set:", "GROQ_API_KEY" in os.environ)
12
 
13
+ llm = init_chat_model("llama-3.1-8b-instant", model_provider='groq')
14
 
15
 
16
+ def genratequestionnaire(model: str, capabilities: str) -> List[str]:
17
  """
18
  Generate a baseline questionnaire for the given model.
19
+ Returns a list of question strings for diagnostic purposes.
20
  """
21
  global llm
22
  questions = []
23
  previously_generated = ""
24
 
25
+ for i in range(5):
26
+ try:
27
+ response = llm.invoke(
28
+ f"Generate a questionnaire for a model with the following capabilities:\n"
29
+ f"Model Name: {model}\n"
30
+ f"Capabilities Overview:\n{capabilities}\n"
31
+ f"Please provide one more question that covers the model's capabilities and typical use-cases.\n"
32
+ f"Previously generated questions:\n{previously_generated}\n"
33
+ f"Question {i + 1}:"
34
+ )
35
+ new_question = str(response.content).strip()
36
+ questions.append(new_question)
37
+
38
+ # Update previously_generated to include the new question
39
+ if previously_generated:
40
+ previously_generated += "\n"
41
+ previously_generated += f"Question {i + 1}: {new_question}"
42
+
43
+ except Exception as e:
44
+ print(f"Error generating question {i + 1}: {e}")
45
+ # Fallback question
46
+ questions.append(f"What are your capabilities as {model}?")
47
+
48
+ return questions
49
+
50
+
51
+ def gradeanswers(old_answers: List[str], new_answers: List[str]) -> str:
52
  """
53
  Use the LLM to compare the old and new answers to compute a drift score.
54
+ Returns a drift percentage as a string.
55
  """
56
  global llm
57
 
58
  if not old_answers or not new_answers:
59
+ return "0"
60
+
61
+ if len(old_answers) != len(new_answers):
62
+ return "100" # Major drift if answer count differs
63
+
64
+ try:
65
  # Prepare a prompt with old and new answers for the LLM to analyze
66
  prompt = "You're tasked with detecting semantic drift between two sets of model responses.\n\n"
67
  prompt += "Original responses:\n"
68
  for i, ans in enumerate(old_answers):
69
+ prompt += f"Response {i + 1}: {ans}\n\n"
70
 
71
  prompt += "New responses:\n"
72
  for i, ans in enumerate(new_answers):
73
+ prompt += f"Response {i + 1}: {ans}\n\n"
74
 
75
+ prompt += ("Analyze the semantic differences between the original and new responses. "
76
+ "Provide a drift percentage score (0-100%) that represents how much the meaning, "
77
+ "intent, or capabilities have changed between the two sets of responses. "
78
+ "Only return the numerical percentage value without any explanation or additional text.")
79
 
80
  # Get the drift assessment from the LLM
81
  response = llm.invoke(prompt)
82
  drift_text = str(response.content).strip()
83
 
84
  # Extract just the numerical value if there's extra text
 
85
  drift_match = re.search(r'(\d+\.?\d*)', drift_text)
86
  if drift_match:
87
  drift_pct = float(drift_match.group(1))
88
+ return str(int(drift_pct)) # Return as integer string
89
  else:
90
+ # Fallback: calculate simple text similarity
91
+ similarity_scores = []
92
+ for old, new in zip(old_answers, new_answers):
93
+ similarity = difflib.SequenceMatcher(None, old, new).ratio()
94
+ similarity_scores.append(similarity)
95
+
96
+ avg_similarity = sum(similarity_scores) / len(similarity_scores)
97
+ drift_pct = (1 - avg_similarity) * 100
98
+ return str(int(drift_pct))
99
+
100
+ except Exception as e:
101
+ print(f"Error grading answers: {e}")
102
+ # Fallback: calculate simple text similarity
103
+ similarity_scores = []
104
+ for old, new in zip(old_answers, new_answers):
105
+ similarity = difflib.SequenceMatcher(None, old, new).ratio()
106
+ similarity_scores.append(similarity)
107
+
108
+ avg_similarity = sum(similarity_scores) / len(similarity_scores)
109
+ drift_pct = (1 - avg_similarity) * 100
110
+ return str(int(drift_pct))
server.py CHANGED
@@ -25,6 +25,7 @@ init_db()
25
 
26
  app = Server("mcp-drift-server")
27
 
 
28
  # === Tool Manifest ===
29
  @app.list_tools()
30
  async def list_tools() -> List[types.Tool]:
@@ -36,7 +37,8 @@ async def list_tools() -> List[types.Tool]:
36
  "type": "object",
37
  "properties": {
38
  "model": {"type": "string", "description": "The name of the model to run diagnostics on"},
39
- "model_capabilities": {"type": "string", "description": "Full description of the model's capabilities, along with the system prompt."}
 
40
  },
41
  "required": ["model", "model_capabilities"]
42
  },
@@ -46,7 +48,8 @@ async def list_tools() -> List[types.Tool]:
46
  description="Re-run diagnostics and compare to baseline for drift scoring.",
47
  inputSchema={
48
  "type": "object",
49
- "properties": {"model": {"type": "string", "description": "The name of the model to run diagnostics on"}},
 
50
  "required": ["model"]
51
  },
52
  ),
@@ -66,132 +69,213 @@ async def list_tools() -> List[types.Tool]:
66
  ),
67
  ]
68
 
 
69
  # === Sampling Wrapper ===
70
  async def sample(messages: list[types.SamplingMessage], max_tokens=600) -> CreateMessageResult:
71
- return await app.request_context.session.create_message(
72
- messages=messages,
73
- max_tokens=max_tokens,
74
- temperature=0.7
75
- )
 
 
 
 
 
 
 
 
 
 
76
 
77
  # === Core Logic ===
78
  async def run_initial_diagnostics(arguments: Dict[str, Any]) -> List[types.TextContent]:
79
  model = arguments["model"]
80
  capabilities = arguments["model_capabilities"]
81
 
82
- # 1. Ask the server's internal LLM to generate a questionnaire
83
- questions = genratequestionnaire(model, capabilities) # Server-side trusted LLM
84
- answers = []
85
- for q in questions:
86
- a = await sample([q])
87
- answers.append(a)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # 2. Save the model capabilities and questions/answers to database
90
- register_model_with_capabilities(model, capabilities)
91
- save_diagnostic_data(
92
- model_name=model,
93
- questions=[m.content.text for m in questions],
94
- answers=[m.content.text for m in answers],
95
- is_baseline=True
96
- )
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- return [types.TextContent(type="text", text=f"βœ… Baseline stored for model: {model}")]
99
 
100
  async def check_drift(arguments: Dict[str, Any]) -> List[types.TextContent]:
101
  model = arguments["model"]
102
 
103
- # Get baseline from database
104
- baseline = get_baseline_diagnostics(model)
 
105
 
106
- # Ensure baseline exists
107
- if not baseline:
108
- return [types.TextContent(type="text", text=f"❌ No baseline for model: {model}")]
109
 
110
- # Convert questions to sampling messages
111
- questions = [
112
- types.SamplingMessage(role="user", content=types.TextContent(type="text", text=q))
113
- for q in baseline["questions"]
114
- ]
115
- old_answers = baseline["answers"]
116
-
117
- # Ask the model again
118
- new_answers_msgs = []
119
- for q in questions:
120
- a = await sample([q])
121
- new_answers_msgs.append(a)
122
- new_answers = [m.content.text for m in new_answers_msgs]
123
-
124
- # Grade the answers and get a drift score
125
- grading_response = gradeanswers(old_answers, new_answers)
126
- drift_score = grading_response[0].content.text.strip()
127
-
128
- # Save the latest responses and drift score to database
129
- save_diagnostic_data(
130
- model_name=model,
131
- questions=baseline["questions"],
132
- answers=new_answers,
133
- is_baseline=False
134
- )
135
- save_drift_score(model, drift_score)
136
-
137
- # Alert threshold
138
- try:
139
- score_val = float(drift_score)
140
- alert = "🚨 Significant drift!" if score_val > 50 else "βœ… Drift OK"
141
- except ValueError:
142
- alert = "⚠️ Drift score not numeric"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- return [
145
- types.TextContent(type="text", text=f"Drift score for {model}: {drift_score}"),
146
- types.TextContent(type="text", text=alert)
147
- ]
148
 
149
  # Database tool handlers
150
  async def get_all_models_handler_async(_: Dict[str, Any]) -> List[types.TextContent]:
151
- models = get_all_models_handler({})
152
- if not models:
153
- return [types.TextContent(type="text", text="No models registered.")]
 
 
 
 
 
 
 
 
 
 
154
 
155
- model_list = "\n".join([f"β€’ {m['name']} - {m['description']}" for m in models])
156
- return [types.TextContent(
157
- type="text",
158
- text=f"Registered models:\n{model_list}"
159
- )]
160
 
161
  async def search_models_handler_async(arguments: Dict[str, Any]) -> List[types.TextContent]:
162
- query = arguments.get("query", "")
163
- models = search_models_handler({"search_term": query})
 
 
 
 
 
 
 
164
 
165
- if not models:
166
  return [types.TextContent(
167
  type="text",
168
- text=f"No models found matching '{query}'."
169
  )]
 
 
 
170
 
171
- model_list = "\n".join([f"β€’ {m['name']} - {m['description']}" for m in models])
172
- return [types.TextContent(
173
- type="text",
174
- text=f"Models matching '{query}':\n{model_list}"
175
- )]
176
 
177
  # === Dispatcher ===
178
  @app.call_tool()
179
  async def dispatch_tool(name: str, arguments: Dict[str, Any] | None = None):
180
- if name == "run_initial_diagnostics":
181
- return await run_initial_diagnostics(arguments)
182
- elif name == "check_drift":
183
- return await check_drift(arguments)
184
- elif name == "get_all_models":
185
- return await get_all_models_handler_async(arguments or {})
186
- elif name == "search_models":
187
- return await search_models_handler_async(arguments or {})
188
- else:
189
- raise ValueError(f"Unknown tool: {name}")
 
 
 
 
 
190
 
191
  # === Entrypoint ===
192
  async def main():
193
- async with stdio_server() as (reader, writer):
194
- await app.run(reader, writer, app.create_initialization_options())
 
 
 
 
195
 
196
  if __name__ == "__main__":
197
- asyncio.run(main())
 
25
 
26
  app = Server("mcp-drift-server")
27
 
28
+
29
  # === Tool Manifest ===
30
  @app.list_tools()
31
  async def list_tools() -> List[types.Tool]:
 
37
  "type": "object",
38
  "properties": {
39
  "model": {"type": "string", "description": "The name of the model to run diagnostics on"},
40
+ "model_capabilities": {"type": "string",
41
+ "description": "Full description of the model's capabilities, along with the system prompt."}
42
  },
43
  "required": ["model", "model_capabilities"]
44
  },
 
48
  description="Re-run diagnostics and compare to baseline for drift scoring.",
49
  inputSchema={
50
  "type": "object",
51
+ "properties": {
52
+ "model": {"type": "string", "description": "The name of the model to run diagnostics on"}},
53
  "required": ["model"]
54
  },
55
  ),
 
69
  ),
70
  ]
71
 
72
+
73
  # === Sampling Wrapper ===
74
  async def sample(messages: list[types.SamplingMessage], max_tokens=600) -> CreateMessageResult:
75
+ try:
76
+ return await app.request_context.session.create_message(
77
+ messages=messages,
78
+ max_tokens=max_tokens,
79
+ temperature=0.7
80
+ )
81
+ except Exception as e:
82
+ print(f"Error in sampling: {e}")
83
+ # Return a fallback response
84
+ return CreateMessageResult(
85
+ content=types.TextContent(type="text", text="Error generating response"),
86
+ model="unknown",
87
+ role="assistant"
88
+ )
89
+
90
 
91
  # === Core Logic ===
92
  async def run_initial_diagnostics(arguments: Dict[str, Any]) -> List[types.TextContent]:
93
  model = arguments["model"]
94
  capabilities = arguments["model_capabilities"]
95
 
96
+ try:
97
+ # 1. Generate questionnaire using ourllm (returns list of strings)
98
+ questions = genratequestionnaire(model, capabilities)
99
+
100
+ # 2. Convert questions to sampling messages and get answers
101
+ answers = []
102
+ for question_text in questions:
103
+ try:
104
+ sampling_msg = types.SamplingMessage(
105
+ role="user",
106
+ content=types.TextContent(type="text", text=question_text)
107
+ )
108
+ answer_result = await sample([sampling_msg])
109
+
110
+ # Extract text content from the answer
111
+ if hasattr(answer_result, 'content'):
112
+ if hasattr(answer_result.content, 'text'):
113
+ answers.append(answer_result.content.text)
114
+ else:
115
+ answers.append(str(answer_result.content))
116
+ else:
117
+ answers.append("No response generated")
118
+
119
+ except Exception as e:
120
+ print(f"Error getting answer for question '{question_text}': {e}")
121
+ answers.append(f"Error: {str(e)}")
122
 
123
+ # 3. Save the model capabilities and questions/answers to database
124
+ try:
125
+ register_model_with_capabilities(model, capabilities)
126
+ save_diagnostic_data(
127
+ model_name=model,
128
+ questions=questions,
129
+ answers=answers,
130
+ is_baseline=True
131
+ )
132
+ except Exception as e:
133
+ print(f"Error saving diagnostic data: {e}")
134
+ return [types.TextContent(type="text", text=f"❌ Error saving baseline for {model}: {str(e)}")]
135
+
136
+ return [
137
+ types.TextContent(type="text", text=f"βœ… Baseline stored for model: {model} ({len(questions)} questions)")]
138
+
139
+ except Exception as e:
140
+ print(f"Error in run_initial_diagnostics: {e}")
141
+ return [types.TextContent(type="text", text=f"❌ Error running diagnostics for {model}: {str(e)}")]
142
 
 
143
 
144
  async def check_drift(arguments: Dict[str, Any]) -> List[types.TextContent]:
145
  model = arguments["model"]
146
 
147
+ try:
148
+ # Get baseline from database
149
+ baseline = get_baseline_diagnostics(model)
150
 
151
+ # Ensure baseline exists
152
+ if not baseline:
153
+ return [types.TextContent(type="text", text=f"❌ No baseline for model: {model}")]
154
 
155
+ # Get old answers from baseline
156
+ old_answers = baseline["answers"]
157
+ questions = baseline["questions"]
158
+
159
+ # Ask the model the same questions again
160
+ new_answers = []
161
+ for question_text in questions:
162
+ try:
163
+ sampling_msg = types.SamplingMessage(
164
+ role="user",
165
+ content=types.TextContent(type="text", text=question_text)
166
+ )
167
+ answer_result = await sample([sampling_msg])
168
+
169
+ # Extract text content from the answer
170
+ if hasattr(answer_result, 'content'):
171
+ if hasattr(answer_result.content, 'text'):
172
+ new_answers.append(answer_result.content.text)
173
+ else:
174
+ new_answers.append(str(answer_result.content))
175
+ else:
176
+ new_answers.append("No response generated")
177
+
178
+ except Exception as e:
179
+ print(f"Error getting new answer for question '{question_text}': {e}")
180
+ new_answers.append(f"Error: {str(e)}")
181
+
182
+ # Grade the answers and get a drift score (returns string)
183
+ drift_score_str = gradeanswers(old_answers, new_answers)
184
+
185
+ # Save the latest responses and drift score to database
186
+ try:
187
+ save_diagnostic_data(
188
+ model_name=model,
189
+ questions=questions,
190
+ answers=new_answers,
191
+ is_baseline=False
192
+ )
193
+ save_drift_score(model, drift_score_str)
194
+ except Exception as e:
195
+ print(f"Error saving drift data: {e}")
196
+
197
+ # Alert threshold
198
+ try:
199
+ score_val = float(drift_score_str)
200
+ alert = "🚨 Significant drift!" if score_val > 50 else "βœ… Drift OK"
201
+ except ValueError:
202
+ alert = "⚠️ Drift score not numeric"
203
+
204
+ return [
205
+ types.TextContent(type="text", text=f"Drift score for {model}: {drift_score_str}%"),
206
+ types.TextContent(type="text", text=alert)
207
+ ]
208
+
209
+ except Exception as e:
210
+ print(f"Error in check_drift: {e}")
211
+ return [types.TextContent(type="text", text=f"❌ Error checking drift for {model}: {str(e)}")]
212
 
 
 
 
 
213
 
214
  # Database tool handlers
215
  async def get_all_models_handler_async(_: Dict[str, Any]) -> List[types.TextContent]:
216
+ try:
217
+ models = get_all_models_handler({})
218
+ if not models:
219
+ return [types.TextContent(type="text", text="No models registered.")]
220
+
221
+ model_list = "\n".join([f"β€’ {m['name']} - {m.get('description', 'No description')}" for m in models])
222
+ return [types.TextContent(
223
+ type="text",
224
+ text=f"Registered models:\n{model_list}"
225
+ )]
226
+ except Exception as e:
227
+ print(f"Error getting all models: {e}")
228
+ return [types.TextContent(type="text", text=f"❌ Error retrieving models: {str(e)}")]
229
 
 
 
 
 
 
230
 
231
  async def search_models_handler_async(arguments: Dict[str, Any]) -> List[types.TextContent]:
232
+ try:
233
+ query = arguments.get("query", "")
234
+ models = search_models_handler({"search_term": query})
235
+
236
+ if not models:
237
+ return [types.TextContent(
238
+ type="text",
239
+ text=f"No models found matching '{query}'."
240
+ )]
241
 
242
+ model_list = "\n".join([f"β€’ {m['name']} - {m.get('description', 'No description')}" for m in models])
243
  return [types.TextContent(
244
  type="text",
245
+ text=f"Models matching '{query}':\n{model_list}"
246
  )]
247
+ except Exception as e:
248
+ print(f"Error searching models: {e}")
249
+ return [types.TextContent(type="text", text=f"❌ Error searching models: {str(e)}")]
250
 
 
 
 
 
 
251
 
252
  # === Dispatcher ===
253
  @app.call_tool()
254
  async def dispatch_tool(name: str, arguments: Dict[str, Any] | None = None):
255
+ try:
256
+ if name == "run_initial_diagnostics":
257
+ return await run_initial_diagnostics(arguments)
258
+ elif name == "check_drift":
259
+ return await check_drift(arguments)
260
+ elif name == "get_all_models":
261
+ return await get_all_models_handler_async(arguments or {})
262
+ elif name == "search_models":
263
+ return await search_models_handler_async(arguments or {})
264
+ else:
265
+ return [types.TextContent(type="text", text=f"❌ Unknown tool: {name}")]
266
+ except Exception as e:
267
+ print(f"Error in dispatch_tool for {name}: {e}")
268
+ return [types.TextContent(type="text", text=f"❌ Error executing {name}: {str(e)}")]
269
+
270
 
271
  # === Entrypoint ===
272
  async def main():
273
+ try:
274
+ async with stdio_server() as (reader, writer):
275
+ await app.run(reader, writer, app.create_initialization_options())
276
+ except Exception as e:
277
+ print(f"Error running MCP server: {e}")
278
+
279
 
280
  if __name__ == "__main__":
281
+ asyncio.run(main())
test_llm.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test_llm.py - Create this as a separate file to test your LLM setup
2
+ import os
3
+ from dotenv import load_dotenv
4
+
5
+ print("=== Testing LLM Setup ===")
6
+
7
+ # Load environment variables
8
+ load_dotenv()
9
+ print(f"βœ… Environment loaded")
10
+ print(f"πŸ”‘ GROQ_API_KEY exists: {'GROQ_API_KEY' in os.environ}")
11
+ if 'GROQ_API_KEY' in os.environ:
12
+ key = os.environ['GROQ_API_KEY']
13
+ print(f"πŸ”‘ API Key starts with: {key[:10]}...")
14
+
15
+ # Test LLM import
16
+ try:
17
+ from ourllm import llm
18
+
19
+ print("βœ… Successfully imported LLM")
20
+
21
+ # Test LLM call
22
+ test_message = "Hello, please respond with 'LLM is working correctly'"
23
+ print(f"πŸ§ͺ Testing with message: {test_message}")
24
+
25
+ response = llm.invoke(test_message)
26
+ print(f"βœ… LLM Response: {response.content}")
27
+
28
+ except ImportError as e:
29
+ print(f"❌ Import error: {e}")
30
+ except Exception as e:
31
+ print(f"❌ LLM call error: {e}")
32
+
33
+ print("=== Test Complete ===")