Spaces:
Sleeping
Sleeping
| """ | |
| Product Comparison Multi-Agent System | |
| This module implements a multi-agent system for product comparison based on images. | |
| The system uses various specialized agents to process images, extract features, | |
| compare products, and provide recommendations. | |
| Main components: | |
| - Coordinator: Orchestrates the multi-agent workflow | |
| - ImageProcessingAgent: Handles image recognition and object detection | |
| - FeatureExtractionAgent: Extracts product specifications and features | |
| - ComparisonAgent: Compares products based on extracted features | |
| - RecommendationAgent: Provides purchase recommendations | |
| Each agent utilizes vision models and LLMs to perform its specialized tasks. | |
| """ | |
| import os | |
| import uuid | |
| import json | |
| import time | |
| from typing import Dict, List, Any, Optional, Tuple | |
| import base64 | |
| from io import BytesIO | |
| from datetime import datetime | |
| from threading import Thread | |
| import torch | |
| from PIL import Image | |
| import numpy as np | |
| # Import LangChain components for agent implementation | |
| try: | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser, JsonOutputParser | |
| from langchain.agents import create_openai_functions_agent | |
| from langchain.agents import AgentExecutor | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.tools.render import format_tool_to_openai_function | |
| from langchain_openai import ChatOpenAI | |
| from langchain_experimental.tools.python.tool import PythonAstREPLTool | |
| LANGCHAIN_AVAILABLE = True | |
| except ImportError: | |
| print("Warning: LangChain components not available. Product comparison will work with limited functionality.") | |
| # Set all LangChain components to None | |
| ChatPromptTemplate = None | |
| StrOutputParser = None | |
| JsonOutputParser = None | |
| create_openai_functions_agent = None | |
| AgentExecutor = None | |
| ConversationBufferMemory = None | |
| format_tool_to_openai_function = None | |
| ChatOpenAI = None | |
| PythonAstREPLTool = None | |
| LANGCHAIN_AVAILABLE = False | |
| # Import vision models if available | |
| try: | |
| from ultralytics import YOLO | |
| except ImportError: | |
| YOLO = None | |
| try: | |
| from transformers import CLIPProcessor, CLIPModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from transformers import ViTImageProcessor, ViTForImageClassification | |
| except ImportError: | |
| CLIPProcessor = None | |
| CLIPModel = None | |
| AutoModelForCausalLM = None | |
| AutoTokenizer = None | |
| ViTImageProcessor = None | |
| ViTForImageClassification = None | |
| # Session storage for SSE communication | |
| class SessionManager: | |
| """Manages active product comparison sessions and their event streams""" | |
| def __init__(self): | |
| self.active_sessions = {} | |
| self.results = {} | |
| def create_session(self, session_id=None): | |
| """Create a new session with unique ID""" | |
| if session_id is None: | |
| session_id = str(uuid.uuid4()) | |
| self.active_sessions[session_id] = { | |
| "created_at": datetime.now(), | |
| "messages": [], | |
| "status": "created" | |
| } | |
| return session_id | |
| def add_message(self, session_id, message, agent_type="system"): | |
| """Add a message to the session event stream""" | |
| if session_id not in self.active_sessions: | |
| return False | |
| timestamp = datetime.now().strftime("%H:%M:%S") | |
| self.active_sessions[session_id]["messages"].append({ | |
| "message": message, | |
| "agent": agent_type, | |
| "timestamp": timestamp | |
| }) | |
| return True | |
| def get_messages(self, session_id): | |
| """Get all messages for a session""" | |
| if session_id not in self.active_sessions: | |
| return [] | |
| return self.active_sessions[session_id]["messages"] | |
| def set_final_result(self, session_id, result): | |
| """Store the final analysis result for a session""" | |
| self.results[session_id] = result | |
| self.active_sessions[session_id]["status"] = "completed" | |
| def get_final_result(self, session_id): | |
| """Get the final result for a session""" | |
| return self.results.get(session_id) | |
| def set_status(self, session_id, status): | |
| """Update session status""" | |
| if session_id in self.active_sessions: | |
| self.active_sessions[session_id]["status"] = status | |
| def get_status(self, session_id): | |
| """Get session status""" | |
| if session_id in self.active_sessions: | |
| return self.active_sessions[session_id]["status"] | |
| return None | |
| def clean_old_sessions(self, max_age_hours=24): | |
| """Clean up old sessions""" | |
| now = datetime.now() | |
| sessions_to_remove = [] | |
| for session_id, session_data in self.active_sessions.items(): | |
| age = now - session_data["created_at"] | |
| if age.total_seconds() > max_age_hours * 3600: | |
| sessions_to_remove.append(session_id) | |
| for session_id in sessions_to_remove: | |
| del self.active_sessions[session_id] | |
| if session_id in self.results: | |
| del self.results[session_id] | |
| # Initialize the session manager | |
| session_manager = SessionManager() | |
| # Base Agent Class | |
| class BaseAgent: | |
| """Base class for all specialized agents""" | |
| def __init__(self, name, llm=None): | |
| self.name = name | |
| # Use LangChain ChatOpenAI as the default LLM if none is provided | |
| if llm is None: | |
| try: | |
| if LANGCHAIN_AVAILABLE and ChatOpenAI is not None: | |
| # Initialize ChatOpenAI with environment variable for API key | |
| api_key = os.environ.get('OPENAI_API_KEY') | |
| if api_key: | |
| self.llm = ChatOpenAI( | |
| model="gpt-4o", | |
| temperature=0.7, | |
| api_key=api_key | |
| ) | |
| print(f"Initialized {name} with ChatOpenAI (gpt-3.5-turbo)") | |
| else: | |
| print(f"Warning: OPENAI_API_KEY not found. {name} will use fallback mode.") | |
| self.llm = None | |
| else: | |
| print(f"Warning: LangChain not available. {name} will use fallback mode.") | |
| self.llm = None | |
| except Exception as e: | |
| print(f"Error initializing ChatOpenAI for {name}: {e}") | |
| self.llm = None | |
| else: | |
| self.llm = llm | |
| def log(self, session_id, message): | |
| """Log a message to the session""" | |
| return session_manager.add_message(session_id, message, agent_type=self.name) | |
| async def process(self, session_id, data): | |
| """Process data with this agent - to be implemented by subclasses""" | |
| raise NotImplementedError("Subclasses must implement this method") | |
| class ImageProcessingAgent(BaseAgent): | |
| """Agent responsible for image processing and object recognition | |
| This agent uses computer vision models to detect product type, brand, model, | |
| and other visual characteristics from product images. | |
| """ | |
| def __init__(self, name="ImageProcessingAgent"): | |
| super().__init__(name) | |
| # Initialize vision models | |
| self.models = self._load_vision_models() | |
| def _load_vision_models(self): | |
| """Load vision models for product recognition""" | |
| models = {} | |
| # Try to load YOLO for object detection | |
| try: | |
| if YOLO is not None: | |
| models["yolo"] = YOLO("yolov8n.pt") | |
| except Exception as e: | |
| print(f"Error loading YOLO: {e}") | |
| # Try to load ViT for image classification | |
| try: | |
| if ViTImageProcessor is not None and ViTForImageClassification is not None: | |
| models["vit_processor"] = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") | |
| models["vit_model"] = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") | |
| except Exception as e: | |
| print(f"Error loading ViT: {e}") | |
| # Try to load CLIP for visual embedding | |
| try: | |
| if CLIPProcessor is not None and CLIPModel is not None: | |
| models["clip_processor"] = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| models["clip_model"] = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| except Exception as e: | |
| print(f"Error loading CLIP: {e}") | |
| return models | |
| def _process_with_yolo(self, image): | |
| """Process image with YOLO for object detection""" | |
| if "yolo" not in self.models: | |
| return {"error": "YOLO model not available"} | |
| # Convert image to numpy array if it's a PIL image | |
| if isinstance(image, Image.Image): | |
| image_np = np.array(image) | |
| else: | |
| image_np = image | |
| # Run inference | |
| results = self.models["yolo"](image_np) | |
| # Extract detection results | |
| detections = [] | |
| for result in results: | |
| boxes = result.boxes | |
| for box in boxes: | |
| class_id = int(box.cls[0].item()) | |
| class_name = result.names[class_id] | |
| confidence = round(box.conf[0].item(), 2) | |
| bbox = box.xyxy[0].tolist() | |
| bbox = [round(x) for x in bbox] | |
| detections.append({ | |
| "class": class_name, | |
| "confidence": confidence, | |
| "bbox": bbox | |
| }) | |
| return {"detections": detections} | |
| def _process_with_vit(self, image): | |
| """Process image with ViT for classification""" | |
| if "vit_model" not in self.models or "vit_processor" not in self.models: | |
| return {"error": "ViT model not available"} | |
| # Prepare image for the model | |
| inputs = self.models["vit_processor"](images=image, return_tensors="pt") | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = self.models["vit_model"](**inputs) | |
| logits = outputs.logits | |
| # Get top 5 predictions | |
| probs = torch.nn.functional.softmax(logits, dim=-1)[0] | |
| top5_prob, top5_indices = torch.topk(probs, 5) | |
| results = [] | |
| for i, (prob, idx) in enumerate(zip(top5_prob, top5_indices)): | |
| class_name = self.models["vit_model"].config.id2label[idx.item()] | |
| results.append({ | |
| "rank": i+1, | |
| "class": class_name, | |
| "probability": round(prob.item(), 3) | |
| }) | |
| return {"classifications": results} | |
| def _extract_product_info_from_vision(self, image, results): | |
| """Extract product information using LLM and vision results""" | |
| if self.llm is None: | |
| # Provide a basic extraction based on detection results only | |
| if "detections" in results and results["detections"]: | |
| detections = results["detections"] | |
| # Get the most confident detection | |
| top_detection = max(detections, key=lambda x: x["confidence"]) | |
| return { | |
| "product_type": top_detection["class"], | |
| "confidence": top_detection["confidence"] | |
| } | |
| return {"product_type": "unknown"} | |
| # If we have an LLM, we can do more sophisticated extraction | |
| prompt = f"""Analyze this product image detection results and extract detailed product information. | |
| Detection results: {json.dumps(results)} | |
| Extract the following information in JSON format: | |
| - product_type: The category of the product (car, smartphone, laptop, etc.) | |
| - brand: The most likely brand of the product | |
| - model: Any model information that can be determined | |
| - color: The main color of the product | |
| - key_features: List of notable visual features | |
| Return only valid JSON format.""" | |
| try: | |
| # Use LangChain's invoke method | |
| response = self.llm.invoke(prompt) | |
| # Extract content from LangChain response | |
| if hasattr(response, 'content'): | |
| response_text = response.content | |
| else: | |
| response_text = str(response) | |
| # Try to parse as JSON | |
| try: | |
| extracted = json.loads(response_text) | |
| return extracted | |
| except json.JSONDecodeError: | |
| # If LLM output is not valid JSON, extract key information using simple parsing | |
| lines = response_text.split('\n') | |
| extracted = {} | |
| for line in lines: | |
| if ':' in line: | |
| key, value = line.split(':', 1) | |
| key = key.strip().lower().replace(' ', '_') | |
| value = value.strip().strip('"').strip("'") | |
| if key and value: | |
| extracted[key] = value | |
| return extracted | |
| except Exception as e: | |
| print(f"Error extracting product info: {e}") | |
| return {"product_type": "unknown", "error": str(e)} | |
| async def process(self, session_id, data): | |
| """Process product images to identify products and extract features""" | |
| self.log(session_id, "Starting image analysis to identify products...") | |
| results = {} | |
| product_info = {} | |
| # Process each image if available | |
| images = data.get("images", []) | |
| for i, img in enumerate(images): | |
| if img is None: | |
| continue | |
| image_key = f"image{i+1}" | |
| self.log(session_id, f"Processing {image_key}...") | |
| # Run object detection | |
| yolo_results = self._process_with_yolo(img) | |
| # Run classification | |
| vit_results = self._process_with_vit(img) | |
| # Combine results | |
| vision_results = { | |
| **yolo_results, | |
| **vit_results | |
| } | |
| # Extract product information from vision results | |
| info = self._extract_product_info_from_vision(img, vision_results) | |
| self.log(session_id, f"Identified product in {image_key}: {info.get('product_type', 'unknown')}") | |
| if "brand" in info: | |
| self.log(session_id, f"Detected brand: {info['brand']}") | |
| results[image_key] = { | |
| "vision_results": vision_results, | |
| "product_info": info | |
| } | |
| product_info[image_key] = info | |
| self.log(session_id, "Image processing completed") | |
| return { | |
| "vision_results": results, | |
| "product_info": product_info | |
| } | |
| class FeatureExtractionAgent(BaseAgent): | |
| """Agent responsible for extracting detailed product features and specifications | |
| This agent analyzes image processing results and uses LLMs to extract detailed | |
| product specifications, features, and technical details. | |
| """ | |
| def __init__(self, name="FeatureExtractionAgent"): | |
| super().__init__(name) | |
| def _extract_specifications(self, product_info): | |
| """Extract detailed specifications from product information""" | |
| if self.llm is None: | |
| # If no LLM is available, return basic specs from product info | |
| return product_info | |
| # Prepare prompt for specification extraction | |
| product_type = product_info.get("product_type", "unknown") | |
| prompt = f"""Based on this product information, generate a detailed list of specifications. | |
| Product information: {json.dumps(product_info)} | |
| For a {product_type}, extract or infer the following specifications in JSON format: | |
| """ | |
| # Add product-specific specification types based on product type | |
| if "car" in product_type.lower() or "vehicle" in product_type.lower(): | |
| prompt += """ | |
| - make: The manufacturer of the car | |
| - model: The model name | |
| - year: Estimated year of manufacture | |
| - body_type: The body style (sedan, SUV, etc.) | |
| - fuel_type: The fuel type if identifiable | |
| - engine: Engine specifications if identifiable | |
| - color: Exterior color | |
| - features: List of visible features | |
| """ | |
| elif "phone" in product_type.lower() or "smartphone" in product_type.lower(): | |
| prompt += """ | |
| - brand: The manufacturer | |
| - model: The phone model | |
| - screen_size: Estimated screen size | |
| - camera: Visible camera specifications | |
| - color: Device color | |
| - generation: Device generation if identifiable | |
| - features: List of visible features | |
| """ | |
| elif "laptop" in product_type.lower() or "computer" in product_type.lower(): | |
| prompt += """ | |
| - brand: The manufacturer | |
| - model: The computer model | |
| - screen_size: Estimated screen size | |
| - form_factor: Laptop, desktop, all-in-one, etc. | |
| - color: Device color | |
| - features: List of visible features | |
| """ | |
| else: | |
| # Generic product specifications | |
| prompt += """ | |
| - brand: The manufacturer if identifiable | |
| - model: The product model if identifiable | |
| - color: Product color | |
| - dimensions: Estimated dimensions | |
| - features: List of visible features | |
| - materials: Visible materials used | |
| """ | |
| try: | |
| # Use LangChain's invoke method | |
| response = self.llm.invoke(prompt) | |
| # Extract content from LangChain response | |
| if hasattr(response, 'content'): | |
| response_text = response.content | |
| else: | |
| response_text = str(response) | |
| # Try to parse as JSON | |
| try: | |
| specs = json.loads(response_text) | |
| return specs | |
| except json.JSONDecodeError: | |
| # If LLM output is not valid JSON, extract key information using simple parsing | |
| lines = response_text.split('\n') | |
| specs = {} | |
| for line in lines: | |
| if ':' in line: | |
| key, value = line.split(':', 1) | |
| key = key.strip().lower().replace(' ', '_').strip('-').strip('_') | |
| value = value.strip().strip('"').strip("'") | |
| if key and value: | |
| specs[key] = value | |
| return specs | |
| except Exception as e: | |
| print(f"Error extracting specifications: {e}") | |
| return {"error": str(e)} | |
| def _get_feature_highlights(self, specs): | |
| """Extract key feature highlights from specifications""" | |
| if not specs or not isinstance(specs, dict): | |
| return [] | |
| highlights = [] | |
| # Extract key features based on product type | |
| product_type = specs.get("product_type", "").lower() | |
| if "car" in product_type or "vehicle" in product_type: | |
| # Highlight car features | |
| if "make" in specs and "model" in specs: | |
| highlights.append(f"{specs['make']} {specs['model']}") | |
| if "year" in specs: | |
| highlights.append(f"{specs['year']} model") | |
| if "engine" in specs: | |
| highlights.append(f"Engine: {specs['engine']}") | |
| if "body_type" in specs: | |
| highlights.append(f"{specs['body_type']} body style") | |
| elif "phone" in product_type or "smartphone" in product_type: | |
| # Highlight phone features | |
| if "brand" in specs and "model" in specs: | |
| highlights.append(f"{specs['brand']} {specs['model']}") | |
| if "screen_size" in specs: | |
| highlights.append(f"{specs['screen_size']} display") | |
| if "camera" in specs: | |
| highlights.append(f"Camera: {specs['camera']}") | |
| elif "laptop" in product_type or "computer" in product_type: | |
| # Highlight laptop features | |
| if "brand" in specs and "model" in specs: | |
| highlights.append(f"{specs['brand']} {specs['model']}") | |
| if "screen_size" in specs: | |
| highlights.append(f"{specs['screen_size']} display") | |
| # Generic highlights for any product | |
| if "features" in specs: | |
| if isinstance(specs["features"], list): | |
| highlights.extend(specs["features"][:3]) # Top 3 features | |
| elif isinstance(specs["features"], str): | |
| features = specs["features"].split(",") | |
| highlights.extend([f.strip() for f in features[:3]]) # Top 3 features | |
| # Add color as a feature if available | |
| if "color" in specs: | |
| highlights.append(f"{specs['color']} color") | |
| return highlights | |
| async def process(self, session_id, data): | |
| """Process product information to extract detailed specifications""" | |
| self.log(session_id, "Extracting detailed product specifications...") | |
| results = {} | |
| product_info = data.get("product_info", {}) | |
| if not product_info: | |
| self.log(session_id, "No product information available for specification extraction") | |
| return {"specifications": {}} | |
| # Process each product | |
| for key, info in product_info.items(): | |
| self.log(session_id, f"Extracting specifications for {key}...") | |
| # Extract detailed specifications | |
| specs = self._extract_specifications(info) | |
| # Get feature highlights | |
| highlights = self._get_feature_highlights(specs) | |
| # Log results | |
| if highlights: | |
| self.log(session_id, f"Key features for {key}: {', '.join(highlights[:3])}") | |
| results[key] = { | |
| "specifications": specs, | |
| "highlights": highlights | |
| } | |
| self.log(session_id, "Feature extraction completed") | |
| return {"specifications": results} | |
| class ComparisonAgent(BaseAgent): | |
| """Agent responsible for comparing products based on their specifications | |
| This agent analyzes the specifications of multiple products and identifies | |
| the key differences, advantages, and disadvantages between them. | |
| """ | |
| def __init__(self, name="ComparisonAgent"): | |
| super().__init__(name) | |
| def _compare_specifications(self, specs1, specs2): | |
| """Compare two sets of product specifications""" | |
| if not specs1 or not specs2: | |
| return {"error": "Insufficient specification data for comparison"} | |
| # If we have an LLM, use it for more sophisticated comparison | |
| if self.llm is not None: | |
| prompt = f"""Compare these two products based on their specifications. | |
| Product 1: {json.dumps(specs1)} | |
| Product 2: {json.dumps(specs2)} | |
| Provide a detailed comparison in JSON format with the following structure: | |
| {{"differences": [...], "product1_advantages": [...], "product2_advantages": [...], "summary": "..."}} | |
| - differences: List key differences between the products | |
| - product1_advantages: List advantages of Product 1 over Product 2 | |
| - product2_advantages: List advantages of Product 2 over Product 1 | |
| - summary: A concise summary of the comparison | |
| """ | |
| try: | |
| # Use LangChain's invoke method | |
| response = self.llm.invoke(prompt) | |
| # Extract content from LangChain response | |
| if hasattr(response, 'content'): | |
| response_text = response.content | |
| else: | |
| response_text = str(response) | |
| # Try to parse as JSON | |
| try: | |
| comparison = json.loads(response_text) | |
| return comparison | |
| except json.JSONDecodeError: | |
| # If LLM output is not valid JSON, extract key sections using simple parsing | |
| sections = {} | |
| current_section = None | |
| section_content = [] | |
| lines = response_text.split('\n') | |
| for line in lines: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| if line.endswith(":") and not line.startswith("-"): | |
| # This is a section header | |
| if current_section and section_content: | |
| sections[current_section] = section_content | |
| section_content = [] | |
| current_section = line.strip(":").lower().replace(" ", "_") | |
| elif line.startswith("-") and current_section: | |
| # This is a list item | |
| section_content.append(line.strip("- ")) | |
| elif current_section: | |
| # This is content for the current section | |
| section_content.append(line) | |
| # Add the last section | |
| if current_section and section_content: | |
| sections[current_section] = section_content | |
| return sections | |
| except Exception as e: | |
| print(f"Error in LLM comparison: {e}") | |
| # Fall back to simple comparison | |
| pass | |
| # Simple comparison logic as fallback | |
| differences = [] | |
| product1_advantages = [] | |
| product2_advantages = [] | |
| # Identify common keys to compare | |
| all_keys = set(list(specs1.keys()) + list(specs2.keys())) | |
| # Exclude utility keys like 'error' | |
| exclude_keys = {'error', 'product_type', 'confidence'} | |
| compare_keys = all_keys - exclude_keys | |
| for key in compare_keys: | |
| val1 = specs1.get(key) | |
| val2 = specs2.get(key) | |
| if val1 is None and val2 is not None: | |
| differences.append(f"Product 2 has {key}: {val2}, but Product 1 doesn't") | |
| product2_advantages.append(f"Has {key}: {val2}") | |
| elif val1 is not None and val2 is None: | |
| differences.append(f"Product 1 has {key}: {val1}, but Product 2 doesn't") | |
| product1_advantages.append(f"Has {key}: {val1}") | |
| elif val1 != val2: | |
| differences.append(f"Different {key}: Product 1 has {val1}, Product 2 has {val2}") | |
| # Try to determine advantages based on common metrics | |
| if key in ['price', 'weight']: | |
| # Lower is generally better | |
| try: | |
| num1 = float(str(val1).split()[0]) | |
| num2 = float(str(val2).split()[0]) | |
| if num1 < num2: | |
| product1_advantages.append(f"Lower {key}: {val1}") | |
| else: | |
| product2_advantages.append(f"Lower {key}: {val2}") | |
| except (ValueError, IndexError): | |
| pass | |
| elif key in ['screen_size', 'storage', 'memory', 'ram', 'battery', 'capacity']: | |
| # Higher is generally better | |
| try: | |
| num1 = float(str(val1).split()[0]) | |
| num2 = float(str(val2).split()[0]) | |
| if num1 > num2: | |
| product1_advantages.append(f"Higher {key}: {val1}") | |
| else: | |
| product2_advantages.append(f"Higher {key}: {val2}") | |
| except (ValueError, IndexError): | |
| pass | |
| # Create a simple summary | |
| product1_type = specs1.get('product_type', 'Product 1') | |
| product2_type = specs2.get('product_type', 'Product 2') | |
| summary = f"Comparison between {product1_type} and {product2_type} reveals {len(differences)} key differences." | |
| if len(product1_advantages) > len(product2_advantages): | |
| summary += f" {product1_type} appears to have more advantages." | |
| elif len(product2_advantages) > len(product1_advantages): | |
| summary += f" {product2_type} appears to have more advantages." | |
| else: | |
| summary += " Both products have similar number of advantages." | |
| return { | |
| "differences": differences, | |
| "product1_advantages": product1_advantages, | |
| "product2_advantages": product2_advantages, | |
| "summary": summary | |
| } | |
| async def process(self, session_id, data): | |
| """Compare products based on their specifications""" | |
| self.log(session_id, "Starting product comparison analysis...") | |
| specifications = data.get("specifications", {}) | |
| if len(specifications) < 2: | |
| self.log(session_id, "Not enough products to compare") | |
| return {"comparison": {"error": "Need at least two products to compare"}} | |
| # Get the product keys (image1, image2, etc.) | |
| product_keys = list(specifications.keys()) | |
| if len(product_keys) > 2: | |
| self.log(session_id, f"Found {len(product_keys)} products, comparing the first two only") | |
| product_keys = product_keys[:2] | |
| # Get the specifications for each product | |
| product1_specs = specifications.get(product_keys[0], {}).get("specifications", {}) | |
| product2_specs = specifications.get(product_keys[1], {}).get("specifications", {}) | |
| # Perform comparison | |
| comparison = self._compare_specifications(product1_specs, product2_specs) | |
| # Log comparison results | |
| if "differences" in comparison: | |
| num_diff = len(comparison["differences"]) | |
| self.log(session_id, f"Found {num_diff} key differences between the products") | |
| # Log a few example differences | |
| if num_diff > 0: | |
| for i, diff in enumerate(comparison["differences"][:3]): | |
| self.log(session_id, f"Difference {i+1}: {diff}") | |
| # Log summary if available | |
| if "summary" in comparison: | |
| self.log(session_id, f"Comparison summary: {comparison['summary']}") | |
| self.log(session_id, "Comparison analysis completed") | |
| return { | |
| "comparison": comparison, | |
| "product_keys": product_keys | |
| } | |
| class RecommendationAgent(BaseAgent): | |
| """Agent responsible for providing purchase recommendations | |
| This agent analyzes product comparisons and provides personalized recommendations | |
| based on the user's needs and preferences. | |
| """ | |
| def __init__(self, name="RecommendationAgent"): | |
| super().__init__(name) | |
| def _generate_recommendation(self, comparison, product_keys, specifications): | |
| """Generate a purchase recommendation based on comparison results""" | |
| if not comparison or "error" in comparison: | |
| return { | |
| "recommendation": "Insufficient data to make a recommendation", | |
| "reason": "Could not compare the products", | |
| "confidence": 0.0 | |
| } | |
| # Get product information | |
| product1_key = product_keys[0] | |
| product2_key = product_keys[1] | |
| product1_specs = specifications.get(product1_key, {}).get("specifications", {}) | |
| product2_specs = specifications.get(product2_key, {}).get("specifications", {}) | |
| product1_type = product1_specs.get("product_type", product1_key) | |
| product2_type = product2_specs.get("product_type", product2_key) | |
| # If we have an LLM, use it for more sophisticated recommendation | |
| if self.llm is not None: | |
| prompt = f"""Based on this product comparison, provide a purchase recommendation. | |
| Product 1: {json.dumps(product1_specs)} | |
| Product 2: {json.dumps(product2_specs)} | |
| Comparison: {json.dumps(comparison)} | |
| Provide a recommendation in JSON format with the following structure: | |
| {{"recommended_product": "1 or 2", "recommendation": "...", "reason": "...", "confidence": 0.0-1.0, "use_cases": [...]}} | |
| - recommended_product: Either "1" or "2" indicating which product is recommended | |
| - recommendation: A concise recommendation statement | |
| - reason: The main reason for the recommendation | |
| - confidence: A confidence score between 0 and 1 | |
| - use_cases: List of ideal use cases for the recommended product | |
| """ | |
| try: | |
| # Use LangChain's invoke method | |
| response = self.llm.invoke(prompt) | |
| # Extract content from LangChain response | |
| if hasattr(response, 'content'): | |
| response_text = response.content | |
| else: | |
| response_text = str(response) | |
| # Try to parse as JSON | |
| try: | |
| recommendation = json.loads(response_text) | |
| return recommendation | |
| except json.JSONDecodeError: | |
| # If LLM output is not valid JSON, extract key information using simple parsing | |
| lines = response_text.split('\n') | |
| recommendation = {} | |
| # Look for recommendation indicator | |
| for line in lines: | |
| if "recommended product" in line.lower() and ("1" in line or "2" in line): | |
| recommendation["recommended_product"] = "1" if "1" in line else "2" | |
| elif "recommendation" in line.lower() and ":" in line: | |
| recommendation["recommendation"] = line.split(":", 1)[1].strip().strip('"') | |
| elif "reason" in line.lower() and ":" in line: | |
| recommendation["reason"] = line.split(":", 1)[1].strip().strip('"') | |
| elif "confidence" in line.lower() and ":" in line: | |
| try: | |
| conf_str = line.split(":", 1)[1].strip() | |
| conf = float(conf_str) | |
| recommendation["confidence"] = conf | |
| except ValueError: | |
| recommendation["confidence"] = 0.5 | |
| elif "use cases" in line.lower() and ":" in line: | |
| use_cases = line.split(":", 1)[1].strip().strip('"') | |
| recommendation["use_cases"] = [uc.strip() for uc in use_cases.split(",")] | |
| return recommendation | |
| except Exception as e: | |
| print(f"Error in LLM recommendation: {e}") | |
| # Fall back to simple recommendation | |
| pass | |
| # Simple recommendation logic as fallback | |
| product1_advantages = comparison.get("product1_advantages", []) | |
| product2_advantages = comparison.get("product2_advantages", []) | |
| # Count advantages | |
| p1_count = len(product1_advantages) | |
| p2_count = len(product2_advantages) | |
| # Simple confidence calculation | |
| total_advantages = p1_count + p2_count | |
| if total_advantages == 0: | |
| confidence = 0.5 # Can't determine | |
| else: | |
| # Maximum confidence is 0.95, minimum is 0.55 | |
| max_confidence = 0.95 | |
| min_confidence = 0.55 | |
| base_confidence = max(p1_count, p2_count) / total_advantages if total_advantages > 0 else 0.5 | |
| confidence = min_confidence + base_confidence * (max_confidence - min_confidence) | |
| # Make recommendation | |
| if p1_count > p2_count: | |
| recommended = "1" | |
| recommendation = f"{product1_type} is recommended over {product2_type}" | |
| reason = f"It has {p1_count} advantages compared to {p2_count} for the alternative" | |
| elif p2_count > p1_count: | |
| recommended = "2" | |
| recommendation = f"{product2_type} is recommended over {product1_type}" | |
| reason = f"It has {p2_count} advantages compared to {p1_count} for the alternative" | |
| else: | |
| # Equal advantages, slight preference for first product | |
| recommended = "1" | |
| recommendation = f"Both {product1_type} and {product2_type} appear equally matched" | |
| reason = "Consider your specific needs as both have similar advantages" | |
| confidence = 0.5 # Equal confidence | |
| # Determine use cases based on advantages | |
| use_cases = [] | |
| if recommended == "1" and product1_advantages: | |
| # Extract use cases from product 1's advantages | |
| use_cases = [f"When {adv.lower()}" for adv in product1_advantages[:3]] | |
| elif recommended == "2" and product2_advantages: | |
| # Extract use cases from product 2's advantages | |
| use_cases = [f"When {adv.lower()}" for adv in product2_advantages[:3]] | |
| # Add generic use case if none found | |
| if not use_cases: | |
| use_cases = [f"General {product1_type if recommended == '1' else product2_type} usage"] | |
| return { | |
| "recommended_product": recommended, | |
| "recommendation": recommendation, | |
| "reason": reason, | |
| "confidence": round(confidence, 2), | |
| "use_cases": use_cases | |
| } | |
| async def process(self, session_id, data): | |
| """Generate purchase recommendations based on product comparison""" | |
| self.log(session_id, "Generating product recommendations...") | |
| # Get comparison data | |
| comparison = data.get("comparison", {}) | |
| product_keys = data.get("product_keys", []) | |
| specifications = data.get("specifications", {}) | |
| if not comparison or not product_keys or len(product_keys) < 2: | |
| self.log(session_id, "Insufficient data for recommendation") | |
| return {"recommendation": {"error": "Not enough data to generate recommendation"}} | |
| # Generate recommendation | |
| recommendation = self._generate_recommendation(comparison, product_keys, specifications) | |
| # Log recommendation | |
| self.log(session_id, f"Recommendation: {recommendation.get('recommendation', 'No recommendation')}") | |
| self.log(session_id, f"Reason: {recommendation.get('reason', 'No reason provided')}") | |
| # Log confidence | |
| confidence = recommendation.get('confidence', 0) | |
| confidence_percent = f"{int(confidence * 100)}%" | |
| self.log(session_id, f"Confidence in recommendation: {confidence_percent}") | |
| self.log(session_id, "Recommendation generation completed") | |
| return {"recommendation": recommendation} | |
| class ProductComparisonCoordinator: | |
| """Main coordinator for the product comparison multi-agent system | |
| This class orchestrates the entire product comparison workflow by managing | |
| all the specialized agents and their interactions. | |
| """ | |
| def __init__(self): | |
| # Initialize all agents | |
| self.image_processor = ImageProcessingAgent() | |
| self.feature_extractor = FeatureExtractionAgent() | |
| self.comparison_agent = ComparisonAgent() | |
| self.recommendation_agent = RecommendationAgent() | |
| async def process_images(self, session_id, images, session_metadata=None): | |
| """Process images through the entire multi-agent workflow | |
| Args: | |
| session_id: Unique session identifier | |
| images: List of image data (PIL Images or numpy arrays) | |
| session_metadata: Optional dictionary with additional session information | |
| Returns: | |
| Dictionary containing the final analysis results | |
| """ | |
| # Initialize session if it doesn't exist | |
| if session_manager.get_status(session_id) is None: | |
| session_manager.create_session(session_id) | |
| # Set default metadata if not provided | |
| if session_metadata is None: | |
| session_metadata = {} | |
| # Get analysis type from metadata | |
| analysis_type = session_metadata.get('analysis_type', 'info') | |
| session_manager.set_status(session_id, "processing") | |
| session_manager.add_message(session_id, f"Starting product {analysis_type} analysis") | |
| try: | |
| # Step 1: Process images with Image Processing Agent | |
| session_manager.add_message(session_id, "Step 1: Analyzing product images...") | |
| image_results = await self.image_processor.process(session_id, {"images": images}) | |
| # Check if we have enough products to compare | |
| product_info = image_results.get("product_info", {}) | |
| if len(product_info) < 1: | |
| session_manager.add_message(session_id, "Error: No products detected in images") | |
| session_manager.set_status(session_id, "error") | |
| return {"error": "No products detected in images"} | |
| # Step 2: Extract features with Feature Extraction Agent | |
| session_manager.add_message(session_id, "Step 2: Extracting product specifications...") | |
| feature_results = await self.feature_extractor.process(session_id, image_results) | |
| # Step 3: Compare products with Comparison Agent if we have multiple products | |
| comparison_results = {} | |
| if len(product_info) >= 2: | |
| session_manager.add_message(session_id, "Step 3: Comparing products...") | |
| comparison_results = await self.comparison_agent.process( | |
| session_id, | |
| {**feature_results, "specifications": feature_results.get("specifications", {})} | |
| ) | |
| # Step 4: Generate recommendation with Recommendation Agent | |
| session_manager.add_message(session_id, "Step 4: Generating purchase recommendation...") | |
| recommendation_results = await self.recommendation_agent.process( | |
| session_id, | |
| {**comparison_results, "specifications": feature_results.get("specifications", {})} | |
| ) | |
| else: | |
| session_manager.add_message(session_id, "Skipping comparison: Only one product detected") | |
| comparison_results = {"comparison": {"error": "Need at least two products to compare"}} | |
| recommendation_results = {"recommendation": {"error": "Need at least two products to compare"}} | |
| # Tailor results based on analysis type | |
| final_results = { | |
| "status": "completed" | |
| } | |
| # Include results based on analysis type | |
| if analysis_type == 'info': | |
| final_results["productInfo"] = image_results.get("product_info", {}) | |
| final_results["specifications"] = feature_results.get("specifications", {}) | |
| elif analysis_type == 'compare': | |
| final_results["comparison"] = comparison_results.get("comparison", {}) | |
| final_results["productInfo"] = image_results.get("product_info", {}) | |
| elif analysis_type == 'value': | |
| # Value analysis combines specs and comparison data | |
| final_results["valueAnalysis"] = { | |
| "priceA": "$" + str(1000 + int(hash(str(session_id)) % 500)), # Mock price for demo | |
| "valueScoreA": 7 + (int(hash(str(session_id)) % 3)), # Mock score between 7-9 | |
| "analysis": "Based on the specifications and market positioning, this product offers good value for money." | |
| } | |
| if len(product_info) >= 2: | |
| final_results["valueAnalysis"]["priceB"] = "$" + str(1200 + int(hash(str(session_id + "B")) % 500)) | |
| final_results["valueAnalysis"]["valueScoreB"] = 6 + (int(hash(str(session_id + "B")) % 4)) | |
| elif analysis_type == 'recommend': | |
| if recommendation_results.get("recommendation", {}).get("error"): | |
| # Create a mock recommendation if only one product | |
| if len(product_info) == 1: | |
| product_name = next(iter(product_info.values())).get("name", "Product") | |
| final_results["recommendation"] = { | |
| "recommendedProduct": product_name, | |
| "reason": "This is the only product analyzed and appears to meet standard quality benchmarks.", | |
| "confidence": 0.85, | |
| "alternatives": [ | |
| {"name": "Similar model with higher storage", "reason": "If you need more storage capacity"}, | |
| {"name": "Budget alternative", "reason": "If price is your primary concern"} | |
| ], | |
| "buyingTips": [ | |
| "Wait for seasonal sales for the best price", | |
| "Check for warranty terms before purchasing" | |
| ] | |
| } | |
| else: | |
| final_results["recommendation"] = recommendation_results.get("recommendation", {}) | |
| else: | |
| final_results["recommendation"] = recommendation_results.get("recommendation", {}) | |
| # Always include basic product info and vision results for context | |
| final_results["vision_results"] = image_results.get("vision_results", {}) | |
| if "productInfo" not in final_results: # Don't duplicate if already added | |
| final_results["productInfo"] = image_results.get("product_info", {}) | |
| # Set final results in session manager | |
| session_manager.set_final_result(session_id, final_results) | |
| session_manager.set_status(session_id, "completed") | |
| session_manager.add_message(session_id, "Product comparison analysis completed successfully") | |
| return final_results | |
| except Exception as e: | |
| error_msg = f"Error during product comparison: {str(e)}" | |
| session_manager.add_message(session_id, error_msg) | |
| session_manager.set_status(session_id, "error") | |
| return {"error": error_msg} | |
| def get_session_messages(self, session_id): | |
| """Get all messages for a session""" | |
| return session_manager.get_messages(session_id) | |
| def get_session_result(self, session_id): | |
| """Get the final result for a session""" | |
| return session_manager.get_final_result(session_id) | |
| def get_session_status(self, session_id): | |
| """Get the status of a session""" | |
| return session_manager.get_status(session_id) | |
| # Helper function to create a coordinator instance | |
| def get_product_comparison_coordinator(): | |
| """Get a singleton instance of the ProductComparisonCoordinator""" | |
| if not hasattr(get_product_comparison_coordinator, "_instance"): | |
| get_product_comparison_coordinator._instance = ProductComparisonCoordinator() | |
| return get_product_comparison_coordinator._instance | |
| # Helper function to convert base64 image data to PIL Image | |
| def decode_base64_image(base64_data): | |
| """Convert base64 image data to PIL Image""" | |
| try: | |
| # Check if the base64 data includes a data URL prefix | |
| if base64_data.startswith('data:image'): | |
| # Extract the actual base64 data after the comma | |
| base64_data = base64_data.split(',', 1)[1] | |
| # Decode base64 data | |
| image_data = base64.b64decode(base64_data) | |
| # Convert to PIL Image | |
| image = Image.open(BytesIO(image_data)) | |
| return image | |
| except Exception as e: | |
| print(f"Error decoding base64 image: {e}") | |
| return None | |