fashion_controlnet / src /pipeline.py
dragynir's picture
ref
903b52c
from dataclasses import dataclass
from typing import Tuple
from PIL import Image
import numpy as np
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
import torch
from src.preprocess import HWC3
from src.unet.predictor import generate_mask, load_seg_model
from config import PipelineConfig
@dataclass
class PipelineOutput:
control_mask: np.ndarray
generated_image: np.ndarray
class FashionPipeline:
def __init__(
self,
config: PipelineConfig,
device: torch.device,
):
self.config = config
self.device = device
self.segmentation_model = None
self.controlnet = None
self.pipeline = None
self.__init_pipeline()
def __call__(
self,
control_image: np.ndarray,
prompt: str,
negative_prompt: str,
generate_from_mask: bool,
num_inference_steps: int,
guidance_scale: float,
conditioning_scale: float,
target_image_size: int,
max_image_size: int,
seed: int,
) -> PipelineOutput:
"""Runs image generation pipeline."""
# check image format
control_image = HWC3(control_image)
# extract segmentation mask
if generate_from_mask:
control_mask = Image.fromarray(control_image.astype('uint8'), 'RGB')
else:
segm_mask = generate_mask(control_image, self.segmentation_model, device=self.device)
control_mask = self.create_control_mask(segm_mask)
control_mask = self.adaptive_resize(
image=control_mask,
initial_shape=(control_image.shape[1], control_image.shape[0]),
target_image_size=target_image_size,
max_image_size=max_image_size,
)
# generate image
generator = torch.manual_seed(seed)
generated_image = self.pipeline(
image=control_mask,
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
controlnet_conditioning_scale=conditioning_scale,
generator=generator,
).images[0]
return PipelineOutput(
control_mask=control_mask,
generated_image=generated_image,
)
def create_control_mask(self, segm_mask: np.ndarray) -> Image:
"""Create RGB control mask from segmentation output."""
ch1 = (segm_mask == 1) * 255 # Upper body(red)
ch2 = (segm_mask == 2) * 255 # Lower body(green)
ch3 = (segm_mask == 3) * 255 # Full body(blue).
return Image.fromarray(np.stack([ch1, ch2, ch3], axis=-1).astype('uint8'), 'RGB')
def adaptive_resize(
self,
image: Image,
initial_shape: Tuple[int, int],
target_image_size: int = 512,
max_image_size: int = 768,
divisible: int = 64,
) -> Image:
"""Resizes the image so that width and height are
divided by 'divisible' while maintaining aspect ratio.
Restrict image size with target_image_size and max_image_size.
"""
assert target_image_size % divisible == 0
assert max_image_size % divisible == 0
assert max_image_size >= target_image_size
width, height = initial_shape
aspect_ratio = width / height
if height > width:
new_width = target_image_size
new_height = new_width / aspect_ratio
new_height = (new_height // divisible) * divisible
new_height = int(min(new_height, max_image_size))
else:
new_height = target_image_size
new_width = new_height / aspect_ratio
new_width = (new_width // divisible) * divisible
new_width = int(min(new_width, max_image_size))
return image.resize((new_width, new_height))
def __init_pipeline(self):
"""Init models and SDXL pipeline."""
self.segmentation_model = load_seg_model(
self.config.segmentation_model_path,
device=self.device,
)
self.controlnet = ControlNetModel.from_pretrained(
self.config.controlnet_path,
torch_dtype=torch.float16,
).to(self.device)
self.pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
self.config.base_model_path,
controlnet=self.controlnet,
torch_dtype=torch.float16,
).to(self.device)
self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)
self.pipeline.enable_model_cpu_offload()