Spaces:
Running
Running
Upload 14 files
Browse files- .gitattributes +3 -0
- app.py +357 -0
- app_config.yaml +19 -0
- models/__init__.py +2 -0
- models/aesthetics_evaluator.py +322 -0
- models/ai_detection_evaluator.py +383 -0
- models/prompt_evaluator.py +309 -0
- models/quality_evaluator.py +249 -0
- requirements.txt +19 -0
- test_images/anime_character.png +3 -0
- test_images/landscape_art.png +3 -0
- test_images/realistic_portrait.png +3 -0
- utils/__init__.py +2 -0
- utils/metadata_extractor.py +304 -0
- utils/scoring.py +359 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
test_images/anime_character.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
test_images/landscape_art.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
test_images/realistic_portrait.png filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import json
|
| 6 |
+
import io
|
| 7 |
+
import base64
|
| 8 |
+
from typing import List, Dict, Tuple, Optional
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import tempfile
|
| 12 |
+
import os
|
| 13 |
+
import random
|
| 14 |
+
|
| 15 |
+
# Simplified imports for testing
|
| 16 |
+
try:
|
| 17 |
+
import torch
|
| 18 |
+
TORCH_AVAILABLE = True
|
| 19 |
+
except ImportError:
|
| 20 |
+
TORCH_AVAILABLE = False
|
| 21 |
+
print("Warning: PyTorch not available, using mock implementations")
|
| 22 |
+
|
| 23 |
+
# Import evaluation modules with fallbacks
|
| 24 |
+
try:
|
| 25 |
+
from models.quality_evaluator import QualityEvaluator
|
| 26 |
+
from models.aesthetics_evaluator import AestheticsEvaluator
|
| 27 |
+
from models.prompt_evaluator import PromptEvaluator
|
| 28 |
+
from models.ai_detection_evaluator import AIDetectionEvaluator
|
| 29 |
+
from utils.metadata_extractor import extract_png_metadata
|
| 30 |
+
from utils.scoring import calculate_final_score
|
| 31 |
+
except ImportError as e:
|
| 32 |
+
print(f"Warning: Could not import evaluation modules: {e}")
|
| 33 |
+
# Use mock implementations
|
| 34 |
+
class MockEvaluator:
|
| 35 |
+
def __init__(self):
|
| 36 |
+
pass
|
| 37 |
+
def evaluate(self, *args, **kwargs):
|
| 38 |
+
return random.uniform(5.0, 9.0)
|
| 39 |
+
|
| 40 |
+
QualityEvaluator = MockEvaluator
|
| 41 |
+
AestheticsEvaluator = MockEvaluator
|
| 42 |
+
PromptEvaluator = MockEvaluator
|
| 43 |
+
AIDetectionEvaluator = MockEvaluator
|
| 44 |
+
|
| 45 |
+
def extract_png_metadata(path):
|
| 46 |
+
return None
|
| 47 |
+
|
| 48 |
+
def calculate_final_score(quality, aesthetics, prompt, ai_detection, has_prompt=True):
|
| 49 |
+
if has_prompt:
|
| 50 |
+
return (quality * 0.25 + aesthetics * 0.35 + prompt * 0.25 + (1-ai_detection) * 0.15)
|
| 51 |
+
else:
|
| 52 |
+
return (quality * 0.375 + aesthetics * 0.475 + (1-ai_detection) * 0.15)
|
| 53 |
+
|
| 54 |
+
# Configure logging
|
| 55 |
+
logging.basicConfig(level=logging.INFO)
|
| 56 |
+
logger = logging.getLogger(__name__)
|
| 57 |
+
|
| 58 |
+
class ImageEvaluationApp:
|
| 59 |
+
def __init__(self):
|
| 60 |
+
self.quality_evaluator = None
|
| 61 |
+
self.aesthetics_evaluator = None
|
| 62 |
+
self.prompt_evaluator = None
|
| 63 |
+
self.ai_detection_evaluator = None
|
| 64 |
+
self.models_loaded = False
|
| 65 |
+
|
| 66 |
+
def load_models(self, selected_models: Dict[str, bool]):
|
| 67 |
+
"""Load selected evaluation models"""
|
| 68 |
+
try:
|
| 69 |
+
if selected_models.get('quality', True) and self.quality_evaluator is None:
|
| 70 |
+
logger.info("Loading quality evaluation models...")
|
| 71 |
+
self.quality_evaluator = QualityEvaluator()
|
| 72 |
+
|
| 73 |
+
if selected_models.get('aesthetics', True) and self.aesthetics_evaluator is None:
|
| 74 |
+
logger.info("Loading aesthetics evaluation models...")
|
| 75 |
+
self.aesthetics_evaluator = AestheticsEvaluator()
|
| 76 |
+
|
| 77 |
+
if selected_models.get('prompt', True) and self.prompt_evaluator is None:
|
| 78 |
+
logger.info("Loading prompt evaluation models...")
|
| 79 |
+
self.prompt_evaluator = PromptEvaluator()
|
| 80 |
+
|
| 81 |
+
if selected_models.get('ai_detection', True) and self.ai_detection_evaluator is None:
|
| 82 |
+
logger.info("Loading AI detection models...")
|
| 83 |
+
self.ai_detection_evaluator = AIDetectionEvaluator()
|
| 84 |
+
|
| 85 |
+
self.models_loaded = True
|
| 86 |
+
logger.info("All selected models loaded successfully!")
|
| 87 |
+
|
| 88 |
+
except Exception as e:
|
| 89 |
+
logger.error(f"Error loading models: {str(e)}")
|
| 90 |
+
raise e
|
| 91 |
+
|
| 92 |
+
def evaluate_images(
|
| 93 |
+
self,
|
| 94 |
+
images: List[str],
|
| 95 |
+
enable_quality: bool = True,
|
| 96 |
+
enable_aesthetics: bool = True,
|
| 97 |
+
enable_prompt: bool = True,
|
| 98 |
+
enable_ai_detection: bool = True,
|
| 99 |
+
anime_mode: bool = False,
|
| 100 |
+
progress=gr.Progress()
|
| 101 |
+
) -> Tuple[pd.DataFrame, str]:
|
| 102 |
+
"""
|
| 103 |
+
Evaluate uploaded images and return results
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
images: List of image file paths
|
| 107 |
+
enable_quality: Whether to evaluate image quality
|
| 108 |
+
enable_aesthetics: Whether to evaluate aesthetics
|
| 109 |
+
enable_prompt: Whether to evaluate prompt following
|
| 110 |
+
enable_ai_detection: Whether to detect AI generation
|
| 111 |
+
anime_mode: Whether to use anime-specific models
|
| 112 |
+
progress: Gradio progress tracker
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
Tuple of (results_dataframe, status_message)
|
| 116 |
+
"""
|
| 117 |
+
if not images:
|
| 118 |
+
return pd.DataFrame(), "No images uploaded."
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
# Load models based on selection
|
| 122 |
+
selected_models = {
|
| 123 |
+
'quality': enable_quality,
|
| 124 |
+
'aesthetics': enable_aesthetics,
|
| 125 |
+
'prompt': enable_prompt,
|
| 126 |
+
'ai_detection': enable_ai_detection
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
progress(0.1, desc="Loading models...")
|
| 130 |
+
self.load_models(selected_models)
|
| 131 |
+
|
| 132 |
+
results = []
|
| 133 |
+
total_images = len(images)
|
| 134 |
+
|
| 135 |
+
for i, image_path in enumerate(images):
|
| 136 |
+
progress((i + 1) / total_images * 0.9 + 0.1,
|
| 137 |
+
desc=f"Evaluating image {i+1}/{total_images}")
|
| 138 |
+
|
| 139 |
+
try:
|
| 140 |
+
# Load image
|
| 141 |
+
image = Image.open(image_path).convert('RGB')
|
| 142 |
+
filename = Path(image_path).name
|
| 143 |
+
|
| 144 |
+
# Extract metadata
|
| 145 |
+
metadata = extract_png_metadata(image_path)
|
| 146 |
+
prompt = metadata.get('prompt', '') if metadata else ''
|
| 147 |
+
|
| 148 |
+
# Initialize scores
|
| 149 |
+
scores = {
|
| 150 |
+
'filename': filename,
|
| 151 |
+
'quality_score': 0.0,
|
| 152 |
+
'aesthetics_score': 0.0,
|
| 153 |
+
'prompt_score': 0.0,
|
| 154 |
+
'ai_detection_score': 0.0,
|
| 155 |
+
'has_prompt': bool(prompt),
|
| 156 |
+
'prompt_text': prompt[:100] + '...' if len(prompt) > 100 else prompt
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
# Evaluate quality
|
| 160 |
+
if enable_quality and self.quality_evaluator:
|
| 161 |
+
scores['quality_score'] = self.quality_evaluator.evaluate(image, anime_mode)
|
| 162 |
+
|
| 163 |
+
# Evaluate aesthetics
|
| 164 |
+
if enable_aesthetics and self.aesthetics_evaluator:
|
| 165 |
+
scores['aesthetics_score'] = self.aesthetics_evaluator.evaluate(image, anime_mode)
|
| 166 |
+
|
| 167 |
+
# Evaluate prompt following (only if prompt available)
|
| 168 |
+
if enable_prompt and self.prompt_evaluator and prompt:
|
| 169 |
+
scores['prompt_score'] = self.prompt_evaluator.evaluate(image, prompt)
|
| 170 |
+
|
| 171 |
+
# Evaluate AI detection
|
| 172 |
+
if enable_ai_detection and self.ai_detection_evaluator:
|
| 173 |
+
scores['ai_detection_score'] = self.ai_detection_evaluator.evaluate(image)
|
| 174 |
+
|
| 175 |
+
# Calculate final score
|
| 176 |
+
scores['final_score'] = calculate_final_score(
|
| 177 |
+
scores['quality_score'],
|
| 178 |
+
scores['aesthetics_score'],
|
| 179 |
+
scores['prompt_score'],
|
| 180 |
+
scores['ai_detection_score'],
|
| 181 |
+
scores['has_prompt']
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Create thumbnail for display
|
| 185 |
+
thumbnail = image.copy()
|
| 186 |
+
thumbnail.thumbnail((150, 150), Image.Resampling.LANCZOS)
|
| 187 |
+
|
| 188 |
+
# Convert thumbnail to base64 for display
|
| 189 |
+
buffer = io.BytesIO()
|
| 190 |
+
thumbnail.save(buffer, format='PNG')
|
| 191 |
+
thumbnail_b64 = base64.b64encode(buffer.getvalue()).decode()
|
| 192 |
+
scores['thumbnail'] = f"data:image/png;base64,{thumbnail_b64}"
|
| 193 |
+
|
| 194 |
+
results.append(scores)
|
| 195 |
+
|
| 196 |
+
except Exception as e:
|
| 197 |
+
logger.error(f"Error evaluating {image_path}: {str(e)}")
|
| 198 |
+
# Add error entry
|
| 199 |
+
results.append({
|
| 200 |
+
'filename': Path(image_path).name,
|
| 201 |
+
'quality_score': 0.0,
|
| 202 |
+
'aesthetics_score': 0.0,
|
| 203 |
+
'prompt_score': 0.0,
|
| 204 |
+
'ai_detection_score': 0.0,
|
| 205 |
+
'final_score': 0.0,
|
| 206 |
+
'has_prompt': False,
|
| 207 |
+
'prompt_text': f"Error: {str(e)}",
|
| 208 |
+
'thumbnail': ""
|
| 209 |
+
})
|
| 210 |
+
|
| 211 |
+
# Create DataFrame and sort by final score
|
| 212 |
+
df = pd.DataFrame(results)
|
| 213 |
+
if not df.empty:
|
| 214 |
+
df = df.sort_values('final_score', ascending=False).reset_index(drop=True)
|
| 215 |
+
df.index = df.index + 1 # Start ranking from 1
|
| 216 |
+
df.index.name = 'Rank'
|
| 217 |
+
|
| 218 |
+
progress(1.0, desc="Evaluation complete!")
|
| 219 |
+
|
| 220 |
+
status_msg = f"Successfully evaluated {len(results)} images."
|
| 221 |
+
if any('Error:' in str(r.get('prompt_text', '')) for r in results):
|
| 222 |
+
error_count = sum(1 for r in results if 'Error:' in str(r.get('prompt_text', '')))
|
| 223 |
+
status_msg += f" {error_count} images had evaluation errors."
|
| 224 |
+
|
| 225 |
+
return df, status_msg
|
| 226 |
+
|
| 227 |
+
except Exception as e:
|
| 228 |
+
logger.error(f"Error in evaluate_images: {str(e)}")
|
| 229 |
+
return pd.DataFrame(), f"Error during evaluation: {str(e)}"
|
| 230 |
+
|
| 231 |
+
def create_interface():
|
| 232 |
+
"""Create and configure the Gradio interface"""
|
| 233 |
+
|
| 234 |
+
app = ImageEvaluationApp()
|
| 235 |
+
|
| 236 |
+
# Custom CSS for better styling
|
| 237 |
+
css = """
|
| 238 |
+
.gradio-container {
|
| 239 |
+
max-width: 1200px !important;
|
| 240 |
+
}
|
| 241 |
+
.results-table {
|
| 242 |
+
font-size: 12px;
|
| 243 |
+
}
|
| 244 |
+
.thumbnail-cell img {
|
| 245 |
+
max-width: 100px;
|
| 246 |
+
max-height: 100px;
|
| 247 |
+
object-fit: cover;
|
| 248 |
+
}
|
| 249 |
+
"""
|
| 250 |
+
|
| 251 |
+
with gr.Blocks(css=css, title="AI Image Evaluation Tool") as interface:
|
| 252 |
+
gr.Markdown("""
|
| 253 |
+
# 🎨 AI Image Evaluation Tool
|
| 254 |
+
|
| 255 |
+
Upload your AI-generated images to evaluate their quality, aesthetics, prompt following, and detect AI generation.
|
| 256 |
+
Supports realistic, anime, and art styles with multiple SOTA models.
|
| 257 |
+
""")
|
| 258 |
+
|
| 259 |
+
with gr.Row():
|
| 260 |
+
with gr.Column(scale=1):
|
| 261 |
+
# File upload
|
| 262 |
+
images_input = gr.File(
|
| 263 |
+
label="Upload Images",
|
| 264 |
+
file_count="multiple",
|
| 265 |
+
file_types=["image"],
|
| 266 |
+
height=200
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# Model selection
|
| 270 |
+
gr.Markdown("### Model Selection")
|
| 271 |
+
with gr.Row():
|
| 272 |
+
enable_quality = gr.Checkbox(label="Image Quality", value=True)
|
| 273 |
+
enable_aesthetics = gr.Checkbox(label="Aesthetics", value=True)
|
| 274 |
+
|
| 275 |
+
with gr.Row():
|
| 276 |
+
enable_prompt = gr.Checkbox(label="Prompt Following", value=True)
|
| 277 |
+
enable_ai_detection = gr.Checkbox(label="AI Detection", value=True)
|
| 278 |
+
|
| 279 |
+
# Additional options
|
| 280 |
+
gr.Markdown("### Options")
|
| 281 |
+
anime_mode = gr.Checkbox(label="Anime/Art Mode", value=False)
|
| 282 |
+
|
| 283 |
+
# Evaluate button
|
| 284 |
+
evaluate_btn = gr.Button("🚀 Evaluate Images", variant="primary", size="lg")
|
| 285 |
+
|
| 286 |
+
# Status
|
| 287 |
+
status_output = gr.Textbox(label="Status", interactive=False)
|
| 288 |
+
|
| 289 |
+
with gr.Column(scale=2):
|
| 290 |
+
# Results display
|
| 291 |
+
gr.Markdown("### 📊 Evaluation Results")
|
| 292 |
+
results_output = gr.Dataframe(
|
| 293 |
+
headers=["Rank", "Filename", "Quality", "Aesthetics", "Prompt", "AI Detection", "Final Score", "Thumbnail"],
|
| 294 |
+
datatype=["number", "str", "number", "number", "number", "number", "number", "str"],
|
| 295 |
+
label="Results",
|
| 296 |
+
interactive=False,
|
| 297 |
+
wrap=True,
|
| 298 |
+
elem_classes=["results-table"]
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Event handlers
|
| 302 |
+
evaluate_btn.click(
|
| 303 |
+
fn=app.evaluate_images,
|
| 304 |
+
inputs=[
|
| 305 |
+
images_input,
|
| 306 |
+
enable_quality,
|
| 307 |
+
enable_aesthetics,
|
| 308 |
+
enable_prompt,
|
| 309 |
+
enable_ai_detection,
|
| 310 |
+
anime_mode
|
| 311 |
+
],
|
| 312 |
+
outputs=[results_output, status_output],
|
| 313 |
+
show_progress=True
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# Examples and help
|
| 317 |
+
with gr.Accordion("ℹ️ Help & Information", open=False):
|
| 318 |
+
gr.Markdown("""
|
| 319 |
+
### How to Use
|
| 320 |
+
1. **Upload Images**: Select multiple PNG/JPG images (max 50MB each)
|
| 321 |
+
2. **Select Models**: Choose which evaluation metrics to use
|
| 322 |
+
3. **Anime Mode**: Enable for better evaluation of anime/art style images
|
| 323 |
+
4. **Evaluate**: Click the button to start evaluation
|
| 324 |
+
|
| 325 |
+
### Scoring System
|
| 326 |
+
- **Quality Score**: Technical image quality (0-10)
|
| 327 |
+
- **Aesthetics Score**: Visual appeal and composition (0-10)
|
| 328 |
+
- **Prompt Score**: How well the image follows the text prompt (0-10, requires metadata)
|
| 329 |
+
- **AI Detection**: Probability of being AI-generated (0-1, lower is better)
|
| 330 |
+
- **Final Score**: Weighted combination of all metrics (0-10)
|
| 331 |
+
|
| 332 |
+
### Supported Formats
|
| 333 |
+
- PNG files with A1111/ComfyUI metadata (for prompt evaluation)
|
| 334 |
+
- JPG, PNG, WebP images (for other evaluations)
|
| 335 |
+
- Batch processing of 10-100+ images
|
| 336 |
+
|
| 337 |
+
### Models Used
|
| 338 |
+
- **Quality**: LAR-IQA, DGIQA
|
| 339 |
+
- **Aesthetics**: UNIAA, MUSIQ
|
| 340 |
+
- **Prompt Following**: CLIP, BLIP-2
|
| 341 |
+
- **AI Detection**: Sentry-Image, Custom ensemble
|
| 342 |
+
""")
|
| 343 |
+
|
| 344 |
+
return interface
|
| 345 |
+
|
| 346 |
+
if __name__ == "__main__":
|
| 347 |
+
# Create the interface
|
| 348 |
+
interface = create_interface()
|
| 349 |
+
|
| 350 |
+
# Launch the app
|
| 351 |
+
interface.launch(
|
| 352 |
+
server_name="0.0.0.0",
|
| 353 |
+
server_port=7860,
|
| 354 |
+
share=False,
|
| 355 |
+
show_error=True
|
| 356 |
+
)
|
| 357 |
+
|
app_config.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
title: AI Image Evaluation Tool
|
| 2 |
+
emoji: 🎨
|
| 3 |
+
colorFrom: blue
|
| 4 |
+
colorTo: purple
|
| 5 |
+
sdk: gradio
|
| 6 |
+
sdk_version: 5.38.0
|
| 7 |
+
app_file: app.py
|
| 8 |
+
pinned: false
|
| 9 |
+
license: mit
|
| 10 |
+
short_description: Evaluate AI-generated images using multiple SOTA models for quality, aesthetics, prompt following, and AI detection
|
| 11 |
+
tags:
|
| 12 |
+
- image-evaluation
|
| 13 |
+
- ai-detection
|
| 14 |
+
- image-quality
|
| 15 |
+
- aesthetics
|
| 16 |
+
- prompt-following
|
| 17 |
+
- gradio
|
| 18 |
+
- computer-vision
|
| 19 |
+
|
models/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Models package for image evaluation
|
| 2 |
+
|
models/aesthetics_evaluator.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import torchvision.transforms as transforms
|
| 6 |
+
from transformers import AutoModel, AutoProcessor
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class AestheticsEvaluator:
|
| 12 |
+
"""Image aesthetics assessment using multiple SOTA models"""
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 16 |
+
self.models = {}
|
| 17 |
+
self.processors = {}
|
| 18 |
+
self.load_models()
|
| 19 |
+
|
| 20 |
+
def load_models(self):
|
| 21 |
+
"""Load aesthetics assessment models"""
|
| 22 |
+
try:
|
| 23 |
+
# Load UNIAA model (primary)
|
| 24 |
+
logger.info("Loading UNIAA model...")
|
| 25 |
+
self.load_uniaa()
|
| 26 |
+
|
| 27 |
+
# Load MUSIQ model (secondary)
|
| 28 |
+
logger.info("Loading MUSIQ model...")
|
| 29 |
+
self.load_musiq()
|
| 30 |
+
|
| 31 |
+
# Load anime-specific aesthetic model
|
| 32 |
+
logger.info("Loading anime aesthetic model...")
|
| 33 |
+
self.load_anime_aesthetic_model()
|
| 34 |
+
|
| 35 |
+
except Exception as e:
|
| 36 |
+
logger.error(f"Error loading aesthetic models: {str(e)}")
|
| 37 |
+
self.use_fallback_implementation()
|
| 38 |
+
|
| 39 |
+
def load_uniaa(self):
|
| 40 |
+
"""Load UNIAA model"""
|
| 41 |
+
try:
|
| 42 |
+
# Placeholder implementation for UNIAA
|
| 43 |
+
self.models['uniaa'] = self.create_mock_aesthetic_model()
|
| 44 |
+
self.processors['uniaa'] = transforms.Compose([
|
| 45 |
+
transforms.Resize((224, 224)),
|
| 46 |
+
transforms.ToTensor(),
|
| 47 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 48 |
+
std=[0.229, 0.224, 0.225])
|
| 49 |
+
])
|
| 50 |
+
except Exception as e:
|
| 51 |
+
logger.warning(f"Could not load UNIAA: {str(e)}")
|
| 52 |
+
|
| 53 |
+
def load_musiq(self):
|
| 54 |
+
"""Load MUSIQ model"""
|
| 55 |
+
try:
|
| 56 |
+
# Placeholder implementation for MUSIQ
|
| 57 |
+
self.models['musiq'] = self.create_mock_aesthetic_model()
|
| 58 |
+
self.processors['musiq'] = transforms.Compose([
|
| 59 |
+
transforms.Resize((224, 224)),
|
| 60 |
+
transforms.ToTensor(),
|
| 61 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 62 |
+
std=[0.229, 0.224, 0.225])
|
| 63 |
+
])
|
| 64 |
+
except Exception as e:
|
| 65 |
+
logger.warning(f"Could not load MUSIQ: {str(e)}")
|
| 66 |
+
|
| 67 |
+
def load_anime_aesthetic_model(self):
|
| 68 |
+
"""Load anime-specific aesthetic model"""
|
| 69 |
+
try:
|
| 70 |
+
# Placeholder for anime-specific model
|
| 71 |
+
self.models['anime_aesthetic'] = self.create_mock_aesthetic_model()
|
| 72 |
+
self.processors['anime_aesthetic'] = transforms.Compose([
|
| 73 |
+
transforms.Resize((224, 224)),
|
| 74 |
+
transforms.ToTensor(),
|
| 75 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 76 |
+
std=[0.229, 0.224, 0.225])
|
| 77 |
+
])
|
| 78 |
+
except Exception as e:
|
| 79 |
+
logger.warning(f"Could not load anime aesthetic model: {str(e)}")
|
| 80 |
+
|
| 81 |
+
def create_mock_aesthetic_model(self):
|
| 82 |
+
"""Create a mock aesthetic model for demonstration"""
|
| 83 |
+
class MockAestheticModel(nn.Module):
|
| 84 |
+
def __init__(self):
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.backbone = torch.nn.Sequential(
|
| 87 |
+
torch.nn.Conv2d(3, 64, 3, padding=1),
|
| 88 |
+
torch.nn.ReLU(),
|
| 89 |
+
torch.nn.Conv2d(64, 128, 3, padding=1),
|
| 90 |
+
torch.nn.ReLU(),
|
| 91 |
+
torch.nn.AdaptiveAvgPool2d((1, 1)),
|
| 92 |
+
torch.nn.Flatten(),
|
| 93 |
+
torch.nn.Linear(128, 64),
|
| 94 |
+
torch.nn.ReLU(),
|
| 95 |
+
torch.nn.Linear(64, 1),
|
| 96 |
+
torch.nn.Sigmoid()
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def forward(self, x):
|
| 100 |
+
return self.backbone(x) * 10 # Scale to 0-10
|
| 101 |
+
|
| 102 |
+
model = MockAestheticModel().to(self.device)
|
| 103 |
+
model.eval()
|
| 104 |
+
return model
|
| 105 |
+
|
| 106 |
+
def use_fallback_implementation(self):
|
| 107 |
+
"""Use simple fallback aesthetic assessment"""
|
| 108 |
+
logger.info("Using fallback aesthetic assessment implementation")
|
| 109 |
+
self.fallback_mode = True
|
| 110 |
+
|
| 111 |
+
def evaluate_with_uniaa(self, image: Image.Image) -> float:
|
| 112 |
+
"""Evaluate aesthetics using UNIAA"""
|
| 113 |
+
try:
|
| 114 |
+
if 'uniaa' not in self.models:
|
| 115 |
+
return self.fallback_aesthetic_score(image)
|
| 116 |
+
|
| 117 |
+
# Preprocess image
|
| 118 |
+
tensor = self.processors['uniaa'](image).unsqueeze(0).to(self.device)
|
| 119 |
+
|
| 120 |
+
# Get prediction
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
score = self.models['uniaa'](tensor).item()
|
| 123 |
+
|
| 124 |
+
return max(0.0, min(10.0, score))
|
| 125 |
+
|
| 126 |
+
except Exception as e:
|
| 127 |
+
logger.error(f"Error in UNIAA evaluation: {str(e)}")
|
| 128 |
+
return self.fallback_aesthetic_score(image)
|
| 129 |
+
|
| 130 |
+
def evaluate_with_musiq(self, image: Image.Image) -> float:
|
| 131 |
+
"""Evaluate aesthetics using MUSIQ"""
|
| 132 |
+
try:
|
| 133 |
+
if 'musiq' not in self.models:
|
| 134 |
+
return self.fallback_aesthetic_score(image)
|
| 135 |
+
|
| 136 |
+
# Preprocess image
|
| 137 |
+
tensor = self.processors['musiq'](image).unsqueeze(0).to(self.device)
|
| 138 |
+
|
| 139 |
+
# Get prediction
|
| 140 |
+
with torch.no_grad():
|
| 141 |
+
score = self.models['musiq'](tensor).item()
|
| 142 |
+
|
| 143 |
+
return max(0.0, min(10.0, score))
|
| 144 |
+
|
| 145 |
+
except Exception as e:
|
| 146 |
+
logger.error(f"Error in MUSIQ evaluation: {str(e)}")
|
| 147 |
+
return self.fallback_aesthetic_score(image)
|
| 148 |
+
|
| 149 |
+
def evaluate_with_anime_model(self, image: Image.Image) -> float:
|
| 150 |
+
"""Evaluate aesthetics using anime-specific model"""
|
| 151 |
+
try:
|
| 152 |
+
if 'anime_aesthetic' not in self.models:
|
| 153 |
+
return self.fallback_aesthetic_score(image)
|
| 154 |
+
|
| 155 |
+
# Preprocess image
|
| 156 |
+
tensor = self.processors['anime_aesthetic'](image).unsqueeze(0).to(self.device)
|
| 157 |
+
|
| 158 |
+
# Get prediction
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
score = self.models['anime_aesthetic'](tensor).item()
|
| 161 |
+
|
| 162 |
+
return max(0.0, min(10.0, score))
|
| 163 |
+
|
| 164 |
+
except Exception as e:
|
| 165 |
+
logger.error(f"Error in anime aesthetic evaluation: {str(e)}")
|
| 166 |
+
return self.fallback_aesthetic_score(image)
|
| 167 |
+
|
| 168 |
+
def evaluate_composition_rules(self, image: Image.Image) -> float:
|
| 169 |
+
"""Evaluate based on composition rules (rule of thirds, etc.)"""
|
| 170 |
+
try:
|
| 171 |
+
# Convert to numpy array
|
| 172 |
+
img_array = np.array(image)
|
| 173 |
+
height, width = img_array.shape[:2]
|
| 174 |
+
|
| 175 |
+
# Convert to grayscale for analysis
|
| 176 |
+
if len(img_array.shape) == 3:
|
| 177 |
+
gray = np.dot(img_array[...,:3], [0.2989, 0.5870, 0.1140])
|
| 178 |
+
else:
|
| 179 |
+
gray = img_array
|
| 180 |
+
|
| 181 |
+
# Rule of thirds analysis
|
| 182 |
+
third_h, third_w = height // 3, width // 3
|
| 183 |
+
|
| 184 |
+
# Check for interesting content at rule of thirds intersections
|
| 185 |
+
intersections = [
|
| 186 |
+
(third_h, third_w), (third_h, 2*third_w),
|
| 187 |
+
(2*third_h, third_w), (2*third_h, 2*third_w)
|
| 188 |
+
]
|
| 189 |
+
|
| 190 |
+
composition_score = 0.0
|
| 191 |
+
for y, x in intersections:
|
| 192 |
+
# Check local variance around intersection points
|
| 193 |
+
region = gray[max(0, y-10):min(height, y+10),
|
| 194 |
+
max(0, x-10):min(width, x+10)]
|
| 195 |
+
if region.size > 0:
|
| 196 |
+
composition_score += region.var()
|
| 197 |
+
|
| 198 |
+
# Normalize composition score
|
| 199 |
+
composition_score = min(10.0, composition_score / 1000.0)
|
| 200 |
+
|
| 201 |
+
# Color harmony analysis
|
| 202 |
+
if len(img_array.shape) == 3:
|
| 203 |
+
# Calculate color distribution
|
| 204 |
+
colors = img_array.reshape(-1, 3)
|
| 205 |
+
color_std = np.std(colors, axis=0).mean()
|
| 206 |
+
color_harmony_score = min(10.0, color_std / 25.0)
|
| 207 |
+
else:
|
| 208 |
+
color_harmony_score = 5.0
|
| 209 |
+
|
| 210 |
+
# Combine scores
|
| 211 |
+
final_score = (composition_score * 0.6 + color_harmony_score * 0.4)
|
| 212 |
+
|
| 213 |
+
return max(0.0, min(10.0, final_score))
|
| 214 |
+
|
| 215 |
+
except Exception as e:
|
| 216 |
+
logger.error(f"Error in composition analysis: {str(e)}")
|
| 217 |
+
return 5.0
|
| 218 |
+
|
| 219 |
+
def fallback_aesthetic_score(self, image: Image.Image) -> float:
|
| 220 |
+
"""Simple fallback aesthetic assessment"""
|
| 221 |
+
try:
|
| 222 |
+
# Basic aesthetic assessment based on image properties
|
| 223 |
+
width, height = image.size
|
| 224 |
+
|
| 225 |
+
# Aspect ratio score (prefer aesthetically pleasing ratios)
|
| 226 |
+
aspect_ratio = width / height
|
| 227 |
+
golden_ratio = 1.618
|
| 228 |
+
|
| 229 |
+
if abs(aspect_ratio - golden_ratio) < 0.1 or abs(aspect_ratio - 1/golden_ratio) < 0.1:
|
| 230 |
+
aspect_score = 9.0
|
| 231 |
+
elif 0.7 <= aspect_ratio <= 1.4: # Square-ish
|
| 232 |
+
aspect_score = 7.0
|
| 233 |
+
elif 1.4 <= aspect_ratio <= 2.0: # Landscape
|
| 234 |
+
aspect_score = 8.0
|
| 235 |
+
else:
|
| 236 |
+
aspect_score = 5.0
|
| 237 |
+
|
| 238 |
+
# Resolution score (higher resolution often looks better)
|
| 239 |
+
total_pixels = width * height
|
| 240 |
+
resolution_score = min(10.0, total_pixels / 200000.0) # Normalize by 2MP
|
| 241 |
+
|
| 242 |
+
# Color analysis
|
| 243 |
+
img_array = np.array(image)
|
| 244 |
+
if len(img_array.shape) == 3:
|
| 245 |
+
# Color variety score
|
| 246 |
+
unique_colors = len(np.unique(img_array.reshape(-1, 3), axis=0))
|
| 247 |
+
color_variety_score = min(10.0, unique_colors / 1000.0)
|
| 248 |
+
|
| 249 |
+
# Brightness distribution
|
| 250 |
+
brightness = np.mean(img_array, axis=2)
|
| 251 |
+
brightness_score = 10.0 - abs(brightness.mean() - 127.5) / 12.75
|
| 252 |
+
else:
|
| 253 |
+
color_variety_score = 5.0
|
| 254 |
+
brightness_score = 5.0
|
| 255 |
+
|
| 256 |
+
# Combine scores
|
| 257 |
+
aesthetic_score = (aspect_score * 0.3 +
|
| 258 |
+
resolution_score * 0.2 +
|
| 259 |
+
color_variety_score * 0.3 +
|
| 260 |
+
brightness_score * 0.2)
|
| 261 |
+
|
| 262 |
+
return max(0.0, min(10.0, aesthetic_score))
|
| 263 |
+
|
| 264 |
+
except Exception:
|
| 265 |
+
return 5.0 # Default neutral score
|
| 266 |
+
|
| 267 |
+
def evaluate(self, image: Image.Image, anime_mode: bool = False) -> float:
|
| 268 |
+
"""
|
| 269 |
+
Evaluate image aesthetics using ensemble of models
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
image: PIL Image to evaluate
|
| 273 |
+
anime_mode: Whether to use anime-specific evaluation
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
Aesthetic score from 0-10
|
| 277 |
+
"""
|
| 278 |
+
try:
|
| 279 |
+
scores = []
|
| 280 |
+
|
| 281 |
+
if anime_mode:
|
| 282 |
+
# For anime images, prioritize anime-specific model
|
| 283 |
+
anime_score = self.evaluate_with_anime_model(image)
|
| 284 |
+
scores.append(anime_score)
|
| 285 |
+
|
| 286 |
+
# Also use general models but with lower weight
|
| 287 |
+
uniaa_score = self.evaluate_with_uniaa(image)
|
| 288 |
+
scores.append(uniaa_score)
|
| 289 |
+
|
| 290 |
+
# Composition rules
|
| 291 |
+
composition_score = self.evaluate_composition_rules(image)
|
| 292 |
+
scores.append(composition_score)
|
| 293 |
+
|
| 294 |
+
# Weights for anime mode
|
| 295 |
+
weights = [0.5, 0.3, 0.2]
|
| 296 |
+
|
| 297 |
+
else:
|
| 298 |
+
# For realistic images, use general aesthetic models
|
| 299 |
+
uniaa_score = self.evaluate_with_uniaa(image)
|
| 300 |
+
scores.append(uniaa_score)
|
| 301 |
+
|
| 302 |
+
musiq_score = self.evaluate_with_musiq(image)
|
| 303 |
+
scores.append(musiq_score)
|
| 304 |
+
|
| 305 |
+
# Composition rules
|
| 306 |
+
composition_score = self.evaluate_composition_rules(image)
|
| 307 |
+
scores.append(composition_score)
|
| 308 |
+
|
| 309 |
+
# Weights for realistic mode
|
| 310 |
+
weights = [0.4, 0.4, 0.2]
|
| 311 |
+
|
| 312 |
+
# Ensemble scoring
|
| 313 |
+
final_score = sum(score * weight for score, weight in zip(scores, weights))
|
| 314 |
+
|
| 315 |
+
logger.info(f"Aesthetic scores - Scores: {scores}, Final: {final_score:.2f}")
|
| 316 |
+
|
| 317 |
+
return max(0.0, min(10.0, final_score))
|
| 318 |
+
|
| 319 |
+
except Exception as e:
|
| 320 |
+
logger.error(f"Error in aesthetic evaluation: {str(e)}")
|
| 321 |
+
return self.fallback_aesthetic_score(image)
|
| 322 |
+
|
models/ai_detection_evaluator.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import torchvision.transforms as transforms
|
| 6 |
+
from transformers import AutoModel, AutoProcessor
|
| 7 |
+
import cv2
|
| 8 |
+
import logging
|
| 9 |
+
from scipy import ndimage
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
class AIDetectionEvaluator:
|
| 14 |
+
"""AI-generated image detection using multiple approaches"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 18 |
+
self.models = {}
|
| 19 |
+
self.processors = {}
|
| 20 |
+
self.load_models()
|
| 21 |
+
|
| 22 |
+
def load_models(self):
|
| 23 |
+
"""Load AI detection models"""
|
| 24 |
+
try:
|
| 25 |
+
# Load Sentry-Image model (primary)
|
| 26 |
+
logger.info("Loading Sentry-Image model...")
|
| 27 |
+
self.load_sentry_image()
|
| 28 |
+
|
| 29 |
+
# Load custom ensemble model (secondary)
|
| 30 |
+
logger.info("Loading custom ensemble model...")
|
| 31 |
+
self.load_custom_ensemble()
|
| 32 |
+
|
| 33 |
+
# Load traditional artifact detection
|
| 34 |
+
logger.info("Loading traditional artifact detection...")
|
| 35 |
+
self.load_artifact_detection()
|
| 36 |
+
|
| 37 |
+
except Exception as e:
|
| 38 |
+
logger.error(f"Error loading AI detection models: {str(e)}")
|
| 39 |
+
self.use_fallback_implementation()
|
| 40 |
+
|
| 41 |
+
def load_sentry_image(self):
|
| 42 |
+
"""Load Sentry-Image model"""
|
| 43 |
+
try:
|
| 44 |
+
# Placeholder implementation for Sentry-Image
|
| 45 |
+
# In production, this would load the actual Sentry-Image model
|
| 46 |
+
self.models['sentry'] = self.create_mock_detection_model()
|
| 47 |
+
self.processors['sentry'] = transforms.Compose([
|
| 48 |
+
transforms.Resize((224, 224)),
|
| 49 |
+
transforms.ToTensor(),
|
| 50 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 51 |
+
std=[0.229, 0.224, 0.225])
|
| 52 |
+
])
|
| 53 |
+
except Exception as e:
|
| 54 |
+
logger.warning(f"Could not load Sentry-Image: {str(e)}")
|
| 55 |
+
|
| 56 |
+
def load_custom_ensemble(self):
|
| 57 |
+
"""Load custom ensemble detection model"""
|
| 58 |
+
try:
|
| 59 |
+
# Placeholder for custom ensemble
|
| 60 |
+
self.models['ensemble'] = self.create_mock_detection_model()
|
| 61 |
+
self.processors['ensemble'] = transforms.Compose([
|
| 62 |
+
transforms.Resize((224, 224)),
|
| 63 |
+
transforms.ToTensor(),
|
| 64 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 65 |
+
std=[0.229, 0.224, 0.225])
|
| 66 |
+
])
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logger.warning(f"Could not load custom ensemble: {str(e)}")
|
| 69 |
+
|
| 70 |
+
def load_artifact_detection(self):
|
| 71 |
+
"""Load traditional artifact detection methods"""
|
| 72 |
+
try:
|
| 73 |
+
# These would be implemented using opencv and scipy
|
| 74 |
+
self.artifact_detection_available = True
|
| 75 |
+
except Exception as e:
|
| 76 |
+
logger.warning(f"Could not load artifact detection: {str(e)}")
|
| 77 |
+
self.artifact_detection_available = False
|
| 78 |
+
|
| 79 |
+
def create_mock_detection_model(self):
|
| 80 |
+
"""Create a mock detection model for demonstration"""
|
| 81 |
+
class MockDetectionModel(nn.Module):
|
| 82 |
+
def __init__(self):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.backbone = torch.nn.Sequential(
|
| 85 |
+
torch.nn.Conv2d(3, 64, 3, padding=1),
|
| 86 |
+
torch.nn.ReLU(),
|
| 87 |
+
torch.nn.Conv2d(64, 128, 3, padding=1),
|
| 88 |
+
torch.nn.ReLU(),
|
| 89 |
+
torch.nn.AdaptiveAvgPool2d((1, 1)),
|
| 90 |
+
torch.nn.Flatten(),
|
| 91 |
+
torch.nn.Linear(128, 64),
|
| 92 |
+
torch.nn.ReLU(),
|
| 93 |
+
torch.nn.Linear(64, 1),
|
| 94 |
+
torch.nn.Sigmoid()
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
def forward(self, x):
|
| 98 |
+
return self.backbone(x) # Returns probability 0-1
|
| 99 |
+
|
| 100 |
+
model = MockDetectionModel().to(self.device)
|
| 101 |
+
model.eval()
|
| 102 |
+
return model
|
| 103 |
+
|
| 104 |
+
def use_fallback_implementation(self):
|
| 105 |
+
"""Use simple fallback AI detection"""
|
| 106 |
+
logger.info("Using fallback AI detection implementation")
|
| 107 |
+
self.fallback_mode = True
|
| 108 |
+
|
| 109 |
+
def evaluate_with_sentry(self, image: Image.Image) -> float:
|
| 110 |
+
"""Evaluate AI generation probability using Sentry-Image"""
|
| 111 |
+
try:
|
| 112 |
+
if 'sentry' not in self.models:
|
| 113 |
+
return self.fallback_detection_score(image)
|
| 114 |
+
|
| 115 |
+
# Preprocess image
|
| 116 |
+
tensor = self.processors['sentry'](image).unsqueeze(0).to(self.device)
|
| 117 |
+
|
| 118 |
+
# Get prediction
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
probability = self.models['sentry'](tensor).item()
|
| 121 |
+
|
| 122 |
+
return max(0.0, min(1.0, probability))
|
| 123 |
+
|
| 124 |
+
except Exception as e:
|
| 125 |
+
logger.error(f"Error in Sentry evaluation: {str(e)}")
|
| 126 |
+
return self.fallback_detection_score(image)
|
| 127 |
+
|
| 128 |
+
def evaluate_with_ensemble(self, image: Image.Image) -> float:
|
| 129 |
+
"""Evaluate AI generation probability using custom ensemble"""
|
| 130 |
+
try:
|
| 131 |
+
if 'ensemble' not in self.models:
|
| 132 |
+
return self.fallback_detection_score(image)
|
| 133 |
+
|
| 134 |
+
# Preprocess image
|
| 135 |
+
tensor = self.processors['ensemble'](image).unsqueeze(0).to(self.device)
|
| 136 |
+
|
| 137 |
+
# Get prediction
|
| 138 |
+
with torch.no_grad():
|
| 139 |
+
probability = self.models['ensemble'](tensor).item()
|
| 140 |
+
|
| 141 |
+
return max(0.0, min(1.0, probability))
|
| 142 |
+
|
| 143 |
+
except Exception as e:
|
| 144 |
+
logger.error(f"Error in ensemble evaluation: {str(e)}")
|
| 145 |
+
return self.fallback_detection_score(image)
|
| 146 |
+
|
| 147 |
+
def detect_compression_artifacts(self, image: Image.Image) -> float:
|
| 148 |
+
"""Detect compression artifacts that might indicate AI generation"""
|
| 149 |
+
try:
|
| 150 |
+
# Convert to numpy array
|
| 151 |
+
img_array = np.array(image)
|
| 152 |
+
|
| 153 |
+
# Convert to grayscale
|
| 154 |
+
if len(img_array.shape) == 3:
|
| 155 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 156 |
+
else:
|
| 157 |
+
gray = img_array
|
| 158 |
+
|
| 159 |
+
# Detect JPEG compression artifacts using DCT analysis
|
| 160 |
+
# This is a simplified version - real implementation would be more complex
|
| 161 |
+
|
| 162 |
+
# Calculate local variance to detect blocking artifacts
|
| 163 |
+
kernel = np.ones((8, 8), np.float32) / 64
|
| 164 |
+
local_mean = cv2.filter2D(gray.astype(np.float32), -1, kernel)
|
| 165 |
+
local_variance = cv2.filter2D((gray.astype(np.float32) - local_mean) ** 2, -1, kernel)
|
| 166 |
+
|
| 167 |
+
# High variance in 8x8 blocks might indicate JPEG artifacts
|
| 168 |
+
block_variance = np.mean(local_variance)
|
| 169 |
+
|
| 170 |
+
# Normalize to 0-1 probability
|
| 171 |
+
artifact_probability = min(1.0, block_variance / 1000.0)
|
| 172 |
+
|
| 173 |
+
return artifact_probability
|
| 174 |
+
|
| 175 |
+
except Exception as e:
|
| 176 |
+
logger.error(f"Error in compression artifact detection: {str(e)}")
|
| 177 |
+
return 0.5
|
| 178 |
+
|
| 179 |
+
def detect_frequency_anomalies(self, image: Image.Image) -> float:
|
| 180 |
+
"""Detect frequency domain anomalies common in AI-generated images"""
|
| 181 |
+
try:
|
| 182 |
+
# Convert to numpy array and grayscale
|
| 183 |
+
img_array = np.array(image)
|
| 184 |
+
if len(img_array.shape) == 3:
|
| 185 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 186 |
+
else:
|
| 187 |
+
gray = img_array
|
| 188 |
+
|
| 189 |
+
# Apply FFT
|
| 190 |
+
f_transform = np.fft.fft2(gray)
|
| 191 |
+
f_shift = np.fft.fftshift(f_transform)
|
| 192 |
+
magnitude_spectrum = np.log(np.abs(f_shift) + 1)
|
| 193 |
+
|
| 194 |
+
# Analyze frequency distribution
|
| 195 |
+
# AI-generated images often have specific frequency patterns
|
| 196 |
+
|
| 197 |
+
# Calculate radial frequency distribution
|
| 198 |
+
h, w = magnitude_spectrum.shape
|
| 199 |
+
center_y, center_x = h // 2, w // 2
|
| 200 |
+
|
| 201 |
+
# Create radial mask
|
| 202 |
+
y, x = np.ogrid[:h, :w]
|
| 203 |
+
mask = (x - center_x) ** 2 + (y - center_y) ** 2
|
| 204 |
+
|
| 205 |
+
# Calculate mean magnitude at different frequencies
|
| 206 |
+
low_freq_mask = mask <= (min(h, w) // 8) ** 2
|
| 207 |
+
high_freq_mask = mask >= (min(h, w) // 4) ** 2
|
| 208 |
+
|
| 209 |
+
low_freq_energy = np.mean(magnitude_spectrum[low_freq_mask])
|
| 210 |
+
high_freq_energy = np.mean(magnitude_spectrum[high_freq_mask])
|
| 211 |
+
|
| 212 |
+
# AI images often have unusual low/high frequency ratios
|
| 213 |
+
if high_freq_energy > 0:
|
| 214 |
+
freq_ratio = low_freq_energy / high_freq_energy
|
| 215 |
+
# Normalize to probability
|
| 216 |
+
anomaly_probability = min(1.0, abs(freq_ratio - 10.0) / 20.0)
|
| 217 |
+
else:
|
| 218 |
+
anomaly_probability = 0.5
|
| 219 |
+
|
| 220 |
+
return anomaly_probability
|
| 221 |
+
|
| 222 |
+
except Exception as e:
|
| 223 |
+
logger.error(f"Error in frequency analysis: {str(e)}")
|
| 224 |
+
return 0.5
|
| 225 |
+
|
| 226 |
+
def detect_pixel_patterns(self, image: Image.Image) -> float:
|
| 227 |
+
"""Detect suspicious pixel patterns common in AI-generated images"""
|
| 228 |
+
try:
|
| 229 |
+
img_array = np.array(image)
|
| 230 |
+
|
| 231 |
+
# Check for perfect pixel repetitions (uncommon in natural images)
|
| 232 |
+
if len(img_array.shape) == 3:
|
| 233 |
+
# Flatten to check for repeated pixel values
|
| 234 |
+
pixels = img_array.reshape(-1, 3)
|
| 235 |
+
unique_pixels = np.unique(pixels, axis=0)
|
| 236 |
+
|
| 237 |
+
# Calculate pixel diversity
|
| 238 |
+
pixel_diversity = len(unique_pixels) / len(pixels)
|
| 239 |
+
|
| 240 |
+
# Very low diversity might indicate AI generation
|
| 241 |
+
if pixel_diversity < 0.1:
|
| 242 |
+
pattern_probability = 0.8
|
| 243 |
+
elif pixel_diversity < 0.3:
|
| 244 |
+
pattern_probability = 0.6
|
| 245 |
+
else:
|
| 246 |
+
pattern_probability = 0.2
|
| 247 |
+
else:
|
| 248 |
+
pattern_probability = 0.5
|
| 249 |
+
|
| 250 |
+
# Check for unnatural smoothness
|
| 251 |
+
if len(img_array.shape) == 3:
|
| 252 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 253 |
+
else:
|
| 254 |
+
gray = img_array
|
| 255 |
+
|
| 256 |
+
# Calculate local standard deviation
|
| 257 |
+
local_std = ndimage.generic_filter(gray.astype(np.float32), np.std, size=3)
|
| 258 |
+
avg_local_std = np.mean(local_std)
|
| 259 |
+
|
| 260 |
+
# Very smooth images might be AI-generated
|
| 261 |
+
if avg_local_std < 5.0:
|
| 262 |
+
smoothness_probability = 0.7
|
| 263 |
+
elif avg_local_std < 15.0:
|
| 264 |
+
smoothness_probability = 0.4
|
| 265 |
+
else:
|
| 266 |
+
smoothness_probability = 0.2
|
| 267 |
+
|
| 268 |
+
# Combine pattern and smoothness indicators
|
| 269 |
+
combined_probability = (pattern_probability + smoothness_probability) / 2
|
| 270 |
+
|
| 271 |
+
return max(0.0, min(1.0, combined_probability))
|
| 272 |
+
|
| 273 |
+
except Exception as e:
|
| 274 |
+
logger.error(f"Error in pixel pattern detection: {str(e)}")
|
| 275 |
+
return 0.5
|
| 276 |
+
|
| 277 |
+
def analyze_metadata_indicators(self, image: Image.Image) -> float:
|
| 278 |
+
"""Analyze image metadata for AI generation indicators"""
|
| 279 |
+
try:
|
| 280 |
+
# Check image format and properties
|
| 281 |
+
format_probability = 0.0
|
| 282 |
+
|
| 283 |
+
# PNG format is more common for AI-generated images
|
| 284 |
+
if image.format == 'PNG':
|
| 285 |
+
format_probability += 0.3
|
| 286 |
+
|
| 287 |
+
# Check for specific dimensions common in AI generation
|
| 288 |
+
width, height = image.size
|
| 289 |
+
|
| 290 |
+
# Common AI generation resolutions
|
| 291 |
+
ai_resolutions = [
|
| 292 |
+
(512, 512), (768, 768), (1024, 1024), # Square formats
|
| 293 |
+
(512, 768), (768, 512), # 2:3 ratios
|
| 294 |
+
(1024, 768), (768, 1024) # 4:3 ratios
|
| 295 |
+
]
|
| 296 |
+
|
| 297 |
+
if (width, height) in ai_resolutions:
|
| 298 |
+
format_probability += 0.4
|
| 299 |
+
|
| 300 |
+
# Check for perfect aspect ratios (less common in natural photos)
|
| 301 |
+
aspect_ratio = width / height
|
| 302 |
+
common_ai_ratios = [1.0, 1.5, 0.67, 1.33, 0.75, 1.25]
|
| 303 |
+
|
| 304 |
+
for ratio in common_ai_ratios:
|
| 305 |
+
if abs(aspect_ratio - ratio) < 0.01:
|
| 306 |
+
format_probability += 0.2
|
| 307 |
+
break
|
| 308 |
+
|
| 309 |
+
return max(0.0, min(1.0, format_probability))
|
| 310 |
+
|
| 311 |
+
except Exception as e:
|
| 312 |
+
logger.error(f"Error in metadata analysis: {str(e)}")
|
| 313 |
+
return 0.5
|
| 314 |
+
|
| 315 |
+
def fallback_detection_score(self, image: Image.Image) -> float:
|
| 316 |
+
"""Simple fallback AI detection"""
|
| 317 |
+
try:
|
| 318 |
+
# Combine multiple simple heuristics
|
| 319 |
+
scores = []
|
| 320 |
+
|
| 321 |
+
# Compression artifacts
|
| 322 |
+
artifact_score = self.detect_compression_artifacts(image)
|
| 323 |
+
scores.append(artifact_score)
|
| 324 |
+
|
| 325 |
+
# Frequency anomalies
|
| 326 |
+
freq_score = self.detect_frequency_anomalies(image)
|
| 327 |
+
scores.append(freq_score)
|
| 328 |
+
|
| 329 |
+
# Pixel patterns
|
| 330 |
+
pattern_score = self.detect_pixel_patterns(image)
|
| 331 |
+
scores.append(pattern_score)
|
| 332 |
+
|
| 333 |
+
# Metadata indicators
|
| 334 |
+
metadata_score = self.analyze_metadata_indicators(image)
|
| 335 |
+
scores.append(metadata_score)
|
| 336 |
+
|
| 337 |
+
# Average the scores
|
| 338 |
+
final_score = np.mean(scores)
|
| 339 |
+
|
| 340 |
+
return max(0.0, min(1.0, final_score))
|
| 341 |
+
|
| 342 |
+
except Exception:
|
| 343 |
+
return 0.5 # Default neutral probability
|
| 344 |
+
|
| 345 |
+
def evaluate(self, image: Image.Image) -> float:
|
| 346 |
+
"""
|
| 347 |
+
Evaluate probability that image is AI-generated
|
| 348 |
+
|
| 349 |
+
Args:
|
| 350 |
+
image: PIL Image to evaluate
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
AI generation probability from 0-1 (0 = likely real, 1 = likely AI)
|
| 354 |
+
"""
|
| 355 |
+
try:
|
| 356 |
+
scores = []
|
| 357 |
+
|
| 358 |
+
# Sentry-Image evaluation (primary)
|
| 359 |
+
sentry_score = self.evaluate_with_sentry(image)
|
| 360 |
+
scores.append(sentry_score)
|
| 361 |
+
|
| 362 |
+
# Custom ensemble evaluation (secondary)
|
| 363 |
+
ensemble_score = self.evaluate_with_ensemble(image)
|
| 364 |
+
scores.append(ensemble_score)
|
| 365 |
+
|
| 366 |
+
# Traditional artifact detection
|
| 367 |
+
artifact_score = self.fallback_detection_score(image)
|
| 368 |
+
scores.append(artifact_score)
|
| 369 |
+
|
| 370 |
+
# Ensemble scoring
|
| 371 |
+
weights = [0.5, 0.3, 0.2] # Sentry gets highest weight
|
| 372 |
+
final_score = sum(score * weight for score, weight in zip(scores, weights))
|
| 373 |
+
|
| 374 |
+
logger.info(f"AI detection scores - Sentry: {sentry_score:.3f}, "
|
| 375 |
+
f"Ensemble: {ensemble_score:.3f}, Artifacts: {artifact_score:.3f}, "
|
| 376 |
+
f"Final: {final_score:.3f}")
|
| 377 |
+
|
| 378 |
+
return max(0.0, min(1.0, final_score))
|
| 379 |
+
|
| 380 |
+
except Exception as e:
|
| 381 |
+
logger.error(f"Error in AI detection evaluation: {str(e)}")
|
| 382 |
+
return self.fallback_detection_score(image)
|
| 383 |
+
|
models/prompt_evaluator.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import clip
|
| 5 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
| 6 |
+
import logging
|
| 7 |
+
from sentence_transformers import SentenceTransformer, util
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class PromptEvaluator:
|
| 12 |
+
"""Prompt following assessment using CLIP and other vision-language models"""
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 16 |
+
self.models = {}
|
| 17 |
+
self.processors = {}
|
| 18 |
+
self.load_models()
|
| 19 |
+
|
| 20 |
+
def load_models(self):
|
| 21 |
+
"""Load prompt evaluation models"""
|
| 22 |
+
try:
|
| 23 |
+
# Load CLIP model (primary)
|
| 24 |
+
logger.info("Loading CLIP model...")
|
| 25 |
+
self.load_clip()
|
| 26 |
+
|
| 27 |
+
# Load BLIP-2 model (secondary)
|
| 28 |
+
logger.info("Loading BLIP-2 model...")
|
| 29 |
+
self.load_blip2()
|
| 30 |
+
|
| 31 |
+
# Load sentence transformer for text similarity
|
| 32 |
+
logger.info("Loading sentence transformer...")
|
| 33 |
+
self.load_sentence_transformer()
|
| 34 |
+
|
| 35 |
+
except Exception as e:
|
| 36 |
+
logger.error(f"Error loading prompt evaluation models: {str(e)}")
|
| 37 |
+
self.use_fallback_implementation()
|
| 38 |
+
|
| 39 |
+
def load_clip(self):
|
| 40 |
+
"""Load CLIP model"""
|
| 41 |
+
try:
|
| 42 |
+
model, preprocess = clip.load("ViT-B/32", device=self.device)
|
| 43 |
+
self.models['clip'] = model
|
| 44 |
+
self.processors['clip'] = preprocess
|
| 45 |
+
logger.info("CLIP model loaded successfully")
|
| 46 |
+
except Exception as e:
|
| 47 |
+
logger.warning(f"Could not load CLIP: {str(e)}")
|
| 48 |
+
|
| 49 |
+
def load_blip2(self):
|
| 50 |
+
"""Load BLIP-2 model"""
|
| 51 |
+
try:
|
| 52 |
+
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
| 53 |
+
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
|
| 54 |
+
model = model.to(self.device)
|
| 55 |
+
|
| 56 |
+
self.models['blip2'] = model
|
| 57 |
+
self.processors['blip2'] = processor
|
| 58 |
+
logger.info("BLIP-2 model loaded successfully")
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logger.warning(f"Could not load BLIP-2: {str(e)}")
|
| 61 |
+
|
| 62 |
+
def load_sentence_transformer(self):
|
| 63 |
+
"""Load sentence transformer for text similarity"""
|
| 64 |
+
try:
|
| 65 |
+
model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 66 |
+
self.models['sentence_transformer'] = model
|
| 67 |
+
logger.info("Sentence transformer loaded successfully")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logger.warning(f"Could not load sentence transformer: {str(e)}")
|
| 70 |
+
|
| 71 |
+
def use_fallback_implementation(self):
|
| 72 |
+
"""Use simple fallback prompt evaluation"""
|
| 73 |
+
logger.info("Using fallback prompt evaluation implementation")
|
| 74 |
+
self.fallback_mode = True
|
| 75 |
+
|
| 76 |
+
def evaluate_with_clip(self, image: Image.Image, prompt: str) -> float:
|
| 77 |
+
"""Evaluate prompt following using CLIP"""
|
| 78 |
+
try:
|
| 79 |
+
if 'clip' not in self.models:
|
| 80 |
+
return self.fallback_prompt_score(image, prompt)
|
| 81 |
+
|
| 82 |
+
model = self.models['clip']
|
| 83 |
+
preprocess = self.processors['clip']
|
| 84 |
+
|
| 85 |
+
# Preprocess image
|
| 86 |
+
image_tensor = preprocess(image).unsqueeze(0).to(self.device)
|
| 87 |
+
|
| 88 |
+
# Tokenize text
|
| 89 |
+
text_tokens = clip.tokenize([prompt]).to(self.device)
|
| 90 |
+
|
| 91 |
+
# Get features
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
image_features = model.encode_image(image_tensor)
|
| 94 |
+
text_features = model.encode_text(text_tokens)
|
| 95 |
+
|
| 96 |
+
# Normalize features
|
| 97 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
| 98 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
| 99 |
+
|
| 100 |
+
# Calculate similarity
|
| 101 |
+
similarity = (image_features @ text_features.T).item()
|
| 102 |
+
|
| 103 |
+
# Convert similarity to 0-10 scale
|
| 104 |
+
# CLIP similarity is typically between -1 and 1, but usually 0-1 for related content
|
| 105 |
+
score = max(0.0, min(10.0, (similarity + 1) * 5))
|
| 106 |
+
|
| 107 |
+
return score
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
logger.error(f"Error in CLIP evaluation: {str(e)}")
|
| 111 |
+
return self.fallback_prompt_score(image, prompt)
|
| 112 |
+
|
| 113 |
+
def evaluate_with_blip2(self, image: Image.Image, prompt: str) -> float:
|
| 114 |
+
"""Evaluate prompt following using BLIP-2"""
|
| 115 |
+
try:
|
| 116 |
+
if 'blip2' not in self.models:
|
| 117 |
+
return self.fallback_prompt_score(image, prompt)
|
| 118 |
+
|
| 119 |
+
model = self.models['blip2']
|
| 120 |
+
processor = self.processors['blip2']
|
| 121 |
+
|
| 122 |
+
# Generate caption for the image
|
| 123 |
+
inputs = processor(image, return_tensors="pt").to(self.device)
|
| 124 |
+
|
| 125 |
+
with torch.no_grad():
|
| 126 |
+
out = model.generate(**inputs, max_length=50)
|
| 127 |
+
generated_caption = processor.decode(out[0], skip_special_tokens=True)
|
| 128 |
+
|
| 129 |
+
# Compare generated caption with original prompt using text similarity
|
| 130 |
+
if 'sentence_transformer' in self.models:
|
| 131 |
+
similarity_score = self.calculate_text_similarity(prompt, generated_caption)
|
| 132 |
+
else:
|
| 133 |
+
# Simple word overlap fallback
|
| 134 |
+
similarity_score = self.simple_text_similarity(prompt, generated_caption)
|
| 135 |
+
|
| 136 |
+
return similarity_score
|
| 137 |
+
|
| 138 |
+
except Exception as e:
|
| 139 |
+
logger.error(f"Error in BLIP-2 evaluation: {str(e)}")
|
| 140 |
+
return self.fallback_prompt_score(image, prompt)
|
| 141 |
+
|
| 142 |
+
def calculate_text_similarity(self, text1: str, text2: str) -> float:
|
| 143 |
+
"""Calculate semantic similarity between two texts"""
|
| 144 |
+
try:
|
| 145 |
+
model = self.models['sentence_transformer']
|
| 146 |
+
|
| 147 |
+
# Encode texts
|
| 148 |
+
embeddings = model.encode([text1, text2])
|
| 149 |
+
|
| 150 |
+
# Calculate cosine similarity
|
| 151 |
+
similarity = util.cos_sim(embeddings[0], embeddings[1]).item()
|
| 152 |
+
|
| 153 |
+
# Convert to 0-10 scale
|
| 154 |
+
score = max(0.0, min(10.0, (similarity + 1) * 5))
|
| 155 |
+
|
| 156 |
+
return score
|
| 157 |
+
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.error(f"Error calculating text similarity: {str(e)}")
|
| 160 |
+
return self.simple_text_similarity(text1, text2)
|
| 161 |
+
|
| 162 |
+
def simple_text_similarity(self, text1: str, text2: str) -> float:
|
| 163 |
+
"""Simple word overlap similarity"""
|
| 164 |
+
try:
|
| 165 |
+
# Convert to lowercase and split into words
|
| 166 |
+
words1 = set(text1.lower().split())
|
| 167 |
+
words2 = set(text2.lower().split())
|
| 168 |
+
|
| 169 |
+
# Calculate Jaccard similarity
|
| 170 |
+
intersection = len(words1.intersection(words2))
|
| 171 |
+
union = len(words1.union(words2))
|
| 172 |
+
|
| 173 |
+
if union == 0:
|
| 174 |
+
return 0.0
|
| 175 |
+
|
| 176 |
+
jaccard_similarity = intersection / union
|
| 177 |
+
|
| 178 |
+
# Convert to 0-10 scale
|
| 179 |
+
score = jaccard_similarity * 10
|
| 180 |
+
|
| 181 |
+
return max(0.0, min(10.0, score))
|
| 182 |
+
|
| 183 |
+
except Exception:
|
| 184 |
+
return 5.0 # Default neutral score
|
| 185 |
+
|
| 186 |
+
def extract_key_concepts(self, prompt: str) -> list:
|
| 187 |
+
"""Extract key concepts from prompt for detailed analysis"""
|
| 188 |
+
try:
|
| 189 |
+
# Simple keyword extraction
|
| 190 |
+
# In production, this could use more sophisticated NLP
|
| 191 |
+
|
| 192 |
+
# Remove common words
|
| 193 |
+
stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should'}
|
| 194 |
+
|
| 195 |
+
words = prompt.lower().split()
|
| 196 |
+
key_concepts = [word for word in words if word not in stop_words and len(word) > 2]
|
| 197 |
+
|
| 198 |
+
return key_concepts
|
| 199 |
+
|
| 200 |
+
except Exception:
|
| 201 |
+
return []
|
| 202 |
+
|
| 203 |
+
def evaluate_concept_presence(self, image: Image.Image, concepts: list) -> float:
|
| 204 |
+
"""Evaluate presence of specific concepts in image"""
|
| 205 |
+
try:
|
| 206 |
+
if 'clip' not in self.models or not concepts:
|
| 207 |
+
return 5.0
|
| 208 |
+
|
| 209 |
+
model = self.models['clip']
|
| 210 |
+
preprocess = self.processors['clip']
|
| 211 |
+
|
| 212 |
+
# Preprocess image
|
| 213 |
+
image_tensor = preprocess(image).unsqueeze(0).to(self.device)
|
| 214 |
+
|
| 215 |
+
# Create concept queries
|
| 216 |
+
concept_queries = [f"a photo of {concept}" for concept in concepts]
|
| 217 |
+
|
| 218 |
+
# Tokenize concepts
|
| 219 |
+
text_tokens = clip.tokenize(concept_queries).to(self.device)
|
| 220 |
+
|
| 221 |
+
# Get features
|
| 222 |
+
with torch.no_grad():
|
| 223 |
+
image_features = model.encode_image(image_tensor)
|
| 224 |
+
text_features = model.encode_text(text_tokens)
|
| 225 |
+
|
| 226 |
+
# Normalize features
|
| 227 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
| 228 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
| 229 |
+
|
| 230 |
+
# Calculate similarities
|
| 231 |
+
similarities = (image_features @ text_features.T).squeeze(0)
|
| 232 |
+
|
| 233 |
+
# Average similarity across concepts
|
| 234 |
+
avg_similarity = similarities.mean().item()
|
| 235 |
+
|
| 236 |
+
# Convert to 0-10 scale
|
| 237 |
+
score = max(0.0, min(10.0, (avg_similarity + 1) * 5))
|
| 238 |
+
|
| 239 |
+
return score
|
| 240 |
+
|
| 241 |
+
except Exception as e:
|
| 242 |
+
logger.error(f"Error in concept presence evaluation: {str(e)}")
|
| 243 |
+
return 5.0
|
| 244 |
+
|
| 245 |
+
def fallback_prompt_score(self, image: Image.Image, prompt: str) -> float:
|
| 246 |
+
"""Simple fallback prompt evaluation"""
|
| 247 |
+
try:
|
| 248 |
+
# Very basic evaluation based on prompt length and image properties
|
| 249 |
+
prompt_length = len(prompt.split())
|
| 250 |
+
|
| 251 |
+
# Longer, more detailed prompts might be harder to follow perfectly
|
| 252 |
+
if prompt_length < 5:
|
| 253 |
+
length_penalty = 0.0
|
| 254 |
+
elif prompt_length < 15:
|
| 255 |
+
length_penalty = 0.5
|
| 256 |
+
else:
|
| 257 |
+
length_penalty = 1.0
|
| 258 |
+
|
| 259 |
+
# Base score
|
| 260 |
+
base_score = 7.0 - length_penalty
|
| 261 |
+
|
| 262 |
+
return max(0.0, min(10.0, base_score))
|
| 263 |
+
|
| 264 |
+
except Exception:
|
| 265 |
+
return 5.0 # Default neutral score
|
| 266 |
+
|
| 267 |
+
def evaluate(self, image: Image.Image, prompt: str) -> float:
|
| 268 |
+
"""
|
| 269 |
+
Evaluate how well the image follows the given prompt
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
image: PIL Image to evaluate
|
| 273 |
+
prompt: Text prompt to compare against
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
Prompt following score from 0-10
|
| 277 |
+
"""
|
| 278 |
+
try:
|
| 279 |
+
if not prompt or not prompt.strip():
|
| 280 |
+
return 0.0 # No prompt to evaluate against
|
| 281 |
+
|
| 282 |
+
scores = []
|
| 283 |
+
|
| 284 |
+
# CLIP evaluation (primary)
|
| 285 |
+
clip_score = self.evaluate_with_clip(image, prompt)
|
| 286 |
+
scores.append(clip_score)
|
| 287 |
+
|
| 288 |
+
# BLIP-2 evaluation (secondary)
|
| 289 |
+
blip2_score = self.evaluate_with_blip2(image, prompt)
|
| 290 |
+
scores.append(blip2_score)
|
| 291 |
+
|
| 292 |
+
# Concept presence evaluation
|
| 293 |
+
key_concepts = self.extract_key_concepts(prompt)
|
| 294 |
+
concept_score = self.evaluate_concept_presence(image, key_concepts)
|
| 295 |
+
scores.append(concept_score)
|
| 296 |
+
|
| 297 |
+
# Ensemble scoring
|
| 298 |
+
weights = [0.5, 0.3, 0.2] # CLIP gets highest weight
|
| 299 |
+
final_score = sum(score * weight for score, weight in zip(scores, weights))
|
| 300 |
+
|
| 301 |
+
logger.info(f"Prompt scores - CLIP: {clip_score:.2f}, BLIP-2: {blip2_score:.2f}, "
|
| 302 |
+
f"Concepts: {concept_score:.2f}, Final: {final_score:.2f}")
|
| 303 |
+
|
| 304 |
+
return max(0.0, min(10.0, final_score))
|
| 305 |
+
|
| 306 |
+
except Exception as e:
|
| 307 |
+
logger.error(f"Error in prompt evaluation: {str(e)}")
|
| 308 |
+
return self.fallback_prompt_score(image, prompt)
|
| 309 |
+
|
models/quality_evaluator.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import torchvision.transforms as transforms
|
| 6 |
+
from transformers import AutoModel, AutoProcessor
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class QualityEvaluator:
|
| 12 |
+
"""Image quality assessment using multiple SOTA models"""
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 16 |
+
self.models = {}
|
| 17 |
+
self.processors = {}
|
| 18 |
+
self.load_models()
|
| 19 |
+
|
| 20 |
+
def load_models(self):
|
| 21 |
+
"""Load quality assessment models"""
|
| 22 |
+
try:
|
| 23 |
+
# Load LAR-IQA model (primary)
|
| 24 |
+
logger.info("Loading LAR-IQA model...")
|
| 25 |
+
self.load_lar_iqa()
|
| 26 |
+
|
| 27 |
+
# Load DGIQA model (secondary)
|
| 28 |
+
logger.info("Loading DGIQA model...")
|
| 29 |
+
self.load_dgiqa()
|
| 30 |
+
|
| 31 |
+
# Load traditional metrics as fallback
|
| 32 |
+
logger.info("Loading traditional quality metrics...")
|
| 33 |
+
self.load_traditional_metrics()
|
| 34 |
+
|
| 35 |
+
except Exception as e:
|
| 36 |
+
logger.error(f"Error loading quality models: {str(e)}")
|
| 37 |
+
# Use fallback implementation
|
| 38 |
+
self.use_fallback_implementation()
|
| 39 |
+
|
| 40 |
+
def load_lar_iqa(self):
|
| 41 |
+
"""Load LAR-IQA model"""
|
| 42 |
+
try:
|
| 43 |
+
# For now, use a placeholder implementation
|
| 44 |
+
# In production, this would load the actual LAR-IQA model
|
| 45 |
+
self.models['lar_iqa'] = self.create_mock_model()
|
| 46 |
+
self.processors['lar_iqa'] = transforms.Compose([
|
| 47 |
+
transforms.Resize((224, 224)),
|
| 48 |
+
transforms.ToTensor(),
|
| 49 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 50 |
+
std=[0.229, 0.224, 0.225])
|
| 51 |
+
])
|
| 52 |
+
except Exception as e:
|
| 53 |
+
logger.warning(f"Could not load LAR-IQA: {str(e)}")
|
| 54 |
+
|
| 55 |
+
def load_dgiqa(self):
|
| 56 |
+
"""Load DGIQA model"""
|
| 57 |
+
try:
|
| 58 |
+
# Placeholder implementation
|
| 59 |
+
self.models['dgiqa'] = self.create_mock_model()
|
| 60 |
+
self.processors['dgiqa'] = transforms.Compose([
|
| 61 |
+
transforms.Resize((224, 224)),
|
| 62 |
+
transforms.ToTensor(),
|
| 63 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 64 |
+
std=[0.229, 0.224, 0.225])
|
| 65 |
+
])
|
| 66 |
+
except Exception as e:
|
| 67 |
+
logger.warning(f"Could not load DGIQA: {str(e)}")
|
| 68 |
+
|
| 69 |
+
def load_traditional_metrics(self):
|
| 70 |
+
"""Load traditional quality metrics (BRISQUE, NIQE, etc.)"""
|
| 71 |
+
try:
|
| 72 |
+
# These would be implemented using scikit-image or opencv
|
| 73 |
+
self.traditional_metrics_available = True
|
| 74 |
+
except Exception as e:
|
| 75 |
+
logger.warning(f"Could not load traditional metrics: {str(e)}")
|
| 76 |
+
self.traditional_metrics_available = False
|
| 77 |
+
|
| 78 |
+
def create_mock_model(self):
|
| 79 |
+
"""Create a mock model for demonstration purposes"""
|
| 80 |
+
class MockQualityModel(nn.Module):
|
| 81 |
+
def __init__(self):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.backbone = torch.nn.Sequential(
|
| 84 |
+
torch.nn.Conv2d(3, 64, 3, padding=1),
|
| 85 |
+
torch.nn.ReLU(),
|
| 86 |
+
torch.nn.AdaptiveAvgPool2d((1, 1)),
|
| 87 |
+
torch.nn.Flatten(),
|
| 88 |
+
torch.nn.Linear(64, 1),
|
| 89 |
+
torch.nn.Sigmoid()
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
def forward(self, x):
|
| 93 |
+
return self.backbone(x) * 10 # Scale to 0-10
|
| 94 |
+
|
| 95 |
+
model = MockQualityModel().to(self.device)
|
| 96 |
+
model.eval()
|
| 97 |
+
return model
|
| 98 |
+
|
| 99 |
+
def use_fallback_implementation(self):
|
| 100 |
+
"""Use simple fallback quality assessment"""
|
| 101 |
+
logger.info("Using fallback quality assessment implementation")
|
| 102 |
+
self.fallback_mode = True
|
| 103 |
+
|
| 104 |
+
def evaluate_with_lar_iqa(self, image: Image.Image) -> float:
|
| 105 |
+
"""Evaluate image quality using LAR-IQA"""
|
| 106 |
+
try:
|
| 107 |
+
if 'lar_iqa' not in self.models:
|
| 108 |
+
return self.fallback_quality_score(image)
|
| 109 |
+
|
| 110 |
+
# Preprocess image
|
| 111 |
+
tensor = self.processors['lar_iqa'](image).unsqueeze(0).to(self.device)
|
| 112 |
+
|
| 113 |
+
# Get prediction
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
score = self.models['lar_iqa'](tensor).item()
|
| 116 |
+
|
| 117 |
+
return max(0.0, min(10.0, score))
|
| 118 |
+
|
| 119 |
+
except Exception as e:
|
| 120 |
+
logger.error(f"Error in LAR-IQA evaluation: {str(e)}")
|
| 121 |
+
return self.fallback_quality_score(image)
|
| 122 |
+
|
| 123 |
+
def evaluate_with_dgiqa(self, image: Image.Image) -> float:
|
| 124 |
+
"""Evaluate image quality using DGIQA"""
|
| 125 |
+
try:
|
| 126 |
+
if 'dgiqa' not in self.models:
|
| 127 |
+
return self.fallback_quality_score(image)
|
| 128 |
+
|
| 129 |
+
# Preprocess image
|
| 130 |
+
tensor = self.processors['dgiqa'](image).unsqueeze(0).to(self.device)
|
| 131 |
+
|
| 132 |
+
# Get prediction
|
| 133 |
+
with torch.no_grad():
|
| 134 |
+
score = self.models['dgiqa'](tensor).item()
|
| 135 |
+
|
| 136 |
+
return max(0.0, min(10.0, score))
|
| 137 |
+
|
| 138 |
+
except Exception as e:
|
| 139 |
+
logger.error(f"Error in DGIQA evaluation: {str(e)}")
|
| 140 |
+
return self.fallback_quality_score(image)
|
| 141 |
+
|
| 142 |
+
def evaluate_traditional_metrics(self, image: Image.Image) -> float:
|
| 143 |
+
"""Evaluate using traditional quality metrics"""
|
| 144 |
+
try:
|
| 145 |
+
# Convert to numpy array
|
| 146 |
+
img_array = np.array(image)
|
| 147 |
+
|
| 148 |
+
# Simple quality metrics based on image statistics
|
| 149 |
+
# In production, this would use BRISQUE, NIQE, etc.
|
| 150 |
+
|
| 151 |
+
# Calculate sharpness (Laplacian variance)
|
| 152 |
+
from scipy import ndimage
|
| 153 |
+
gray = np.dot(img_array[...,:3], [0.2989, 0.5870, 0.1140])
|
| 154 |
+
laplacian_var = ndimage.laplace(gray).var()
|
| 155 |
+
sharpness_score = min(10.0, laplacian_var / 100.0)
|
| 156 |
+
|
| 157 |
+
# Calculate contrast
|
| 158 |
+
contrast_score = min(10.0, gray.std() / 25.0)
|
| 159 |
+
|
| 160 |
+
# Calculate brightness distribution
|
| 161 |
+
brightness_score = 10.0 - abs(gray.mean() - 127.5) / 12.75
|
| 162 |
+
|
| 163 |
+
# Combine scores
|
| 164 |
+
quality_score = (sharpness_score * 0.4 +
|
| 165 |
+
contrast_score * 0.3 +
|
| 166 |
+
brightness_score * 0.3)
|
| 167 |
+
|
| 168 |
+
return max(0.0, min(10.0, quality_score))
|
| 169 |
+
|
| 170 |
+
except Exception as e:
|
| 171 |
+
logger.error(f"Error in traditional metrics: {str(e)}")
|
| 172 |
+
return 5.0 # Default score
|
| 173 |
+
|
| 174 |
+
def fallback_quality_score(self, image: Image.Image) -> float:
|
| 175 |
+
"""Simple fallback quality assessment"""
|
| 176 |
+
try:
|
| 177 |
+
# Basic quality assessment based on image properties
|
| 178 |
+
width, height = image.size
|
| 179 |
+
|
| 180 |
+
# Resolution score
|
| 181 |
+
total_pixels = width * height
|
| 182 |
+
resolution_score = min(10.0, total_pixels / 100000.0) # Normalize by 1MP
|
| 183 |
+
|
| 184 |
+
# Aspect ratio score (prefer standard ratios)
|
| 185 |
+
aspect_ratio = width / height
|
| 186 |
+
if 0.5 <= aspect_ratio <= 2.0:
|
| 187 |
+
aspect_score = 8.0
|
| 188 |
+
else:
|
| 189 |
+
aspect_score = 5.0
|
| 190 |
+
|
| 191 |
+
# File format score (prefer lossless)
|
| 192 |
+
format_score = 8.0 if image.format == 'PNG' else 6.0
|
| 193 |
+
|
| 194 |
+
# Combine scores
|
| 195 |
+
quality_score = (resolution_score * 0.5 +
|
| 196 |
+
aspect_score * 0.3 +
|
| 197 |
+
format_score * 0.2)
|
| 198 |
+
|
| 199 |
+
return max(0.0, min(10.0, quality_score))
|
| 200 |
+
|
| 201 |
+
except Exception:
|
| 202 |
+
return 5.0 # Default neutral score
|
| 203 |
+
|
| 204 |
+
def evaluate(self, image: Image.Image, anime_mode: bool = False) -> float:
|
| 205 |
+
"""
|
| 206 |
+
Evaluate image quality using ensemble of models
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
image: PIL Image to evaluate
|
| 210 |
+
anime_mode: Whether to use anime-specific evaluation
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
Quality score from 0-10
|
| 214 |
+
"""
|
| 215 |
+
try:
|
| 216 |
+
scores = []
|
| 217 |
+
|
| 218 |
+
# LAR-IQA evaluation
|
| 219 |
+
lar_score = self.evaluate_with_lar_iqa(image)
|
| 220 |
+
scores.append(lar_score)
|
| 221 |
+
|
| 222 |
+
# DGIQA evaluation
|
| 223 |
+
dgiqa_score = self.evaluate_with_dgiqa(image)
|
| 224 |
+
scores.append(dgiqa_score)
|
| 225 |
+
|
| 226 |
+
# Traditional metrics
|
| 227 |
+
traditional_score = self.evaluate_traditional_metrics(image)
|
| 228 |
+
scores.append(traditional_score)
|
| 229 |
+
|
| 230 |
+
# Ensemble scoring
|
| 231 |
+
if anime_mode:
|
| 232 |
+
# For anime images, weight traditional metrics higher
|
| 233 |
+
# as they may be more reliable for stylized content
|
| 234 |
+
weights = [0.3, 0.3, 0.4]
|
| 235 |
+
else:
|
| 236 |
+
# For realistic images, weight modern models higher
|
| 237 |
+
weights = [0.4, 0.4, 0.2]
|
| 238 |
+
|
| 239 |
+
final_score = sum(score * weight for score, weight in zip(scores, weights))
|
| 240 |
+
|
| 241 |
+
logger.info(f"Quality scores - LAR: {lar_score:.2f}, DGIQA: {dgiqa_score:.2f}, "
|
| 242 |
+
f"Traditional: {traditional_score:.2f}, Final: {final_score:.2f}")
|
| 243 |
+
|
| 244 |
+
return max(0.0, min(10.0, final_score))
|
| 245 |
+
|
| 246 |
+
except Exception as e:
|
| 247 |
+
logger.error(f"Error in quality evaluation: {str(e)}")
|
| 248 |
+
return self.fallback_quality_score(image)
|
| 249 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.0.0
|
| 2 |
+
Pillow>=9.0.0
|
| 3 |
+
numpy>=1.21.0
|
| 4 |
+
pandas>=1.3.0
|
| 5 |
+
scipy>=1.9.0
|
| 6 |
+
|
| 7 |
+
# Optional dependencies for full functionality
|
| 8 |
+
# Uncomment these for production deployment with real models
|
| 9 |
+
# torch>=2.0.0
|
| 10 |
+
# torchvision>=0.15.0
|
| 11 |
+
# transformers>=4.30.0
|
| 12 |
+
# opencv-python>=4.5.0
|
| 13 |
+
# scikit-image>=0.19.0
|
| 14 |
+
# huggingface-hub>=0.15.0
|
| 15 |
+
# accelerate>=0.20.0
|
| 16 |
+
# timm>=0.9.0
|
| 17 |
+
# sentence-transformers>=2.2.0
|
| 18 |
+
# git+https://github.com/openai/CLIP.git
|
| 19 |
+
|
test_images/anime_character.png
ADDED
|
Git LFS Details
|
test_images/landscape_art.png
ADDED
|
Git LFS Details
|
test_images/realistic_portrait.png
ADDED
|
Git LFS Details
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Utils package for image evaluation
|
| 2 |
+
|
utils/metadata_extractor.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from PIL.PngImagePlugin import PngInfo
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
def extract_png_metadata(image_path: str) -> dict:
|
| 10 |
+
"""
|
| 11 |
+
Extract metadata from PNG files generated by A1111 or ComfyUI
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
image_path: Path to the PNG image file
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
Dictionary containing extracted metadata
|
| 18 |
+
"""
|
| 19 |
+
try:
|
| 20 |
+
with Image.open(image_path) as img:
|
| 21 |
+
metadata = {}
|
| 22 |
+
|
| 23 |
+
# Check for A1111 metadata
|
| 24 |
+
a1111_data = extract_a1111_metadata(img)
|
| 25 |
+
if a1111_data:
|
| 26 |
+
metadata.update(a1111_data)
|
| 27 |
+
metadata['source'] = 'automatic1111'
|
| 28 |
+
|
| 29 |
+
# Check for ComfyUI metadata
|
| 30 |
+
comfyui_data = extract_comfyui_metadata(img)
|
| 31 |
+
if comfyui_data:
|
| 32 |
+
metadata.update(comfyui_data)
|
| 33 |
+
metadata['source'] = 'comfyui'
|
| 34 |
+
|
| 35 |
+
# Check for other common metadata fields
|
| 36 |
+
other_data = extract_other_metadata(img)
|
| 37 |
+
if other_data:
|
| 38 |
+
metadata.update(other_data)
|
| 39 |
+
|
| 40 |
+
return metadata if metadata else None
|
| 41 |
+
|
| 42 |
+
except Exception as e:
|
| 43 |
+
logger.error(f"Error extracting metadata from {image_path}: {str(e)}")
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
def extract_a1111_metadata(img: Image.Image) -> dict:
|
| 47 |
+
"""Extract Automatic1111 metadata from PNG text fields"""
|
| 48 |
+
try:
|
| 49 |
+
metadata = {}
|
| 50 |
+
|
| 51 |
+
# A1111 stores metadata in the 'parameters' text field
|
| 52 |
+
if hasattr(img, 'text') and 'parameters' in img.text:
|
| 53 |
+
parameters_text = img.text['parameters']
|
| 54 |
+
metadata.update(parse_a1111_parameters(parameters_text))
|
| 55 |
+
|
| 56 |
+
return metadata
|
| 57 |
+
|
| 58 |
+
except Exception as e:
|
| 59 |
+
logger.error(f"Error extracting A1111 metadata: {str(e)}")
|
| 60 |
+
return {}
|
| 61 |
+
|
| 62 |
+
def parse_a1111_parameters(parameters_text: str) -> dict:
|
| 63 |
+
"""Parse A1111 parameters text into structured data"""
|
| 64 |
+
try:
|
| 65 |
+
metadata = {}
|
| 66 |
+
|
| 67 |
+
# Split the parameters text into lines
|
| 68 |
+
lines = parameters_text.strip().split('\n')
|
| 69 |
+
|
| 70 |
+
# The first line is usually the prompt
|
| 71 |
+
if lines:
|
| 72 |
+
metadata['prompt'] = lines[0].strip()
|
| 73 |
+
|
| 74 |
+
# Look for negative prompt
|
| 75 |
+
negative_prompt_match = re.search(r'Negative prompt:\s*(.+?)(?:\n|$)', parameters_text, re.DOTALL)
|
| 76 |
+
if negative_prompt_match:
|
| 77 |
+
metadata['negative_prompt'] = negative_prompt_match.group(1).strip()
|
| 78 |
+
|
| 79 |
+
# Extract other parameters using regex
|
| 80 |
+
param_patterns = {
|
| 81 |
+
'steps': r'Steps:\s*(\d+)',
|
| 82 |
+
'sampler': r'Sampler:\s*([^,\n]+)',
|
| 83 |
+
'cfg_scale': r'CFG scale:\s*([\d.]+)',
|
| 84 |
+
'seed': r'Seed:\s*(\d+)',
|
| 85 |
+
'size': r'Size:\s*(\d+x\d+)',
|
| 86 |
+
'model_hash': r'Model hash:\s*([a-fA-F0-9]+)',
|
| 87 |
+
'model': r'Model:\s*([^,\n]+)',
|
| 88 |
+
'denoising_strength': r'Denoising strength:\s*([\d.]+)',
|
| 89 |
+
'clip_skip': r'Clip skip:\s*(\d+)',
|
| 90 |
+
'ensd': r'ENSD:\s*(\d+)'
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
for param_name, pattern in param_patterns.items():
|
| 94 |
+
match = re.search(pattern, parameters_text)
|
| 95 |
+
if match:
|
| 96 |
+
value = match.group(1).strip()
|
| 97 |
+
# Convert numeric values
|
| 98 |
+
if param_name in ['steps', 'seed', 'clip_skip', 'ensd']:
|
| 99 |
+
metadata[param_name] = int(value)
|
| 100 |
+
elif param_name in ['cfg_scale', 'denoising_strength']:
|
| 101 |
+
metadata[param_name] = float(value)
|
| 102 |
+
else:
|
| 103 |
+
metadata[param_name] = value
|
| 104 |
+
|
| 105 |
+
# Parse size into width and height
|
| 106 |
+
if 'size' in metadata:
|
| 107 |
+
size_match = re.match(r'(\d+)x(\d+)', metadata['size'])
|
| 108 |
+
if size_match:
|
| 109 |
+
metadata['width'] = int(size_match.group(1))
|
| 110 |
+
metadata['height'] = int(size_match.group(2))
|
| 111 |
+
|
| 112 |
+
return metadata
|
| 113 |
+
|
| 114 |
+
except Exception as e:
|
| 115 |
+
logger.error(f"Error parsing A1111 parameters: {str(e)}")
|
| 116 |
+
return {}
|
| 117 |
+
|
| 118 |
+
def extract_comfyui_metadata(img: Image.Image) -> dict:
|
| 119 |
+
"""Extract ComfyUI metadata from PNG text fields"""
|
| 120 |
+
try:
|
| 121 |
+
metadata = {}
|
| 122 |
+
|
| 123 |
+
# ComfyUI stores metadata in 'workflow' and 'prompt' text fields
|
| 124 |
+
if hasattr(img, 'text'):
|
| 125 |
+
# Check for workflow data
|
| 126 |
+
if 'workflow' in img.text:
|
| 127 |
+
try:
|
| 128 |
+
workflow_data = json.loads(img.text['workflow'])
|
| 129 |
+
metadata.update(parse_comfyui_workflow(workflow_data))
|
| 130 |
+
except json.JSONDecodeError:
|
| 131 |
+
logger.warning("Could not parse ComfyUI workflow JSON")
|
| 132 |
+
|
| 133 |
+
# Check for prompt data
|
| 134 |
+
if 'prompt' in img.text:
|
| 135 |
+
try:
|
| 136 |
+
prompt_data = json.loads(img.text['prompt'])
|
| 137 |
+
metadata.update(parse_comfyui_prompt(prompt_data))
|
| 138 |
+
except json.JSONDecodeError:
|
| 139 |
+
logger.warning("Could not parse ComfyUI prompt JSON")
|
| 140 |
+
|
| 141 |
+
return metadata
|
| 142 |
+
|
| 143 |
+
except Exception as e:
|
| 144 |
+
logger.error(f"Error extracting ComfyUI metadata: {str(e)}")
|
| 145 |
+
return {}
|
| 146 |
+
|
| 147 |
+
def parse_comfyui_workflow(workflow_data: dict) -> dict:
|
| 148 |
+
"""Parse ComfyUI workflow data"""
|
| 149 |
+
try:
|
| 150 |
+
metadata = {}
|
| 151 |
+
|
| 152 |
+
# Extract nodes from workflow
|
| 153 |
+
if 'nodes' in workflow_data:
|
| 154 |
+
nodes = workflow_data['nodes']
|
| 155 |
+
|
| 156 |
+
# Look for common node types
|
| 157 |
+
for node in nodes:
|
| 158 |
+
if isinstance(node, dict):
|
| 159 |
+
node_type = node.get('type', '')
|
| 160 |
+
|
| 161 |
+
# Extract prompt from text nodes
|
| 162 |
+
if 'text' in node_type.lower() or 'prompt' in node_type.lower():
|
| 163 |
+
if 'widgets_values' in node and node['widgets_values']:
|
| 164 |
+
text_value = node['widgets_values'][0]
|
| 165 |
+
if isinstance(text_value, str) and len(text_value) > 10:
|
| 166 |
+
if 'prompt' not in metadata:
|
| 167 |
+
metadata['prompt'] = text_value
|
| 168 |
+
|
| 169 |
+
# Extract sampler settings
|
| 170 |
+
elif 'sampler' in node_type.lower():
|
| 171 |
+
if 'widgets_values' in node:
|
| 172 |
+
values = node['widgets_values']
|
| 173 |
+
if len(values) >= 3:
|
| 174 |
+
metadata['steps'] = values[0] if isinstance(values[0], int) else None
|
| 175 |
+
metadata['cfg_scale'] = values[1] if isinstance(values[1], (int, float)) else None
|
| 176 |
+
metadata['sampler'] = values[2] if isinstance(values[2], str) else None
|
| 177 |
+
|
| 178 |
+
return metadata
|
| 179 |
+
|
| 180 |
+
except Exception as e:
|
| 181 |
+
logger.error(f"Error parsing ComfyUI workflow: {str(e)}")
|
| 182 |
+
return {}
|
| 183 |
+
|
| 184 |
+
def parse_comfyui_prompt(prompt_data: dict) -> dict:
|
| 185 |
+
"""Parse ComfyUI prompt data"""
|
| 186 |
+
try:
|
| 187 |
+
metadata = {}
|
| 188 |
+
|
| 189 |
+
# ComfyUI prompt data is usually a nested structure
|
| 190 |
+
# Extract common parameters from the prompt structure
|
| 191 |
+
for node_id, node_data in prompt_data.items():
|
| 192 |
+
if isinstance(node_data, dict) and 'inputs' in node_data:
|
| 193 |
+
inputs = node_data['inputs']
|
| 194 |
+
|
| 195 |
+
# Look for text inputs (prompts)
|
| 196 |
+
for key, value in inputs.items():
|
| 197 |
+
if isinstance(value, str) and len(value) > 10:
|
| 198 |
+
if 'text' in key.lower() or 'prompt' in key.lower():
|
| 199 |
+
if 'prompt' not in metadata:
|
| 200 |
+
metadata['prompt'] = value
|
| 201 |
+
|
| 202 |
+
# Look for numeric parameters
|
| 203 |
+
if 'steps' in inputs:
|
| 204 |
+
metadata['steps'] = inputs['steps']
|
| 205 |
+
if 'cfg' in inputs:
|
| 206 |
+
metadata['cfg_scale'] = inputs['cfg']
|
| 207 |
+
if 'seed' in inputs:
|
| 208 |
+
metadata['seed'] = inputs['seed']
|
| 209 |
+
if 'denoise' in inputs:
|
| 210 |
+
metadata['denoising_strength'] = inputs['denoise']
|
| 211 |
+
|
| 212 |
+
return metadata
|
| 213 |
+
|
| 214 |
+
except Exception as e:
|
| 215 |
+
logger.error(f"Error parsing ComfyUI prompt: {str(e)}")
|
| 216 |
+
return {}
|
| 217 |
+
|
| 218 |
+
def extract_other_metadata(img: Image.Image) -> dict:
|
| 219 |
+
"""Extract other common metadata fields"""
|
| 220 |
+
try:
|
| 221 |
+
metadata = {}
|
| 222 |
+
|
| 223 |
+
# Check standard EXIF data
|
| 224 |
+
if hasattr(img, '_getexif') and img._getexif():
|
| 225 |
+
exif_data = img._getexif()
|
| 226 |
+
|
| 227 |
+
# Extract relevant EXIF fields
|
| 228 |
+
exif_fields = {
|
| 229 |
+
'software': 0x0131, # Software tag
|
| 230 |
+
'artist': 0x013B, # Artist tag
|
| 231 |
+
'copyright': 0x8298 # Copyright tag
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
for field_name, tag_id in exif_fields.items():
|
| 235 |
+
if tag_id in exif_data:
|
| 236 |
+
metadata[field_name] = exif_data[tag_id]
|
| 237 |
+
|
| 238 |
+
# Check for other text fields that might contain prompts
|
| 239 |
+
if hasattr(img, 'text'):
|
| 240 |
+
text_fields = ['description', 'comment', 'title', 'subject']
|
| 241 |
+
for field in text_fields:
|
| 242 |
+
if field in img.text:
|
| 243 |
+
value = img.text[field].strip()
|
| 244 |
+
if len(value) > 10 and 'prompt' not in metadata:
|
| 245 |
+
metadata['prompt'] = value
|
| 246 |
+
|
| 247 |
+
return metadata
|
| 248 |
+
|
| 249 |
+
except Exception as e:
|
| 250 |
+
logger.error(f"Error extracting other metadata: {str(e)}")
|
| 251 |
+
return {}
|
| 252 |
+
|
| 253 |
+
def clean_prompt_text(prompt: str) -> str:
|
| 254 |
+
"""Clean and normalize prompt text"""
|
| 255 |
+
try:
|
| 256 |
+
if not prompt:
|
| 257 |
+
return ""
|
| 258 |
+
|
| 259 |
+
# Remove extra whitespace
|
| 260 |
+
prompt = re.sub(r'\s+', ' ', prompt.strip())
|
| 261 |
+
|
| 262 |
+
# Remove common prefixes/suffixes
|
| 263 |
+
prefixes_to_remove = [
|
| 264 |
+
'prompt:', 'positive prompt:', 'text prompt:',
|
| 265 |
+
'description:', 'caption:'
|
| 266 |
+
]
|
| 267 |
+
|
| 268 |
+
for prefix in prefixes_to_remove:
|
| 269 |
+
if prompt.lower().startswith(prefix):
|
| 270 |
+
prompt = prompt[len(prefix):].strip()
|
| 271 |
+
|
| 272 |
+
return prompt
|
| 273 |
+
|
| 274 |
+
except Exception:
|
| 275 |
+
return prompt if prompt else ""
|
| 276 |
+
|
| 277 |
+
def get_generation_parameters(metadata: dict) -> dict:
|
| 278 |
+
"""Extract key generation parameters for display"""
|
| 279 |
+
try:
|
| 280 |
+
params = {}
|
| 281 |
+
|
| 282 |
+
# Essential parameters
|
| 283 |
+
if 'prompt' in metadata:
|
| 284 |
+
params['prompt'] = clean_prompt_text(metadata['prompt'])
|
| 285 |
+
|
| 286 |
+
if 'negative_prompt' in metadata:
|
| 287 |
+
params['negative_prompt'] = clean_prompt_text(metadata['negative_prompt'])
|
| 288 |
+
|
| 289 |
+
# Technical parameters
|
| 290 |
+
technical_params = ['steps', 'cfg_scale', 'sampler', 'seed', 'model', 'width', 'height']
|
| 291 |
+
for param in technical_params:
|
| 292 |
+
if param in metadata:
|
| 293 |
+
params[param] = metadata[param]
|
| 294 |
+
|
| 295 |
+
# Source information
|
| 296 |
+
if 'source' in metadata:
|
| 297 |
+
params['source'] = metadata['source']
|
| 298 |
+
|
| 299 |
+
return params
|
| 300 |
+
|
| 301 |
+
except Exception as e:
|
| 302 |
+
logger.error(f"Error extracting generation parameters: {str(e)}")
|
| 303 |
+
return {}
|
| 304 |
+
|
utils/scoring.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
logger = logging.getLogger(__name__)
|
| 5 |
+
|
| 6 |
+
def calculate_final_score(
|
| 7 |
+
quality_score: float,
|
| 8 |
+
aesthetics_score: float,
|
| 9 |
+
prompt_score: float,
|
| 10 |
+
ai_detection_score: float,
|
| 11 |
+
has_prompt: bool = True
|
| 12 |
+
) -> float:
|
| 13 |
+
"""
|
| 14 |
+
Calculate weighted composite score for image evaluation
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
quality_score: Technical image quality (0-10)
|
| 18 |
+
aesthetics_score: Visual appeal score (0-10)
|
| 19 |
+
prompt_score: Prompt adherence score (0-10)
|
| 20 |
+
ai_detection_score: AI generation probability (0-1)
|
| 21 |
+
has_prompt: Whether prompt metadata is available
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Final composite score (0-10)
|
| 25 |
+
"""
|
| 26 |
+
try:
|
| 27 |
+
# Validate input scores
|
| 28 |
+
quality_score = max(0.0, min(10.0, quality_score))
|
| 29 |
+
aesthetics_score = max(0.0, min(10.0, aesthetics_score))
|
| 30 |
+
prompt_score = max(0.0, min(10.0, prompt_score))
|
| 31 |
+
ai_detection_score = max(0.0, min(1.0, ai_detection_score))
|
| 32 |
+
|
| 33 |
+
if has_prompt:
|
| 34 |
+
# Standard weights when prompt is available
|
| 35 |
+
weights = {
|
| 36 |
+
'quality': 0.25, # 25% - Technical quality
|
| 37 |
+
'aesthetics': 0.35, # 35% - Visual appeal (highest weight)
|
| 38 |
+
'prompt': 0.25, # 25% - Prompt following
|
| 39 |
+
'ai_detection': 0.15 # 15% - AI detection (inverted)
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
# Calculate weighted score
|
| 43 |
+
score = (
|
| 44 |
+
quality_score * weights['quality'] +
|
| 45 |
+
aesthetics_score * weights['aesthetics'] +
|
| 46 |
+
prompt_score * weights['prompt'] +
|
| 47 |
+
(1 - ai_detection_score) * weights['ai_detection']
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
# Redistribute prompt weight when no prompt available
|
| 51 |
+
weights = {
|
| 52 |
+
'quality': 0.375, # 25% + 12.5% from prompt
|
| 53 |
+
'aesthetics': 0.475, # 35% + 12.5% from prompt
|
| 54 |
+
'ai_detection': 0.15 # 15% - AI detection (inverted)
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
# Calculate weighted score without prompt
|
| 58 |
+
score = (
|
| 59 |
+
quality_score * weights['quality'] +
|
| 60 |
+
aesthetics_score * weights['aesthetics'] +
|
| 61 |
+
(1 - ai_detection_score) * weights['ai_detection']
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Ensure score is in valid range
|
| 65 |
+
final_score = max(0.0, min(10.0, score))
|
| 66 |
+
|
| 67 |
+
logger.debug(f"Score calculation - Quality: {quality_score:.2f}, "
|
| 68 |
+
f"Aesthetics: {aesthetics_score:.2f}, Prompt: {prompt_score:.2f}, "
|
| 69 |
+
f"AI Detection: {ai_detection_score:.3f}, Has Prompt: {has_prompt}, "
|
| 70 |
+
f"Final: {final_score:.2f}")
|
| 71 |
+
|
| 72 |
+
return final_score
|
| 73 |
+
|
| 74 |
+
except Exception as e:
|
| 75 |
+
logger.error(f"Error calculating final score: {str(e)}")
|
| 76 |
+
return 5.0 # Default neutral score
|
| 77 |
+
|
| 78 |
+
def calculate_category_rankings(scores_list: list, category: str) -> list:
|
| 79 |
+
"""
|
| 80 |
+
Calculate rankings for a specific category
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
scores_list: List of score dictionaries
|
| 84 |
+
category: Category to rank by ('quality_score', 'aesthetics_score', etc.)
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
List of rankings (1-based)
|
| 88 |
+
"""
|
| 89 |
+
try:
|
| 90 |
+
if not scores_list or category not in scores_list[0]:
|
| 91 |
+
return [1] * len(scores_list)
|
| 92 |
+
|
| 93 |
+
# Extract scores for the category
|
| 94 |
+
category_scores = [item[category] for item in scores_list]
|
| 95 |
+
|
| 96 |
+
# Calculate rankings (higher score = better rank)
|
| 97 |
+
rankings = []
|
| 98 |
+
for i, score in enumerate(category_scores):
|
| 99 |
+
rank = 1
|
| 100 |
+
for j, other_score in enumerate(category_scores):
|
| 101 |
+
if other_score > score:
|
| 102 |
+
rank += 1
|
| 103 |
+
rankings.append(rank)
|
| 104 |
+
|
| 105 |
+
return rankings
|
| 106 |
+
|
| 107 |
+
except Exception as e:
|
| 108 |
+
logger.error(f"Error calculating category rankings: {str(e)}")
|
| 109 |
+
return list(range(1, len(scores_list) + 1))
|
| 110 |
+
|
| 111 |
+
def normalize_scores(scores: list, target_range: tuple = (0, 10)) -> list:
|
| 112 |
+
"""
|
| 113 |
+
Normalize a list of scores to a target range
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
scores: List of numerical scores
|
| 117 |
+
target_range: Tuple of (min, max) for target range
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
List of normalized scores
|
| 121 |
+
"""
|
| 122 |
+
try:
|
| 123 |
+
if not scores:
|
| 124 |
+
return []
|
| 125 |
+
|
| 126 |
+
min_score = min(scores)
|
| 127 |
+
max_score = max(scores)
|
| 128 |
+
|
| 129 |
+
# Avoid division by zero
|
| 130 |
+
if max_score == min_score:
|
| 131 |
+
return [target_range[1]] * len(scores)
|
| 132 |
+
|
| 133 |
+
target_min, target_max = target_range
|
| 134 |
+
target_span = target_max - target_min
|
| 135 |
+
score_span = max_score - min_score
|
| 136 |
+
|
| 137 |
+
normalized = []
|
| 138 |
+
for score in scores:
|
| 139 |
+
normalized_score = target_min + (score - min_score) * target_span / score_span
|
| 140 |
+
normalized.append(max(target_min, min(target_max, normalized_score)))
|
| 141 |
+
|
| 142 |
+
return normalized
|
| 143 |
+
|
| 144 |
+
except Exception as e:
|
| 145 |
+
logger.error(f"Error normalizing scores: {str(e)}")
|
| 146 |
+
return scores
|
| 147 |
+
|
| 148 |
+
def calculate_confidence_intervals(scores: list, confidence_level: float = 0.95) -> dict:
|
| 149 |
+
"""
|
| 150 |
+
Calculate confidence intervals for a list of scores
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
scores: List of numerical scores
|
| 154 |
+
confidence_level: Confidence level (0-1)
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
Dictionary with mean, std, lower_bound, upper_bound
|
| 158 |
+
"""
|
| 159 |
+
try:
|
| 160 |
+
if not scores:
|
| 161 |
+
return {'mean': 0, 'std': 0, 'lower_bound': 0, 'upper_bound': 0}
|
| 162 |
+
|
| 163 |
+
mean_score = np.mean(scores)
|
| 164 |
+
std_score = np.std(scores)
|
| 165 |
+
|
| 166 |
+
# Calculate confidence interval using t-distribution
|
| 167 |
+
from scipy import stats
|
| 168 |
+
n = len(scores)
|
| 169 |
+
t_value = stats.t.ppf((1 + confidence_level) / 2, n - 1)
|
| 170 |
+
margin_error = t_value * std_score / np.sqrt(n)
|
| 171 |
+
|
| 172 |
+
return {
|
| 173 |
+
'mean': float(mean_score),
|
| 174 |
+
'std': float(std_score),
|
| 175 |
+
'lower_bound': float(mean_score - margin_error),
|
| 176 |
+
'upper_bound': float(mean_score + margin_error)
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
except Exception as e:
|
| 180 |
+
logger.error(f"Error calculating confidence intervals: {str(e)}")
|
| 181 |
+
return {'mean': 0, 'std': 0, 'lower_bound': 0, 'upper_bound': 0}
|
| 182 |
+
|
| 183 |
+
def detect_outliers(scores: list, method: str = 'iqr') -> list:
|
| 184 |
+
"""
|
| 185 |
+
Detect outliers in a list of scores
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
scores: List of numerical scores
|
| 189 |
+
method: Method to use ('iqr', 'zscore', 'modified_zscore')
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
List of boolean values indicating outliers
|
| 193 |
+
"""
|
| 194 |
+
try:
|
| 195 |
+
if not scores or len(scores) < 3:
|
| 196 |
+
return [False] * len(scores)
|
| 197 |
+
|
| 198 |
+
scores_array = np.array(scores)
|
| 199 |
+
|
| 200 |
+
if method == 'iqr':
|
| 201 |
+
# Interquartile Range method
|
| 202 |
+
q1 = np.percentile(scores_array, 25)
|
| 203 |
+
q3 = np.percentile(scores_array, 75)
|
| 204 |
+
iqr = q3 - q1
|
| 205 |
+
lower_bound = q1 - 1.5 * iqr
|
| 206 |
+
upper_bound = q3 + 1.5 * iqr
|
| 207 |
+
outliers = (scores_array < lower_bound) | (scores_array > upper_bound)
|
| 208 |
+
|
| 209 |
+
elif method == 'zscore':
|
| 210 |
+
# Z-score method
|
| 211 |
+
z_scores = np.abs(stats.zscore(scores_array))
|
| 212 |
+
outliers = z_scores > 2.5
|
| 213 |
+
|
| 214 |
+
elif method == 'modified_zscore':
|
| 215 |
+
# Modified Z-score method (more robust)
|
| 216 |
+
median = np.median(scores_array)
|
| 217 |
+
mad = np.median(np.abs(scores_array - median))
|
| 218 |
+
modified_z_scores = 0.6745 * (scores_array - median) / mad
|
| 219 |
+
outliers = np.abs(modified_z_scores) > 3.5
|
| 220 |
+
|
| 221 |
+
else:
|
| 222 |
+
outliers = [False] * len(scores)
|
| 223 |
+
|
| 224 |
+
return outliers.tolist()
|
| 225 |
+
|
| 226 |
+
except Exception as e:
|
| 227 |
+
logger.error(f"Error detecting outliers: {str(e)}")
|
| 228 |
+
return [False] * len(scores)
|
| 229 |
+
|
| 230 |
+
def calculate_score_distribution(scores: list) -> dict:
|
| 231 |
+
"""
|
| 232 |
+
Calculate distribution statistics for scores
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
scores: List of numerical scores
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
Dictionary with distribution statistics
|
| 239 |
+
"""
|
| 240 |
+
try:
|
| 241 |
+
if not scores:
|
| 242 |
+
return {}
|
| 243 |
+
|
| 244 |
+
scores_array = np.array(scores)
|
| 245 |
+
|
| 246 |
+
distribution = {
|
| 247 |
+
'count': len(scores),
|
| 248 |
+
'mean': float(np.mean(scores_array)),
|
| 249 |
+
'median': float(np.median(scores_array)),
|
| 250 |
+
'std': float(np.std(scores_array)),
|
| 251 |
+
'min': float(np.min(scores_array)),
|
| 252 |
+
'max': float(np.max(scores_array)),
|
| 253 |
+
'q1': float(np.percentile(scores_array, 25)),
|
| 254 |
+
'q3': float(np.percentile(scores_array, 75)),
|
| 255 |
+
'skewness': float(stats.skew(scores_array)),
|
| 256 |
+
'kurtosis': float(stats.kurtosis(scores_array))
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
return distribution
|
| 260 |
+
|
| 261 |
+
except Exception as e:
|
| 262 |
+
logger.error(f"Error calculating score distribution: {str(e)}")
|
| 263 |
+
return {}
|
| 264 |
+
|
| 265 |
+
def apply_score_adjustments(
|
| 266 |
+
scores: dict,
|
| 267 |
+
adjustments: dict = None
|
| 268 |
+
) -> dict:
|
| 269 |
+
"""
|
| 270 |
+
Apply custom score adjustments based on specific criteria
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
scores: Dictionary of scores
|
| 274 |
+
adjustments: Dictionary of adjustment parameters
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
Dictionary of adjusted scores
|
| 278 |
+
"""
|
| 279 |
+
try:
|
| 280 |
+
if adjustments is None:
|
| 281 |
+
adjustments = {}
|
| 282 |
+
|
| 283 |
+
adjusted_scores = scores.copy()
|
| 284 |
+
|
| 285 |
+
# Apply anime mode adjustments
|
| 286 |
+
if adjustments.get('anime_mode', False):
|
| 287 |
+
# Boost aesthetics score for anime images
|
| 288 |
+
if 'aesthetics_score' in adjusted_scores:
|
| 289 |
+
adjusted_scores['aesthetics_score'] *= 1.1
|
| 290 |
+
adjusted_scores['aesthetics_score'] = min(10.0, adjusted_scores['aesthetics_score'])
|
| 291 |
+
|
| 292 |
+
# Apply quality penalties for low resolution
|
| 293 |
+
if adjustments.get('penalize_low_resolution', True):
|
| 294 |
+
width = adjustments.get('width', 1024)
|
| 295 |
+
height = adjustments.get('height', 1024)
|
| 296 |
+
total_pixels = width * height
|
| 297 |
+
|
| 298 |
+
if total_pixels < 262144: # Less than 512x512
|
| 299 |
+
penalty = 0.8
|
| 300 |
+
if 'quality_score' in adjusted_scores:
|
| 301 |
+
adjusted_scores['quality_score'] *= penalty
|
| 302 |
+
|
| 303 |
+
# Apply prompt complexity adjustments
|
| 304 |
+
prompt_length = adjustments.get('prompt_length', 0)
|
| 305 |
+
if prompt_length > 0 and 'prompt_score' in adjusted_scores:
|
| 306 |
+
if prompt_length > 100: # Very long prompts are harder to follow
|
| 307 |
+
adjusted_scores['prompt_score'] *= 0.95
|
| 308 |
+
elif prompt_length < 10: # Very short prompts are easier
|
| 309 |
+
adjusted_scores['prompt_score'] *= 1.05
|
| 310 |
+
adjusted_scores['prompt_score'] = min(10.0, adjusted_scores['prompt_score'])
|
| 311 |
+
|
| 312 |
+
return adjusted_scores
|
| 313 |
+
|
| 314 |
+
except Exception as e:
|
| 315 |
+
logger.error(f"Error applying score adjustments: {str(e)}")
|
| 316 |
+
return scores
|
| 317 |
+
|
| 318 |
+
def generate_score_summary(results_list: list) -> dict:
|
| 319 |
+
"""
|
| 320 |
+
Generate summary statistics for a batch of evaluation results
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
results_list: List of result dictionaries
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
Dictionary with summary statistics
|
| 327 |
+
"""
|
| 328 |
+
try:
|
| 329 |
+
if not results_list:
|
| 330 |
+
return {}
|
| 331 |
+
|
| 332 |
+
# Extract scores by category
|
| 333 |
+
categories = ['quality_score', 'aesthetics_score', 'prompt_score', 'ai_detection_score', 'final_score']
|
| 334 |
+
summary = {}
|
| 335 |
+
|
| 336 |
+
for category in categories:
|
| 337 |
+
if category in results_list[0]:
|
| 338 |
+
scores = [result[category] for result in results_list if category in result]
|
| 339 |
+
if scores:
|
| 340 |
+
summary[category] = calculate_score_distribution(scores)
|
| 341 |
+
|
| 342 |
+
# Calculate overall statistics
|
| 343 |
+
final_scores = [result['final_score'] for result in results_list if 'final_score' in result]
|
| 344 |
+
if final_scores:
|
| 345 |
+
summary['overall'] = {
|
| 346 |
+
'total_images': len(results_list),
|
| 347 |
+
'average_score': np.mean(final_scores),
|
| 348 |
+
'best_score': max(final_scores),
|
| 349 |
+
'worst_score': min(final_scores),
|
| 350 |
+
'score_range': max(final_scores) - min(final_scores),
|
| 351 |
+
'images_with_prompts': sum(1 for r in results_list if r.get('has_prompt', False))
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
return summary
|
| 355 |
+
|
| 356 |
+
except Exception as e:
|
| 357 |
+
logger.error(f"Error generating score summary: {str(e)}")
|
| 358 |
+
return {}
|
| 359 |
+
|