Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| from PIL import Image | |
| import numpy as np | |
| from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler | |
| import torch | |
| from src.preprocess import HWC3 | |
| from src.unet.predictor import generate_mask, load_seg_model | |
| from config import PipelineConfig | |
| class PipelineOutput: | |
| control_image: np.ndarray | |
| control_mask: np.ndarray | |
| generated_image: np.ndarray | |
| class FashionPipeline: | |
| def __init__( | |
| self, | |
| config: PipelineConfig, | |
| device: torch.device, | |
| ): | |
| self.config = config | |
| self.device = device | |
| self.segmentation_model = None | |
| self.controlnet = None | |
| self.pipeline = None | |
| self.__init_pipeline() | |
| def __call__( | |
| self, | |
| control_image: np.ndarray, | |
| prompt: str, | |
| resolution: int = 512, | |
| num_inference_steps: int = 40, | |
| ) -> PipelineOutput: | |
| # check image format | |
| control_image = HWC3(control_image) | |
| # extract segmentation mask | |
| control_mask = self.extract_mask(control_image).resize((resolution, resolution)) | |
| # generate image | |
| generator = torch.manual_seed(0) | |
| generated_image = self.pipeline( | |
| image=control_mask, | |
| prompt=prompt, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| ).images[0] | |
| return PipelineOutput( | |
| control_image=control_image, | |
| control_mask=control_mask, | |
| generated_image=generated_image, | |
| ) | |
| def extract_mask(self, control_image: np.ndarray) -> Image: | |
| """Performs segmentation model to extract clothes parts mask.""" | |
| control_mask = generate_mask(control_image, self.segmentation_model, device=self.device) | |
| control_mask = np.stack([control_mask] * 3, axis=-1) | |
| control_mask = np.clip((control_mask.astype(np.float32) / 3.0) * 255, 0, 255) | |
| return Image.fromarray(control_mask.astype('uint8'), 'RGB') | |
| def __init_pipeline(self): | |
| """Init models and SDXL pipeline.""" | |
| self.segmentation_model = load_seg_model( | |
| self.config.segmentation_model_path, | |
| device=self.device, | |
| ) | |
| self.controlnet = ControlNetModel.from_pretrained( | |
| self.config.controlnet_path, | |
| torch_dtype=torch.float16, | |
| ) | |
| self.pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( | |
| self.config.base_model_path, | |
| controlnet=self.controlnet, | |
| torch_dtype=torch.float16, | |
| ) | |
| self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config) | |
| self.pipeline.enable_model_cpu_offload() | |