File size: 7,342 Bytes
8822914
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from .hidream_model import HidreamModel
from .src.pipelines.hidream_image.pipeline_hidream_image_editing import (
    HiDreamImageEditingPipeline,
)
from .src.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
from toolkit.accelerator import unwrap_model
import torch
from toolkit.prompt_utils import PromptEmbeds
from toolkit.config_modules import GenerateImageConfig
from diffusers.models import HiDreamImageTransformer2DModel

import torch.nn.functional as F
from PIL import Image
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO


class HidreamE1Model(HidreamModel):
    arch = "hidream_e1"
    hidream_transformer_class = HiDreamImageTransformer2DModel
    hidream_pipeline_class = HiDreamImageEditingPipeline

    def get_generation_pipeline(self):
        scheduler = FlowUniPCMultistepScheduler(
            num_train_timesteps=1000, shift=3.0, use_dynamic_shifting=False
        )

        pipeline: HiDreamImageEditingPipeline = HiDreamImageEditingPipeline(
            scheduler=scheduler,
            vae=self.vae,
            text_encoder=self.text_encoder[0],
            tokenizer=self.tokenizer[0],
            text_encoder_2=self.text_encoder[1],
            tokenizer_2=self.tokenizer[1],
            text_encoder_3=self.text_encoder[2],
            tokenizer_3=self.tokenizer[2],
            text_encoder_4=self.text_encoder[3],
            tokenizer_4=self.tokenizer[3],
            transformer=unwrap_model(self.model),
            aggressive_unloading=self.low_vram,
        )

        pipeline = pipeline.to(self.device_torch)

        return pipeline

    def generate_single_image(
        self,
        pipeline: HiDreamImageEditingPipeline,
        gen_config: GenerateImageConfig,
        conditional_embeds: PromptEmbeds,
        unconditional_embeds: PromptEmbeds,
        generator: torch.Generator,
        extra: dict,
    ):
        if gen_config.ctrl_img is None:
            raise ValueError(
                "Control image is required for Flux Kontext model generation."
            )
        else:
            control_img = Image.open(gen_config.ctrl_img)
            control_img = control_img.convert("RGB")
            # resize to width and height
            if control_img.size != (gen_config.width, gen_config.height):
                control_img = control_img.resize(
                    (gen_config.width, gen_config.height), Image.BILINEAR
                )
        img = pipeline(
            prompt_embeds_t5=conditional_embeds.text_embeds[0],
            prompt_embeds_llama3=conditional_embeds.text_embeds[1],
            pooled_prompt_embeds=conditional_embeds.pooled_embeds,
            negative_prompt_embeds_t5=unconditional_embeds.text_embeds[0],
            negative_prompt_embeds_llama3=unconditional_embeds.text_embeds[1],
            negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
            height=gen_config.height,
            width=gen_config.width,
            num_inference_steps=gen_config.num_inference_steps,
            guidance_scale=gen_config.guidance_scale,
            latents=gen_config.latents,
            generator=generator,
            image=control_img,
            **extra,
        ).images[0]
        return img

    def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
        self.text_encoder_to(self.device_torch, dtype=self.torch_dtype)
        max_sequence_length = 128
        (
            prompt_embeds_t5,
            negative_prompt_embeds_t5,
            prompt_embeds_llama3,
            negative_prompt_embeds_llama3,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
        ) = self.pipeline.encode_prompt(
            prompt=prompt,
            prompt_2=prompt,
            prompt_3=prompt,
            prompt_4=prompt,
            device=self.device_torch,
            dtype=self.torch_dtype,
            num_images_per_prompt=1,
            max_sequence_length=max_sequence_length,
            do_classifier_free_guidance=False,
        )
        prompt_embeds = [prompt_embeds_t5, prompt_embeds_llama3]
        pe = PromptEmbeds([prompt_embeds, pooled_prompt_embeds])
        return pe

    def condition_noisy_latents(
        self, latents: torch.Tensor, batch: "DataLoaderBatchDTO"
    ):
        with torch.no_grad():
            control_tensor = batch.control_tensor
            if control_tensor is not None:
                self.vae.to(self.device_torch)
                # we are not packed here, so we just need to pass them so we can pack them later
                control_tensor = control_tensor * 2 - 1
                control_tensor = control_tensor.to(
                    self.vae_device_torch, dtype=self.torch_dtype
                )

                # if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
                if batch.tensor is not None:
                    target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3]
                else:
                    # When caching latents, batch.tensor is None. We get the size from the file_items instead.
                    target_h = batch.file_items[0].crop_height
                    target_w = batch.file_items[0].crop_width

                if (
                    control_tensor.shape[2] != target_h
                    or control_tensor.shape[3] != target_w
                ):
                    control_tensor = F.interpolate(
                        control_tensor, size=(target_h, target_w), mode="bilinear"
                    )

                control_latent = self.encode_images(control_tensor).to(
                    latents.device, latents.dtype
                )
                latents = torch.cat((latents, control_latent), dim=1)

        return latents.detach()

    def get_noise_prediction(
        self,
        latent_model_input: torch.Tensor,
        timestep: torch.Tensor,  # 0 to 1000 scale
        text_embeddings: PromptEmbeds,
        **kwargs,
    ):
        with torch.no_grad():
            # make sure config is set
            self.model.config.force_inference_output = True
            has_control = False
            lat_size = latent_model_input.shape[-1]
            if latent_model_input.shape[1] == 32:
                # chunk it and stack it on batch dimension
                # dont update batch size for img_its
                lat, control = torch.chunk(latent_model_input, 2, dim=1)
                latent_model_input = torch.cat([lat, control], dim=-1)
                has_control = True

        dtype = self.model.dtype
        device = self.device_torch

        text_embeds = text_embeddings.text_embeds
        # run the to for the list
        text_embeds = [te.to(device, dtype=dtype) for te in text_embeds]

        noise_pred = self.transformer(
            hidden_states=latent_model_input,
            timesteps=timestep,
            encoder_hidden_states_t5=text_embeds[0],
            encoder_hidden_states_llama3=text_embeds[1],
            pooled_embeds=text_embeddings.pooled_embeds.to(device, dtype=dtype),
            return_dict=False,
        )[0]

        if has_control:
            noise_pred = -1.0 * noise_pred[..., :lat_size]
        else:
            noise_pred = -1.0 * noise_pred

        return noise_pred