File size: 4,695 Bytes
8ed2153
d50676a
8ed2153
 
 
 
 
 
 
 
 
 
 
 
 
 
9240d37
8ed2153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cfa606
4240411
6cfa606
 
 
d91aa80
 
6cfa606
8ed2153
903b52c
8ed2153
 
 
 
4240411
36c070e
4240411
 
 
d50676a
4f8bfe3
d50676a
 
d91aa80
 
4f8bfe3
6fb3a6d
8ed2153
36c070e
8ed2153
 
 
6cfa606
8ed2153
6cfa606
 
8ed2153
 
 
 
9240d37
8ed2153
 
 
9240d37
 
6fb3a6d
 
 
 
 
d50676a
 
 
 
 
 
 
4f00f4a
903b52c
 
 
 
4f8bfe3
 
 
 
d50676a
4f8bfe3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ed2153
 
 
 
 
 
 
 
 
 
903b52c
8ed2153
 
 
 
 
903b52c
8ed2153
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from dataclasses import dataclass
from typing import Tuple

from PIL import Image
import numpy as np

from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
import torch

from src.preprocess import HWC3
from src.unet.predictor import generate_mask, load_seg_model
from config import PipelineConfig


@dataclass
class PipelineOutput:
    control_mask: np.ndarray
    generated_image: np.ndarray


class FashionPipeline:

    def __init__(
        self,
        config: PipelineConfig,
        device: torch.device,
    ):
        self.config = config
        self.device = device

        self.segmentation_model = None
        self.controlnet = None
        self.pipeline = None

        self.__init_pipeline()

    def __call__(
        self,
        control_image: np.ndarray,
        prompt: str,
        negative_prompt: str,
        generate_from_mask: bool,
        num_inference_steps: int,
        guidance_scale: float,
        conditioning_scale: float,
        target_image_size: int,
        max_image_size: int,
        seed: int,
    ) -> PipelineOutput:
        """Runs image generation pipeline."""
        # check image format
        control_image = HWC3(control_image)

        # extract segmentation mask
        if generate_from_mask:
            control_mask = Image.fromarray(control_image.astype('uint8'), 'RGB')
        else:
            segm_mask = generate_mask(control_image, self.segmentation_model, device=self.device)
            control_mask = self.create_control_mask(segm_mask)

        control_mask = self.adaptive_resize(
            image=control_mask,
            initial_shape=(control_image.shape[1], control_image.shape[0]),
            target_image_size=target_image_size,
            max_image_size=max_image_size,
        )

        # generate image
        generator = torch.manual_seed(seed)
        generated_image = self.pipeline(
            image=control_mask,
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            controlnet_conditioning_scale=conditioning_scale,
            generator=generator,
        ).images[0]

        return PipelineOutput(
            control_mask=control_mask,
            generated_image=generated_image,
        )

    def create_control_mask(self, segm_mask: np.ndarray) -> Image:
        """Create RGB control mask from segmentation output."""
        ch1 = (segm_mask == 1) * 255  # Upper body(red)
        ch2 = (segm_mask == 2) * 255  # Lower body(green)
        ch3 = (segm_mask == 3) * 255  # Full body(blue).
        return Image.fromarray(np.stack([ch1, ch2, ch3], axis=-1).astype('uint8'), 'RGB')

    def adaptive_resize(
        self,
        image: Image,
        initial_shape: Tuple[int, int],
        target_image_size: int = 512,
        max_image_size: int = 768,
        divisible: int = 64,
    ) -> Image:
        """Resizes the image so that width and height are
        divided by 'divisible' while maintaining aspect ratio.
        Restrict image size with target_image_size and max_image_size.
        """
        assert target_image_size % divisible == 0
        assert max_image_size % divisible == 0
        assert max_image_size >= target_image_size

        width, height = initial_shape
        aspect_ratio = width / height

        if height > width:
            new_width = target_image_size
            new_height = new_width / aspect_ratio
            new_height = (new_height // divisible) * divisible
            new_height = int(min(new_height, max_image_size))
        else:
            new_height = target_image_size
            new_width = new_height / aspect_ratio
            new_width = (new_width // divisible) * divisible
            new_width = int(min(new_width, max_image_size))

        return image.resize((new_width, new_height))

    def __init_pipeline(self):
        """Init models and SDXL pipeline."""
        self.segmentation_model = load_seg_model(
            self.config.segmentation_model_path,
            device=self.device,
        )

        self.controlnet = ControlNetModel.from_pretrained(
            self.config.controlnet_path,
            torch_dtype=torch.float16,
        ).to(self.device)

        self.pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
            self.config.base_model_path,
            controlnet=self.controlnet,
            torch_dtype=torch.float16,
        ).to(self.device)

        self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)

        self.pipeline.enable_model_cpu_offload()