File size: 11,005 Bytes
06e9d12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
from __future__ import annotations

import logging

from typing import Any, Dict, Tuple, Union, Optional
from einops import rearrange, repeat
from torch import nn
import torch

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin, load_state_dict

from ..data.data_util import align_repeat_tensor_single_dim

from .unet_3d_condition import UNet3DConditionModel
from .referencenet import ReferenceNet2D
from ip_adapter.ip_adapter import ImageProjModel

logger = logging.getLogger(__name__)


class SuperUNet3DConditionModel(nn.Module):
    """封装了各种子模型的超模型,与 diffusers 的 pipeline 很像,只不过这里是模型定义。
    主要作用
    1. 将支持controlnet、referencenet等功能的计算封装起来,简洁些;
    2. 便于 accelerator 的分布式训练;

    wrap the sub-models, such as unet, referencenet, controlnet, vae, text_encoder, tokenizer, text_emb_extractor, clip_vision_extractor, ip_adapter_image_proj
    1. support controlnet, referencenet, etc.
    2. support accelerator distributed training
    """

    _supports_gradient_checkpointing = True
    print_idx = 0

    # @register_to_config
    def __init__(
        self,
        unet: nn.Module,
        referencenet: nn.Module = None,
        controlnet: nn.Module = None,
        vae: nn.Module = None,
        text_encoder: nn.Module = None,
        tokenizer: nn.Module = None,
        text_emb_extractor: nn.Module = None,
        clip_vision_extractor: nn.Module = None,
        ip_adapter_image_proj: nn.Module = None,
    ) -> None:
        """_summary_

        Args:
            unet (nn.Module): _description_
            referencenet (nn.Module, optional): _description_. Defaults to None.
            controlnet (nn.Module, optional): _description_. Defaults to None.
            vae (nn.Module, optional): _description_. Defaults to None.
            text_encoder (nn.Module, optional): _description_. Defaults to None.
            tokenizer (nn.Module, optional): _description_. Defaults to None.
            text_emb_extractor (nn.Module, optional): wrap text_encoder and tokenizer for str2emb. Defaults to None.
            clip_vision_extractor (nn.Module, optional): _description_. Defaults to None.
        """
        super().__init__()
        self.unet = unet
        self.referencenet = referencenet
        self.controlnet = controlnet
        self.vae = vae
        self.text_encoder = text_encoder
        self.tokenizer = tokenizer
        self.text_emb_extractor = text_emb_extractor
        self.clip_vision_extractor = clip_vision_extractor
        self.ip_adapter_image_proj = ip_adapter_image_proj

    def forward(
        self,
        unet_params: Dict,
        encoder_hidden_states: torch.Tensor,
        referencenet_params: Dict = None,
        controlnet_params: Dict = None,
        controlnet_scale: float = 1.0,
        vision_clip_emb: Union[torch.Tensor, None] = None,
        prompt_only_use_image_prompt: bool = False,
    ):
        """_summary_

        Args:
            unet_params (Dict): _description_
            encoder_hidden_states (torch.Tensor): b t n d
            referencenet_params (Dict, optional): _description_. Defaults to None.
            controlnet_params (Dict, optional): _description_. Defaults to None.
            controlnet_scale (float, optional): _description_. Defaults to 1.0.
            vision_clip_emb (Union[torch.Tensor, None], optional): b t d. Defaults to None.
            prompt_only_use_image_prompt (bool, optional): _description_. Defaults to False.

        Returns:
            _type_: _description_
        """
        batch_size = unet_params["sample"].shape[0]
        time_size = unet_params["sample"].shape[2]

        # ip_adapter_cross_attn, prepare image prompt
        if vision_clip_emb is not None:
            # b t n d -> b t n d
            if self.print_idx == 0:
                logger.debug(
                    f"vision_clip_emb, before ip_adapter_image_proj, shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}"
                )
            if vision_clip_emb.ndim == 3:
                vision_clip_emb = rearrange(vision_clip_emb, "b t d-> b t 1 d")
            if self.ip_adapter_image_proj is not None:
                vision_clip_emb = rearrange(vision_clip_emb, "b t n d ->(b t) n d")
                vision_clip_emb = self.ip_adapter_image_proj(vision_clip_emb)
                if self.print_idx == 0:
                    logger.debug(
                        f"vision_clip_emb, after ip_adapter_image_proj shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}"
                    )
                if vision_clip_emb.ndim == 2:
                    vision_clip_emb = rearrange(vision_clip_emb, "b d-> b 1 d")
                vision_clip_emb = rearrange(
                    vision_clip_emb, "(b t) n d -> b t n d", b=batch_size
                )
            vision_clip_emb = align_repeat_tensor_single_dim(
                vision_clip_emb, target_length=time_size, dim=1
            )
            if self.print_idx == 0:
                logger.debug(
                    f"vision_clip_emb, after reshape shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}"
                )

        if vision_clip_emb is None and encoder_hidden_states is not None:
            vision_clip_emb = encoder_hidden_states
        if vision_clip_emb is not None and encoder_hidden_states is None:
            encoder_hidden_states = vision_clip_emb
        # 当 prompt_only_use_image_prompt 为True时,
        # 1. referencenet 都使用 vision_clip_emb
        # 2. unet 如果没有dual_cross_attn,使用vision_clip_emb,有时不更新
        # 3. controlnet 当前使用 text_prompt

        # when prompt_only_use_image_prompt True,
        # 1. referencenet use vision_clip_emb
        # 2. unet use vision_clip_emb if no dual_cross_attn, sometimes not update
        # 3. controlnet use text_prompt

        # extract referencenet emb
        if self.referencenet is not None and referencenet_params is not None:
            referencenet_encoder_hidden_states = align_repeat_tensor_single_dim(
                vision_clip_emb,
                target_length=referencenet_params["num_frames"],
                dim=1,
            )
            referencenet_params["encoder_hidden_states"] = rearrange(
                referencenet_encoder_hidden_states, "b t n d->(b t) n d"
            )
            referencenet_out = self.referencenet(**referencenet_params)
            (
                down_block_refer_embs,
                mid_block_refer_emb,
                refer_self_attn_emb,
            ) = referencenet_out
            if down_block_refer_embs is not None:
                if self.print_idx == 0:
                    logger.debug(
                        f"len(down_block_refer_embs)={len(down_block_refer_embs)}"
                    )
                for i, down_emb in enumerate(down_block_refer_embs):
                    if self.print_idx == 0:
                        logger.debug(
                            f"down_emb, {i}, {down_emb.shape}, mean={down_emb.mean()}"
                        )
            else:
                if self.print_idx == 0:
                    logger.debug(f"down_block_refer_embs is None")
            if mid_block_refer_emb is not None:
                if self.print_idx == 0:
                    logger.debug(
                        f"mid_block_refer_emb, {mid_block_refer_emb.shape}, mean={mid_block_refer_emb.mean()}"
                    )
            else:
                if self.print_idx == 0:
                    logger.debug(f"mid_block_refer_emb is None")
            if refer_self_attn_emb is not None:
                if self.print_idx == 0:
                    logger.debug(f"refer_self_attn_emb, num={len(refer_self_attn_emb)}")
                for i, self_attn_emb in enumerate(refer_self_attn_emb):
                    if self.print_idx == 0:
                        logger.debug(
                            f"referencenet, self_attn_emb, {i}th, shape={self_attn_emb.shape}, mean={self_attn_emb.mean()}"
                        )
            else:
                if self.print_idx == 0:
                    logger.debug(f"refer_self_attn_emb is None")
        else:
            down_block_refer_embs, mid_block_refer_emb, refer_self_attn_emb = (
                None,
                None,
                None,
            )

        # extract controlnet emb
        if self.controlnet is not None and controlnet_params is not None:
            controlnet_encoder_hidden_states = align_repeat_tensor_single_dim(
                encoder_hidden_states,
                target_length=unet_params["sample"].shape[2],
                dim=1,
            )
            controlnet_params["encoder_hidden_states"] = rearrange(
                controlnet_encoder_hidden_states, " b t n d -> (b t) n d"
            )
            (
                down_block_additional_residuals,
                mid_block_additional_residual,
            ) = self.controlnet(**controlnet_params)
            if controlnet_scale != 1.0:
                down_block_additional_residuals = [
                    x * controlnet_scale for x in down_block_additional_residuals
                ]
                mid_block_additional_residual = (
                    mid_block_additional_residual * controlnet_scale
                )
            for i, down_block_additional_residual in enumerate(
                down_block_additional_residuals
            ):
                if self.print_idx == 0:
                    logger.debug(
                        f"{i}, down_block_additional_residual mean={torch.mean(down_block_additional_residual)}"
                    )

            if self.print_idx == 0:
                logger.debug(
                    f"mid_block_additional_residual mean={torch.mean(mid_block_additional_residual)}"
                )
        else:
            down_block_additional_residuals = None
            mid_block_additional_residual = None

        if prompt_only_use_image_prompt and vision_clip_emb is not None:
            encoder_hidden_states = vision_clip_emb

        # run unet
        out = self.unet(
            **unet_params,
            down_block_refer_embs=down_block_refer_embs,
            mid_block_refer_emb=mid_block_refer_emb,
            refer_self_attn_emb=refer_self_attn_emb,
            down_block_additional_residuals=down_block_additional_residuals,
            mid_block_additional_residual=mid_block_additional_residual,
            encoder_hidden_states=encoder_hidden_states,
            vision_clip_emb=vision_clip_emb,
        )
        self.print_idx += 1
        return out

    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, (UNet3DConditionModel, ReferenceNet2D)):
            module.gradient_checkpointing = value