digiPal / models /image_generator.py
BladeSzaSza's picture
new design
fe24641
raw
history blame
9.69 kB
import torch
from diffusers import DiffusionPipeline
from PIL import Image
import numpy as np
from typing import Optional, List, Union
import gc
class OmniGenImageGenerator:
"""Image generation using OmniGen2 model"""
def __init__(self, device: str = "cuda"):
self.device = device if torch.cuda.is_available() else "cpu"
self.pipeline = None
self.model_id = "OmniGen2/OmniGen2" # Placeholder - actual model path may differ
# Generation parameters
self.default_width = 512
self.default_height = 512
self.num_inference_steps = 30
self.guidance_scale = 7.5
# Memory optimization
self.enable_attention_slicing = True
self.enable_vae_slicing = True
self.enable_cpu_offload = self.device == "cuda"
def load_model(self):
"""Lazy load the image generation model"""
if self.pipeline is None:
try:
# Determine torch dtype
torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
# Load pipeline with optimizations
self.pipeline = DiffusionPipeline.from_pretrained(
self.model_id,
torch_dtype=torch_dtype,
use_safetensors=True,
variant="fp16" if self.device == "cuda" else None
)
# Apply optimizations
if self.device == "cuda":
if self.enable_cpu_offload:
self.pipeline.enable_sequential_cpu_offload()
else:
self.pipeline = self.pipeline.to(self.device)
if self.enable_attention_slicing:
self.pipeline.enable_attention_slicing(1)
if self.enable_vae_slicing:
self.pipeline.enable_vae_slicing()
else:
self.pipeline = self.pipeline.to(self.device)
# Compile for faster inference (if available)
if hasattr(torch, 'compile') and self.device == "cuda":
try:
self.pipeline.unet = torch.compile(self.pipeline.unet, mode="reduce-overhead")
except:
pass # Compilation is optional
except Exception as e:
print(f"Failed to load image generation model: {e}")
# Try fallback to stable diffusion
try:
self.model_id = "runwayml/stable-diffusion-v1-5"
self._load_fallback_model()
except:
raise
def _load_fallback_model(self):
"""Load fallback Stable Diffusion model"""
from diffusers import StableDiffusionPipeline
torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
self.pipeline = StableDiffusionPipeline.from_pretrained(
self.model_id,
torch_dtype=torch_dtype,
use_safetensors=True
)
if self.device == "cuda" and self.enable_cpu_offload:
self.pipeline.enable_sequential_cpu_offload()
else:
self.pipeline = self.pipeline.to(self.device)
def generate(self,
prompt: str,
reference_images: Optional[List[Union[str, Image.Image]]] = None,
negative_prompt: Optional[str] = None,
width: Optional[int] = None,
height: Optional[int] = None,
num_images: int = 1,
seed: Optional[int] = None) -> Union[Image.Image, List[Image.Image]]:
"""Generate monster image from prompt"""
try:
# Load model if needed
self.load_model()
# Set dimensions
width = width or self.default_width
height = height or self.default_height
# Ensure dimensions are multiples of 8
width = (width // 8) * 8
height = (height // 8) * 8
# Enhance prompt for monster generation
enhanced_prompt = self._enhance_prompt(prompt)
# Default negative prompt for quality
if negative_prompt is None:
negative_prompt = (
"low quality, blurry, distorted, disfigured, "
"bad anatomy, wrong proportions, ugly, duplicate, "
"morbid, mutilated, extra limbs, malformed"
)
# Set seed for reproducibility
generator = None
if seed is not None:
generator = torch.Generator(device=self.device).manual_seed(seed)
# Generate images
with torch.no_grad():
if hasattr(self.pipeline, '__call__'):
# Standard diffusion pipeline
images = self.pipeline(
prompt=enhanced_prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_inference_steps=self.num_inference_steps,
guidance_scale=self.guidance_scale,
num_images_per_prompt=num_images,
generator=generator
).images
else:
# OmniGen specific generation (if different API)
images = self._omnigen_generate(
enhanced_prompt,
reference_images,
width,
height,
num_images
)
# Clean up memory
if self.device == "cuda":
torch.cuda.empty_cache()
# Return single image or list
if num_images == 1:
return images[0]
return images
except Exception as e:
print(f"Image generation error: {e}")
# Return fallback image
return self._generate_fallback_image(width, height)
def _enhance_prompt(self, base_prompt: str) -> str:
"""Enhance prompt for better monster generation"""
enhancements = [
"digital art",
"creature design",
"game character",
"detailed",
"vibrant colors",
"fantasy creature",
"high quality",
"professional artwork"
]
# Combine base prompt with enhancements
enhanced = f"{base_prompt}, {', '.join(enhancements)}"
return enhanced
def _omnigen_generate(self, prompt: str, reference_images: Optional[List],
width: int, height: int, num_images: int) -> List[Image.Image]:
"""OmniGen specific generation with multimodal inputs"""
# This would be implemented based on OmniGen's specific API
# For now, fall back to standard generation
return self.pipeline(
prompt=prompt,
width=width,
height=height,
num_images_per_prompt=num_images
).images
def _generate_fallback_image(self, width: int, height: int) -> Image.Image:
"""Generate a fallback monster image"""
# Create a simple procedural monster image
img_array = np.zeros((height, width, 3), dtype=np.uint8)
# Add some basic shapes and colors
center_x, center_y = width // 2, height // 2
radius = min(width, height) // 3
# Create circular body
y, x = np.ogrid[:height, :width]
mask = (x - center_x)**2 + (y - center_y)**2 <= radius**2
# Random monster color
color = np.random.randint(50, 200, size=3)
img_array[mask] = color
# Add eyes
eye_y = center_y - radius // 3
eye_left_x = center_x - radius // 3
eye_right_x = center_x + radius // 3
eye_radius = radius // 8
# Left eye
eye_mask = (x - eye_left_x)**2 + (y - eye_y)**2 <= eye_radius**2
img_array[eye_mask] = [255, 255, 255]
# Right eye
eye_mask = (x - eye_right_x)**2 + (y - eye_y)**2 <= eye_radius**2
img_array[eye_mask] = [255, 255, 255]
# Convert to PIL Image
return Image.fromarray(img_array)
def edit_image(self,
image: Union[str, Image.Image],
prompt: str,
mask: Optional[Union[str, Image.Image]] = None) -> Image.Image:
"""Edit existing image (for future monster customization)"""
# This would implement image editing capabilities
raise NotImplementedError("Image editing not yet implemented")
def to(self, device: str):
"""Move pipeline to specified device"""
self.device = device
if self.pipeline:
if device == "cuda" and self.enable_cpu_offload:
self.pipeline.enable_sequential_cpu_offload()
else:
self.pipeline = self.pipeline.to(device)
def __del__(self):
"""Cleanup when object is destroyed"""
if self.pipeline:
del self.pipeline
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()