import torch import torch.nn as nn import numpy as np from PIL import Image import torchvision.transforms as transforms from transformers import AutoModel, AutoProcessor import logging logger = logging.getLogger(__name__) class AestheticsEvaluator: """Image aesthetics assessment using multiple SOTA models""" def __init__(self): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.models = {} self.processors = {} self.load_models() def load_models(self): """Load aesthetics assessment models""" try: # Load UNIAA model (primary) logger.info("Loading UNIAA model...") self.load_uniaa() # Load MUSIQ model (secondary) logger.info("Loading MUSIQ model...") self.load_musiq() # Load anime-specific aesthetic model logger.info("Loading anime aesthetic model...") self.load_anime_aesthetic_model() except Exception as e: logger.error(f"Error loading aesthetic models: {str(e)}") self.use_fallback_implementation() def load_uniaa(self): """Load UNIAA model""" try: # Placeholder implementation for UNIAA self.models['uniaa'] = self.create_mock_aesthetic_model() self.processors['uniaa'] = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) except Exception as e: logger.warning(f"Could not load UNIAA: {str(e)}") def load_musiq(self): """Load MUSIQ model""" try: # Placeholder implementation for MUSIQ self.models['musiq'] = self.create_mock_aesthetic_model() self.processors['musiq'] = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) except Exception as e: logger.warning(f"Could not load MUSIQ: {str(e)}") def load_anime_aesthetic_model(self): """Load anime-specific aesthetic model""" try: # Placeholder for anime-specific model self.models['anime_aesthetic'] = self.create_mock_aesthetic_model() self.processors['anime_aesthetic'] = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) except Exception as e: logger.warning(f"Could not load anime aesthetic model: {str(e)}") def create_mock_aesthetic_model(self): """Create a mock aesthetic model for demonstration""" class MockAestheticModel(nn.Module): def __init__(self): super().__init__() self.backbone = torch.nn.Sequential( torch.nn.Conv2d(3, 64, 3, padding=1), torch.nn.ReLU(), torch.nn.Conv2d(64, 128, 3, padding=1), torch.nn.ReLU(), torch.nn.AdaptiveAvgPool2d((1, 1)), torch.nn.Flatten(), torch.nn.Linear(128, 64), torch.nn.ReLU(), torch.nn.Linear(64, 1), torch.nn.Sigmoid() ) def forward(self, x): return self.backbone(x) * 10 # Scale to 0-10 model = MockAestheticModel().to(self.device) model.eval() return model def use_fallback_implementation(self): """Use simple fallback aesthetic assessment""" logger.info("Using fallback aesthetic assessment implementation") self.fallback_mode = True def evaluate_with_uniaa(self, image: Image.Image) -> float: """Evaluate aesthetics using UNIAA""" try: if 'uniaa' not in self.models: return self.fallback_aesthetic_score(image) # Preprocess image tensor = self.processors['uniaa'](image).unsqueeze(0).to(self.device) # Get prediction with torch.no_grad(): score = self.models['uniaa'](tensor).item() return max(0.0, min(10.0, score)) except Exception as e: logger.error(f"Error in UNIAA evaluation: {str(e)}") return self.fallback_aesthetic_score(image) def evaluate_with_musiq(self, image: Image.Image) -> float: """Evaluate aesthetics using MUSIQ""" try: if 'musiq' not in self.models: return self.fallback_aesthetic_score(image) # Preprocess image tensor = self.processors['musiq'](image).unsqueeze(0).to(self.device) # Get prediction with torch.no_grad(): score = self.models['musiq'](tensor).item() return max(0.0, min(10.0, score)) except Exception as e: logger.error(f"Error in MUSIQ evaluation: {str(e)}") return self.fallback_aesthetic_score(image) def evaluate_with_anime_model(self, image: Image.Image) -> float: """Evaluate aesthetics using anime-specific model""" try: if 'anime_aesthetic' not in self.models: return self.fallback_aesthetic_score(image) # Preprocess image tensor = self.processors['anime_aesthetic'](image).unsqueeze(0).to(self.device) # Get prediction with torch.no_grad(): score = self.models['anime_aesthetic'](tensor).item() return max(0.0, min(10.0, score)) except Exception as e: logger.error(f"Error in anime aesthetic evaluation: {str(e)}") return self.fallback_aesthetic_score(image) def evaluate_composition_rules(self, image: Image.Image) -> float: """Evaluate based on composition rules (rule of thirds, etc.)""" try: # Convert to numpy array img_array = np.array(image) height, width = img_array.shape[:2] # Convert to grayscale for analysis if len(img_array.shape) == 3: gray = np.dot(img_array[...,:3], [0.2989, 0.5870, 0.1140]) else: gray = img_array # Rule of thirds analysis third_h, third_w = height // 3, width // 3 # Check for interesting content at rule of thirds intersections intersections = [ (third_h, third_w), (third_h, 2*third_w), (2*third_h, third_w), (2*third_h, 2*third_w) ] composition_score = 0.0 for y, x in intersections: # Check local variance around intersection points region = gray[max(0, y-10):min(height, y+10), max(0, x-10):min(width, x+10)] if region.size > 0: composition_score += region.var() # Normalize composition score composition_score = min(10.0, composition_score / 1000.0) # Color harmony analysis if len(img_array.shape) == 3: # Calculate color distribution colors = img_array.reshape(-1, 3) color_std = np.std(colors, axis=0).mean() color_harmony_score = min(10.0, color_std / 25.0) else: color_harmony_score = 5.0 # Combine scores final_score = (composition_score * 0.6 + color_harmony_score * 0.4) return max(0.0, min(10.0, final_score)) except Exception as e: logger.error(f"Error in composition analysis: {str(e)}") return 5.0 def fallback_aesthetic_score(self, image: Image.Image) -> float: """Simple fallback aesthetic assessment""" try: # Basic aesthetic assessment based on image properties width, height = image.size # Aspect ratio score (prefer aesthetically pleasing ratios) aspect_ratio = width / height golden_ratio = 1.618 if abs(aspect_ratio - golden_ratio) < 0.1 or abs(aspect_ratio - 1/golden_ratio) < 0.1: aspect_score = 9.0 elif 0.7 <= aspect_ratio <= 1.4: # Square-ish aspect_score = 7.0 elif 1.4 <= aspect_ratio <= 2.0: # Landscape aspect_score = 8.0 else: aspect_score = 5.0 # Resolution score (higher resolution often looks better) total_pixels = width * height resolution_score = min(10.0, total_pixels / 200000.0) # Normalize by 2MP # Color analysis img_array = np.array(image) if len(img_array.shape) == 3: # Color variety score unique_colors = len(np.unique(img_array.reshape(-1, 3), axis=0)) color_variety_score = min(10.0, unique_colors / 1000.0) # Brightness distribution brightness = np.mean(img_array, axis=2) brightness_score = 10.0 - abs(brightness.mean() - 127.5) / 12.75 else: color_variety_score = 5.0 brightness_score = 5.0 # Combine scores aesthetic_score = (aspect_score * 0.3 + resolution_score * 0.2 + color_variety_score * 0.3 + brightness_score * 0.2) return max(0.0, min(10.0, aesthetic_score)) except Exception: return 5.0 # Default neutral score def evaluate(self, image: Image.Image, anime_mode: bool = False) -> float: """ Evaluate image aesthetics using ensemble of models Args: image: PIL Image to evaluate anime_mode: Whether to use anime-specific evaluation Returns: Aesthetic score from 0-10 """ try: scores = [] if anime_mode: # For anime images, prioritize anime-specific model anime_score = self.evaluate_with_anime_model(image) scores.append(anime_score) # Also use general models but with lower weight uniaa_score = self.evaluate_with_uniaa(image) scores.append(uniaa_score) # Composition rules composition_score = self.evaluate_composition_rules(image) scores.append(composition_score) # Weights for anime mode weights = [0.5, 0.3, 0.2] else: # For realistic images, use general aesthetic models uniaa_score = self.evaluate_with_uniaa(image) scores.append(uniaa_score) musiq_score = self.evaluate_with_musiq(image) scores.append(musiq_score) # Composition rules composition_score = self.evaluate_composition_rules(image) scores.append(composition_score) # Weights for realistic mode weights = [0.4, 0.4, 0.2] # Ensemble scoring final_score = sum(score * weight for score, weight in zip(scores, weights)) logger.info(f"Aesthetic scores - Scores: {scores}, Final: {final_score:.2f}") return max(0.0, min(10.0, final_score)) except Exception as e: logger.error(f"Error in aesthetic evaluation: {str(e)}") return self.fallback_aesthetic_score(image)