from dataclasses import dataclass from typing import Tuple from PIL import Image import numpy as np from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler import torch from src.preprocess import HWC3 from src.unet.predictor import generate_mask, load_seg_model from config import PipelineConfig @dataclass class PipelineOutput: control_mask: np.ndarray generated_image: np.ndarray class FashionPipeline: def __init__( self, config: PipelineConfig, device: torch.device, ): self.config = config self.device = device self.segmentation_model = None self.controlnet = None self.pipeline = None self.__init_pipeline() def __call__( self, control_image: np.ndarray, prompt: str, negative_prompt: str, generate_from_mask: bool, num_inference_steps: int, guidance_scale: float, conditioning_scale: float, target_image_size: int, max_image_size: int, seed: int, ) -> PipelineOutput: """Runs image generation pipeline.""" # check image format control_image = HWC3(control_image) # extract segmentation mask if generate_from_mask: control_mask = Image.fromarray(control_image.astype('uint8'), 'RGB') else: segm_mask = generate_mask(control_image, self.segmentation_model, device=self.device) control_mask = self.create_control_mask(segm_mask) control_mask = self.adaptive_resize( image=control_mask, initial_shape=(control_image.shape[1], control_image.shape[0]), target_image_size=target_image_size, max_image_size=max_image_size, ) # generate image generator = torch.manual_seed(seed) generated_image = self.pipeline( image=control_mask, prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, controlnet_conditioning_scale=conditioning_scale, generator=generator, ).images[0] return PipelineOutput( control_mask=control_mask, generated_image=generated_image, ) def create_control_mask(self, segm_mask: np.ndarray) -> Image: """Create RGB control mask from segmentation output.""" ch1 = (segm_mask == 1) * 255 # Upper body(red) ch2 = (segm_mask == 2) * 255 # Lower body(green) ch3 = (segm_mask == 3) * 255 # Full body(blue). return Image.fromarray(np.stack([ch1, ch2, ch3], axis=-1).astype('uint8'), 'RGB') def adaptive_resize( self, image: Image, initial_shape: Tuple[int, int], target_image_size: int = 512, max_image_size: int = 768, divisible: int = 64, ) -> Image: """Resizes the image so that width and height are divided by 'divisible' while maintaining aspect ratio. Restrict image size with target_image_size and max_image_size. """ assert target_image_size % divisible == 0 assert max_image_size % divisible == 0 assert max_image_size >= target_image_size width, height = initial_shape aspect_ratio = width / height if height > width: new_width = target_image_size new_height = new_width / aspect_ratio new_height = (new_height // divisible) * divisible new_height = int(min(new_height, max_image_size)) else: new_height = target_image_size new_width = new_height / aspect_ratio new_width = (new_width // divisible) * divisible new_width = int(min(new_width, max_image_size)) return image.resize((new_width, new_height)) def __init_pipeline(self): """Init models and SDXL pipeline.""" self.segmentation_model = load_seg_model( self.config.segmentation_model_path, device=self.device, ) self.controlnet = ControlNetModel.from_pretrained( self.config.controlnet_path, torch_dtype=torch.float16, ).to(self.device) self.pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( self.config.base_model_path, controlnet=self.controlnet, torch_dtype=torch.float16, ).to(self.device) self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config) self.pipeline.enable_model_cpu_offload()