Spaces:
Paused
Paused
File size: 7,342 Bytes
8822914 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
from .hidream_model import HidreamModel
from .src.pipelines.hidream_image.pipeline_hidream_image_editing import (
HiDreamImageEditingPipeline,
)
from .src.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
from toolkit.accelerator import unwrap_model
import torch
from toolkit.prompt_utils import PromptEmbeds
from toolkit.config_modules import GenerateImageConfig
from diffusers.models import HiDreamImageTransformer2DModel
import torch.nn.functional as F
from PIL import Image
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
class HidreamE1Model(HidreamModel):
arch = "hidream_e1"
hidream_transformer_class = HiDreamImageTransformer2DModel
hidream_pipeline_class = HiDreamImageEditingPipeline
def get_generation_pipeline(self):
scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=1000, shift=3.0, use_dynamic_shifting=False
)
pipeline: HiDreamImageEditingPipeline = HiDreamImageEditingPipeline(
scheduler=scheduler,
vae=self.vae,
text_encoder=self.text_encoder[0],
tokenizer=self.tokenizer[0],
text_encoder_2=self.text_encoder[1],
tokenizer_2=self.tokenizer[1],
text_encoder_3=self.text_encoder[2],
tokenizer_3=self.tokenizer[2],
text_encoder_4=self.text_encoder[3],
tokenizer_4=self.tokenizer[3],
transformer=unwrap_model(self.model),
aggressive_unloading=self.low_vram,
)
pipeline = pipeline.to(self.device_torch)
return pipeline
def generate_single_image(
self,
pipeline: HiDreamImageEditingPipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
if gen_config.ctrl_img is None:
raise ValueError(
"Control image is required for Flux Kontext model generation."
)
else:
control_img = Image.open(gen_config.ctrl_img)
control_img = control_img.convert("RGB")
# resize to width and height
if control_img.size != (gen_config.width, gen_config.height):
control_img = control_img.resize(
(gen_config.width, gen_config.height), Image.BILINEAR
)
img = pipeline(
prompt_embeds_t5=conditional_embeds.text_embeds[0],
prompt_embeds_llama3=conditional_embeds.text_embeds[1],
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
negative_prompt_embeds_t5=unconditional_embeds.text_embeds[0],
negative_prompt_embeds_llama3=unconditional_embeds.text_embeds[1],
negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
generator=generator,
image=control_img,
**extra,
).images[0]
return img
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
self.text_encoder_to(self.device_torch, dtype=self.torch_dtype)
max_sequence_length = 128
(
prompt_embeds_t5,
negative_prompt_embeds_t5,
prompt_embeds_llama3,
negative_prompt_embeds_llama3,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.pipeline.encode_prompt(
prompt=prompt,
prompt_2=prompt,
prompt_3=prompt,
prompt_4=prompt,
device=self.device_torch,
dtype=self.torch_dtype,
num_images_per_prompt=1,
max_sequence_length=max_sequence_length,
do_classifier_free_guidance=False,
)
prompt_embeds = [prompt_embeds_t5, prompt_embeds_llama3]
pe = PromptEmbeds([prompt_embeds, pooled_prompt_embeds])
return pe
def condition_noisy_latents(
self, latents: torch.Tensor, batch: "DataLoaderBatchDTO"
):
with torch.no_grad():
control_tensor = batch.control_tensor
if control_tensor is not None:
self.vae.to(self.device_torch)
# we are not packed here, so we just need to pass them so we can pack them later
control_tensor = control_tensor * 2 - 1
control_tensor = control_tensor.to(
self.vae_device_torch, dtype=self.torch_dtype
)
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
if batch.tensor is not None:
target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3]
else:
# When caching latents, batch.tensor is None. We get the size from the file_items instead.
target_h = batch.file_items[0].crop_height
target_w = batch.file_items[0].crop_width
if (
control_tensor.shape[2] != target_h
or control_tensor.shape[3] != target_w
):
control_tensor = F.interpolate(
control_tensor, size=(target_h, target_w), mode="bilinear"
)
control_latent = self.encode_images(control_tensor).to(
latents.device, latents.dtype
)
latents = torch.cat((latents, control_latent), dim=1)
return latents.detach()
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
**kwargs,
):
with torch.no_grad():
# make sure config is set
self.model.config.force_inference_output = True
has_control = False
lat_size = latent_model_input.shape[-1]
if latent_model_input.shape[1] == 32:
# chunk it and stack it on batch dimension
# dont update batch size for img_its
lat, control = torch.chunk(latent_model_input, 2, dim=1)
latent_model_input = torch.cat([lat, control], dim=-1)
has_control = True
dtype = self.model.dtype
device = self.device_torch
text_embeds = text_embeddings.text_embeds
# run the to for the list
text_embeds = [te.to(device, dtype=dtype) for te in text_embeds]
noise_pred = self.transformer(
hidden_states=latent_model_input,
timesteps=timestep,
encoder_hidden_states_t5=text_embeds[0],
encoder_hidden_states_llama3=text_embeds[1],
pooled_embeds=text_embeddings.pooled_embeds.to(device, dtype=dtype),
return_dict=False,
)[0]
if has_control:
noise_pred = -1.0 * noise_pred[..., :lat_size]
else:
noise_pred = -1.0 * noise_pred
return noise_pred
|