Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| from typing import Tuple | |
| from PIL import Image | |
| import numpy as np | |
| from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler | |
| import torch | |
| from src.preprocess import HWC3 | |
| from src.unet.predictor import generate_mask, load_seg_model | |
| from config import PipelineConfig | |
| class PipelineOutput: | |
| control_mask: np.ndarray | |
| generated_image: np.ndarray | |
| class FashionPipeline: | |
| def __init__( | |
| self, | |
| config: PipelineConfig, | |
| device: torch.device, | |
| ): | |
| self.config = config | |
| self.device = device | |
| self.segmentation_model = None | |
| self.controlnet = None | |
| self.pipeline = None | |
| self.__init_pipeline() | |
| def __call__( | |
| self, | |
| control_image: np.ndarray, | |
| prompt: str, | |
| negative_prompt: str, | |
| generate_from_mask: bool, | |
| num_inference_steps: int, | |
| guidance_scale: float, | |
| conditioning_scale: float, | |
| guess_mode: bool, | |
| target_image_size: int, | |
| max_image_size: int, | |
| seed: int, | |
| ) -> PipelineOutput: | |
| # check image format | |
| control_image = HWC3(control_image) | |
| # extract segmentation mask | |
| if generate_from_mask: | |
| control_mask = control_image | |
| else: | |
| segm_mask = generate_mask(control_image, self.segmentation_model, device=self.device) | |
| control_mask = self.create_control_mask(segm_mask) | |
| control_mask = self.adaptive_resize( | |
| image=control_mask, | |
| initial_shape=(control_image.shape[1], control_image.shape[0]), | |
| target_image_size=target_image_size, | |
| max_image_size=max_image_size, | |
| ) | |
| # generate image | |
| generator = torch.manual_seed(0) | |
| generated_image = self.pipeline( | |
| image=control_mask, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| controlnet_conditioning_scale=conditioning_scale, | |
| guess_mode=guess_mode, | |
| seed=seed, | |
| generator=generator, | |
| ).images[0] | |
| return PipelineOutput( | |
| control_mask=control_mask, | |
| generated_image=generated_image, | |
| ) | |
| def create_control_mask(self, segm_mask: np.ndarray) -> Image: | |
| """Create RGB control mask from segmentation output.""" | |
| ch1 = (segm_mask == 1) * 255 # Upper body(red) | |
| ch2 = (segm_mask == 2) * 255 # Lower body(green) | |
| ch3 = (segm_mask == 3) * 255 # Full body(blue). | |
| return Image.fromarray(np.stack([ch1, ch2, ch3], axis=-1).astype('uint8'), 'RGB') | |
| def adaptive_resize( | |
| self, | |
| image: Image, | |
| initial_shape: Tuple[int, int], | |
| target_image_size: int = 512, | |
| max_image_size: int = 768, | |
| divisible: int = 64, | |
| ) -> Image: | |
| assert target_image_size % divisible == 0 | |
| assert max_image_size % divisible == 0 | |
| assert max_image_size >= target_image_size | |
| width, height = initial_shape | |
| aspect_ratio = width / height | |
| if height > width: | |
| new_width = target_image_size | |
| new_height = new_width / aspect_ratio | |
| new_height = (new_height // divisible) * divisible | |
| new_height = int(min(new_height, max_image_size)) | |
| else: | |
| new_height = target_image_size | |
| new_width = new_height / aspect_ratio | |
| new_width = (new_width // divisible) * divisible | |
| new_width = int(min(new_width, max_image_size)) | |
| return image.resize((new_width, new_height)) | |
| def __init_pipeline(self): | |
| """Init models and SDXL pipeline.""" | |
| self.segmentation_model = load_seg_model( | |
| self.config.segmentation_model_path, | |
| device=self.device, | |
| ) | |
| self.controlnet = ControlNetModel.from_pretrained( | |
| self.config.controlnet_path, | |
| torch_dtype=torch.float16, | |
| # device_map="auto", | |
| ) | |
| self.pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( | |
| self.config.base_model_path, | |
| controlnet=self.controlnet, | |
| torch_dtype=torch.float16, | |
| # device_map="auto", | |
| ) | |
| self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config) | |
| self.pipeline.enable_model_cpu_offload() | |