Spaces:
Runtime error
Runtime error
from dataclasses import dataclass | |
from typing import Tuple | |
from PIL import Image | |
import numpy as np | |
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler | |
import torch | |
from src.preprocess import HWC3 | |
from src.unet.predictor import generate_mask, load_seg_model | |
from config import PipelineConfig | |
class PipelineOutput: | |
control_mask: np.ndarray | |
generated_image: np.ndarray | |
class FashionPipeline: | |
def __init__( | |
self, | |
config: PipelineConfig, | |
device: torch.device, | |
): | |
self.config = config | |
self.device = device | |
self.segmentation_model = None | |
self.controlnet = None | |
self.pipeline = None | |
self.__init_pipeline() | |
def __call__( | |
self, | |
control_image: np.ndarray, | |
prompt: str, | |
negative_prompt: str, | |
generate_from_mask: bool, | |
num_inference_steps: int, | |
guidance_scale: float, | |
conditioning_scale: float, | |
target_image_size: int, | |
max_image_size: int, | |
seed: int, | |
) -> PipelineOutput: | |
"""Runs image generation pipeline.""" | |
# check image format | |
control_image = HWC3(control_image) | |
# extract segmentation mask | |
if generate_from_mask: | |
control_mask = Image.fromarray(control_image.astype('uint8'), 'RGB') | |
else: | |
segm_mask = generate_mask(control_image, self.segmentation_model, device=self.device) | |
control_mask = self.create_control_mask(segm_mask) | |
control_mask = self.adaptive_resize( | |
image=control_mask, | |
initial_shape=(control_image.shape[1], control_image.shape[0]), | |
target_image_size=target_image_size, | |
max_image_size=max_image_size, | |
) | |
# generate image | |
generator = torch.manual_seed(seed) | |
generated_image = self.pipeline( | |
image=control_mask, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
controlnet_conditioning_scale=conditioning_scale, | |
generator=generator, | |
).images[0] | |
return PipelineOutput( | |
control_mask=control_mask, | |
generated_image=generated_image, | |
) | |
def create_control_mask(self, segm_mask: np.ndarray) -> Image: | |
"""Create RGB control mask from segmentation output.""" | |
ch1 = (segm_mask == 1) * 255 # Upper body(red) | |
ch2 = (segm_mask == 2) * 255 # Lower body(green) | |
ch3 = (segm_mask == 3) * 255 # Full body(blue). | |
return Image.fromarray(np.stack([ch1, ch2, ch3], axis=-1).astype('uint8'), 'RGB') | |
def adaptive_resize( | |
self, | |
image: Image, | |
initial_shape: Tuple[int, int], | |
target_image_size: int = 512, | |
max_image_size: int = 768, | |
divisible: int = 64, | |
) -> Image: | |
"""Resizes the image so that width and height are | |
divided by 'divisible' while maintaining aspect ratio. | |
Restrict image size with target_image_size and max_image_size. | |
""" | |
assert target_image_size % divisible == 0 | |
assert max_image_size % divisible == 0 | |
assert max_image_size >= target_image_size | |
width, height = initial_shape | |
aspect_ratio = width / height | |
if height > width: | |
new_width = target_image_size | |
new_height = new_width / aspect_ratio | |
new_height = (new_height // divisible) * divisible | |
new_height = int(min(new_height, max_image_size)) | |
else: | |
new_height = target_image_size | |
new_width = new_height / aspect_ratio | |
new_width = (new_width // divisible) * divisible | |
new_width = int(min(new_width, max_image_size)) | |
return image.resize((new_width, new_height)) | |
def __init_pipeline(self): | |
"""Init models and SDXL pipeline.""" | |
self.segmentation_model = load_seg_model( | |
self.config.segmentation_model_path, | |
device=self.device, | |
) | |
self.controlnet = ControlNetModel.from_pretrained( | |
self.config.controlnet_path, | |
torch_dtype=torch.float16, | |
).to(self.device) | |
self.pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( | |
self.config.base_model_path, | |
controlnet=self.controlnet, | |
torch_dtype=torch.float16, | |
).to(self.device) | |
self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config) | |
self.pipeline.enable_model_cpu_offload() | |