Spaces:
Runtime error
Runtime error
File size: 4,695 Bytes
8ed2153 d50676a 8ed2153 9240d37 8ed2153 6cfa606 4240411 6cfa606 d91aa80 6cfa606 8ed2153 903b52c 8ed2153 4240411 36c070e 4240411 d50676a 4f8bfe3 d50676a d91aa80 4f8bfe3 6fb3a6d 8ed2153 36c070e 8ed2153 6cfa606 8ed2153 6cfa606 8ed2153 9240d37 8ed2153 9240d37 6fb3a6d d50676a 4f00f4a 903b52c 4f8bfe3 d50676a 4f8bfe3 8ed2153 903b52c 8ed2153 903b52c 8ed2153 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from dataclasses import dataclass
from typing import Tuple
from PIL import Image
import numpy as np
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
import torch
from src.preprocess import HWC3
from src.unet.predictor import generate_mask, load_seg_model
from config import PipelineConfig
@dataclass
class PipelineOutput:
control_mask: np.ndarray
generated_image: np.ndarray
class FashionPipeline:
def __init__(
self,
config: PipelineConfig,
device: torch.device,
):
self.config = config
self.device = device
self.segmentation_model = None
self.controlnet = None
self.pipeline = None
self.__init_pipeline()
def __call__(
self,
control_image: np.ndarray,
prompt: str,
negative_prompt: str,
generate_from_mask: bool,
num_inference_steps: int,
guidance_scale: float,
conditioning_scale: float,
target_image_size: int,
max_image_size: int,
seed: int,
) -> PipelineOutput:
"""Runs image generation pipeline."""
# check image format
control_image = HWC3(control_image)
# extract segmentation mask
if generate_from_mask:
control_mask = Image.fromarray(control_image.astype('uint8'), 'RGB')
else:
segm_mask = generate_mask(control_image, self.segmentation_model, device=self.device)
control_mask = self.create_control_mask(segm_mask)
control_mask = self.adaptive_resize(
image=control_mask,
initial_shape=(control_image.shape[1], control_image.shape[0]),
target_image_size=target_image_size,
max_image_size=max_image_size,
)
# generate image
generator = torch.manual_seed(seed)
generated_image = self.pipeline(
image=control_mask,
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
controlnet_conditioning_scale=conditioning_scale,
generator=generator,
).images[0]
return PipelineOutput(
control_mask=control_mask,
generated_image=generated_image,
)
def create_control_mask(self, segm_mask: np.ndarray) -> Image:
"""Create RGB control mask from segmentation output."""
ch1 = (segm_mask == 1) * 255 # Upper body(red)
ch2 = (segm_mask == 2) * 255 # Lower body(green)
ch3 = (segm_mask == 3) * 255 # Full body(blue).
return Image.fromarray(np.stack([ch1, ch2, ch3], axis=-1).astype('uint8'), 'RGB')
def adaptive_resize(
self,
image: Image,
initial_shape: Tuple[int, int],
target_image_size: int = 512,
max_image_size: int = 768,
divisible: int = 64,
) -> Image:
"""Resizes the image so that width and height are
divided by 'divisible' while maintaining aspect ratio.
Restrict image size with target_image_size and max_image_size.
"""
assert target_image_size % divisible == 0
assert max_image_size % divisible == 0
assert max_image_size >= target_image_size
width, height = initial_shape
aspect_ratio = width / height
if height > width:
new_width = target_image_size
new_height = new_width / aspect_ratio
new_height = (new_height // divisible) * divisible
new_height = int(min(new_height, max_image_size))
else:
new_height = target_image_size
new_width = new_height / aspect_ratio
new_width = (new_width // divisible) * divisible
new_width = int(min(new_width, max_image_size))
return image.resize((new_width, new_height))
def __init_pipeline(self):
"""Init models and SDXL pipeline."""
self.segmentation_model = load_seg_model(
self.config.segmentation_model_path,
device=self.device,
)
self.controlnet = ControlNetModel.from_pretrained(
self.config.controlnet_path,
torch_dtype=torch.float16,
).to(self.device)
self.pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
self.config.base_model_path,
controlnet=self.controlnet,
torch_dtype=torch.float16,
).to(self.device)
self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)
self.pipeline.enable_model_cpu_offload()
|