NLarchive commited on
Commit
3bbd581
·
verified ·
1 Parent(s): 0d98fee

Create workflow_vizualizer.py

Browse files
Files changed (1) hide show
  1. workflow_vizualizer.py +501 -0
workflow_vizualizer.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import Dict, List, Any, Optional
3
+ from dataclasses import dataclass, asdict
4
+ import re
5
+
6
+ # Import visualization dependencies with fallbacks
7
+ try:
8
+ import networkx as nx
9
+ import matplotlib.pyplot as plt
10
+ plt.switch_backend('Agg')
11
+ import matplotlib
12
+ matplotlib.use('Agg')
13
+ import warnings
14
+ warnings.filterwarnings('ignore', category=UserWarning, module='matplotlib')
15
+ except ImportError:
16
+ print("Warning: Visualization deps missing. Install with: pip install networkx matplotlib")
17
+ nx = None
18
+ plt = None
19
+
20
+ @dataclass
21
+ class WorkflowStep:
22
+ step_id: str
23
+ step_type: str
24
+ timestamp: float
25
+ content: str
26
+ metadata: Dict[str, Any]
27
+ duration: Optional[float] = None
28
+ status: str = 'pending'
29
+ parent_step: Optional[str] = None
30
+ details: Optional[Dict[str, Any]] = None
31
+ mcp_server: Optional[str] = None # Added to track MCP server
32
+ tool_name: Optional[str] = None # Added to track specific tool
33
+
34
+ class EnhancedWorkflowVisualizer:
35
+ def __init__(self):
36
+ self.steps: List[WorkflowStep] = []
37
+ self.current_step: Optional[WorkflowStep] = None
38
+ self.start_time = time.time()
39
+ self.step_counter = 0
40
+
41
+ # MCP server mapping for better display names
42
+ self.server_display_names = {
43
+ "7860": "Semantic Server",
44
+ "7861": "Token Counter",
45
+ "7862": "Sentiment Analysis",
46
+ "7863": "Health Monitor"
47
+ }
48
+
49
+ def _extract_mcp_server_from_url(self, url_or_content: str) -> Optional[str]:
50
+ """Extract MCP server name from URL or content."""
51
+ if not url_or_content:
52
+ return None
53
+
54
+ # Extract port from URL
55
+ port_match = re.search(r':(\d{4})', url_or_content)
56
+ if port_match:
57
+ port = port_match.group(1)
58
+ return self.server_display_names.get(port, f"Port {port}")
59
+
60
+ # Check for server keywords in content
61
+ if "semantic" in url_or_content.lower():
62
+ return "Semantic Server"
63
+ elif "token" in url_or_content.lower():
64
+ return "Token Counter"
65
+ elif "sentiment" in url_or_content.lower():
66
+ return "Sentiment Analysis"
67
+ elif "health" in url_or_content.lower():
68
+ return "Health Monitor"
69
+
70
+ return None
71
+
72
+ def _extract_tool_name(self, content: str) -> Optional[str]:
73
+ """Extract tool name from content."""
74
+ # Enhanced tool patterns - prioritize actual function names
75
+ function_patterns = [
76
+ # Specific MCP tool functions (high priority)
77
+ r'\b(sentiment_analysis)\s*\(',
78
+ r'\b(count_tokens_openai_gpt4)\s*\(',
79
+ r'\b(count_tokens_openai_gpt3)\s*\(',
80
+ r'\b(count_tokens_openai_davinci)\s*\(',
81
+ r'\b(count_tokens_bert_family)\s*\(',
82
+ r'\b(count_tokens_roberta_family)\s*\(',
83
+ r'\b(count_tokens_gpt2_family)\s*\(',
84
+ r'\b(count_tokens_t5_family)\s*\(',
85
+ r'\b(count_tokens_distilbert)\s*\(',
86
+ r'\b(semantic_similarity)\s*\(',
87
+ r'\b(find_similar_sentences)\s*\(',
88
+ r'\b(extract_semantic_keywords)\s*\(',
89
+ r'\b(semantic_search_in_text)\s*\(',
90
+ r'\b(health_check)\s*\(',
91
+ r'\b(server_status)\s*\(',
92
+ r'\b(get_server_info)\s*\(',
93
+
94
+ # Generic patterns (lower priority)
95
+ r'(\w*sentiment_analysis\w*)',
96
+ r'(\w*semantic_similarity\w*)',
97
+ r'(\w*find_similar_sentences\w*)',
98
+ r'(\w*extract_semantic_keywords\w*)',
99
+ r'(\w*semantic_search_in_text\w*)',
100
+ r'(\w*count_tokens_\w+)',
101
+ r'(\w*health_check\w*)',
102
+ r'(\w*server_status\w*)'
103
+ ]
104
+
105
+ # Try high-priority function call patterns first
106
+ for pattern in function_patterns:
107
+ match = re.search(pattern, content, re.IGNORECASE)
108
+ if match:
109
+ tool_name = match.group(1)
110
+ # Skip common non-tool functions
111
+ if tool_name not in ['print', 'len', 'str', 'int', 'float', 'final_answer', 'sse', 'model']:
112
+ return tool_name
113
+
114
+ # Check for execution logs that contain actual function names
115
+ if "count_tokens_openai_gpt4" in content:
116
+ return "count_tokens_openai_gpt4"
117
+ elif "sentiment_analysis" in content:
118
+ return "sentiment_analysis"
119
+ elif "extract_semantic_keywords" in content:
120
+ return "extract_semantic_keywords"
121
+ elif "semantic_similarity" in content:
122
+ return "semantic_similarity"
123
+ elif "find_similar_sentences" in content:
124
+ return "find_similar_sentences"
125
+ elif "semantic_search_in_text" in content:
126
+ return "semantic_search_in_text"
127
+
128
+ return None
129
+
130
+ def add_step(self, step_type: str, content: str, metadata: Optional[Dict[str, Any]] = None,
131
+ parent_step: Optional[str] = None, details: Optional[Dict[str, Any]] = None,
132
+ mcp_server: Optional[str] = None, tool_name: Optional[str] = None) -> str:
133
+ step_id = f"{step_type}_{self.step_counter}"
134
+ self.step_counter += 1
135
+
136
+ # Auto-extract MCP server and tool if not provided
137
+ if not mcp_server:
138
+ mcp_server = self._extract_mcp_server_from_url(content)
139
+ if not tool_name:
140
+ tool_name = self._extract_tool_name(content)
141
+
142
+ step = WorkflowStep(
143
+ step_id=step_id,
144
+ step_type=step_type,
145
+ timestamp=time.time(),
146
+ content=content,
147
+ metadata=metadata or {},
148
+ status='running',
149
+ parent_step=parent_step,
150
+ details=details or {},
151
+ mcp_server=mcp_server,
152
+ tool_name=tool_name
153
+ )
154
+ self.steps.append(step)
155
+ self.current_step = step
156
+ return step_id
157
+
158
+ def complete_step(self, step_id: str, status: str = 'completed',
159
+ additional_metadata: Optional[Dict[str, Any]] = None,
160
+ details: Optional[Dict[str, Any]] = None):
161
+ for step in self.steps:
162
+ if step.step_id == step_id:
163
+ step.status = status
164
+ step.duration = time.time() - step.timestamp
165
+ if additional_metadata and step.metadata is not None:
166
+ step.metadata.update(additional_metadata)
167
+ if details and step.details is not None:
168
+ step.details.update(details)
169
+ break
170
+
171
+ def add_communication_step(self, from_component: str, to_component: str,
172
+ message_type: str, content: str,
173
+ parent_step: Optional[str] = None) -> str:
174
+ """Add a communication step between components."""
175
+ step_type = f"comm_{from_component}_to_{to_component}"
176
+
177
+ # Extract server info for communication steps
178
+ mcp_server = self._extract_mcp_server_from_url(content)
179
+ tool_name = self._extract_tool_name(content)
180
+
181
+ details = {
182
+ "from": from_component,
183
+ "to": to_component,
184
+ "message_type": message_type,
185
+ "content_preview": content[:100] + "..." if len(content) > 100 else content
186
+ }
187
+ return self.add_step(step_type, f"{message_type}: {from_component} → {to_component}",
188
+ parent_step=parent_step, details=details,
189
+ mcp_server=mcp_server, tool_name=tool_name)
190
+
191
+ def add_tool_execution_step(self, tool_name: str, mcp_server: str,
192
+ input_data: str, parent_step: Optional[str] = None) -> str:
193
+ """Specialized method for tool execution steps."""
194
+ content = f"Executing {tool_name} on {mcp_server}"
195
+ return self.add_step("tool_execution", content,
196
+ parent_step=parent_step,
197
+ mcp_server=mcp_server,
198
+ tool_name=tool_name,
199
+ details={"input_preview": input_data[:50] + "..." if len(input_data) > 50 else input_data})
200
+
201
+ def generate_graph(self) -> Any:
202
+ if nx is None:
203
+ return None
204
+
205
+ G = nx.DiGraph()
206
+
207
+ # Enhanced color mapping with server-specific colors
208
+ color_map = {
209
+ 'input': '#e3f2fd', # Light blue
210
+ 'agent_init': '#f3e5f5', # Light purple
211
+ 'agent_process': '#e8f5e8', # Light green
212
+ 'comm_agent_to_mcp': '#fff3e0', # Light orange
213
+ 'comm_mcp_to_server': '#ffebee', # Light red
214
+ 'comm_server_to_mcp': '#e0f2f1', # Light teal
215
+ 'comm_mcp_to_agent': '#f9fbe7', # Light lime
216
+ 'llm_call': '#fce4ec', # Light pink
217
+ 'tool_execution': '#e1f5fe', # Light cyan
218
+ 'response': '#f1f8e9', # Light green
219
+ 'error': '#ffcdd2' # Light red
220
+ }
221
+
222
+ # Add nodes with enhanced labeling
223
+ for step in self.steps:
224
+ color = color_map.get(step.step_type, '#f5f5f5')
225
+
226
+ # Create enhanced label with MCP server and tool info
227
+ duration_str = f" ({step.duration:.2f}s)" if step.duration else ""
228
+
229
+ # Build comprehensive label
230
+ label_parts = []
231
+
232
+ # Add step type
233
+ step_display = step.step_type.replace('_', ' ').title()
234
+ label_parts.append(step_display)
235
+
236
+ # Add MCP server info
237
+ if step.mcp_server:
238
+ label_parts.append(f"📡 {step.mcp_server}")
239
+
240
+ # Add tool name prominently
241
+ if step.tool_name:
242
+ label_parts.append(f"🔧 {step.tool_name}")
243
+
244
+ # Add content preview (shortened to make room for server/tool info)
245
+ content_preview = step.content[:20] + "..." if len(step.content) > 20 else step.content
246
+ if not step.tool_name or step.tool_name.lower() not in content_preview.lower():
247
+ label_parts.append(content_preview)
248
+
249
+ # Add duration
250
+ if duration_str:
251
+ label_parts.append(duration_str)
252
+
253
+ label = "\n".join(label_parts)
254
+
255
+ G.add_node(step.step_id,
256
+ label=label,
257
+ color=color,
258
+ step_type=step.step_type,
259
+ status=step.status,
260
+ mcp_server=step.mcp_server,
261
+ tool_name=step.tool_name)
262
+
263
+ # Add edges based on parent relationships and chronological order
264
+ for i, step in enumerate(self.steps):
265
+ if step.parent_step:
266
+ # Add edge from parent step
267
+ G.add_edge(step.parent_step, step.step_id, edge_type='parent')
268
+ elif i > 0:
269
+ # Add chronological edge to previous step
270
+ G.add_edge(self.steps[i-1].step_id, step.step_id, edge_type='sequence')
271
+
272
+ return G
273
+
274
+ def create_matplotlib_visualization(self) -> str:
275
+ if nx is None or plt is None:
276
+ return ""
277
+
278
+ G = self.generate_graph()
279
+ if not G or len(G.nodes()) == 0:
280
+ return ""
281
+
282
+ # Create larger figure to accommodate enhanced labels
283
+ fig, ax = plt.subplots(figsize=(20, 12))
284
+
285
+ # Use hierarchical layout if possible
286
+ try:
287
+ pos = nx.spring_layout(G, k=3, iterations=150, seed=42)
288
+ except:
289
+ pos = nx.circular_layout(G)
290
+
291
+ # Prepare node visualization with server-aware coloring
292
+ node_colors = []
293
+ node_labels = {}
294
+ node_sizes = []
295
+
296
+ for node_id in G.nodes():
297
+ step = next(s for s in self.steps if s.step_id == node_id)
298
+
299
+ # Enhanced color coding based on status and server
300
+ if step.status == 'error':
301
+ color = '#ff5252'
302
+ elif step.status == 'completed':
303
+ # Server-specific color coding
304
+ if step.mcp_server == "Semantic Server":
305
+ base_color = '#4caf50' # Green for semantic
306
+ elif step.mcp_server == "Token Counter":
307
+ base_color = '#2196f3' # Blue for token counting
308
+ elif step.mcp_server == "Sentiment Analysis":
309
+ base_color = '#ff9800' # Orange for sentiment
310
+ elif step.mcp_server == "Health Monitor":
311
+ base_color = '#9c27b0' # Purple for health
312
+ else:
313
+ # Default colors by step type
314
+ base_colors = {
315
+ 'input': '#4caf50',
316
+ 'agent_init': '#9c27b0',
317
+ 'agent_process': '#2e7d32',
318
+ 'comm_agent_to_mcp': '#ff9800',
319
+ 'comm_mcp_to_server': '#f44336',
320
+ 'comm_server_to_mcp': '#009688',
321
+ 'comm_mcp_to_agent': '#8bc34a',
322
+ 'llm_call': '#e91e63',
323
+ 'tool_execution': '#03a9f4',
324
+ 'response': '#4caf50'
325
+ }
326
+ base_color = base_colors.get(step.step_type, '#607d8b')
327
+ color = base_color
328
+ else:
329
+ color = '#bdbdbd'
330
+
331
+ node_colors.append(color)
332
+
333
+ # Create enhanced node labels
334
+ label_parts = []
335
+
336
+ # Step type
337
+ step_display = step.step_type.replace('_', ' ').title()
338
+ label_parts.append(f"**{step_display}**")
339
+
340
+ # MCP Server (prominent)
341
+ if step.mcp_server:
342
+ label_parts.append(f"📡 {step.mcp_server}")
343
+
344
+ # Tool name (most prominent)
345
+ if step.tool_name:
346
+ label_parts.append(f"🔧 **{step.tool_name}**")
347
+
348
+ # Duration
349
+ if step.duration:
350
+ label_parts.append(f"⏱️ {step.duration:.2f}s")
351
+
352
+ node_labels[node_id] = "\n".join(label_parts)
353
+
354
+ # Size based on importance - larger for tool executions
355
+ if step.step_type == 'tool_execution':
356
+ node_sizes.append(5000)
357
+ elif step.step_type in ['input', 'response']:
358
+ node_sizes.append(4000)
359
+ elif 'comm_' in step.step_type:
360
+ node_sizes.append(2500)
361
+ else:
362
+ node_sizes.append(3000)
363
+
364
+ # Draw the graph
365
+ nx.draw(G, pos,
366
+ node_color=node_colors,
367
+ node_size=node_sizes,
368
+ font_size=9,
369
+ font_weight='bold',
370
+ arrows=True,
371
+ arrowsize=20,
372
+ edge_color='#666666',
373
+ alpha=0.9,
374
+ ax=ax,
375
+ arrowstyle='->')
376
+
377
+ # Draw enhanced labels
378
+ nx.draw_networkx_labels(G, pos, node_labels, font_size=8, ax=ax)
379
+
380
+ # Add title and formatting
381
+ ax.set_title("MCP Agent Workflow: Server & Tool Execution Flow",
382
+ fontsize=20, pad=25, fontweight='bold')
383
+ ax.axis('off')
384
+
385
+ # Enhanced legend with server info
386
+ legend_elements = [
387
+ plt.Rectangle((0,0),1,1, facecolor='#4caf50', label='Semantic Server'),
388
+ plt.Rectangle((0,0),1,1, facecolor='#2196f3', label='Token Counter Server'),
389
+ plt.Rectangle((0,0),1,1, facecolor='#ff9800', label='Sentiment Analysis Server'),
390
+ plt.Rectangle((0,0),1,1, facecolor='#9c27b0', label='Health Monitor Server'),
391
+ plt.Rectangle((0,0),1,1, facecolor='#e91e63', label='LLM Calls'),
392
+ plt.Rectangle((0,0),1,1, facecolor='#607d8b', label='Agent Processing'),
393
+ ]
394
+ ax.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(0, 1))
395
+
396
+ fig.set_constrained_layout(True)
397
+
398
+ # Save to temporary file
399
+ import tempfile
400
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
401
+ plt.savefig(temp_file.name, format='png', dpi=300, bbox_inches='tight')
402
+ plt.close(fig)
403
+
404
+ return temp_file.name
405
+
406
+ def get_workflow_summary(self) -> Dict[str, Any]:
407
+ total_duration = time.time() - self.start_time
408
+
409
+ # Count steps by type and server
410
+ step_counts = {}
411
+ server_usage = {}
412
+ tool_usage = {}
413
+ communication_steps = []
414
+ processing_steps = []
415
+
416
+ for step in self.steps:
417
+ step_counts[step.step_type] = step_counts.get(step.step_type, 0) + 1
418
+
419
+ # Track server usage
420
+ if step.mcp_server:
421
+ server_usage[step.mcp_server] = server_usage.get(step.mcp_server, 0) + 1
422
+
423
+ # Track tool usage
424
+ if step.tool_name:
425
+ tool_usage[step.tool_name] = tool_usage.get(step.tool_name, 0) + 1
426
+
427
+ if 'comm_' in step.step_type:
428
+ communication_steps.append({
429
+ 'step_id': step.step_id,
430
+ 'from': step.details.get('from', 'unknown') if step.details else 'unknown',
431
+ 'to': step.details.get('to', 'unknown') if step.details else 'unknown',
432
+ 'message_type': step.details.get('message_type', 'unknown') if step.details else 'unknown',
433
+ 'mcp_server': step.mcp_server,
434
+ 'tool_name': step.tool_name,
435
+ 'duration': step.duration,
436
+ 'status': step.status
437
+ })
438
+ else:
439
+ processing_steps.append({
440
+ 'step_id': step.step_id,
441
+ 'type': step.step_type,
442
+ 'content': step.content[:50] + "..." if len(step.content) > 50 else step.content,
443
+ 'mcp_server': step.mcp_server,
444
+ 'tool_name': step.tool_name,
445
+ 'duration': step.duration,
446
+ 'status': step.status
447
+ })
448
+
449
+ # Calculate timing statistics
450
+ completed_steps = [s for s in self.steps if s.duration is not None]
451
+ avg_duration = (sum(s.duration or 0 for s in completed_steps) / len(completed_steps)) if completed_steps else 0
452
+
453
+ return {
454
+ 'total_steps': len(self.steps),
455
+ 'total_duration': round(total_duration, 3),
456
+ 'average_step_duration': round(avg_duration, 3),
457
+ 'step_counts': step_counts,
458
+ 'server_usage': server_usage, # New: server usage stats
459
+ 'tool_usage': tool_usage, # New: tool usage stats
460
+ 'communication_flow': communication_steps,
461
+ 'processing_steps': processing_steps,
462
+ 'status': 'completed' if all(s.status in ['completed', 'error'] for s in self.steps) else 'running',
463
+ 'error_count': sum(1 for s in self.steps if s.status == 'error'),
464
+ 'success_rate': round((sum(1 for s in self.steps if s.status == 'completed') / len(self.steps)) * 100, 1) if self.steps else 0,
465
+ 'detailed_steps': [asdict(s) for s in self.steps]
466
+ }
467
+
468
+ # Global instance
469
+ workflow_visualizer = EnhancedWorkflowVisualizer()
470
+
471
+ # Enhanced helper functions
472
+ def track_workflow_step(step_type: str, content: str, metadata: Optional[Dict[str, Any]] = None,
473
+ parent_step: Optional[str] = None, mcp_server: Optional[str] = None,
474
+ tool_name: Optional[str] = None) -> str:
475
+ return workflow_visualizer.add_step(step_type, content, metadata, parent_step,
476
+ mcp_server=mcp_server, tool_name=tool_name)
477
+
478
+ def track_communication(from_component: str, to_component: str, message_type: str,
479
+ content: str, parent_step: Optional[str] = None) -> str:
480
+ return workflow_visualizer.add_communication_step(from_component, to_component,
481
+ message_type, content, parent_step)
482
+
483
+ def track_tool_execution(tool_name: str, mcp_server: str, input_data: str,
484
+ parent_step: Optional[str] = None) -> str:
485
+ """New helper for tracking tool executions with clear server/tool info."""
486
+ return workflow_visualizer.add_tool_execution_step(tool_name, mcp_server, input_data, parent_step)
487
+
488
+ def complete_workflow_step(step_id: str, status: str = 'completed',
489
+ metadata: Optional[Dict[str, Any]] = None,
490
+ details: Optional[Dict[str, Any]] = None):
491
+ workflow_visualizer.complete_step(step_id, status, metadata, details)
492
+
493
+ def get_workflow_visualization() -> str:
494
+ return workflow_visualizer.create_matplotlib_visualization()
495
+
496
+ def get_workflow_summary() -> Dict[str, Any]:
497
+ return workflow_visualizer.get_workflow_summary()
498
+
499
+ def reset_workflow():
500
+ global workflow_visualizer
501
+ workflow_visualizer = EnhancedWorkflowVisualizer()