fashion_controlnet / src /pipeline.py
dragynir's picture
add app
8ed2153
raw
history blame
2.77 kB
from dataclasses import dataclass
from PIL import Image
import numpy as np
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
import torch
from src.preprocess import HWC3
from src.unet.predictor import generate_mask, load_seg_model
from config import PipelineConfig
@dataclass
class PipelineOutput:
control_image: np.ndarray
control_mask: np.ndarray
generated_image: np.ndarray
class FashionPipeline:
def __init__(
self,
config: PipelineConfig,
device: torch.device,
):
self.config = config
self.device = device
self.segmentation_model = None
self.controlnet = None
self.pipeline = None
self.__init_pipeline()
def __call__(
self,
control_image: np.ndarray,
prompt: str,
resolution: int = 512,
num_inference_steps: int = 40,
) -> PipelineOutput:
# check image format
control_image = HWC3(control_image)
# extract segmentation mask
control_mask = self.extract_mask(control_image).resize((resolution, resolution))
# generate image
generator = torch.manual_seed(0)
generated_image = self.pipeline(
image=control_mask,
prompt=prompt,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
return PipelineOutput(
control_image=control_image,
control_mask=control_mask,
generated_image=generated_image,
)
def extract_mask(self, control_image: np.ndarray) -> Image:
"""Performs segmentation model to extract clothes parts mask."""
control_mask = generate_mask(control_image, self.segmentation_model, device=self.device)
control_mask = np.stack([control_mask] * 3, axis=-1)
control_mask = np.clip((control_mask.astype(np.float32) / 3.0) * 255, 0, 255)
return Image.fromarray(control_mask.astype('uint8'), 'RGB')
def __init_pipeline(self):
"""Init models and SDXL pipeline."""
self.segmentation_model = load_seg_model(
self.config.segmentation_model_path,
device=self.device,
)
self.controlnet = ControlNetModel.from_pretrained(
self.config.controlnet_path,
torch_dtype=torch.float16,
)
self.pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
self.config.base_model_path,
controlnet=self.controlnet,
torch_dtype=torch.float16,
)
self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)
self.pipeline.enable_model_cpu_offload()