Spaces:
Sleeping
Sleeping
initial commit
Browse files- .gitignore +0 -0
- README.md +73 -13
- agents/__init__.py +0 -0
- agents/assembler_agent.py +204 -0
- agents/base_agent.py +99 -0
- agents/context_agent.py +246 -0
- agents/image_agent.py +1065 -0
- agents/user_input_agent.py +321 -0
- app.py +54 -0
- config/__init__.py +0 -0
- config/settings.py +11 -0
- interface/__init__.py +49 -0
- interface/app.py +153 -0
- interface/display.py +50 -0
- interface/handlers.py +129 -0
- interface/utils.py +89 -0
- models/__init__.py +0 -0
- models/data_models.py +44 -0
- models/model_config.py +9 -0
- requirements.txt +31 -0
- utils/resource_manager.py +54 -0
.gitignore
ADDED
|
File without changes
|
README.md
CHANGED
|
@@ -1,13 +1,73 @@
|
|
| 1 |
-
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.16.0
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Sustainable Content Moderation
|
| 3 |
+
emoji: 🌍
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.16.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
# Multi-Model multi-agent Analysis System
|
| 12 |
+
|
| 13 |
+
A multi-agent system for analyzing pump equipment images using state-of-the-art AI models.
|
| 14 |
+
|
| 15 |
+
## Overview
|
| 16 |
+
This application uses multiple AI agents to:
|
| 17 |
+
- Process user queries about pump equipment
|
| 18 |
+
- Analyze equipment images
|
| 19 |
+
- Search relevant technical context
|
| 20 |
+
- Generate comprehensive analysis reports
|
| 21 |
+
|
| 22 |
+
## Usage
|
| 23 |
+
1. Enter your query about pump equipment
|
| 24 |
+
2. (Optional) Add any specific constraints
|
| 25 |
+
3. Upload equipment images (supported formats: JPG, JPEG, PNG)
|
| 26 |
+
4. Select number of top results to show
|
| 27 |
+
5. Choose report format (summary/detailed)
|
| 28 |
+
6. Click "Analyze" to process
|
| 29 |
+
|
| 30 |
+
## Technical Details
|
| 31 |
+
- Built with Gradio 5.16.0
|
| 32 |
+
- Uses multi-model approach:
|
| 33 |
+
- Lightweight models for initial processing
|
| 34 |
+
- Advanced models for detailed analysis
|
| 35 |
+
- Implements memory-efficient processing
|
| 36 |
+
- Supports batch image processing
|
| 37 |
+
|
| 38 |
+
## Limitations
|
| 39 |
+
- Maximum image size: 5MB
|
| 40 |
+
- Maximum resolution: 2048x2048
|
| 41 |
+
- Maximum images per request: 10
|
| 42 |
+
- Query length limit: 500 characters
|
| 43 |
+
|
| 44 |
+
## Models Used
|
| 45 |
+
- Image Captioning: BLIP and BLIP-2
|
| 46 |
+
- Image Classification: ResNet and ViT
|
| 47 |
+
- Text Processing: LaMini-Flan-T5
|
| 48 |
+
|
| 49 |
+
## Requirements
|
| 50 |
+
- Python 3.8+
|
| 51 |
+
- See requirements.txt for full dependencies
|
| 52 |
+
|
| 53 |
+
## Local Development
|
| 54 |
+
`pip install -r requirements.txt python app.py`
|
| 55 |
+
|
| 56 |
+
## Deployment
|
| 57 |
+
This app is optimized for HuggingFace Spaces deployment.
|
| 58 |
+
|
| 59 |
+
## Error Handling
|
| 60 |
+
- Input validation with clear error messages
|
| 61 |
+
- Resource monitoring and automatic cleanup
|
| 62 |
+
- Graceful error recovery
|
| 63 |
+
|
| 64 |
+
## Credits
|
| 65 |
+
- Built using HuggingFace's model hub
|
| 66 |
+
- Powered by Gradio interface
|
| 67 |
+
- Uses Wikipedia API for context gathering
|
| 68 |
+
|
| 69 |
+
## License
|
| 70 |
+
MIT License
|
| 71 |
+
|
| 72 |
+
## Support
|
| 73 |
+
For issues or questions, please open a GitHub issue.
|
agents/__init__.py
ADDED
|
File without changes
|
agents/assembler_agent.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
from models.model_config import ModelConfig
|
| 3 |
+
from models.data_models import AssemblerInput
|
| 4 |
+
from .base_agent import BaseAgent
|
| 5 |
+
import datetime
|
| 6 |
+
|
| 7 |
+
class AssemblerAgent(BaseAgent):
|
| 8 |
+
def __init__(self, name: str = "AssemblerAgent"):
|
| 9 |
+
super().__init__(name)
|
| 10 |
+
self.final_report: Dict = {}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def reason(self, input_data: AssemblerInput) -> List[str]:
|
| 14 |
+
"""
|
| 15 |
+
Plan how to assemble the final report from all agent results
|
| 16 |
+
"""
|
| 17 |
+
thoughts = []
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
# Analyze available inputs
|
| 21 |
+
thoughts.append("Analyzing inputs from all agents:")
|
| 22 |
+
thoughts.append(f"- User input processing results available: {bool(input_data.user_input_results)}")
|
| 23 |
+
thoughts.append(f"- Context learning results available: {bool(input_data.context_results)}")
|
| 24 |
+
thoughts.append(f"- Image analysis results available: {bool(input_data.image_results)}")
|
| 25 |
+
|
| 26 |
+
# Plan report structure
|
| 27 |
+
thoughts.append("\nPlanning report structure:")
|
| 28 |
+
thoughts.append("1. User Query Summary")
|
| 29 |
+
thoughts.append("2. Context Analysis")
|
| 30 |
+
thoughts.append("3. Image Analysis Results")
|
| 31 |
+
thoughts.append("4. Model Performance Metrics")
|
| 32 |
+
thoughts.append("5. Final Recommendations")
|
| 33 |
+
|
| 34 |
+
# Consider report format
|
| 35 |
+
thoughts.append(f"\nReport Format: {input_data.report_format}")
|
| 36 |
+
if input_data.report_format == "detailed":
|
| 37 |
+
thoughts.append("- Will include full model decisions")
|
| 38 |
+
thoughts.append("- Will include confidence scores")
|
| 39 |
+
thoughts.append("- Will include processing statistics")
|
| 40 |
+
else:
|
| 41 |
+
thoughts.append("- Will provide condensed summary")
|
| 42 |
+
thoughts.append("- Will focus on key findings")
|
| 43 |
+
|
| 44 |
+
# Store thoughts in state
|
| 45 |
+
self.state.thoughts.extend(thoughts)
|
| 46 |
+
self.logger.info("Reasoning complete for report assembly")
|
| 47 |
+
|
| 48 |
+
return thoughts
|
| 49 |
+
|
| 50 |
+
except Exception as e:
|
| 51 |
+
error_msg = f"Error during assembly reasoning: {str(e)}"
|
| 52 |
+
self.state.errors.append(error_msg)
|
| 53 |
+
self.logger.error(error_msg)
|
| 54 |
+
return thoughts
|
| 55 |
+
|
| 56 |
+
def execute(self, input_data: AssemblerInput) -> Dict:
|
| 57 |
+
"""
|
| 58 |
+
Assemble final report from all agent results
|
| 59 |
+
"""
|
| 60 |
+
try:
|
| 61 |
+
if not self.validate(input_data):
|
| 62 |
+
return {
|
| 63 |
+
'status': 'error',
|
| 64 |
+
'error': self.state.errors[-1]
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
report = {
|
| 68 |
+
'summary': {
|
| 69 |
+
'user_query': {},
|
| 70 |
+
'context_analysis': {},
|
| 71 |
+
'image_analysis': {},
|
| 72 |
+
'recommendations': []
|
| 73 |
+
},
|
| 74 |
+
'details': {
|
| 75 |
+
'model_decisions': {},
|
| 76 |
+
'processing_stats': {},
|
| 77 |
+
'confidence_scores': {}
|
| 78 |
+
},
|
| 79 |
+
'metadata': {
|
| 80 |
+
'report_format': input_data.report_format,
|
| 81 |
+
'timestamp': datetime.now().isoformat()
|
| 82 |
+
},
|
| 83 |
+
'status': 'processing'
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
# Process user input results
|
| 87 |
+
if input_data.user_input_results:
|
| 88 |
+
report['summary']['user_query'] = {
|
| 89 |
+
'original_query': input_data.user_input_results.get('query', ''),
|
| 90 |
+
'constraints': input_data.user_input_results.get('constraints', []),
|
| 91 |
+
'intent': input_data.user_input_results.get('intent', '')
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
# Process context results
|
| 95 |
+
if input_data.context_results:
|
| 96 |
+
report['summary']['context_analysis'] = {
|
| 97 |
+
'key_findings': input_data.context_results.get('summaries', {}),
|
| 98 |
+
'relevant_keywords': list(input_data.context_results.get('keywords', set())),
|
| 99 |
+
'sources': list(input_data.context_results.get('gathered_context', {}).keys())
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
# Process image analysis results
|
| 103 |
+
if input_data.image_results:
|
| 104 |
+
report['summary']['image_analysis'] = {
|
| 105 |
+
'selected_images': input_data.image_results.get('selected_images', []),
|
| 106 |
+
'analysis_summary': {
|
| 107 |
+
path: results['caption']['text']
|
| 108 |
+
for path, results in input_data.image_results.get('analyzed_images', {}).items()
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
# Add detailed information if requested
|
| 113 |
+
if input_data.report_format == "detailed":
|
| 114 |
+
report['details']['model_decisions'] = {
|
| 115 |
+
'context_models': input_data.context_results.get('model_decisions', {}),
|
| 116 |
+
'image_models': input_data.image_results.get('model_decisions', {})
|
| 117 |
+
}
|
| 118 |
+
report['details']['processing_stats'] = {
|
| 119 |
+
'context_processing': input_data.context_results.get('model_decisions', {}).get('processing_stats', {}),
|
| 120 |
+
'image_processing': input_data.image_results.get('model_decisions', {}).get('processing_stats', {})
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
# Generate recommendations
|
| 124 |
+
report['summary']['recommendations'] = self._generate_recommendations(
|
| 125 |
+
report['summary']['context_analysis'],
|
| 126 |
+
report['summary']['image_analysis']
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
report['status'] = 'success'
|
| 130 |
+
self.final_report = report
|
| 131 |
+
|
| 132 |
+
# Log decision
|
| 133 |
+
self.log_decision({
|
| 134 |
+
'action': 'report_assembly',
|
| 135 |
+
'format': input_data.report_format,
|
| 136 |
+
'sections_completed': list(report['summary'].keys())
|
| 137 |
+
})
|
| 138 |
+
|
| 139 |
+
return report
|
| 140 |
+
|
| 141 |
+
except Exception as e:
|
| 142 |
+
error_msg = f"Error executing report assembly: {str(e)}"
|
| 143 |
+
self.state.errors.append(error_msg)
|
| 144 |
+
self.logger.error(error_msg)
|
| 145 |
+
return {'status': 'error', 'error': error_msg}
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _generate_recommendations(self, context_analysis: Dict, image_analysis: Dict) -> List[str]:
|
| 149 |
+
"""
|
| 150 |
+
Generate recommendations based on context and image analysis
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
context_analysis: Results from context learning
|
| 154 |
+
image_analysis: Results from image analysis
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
List[str]: List of recommendations
|
| 158 |
+
"""
|
| 159 |
+
try:
|
| 160 |
+
recommendations = []
|
| 161 |
+
|
| 162 |
+
# Check if we have sufficient data
|
| 163 |
+
if not context_analysis or not image_analysis:
|
| 164 |
+
return ["Insufficient data to generate recommendations"]
|
| 165 |
+
|
| 166 |
+
# Analyze context findings
|
| 167 |
+
if context_analysis.get('key_findings'):
|
| 168 |
+
recommendations.append("Based on context analysis:")
|
| 169 |
+
for source, finding in context_analysis['key_findings'].items():
|
| 170 |
+
if finding: # Check if finding exists
|
| 171 |
+
recommendations.append(f"- {finding}")
|
| 172 |
+
|
| 173 |
+
# Analyze image findings
|
| 174 |
+
if image_analysis.get('selected_images'):
|
| 175 |
+
recommendations.append("\nBased on image analysis:")
|
| 176 |
+
recommendations.append(f"- Found {len(image_analysis['selected_images'])} relevant images")
|
| 177 |
+
|
| 178 |
+
# Add specific image recommendations
|
| 179 |
+
if image_analysis.get('analysis_summary'):
|
| 180 |
+
for img_path, caption in image_analysis['analysis_summary'].items():
|
| 181 |
+
if caption: # Check if caption exists
|
| 182 |
+
recommendations.append(f"- {caption}")
|
| 183 |
+
|
| 184 |
+
# Combine findings for final recommendations
|
| 185 |
+
recommendations.append("\nKey Recommendations:")
|
| 186 |
+
if context_analysis.get('relevant_keywords'):
|
| 187 |
+
keywords = context_analysis['relevant_keywords'][:5] # Top 5 keywords
|
| 188 |
+
recommendations.append(f"- Focus areas identified: {', '.join(keywords)}")
|
| 189 |
+
|
| 190 |
+
# Add source credibility note
|
| 191 |
+
if context_analysis.get('sources'):
|
| 192 |
+
recommendations.append(f"- Analysis based on {len(context_analysis['sources'])} credible sources")
|
| 193 |
+
|
| 194 |
+
# Add confidence note
|
| 195 |
+
recommendations.append("- Regular monitoring and updates recommended")
|
| 196 |
+
|
| 197 |
+
self.logger.info("Generated recommendations successfully")
|
| 198 |
+
return recommendations
|
| 199 |
+
|
| 200 |
+
except Exception as e:
|
| 201 |
+
error_msg = f"Error generating recommendations: {str(e)}"
|
| 202 |
+
self.logger.error(error_msg)
|
| 203 |
+
return ["Error generating recommendations. Please check the detailed report."]
|
| 204 |
+
|
agents/base_agent.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Any, Dict, List, Optional
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
from models.data_models import AgentState
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BaseAgent(ABC):
|
| 12 |
+
def __init__(self, name: str):
|
| 13 |
+
self.name = name
|
| 14 |
+
self.state = AgentState()
|
| 15 |
+
self.logger = logging.getLogger(name)
|
| 16 |
+
|
| 17 |
+
@abstractmethod
|
| 18 |
+
def reason(self, input_data: Any) -> List[str]:
|
| 19 |
+
"""Implement step-by-step reasoning process"""
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
@abstractmethod
|
| 23 |
+
def execute(self, input_data: Any) -> Any:
|
| 24 |
+
"""Implement main execution logic"""
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
def validate(self, input_data: Any) -> bool:
|
| 28 |
+
"""
|
| 29 |
+
Validate input data with basic checks and logging
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
input_data: Any - Data to validate
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
bool - True if valid, False otherwise
|
| 36 |
+
|
| 37 |
+
Note: Child classes should override this method
|
| 38 |
+
with additional specific validation rules
|
| 39 |
+
"""
|
| 40 |
+
try:
|
| 41 |
+
# Basic validation: check if input exists
|
| 42 |
+
if input_data is None:
|
| 43 |
+
self.state.errors.append("Input data is None")
|
| 44 |
+
return False
|
| 45 |
+
|
| 46 |
+
# Check if input is empty
|
| 47 |
+
if isinstance(input_data, (str, list, dict)) and not input_data:
|
| 48 |
+
self.state.errors.append("Input data is empty")
|
| 49 |
+
return False
|
| 50 |
+
|
| 51 |
+
self.logger.debug(f"Input validation successful for {self.name}")
|
| 52 |
+
return True
|
| 53 |
+
|
| 54 |
+
except Exception as e:
|
| 55 |
+
error_msg = f"Validation error: {str(e)}"
|
| 56 |
+
self.state.errors.append(error_msg)
|
| 57 |
+
self.logger.error(error_msg)
|
| 58 |
+
return False
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def log_decision(self, decision: Dict) -> None:
|
| 62 |
+
"""
|
| 63 |
+
Track agent decisions with timestamps and metadata
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
decision: Dictionary containing:
|
| 67 |
+
- action: str - What action was taken
|
| 68 |
+
- reason: str - Why this decision was made
|
| 69 |
+
- metadata: Dict - Any additional information
|
| 70 |
+
"""
|
| 71 |
+
timestamped_decision = {
|
| 72 |
+
'timestamp': datetime.now().isoformat(),
|
| 73 |
+
'agent_name': self.name,
|
| 74 |
+
**decision
|
| 75 |
+
}
|
| 76 |
+
self.state.decisions.append(timestamped_decision)
|
| 77 |
+
self.logger.info(f"Decision logged: {timestamped_decision}")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_state(self) -> Dict:
|
| 81 |
+
"""
|
| 82 |
+
Return current agent state as a dictionary
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Dict containing:
|
| 86 |
+
- intent: str - Current agent intent
|
| 87 |
+
- thoughts: List[str] - Reasoning steps
|
| 88 |
+
- decisions: List[Dict] - History of decisions
|
| 89 |
+
- errors: List[str] - Any errors encountered
|
| 90 |
+
- last_updated: str - Timestamp of last state change
|
| 91 |
+
"""
|
| 92 |
+
return {
|
| 93 |
+
'intent': self.state.intent,
|
| 94 |
+
'thoughts': self.state.thoughts.copy(),
|
| 95 |
+
'decisions': self.state.decisions.copy(),
|
| 96 |
+
'errors': self.state.errors.copy(),
|
| 97 |
+
'last_updated': datetime.now().isoformat()
|
| 98 |
+
}
|
| 99 |
+
|
agents/context_agent.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional
|
| 2 |
+
import wikipedia
|
| 3 |
+
from transformers import pipeline
|
| 4 |
+
from models.model_config import ModelConfig
|
| 5 |
+
from models.data_models import ContextInput
|
| 6 |
+
from .base_agent import BaseAgent
|
| 7 |
+
|
| 8 |
+
class ContextLearnAgent(BaseAgent):
|
| 9 |
+
|
| 10 |
+
def __init__(self, name: str = "ContextLearnAgent"):
|
| 11 |
+
super().__init__(name)
|
| 12 |
+
self.learned_context: Dict = {}
|
| 13 |
+
self.models: Dict[str, ModelConfig] = {
|
| 14 |
+
'lightweight_summarizer': ModelConfig(
|
| 15 |
+
name='facebook/bart-large-cnn',
|
| 16 |
+
type='lightweight',
|
| 17 |
+
task='summarization'
|
| 18 |
+
),
|
| 19 |
+
'advanced_summarizer': ModelConfig(
|
| 20 |
+
name='facebook/bart-large-xsum',
|
| 21 |
+
type='advanced',
|
| 22 |
+
task='summarization'
|
| 23 |
+
)
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
def reason(self, input_data: ContextInput) -> List[str]:
|
| 27 |
+
"""
|
| 28 |
+
Plan the context gathering and learning process
|
| 29 |
+
"""
|
| 30 |
+
thoughts = []
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
# Analyze query requirements
|
| 34 |
+
thoughts.append(f"Analyzing query: '{input_data.processed_query}'")
|
| 35 |
+
|
| 36 |
+
# Plan Wikipedia search strategy
|
| 37 |
+
thoughts.append("\nContext Gathering Strategy:")
|
| 38 |
+
thoughts.append("1. Extract key terms from query")
|
| 39 |
+
thoughts.append("2. Search Wikipedia for relevant articles")
|
| 40 |
+
thoughts.append("3. Analyze article relevance")
|
| 41 |
+
|
| 42 |
+
# Plan summarization approach
|
| 43 |
+
thoughts.append("\nSummarization Strategy:")
|
| 44 |
+
thoughts.append(f"- Start with {self.models['lightweight_summarizer'].name}")
|
| 45 |
+
thoughts.append(f"- Use {self.models['advanced_summarizer'].name} for complex content")
|
| 46 |
+
|
| 47 |
+
if input_data.constraints:
|
| 48 |
+
thoughts.append("\nConstraint Handling:")
|
| 49 |
+
for constraint in input_data.constraints:
|
| 50 |
+
thoughts.append(f"- Will filter content against: {constraint}")
|
| 51 |
+
|
| 52 |
+
# Store thoughts in state
|
| 53 |
+
self.state.thoughts.extend(thoughts)
|
| 54 |
+
self.logger.info("Reasoning complete for context learning")
|
| 55 |
+
|
| 56 |
+
return thoughts
|
| 57 |
+
|
| 58 |
+
except Exception as e:
|
| 59 |
+
error_msg = f"Error during context reasoning: {str(e)}"
|
| 60 |
+
self.state.errors.append(error_msg)
|
| 61 |
+
self.logger.error(error_msg)
|
| 62 |
+
return thoughts
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def execute(self, input_data: ContextInput) -> Dict:
|
| 66 |
+
"""
|
| 67 |
+
Gather and process context information
|
| 68 |
+
"""
|
| 69 |
+
try:
|
| 70 |
+
if not self.validate(input_data):
|
| 71 |
+
return {
|
| 72 |
+
'status': 'error',
|
| 73 |
+
'error': self.state.errors[-1]
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
results = {
|
| 77 |
+
'gathered_context': {},
|
| 78 |
+
'keywords': set(),
|
| 79 |
+
'summaries': {},
|
| 80 |
+
'model_decisions': {
|
| 81 |
+
'summarizer_used': [],
|
| 82 |
+
'processing_stats': {
|
| 83 |
+
'articles_found': 0,
|
| 84 |
+
'articles_processed': 0
|
| 85 |
+
}
|
| 86 |
+
},
|
| 87 |
+
'status': 'processing'
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
# Search Wikipedia
|
| 91 |
+
try:
|
| 92 |
+
search_results = wikipedia.search(input_data.processed_query, results=5)
|
| 93 |
+
results['model_decisions']['processing_stats']['articles_found'] = len(search_results)
|
| 94 |
+
|
| 95 |
+
for title in search_results:
|
| 96 |
+
try:
|
| 97 |
+
page = wikipedia.page(title)
|
| 98 |
+
|
| 99 |
+
# Start with lightweight summarization
|
| 100 |
+
summary = self._get_summary(page.content, 'lightweight')
|
| 101 |
+
|
| 102 |
+
# Use advanced summarizer if content is complex
|
| 103 |
+
if len(page.content.split()) > 1000: # Long article
|
| 104 |
+
summary = self._get_summary(page.content, 'advanced')
|
| 105 |
+
|
| 106 |
+
results['gathered_context'][title] = {
|
| 107 |
+
'url': page.url,
|
| 108 |
+
'summary': summary['text'],
|
| 109 |
+
'confidence': summary['confidence'],
|
| 110 |
+
'model_used': summary['model_used']
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
# Extract keywords
|
| 114 |
+
results['keywords'].update(self._extract_keywords(page.content))
|
| 115 |
+
results['model_decisions']['summarizer_used'].append(summary['model_used'])
|
| 116 |
+
results['model_decisions']['processing_stats']['articles_processed'] += 1
|
| 117 |
+
|
| 118 |
+
except wikipedia.exceptions.DisambiguationError as e:
|
| 119 |
+
self.logger.warning(f"Disambiguation for {title}: {str(e)}")
|
| 120 |
+
except wikipedia.exceptions.PageError as e:
|
| 121 |
+
self.logger.warning(f"Page error for {title}: {str(e)}")
|
| 122 |
+
|
| 123 |
+
except Exception as e:
|
| 124 |
+
self.logger.error(f"Wikipedia search error: {str(e)}")
|
| 125 |
+
|
| 126 |
+
results['status'] = 'success'
|
| 127 |
+
self.learned_context = results
|
| 128 |
+
|
| 129 |
+
# Log decision
|
| 130 |
+
self.log_decision({
|
| 131 |
+
'action': 'context_gathering',
|
| 132 |
+
'articles_processed': results['model_decisions']['processing_stats']['articles_processed'],
|
| 133 |
+
'keywords_found': len(results['keywords'])
|
| 134 |
+
})
|
| 135 |
+
|
| 136 |
+
return results
|
| 137 |
+
|
| 138 |
+
except Exception as e:
|
| 139 |
+
error_msg = f"Error executing context gathering: {str(e)}"
|
| 140 |
+
self.state.errors.append(error_msg)
|
| 141 |
+
self.logger.error(error_msg)
|
| 142 |
+
return {'status': 'error', 'error': error_msg}
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# def __init__(self, name: str = "ContextLearnAgent"):
|
| 147 |
+
# super().__init__(name)
|
| 148 |
+
# self.learned_context: Dict = {}
|
| 149 |
+
|
| 150 |
+
# def reason(self, input_data: ContextInput) -> List[str]:
|
| 151 |
+
# """
|
| 152 |
+
# Analyze what information needs to be gathered about pumps
|
| 153 |
+
|
| 154 |
+
# Args:
|
| 155 |
+
# input_data: Processed user query and constraints
|
| 156 |
+
# Returns:
|
| 157 |
+
# List[str]: Reasoning steps about required information
|
| 158 |
+
# """
|
| 159 |
+
# thoughts = []
|
| 160 |
+
|
| 161 |
+
# try:
|
| 162 |
+
# # Analyze search requirements
|
| 163 |
+
# thoughts.append(f"Analyzing search query: '{input_data.processed_query}'")
|
| 164 |
+
|
| 165 |
+
# # Consider search sources
|
| 166 |
+
# thoughts.append(f"Planning to search in: {', '.join(input_data.search_sources)}")
|
| 167 |
+
|
| 168 |
+
# # Consider constraints
|
| 169 |
+
# if input_data.constraints:
|
| 170 |
+
# thoughts.append(f"Will filter results based on constraints: {input_data.constraints}")
|
| 171 |
+
|
| 172 |
+
# # Consider result limit
|
| 173 |
+
# thoughts.append(f"Will gather up to {input_data.max_results} relevant results")
|
| 174 |
+
|
| 175 |
+
# # Store thoughts in state
|
| 176 |
+
# self.state.thoughts.extend(thoughts)
|
| 177 |
+
# self.logger.info("Reasoning complete for context gathering")
|
| 178 |
+
|
| 179 |
+
# return thoughts
|
| 180 |
+
|
| 181 |
+
# except Exception as e:
|
| 182 |
+
# error_msg = f"Error during context reasoning: {str(e)}"
|
| 183 |
+
# self.state.errors.append(error_msg)
|
| 184 |
+
# self.logger.error(error_msg)
|
| 185 |
+
# return thoughts
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# def execute(self, input_data: ContextInput) -> Dict:
|
| 189 |
+
# """
|
| 190 |
+
# Gather and process information from specified sources
|
| 191 |
+
|
| 192 |
+
# Args:
|
| 193 |
+
# input_data: Search parameters and constraints
|
| 194 |
+
# Returns:
|
| 195 |
+
# Dict containing:
|
| 196 |
+
# - sources: List[str] - Sources used
|
| 197 |
+
# - gathered_info: Dict - Information gathered by source
|
| 198 |
+
# - summary: str - Brief summary of findings
|
| 199 |
+
# - status: str - Processing status
|
| 200 |
+
# """
|
| 201 |
+
# try:
|
| 202 |
+
# # First validate the input
|
| 203 |
+
# if not self.validate(input_data):
|
| 204 |
+
# return {
|
| 205 |
+
# 'status': 'error',
|
| 206 |
+
# 'error': self.state.errors[-1]
|
| 207 |
+
# }
|
| 208 |
+
|
| 209 |
+
# # Initialize results structure
|
| 210 |
+
# gathered_info = {
|
| 211 |
+
# 'sources': input_data.search_sources,
|
| 212 |
+
# 'gathered_info': {},
|
| 213 |
+
# 'summary': '',
|
| 214 |
+
# 'status': 'processing'
|
| 215 |
+
# }
|
| 216 |
+
|
| 217 |
+
# # Process each source (placeholder for actual API calls)
|
| 218 |
+
# for source in input_data.search_sources:
|
| 219 |
+
# gathered_info['gathered_info'][source] = {
|
| 220 |
+
# 'status': 'pending',
|
| 221 |
+
# 'content': [],
|
| 222 |
+
# 'metadata': {
|
| 223 |
+
# 'timestamp': datetime.now().isoformat(),
|
| 224 |
+
# 'query': input_data.processed_query
|
| 225 |
+
# }
|
| 226 |
+
# }
|
| 227 |
+
|
| 228 |
+
# # Log the decision
|
| 229 |
+
# self.log_decision({
|
| 230 |
+
# 'action': 'gather_context',
|
| 231 |
+
# 'sources': input_data.search_sources,
|
| 232 |
+
# 'query': input_data.processed_query
|
| 233 |
+
# })
|
| 234 |
+
|
| 235 |
+
# # Store in agent's state
|
| 236 |
+
# self.learned_context = gathered_info
|
| 237 |
+
# gathered_info['status'] = 'success'
|
| 238 |
+
|
| 239 |
+
# return gathered_info
|
| 240 |
+
|
| 241 |
+
# except Exception as e:
|
| 242 |
+
# error_msg = f"Error executing context gathering: {str(e)}"
|
| 243 |
+
# self.state.errors.append(error_msg)
|
| 244 |
+
# self.logger.error(error_msg)
|
| 245 |
+
# return {'status': 'error', 'error': error_msg}
|
| 246 |
+
|
agents/image_agent.py
ADDED
|
@@ -0,0 +1,1065 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
import torch
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from transformers import (
|
| 5 |
+
BlipProcessor, BlipForConditionalGeneration,
|
| 6 |
+
Blip2Processor, Blip2ForConditionalGeneration,
|
| 7 |
+
AutoImageProcessor, ResNetForImageClassification,
|
| 8 |
+
ViTImageProcessor, ViTForImageClassification
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
from models.model_config import ModelConfig
|
| 12 |
+
from models.data_models import ImageAnalysisInput
|
| 13 |
+
from config.settings import config
|
| 14 |
+
from .base_agent import BaseAgent
|
| 15 |
+
|
| 16 |
+
from langchain.llms import HuggingFacePipeline
|
| 17 |
+
from langchain.prompts import PromptTemplate
|
| 18 |
+
from langchain.chains import LLMChain
|
| 19 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ImageAnalyzerAgent(BaseAgent):
|
| 24 |
+
def __init__(self, name: str = "ImageAnalyzerAgent"):
|
| 25 |
+
super().__init__(name)
|
| 26 |
+
self.analyzed_images: Dict = {}
|
| 27 |
+
# self.models: Dict[str, ModelConfig] = {
|
| 28 |
+
# 'lightweight_caption': ModelConfig('clip-vit-base', 'lightweight', 'caption'),
|
| 29 |
+
# 'advanced_caption': ModelConfig('blip2-opt', 'advanced', 'caption'),
|
| 30 |
+
# 'lightweight_classifier': ModelConfig('resnet18', 'lightweight', 'classification'),
|
| 31 |
+
# #'advanced_classifier': ModelConfig('vit-large', 'advanced', 'classification')
|
| 32 |
+
|
| 33 |
+
# }
|
| 34 |
+
self.models: Dict[str, ModelConfig] = {
|
| 35 |
+
'lightweight_caption': ModelConfig(
|
| 36 |
+
name='Salesforce/blip-image-captioning-base',
|
| 37 |
+
type='lightweight',
|
| 38 |
+
task='caption'
|
| 39 |
+
),
|
| 40 |
+
'advanced_caption': ModelConfig(
|
| 41 |
+
name='Salesforce/blip2-opt-2.7b',
|
| 42 |
+
type='advanced',
|
| 43 |
+
task='caption'
|
| 44 |
+
),
|
| 45 |
+
'lightweight_classifier': ModelConfig(
|
| 46 |
+
name='microsoft/resnet-50',
|
| 47 |
+
type='lightweight',
|
| 48 |
+
task='classification'
|
| 49 |
+
),
|
| 50 |
+
'advanced_classifier': ModelConfig(
|
| 51 |
+
name='google/vit-base-patch16-224',
|
| 52 |
+
type='lightweight',
|
| 53 |
+
task='classification'
|
| 54 |
+
)
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
# Add LangChain setup
|
| 58 |
+
self.llm_model_name = "MBZUAI/LaMini-Flan-T5-783M"
|
| 59 |
+
self._initialize_llm()
|
| 60 |
+
|
| 61 |
+
# def reason(self, input_data: ImageAnalysisInput) -> List[str]:
|
| 62 |
+
# """
|
| 63 |
+
# Plan the image analysis process for pump-related images
|
| 64 |
+
|
| 65 |
+
# Args:
|
| 66 |
+
# input_data: Contains uploaded images, context, and constraints
|
| 67 |
+
# Returns:
|
| 68 |
+
# List[str]: Reasoning steps about image analysis plan
|
| 69 |
+
# """
|
| 70 |
+
# thoughts = []
|
| 71 |
+
|
| 72 |
+
# try:
|
| 73 |
+
# # Consider number of images
|
| 74 |
+
# thoughts.append(f"Processing {len(input_data.images)} uploaded images")
|
| 75 |
+
|
| 76 |
+
# # Consider context from ContextLearnAgent
|
| 77 |
+
# if input_data.context:
|
| 78 |
+
# thoughts.append("Using learned context to guide image analysis")
|
| 79 |
+
# thoughts.append(f"Looking for pump-related features based on context")
|
| 80 |
+
|
| 81 |
+
# # Consider constraints
|
| 82 |
+
# if input_data.constraints:
|
| 83 |
+
# thoughts.append(f"Will apply constraints: {input_data.constraints}")
|
| 84 |
+
|
| 85 |
+
# # Plan analysis steps
|
| 86 |
+
# thoughts.append(f"Analysis plan:")
|
| 87 |
+
# thoughts.append("1. Verify image accessibility")
|
| 88 |
+
# thoughts.append("2. Perform image captioning")
|
| 89 |
+
# thoughts.append("3. Match captions with context")
|
| 90 |
+
# thoughts.append(f"4. Select top {input_data.top_k} relevant images")
|
| 91 |
+
|
| 92 |
+
# # Store thoughts in state
|
| 93 |
+
# self.state.thoughts.extend(thoughts)
|
| 94 |
+
# self.logger.info("Reasoning complete for image analysis")
|
| 95 |
+
|
| 96 |
+
# return thoughts
|
| 97 |
+
|
| 98 |
+
# except Exception as e:
|
| 99 |
+
# error_msg = f"Error during image analysis reasoning: {str(e)}"
|
| 100 |
+
# self.state.errors.append(error_msg)
|
| 101 |
+
# self.logger.error(error_msg)
|
| 102 |
+
# return thoughts
|
| 103 |
+
|
| 104 |
+
# def reason(self, input_data: ImageAnalysisInput) -> List[str]:
|
| 105 |
+
# """
|
| 106 |
+
# Plan the image analysis process with model selection strategy
|
| 107 |
+
|
| 108 |
+
# Args:
|
| 109 |
+
# input_data: Contains uploaded images, context, and constraints
|
| 110 |
+
# Returns:
|
| 111 |
+
# List[str]: Reasoning steps about image analysis plan
|
| 112 |
+
# """
|
| 113 |
+
# thoughts = []
|
| 114 |
+
|
| 115 |
+
# try:
|
| 116 |
+
# # Consider input volume and resources
|
| 117 |
+
# thoughts.append(f"Processing {len(input_data.images)} uploaded images")
|
| 118 |
+
|
| 119 |
+
# # Model selection strategy
|
| 120 |
+
# thoughts.append("Planning model selection strategy:")
|
| 121 |
+
# thoughts.append("- Will start with lightweight models for efficiency")
|
| 122 |
+
# thoughts.append(f"- Using {self.models['lightweight_caption'].name} for initial caption generation")
|
| 123 |
+
# thoughts.append(f"- Will switch to {self.models['advanced_caption'].name} if confidence below {self.models['lightweight_caption'].threshold}")
|
| 124 |
+
|
| 125 |
+
# # Context consideration
|
| 126 |
+
# if input_data.context:
|
| 127 |
+
# thoughts.append("Using learned context to guide image analysis:")
|
| 128 |
+
# thoughts.append("- Will match image captions against context")
|
| 129 |
+
# thoughts.append("- Will use context for relevance scoring")
|
| 130 |
+
|
| 131 |
+
# # Constraints handling
|
| 132 |
+
# if input_data.constraints:
|
| 133 |
+
# thoughts.append(f"Will apply constraints during analysis:")
|
| 134 |
+
# for constraint in input_data.constraints:
|
| 135 |
+
# thoughts.append(f"- {constraint}")
|
| 136 |
+
|
| 137 |
+
# # Analysis pipeline
|
| 138 |
+
# thoughts.append("Analysis pipeline:")
|
| 139 |
+
# thoughts.append("1. Initial lightweight model caption generation")
|
| 140 |
+
# thoughts.append("2. Confidence check and model escalation if needed")
|
| 141 |
+
# thoughts.append("3. Context matching and constraint application")
|
| 142 |
+
# thoughts.append(f"4. Selection of top {input_data.top_k} relevant images")
|
| 143 |
+
|
| 144 |
+
# # Store thoughts in state
|
| 145 |
+
# self.state.thoughts.extend(thoughts)
|
| 146 |
+
# self.logger.info("Reasoning complete for image analysis")
|
| 147 |
+
|
| 148 |
+
# return thoughts
|
| 149 |
+
|
| 150 |
+
# except Exception as e:
|
| 151 |
+
# error_msg = f"Error during image analysis reasoning: {str(e)}"
|
| 152 |
+
# self.state.errors.append(error_msg)
|
| 153 |
+
# self.logger.error(error_msg)
|
| 154 |
+
# return thoughts
|
| 155 |
+
|
| 156 |
+
def _initialize_llm(self):
|
| 157 |
+
"""Initialize LangChain components"""
|
| 158 |
+
try:
|
| 159 |
+
tokenizer = AutoTokenizer.from_pretrained(self.llm_model_name)
|
| 160 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(self.llm_model_name)
|
| 161 |
+
|
| 162 |
+
pipe = pipeline(
|
| 163 |
+
"text2text-generation",
|
| 164 |
+
model=model,
|
| 165 |
+
tokenizer=tokenizer,
|
| 166 |
+
max_length=512,
|
| 167 |
+
temperature=0.3
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
self.llm = HuggingFacePipeline(pipeline=pipe)
|
| 171 |
+
|
| 172 |
+
# Create relevance analysis chain
|
| 173 |
+
relevance_template = """
|
| 174 |
+
Analyze the relevance between image content and context:
|
| 175 |
+
|
| 176 |
+
Image Caption: {caption}
|
| 177 |
+
Image Classification: {classification}
|
| 178 |
+
Context Keywords: {context_keywords}
|
| 179 |
+
Domain Context: {domain_context}
|
| 180 |
+
|
| 181 |
+
Provide a detailed analysis of:
|
| 182 |
+
1. Content Relevance
|
| 183 |
+
2. Domain Alignment
|
| 184 |
+
3. Context Matching Score (0-1)
|
| 185 |
+
|
| 186 |
+
Analysis:
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
self.relevance_chain = LLMChain(
|
| 190 |
+
llm=self.llm,
|
| 191 |
+
prompt=PromptTemplate(
|
| 192 |
+
template=relevance_template,
|
| 193 |
+
input_variables=["caption", "classification",
|
| 194 |
+
"context_keywords", "domain_context"]
|
| 195 |
+
)
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
except Exception as e:
|
| 199 |
+
self.logger.error(f"Error initializing LLM: {str(e)}")
|
| 200 |
+
raise
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def reason(self, input_data: ImageAnalysisInput) -> List[str]:
|
| 204 |
+
"""
|
| 205 |
+
Plan the image analysis process with LLM-enhanced reasoning
|
| 206 |
+
"""
|
| 207 |
+
thoughts = []
|
| 208 |
+
|
| 209 |
+
try:
|
| 210 |
+
# Initial assessment
|
| 211 |
+
thoughts.append(f"Processing {len(input_data.images)} images")
|
| 212 |
+
|
| 213 |
+
# Context understanding
|
| 214 |
+
if input_data.context:
|
| 215 |
+
thoughts.append("\nAnalyzing Context:")
|
| 216 |
+
thoughts.append(f"- Domain context available: {bool(input_data.context.get('domain_context'))}")
|
| 217 |
+
thoughts.append(f"- Keywords identified: {len(input_data.context.get('keywords', []))}")
|
| 218 |
+
|
| 219 |
+
# Use LLM to analyze context requirements
|
| 220 |
+
context_analysis = self.relevance_chain.run(
|
| 221 |
+
caption="Context analysis phase",
|
| 222 |
+
classification="Initial planning",
|
| 223 |
+
context_keywords=str(input_data.context.get('keywords', [])),
|
| 224 |
+
domain_context=input_data.context.get('domain_context', 'Not specified')
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
thoughts.append("\nLLM Context Analysis:")
|
| 228 |
+
for line in context_analysis.split('\n'):
|
| 229 |
+
if line.strip():
|
| 230 |
+
thoughts.append(f"- {line.strip()}")
|
| 231 |
+
|
| 232 |
+
# Model selection strategy
|
| 233 |
+
thoughts.append("\nModel Selection Strategy:")
|
| 234 |
+
thoughts.append("1. Image Processing Pipeline:")
|
| 235 |
+
thoughts.append(f" - Initial caption: {self.models['lightweight_caption'].name}")
|
| 236 |
+
thoughts.append(f" - Initial classification: {self.models['lightweight_classifier'].name}")
|
| 237 |
+
thoughts.append(" - Will escalate to advanced models if confidence below threshold")
|
| 238 |
+
|
| 239 |
+
# Analysis plan
|
| 240 |
+
thoughts.append("\nAnalysis Pipeline:")
|
| 241 |
+
thoughts.append("1. Generate captions and classifications")
|
| 242 |
+
thoughts.append("2. Perform LLM-based relevance analysis")
|
| 243 |
+
thoughts.append("3. Apply context matching")
|
| 244 |
+
thoughts.append(f"4. Select top {input_data.top_k} relevant images")
|
| 245 |
+
|
| 246 |
+
if input_data.constraints:
|
| 247 |
+
thoughts.append("\nConstraint Handling:")
|
| 248 |
+
for constraint in input_data.constraints:
|
| 249 |
+
thoughts.append(f"- Will verify: {constraint}")
|
| 250 |
+
|
| 251 |
+
# Store thoughts in state
|
| 252 |
+
self.state.thoughts.extend(thoughts)
|
| 253 |
+
self.logger.info("Reasoning complete for image analysis")
|
| 254 |
+
self._cleanup_llm()
|
| 255 |
+
|
| 256 |
+
return thoughts
|
| 257 |
+
|
| 258 |
+
except Exception as e:
|
| 259 |
+
error_msg = f"Error during reasoning: {str(e)}"
|
| 260 |
+
self.state.errors.append(error_msg)
|
| 261 |
+
self.logger.error(error_msg)
|
| 262 |
+
self._cleanup_llm()
|
| 263 |
+
return thoughts
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def _analyze_image_context_relevance(self,
|
| 267 |
+
caption: str,
|
| 268 |
+
classification: str,
|
| 269 |
+
context: Dict) -> Dict:
|
| 270 |
+
"""
|
| 271 |
+
Analyze relevance between image content and context using LLM
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
caption: Generated image caption
|
| 275 |
+
classification: Image classification result
|
| 276 |
+
context: Context information from ContextLearnAgent
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
Dict containing:
|
| 280 |
+
- relevance_score: float (0-1)
|
| 281 |
+
- content_analysis: str
|
| 282 |
+
- domain_alignment: str
|
| 283 |
+
- confidence: float
|
| 284 |
+
"""
|
| 285 |
+
try:
|
| 286 |
+
# Prepare context information
|
| 287 |
+
context_keywords = context.get('keywords', [])
|
| 288 |
+
domain_context = context.get('domain_context', '')
|
| 289 |
+
|
| 290 |
+
# Get LLM analysis
|
| 291 |
+
llm_analysis = self.relevance_chain.run(
|
| 292 |
+
caption=caption,
|
| 293 |
+
classification=classification,
|
| 294 |
+
context_keywords=str(context_keywords),
|
| 295 |
+
domain_context=domain_context
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# Parse LLM output
|
| 299 |
+
analysis_result = {
|
| 300 |
+
'relevance_score': 0.0,
|
| 301 |
+
'content_analysis': '',
|
| 302 |
+
'domain_alignment': '',
|
| 303 |
+
'confidence': 0.0
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
# Extract information from LLM output
|
| 307 |
+
current_section = ''
|
| 308 |
+
for line in llm_analysis.split('\n'):
|
| 309 |
+
line = line.strip()
|
| 310 |
+
if 'Content Relevance' in line:
|
| 311 |
+
current_section = 'content'
|
| 312 |
+
elif 'Domain Alignment' in line:
|
| 313 |
+
current_section = 'domain'
|
| 314 |
+
elif 'Context Matching Score' in line:
|
| 315 |
+
try:
|
| 316 |
+
# Extract score (0-1) from text
|
| 317 |
+
score = float([s for s in line.split() if s.replace('.','').isdigit()][0])
|
| 318 |
+
analysis_result['relevance_score'] = min(1.0, max(0.0, score))
|
| 319 |
+
except:
|
| 320 |
+
pass
|
| 321 |
+
elif line and current_section:
|
| 322 |
+
if current_section == 'content':
|
| 323 |
+
analysis_result['content_analysis'] += line + ' '
|
| 324 |
+
elif current_section == 'domain':
|
| 325 |
+
analysis_result['domain_alignment'] += line + ' '
|
| 326 |
+
|
| 327 |
+
# Calculate confidence based on clarity of analysis
|
| 328 |
+
analysis_result['confidence'] = min(1.0,
|
| 329 |
+
(len(analysis_result['content_analysis']) +
|
| 330 |
+
len(analysis_result['domain_alignment'])) / 200)
|
| 331 |
+
|
| 332 |
+
self.logger.debug(f"Context relevance analysis completed with "
|
| 333 |
+
f"score: {analysis_result['relevance_score']}")
|
| 334 |
+
|
| 335 |
+
return analysis_result
|
| 336 |
+
|
| 337 |
+
except Exception as e:
|
| 338 |
+
self.logger.error(f"Error in context relevance analysis: {str(e)}")
|
| 339 |
+
return {
|
| 340 |
+
'relevance_score': 0.0,
|
| 341 |
+
'content_analysis': '',
|
| 342 |
+
'domain_alignment': '',
|
| 343 |
+
'confidence': 0.0
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def reason1(self, input_data: ImageAnalysisInput) -> List[str]:
|
| 349 |
+
"""
|
| 350 |
+
Plan the image analysis process with both captioning and classification
|
| 351 |
+
"""
|
| 352 |
+
thoughts = []
|
| 353 |
+
|
| 354 |
+
try:
|
| 355 |
+
# Input volume assessment
|
| 356 |
+
thoughts.append(f"Processing {len(input_data.images)} images")
|
| 357 |
+
|
| 358 |
+
# Model strategy explanation
|
| 359 |
+
thoughts.append("\nModel Selection Strategy:")
|
| 360 |
+
thoughts.append("1. Captioning Pipeline:")
|
| 361 |
+
thoughts.append(f" - Start with {self.models['lightweight_caption'].name} for efficient processing")
|
| 362 |
+
thoughts.append(f" - Escalate to {self.models['advanced_caption'].name} if confidence < {self.models['lightweight_caption'].threshold}")
|
| 363 |
+
|
| 364 |
+
thoughts.append("\n2. Classification Pipeline:")
|
| 365 |
+
thoughts.append(f" - Begin with {self.models['lightweight_classifier'].name} for initial classification")
|
| 366 |
+
thoughts.append(f" - Switch to {self.models['advanced_classifier'].name} if confidence < {self.models['lightweight_classifier'].threshold}")
|
| 367 |
+
|
| 368 |
+
# Context consideration
|
| 369 |
+
if input_data.context:
|
| 370 |
+
thoughts.append("\nContext Integration:")
|
| 371 |
+
thoughts.append("- Will match image captions against provided context")
|
| 372 |
+
thoughts.append("- Will verify classifications against context requirements")
|
| 373 |
+
|
| 374 |
+
# Constraints handling
|
| 375 |
+
if input_data.constraints:
|
| 376 |
+
thoughts.append("\nConstraints Application:")
|
| 377 |
+
for constraint in input_data.constraints:
|
| 378 |
+
thoughts.append(f"- {constraint}")
|
| 379 |
+
|
| 380 |
+
# Process outline
|
| 381 |
+
thoughts.append("\nAnalysis Pipeline:")
|
| 382 |
+
thoughts.append("1. Initial lightweight model processing (caption + classification)")
|
| 383 |
+
thoughts.append("2. Advanced model processing where needed")
|
| 384 |
+
thoughts.append("3. Context matching and constraint verification")
|
| 385 |
+
thoughts.append(f"4. Selection of top {input_data.top_k} most relevant images")
|
| 386 |
+
thoughts.append("5. Memory cleanup after processing")
|
| 387 |
+
|
| 388 |
+
# Store thoughts in state
|
| 389 |
+
self.state.thoughts.extend(thoughts)
|
| 390 |
+
self.logger.info("Reasoning complete for image analysis")
|
| 391 |
+
|
| 392 |
+
return thoughts
|
| 393 |
+
|
| 394 |
+
except Exception as e:
|
| 395 |
+
error_msg = f"Error during reasoning: {str(e)}"
|
| 396 |
+
self.state.errors.append(error_msg)
|
| 397 |
+
self.logger.error(error_msg)
|
| 398 |
+
return thoughts
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
# def execute(self, input_data: ImageAnalysisInput) -> Dict:
|
| 402 |
+
# """
|
| 403 |
+
# Analyze images using multiple models with tiered approach
|
| 404 |
+
|
| 405 |
+
# Args:
|
| 406 |
+
# input_data: Contains images and analysis parameters
|
| 407 |
+
# Returns:
|
| 408 |
+
# Dict containing:
|
| 409 |
+
# - analyzed_images: Dict of image analysis results
|
| 410 |
+
# - selected_images: List of top-k relevant images
|
| 411 |
+
# - model_decisions: Dict of model choices and confidence
|
| 412 |
+
# - status: Processing status
|
| 413 |
+
# """
|
| 414 |
+
# try:
|
| 415 |
+
# if not self.validate(input_data):
|
| 416 |
+
# return {
|
| 417 |
+
# 'status': 'error',
|
| 418 |
+
# 'error': self.state.errors[-1]
|
| 419 |
+
# }
|
| 420 |
+
|
| 421 |
+
# results = {
|
| 422 |
+
# 'analyzed_images': {},
|
| 423 |
+
# 'selected_images': [],
|
| 424 |
+
# 'model_decisions': {},
|
| 425 |
+
# 'status': 'processing'
|
| 426 |
+
# }
|
| 427 |
+
|
| 428 |
+
# # Process each image
|
| 429 |
+
# for img_path in input_data.images:
|
| 430 |
+
# # Start with lightweight models
|
| 431 |
+
# caption_result = self._get_image_caption(img_path, 'lightweight')
|
| 432 |
+
|
| 433 |
+
# # If confidence is low, use advanced model
|
| 434 |
+
# if caption_result['confidence'] < self.models['lightweight_caption'].threshold:
|
| 435 |
+
# caption_result = self._get_image_caption(img_path, 'advanced')
|
| 436 |
+
|
| 437 |
+
# results['analyzed_images'][img_path] = {
|
| 438 |
+
# 'caption': caption_result['caption'],
|
| 439 |
+
# 'confidence': caption_result['confidence'],
|
| 440 |
+
# 'model_used': caption_result['model_used']
|
| 441 |
+
# }
|
| 442 |
+
|
| 443 |
+
# # Select top-k relevant images based on context matching
|
| 444 |
+
# results['selected_images'] = self._select_relevant_images(
|
| 445 |
+
# results['analyzed_images'],
|
| 446 |
+
# input_data.context,
|
| 447 |
+
# input_data.top_k
|
| 448 |
+
# )
|
| 449 |
+
|
| 450 |
+
# results['status'] = 'success'
|
| 451 |
+
# self.analyzed_images = results
|
| 452 |
+
|
| 453 |
+
# # Log decision
|
| 454 |
+
# self.log_decision({
|
| 455 |
+
# 'action': 'analyze_images',
|
| 456 |
+
# 'num_images': len(input_data.images),
|
| 457 |
+
# 'selected_images': len(results['selected_images'])
|
| 458 |
+
# })
|
| 459 |
+
|
| 460 |
+
# return results
|
| 461 |
+
|
| 462 |
+
# except Exception as e:
|
| 463 |
+
# error_msg = f"Error executing image analysis: {str(e)}"
|
| 464 |
+
# self.state.errors.append(error_msg)
|
| 465 |
+
# self.logger.error(error_msg)
|
| 466 |
+
# return {'status': 'error', 'error': error_msg}
|
| 467 |
+
|
| 468 |
+
def execute(self, input_data: ImageAnalysisInput) -> Dict:
|
| 469 |
+
"""
|
| 470 |
+
Execute image analysis using tiered model approach
|
| 471 |
+
"""
|
| 472 |
+
try:
|
| 473 |
+
if not self.validate(input_data):
|
| 474 |
+
return {
|
| 475 |
+
'status': 'error',
|
| 476 |
+
'error': self.state.errors[-1]
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
results = {
|
| 480 |
+
'analyzed_images': {},
|
| 481 |
+
'selected_images': [],
|
| 482 |
+
'model_decisions': {
|
| 483 |
+
'caption_models': set(),
|
| 484 |
+
'classifier_models': set(),
|
| 485 |
+
'processing_stats': {
|
| 486 |
+
'lightweight_usage': 0,
|
| 487 |
+
'advanced_usage': 0
|
| 488 |
+
}
|
| 489 |
+
},
|
| 490 |
+
'status': 'processing'
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
# Process each image
|
| 494 |
+
for img_path in input_data.images:
|
| 495 |
+
self.logger.info(f"Processing image: {img_path}")
|
| 496 |
+
|
| 497 |
+
# Initial lightweight processing
|
| 498 |
+
caption_result = self._get_image_caption(img_path, 'lightweight')
|
| 499 |
+
classify_result = self._get_classification(img_path, 'lightweight')
|
| 500 |
+
results['model_decisions']['processing_stats']['lightweight_usage'] += 1
|
| 501 |
+
|
| 502 |
+
# Advanced processing if needed
|
| 503 |
+
if caption_result['confidence'] < self.models['lightweight_caption'].threshold:
|
| 504 |
+
self.logger.info("Escalating to advanced caption model")
|
| 505 |
+
caption_result = self._get_image_caption(img_path, 'advanced')
|
| 506 |
+
results['model_decisions']['processing_stats']['advanced_usage'] += 1
|
| 507 |
+
|
| 508 |
+
if classify_result['confidence'] < self.models['lightweight_classifier'].threshold:
|
| 509 |
+
self.logger.info("Escalating to advanced classification model")
|
| 510 |
+
classify_result = self._get_classification(img_path, 'advanced')
|
| 511 |
+
results['model_decisions']['processing_stats']['advanced_usage'] += 1
|
| 512 |
+
|
| 513 |
+
# Store results for this image
|
| 514 |
+
results['analyzed_images'][img_path] = {
|
| 515 |
+
'caption': {
|
| 516 |
+
'text': caption_result['caption'],
|
| 517 |
+
'confidence': caption_result['confidence'],
|
| 518 |
+
'model_used': caption_result['model_used']
|
| 519 |
+
},
|
| 520 |
+
'classification': {
|
| 521 |
+
'label': classify_result['class'],
|
| 522 |
+
'confidence': classify_result['confidence'],
|
| 523 |
+
'model_used': classify_result['model_used']
|
| 524 |
+
},
|
| 525 |
+
'combined_confidence': (caption_result['confidence'] + classify_result['confidence']) / 2
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
# Track models used
|
| 529 |
+
results['model_decisions']['caption_models'].add(caption_result['model_used'])
|
| 530 |
+
results['model_decisions']['classifier_models'].add(classify_result['model_used'])
|
| 531 |
+
|
| 532 |
+
# Cleanup after each image to manage memory
|
| 533 |
+
self._cleanup_models()
|
| 534 |
+
|
| 535 |
+
# Select top-k relevant images
|
| 536 |
+
results['selected_images'] = self._select_relevant_images(
|
| 537 |
+
results['analyzed_images'],
|
| 538 |
+
input_data.context,
|
| 539 |
+
input_data.top_k
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
results['status'] = 'success'
|
| 543 |
+
self.analyzed_images = results
|
| 544 |
+
|
| 545 |
+
# Log final decision
|
| 546 |
+
self.log_decision({
|
| 547 |
+
'action': 'complete_image_analysis',
|
| 548 |
+
'num_images': len(input_data.images),
|
| 549 |
+
'selected_images': len(results['selected_images']),
|
| 550 |
+
'model_usage': results['model_decisions']['processing_stats']
|
| 551 |
+
})
|
| 552 |
+
|
| 553 |
+
return results
|
| 554 |
+
|
| 555 |
+
except Exception as e:
|
| 556 |
+
error_msg = f"Error executing image analysis: {str(e)}"
|
| 557 |
+
self.state.errors.append(error_msg)
|
| 558 |
+
self.logger.error(error_msg)
|
| 559 |
+
return {'status': 'error', 'error': error_msg}
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
# def execute(self, input_data: ImageAnalysisInput) -> Dict:
|
| 563 |
+
# """
|
| 564 |
+
# Analyze images using both captioning and classification models
|
| 565 |
+
|
| 566 |
+
# Args:
|
| 567 |
+
# input_data: Contains images and analysis parameters
|
| 568 |
+
# Returns:
|
| 569 |
+
# Dict containing:
|
| 570 |
+
# - analyzed_images: Dict of image analysis results
|
| 571 |
+
# - selected_images: List of top-k relevant images
|
| 572 |
+
# - model_decisions: Dict of model choices and confidence
|
| 573 |
+
# - status: Processing status
|
| 574 |
+
# """
|
| 575 |
+
# try:
|
| 576 |
+
# if not self.validate(input_data):
|
| 577 |
+
# return {
|
| 578 |
+
# 'status': 'error',
|
| 579 |
+
# 'error': self.state.errors[-1]
|
| 580 |
+
# }
|
| 581 |
+
|
| 582 |
+
# results = {
|
| 583 |
+
# 'analyzed_images': {},
|
| 584 |
+
# 'selected_images': [],
|
| 585 |
+
# 'model_decisions': {},
|
| 586 |
+
# 'status': 'processing'
|
| 587 |
+
# }
|
| 588 |
+
|
| 589 |
+
# # Process each image
|
| 590 |
+
# for img_path in input_data.images:
|
| 591 |
+
# # Start with lightweight models
|
| 592 |
+
# caption_result = self._get_image_caption(img_path, 'lightweight')
|
| 593 |
+
# classify_result = self._get_classification(img_path, 'lightweight')
|
| 594 |
+
|
| 595 |
+
# # If either confidence is low, use advanced models
|
| 596 |
+
# if caption_result['confidence'] < self.models['lightweight_caption'].threshold:
|
| 597 |
+
# caption_result = self._get_image_caption(img_path, 'advanced')
|
| 598 |
+
|
| 599 |
+
# if classify_result['confidence'] < self.models['lightweight_classifier'].threshold:
|
| 600 |
+
# classify_result = self._get_classification(img_path, 'advanced')
|
| 601 |
+
|
| 602 |
+
# # Store results for this image
|
| 603 |
+
# results['analyzed_images'][img_path] = {
|
| 604 |
+
# 'caption': caption_result['caption'],
|
| 605 |
+
# 'caption_confidence': caption_result['confidence'],
|
| 606 |
+
# 'caption_model': caption_result['model_used'],
|
| 607 |
+
# 'classification': classify_result['class'],
|
| 608 |
+
# 'class_confidence': classify_result['confidence'],
|
| 609 |
+
# 'class_model': classify_result['model_used']
|
| 610 |
+
# }
|
| 611 |
+
|
| 612 |
+
# # Select top-k relevant images considering both caption and classification
|
| 613 |
+
# results['selected_images'] = self._select_relevant_images(
|
| 614 |
+
# results['analyzed_images'],
|
| 615 |
+
# input_data.context,
|
| 616 |
+
# input_data.top_k
|
| 617 |
+
# )
|
| 618 |
+
|
| 619 |
+
# # Track model decisions
|
| 620 |
+
# results['model_decisions'] = {
|
| 621 |
+
# 'caption_models_used': set(img['caption_model'] for img in results['analyzed_images'].values()),
|
| 622 |
+
# 'classifier_models_used': set(img['class_model'] for img in results['analyzed_images'].values())
|
| 623 |
+
# }
|
| 624 |
+
|
| 625 |
+
# results['status'] = 'success'
|
| 626 |
+
# self.analyzed_images = results
|
| 627 |
+
|
| 628 |
+
# # Log decision
|
| 629 |
+
# self.log_decision({
|
| 630 |
+
# 'action': 'analyze_images',
|
| 631 |
+
# 'num_images': len(input_data.images),
|
| 632 |
+
# 'selected_images': len(results['selected_images']),
|
| 633 |
+
# 'models_used': results['model_decisions']
|
| 634 |
+
# })
|
| 635 |
+
|
| 636 |
+
# # Cleanup to manage memory
|
| 637 |
+
# self._cleanup_models()
|
| 638 |
+
|
| 639 |
+
# return results
|
| 640 |
+
|
| 641 |
+
# except Exception as e:
|
| 642 |
+
# error_msg = f"Error executing image analysis: {str(e)}"
|
| 643 |
+
# self.state.errors.append(error_msg)
|
| 644 |
+
# self.logger.error(error_msg)
|
| 645 |
+
# return {'status': 'error', 'error': error_msg}
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
# def _get_image_caption(self, img_path: str, model_type: str) -> Dict:
|
| 650 |
+
# """
|
| 651 |
+
# Generate caption for an image using specified model type
|
| 652 |
+
|
| 653 |
+
# Args:
|
| 654 |
+
# img_path: Path to the image
|
| 655 |
+
# model_type: Either 'lightweight' or 'advanced'
|
| 656 |
+
|
| 657 |
+
# Returns:
|
| 658 |
+
# Dict containing:
|
| 659 |
+
# - caption: str - Generated caption
|
| 660 |
+
# - confidence: float - Confidence score
|
| 661 |
+
# - model_used: str - Name of the model used
|
| 662 |
+
# """
|
| 663 |
+
# try:
|
| 664 |
+
# # Select appropriate model
|
| 665 |
+
# model_key = f"{model_type}_caption"
|
| 666 |
+
# model_config = self.models[model_key]
|
| 667 |
+
|
| 668 |
+
# # TODO: Implement actual model loading and inference
|
| 669 |
+
# # For now, return placeholder result
|
| 670 |
+
# return {
|
| 671 |
+
# 'caption': f"Placeholder caption for {img_path}",
|
| 672 |
+
# 'confidence': 0.8 if model_type == 'advanced' else 0.6,
|
| 673 |
+
# 'model_used': model_config.name
|
| 674 |
+
# }
|
| 675 |
+
|
| 676 |
+
# except Exception as e:
|
| 677 |
+
# self.logger.error(f"Error in image captioning: {str(e)}")
|
| 678 |
+
# return {
|
| 679 |
+
# 'caption': '',
|
| 680 |
+
# 'confidence': 0.0,
|
| 681 |
+
# 'model_used': 'none'
|
| 682 |
+
# }
|
| 683 |
+
|
| 684 |
+
def _get_image_caption(self, img_path: str, model_type: str) -> Dict:
|
| 685 |
+
"""
|
| 686 |
+
Generate caption for an image using BLIP models
|
| 687 |
+
|
| 688 |
+
Args:
|
| 689 |
+
img_path: Path to the image
|
| 690 |
+
model_type: Either 'lightweight' or 'advanced'
|
| 691 |
+
|
| 692 |
+
Returns:
|
| 693 |
+
Dict containing:
|
| 694 |
+
- caption: str - Generated caption
|
| 695 |
+
- confidence: float - Confidence score
|
| 696 |
+
- model_used: str - Name of model used
|
| 697 |
+
"""
|
| 698 |
+
try:
|
| 699 |
+
model_key = f"{model_type}_caption"
|
| 700 |
+
model_config = self.models[model_key]
|
| 701 |
+
|
| 702 |
+
# Load image
|
| 703 |
+
image = Image.open(img_path).convert('RGB')
|
| 704 |
+
|
| 705 |
+
# Initialize model and processor based on type
|
| 706 |
+
if model_type == 'lightweight':
|
| 707 |
+
if 'lightweight_processor' not in self.__dict__:
|
| 708 |
+
self.lightweight_processor = BlipProcessor.from_pretrained(model_config.name)
|
| 709 |
+
self.lightweight_model = BlipForConditionalGeneration.from_pretrained(model_config.name)
|
| 710 |
+
processor = self.lightweight_processor
|
| 711 |
+
model = self.lightweight_model
|
| 712 |
+
else:
|
| 713 |
+
if 'advanced_processor' not in self.__dict__:
|
| 714 |
+
self.advanced_processor = Blip2Processor.from_pretrained(model_config.name)
|
| 715 |
+
self.advanced_model = Blip2ForConditionalGeneration.from_pretrained(model_config.name)
|
| 716 |
+
processor = self.advanced_processor
|
| 717 |
+
model = self.advanced_model
|
| 718 |
+
|
| 719 |
+
# Process image
|
| 720 |
+
inputs = processor(image, return_tensors="pt")
|
| 721 |
+
|
| 722 |
+
# Generate caption
|
| 723 |
+
outputs = model.generate(**inputs, max_new_tokens=50)
|
| 724 |
+
caption = processor.decode(outputs[0], skip_special_tokens=True)
|
| 725 |
+
|
| 726 |
+
# Calculate confidence (using max probability as proxy)
|
| 727 |
+
with torch.no_grad():
|
| 728 |
+
logits = model(**inputs).logits
|
| 729 |
+
confidence = float(torch.max(torch.softmax(logits[0], dim=-1)).item())
|
| 730 |
+
|
| 731 |
+
return {
|
| 732 |
+
'caption': caption,
|
| 733 |
+
'confidence': confidence,
|
| 734 |
+
'model_used': model_config.name
|
| 735 |
+
}
|
| 736 |
+
|
| 737 |
+
except Exception as e:
|
| 738 |
+
self.logger.error(f"Error in image captioning: {str(e)}")
|
| 739 |
+
return {
|
| 740 |
+
'caption': '',
|
| 741 |
+
'confidence': 0.0,
|
| 742 |
+
'model_used': 'none'
|
| 743 |
+
}
|
| 744 |
+
|
| 745 |
+
# def _select_relevant_images(self, analyzed_images: Dict, context: Dict, top_k: int) -> List[str]:
|
| 746 |
+
# """
|
| 747 |
+
# Select most relevant images based on context matching and caption analysis
|
| 748 |
+
|
| 749 |
+
# Args:
|
| 750 |
+
# analyzed_images: Dict of image analysis results
|
| 751 |
+
# context: Context information from ContextLearnAgent
|
| 752 |
+
# top_k: Number of top images to return
|
| 753 |
+
|
| 754 |
+
# Returns:
|
| 755 |
+
# List[str]: Paths of top-k relevant images
|
| 756 |
+
# """
|
| 757 |
+
# try:
|
| 758 |
+
# relevance_scores = {}
|
| 759 |
+
|
| 760 |
+
# for img_path, analysis in analyzed_images.items():
|
| 761 |
+
# # Calculate relevance score based on:
|
| 762 |
+
# # 1. Caption confidence
|
| 763 |
+
# # 2. Context matching
|
| 764 |
+
# # 3. Model reliability (advanced models given higher weight)
|
| 765 |
+
|
| 766 |
+
# base_score = analysis['confidence']
|
| 767 |
+
|
| 768 |
+
# # Adjust score based on model type
|
| 769 |
+
# model_weight = 1.2 if 'advanced' in analysis['model_used'] else 1.0
|
| 770 |
+
|
| 771 |
+
# # TODO: Implement context matching using embedding similarity
|
| 772 |
+
# context_match_score = 0.5 # Placeholder
|
| 773 |
+
|
| 774 |
+
# final_score = base_score * model_weight * context_match_score
|
| 775 |
+
# relevance_scores[img_path] = final_score
|
| 776 |
+
|
| 777 |
+
# self.logger.debug(f"Score for {img_path}: {final_score}")
|
| 778 |
+
|
| 779 |
+
# # Sort by score and get top-k
|
| 780 |
+
# selected_images = sorted(
|
| 781 |
+
# relevance_scores.items(),
|
| 782 |
+
# key=lambda x: x[1],
|
| 783 |
+
# reverse=True
|
| 784 |
+
# )[:top_k]
|
| 785 |
+
|
| 786 |
+
# return [img_path for img_path, _ in selected_images]
|
| 787 |
+
|
| 788 |
+
# except Exception as e:
|
| 789 |
+
# self.logger.error(f"Error in selecting relevant images: {str(e)}")
|
| 790 |
+
# return []
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
|
| 794 |
+
def _cleanup_llm(self):
|
| 795 |
+
"""
|
| 796 |
+
Cleanup LLM and model resources to manage memory
|
| 797 |
+
"""
|
| 798 |
+
try:
|
| 799 |
+
# Clear LangChain resources
|
| 800 |
+
if hasattr(self, 'relevance_chain'):
|
| 801 |
+
if hasattr(self.relevance_chain, 'clear'):
|
| 802 |
+
self.relevance_chain.clear()
|
| 803 |
+
|
| 804 |
+
# Clear any cached data
|
| 805 |
+
if hasattr(self, 'llm'):
|
| 806 |
+
if hasattr(self.llm, 'clear_cache'):
|
| 807 |
+
self.llm.clear_cache()
|
| 808 |
+
|
| 809 |
+
# Force garbage collection
|
| 810 |
+
import gc
|
| 811 |
+
gc.collect()
|
| 812 |
+
|
| 813 |
+
# Clear CUDA cache if available
|
| 814 |
+
import torch
|
| 815 |
+
if torch.cuda.is_available():
|
| 816 |
+
torch.cuda.empty_cache()
|
| 817 |
+
|
| 818 |
+
self.logger.info("LLM and model resources cleaned up")
|
| 819 |
+
|
| 820 |
+
except Exception as e:
|
| 821 |
+
self.logger.error(f"Error in LLM cleanup: {str(e)}")
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
def _select_relevant_images(self, analyzed_images: Dict, context: Dict, top_k: int = None) -> List[str]:
|
| 825 |
+
"""
|
| 826 |
+
Select most relevant images using LLM-enhanced analysis
|
| 827 |
+
"""
|
| 828 |
+
try:
|
| 829 |
+
# Use provided top_k or fall back to global config
|
| 830 |
+
top_k = top_k if top_k is not None else config.top_k
|
| 831 |
+
|
| 832 |
+
relevance_scores = {}
|
| 833 |
+
|
| 834 |
+
for img_path, analysis in analyzed_images.items():
|
| 835 |
+
# Get base scores from caption and classification
|
| 836 |
+
base_score = analysis['combined_confidence']
|
| 837 |
+
|
| 838 |
+
# Get LLM-based relevance analysis
|
| 839 |
+
llm_relevance = self._analyze_image_context_relevance(
|
| 840 |
+
caption=analysis['caption']['text'],
|
| 841 |
+
classification=analysis['classification']['label'],
|
| 842 |
+
context=context
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
# Calculate final score combining:
|
| 846 |
+
# 1. Base confidence from models
|
| 847 |
+
# 2. LLM relevance score
|
| 848 |
+
# 3. Model weights (advanced models given higher weight)
|
| 849 |
+
model_weight = 1.2 if 'advanced' in analysis['caption']['model_used'] else 1.0
|
| 850 |
+
|
| 851 |
+
final_score = (
|
| 852 |
+
base_score * 0.3 + # Original confidence (30%)
|
| 853 |
+
llm_relevance['relevance_score'] * 0.5 + # LLM relevance (50%)
|
| 854 |
+
llm_relevance['confidence'] * 0.2 # LLM confidence (20%)
|
| 855 |
+
) * model_weight
|
| 856 |
+
|
| 857 |
+
# Store results
|
| 858 |
+
relevance_scores[img_path] = {
|
| 859 |
+
'score': final_score,
|
| 860 |
+
'analysis': llm_relevance
|
| 861 |
+
}
|
| 862 |
+
|
| 863 |
+
self.logger.debug(f"Score for {img_path}: {final_score} "
|
| 864 |
+
f"(LLM relevance: {llm_relevance['relevance_score']})")
|
| 865 |
+
|
| 866 |
+
# Sort by score and get top-k
|
| 867 |
+
selected_images = sorted(
|
| 868 |
+
relevance_scores.items(),
|
| 869 |
+
key=lambda x: x[1]['score'],
|
| 870 |
+
reverse=True
|
| 871 |
+
)[:top_k]
|
| 872 |
+
|
| 873 |
+
# Log selection decisions
|
| 874 |
+
self.log_decision({
|
| 875 |
+
'action': 'image_selection',
|
| 876 |
+
'selected_count': len(selected_images),
|
| 877 |
+
'selection_criteria': {
|
| 878 |
+
'llm_analysis_used': True,
|
| 879 |
+
'top_scores': {img: data['score']
|
| 880 |
+
for img, data in selected_images}
|
| 881 |
+
}
|
| 882 |
+
})
|
| 883 |
+
|
| 884 |
+
self._cleanup_llm()
|
| 885 |
+
|
| 886 |
+
return [img_path for img_path, _ in selected_images]
|
| 887 |
+
|
| 888 |
+
except Exception as e:
|
| 889 |
+
self.logger.error(f"Error in selecting relevant images: {str(e)}")
|
| 890 |
+
self._cleanup_llm()
|
| 891 |
+
return []
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
def _select_relevant_images1(self, analyzed_images: Dict, context: Dict, top_k: int = None) -> List[str]:
|
| 895 |
+
"""
|
| 896 |
+
Select most relevant images based on caption, classification, and context
|
| 897 |
+
|
| 898 |
+
Args:
|
| 899 |
+
analyzed_images: Dict of image analysis results
|
| 900 |
+
context: Context information from ContextLearnAgent
|
| 901 |
+
top_k: Number of top images to return
|
| 902 |
+
|
| 903 |
+
Returns:
|
| 904 |
+
List[str]: Paths of top-k most relevant images
|
| 905 |
+
"""
|
| 906 |
+
try:
|
| 907 |
+
top_k = top_k if top_k is not None else config.top_k
|
| 908 |
+
|
| 909 |
+
relevance_scores = {}
|
| 910 |
+
|
| 911 |
+
for img_path, analysis in analyzed_images.items():
|
| 912 |
+
# Initialize base score from combined confidence
|
| 913 |
+
base_score = analysis['combined_confidence']
|
| 914 |
+
|
| 915 |
+
# Adjust score based on model types used
|
| 916 |
+
caption_model_weight = 1.2 if 'advanced' in analysis['caption']['model_used'] else 1.0
|
| 917 |
+
classify_model_weight = 1.2 if 'advanced' in analysis['classification']['model_used'] else 1.0
|
| 918 |
+
|
| 919 |
+
# Context matching for captions
|
| 920 |
+
caption_context_score = 0.5 # Default score
|
| 921 |
+
if context and 'keywords' in context:
|
| 922 |
+
caption_matches = sum(1 for keyword in context['keywords']
|
| 923 |
+
if keyword.lower() in analysis['caption']['text'].lower())
|
| 924 |
+
caption_context_score = min(1.0, caption_matches / len(context['keywords']))
|
| 925 |
+
|
| 926 |
+
# Context matching for classification
|
| 927 |
+
classification_context_score = 0.5 # Default score
|
| 928 |
+
if context and 'expected_classes' in context:
|
| 929 |
+
if analysis['classification']['label'].lower() in [cls.lower() for cls in context['expected_classes']]:
|
| 930 |
+
classification_context_score = 1.0
|
| 931 |
+
|
| 932 |
+
# Calculate final score
|
| 933 |
+
final_score = (
|
| 934 |
+
base_score *
|
| 935 |
+
caption_model_weight *
|
| 936 |
+
classify_model_weight *
|
| 937 |
+
(caption_context_score + classification_context_score) / 2
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
+
relevance_scores[img_path] = final_score
|
| 941 |
+
|
| 942 |
+
self.logger.debug(f"Score for {img_path}: {final_score} "
|
| 943 |
+
f"(caption_score: {caption_context_score}, "
|
| 944 |
+
f"class_score: {classification_context_score})")
|
| 945 |
+
|
| 946 |
+
# Sort by score and get top-k
|
| 947 |
+
selected_images = sorted(
|
| 948 |
+
relevance_scores.items(),
|
| 949 |
+
key=lambda x: x[1],
|
| 950 |
+
reverse=True
|
| 951 |
+
)[:top_k]
|
| 952 |
+
|
| 953 |
+
# Log selection decisions
|
| 954 |
+
self.log_decision({
|
| 955 |
+
'action': 'image_selection',
|
| 956 |
+
'selected_count': len(selected_images),
|
| 957 |
+
'top_scores': {img: score for img, score in selected_images}
|
| 958 |
+
})
|
| 959 |
+
|
| 960 |
+
return [img_path for img_path, _ in selected_images]
|
| 961 |
+
|
| 962 |
+
except Exception as e:
|
| 963 |
+
self.logger.error(f"Error in selecting relevant images: {str(e)}")
|
| 964 |
+
return []
|
| 965 |
+
|
| 966 |
+
|
| 967 |
+
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
def _get_classification(self, img_path: str, model_type: str) -> Dict:
|
| 971 |
+
"""
|
| 972 |
+
Classify image using ResNet (lightweight) or ViT (advanced)
|
| 973 |
+
|
| 974 |
+
Args:
|
| 975 |
+
img_path: Path to the image
|
| 976 |
+
model_type: Either 'lightweight' or 'advanced'
|
| 977 |
+
|
| 978 |
+
Returns:
|
| 979 |
+
Dict containing:
|
| 980 |
+
- class: str - Predicted class
|
| 981 |
+
- confidence: float - Confidence score
|
| 982 |
+
- model_used: str - Name of model used
|
| 983 |
+
"""
|
| 984 |
+
try:
|
| 985 |
+
model_key = f"{model_type}_classifier"
|
| 986 |
+
model_config = self.models[model_key]
|
| 987 |
+
|
| 988 |
+
# Load image
|
| 989 |
+
image = Image.open(img_path).convert('RGB')
|
| 990 |
+
|
| 991 |
+
# Initialize model and processor based on type
|
| 992 |
+
if model_type == 'lightweight':
|
| 993 |
+
if 'lightweight_clf_processor' not in self.__dict__:
|
| 994 |
+
self.lightweight_clf_processor = AutoImageProcessor.from_pretrained(model_config.name)
|
| 995 |
+
self.lightweight_clf_model = ResNetForImageClassification.from_pretrained(model_config.name)
|
| 996 |
+
processor = self.lightweight_clf_processor
|
| 997 |
+
model = self.lightweight_clf_model
|
| 998 |
+
else:
|
| 999 |
+
if 'advanced_clf_processor' not in self.__dict__:
|
| 1000 |
+
self.advanced_clf_processor = ViTImageProcessor.from_pretrained(model_config.name)
|
| 1001 |
+
self.advanced_clf_model = ViTForImageClassification.from_pretrained(model_config.name)
|
| 1002 |
+
processor = self.advanced_clf_processor
|
| 1003 |
+
model = self.advanced_clf_model
|
| 1004 |
+
|
| 1005 |
+
# Process image
|
| 1006 |
+
inputs = processor(image, return_tensors="pt")
|
| 1007 |
+
|
| 1008 |
+
# Get predictions
|
| 1009 |
+
with torch.no_grad():
|
| 1010 |
+
outputs = model(**inputs)
|
| 1011 |
+
logits = outputs.logits
|
| 1012 |
+
probs = torch.softmax(logits, dim=-1)
|
| 1013 |
+
|
| 1014 |
+
# Get highest probability class
|
| 1015 |
+
confidence, predicted_idx = torch.max(probs, dim=-1)
|
| 1016 |
+
predicted_label = model.config.id2label[predicted_idx.item()]
|
| 1017 |
+
|
| 1018 |
+
return {
|
| 1019 |
+
'class': predicted_label,
|
| 1020 |
+
'confidence': float(confidence.item()),
|
| 1021 |
+
'model_used': model_config.name
|
| 1022 |
+
}
|
| 1023 |
+
|
| 1024 |
+
except Exception as e:
|
| 1025 |
+
self.logger.error(f"Error in image classification: {str(e)}")
|
| 1026 |
+
return {
|
| 1027 |
+
'class': '',
|
| 1028 |
+
'confidence': 0.0,
|
| 1029 |
+
'model_used': 'none'
|
| 1030 |
+
}
|
| 1031 |
+
|
| 1032 |
+
|
| 1033 |
+
def _cleanup_models(self) -> None:
|
| 1034 |
+
"""
|
| 1035 |
+
Clean up loaded models to free memory
|
| 1036 |
+
Strategy: Remove models but keep processors (smaller memory footprint)
|
| 1037 |
+
"""
|
| 1038 |
+
try:
|
| 1039 |
+
# Caption models cleanup
|
| 1040 |
+
if hasattr(self, 'lightweight_model'):
|
| 1041 |
+
del self.lightweight_model
|
| 1042 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 1043 |
+
self.logger.info("Cleaned up lightweight caption model")
|
| 1044 |
+
|
| 1045 |
+
if hasattr(self, 'advanced_model'):
|
| 1046 |
+
del self.advanced_model
|
| 1047 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 1048 |
+
self.logger.info("Cleaned up advanced caption model")
|
| 1049 |
+
|
| 1050 |
+
# Classification models cleanup
|
| 1051 |
+
if hasattr(self, 'lightweight_clf_model'):
|
| 1052 |
+
del self.lightweight_clf_model
|
| 1053 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 1054 |
+
self.logger.info("Cleaned up lightweight classification model")
|
| 1055 |
+
|
| 1056 |
+
if hasattr(self, 'advanced_clf_model'):
|
| 1057 |
+
del self.advanced_clf_model
|
| 1058 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 1059 |
+
self.logger.info("Cleaned up advanced classification model")
|
| 1060 |
+
|
| 1061 |
+
# Keep processors as they're smaller and faster to reload
|
| 1062 |
+
self.logger.info("Model cleanup completed")
|
| 1063 |
+
|
| 1064 |
+
except Exception as e:
|
| 1065 |
+
self.logger.error(f"Error during model cleanup: {str(e)}")
|
agents/user_input_agent.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Dict, List, Any
|
| 3 |
+
|
| 4 |
+
from langchain.llms import HuggingFacePipeline
|
| 5 |
+
from langchain.prompts import PromptTemplate
|
| 6 |
+
from langchain.chains import LLMChain
|
| 7 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
|
| 8 |
+
|
| 9 |
+
import datetime
|
| 10 |
+
|
| 11 |
+
from .base_agent import BaseAgent
|
| 12 |
+
from models.data_models import UserInput
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class UserInputAgent(BaseAgent):
|
| 16 |
+
def __init__(self, name: str = "UserInputAgent"):
|
| 17 |
+
super().__init__(name)
|
| 18 |
+
self.model_name = "MBZUAI/LaMini-Flan-T5-783M"
|
| 19 |
+
self._initialize_llm()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _initialize_llm(self):
|
| 23 |
+
"""Initialize the language model pipeline"""
|
| 24 |
+
try:
|
| 25 |
+
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
| 26 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
|
| 27 |
+
|
| 28 |
+
# Create pipeline
|
| 29 |
+
pipe = pipeline(
|
| 30 |
+
"text2text-generation",
|
| 31 |
+
model=model,
|
| 32 |
+
tokenizer=tokenizer,
|
| 33 |
+
max_length=512,
|
| 34 |
+
temperature=0.3
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Create LangChain HF pipeline
|
| 38 |
+
self.llm = HuggingFacePipeline(pipeline=pipe)
|
| 39 |
+
|
| 40 |
+
# Create intent analysis chain
|
| 41 |
+
intent_template = """
|
| 42 |
+
Analyze this image analysis task request:
|
| 43 |
+
Query: {query}
|
| 44 |
+
Constraints: {constraints}
|
| 45 |
+
|
| 46 |
+
Provide a structured analysis including:
|
| 47 |
+
1. Main purpose of the request
|
| 48 |
+
2. Key requirements
|
| 49 |
+
3. Important constraints to consider
|
| 50 |
+
|
| 51 |
+
Analysis:
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
self.intent_chain = LLMChain(
|
| 55 |
+
llm=self.llm,
|
| 56 |
+
prompt=PromptTemplate(
|
| 57 |
+
template=intent_template,
|
| 58 |
+
input_variables=["query", "constraints"]
|
| 59 |
+
)
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
except Exception as e:
|
| 63 |
+
self.logger.error(f"Error initializing LLM: {str(e)}")
|
| 64 |
+
raise
|
| 65 |
+
|
| 66 |
+
def reason(self, input_data: UserInput) -> List[str]:
|
| 67 |
+
"""
|
| 68 |
+
Analyze user input using LangChain for better understanding
|
| 69 |
+
"""
|
| 70 |
+
thoughts = []
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
# Initial analysis
|
| 74 |
+
thoughts.append(f"Processing user query: '{input_data.query}'")
|
| 75 |
+
if input_data.constraints:
|
| 76 |
+
thoughts.append(f"With constraints: '{input_data.constraints}'")
|
| 77 |
+
|
| 78 |
+
# Use LangChain for intent analysis
|
| 79 |
+
llm_analysis = self.intent_chain.run(
|
| 80 |
+
query=input_data.query,
|
| 81 |
+
constraints=input_data.constraints if input_data.constraints else "None specified"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Parse and add LLM insights
|
| 85 |
+
thoughts.append("\nLLM Analysis:")
|
| 86 |
+
for line in llm_analysis.split('\n'):
|
| 87 |
+
if line.strip():
|
| 88 |
+
thoughts.append(f"- {line.strip()}")
|
| 89 |
+
|
| 90 |
+
# Add reasoning about next steps
|
| 91 |
+
thoughts.append("\nPlanned Actions:")
|
| 92 |
+
thoughts.append("1. Extract key terms for context search")
|
| 93 |
+
thoughts.append("2. Prepare constraints for filtering")
|
| 94 |
+
thoughts.append("3. Format requirements for image analysis")
|
| 95 |
+
|
| 96 |
+
# Store thoughts in state
|
| 97 |
+
self.state.thoughts.extend(thoughts)
|
| 98 |
+
self.logger.info("Reasoning complete for user input")
|
| 99 |
+
|
| 100 |
+
return thoughts
|
| 101 |
+
|
| 102 |
+
except Exception as e:
|
| 103 |
+
error_msg = f"Error during reasoning: {str(e)}"
|
| 104 |
+
self.state.errors.append(error_msg)
|
| 105 |
+
self.logger.error(error_msg)
|
| 106 |
+
return thoughts
|
| 107 |
+
|
| 108 |
+
# def reason1(self, input_data: UserInput) -> List[str]:
|
| 109 |
+
# """
|
| 110 |
+
# Analyze user input and formulate reasoning steps about their intent
|
| 111 |
+
|
| 112 |
+
# Args:
|
| 113 |
+
# input_data: UserInput containing query, constraints, and top_k
|
| 114 |
+
|
| 115 |
+
# Returns:
|
| 116 |
+
# List[str]: Reasoning steps about user's intent
|
| 117 |
+
# """
|
| 118 |
+
# thoughts = []
|
| 119 |
+
|
| 120 |
+
# try:
|
| 121 |
+
# # Analyze main query
|
| 122 |
+
# thoughts.append(f"Analyzing user query: '{input_data.query}'")
|
| 123 |
+
|
| 124 |
+
# # Analyze constraints if provided
|
| 125 |
+
# if input_data.constraints:
|
| 126 |
+
# thoughts.append(f"Considering constraints: '{input_data.constraints}'")
|
| 127 |
+
# else:
|
| 128 |
+
# thoughts.append("No specific constraints provided")
|
| 129 |
+
|
| 130 |
+
# # Consider result limit
|
| 131 |
+
# thoughts.append(f"User requests top {input_data.top_k} results")
|
| 132 |
+
|
| 133 |
+
# # Store thoughts in state
|
| 134 |
+
# self.state.thoughts.extend(thoughts)
|
| 135 |
+
# self.logger.info("Reasoning complete for user input")
|
| 136 |
+
|
| 137 |
+
# return thoughts
|
| 138 |
+
|
| 139 |
+
# except Exception as e:
|
| 140 |
+
# error_msg = f"Error during reasoning: {str(e)}"
|
| 141 |
+
# self.state.errors.append(error_msg)
|
| 142 |
+
# self.logger.error(error_msg)
|
| 143 |
+
# return thoughts
|
| 144 |
+
|
| 145 |
+
# def execute(self, input_data: UserInput) -> Dict:
|
| 146 |
+
# """
|
| 147 |
+
# Process user input and prepare structured data for other agents
|
| 148 |
+
|
| 149 |
+
# Args:
|
| 150 |
+
# input_data: UserInput containing query, constraints, and top_k
|
| 151 |
+
|
| 152 |
+
# Returns:
|
| 153 |
+
# Dict containing:
|
| 154 |
+
# - processed_query: str - Cleaned and formatted query
|
| 155 |
+
# - constraints: List[str] - List of parsed constraints
|
| 156 |
+
# - parameters: Dict - Additional parameters including top_k
|
| 157 |
+
# - status: str - Processing status
|
| 158 |
+
# """
|
| 159 |
+
# try:
|
| 160 |
+
# # First validate the input
|
| 161 |
+
# if not self.validate(input_data):
|
| 162 |
+
# return {
|
| 163 |
+
# 'status': 'error',
|
| 164 |
+
# 'error': self.state.errors[-1]
|
| 165 |
+
# }
|
| 166 |
+
|
| 167 |
+
# # Process the input
|
| 168 |
+
# processed_data = {
|
| 169 |
+
# 'processed_query': input_data.query.strip().lower(),
|
| 170 |
+
# 'constraints': [c.strip() for c in input_data.constraints.split(';')] if input_data.constraints else [],
|
| 171 |
+
# 'parameters': {
|
| 172 |
+
# 'top_k': input_data.top_k,
|
| 173 |
+
# 'timestamp': datetime.now().isoformat()
|
| 174 |
+
# },
|
| 175 |
+
# 'status': 'success'
|
| 176 |
+
# }
|
| 177 |
+
|
| 178 |
+
# # Log the decision
|
| 179 |
+
# self.log_decision({
|
| 180 |
+
# 'action': 'process_user_input',
|
| 181 |
+
# 'input': str(input_data),
|
| 182 |
+
# 'output': processed_data
|
| 183 |
+
# })
|
| 184 |
+
|
| 185 |
+
# return processed_data
|
| 186 |
+
|
| 187 |
+
# except Exception as e:
|
| 188 |
+
# error_msg = f"Error executing user input processing: {str(e)}"
|
| 189 |
+
# self.state.errors.append(error_msg)
|
| 190 |
+
# self.logger.error(error_msg)
|
| 191 |
+
# return {'status': 'error', 'error': error_msg}
|
| 192 |
+
|
| 193 |
+
def execute(self, input_data: UserInput) -> Dict:
|
| 194 |
+
"""
|
| 195 |
+
Process user input and prepare it for other agents
|
| 196 |
+
"""
|
| 197 |
+
try:
|
| 198 |
+
if not self.validate(input_data):
|
| 199 |
+
return {
|
| 200 |
+
'status': 'error',
|
| 201 |
+
'error': self.state.errors[-1]
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
# Process input with LLM analysis
|
| 205 |
+
llm_analysis = self._analyze_intent(
|
| 206 |
+
query=input_data.query,
|
| 207 |
+
constraints=input_data.constraints
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
results = {
|
| 211 |
+
'processed_input': {
|
| 212 |
+
'original_query': input_data.query,
|
| 213 |
+
'constraints': input_data.constraints.split(';') if input_data.constraints else [],
|
| 214 |
+
'llm_analysis': {
|
| 215 |
+
'main_purpose': llm_analysis.get('purpose', ''),
|
| 216 |
+
'key_requirements': llm_analysis.get('requirements', []),
|
| 217 |
+
'constraint_interpretation': llm_analysis.get('constraints', [])
|
| 218 |
+
}
|
| 219 |
+
},
|
| 220 |
+
'metadata': {
|
| 221 |
+
'timestamp': datetime.now().isoformat(),
|
| 222 |
+
'model_used': self.model_name
|
| 223 |
+
},
|
| 224 |
+
'status': 'success'
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
# Log decision
|
| 228 |
+
self.log_decision({
|
| 229 |
+
'action': 'process_user_input',
|
| 230 |
+
'input': str(input_data),
|
| 231 |
+
'llm_insights': llm_analysis
|
| 232 |
+
})
|
| 233 |
+
|
| 234 |
+
# Cleanup LLM resources
|
| 235 |
+
self._cleanup_llm()
|
| 236 |
+
|
| 237 |
+
return results
|
| 238 |
+
|
| 239 |
+
except Exception as e:
|
| 240 |
+
error_msg = f"Error executing user input processing: {str(e)}"
|
| 241 |
+
self.state.errors.append(error_msg)
|
| 242 |
+
self.logger.error(error_msg)
|
| 243 |
+
return {'status': 'error', 'error': error_msg}
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def _analyze_intent(self, query: str, constraints: str) -> Dict:
|
| 247 |
+
"""
|
| 248 |
+
Process user input through LangChain and structure the results
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
query: User's query string
|
| 252 |
+
constraints: User's constraints string
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
Dict containing:
|
| 256 |
+
- purpose: str - Main purpose identified
|
| 257 |
+
- requirements: List[str] - Key requirements
|
| 258 |
+
- constraints: List[str] - Interpreted constraints
|
| 259 |
+
"""
|
| 260 |
+
try:
|
| 261 |
+
# Get raw LLM analysis
|
| 262 |
+
raw_analysis = self.intent_chain.run(
|
| 263 |
+
query=query,
|
| 264 |
+
constraints=constraints if constraints else "None specified"
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Structure the analysis
|
| 268 |
+
analysis_dict = {
|
| 269 |
+
'purpose': '',
|
| 270 |
+
'requirements': [],
|
| 271 |
+
'constraints': []
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
# Parse the raw output
|
| 275 |
+
current_section = ''
|
| 276 |
+
for line in raw_analysis.split('\n'):
|
| 277 |
+
line = line.strip()
|
| 278 |
+
if 'purpose' in line.lower():
|
| 279 |
+
current_section = 'purpose'
|
| 280 |
+
elif 'requirement' in line.lower():
|
| 281 |
+
current_section = 'requirements'
|
| 282 |
+
elif 'constraint' in line.lower():
|
| 283 |
+
current_section = 'constraints'
|
| 284 |
+
elif line:
|
| 285 |
+
if current_section == 'purpose':
|
| 286 |
+
analysis_dict['purpose'] = line
|
| 287 |
+
elif current_section in ['requirements', 'constraints']:
|
| 288 |
+
analysis_dict[current_section].append(line)
|
| 289 |
+
|
| 290 |
+
return analysis_dict
|
| 291 |
+
|
| 292 |
+
except Exception as e:
|
| 293 |
+
self.logger.error(f"Error in intent analysis: {str(e)}")
|
| 294 |
+
return {'purpose': '', 'requirements': [], 'constraints': []}
|
| 295 |
+
|
| 296 |
+
def _cleanup_llm(self):
|
| 297 |
+
"""
|
| 298 |
+
Cleanup LLM resources to manage memory
|
| 299 |
+
"""
|
| 300 |
+
try:
|
| 301 |
+
# Clear any cached data
|
| 302 |
+
if hasattr(self, 'intent_chain'):
|
| 303 |
+
# Clear any stored predictions
|
| 304 |
+
if hasattr(self.intent_chain, 'clear'):
|
| 305 |
+
self.intent_chain.clear()
|
| 306 |
+
|
| 307 |
+
# Force garbage collection
|
| 308 |
+
import gc
|
| 309 |
+
gc.collect()
|
| 310 |
+
|
| 311 |
+
# Clear CUDA cache if available
|
| 312 |
+
import torch
|
| 313 |
+
if torch.cuda.is_available():
|
| 314 |
+
torch.cuda.empty_cache()
|
| 315 |
+
|
| 316 |
+
self.logger.info("LLM resources cleaned up")
|
| 317 |
+
|
| 318 |
+
except Exception as e:
|
| 319 |
+
self.logger.error(f"Error in LLM cleanup: {str(e)}")
|
| 320 |
+
|
| 321 |
+
|
app.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import logging
|
| 4 |
+
from interface.app import create_interface
|
| 5 |
+
from utils.resource_manager import ResourceManager
|
| 6 |
+
from config.settings import config
|
| 7 |
+
|
| 8 |
+
# Configure logging
|
| 9 |
+
logging.basicConfig(
|
| 10 |
+
level=logging.INFO,
|
| 11 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 12 |
+
)
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
def initialize_app():
|
| 16 |
+
"""Initialize application components"""
|
| 17 |
+
try:
|
| 18 |
+
# Initialize resource manager
|
| 19 |
+
resource_manager = ResourceManager()
|
| 20 |
+
|
| 21 |
+
# Check initial resources
|
| 22 |
+
resources_ok, error_msg = resource_manager.check_resources()
|
| 23 |
+
if not resources_ok:
|
| 24 |
+
logger.error(f"Resource check failed: {error_msg}")
|
| 25 |
+
raise Exception(error_msg)
|
| 26 |
+
|
| 27 |
+
# Create Gradio interface
|
| 28 |
+
interface = create_interface()
|
| 29 |
+
|
| 30 |
+
return interface
|
| 31 |
+
|
| 32 |
+
except Exception as e:
|
| 33 |
+
logger.error(f"Initialization error: {str(e)}")
|
| 34 |
+
raise
|
| 35 |
+
|
| 36 |
+
if __name__ == "__main__":
|
| 37 |
+
try:
|
| 38 |
+
# Initialize app
|
| 39 |
+
interface = initialize_app()
|
| 40 |
+
|
| 41 |
+
# Launch Gradio interface
|
| 42 |
+
interface.launch(
|
| 43 |
+
share=True,
|
| 44 |
+
server_name="0.0.0.0",
|
| 45 |
+
server_port=7860,
|
| 46 |
+
enable_queue=True,
|
| 47 |
+
max_threads=4, # Limit concurrent processing
|
| 48 |
+
auth=None, # Add authentication if needed
|
| 49 |
+
ssl_verify=False, # For HuggingFace spaces
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
except Exception as e:
|
| 53 |
+
logger.error(f"Application startup failed: {str(e)}")
|
| 54 |
+
raise
|
config/__init__.py
ADDED
|
File without changes
|
config/settings.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# config/settings.py
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
@dataclass
|
| 5 |
+
class GlobalConfig:
|
| 6 |
+
top_k: int = 5 # default value
|
| 7 |
+
max_wiki_results: int = 5 # our existing ContextLearnAgent parameter
|
| 8 |
+
report_format: str = "detailed" # for AssemblerAgent
|
| 9 |
+
|
| 10 |
+
# Can be accessed throughout the application
|
| 11 |
+
config = GlobalConfig()
|
interface/__init__.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# interface/app.py
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from typing import List, Dict
|
| 4 |
+
from .handlers import process_inputs
|
| 5 |
+
from .display import format_results
|
| 6 |
+
from .utils import validate_inputs
|
| 7 |
+
from config.settings import config
|
| 8 |
+
|
| 9 |
+
def create_interface():
|
| 10 |
+
"""Create the Gradio interface"""
|
| 11 |
+
|
| 12 |
+
# Create interface components
|
| 13 |
+
with gr.Blocks(title="Pump Inspection Analysis") as app:
|
| 14 |
+
gr.Markdown("# Pump Inspection Analysis System")
|
| 15 |
+
|
| 16 |
+
with gr.Row():
|
| 17 |
+
# Input components
|
| 18 |
+
with gr.Column():
|
| 19 |
+
query = gr.Textbox(
|
| 20 |
+
label="What would you like to analyze?",
|
| 21 |
+
placeholder="e.g., Check safety issues in pump systems",
|
| 22 |
+
lines=3
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
constraints = gr.Textbox(
|
| 26 |
+
label="Any specific constraints? (optional)",
|
| 27 |
+
placeholder="e.g., Exclude routine maintenance issues",
|
| 28 |
+
lines=2
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
top_k = gr.Slider(
|
| 32 |
+
minimum=1,
|
| 33 |
+
maximum=10,
|
| 34 |
+
value=5,
|
| 35 |
+
step=1,
|
| 36 |
+
label="Number of top results to show"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
report_format = gr.Radio(
|
| 40 |
+
choices=["summary", "detailed"],
|
| 41 |
+
value="summary",
|
| 42 |
+
label="Report Format"
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
images = gr.File(
|
| 46 |
+
file_count="multiple",
|
| 47 |
+
label="Upload Images",
|
| 48 |
+
file_types=["image"]
|
| 49 |
+
)
|
interface/app.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# interface/app.py
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from typing import List, Dict
|
| 4 |
+
from .handlers import process_inputs
|
| 5 |
+
from .display import format_results
|
| 6 |
+
from .utils import validate_inputs
|
| 7 |
+
from config.settings import config
|
| 8 |
+
from .handlers import ProcessingHandler
|
| 9 |
+
from .utils import InputValidator
|
| 10 |
+
|
| 11 |
+
def create_interface():
|
| 12 |
+
"""Create the Gradio interface"""
|
| 13 |
+
handler = ProcessingHandler()
|
| 14 |
+
|
| 15 |
+
# Create interface components
|
| 16 |
+
with gr.Blocks(title="Image Inspection analysis") as app:
|
| 17 |
+
gr.Markdown("# Image Inspection analysis System")
|
| 18 |
+
|
| 19 |
+
with gr.Row():
|
| 20 |
+
# Input components
|
| 21 |
+
with gr.Column():
|
| 22 |
+
query = gr.Textbox(
|
| 23 |
+
label="What would you like to analyze?",
|
| 24 |
+
placeholder="e.g., Check safety issues in pump systems",
|
| 25 |
+
lines=3
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
constraints = gr.Textbox(
|
| 29 |
+
label="Any specific constraints? (optional)",
|
| 30 |
+
placeholder="e.g., Exclude routine maintenance issues",
|
| 31 |
+
lines=2
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
top_k = gr.Slider(
|
| 35 |
+
minimum=1,
|
| 36 |
+
maximum=10,
|
| 37 |
+
value=5,
|
| 38 |
+
step=1,
|
| 39 |
+
label="Number of top results to show"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
report_format = gr.Radio(
|
| 43 |
+
choices=["summary", "detailed"],
|
| 44 |
+
value="summary",
|
| 45 |
+
label="Report Format"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
images = gr.File(
|
| 49 |
+
file_count="multiple",
|
| 50 |
+
label="Upload Images",
|
| 51 |
+
file_types=["image"]
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
submit_btn = gr.Button("Analyze", variant="primary")
|
| 55 |
+
|
| 56 |
+
# Right column - Outputs
|
| 57 |
+
with gr.Column():
|
| 58 |
+
with gr.Tab("Results"):
|
| 59 |
+
analysis_status = gr.Markdown("Ready for analysis...")
|
| 60 |
+
|
| 61 |
+
results_box = gr.Markdown(
|
| 62 |
+
visible=False,
|
| 63 |
+
label="Analysis Results"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
selected_images = gr.Gallery(
|
| 67 |
+
label="Selected Relevant Images",
|
| 68 |
+
visible=False,
|
| 69 |
+
columns=2,
|
| 70 |
+
height=400
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
confidence_scores = gr.Json(
|
| 74 |
+
label="Confidence Scores",
|
| 75 |
+
visible=False
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
with gr.Tab("Processing Details"):
|
| 79 |
+
processing_status = gr.JSON(
|
| 80 |
+
label="Processing Steps",
|
| 81 |
+
visible=False
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
with gr.Tab("Errors"):
|
| 85 |
+
error_box = gr.Markdown(
|
| 86 |
+
visible=False
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Second: Helper functions for UI updates
|
| 90 |
+
def update_ui_on_error(error_msg):
|
| 91 |
+
return {
|
| 92 |
+
results_box: gr.update(visible=True, value=error_msg),
|
| 93 |
+
selected_images: gr.update(visible=False),
|
| 94 |
+
confidence_scores: gr.update(visible=False),
|
| 95 |
+
processing_status: gr.update(visible=True, value={'status': 'error'})
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
def update_ui_on_success(results):
|
| 99 |
+
return {
|
| 100 |
+
results_box: gr.update(visible=True, value=results['content']),
|
| 101 |
+
selected_images: gr.update(visible=True, value=results['images']),
|
| 102 |
+
confidence_scores: gr.update(visible=True, value=results['scores']),
|
| 103 |
+
processing_status: gr.update(visible=True, value={'status': 'success'})
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
def validate_and_process(query, constraints, images, top_k, report_format):
|
| 107 |
+
# Validate inputs
|
| 108 |
+
is_valid, error_message = InputValidator.validate_inputs(
|
| 109 |
+
query, constraints, images
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if not is_valid:
|
| 113 |
+
return update_ui_on_error(error_message)
|
| 114 |
+
|
| 115 |
+
# If valid, proceed with processing
|
| 116 |
+
return handler.process_inputs(
|
| 117 |
+
query, constraints, images, top_k, report_format
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Single submit button with combined functionality
|
| 121 |
+
submit_btn = gr.Button("Analyze", variant="primary")
|
| 122 |
+
|
| 123 |
+
# Connect submit button to both clear and process
|
| 124 |
+
submit_btn.click(
|
| 125 |
+
fn=lambda: [
|
| 126 |
+
gr.Markdown.update(visible=True, value="Validating inputs......"),
|
| 127 |
+
gr.Markdown.update(visible=False),
|
| 128 |
+
gr.Gallery.update(visible=False),
|
| 129 |
+
gr.Json.update(visible=False),
|
| 130 |
+
gr.JSON.update(visible=False)
|
| 131 |
+
],
|
| 132 |
+
inputs=None,
|
| 133 |
+
outputs=[analysis_status, results_box, selected_images,
|
| 134 |
+
confidence_scores, processing_status]
|
| 135 |
+
).then( # Chain the processing after clearing
|
| 136 |
+
fn=validate_and_process,
|
| 137 |
+
inputs=[query, constraints, images, top_k, report_format],
|
| 138 |
+
outputs=[results_box, selected_images, confidence_scores, processing_status],
|
| 139 |
+
#show_progress=True
|
| 140 |
+
).then(
|
| 141 |
+
# Third: Update UI based on results
|
| 142 |
+
fn=lambda results: update_ui_on_success(results) if results['status'] == 'success'
|
| 143 |
+
else update_ui_on_error(results['error']),
|
| 144 |
+
inputs=[processing_status],
|
| 145 |
+
outputs=[results_box, selected_images, confidence_scores, processing_status]
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
return app
|
| 149 |
+
|
| 150 |
+
# Launch the interface
|
| 151 |
+
if __name__ == "__main__":
|
| 152 |
+
interface = create_interface()
|
| 153 |
+
interface.launch()
|
interface/display.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# interface/display.py
|
| 2 |
+
from typing import Dict, List
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
class DisplayFormatter:
|
| 6 |
+
@staticmethod
|
| 7 |
+
def format_error(error_msg: str) -> Dict:
|
| 8 |
+
"""Format error messages for display"""
|
| 9 |
+
return {
|
| 10 |
+
'status': 'error',
|
| 11 |
+
'content': f"""
|
| 12 |
+
### ❌ Error
|
| 13 |
+
{error_msg}
|
| 14 |
+
|
| 15 |
+
Please try again or contact support if the issue persists.
|
| 16 |
+
"""
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
@staticmethod
|
| 20 |
+
def format_results(report: Dict, images: List[str], scores: Dict) -> Dict:
|
| 21 |
+
"""Format successful results for display"""
|
| 22 |
+
try:
|
| 23 |
+
# Format main results
|
| 24 |
+
markdown_content = f"""
|
| 25 |
+
### 📊 Analysis Results
|
| 26 |
+
|
| 27 |
+
#### Query Analysis
|
| 28 |
+
{report.get('query_analysis', 'No analysis available')}
|
| 29 |
+
|
| 30 |
+
#### Key Findings
|
| 31 |
+
{report.get('key_findings', 'No findings available')}
|
| 32 |
+
|
| 33 |
+
#### Recommendations
|
| 34 |
+
{report.get('recommendations', 'No recommendations available')}
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
# Format confidence scores
|
| 38 |
+
confidence_display = {
|
| 39 |
+
'Model Performance': scores.get('model_decisions', {}),
|
| 40 |
+
'Overall Confidence': f"{scores.get('average_confidence', 0) * 100:.1f}%"
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
return {
|
| 44 |
+
'status': 'success',
|
| 45 |
+
'content': markdown_content,
|
| 46 |
+
'scores': confidence_display,
|
| 47 |
+
'images': images
|
| 48 |
+
}
|
| 49 |
+
except Exception as e:
|
| 50 |
+
return DisplayFormatter.format_error(f"Error formatting results: {str(e)}")
|
interface/handlers.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# interface/handlers.py
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from typing import Dict, List, Tuple
|
| 4 |
+
from agents.user_input_agent import UserInputAgent
|
| 5 |
+
from agents.context_agent import ContextLearnAgent
|
| 6 |
+
from agents.image_agent import ImageAnalyzerAgent
|
| 7 |
+
from agents.assembler_agent import AssemblerAgent
|
| 8 |
+
from models.data_models import UserInput, ContextInput, ImageAnalysisInput
|
| 9 |
+
from config.settings import config
|
| 10 |
+
from .display import DisplayFormatter
|
| 11 |
+
from utils.resource_manager import ResourceManager
|
| 12 |
+
|
| 13 |
+
class ProcessingHandler:
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.formatter = DisplayFormatter()
|
| 16 |
+
self.resource_manager = ResourceManager()
|
| 17 |
+
|
| 18 |
+
"""Initialize agents"""
|
| 19 |
+
self.user_agent = UserInputAgent()
|
| 20 |
+
self.context_agent = ContextLearnAgent()
|
| 21 |
+
self.image_agent = ImageAnalyzerAgent()
|
| 22 |
+
self.assembler_agent = AssemblerAgent()
|
| 23 |
+
|
| 24 |
+
async def process_inputs(
|
| 25 |
+
self,
|
| 26 |
+
query: str,
|
| 27 |
+
constraints: str,
|
| 28 |
+
images: List[str],
|
| 29 |
+
top_k: int,
|
| 30 |
+
report_format: str,
|
| 31 |
+
progress=gr.Progress()
|
| 32 |
+
) -> Tuple[str, List[str], Dict, Dict]:
|
| 33 |
+
"""
|
| 34 |
+
Process inputs through agent pipeline
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Tuple containing:
|
| 38 |
+
- results_markdown: str
|
| 39 |
+
- selected_images: List[str]
|
| 40 |
+
- confidence_scores: Dict
|
| 41 |
+
- processing_details: Dict
|
| 42 |
+
"""
|
| 43 |
+
try:
|
| 44 |
+
resources_ok, error_msg = self.resource_manager.check_resources()
|
| 45 |
+
if not resources_ok:
|
| 46 |
+
raise Exception(error_msg)
|
| 47 |
+
processing_details = {'status': 'processing'}
|
| 48 |
+
|
| 49 |
+
# Step 1: Process user input
|
| 50 |
+
progress(0.1, desc="Processing user input...")
|
| 51 |
+
user_results = self.user_agent.execute(
|
| 52 |
+
UserInput(query=query, constraints=constraints)
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
if user_results['status'] == 'error':
|
| 56 |
+
raise Exception(user_results['error'])
|
| 57 |
+
|
| 58 |
+
# Step 2: Gather context
|
| 59 |
+
progress(0.3, desc="Learning context...")
|
| 60 |
+
context_results = self.context_agent.execute(
|
| 61 |
+
ContextInput(
|
| 62 |
+
processed_query=user_results['processed_input']['original_query'],
|
| 63 |
+
constraints=user_results['processed_input']['constraints'],
|
| 64 |
+
domain="oil_and_gas"
|
| 65 |
+
)
|
| 66 |
+
)
|
| 67 |
+
# Step 3: Analyze images
|
| 68 |
+
progress(0.5, desc="Analyzing images...")
|
| 69 |
+
image_results = self.image_agent.execute(
|
| 70 |
+
ImageAnalysisInput(
|
| 71 |
+
images=images,
|
| 72 |
+
context=context_results['gathered_context'],
|
| 73 |
+
constraints=user_results['processed_input']['constraints'],
|
| 74 |
+
top_k=top_k
|
| 75 |
+
)
|
| 76 |
+
)
|
| 77 |
+
# Step 4: Assemble final report
|
| 78 |
+
progress(0.8, desc="Assembling report...")
|
| 79 |
+
final_report = self.assembler_agent.execute({
|
| 80 |
+
'user_input_results': user_results,
|
| 81 |
+
'context_results': context_results,
|
| 82 |
+
'image_results': image_results,
|
| 83 |
+
'report_format': report_format
|
| 84 |
+
})
|
| 85 |
+
|
| 86 |
+
progress(1.0, desc="Complete!")
|
| 87 |
+
|
| 88 |
+
# Format results for Gradio
|
| 89 |
+
# return (
|
| 90 |
+
# final_report['summary'], # results_markdown
|
| 91 |
+
# image_results['selected_images'], # selected_images
|
| 92 |
+
# image_results['model_decisions'], # confidence_scores
|
| 93 |
+
# processing_details # processing_details
|
| 94 |
+
# )
|
| 95 |
+
|
| 96 |
+
# Monitor resources during processing
|
| 97 |
+
self.resource_manager.monitor_and_cleanup()
|
| 98 |
+
|
| 99 |
+
# Format results
|
| 100 |
+
formatted_results = self.formatter.format_results(
|
| 101 |
+
final_report['summary'],
|
| 102 |
+
image_results['selected_images'],
|
| 103 |
+
image_results['model_decisions']
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Final cleanup
|
| 107 |
+
self.resource_manager.cleanup()
|
| 108 |
+
|
| 109 |
+
return (
|
| 110 |
+
formatted_results['content'],
|
| 111 |
+
formatted_results['images'],
|
| 112 |
+
formatted_results['scores'],
|
| 113 |
+
{'status': 'success', 'details': processing_details}
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
except Exception as e:
|
| 117 |
+
self.resource_manager.cleanup() # Ensure cleanup on error
|
| 118 |
+
error_format = self.formatter.format_error(str(e))
|
| 119 |
+
return (
|
| 120 |
+
error_format['content'],
|
| 121 |
+
[],
|
| 122 |
+
{},
|
| 123 |
+
{'status': 'error', 'error': str(e)}
|
| 124 |
+
)
|
| 125 |
+
# error_msg = f"Error during processing: {str(e)}"
|
| 126 |
+
# processing_details['status'] = 'error'
|
| 127 |
+
# processing_details['error'] = error_msg
|
| 128 |
+
# return "", [], {}, processing_details
|
| 129 |
+
|
interface/utils.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# interface/utils.py
|
| 2 |
+
from typing import Tuple, List
|
| 3 |
+
import os
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
class InputValidator:
|
| 7 |
+
# Constants for validation
|
| 8 |
+
MAX_QUERY_LENGTH = 500
|
| 9 |
+
MIN_IMAGES = 1
|
| 10 |
+
MAX_IMAGES = 10
|
| 11 |
+
ALLOWED_IMAGE_TYPES = ['.jpg', '.jpeg', '.png']
|
| 12 |
+
MAX_IMAGE_SIZE_MB = 5
|
| 13 |
+
MAX_IMAGE_RESOLUTION = (2048, 2048)
|
| 14 |
+
ALLOWED_FORMATS = ['summary', 'detailed']
|
| 15 |
+
|
| 16 |
+
@staticmethod
|
| 17 |
+
def validate_inputs(
|
| 18 |
+
query: str,
|
| 19 |
+
constraints: str,
|
| 20 |
+
images: List[str],
|
| 21 |
+
top_k: int,
|
| 22 |
+
report_format: str
|
| 23 |
+
) -> Tuple[bool, str]:
|
| 24 |
+
"""
|
| 25 |
+
Validate all user inputs
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
query: User's query string
|
| 29 |
+
constraints: User's constraints string
|
| 30 |
+
images: List of image paths
|
| 31 |
+
top_k: Number of top results to return
|
| 32 |
+
report_format: Type of report to generate
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
Tuple(is_valid: bool, error_message: str)
|
| 36 |
+
"""
|
| 37 |
+
try:
|
| 38 |
+
# Query validation
|
| 39 |
+
if not query or query.isspace():
|
| 40 |
+
return False, "Query is required"
|
| 41 |
+
if len(query) > InputValidator.MAX_QUERY_LENGTH:
|
| 42 |
+
return False, f"Query too long (max {InputValidator.MAX_QUERY_LENGTH} characters)"
|
| 43 |
+
|
| 44 |
+
# Images validation
|
| 45 |
+
if not images:
|
| 46 |
+
return False, "At least one image is required"
|
| 47 |
+
if len(images) > InputValidator.MAX_IMAGES:
|
| 48 |
+
return False, f"Too many images. Maximum allowed: {InputValidator.MAX_IMAGES}"
|
| 49 |
+
|
| 50 |
+
# Process each image
|
| 51 |
+
for img_path in images:
|
| 52 |
+
# File type check
|
| 53 |
+
file_ext = os.path.splitext(img_path)[1].lower()
|
| 54 |
+
if file_ext not in InputValidator.ALLOWED_IMAGE_TYPES:
|
| 55 |
+
return False, f"Invalid image type: {file_ext}. Allowed types: {', '.join(InputValidator.ALLOWED_IMAGE_TYPES)}"
|
| 56 |
+
|
| 57 |
+
# File size check
|
| 58 |
+
file_size_mb = os.path.getsize(img_path) / (1024 * 1024)
|
| 59 |
+
if file_size_mb > InputValidator.MAX_IMAGE_SIZE_MB:
|
| 60 |
+
return False, f"Image too large: {file_size_mb:.1f}MB. Maximum size: {InputValidator.MAX_IMAGE_SIZE_MB}MB"
|
| 61 |
+
|
| 62 |
+
# Image integrity and resolution check
|
| 63 |
+
try:
|
| 64 |
+
with Image.open(img_path) as img:
|
| 65 |
+
img.verify() # Verify image integrity
|
| 66 |
+
width, height = img.size
|
| 67 |
+
if width > InputValidator.MAX_IMAGE_RESOLUTION[0] or height > InputValidator.MAX_IMAGE_RESOLUTION[1]:
|
| 68 |
+
return False, f"Image resolution too high. Maximum: {InputValidator.MAX_IMAGE_RESOLUTION[0]}x{InputValidator.MAX_IMAGE_RESOLUTION[1]}"
|
| 69 |
+
except Exception as e:
|
| 70 |
+
return False, f"Invalid or corrupted image: {os.path.basename(img_path)}"
|
| 71 |
+
|
| 72 |
+
# Top-k validation
|
| 73 |
+
if not isinstance(top_k, int) or top_k < 1:
|
| 74 |
+
return False, "Top-k must be a positive integer"
|
| 75 |
+
if top_k > len(images):
|
| 76 |
+
return False, f"Top-k ({top_k}) cannot be larger than number of images ({len(images)})"
|
| 77 |
+
|
| 78 |
+
# Report format validation
|
| 79 |
+
if report_format not in InputValidator.ALLOWED_FORMATS:
|
| 80 |
+
return False, f"Invalid report format. Allowed formats: {', '.join(InputValidator.ALLOWED_FORMATS)}"
|
| 81 |
+
|
| 82 |
+
# Optional constraints validation
|
| 83 |
+
if constraints and len(constraints) > InputValidator.MAX_QUERY_LENGTH:
|
| 84 |
+
return False, f"Constraints too long (max {InputValidator.MAX_QUERY_LENGTH} characters)"
|
| 85 |
+
|
| 86 |
+
return True, ""
|
| 87 |
+
|
| 88 |
+
except Exception as e:
|
| 89 |
+
return False, f"Validation error: {str(e)}"
|
models/__init__.py
ADDED
|
File without changes
|
models/data_models.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import List, Dict
|
| 3 |
+
|
| 4 |
+
# @dataclass
|
| 5 |
+
# class ContextInput:
|
| 6 |
+
# processed_query: str
|
| 7 |
+
# constraints: List[str]
|
| 8 |
+
# search_sources: List[str] = field(default_factory=lambda: ['wikipedia'])
|
| 9 |
+
# max_results: int = 5
|
| 10 |
+
@dataclass
|
| 11 |
+
class ContextInput:
|
| 12 |
+
processed_query: str
|
| 13 |
+
constraints: List[str]
|
| 14 |
+
domain: str # e.g., "oil_and_gas", "fine_art"
|
| 15 |
+
max_results: int = 5
|
| 16 |
+
min_confidence: float = 0.7
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class UserInput:
|
| 20 |
+
query: str
|
| 21 |
+
constraints: str
|
| 22 |
+
top_k: int
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class ImageAnalysisInput:
|
| 26 |
+
images: List[str] # List of image paths
|
| 27 |
+
context: Dict # Context from ContextLearnAgent
|
| 28 |
+
constraints: List[str]
|
| 29 |
+
top_k: int = 5
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class AgentState:
|
| 33 |
+
"""Tracks the current state of the agent"""
|
| 34 |
+
intent: str = ""
|
| 35 |
+
thoughts: List[str] = field(default_factory=list)
|
| 36 |
+
decisions: List[Dict] = field(default_factory=list)
|
| 37 |
+
errors: List[str] = field(default_factory=list)
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class AssemblerInput:
|
| 41 |
+
user_input_results: Dict # From UserInputAgent
|
| 42 |
+
context_results: Dict # From ContextLearnAgent
|
| 43 |
+
image_results: Dict # From ImageAnalyzerAgent
|
| 44 |
+
report_format: str = "detailed" # or "summary"
|
models/model_config.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Literal
|
| 3 |
+
|
| 4 |
+
@dataclass
|
| 5 |
+
class ModelConfig:
|
| 6 |
+
name: str
|
| 7 |
+
type: Literal['lightweight', 'advanced']
|
| 8 |
+
task: Literal['caption', 'classification', 'detection', 'summarization']
|
| 9 |
+
threshold: float = 0.5
|
requirements.txt
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core dependencies
|
| 2 |
+
gradio>=3.50.2
|
| 3 |
+
torch>=2.0.0
|
| 4 |
+
transformers>=4.30.0
|
| 5 |
+
langchain>=0.0.300
|
| 6 |
+
|
| 7 |
+
# Image processing
|
| 8 |
+
Pillow>=10.0.0
|
| 9 |
+
numpy>=1.24.0
|
| 10 |
+
|
| 11 |
+
# Resource management
|
| 12 |
+
psutil>=5.9.0
|
| 13 |
+
|
| 14 |
+
# Web and API
|
| 15 |
+
requests>=2.31.0
|
| 16 |
+
httpx>=0.24.1
|
| 17 |
+
|
| 18 |
+
# Utilities
|
| 19 |
+
python-dotenv>=1.0.0
|
| 20 |
+
tqdm>=4.65.0
|
| 21 |
+
pandas>=2.0.0
|
| 22 |
+
|
| 23 |
+
# Text processing
|
| 24 |
+
nltk>=3.8.1
|
| 25 |
+
beautifulsoup4>=4.12.0
|
| 26 |
+
|
| 27 |
+
# Wikipedia access
|
| 28 |
+
wikipedia-api>=0.6.0
|
| 29 |
+
|
| 30 |
+
# Logging and monitoring
|
| 31 |
+
logging>=0.5.1.2
|
utils/resource_manager.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# utils/resource_manager.py
|
| 2 |
+
import torch
|
| 3 |
+
import gc
|
| 4 |
+
import psutil
|
| 5 |
+
import os
|
| 6 |
+
from typing import Dict, Tuple
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
class ResourceManager:
|
| 10 |
+
# Constants
|
| 11 |
+
MAX_MEMORY_USAGE_PCT = 90 # Maximum memory usage percentage
|
| 12 |
+
CLEANUP_THRESHOLD_PCT = 80 # Threshold to trigger cleanup
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.logger = logging.getLogger(__name__)
|
| 16 |
+
self._active_models: Dict = {}
|
| 17 |
+
|
| 18 |
+
def check_resources(self) -> Tuple[bool, str]:
|
| 19 |
+
"""Check if sufficient resources are available"""
|
| 20 |
+
try:
|
| 21 |
+
memory = psutil.Process(os.getpid()).memory_percent()
|
| 22 |
+
if memory > self.MAX_MEMORY_USAGE_PCT:
|
| 23 |
+
return False, f"Memory usage too high: {memory:.1f}%"
|
| 24 |
+
return True, ""
|
| 25 |
+
except Exception as e:
|
| 26 |
+
return False, f"Resource check error: {str(e)}"
|
| 27 |
+
|
| 28 |
+
def cleanup(self):
|
| 29 |
+
"""Force cleanup of resources"""
|
| 30 |
+
try:
|
| 31 |
+
# Clear models
|
| 32 |
+
self._active_models.clear()
|
| 33 |
+
|
| 34 |
+
# Clear CUDA cache
|
| 35 |
+
if torch.cuda.is_available():
|
| 36 |
+
torch.cuda.empty_cache()
|
| 37 |
+
|
| 38 |
+
# Force garbage collection
|
| 39 |
+
gc.collect()
|
| 40 |
+
|
| 41 |
+
self.logger.info("Resource cleanup completed")
|
| 42 |
+
|
| 43 |
+
except Exception as e:
|
| 44 |
+
self.logger.error(f"Error during cleanup: {str(e)}")
|
| 45 |
+
|
| 46 |
+
def monitor_and_cleanup(self):
|
| 47 |
+
"""Monitor resources and cleanup if needed"""
|
| 48 |
+
try:
|
| 49 |
+
memory = psutil.Process(os.getpid()).memory_percent()
|
| 50 |
+
if memory > self.CLEANUP_THRESHOLD_PCT:
|
| 51 |
+
self.logger.warning(f"High memory usage ({memory:.1f}%), triggering cleanup")
|
| 52 |
+
self.cleanup()
|
| 53 |
+
except Exception as e:
|
| 54 |
+
self.logger.error(f"Monitoring error: {str(e)}")
|