Spaces:
Runtime error
Runtime error
| """Advanced strategy coordination patterns for the unified reasoning engine.""" | |
| import logging | |
| from typing import Dict, Any, List, Optional, Set, Union, Type, Callable | |
| import json | |
| from dataclasses import dataclass, field | |
| from enum import Enum | |
| from datetime import datetime | |
| import asyncio | |
| from collections import defaultdict | |
| from .base import ReasoningStrategy | |
| from .unified_engine import StrategyType, StrategyResult, UnifiedResult | |
| class CoordinationPattern(Enum): | |
| """Types of strategy coordination patterns.""" | |
| PIPELINE = "pipeline" | |
| PARALLEL = "parallel" | |
| HIERARCHICAL = "hierarchical" | |
| FEEDBACK = "feedback" | |
| ADAPTIVE = "adaptive" | |
| ENSEMBLE = "ensemble" | |
| class CoordinationPhase(Enum): | |
| """Phases in strategy coordination.""" | |
| INITIALIZATION = "initialization" | |
| EXECUTION = "execution" | |
| SYNCHRONIZATION = "synchronization" | |
| ADAPTATION = "adaptation" | |
| COMPLETION = "completion" | |
| class CoordinationState: | |
| """State of strategy coordination.""" | |
| pattern: CoordinationPattern | |
| active_strategies: Dict[StrategyType, bool] | |
| phase: CoordinationPhase | |
| shared_context: Dict[str, Any] | |
| synchronization_points: List[str] | |
| adaptation_history: List[Dict[str, Any]] | |
| metadata: Dict[str, Any] = field(default_factory=dict) | |
| class StrategyInteraction: | |
| """Interaction between strategies.""" | |
| source: StrategyType | |
| target: StrategyType | |
| interaction_type: str | |
| data: Dict[str, Any] | |
| timestamp: datetime = field(default_factory=datetime.now) | |
| class StrategyCoordinator: | |
| """ | |
| Advanced strategy coordinator that: | |
| 1. Manages strategy interactions | |
| 2. Implements coordination patterns | |
| 3. Handles state synchronization | |
| 4. Adapts coordination dynamically | |
| 5. Optimizes strategy combinations | |
| """ | |
| def __init__(self, | |
| strategies: Dict[StrategyType, ReasoningStrategy], | |
| learning_rate: float = 0.1): | |
| self.strategies = strategies | |
| self.learning_rate = learning_rate | |
| # Coordination state | |
| self.states: Dict[str, CoordinationState] = {} | |
| self.interactions: List[StrategyInteraction] = [] | |
| # Pattern performance | |
| self.pattern_performance: Dict[CoordinationPattern, List[float]] = defaultdict(list) | |
| self.pattern_weights: Dict[CoordinationPattern, float] = { | |
| pattern: 1.0 for pattern in CoordinationPattern | |
| } | |
| async def coordinate(self, | |
| query: str, | |
| context: Dict[str, Any], | |
| pattern: Optional[CoordinationPattern] = None) -> Dict[str, Any]: | |
| """Coordinate strategy execution using specified pattern.""" | |
| try: | |
| # Select pattern if not specified | |
| if not pattern: | |
| pattern = await self._select_pattern(query, context) | |
| # Initialize coordination | |
| state = await self._initialize_coordination(pattern, context) | |
| # Execute coordination pattern | |
| if pattern == CoordinationPattern.PIPELINE: | |
| result = await self._coordinate_pipeline(query, context, state) | |
| elif pattern == CoordinationPattern.PARALLEL: | |
| result = await self._coordinate_parallel(query, context, state) | |
| elif pattern == CoordinationPattern.HIERARCHICAL: | |
| result = await self._coordinate_hierarchical(query, context, state) | |
| elif pattern == CoordinationPattern.FEEDBACK: | |
| result = await self._coordinate_feedback(query, context, state) | |
| elif pattern == CoordinationPattern.ADAPTIVE: | |
| result = await self._coordinate_adaptive(query, context, state) | |
| elif pattern == CoordinationPattern.ENSEMBLE: | |
| result = await self._coordinate_ensemble(query, context, state) | |
| else: | |
| raise ValueError(f"Unsupported coordination pattern: {pattern}") | |
| # Update performance metrics | |
| self._update_pattern_performance(pattern, result) | |
| return result | |
| except Exception as e: | |
| logging.error(f"Error in strategy coordination: {str(e)}") | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "pattern": pattern.value if pattern else None | |
| } | |
| async def _select_pattern(self, query: str, context: Dict[str, Any]) -> CoordinationPattern: | |
| """Select appropriate coordination pattern.""" | |
| prompt = f""" | |
| Select coordination pattern: | |
| Query: {query} | |
| Context: {json.dumps(context)} | |
| Consider: | |
| 1. Task complexity and type | |
| 2. Strategy dependencies | |
| 3. Resource constraints | |
| 4. Performance history | |
| 5. Adaptation needs | |
| Format as: | |
| [Selection] | |
| Pattern: ... | |
| Rationale: ... | |
| Confidence: ... | |
| """ | |
| response = await context["groq_api"].predict(prompt) | |
| selection = self._parse_pattern_selection(response["answer"]) | |
| # Weight by performance history | |
| weighted_patterns = { | |
| pattern: self.pattern_weights[pattern] * selection.get(pattern.value, 0.0) | |
| for pattern in CoordinationPattern | |
| } | |
| return max(weighted_patterns.items(), key=lambda x: x[1])[0] | |
| async def _coordinate_pipeline(self, | |
| query: str, | |
| context: Dict[str, Any], | |
| state: CoordinationState) -> Dict[str, Any]: | |
| """Coordinate strategies in pipeline pattern.""" | |
| results = [] | |
| current_context = context.copy() | |
| # Determine optimal order | |
| strategy_order = await self._determine_pipeline_order(query, context) | |
| for strategy_type in strategy_order: | |
| try: | |
| # Execute strategy | |
| strategy = self.strategies[strategy_type] | |
| result = await strategy.reason(query, current_context) | |
| # Update context with result | |
| current_context.update({ | |
| "previous_result": result, | |
| "pipeline_position": len(results) | |
| }) | |
| results.append(StrategyResult( | |
| strategy_type=strategy_type, | |
| success=result.get("success", False), | |
| answer=result.get("answer"), | |
| confidence=result.get("confidence", 0.0), | |
| reasoning_trace=result.get("reasoning_trace", []), | |
| metadata=result.get("metadata", {}), | |
| performance_metrics=result.get("performance_metrics", {}) | |
| )) | |
| # Record interaction | |
| self._record_interaction( | |
| source=strategy_type, | |
| target=strategy_order[len(results)] if len(results) < len(strategy_order) else None, | |
| interaction_type="pipeline_transfer", | |
| data={"result": result} | |
| ) | |
| except Exception as e: | |
| logging.error(f"Error in pipeline strategy {strategy_type}: {str(e)}") | |
| return { | |
| "success": any(r.success for r in results), | |
| "results": results, | |
| "pattern": CoordinationPattern.PIPELINE.value, | |
| "metrics": { | |
| "total_steps": len(results), | |
| "success_rate": sum(1 for r in results if r.success) / len(results) if results else 0 | |
| } | |
| } | |
| async def _coordinate_parallel(self, | |
| query: str, | |
| context: Dict[str, Any], | |
| state: CoordinationState) -> Dict[str, Any]: | |
| """Coordinate strategies in parallel pattern.""" | |
| async def execute_strategy(strategy_type: StrategyType) -> StrategyResult: | |
| try: | |
| strategy = self.strategies[strategy_type] | |
| result = await strategy.reason(query, context) | |
| return StrategyResult( | |
| strategy_type=strategy_type, | |
| success=result.get("success", False), | |
| answer=result.get("answer"), | |
| confidence=result.get("confidence", 0.0), | |
| reasoning_trace=result.get("reasoning_trace", []), | |
| metadata=result.get("metadata", {}), | |
| performance_metrics=result.get("performance_metrics", {}) | |
| ) | |
| except Exception as e: | |
| logging.error(f"Error in parallel strategy {strategy_type}: {str(e)}") | |
| return StrategyResult( | |
| strategy_type=strategy_type, | |
| success=False, | |
| answer=None, | |
| confidence=0.0, | |
| reasoning_trace=[{"error": str(e)}], | |
| metadata={}, | |
| performance_metrics={} | |
| ) | |
| # Execute strategies in parallel | |
| tasks = [execute_strategy(strategy_type) | |
| for strategy_type in state.active_strategies | |
| if state.active_strategies[strategy_type]] | |
| results = await asyncio.gather(*tasks) | |
| # Synthesize results | |
| synthesis = await self._synthesize_parallel_results(results, context) | |
| return { | |
| "success": synthesis.get("success", False), | |
| "results": results, | |
| "synthesis": synthesis, | |
| "pattern": CoordinationPattern.PARALLEL.value, | |
| "metrics": { | |
| "total_strategies": len(results), | |
| "success_rate": sum(1 for r in results if r.success) / len(results) if results else 0 | |
| } | |
| } | |
| async def _coordinate_hierarchical(self, | |
| query: str, | |
| context: Dict[str, Any], | |
| state: CoordinationState) -> Dict[str, Any]: | |
| """Coordinate strategies in hierarchical pattern.""" | |
| # Build strategy hierarchy | |
| hierarchy = await self._build_strategy_hierarchy(query, context) | |
| results = {} | |
| async def execute_level(level_strategies: List[StrategyType], | |
| level_context: Dict[str, Any]) -> List[StrategyResult]: | |
| tasks = [] | |
| for strategy_type in level_strategies: | |
| if strategy_type in state.active_strategies and state.active_strategies[strategy_type]: | |
| strategy = self.strategies[strategy_type] | |
| tasks.append(strategy.reason(query, level_context)) | |
| level_results = await asyncio.gather(*tasks) | |
| return [ | |
| StrategyResult( | |
| strategy_type=strategy_type, | |
| success=result.get("success", False), | |
| answer=result.get("answer"), | |
| confidence=result.get("confidence", 0.0), | |
| reasoning_trace=result.get("reasoning_trace", []), | |
| metadata=result.get("metadata", {}), | |
| performance_metrics=result.get("performance_metrics", {}) | |
| ) | |
| for strategy_type, result in zip(level_strategies, level_results) | |
| ] | |
| # Execute hierarchy levels | |
| current_context = context.copy() | |
| for level, level_strategies in enumerate(hierarchy): | |
| results[level] = await execute_level(level_strategies, current_context) | |
| # Update context for next level | |
| current_context.update({ | |
| "previous_level_results": results[level], | |
| "hierarchy_level": level | |
| }) | |
| return { | |
| "success": any(any(r.success for r in level_results) | |
| for level_results in results.values()), | |
| "results": results, | |
| "hierarchy": hierarchy, | |
| "pattern": CoordinationPattern.HIERARCHICAL.value, | |
| "metrics": { | |
| "total_levels": len(hierarchy), | |
| "level_success_rates": { | |
| level: sum(1 for r in results[level] if r.success) / len(results[level]) | |
| for level in results if results[level] | |
| } | |
| } | |
| } | |
| async def _coordinate_feedback(self, | |
| query: str, | |
| context: Dict[str, Any], | |
| state: CoordinationState) -> Dict[str, Any]: | |
| """Coordinate strategies with feedback loops.""" | |
| results = [] | |
| feedback_history = [] | |
| current_context = context.copy() | |
| max_iterations = 5 # Prevent infinite loops | |
| iteration = 0 | |
| while iteration < max_iterations: | |
| iteration += 1 | |
| # Execute strategies | |
| iteration_results = [] | |
| for strategy_type in state.active_strategies: | |
| if state.active_strategies[strategy_type]: | |
| try: | |
| strategy = self.strategies[strategy_type] | |
| result = await strategy.reason(query, current_context) | |
| strategy_result = StrategyResult( | |
| strategy_type=strategy_type, | |
| success=result.get("success", False), | |
| answer=result.get("answer"), | |
| confidence=result.get("confidence", 0.0), | |
| reasoning_trace=result.get("reasoning_trace", []), | |
| metadata=result.get("metadata", {}), | |
| performance_metrics=result.get("performance_metrics", {}) | |
| ) | |
| iteration_results.append(strategy_result) | |
| except Exception as e: | |
| logging.error(f"Error in feedback strategy {strategy_type}: {str(e)}") | |
| results.append(iteration_results) | |
| # Generate feedback | |
| feedback = await self._generate_feedback(iteration_results, current_context) | |
| feedback_history.append(feedback) | |
| # Check termination condition | |
| if self._should_terminate_feedback(feedback, iteration_results): | |
| break | |
| # Update context with feedback | |
| current_context.update({ | |
| "previous_results": iteration_results, | |
| "feedback": feedback, | |
| "iteration": iteration | |
| }) | |
| return { | |
| "success": any(any(r.success for r in iteration_results) | |
| for iteration_results in results), | |
| "results": results, | |
| "feedback_history": feedback_history, | |
| "pattern": CoordinationPattern.FEEDBACK.value, | |
| "metrics": { | |
| "total_iterations": iteration, | |
| "feedback_impact": self._calculate_feedback_impact(results, feedback_history) | |
| } | |
| } | |
| async def _coordinate_adaptive(self, | |
| query: str, | |
| context: Dict[str, Any], | |
| state: CoordinationState) -> Dict[str, Any]: | |
| """Coordinate strategies with adaptive selection.""" | |
| results = [] | |
| adaptations = [] | |
| current_context = context.copy() | |
| while len(results) < len(state.active_strategies): | |
| # Select next strategy | |
| next_strategy = await self._select_next_strategy( | |
| results, state.active_strategies, current_context) | |
| if not next_strategy: | |
| break | |
| try: | |
| # Execute strategy | |
| strategy = self.strategies[next_strategy] | |
| result = await strategy.reason(query, current_context) | |
| strategy_result = StrategyResult( | |
| strategy_type=next_strategy, | |
| success=result.get("success", False), | |
| answer=result.get("answer"), | |
| confidence=result.get("confidence", 0.0), | |
| reasoning_trace=result.get("reasoning_trace", []), | |
| metadata=result.get("metadata", {}), | |
| performance_metrics=result.get("performance_metrics", {}) | |
| ) | |
| results.append(strategy_result) | |
| # Adapt strategy selection | |
| adaptation = await self._adapt_strategy_selection( | |
| strategy_result, current_context) | |
| adaptations.append(adaptation) | |
| # Update context | |
| current_context.update({ | |
| "previous_results": results, | |
| "adaptations": adaptations, | |
| "current_strategy": next_strategy | |
| }) | |
| except Exception as e: | |
| logging.error(f"Error in adaptive strategy {next_strategy}: {str(e)}") | |
| return { | |
| "success": any(r.success for r in results), | |
| "results": results, | |
| "adaptations": adaptations, | |
| "pattern": CoordinationPattern.ADAPTIVE.value, | |
| "metrics": { | |
| "total_strategies": len(results), | |
| "adaptation_impact": self._calculate_adaptation_impact(results, adaptations) | |
| } | |
| } | |
| async def _coordinate_ensemble(self, | |
| query: str, | |
| context: Dict[str, Any], | |
| state: CoordinationState) -> Dict[str, Any]: | |
| """Coordinate strategies as an ensemble.""" | |
| # Execute all strategies | |
| results = [] | |
| for strategy_type in state.active_strategies: | |
| if state.active_strategies[strategy_type]: | |
| try: | |
| strategy = self.strategies[strategy_type] | |
| result = await strategy.reason(query, context) | |
| strategy_result = StrategyResult( | |
| strategy_type=strategy_type, | |
| success=result.get("success", False), | |
| answer=result.get("answer"), | |
| confidence=result.get("confidence", 0.0), | |
| reasoning_trace=result.get("reasoning_trace", []), | |
| metadata=result.get("metadata", {}), | |
| performance_metrics=result.get("performance_metrics", {}) | |
| ) | |
| results.append(strategy_result) | |
| except Exception as e: | |
| logging.error(f"Error in ensemble strategy {strategy_type}: {str(e)}") | |
| # Combine results using ensemble methods | |
| ensemble_result = await self._combine_ensemble_results(results, context) | |
| return { | |
| "success": ensemble_result.get("success", False), | |
| "results": results, | |
| "ensemble_result": ensemble_result, | |
| "pattern": CoordinationPattern.ENSEMBLE.value, | |
| "metrics": { | |
| "total_members": len(results), | |
| "ensemble_confidence": ensemble_result.get("confidence", 0.0) | |
| } | |
| } | |
| def _record_interaction(self, | |
| source: StrategyType, | |
| target: Optional[StrategyType], | |
| interaction_type: str, | |
| data: Dict[str, Any]): | |
| """Record strategy interaction.""" | |
| self.interactions.append(StrategyInteraction( | |
| source=source, | |
| target=target, | |
| interaction_type=interaction_type, | |
| data=data | |
| )) | |
| def _update_pattern_performance(self, pattern: CoordinationPattern, result: Dict[str, Any]): | |
| """Update pattern performance metrics.""" | |
| success_rate = result["metrics"].get("success_rate", 0.0) | |
| self.pattern_performance[pattern].append(success_rate) | |
| # Update weights using exponential moving average | |
| current_weight = self.pattern_weights[pattern] | |
| self.pattern_weights[pattern] = ( | |
| (1 - self.learning_rate) * current_weight + | |
| self.learning_rate * success_rate | |
| ) | |
| def get_performance_metrics(self) -> Dict[str, Any]: | |
| """Get comprehensive performance metrics.""" | |
| return { | |
| "pattern_weights": dict(self.pattern_weights), | |
| "average_performance": { | |
| pattern.value: sum(scores) / len(scores) if scores else 0 | |
| for pattern, scores in self.pattern_performance.items() | |
| }, | |
| "interaction_counts": defaultdict(int, { | |
| interaction.interaction_type: 1 | |
| for interaction in self.interactions | |
| }), | |
| "active_patterns": [ | |
| pattern.value for pattern, weight in self.pattern_weights.items() | |
| if weight > 0.5 | |
| ] | |
| } | |