Comp-I / src /utils /image_utils.py
axrzce's picture
Deploy from GitHub main
338d95d verified
raw
history blame
10.9 kB
"""
Image processing utilities for CompI Phase 2.E: Style Reference/Example Image Integration
This module provides utilities for:
- Image loading from files and URLs
- Image validation and preprocessing
- Style analysis and feature extraction
- Image format conversion and optimization
"""
import os
import io
import requests
import hashlib
from typing import Optional, Tuple, Dict, Any, Union, List
from pathlib import Path
import logging
import torch
import numpy as np
from PIL import Image, ImageStat, ImageFilter
import cv2
from src.utils.logging_utils import setup_logger
logger = setup_logger(__name__)
class ImageProcessor:
"""
Handles image loading, validation, and preprocessing for style reference
"""
def __init__(self, max_size: Tuple[int, int] = (1024, 1024)):
self.max_size = max_size
self.supported_formats = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
def load_image_from_url(
self,
url: str,
timeout: int = 10,
max_file_size: int = 10 * 1024 * 1024 # 10MB
) -> Optional[Image.Image]:
"""
Load image from URL with validation and error handling
Args:
url: Image URL
timeout: Request timeout in seconds
max_file_size: Maximum file size in bytes
Returns:
PIL Image or None if failed
"""
try:
logger.info(f"Loading image from URL: {url}")
# Validate URL format
if not url.startswith(('http://', 'https://')):
logger.error(f"Invalid URL format: {url}")
return None
# Make request with headers to avoid blocking
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
}
response = requests.get(url, timeout=timeout, headers=headers, stream=True)
response.raise_for_status()
# Check content type
content_type = response.headers.get('content-type', '').lower()
if not any(img_type in content_type for img_type in ['image/', 'jpeg', 'png', 'webp']):
logger.error(f"Invalid content type: {content_type}")
return None
# Check file size
content_length = response.headers.get('content-length')
if content_length and int(content_length) > max_file_size:
logger.error(f"File too large: {content_length} bytes")
return None
# Load image data
image_data = io.BytesIO()
downloaded_size = 0
for chunk in response.iter_content(chunk_size=8192):
downloaded_size += len(chunk)
if downloaded_size > max_file_size:
logger.error(f"File too large during download: {downloaded_size} bytes")
return None
image_data.write(chunk)
image_data.seek(0)
# Open and validate image
image = Image.open(image_data)
image = image.convert('RGB')
logger.info(f"Successfully loaded image: {image.size}")
return image
except requests.exceptions.RequestException as e:
logger.error(f"Request error loading image from {url}: {e}")
return None
except Exception as e:
logger.error(f"Error loading image from {url}: {e}")
return None
def load_image_from_file(self, file_path: Union[str, Path]) -> Optional[Image.Image]:
"""
Load image from local file with validation
Args:
file_path: Path to image file
Returns:
PIL Image or None if failed
"""
try:
file_path = Path(file_path)
if not file_path.exists():
logger.error(f"File does not exist: {file_path}")
return None
if file_path.suffix.lower() not in self.supported_formats:
logger.error(f"Unsupported format: {file_path.suffix}")
return None
image = Image.open(file_path)
image = image.convert('RGB')
logger.info(f"Successfully loaded image from file: {image.size}")
return image
except Exception as e:
logger.error(f"Error loading image from {file_path}: {e}")
return None
def preprocess_image(
self,
image: Image.Image,
target_size: Optional[Tuple[int, int]] = None,
maintain_aspect_ratio: bool = True
) -> Image.Image:
"""
Preprocess image for stable diffusion
Args:
image: Input PIL Image
target_size: Target size (width, height)
maintain_aspect_ratio: Whether to maintain aspect ratio
Returns:
Preprocessed PIL Image
"""
if target_size is None:
target_size = (512, 512) # Default SD size
try:
# Resize image
if maintain_aspect_ratio:
image.thumbnail(target_size, Image.Resampling.LANCZOS)
# Create new image with target size and paste resized image
new_image = Image.new('RGB', target_size, (255, 255, 255))
paste_x = (target_size[0] - image.width) // 2
paste_y = (target_size[1] - image.height) // 2
new_image.paste(image, (paste_x, paste_y))
image = new_image
else:
image = image.resize(target_size, Image.Resampling.LANCZOS)
logger.info(f"Preprocessed image to size: {image.size}")
return image
except Exception as e:
logger.error(f"Error preprocessing image: {e}")
return image
def analyze_image_properties(self, image: Image.Image) -> Dict[str, Any]:
"""
Analyze image properties for style reference
Args:
image: PIL Image to analyze
Returns:
Dictionary of image properties
"""
try:
# Basic properties
width, height = image.size
aspect_ratio = width / height
# Color analysis
stat = ImageStat.Stat(image)
avg_brightness = sum(stat.mean) / len(stat.mean)
avg_contrast = sum(stat.stddev) / len(stat.stddev)
# Convert to numpy for additional analysis
img_array = np.array(image)
# Color distribution
r_mean, g_mean, b_mean = np.mean(img_array, axis=(0, 1))
color_variance = np.var(img_array, axis=(0, 1))
# Edge detection for complexity
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
edges = cv2.Canny(gray, 50, 150)
edge_density = np.sum(edges > 0) / (width * height)
properties = {
'dimensions': (width, height),
'aspect_ratio': aspect_ratio,
'brightness': avg_brightness,
'contrast': avg_contrast,
'color_means': (float(r_mean), float(g_mean), float(b_mean)),
'color_variance': color_variance.tolist(),
'edge_density': float(edge_density),
'file_size_pixels': width * height
}
logger.info(f"Analyzed image properties: {properties}")
return properties
except Exception as e:
logger.error(f"Error analyzing image properties: {e}")
return {}
def generate_image_hash(self, image: Image.Image) -> str:
"""
Generate hash for image deduplication
Args:
image: PIL Image
Returns:
MD5 hash string
"""
try:
# Convert image to bytes
img_bytes = io.BytesIO()
image.save(img_bytes, format='PNG')
img_bytes = img_bytes.getvalue()
# Generate hash
hash_md5 = hashlib.md5(img_bytes)
return hash_md5.hexdigest()
except Exception as e:
logger.error(f"Error generating image hash: {e}")
return ""
class StyleAnalyzer:
"""
Analyzes style characteristics of reference images
"""
def __init__(self):
self.style_keywords = {
'realistic': ['photo', 'realistic', 'detailed', 'sharp'],
'artistic': ['painting', 'artistic', 'brushstrokes', 'canvas'],
'anime': ['anime', 'manga', 'cartoon', 'stylized'],
'abstract': ['abstract', 'geometric', 'surreal', 'conceptual'],
'vintage': ['vintage', 'retro', 'aged', 'classic'],
'modern': ['modern', 'contemporary', 'clean', 'minimal']
}
def suggest_style_keywords(self, image_properties: Dict[str, Any]) -> List[str]:
"""
Suggest style keywords based on image analysis
Args:
image_properties: Properties from analyze_image_properties
Returns:
List of suggested style keywords
"""
suggestions = []
try:
brightness = image_properties.get('brightness', 128)
contrast = image_properties.get('contrast', 50)
edge_density = image_properties.get('edge_density', 0.1)
# Brightness-based suggestions
if brightness < 100:
suggestions.extend(['dark', 'moody', 'dramatic'])
elif brightness > 180:
suggestions.extend(['bright', 'light', 'airy'])
# Contrast-based suggestions
if contrast > 80:
suggestions.extend(['high contrast', 'bold', 'striking'])
elif contrast < 30:
suggestions.extend(['soft', 'gentle', 'muted'])
# Edge density-based suggestions
if edge_density > 0.2:
suggestions.extend(['detailed', 'complex', 'intricate'])
elif edge_density < 0.05:
suggestions.extend(['smooth', 'simple', 'minimalist'])
return list(set(suggestions)) # Remove duplicates
except Exception as e:
logger.error(f"Error suggesting style keywords: {e}")
return []