|
""" |
|
CompI Phase 2.E: Style Reference/Example Image to AI Art Generation |
|
|
|
This module implements multimodal AI art generation that combines: |
|
- Text prompts with style and mood conditioning |
|
- Reference image style transfer and guidance |
|
- Image-to-image generation with controllable strength |
|
- Support for both local files and web URLs |
|
- Advanced style analysis and prompt enhancement |
|
|
|
Features: |
|
- Support for various image formats and web sources |
|
- Real-time image analysis and style suggestion |
|
- Controllable reference strength for creative flexibility |
|
- Comprehensive metadata logging and filename conventions |
|
- Batch processing capabilities with multiple variations |
|
""" |
|
|
|
import os |
|
import sys |
|
import torch |
|
import json |
|
from datetime import datetime |
|
from typing import Dict, List, Optional, Tuple, Union |
|
from pathlib import Path |
|
import logging |
|
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) |
|
|
|
from diffusers import StableDiffusionImg2ImgPipeline, StableDiffusionPipeline |
|
from PIL import Image |
|
import numpy as np |
|
|
|
from src.utils.image_utils import ImageProcessor, StyleAnalyzer |
|
from src.utils.logging_utils import setup_logger |
|
from src.utils.file_utils import ensure_directory_exists, generate_filename |
|
from src.config import ( |
|
STABLE_DIFFUSION_IMG2IMG_MODEL, |
|
OUTPUTS_DIR, |
|
DEFAULT_IMAGE_SIZE, |
|
DEFAULT_INFERENCE_STEPS, |
|
DEFAULT_GUIDANCE_SCALE |
|
) |
|
|
|
|
|
logger = setup_logger(__name__) |
|
|
|
class CompIPhase2ERefImageToImage: |
|
""" |
|
CompI Phase 2.E: Style Reference/Example Image to AI Art Generation System |
|
|
|
Combines text prompts with reference image style guidance for enhanced creativity |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model_name: str = STABLE_DIFFUSION_IMG2IMG_MODEL, |
|
device: Optional[str] = None, |
|
enable_attention_slicing: bool = True, |
|
enable_memory_efficient_attention: bool = True |
|
): |
|
""" |
|
Initialize the CompI Phase 2.E system |
|
|
|
Args: |
|
model_name: Hugging Face model identifier |
|
device: Device to run on ('cuda', 'cpu', or None for auto) |
|
enable_attention_slicing: Enable attention slicing for memory efficiency |
|
enable_memory_efficient_attention: Enable memory efficient attention |
|
""" |
|
self.model_name = model_name |
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
self.image_processor = ImageProcessor() |
|
self.style_analyzer = StyleAnalyzer() |
|
|
|
|
|
self._img2img_pipeline = None |
|
self._txt2img_pipeline = None |
|
|
|
|
|
self.enable_attention_slicing = enable_attention_slicing |
|
self.enable_memory_efficient_attention = enable_memory_efficient_attention |
|
|
|
logger.info(f"Initialized CompI Phase 2.E on device: {self.device}") |
|
|
|
@property |
|
def img2img_pipeline(self) -> StableDiffusionImg2ImgPipeline: |
|
"""Lazy load img2img pipeline""" |
|
if self._img2img_pipeline is None: |
|
logger.info(f"Loading img2img pipeline: {self.model_name}") |
|
self._img2img_pipeline = StableDiffusionImg2ImgPipeline.from_pretrained( |
|
self.model_name, |
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, |
|
safety_checker=None, |
|
requires_safety_checker=False |
|
) |
|
self._img2img_pipeline = self._img2img_pipeline.to(self.device) |
|
|
|
if self.enable_attention_slicing: |
|
self._img2img_pipeline.enable_attention_slicing() |
|
if self.enable_memory_efficient_attention and hasattr(self._img2img_pipeline, 'enable_memory_efficient_attention'): |
|
self._img2img_pipeline.enable_memory_efficient_attention() |
|
|
|
return self._img2img_pipeline |
|
|
|
@property |
|
def txt2img_pipeline(self) -> StableDiffusionPipeline: |
|
"""Lazy load txt2img pipeline for fallback""" |
|
if self._txt2img_pipeline is None: |
|
logger.info(f"Loading txt2img pipeline: {self.model_name}") |
|
self._txt2img_pipeline = StableDiffusionPipeline.from_pretrained( |
|
self.model_name, |
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, |
|
safety_checker=None, |
|
requires_safety_checker=False |
|
) |
|
self._txt2img_pipeline = self._txt2img_pipeline.to(self.device) |
|
|
|
if self.enable_attention_slicing: |
|
self._txt2img_pipeline.enable_attention_slicing() |
|
if self.enable_memory_efficient_attention and hasattr(self._txt2img_pipeline, 'enable_memory_efficient_attention'): |
|
self._txt2img_pipeline.enable_memory_efficient_attention() |
|
|
|
return self._txt2img_pipeline |
|
|
|
def load_reference_image( |
|
self, |
|
source: Union[str, Path, Image.Image], |
|
preprocess: bool = True |
|
) -> Optional[Tuple[Image.Image, Dict]]: |
|
""" |
|
Load and analyze reference image from various sources |
|
|
|
Args: |
|
source: Image source (file path, URL, or PIL Image) |
|
preprocess: Whether to preprocess the image |
|
|
|
Returns: |
|
Tuple of (processed_image, analysis_results) or None if failed |
|
""" |
|
try: |
|
|
|
if isinstance(source, Image.Image): |
|
image = source.convert('RGB') |
|
source_info = "PIL Image object" |
|
elif isinstance(source, (str, Path)): |
|
source_str = str(source) |
|
if source_str.startswith(('http://', 'https://')): |
|
image = self.image_processor.load_image_from_url(source_str) |
|
source_info = f"URL: {source_str}" |
|
else: |
|
image = self.image_processor.load_image_from_file(source_str) |
|
source_info = f"File: {source_str}" |
|
|
|
if image is None: |
|
return None |
|
else: |
|
logger.error(f"Unsupported source type: {type(source)}") |
|
return None |
|
|
|
|
|
if preprocess: |
|
image = self.image_processor.preprocess_image(image, DEFAULT_IMAGE_SIZE) |
|
|
|
|
|
properties = self.image_processor.analyze_image_properties(image) |
|
style_suggestions = self.style_analyzer.suggest_style_keywords(properties) |
|
image_hash = self.image_processor.generate_image_hash(image) |
|
|
|
analysis = { |
|
'source': source_info, |
|
'properties': properties, |
|
'style_suggestions': style_suggestions, |
|
'hash': image_hash, |
|
'processed_size': image.size |
|
} |
|
|
|
logger.info(f"Successfully loaded and analyzed reference image: {analysis}") |
|
return image, analysis |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading reference image: {e}") |
|
return None |
|
|
|
def enhance_prompt_with_style( |
|
self, |
|
base_prompt: str, |
|
style: str = "", |
|
mood: str = "", |
|
style_suggestions: List[str] = None |
|
) -> str: |
|
""" |
|
Enhance prompt with style information from reference image |
|
|
|
Args: |
|
base_prompt: Base text prompt |
|
style: Additional style keywords |
|
mood: Mood/atmosphere keywords |
|
style_suggestions: Suggested keywords from image analysis |
|
|
|
Returns: |
|
Enhanced prompt string |
|
""" |
|
try: |
|
prompt_parts = [base_prompt.strip()] |
|
|
|
|
|
if style.strip(): |
|
prompt_parts.append(style.strip()) |
|
|
|
|
|
if mood.strip(): |
|
prompt_parts.append(mood.strip()) |
|
|
|
|
|
if style_suggestions: |
|
|
|
top_suggestions = style_suggestions[:3] |
|
prompt_parts.extend(top_suggestions) |
|
|
|
enhanced_prompt = ", ".join(prompt_parts) |
|
logger.info(f"Enhanced prompt: {enhanced_prompt}") |
|
return enhanced_prompt |
|
|
|
except Exception as e: |
|
logger.error(f"Error enhancing prompt: {e}") |
|
return base_prompt |
|
|
|
def generate_with_reference( |
|
self, |
|
prompt: str, |
|
reference_image: Image.Image, |
|
style: str = "", |
|
mood: str = "", |
|
strength: float = 0.5, |
|
num_images: int = 1, |
|
num_inference_steps: int = DEFAULT_INFERENCE_STEPS, |
|
guidance_scale: float = DEFAULT_GUIDANCE_SCALE, |
|
seed: Optional[int] = None, |
|
style_suggestions: List[str] = None |
|
) -> List[Dict]: |
|
""" |
|
Generate images using reference image guidance |
|
|
|
Args: |
|
prompt: Text prompt |
|
reference_image: Reference PIL Image |
|
style: Style keywords |
|
mood: Mood keywords |
|
strength: Reference strength (0.0-1.0, higher = closer to reference) |
|
num_images: Number of images to generate |
|
num_inference_steps: Number of denoising steps |
|
guidance_scale: Classifier-free guidance scale |
|
seed: Random seed for reproducibility |
|
style_suggestions: Style suggestions from image analysis |
|
|
|
Returns: |
|
List of generation results with metadata |
|
""" |
|
try: |
|
|
|
enhanced_prompt = self.enhance_prompt_with_style( |
|
prompt, style, mood, style_suggestions |
|
) |
|
|
|
results = [] |
|
|
|
for i in range(num_images): |
|
|
|
if seed is not None: |
|
current_seed = seed + i |
|
else: |
|
current_seed = torch.seed() |
|
|
|
generator = torch.Generator(device=self.device).manual_seed(current_seed) |
|
|
|
|
|
logger.info(f"Generating image {i+1}/{num_images} with reference guidance") |
|
|
|
with torch.autocast(self.device) if self.device == "cuda" else torch.no_grad(): |
|
result = self.img2img_pipeline( |
|
prompt=enhanced_prompt, |
|
image=reference_image, |
|
strength=strength, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
generator=generator |
|
) |
|
|
|
generated_image = result.images[0] |
|
|
|
|
|
metadata = { |
|
'prompt': prompt, |
|
'enhanced_prompt': enhanced_prompt, |
|
'style': style, |
|
'mood': mood, |
|
'strength': strength, |
|
'num_inference_steps': num_inference_steps, |
|
'guidance_scale': guidance_scale, |
|
'seed': current_seed, |
|
'model': self.model_name, |
|
'generation_type': 'img2img_reference', |
|
'timestamp': datetime.now().isoformat(), |
|
'device': self.device, |
|
'reference_size': reference_image.size, |
|
'output_size': generated_image.size, |
|
'style_suggestions': style_suggestions or [] |
|
} |
|
|
|
results.append({ |
|
'image': generated_image, |
|
'metadata': metadata, |
|
'index': i |
|
}) |
|
|
|
logger.info(f"Successfully generated {len(results)} images with reference guidance") |
|
return results |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating images with reference: {e}") |
|
return [] |
|
|
|
def generate_without_reference( |
|
self, |
|
prompt: str, |
|
style: str = "", |
|
mood: str = "", |
|
num_images: int = 1, |
|
num_inference_steps: int = DEFAULT_INFERENCE_STEPS, |
|
guidance_scale: float = DEFAULT_GUIDANCE_SCALE, |
|
seed: Optional[int] = None |
|
) -> List[Dict]: |
|
""" |
|
Generate images without reference (fallback to text-to-image) |
|
|
|
Args: |
|
prompt: Text prompt |
|
style: Style keywords |
|
mood: Mood keywords |
|
num_images: Number of images to generate |
|
num_inference_steps: Number of denoising steps |
|
guidance_scale: Classifier-free guidance scale |
|
seed: Random seed for reproducibility |
|
|
|
Returns: |
|
List of generation results with metadata |
|
""" |
|
try: |
|
|
|
enhanced_prompt = self.enhance_prompt_with_style(prompt, style, mood) |
|
|
|
results = [] |
|
|
|
for i in range(num_images): |
|
|
|
if seed is not None: |
|
current_seed = seed + i |
|
else: |
|
current_seed = torch.seed() |
|
|
|
generator = torch.Generator(device=self.device).manual_seed(current_seed) |
|
|
|
|
|
logger.info(f"Generating image {i+1}/{num_images} without reference") |
|
|
|
with torch.autocast(self.device) if self.device == "cuda" else torch.no_grad(): |
|
result = self.txt2img_pipeline( |
|
prompt=enhanced_prompt, |
|
height=DEFAULT_IMAGE_SIZE[1], |
|
width=DEFAULT_IMAGE_SIZE[0], |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
generator=generator |
|
) |
|
|
|
generated_image = result.images[0] |
|
|
|
|
|
metadata = { |
|
'prompt': prompt, |
|
'enhanced_prompt': enhanced_prompt, |
|
'style': style, |
|
'mood': mood, |
|
'num_inference_steps': num_inference_steps, |
|
'guidance_scale': guidance_scale, |
|
'seed': current_seed, |
|
'model': self.model_name, |
|
'generation_type': 'txt2img_fallback', |
|
'timestamp': datetime.now().isoformat(), |
|
'device': self.device, |
|
'output_size': generated_image.size |
|
} |
|
|
|
results.append({ |
|
'image': generated_image, |
|
'metadata': metadata, |
|
'index': i |
|
}) |
|
|
|
logger.info(f"Successfully generated {len(results)} images without reference") |
|
return results |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating images without reference: {e}") |
|
return [] |
|
|
|
def save_results( |
|
self, |
|
results: List[Dict], |
|
output_dir: Path = OUTPUTS_DIR, |
|
reference_info: Optional[Dict] = None |
|
) -> List[str]: |
|
""" |
|
Save generation results with comprehensive metadata |
|
|
|
Args: |
|
results: List of generation results |
|
output_dir: Output directory |
|
reference_info: Reference image information |
|
|
|
Returns: |
|
List of saved file paths |
|
""" |
|
try: |
|
ensure_directory_exists(output_dir) |
|
saved_files = [] |
|
|
|
for result in results: |
|
image = result['image'] |
|
metadata = result['metadata'] |
|
index = result['index'] |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
prompt_slug = "_".join(metadata['prompt'].lower().split()[:5]) |
|
style_slug = metadata.get('style', '').replace(' ', '')[:10] |
|
mood_slug = metadata.get('mood', '').replace(' ', '')[:10] |
|
|
|
|
|
ref_indicator = "REFIMG" if metadata['generation_type'] == 'img2img_reference' else "NOREFIMG" |
|
|
|
filename = f"{prompt_slug}_{style_slug}_{mood_slug}_{timestamp}_seed{metadata['seed']}_{ref_indicator}_v{index+1}.png" |
|
filepath = output_dir / filename |
|
|
|
|
|
image.save(filepath) |
|
|
|
|
|
if reference_info: |
|
metadata['reference_info'] = reference_info |
|
|
|
|
|
metadata_filename = filepath.stem + "_metadata.json" |
|
metadata_filepath = output_dir / metadata_filename |
|
|
|
with open(metadata_filepath, 'w') as f: |
|
json.dump(metadata, f, indent=2, default=str) |
|
|
|
saved_files.extend([str(filepath), str(metadata_filepath)]) |
|
logger.info(f"Saved: {filepath}") |
|
|
|
return saved_files |
|
|
|
except Exception as e: |
|
logger.error(f"Error saving results: {e}") |
|
return [] |
|
|
|
def generate_batch( |
|
self, |
|
prompt: str, |
|
reference_source: Optional[Union[str, Path, Image.Image]] = None, |
|
style: str = "", |
|
mood: str = "", |
|
strength: float = 0.5, |
|
num_images: int = 1, |
|
num_inference_steps: int = DEFAULT_INFERENCE_STEPS, |
|
guidance_scale: float = DEFAULT_GUIDANCE_SCALE, |
|
seed: Optional[int] = None, |
|
save_results: bool = True, |
|
output_dir: Path = OUTPUTS_DIR |
|
) -> Dict: |
|
""" |
|
Complete batch generation pipeline |
|
|
|
Args: |
|
prompt: Text prompt |
|
reference_source: Reference image source (file, URL, or PIL Image) |
|
style: Style keywords |
|
mood: Mood keywords |
|
strength: Reference strength (only used if reference provided) |
|
num_images: Number of images to generate |
|
num_inference_steps: Number of denoising steps |
|
guidance_scale: Classifier-free guidance scale |
|
seed: Random seed for reproducibility |
|
save_results: Whether to save results to disk |
|
output_dir: Output directory for saved files |
|
|
|
Returns: |
|
Dictionary with results and metadata |
|
""" |
|
try: |
|
logger.info(f"Starting batch generation: {num_images} images") |
|
|
|
reference_image = None |
|
reference_info = None |
|
style_suggestions = [] |
|
|
|
|
|
if reference_source is not None: |
|
ref_result = self.load_reference_image(reference_source) |
|
if ref_result: |
|
reference_image, reference_info = ref_result |
|
style_suggestions = reference_info.get('style_suggestions', []) |
|
logger.info(f"Using reference image with suggestions: {style_suggestions}") |
|
else: |
|
logger.warning("Failed to load reference image, falling back to text-only generation") |
|
|
|
|
|
if reference_image is not None: |
|
results = self.generate_with_reference( |
|
prompt=prompt, |
|
reference_image=reference_image, |
|
style=style, |
|
mood=mood, |
|
strength=strength, |
|
num_images=num_images, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
seed=seed, |
|
style_suggestions=style_suggestions |
|
) |
|
else: |
|
results = self.generate_without_reference( |
|
prompt=prompt, |
|
style=style, |
|
mood=mood, |
|
num_images=num_images, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
seed=seed |
|
) |
|
|
|
|
|
saved_files = [] |
|
if save_results and results: |
|
saved_files = self.save_results(results, output_dir, reference_info) |
|
|
|
|
|
batch_result = { |
|
'results': results, |
|
'reference_info': reference_info, |
|
'saved_files': saved_files, |
|
'generation_summary': { |
|
'total_images': len(results), |
|
'prompt': prompt, |
|
'style': style, |
|
'mood': mood, |
|
'has_reference': reference_image is not None, |
|
'style_suggestions': style_suggestions, |
|
'timestamp': datetime.now().isoformat() |
|
} |
|
} |
|
|
|
logger.info(f"Batch generation complete: {len(results)} images generated") |
|
return batch_result |
|
|
|
except Exception as e: |
|
logger.error(f"Error in batch generation: {e}") |
|
return { |
|
'results': [], |
|
'reference_info': None, |
|
'saved_files': [], |
|
'error': str(e) |
|
} |
|
|