Spaces:
Paused
Paused
from functools import partial | |
import torch | |
from toolkit.prompt_utils import PromptEmbeds | |
from PIL import Image | |
from diffusers import UniPCMultistepScheduler | |
import torch | |
from toolkit.config_modules import GenerateImageConfig, ModelConfig | |
from toolkit.samplers.custom_flowmatch_sampler import ( | |
CustomFlowMatchEulerDiscreteScheduler, | |
) | |
from .wan22_pipeline import Wan22Pipeline | |
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO | |
from torchvision.transforms import functional as TF | |
from toolkit.models.wan21.wan21 import Wan21, AggressiveWanUnloadPipeline | |
from toolkit.models.wan21.wan_utils import add_first_frame_conditioning_v22 | |
# for generation only? | |
scheduler_configUniPC = { | |
"_class_name": "UniPCMultistepScheduler", | |
"_diffusers_version": "0.35.0.dev0", | |
"beta_end": 0.02, | |
"beta_schedule": "linear", | |
"beta_start": 0.0001, | |
"disable_corrector": [], | |
"dynamic_thresholding_ratio": 0.995, | |
"final_sigmas_type": "zero", | |
"flow_shift": 5.0, | |
"lower_order_final": True, | |
"num_train_timesteps": 1000, | |
"predict_x0": True, | |
"prediction_type": "flow_prediction", | |
"rescale_betas_zero_snr": False, | |
"sample_max_value": 1.0, | |
"solver_order": 2, | |
"solver_p": None, | |
"solver_type": "bh2", | |
"steps_offset": 0, | |
"thresholding": False, | |
"time_shift_type": "exponential", | |
"timestep_spacing": "linspace", | |
"trained_betas": None, | |
"use_beta_sigmas": False, | |
"use_dynamic_shifting": False, | |
"use_exponential_sigmas": False, | |
"use_flow_sigmas": True, | |
"use_karras_sigmas": False, | |
} | |
# for training. I think it is right | |
scheduler_config = { | |
"num_train_timesteps": 1000, | |
"shift": 5.0, | |
"use_dynamic_shifting": False, | |
} | |
# TODO: this is a temporary monkeypatch to fix the time text embedding to allow for batch sizes greater than 1. Remove this when the diffusers library is fixed. | |
def time_text_monkeypatch( | |
self, | |
timestep: torch.Tensor, | |
encoder_hidden_states, | |
encoder_hidden_states_image = None, | |
timestep_seq_len = None, | |
): | |
timestep = self.timesteps_proj(timestep) | |
if timestep_seq_len is not None: | |
timestep = timestep.unflatten(0, (encoder_hidden_states.shape[0], timestep_seq_len)) | |
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype | |
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: | |
timestep = timestep.to(time_embedder_dtype) | |
temb = self.time_embedder(timestep).type_as(encoder_hidden_states) | |
timestep_proj = self.time_proj(self.act_fn(temb)) | |
encoder_hidden_states = self.text_embedder(encoder_hidden_states) | |
if encoder_hidden_states_image is not None: | |
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) | |
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image | |
class Wan225bModel(Wan21): | |
arch = "wan22_5b" | |
_wan_generation_scheduler_config = scheduler_configUniPC | |
_wan_expand_timesteps = True | |
def __init__( | |
self, | |
device, | |
model_config: ModelConfig, | |
dtype="bf16", | |
custom_pipeline=None, | |
noise_scheduler=None, | |
**kwargs, | |
): | |
super().__init__( | |
device=device, | |
model_config=model_config, | |
dtype=dtype, | |
custom_pipeline=custom_pipeline, | |
noise_scheduler=noise_scheduler, | |
**kwargs, | |
) | |
self._wan_cache = None | |
def load_model(self): | |
super().load_model() | |
# patch the condition embedder | |
self.model.condition_embedder.forward = partial(time_text_monkeypatch, self.model.condition_embedder) | |
def get_bucket_divisibility(self): | |
# 16x compression and 2x2 patch size | |
return 32 | |
def get_generation_pipeline(self): | |
scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config) | |
pipeline = Wan22Pipeline( | |
vae=self.vae, | |
transformer=self.model, | |
transformer_2=self.model, | |
text_encoder=self.text_encoder, | |
tokenizer=self.tokenizer, | |
scheduler=scheduler, | |
expand_timesteps=self._wan_expand_timesteps, | |
device=self.device_torch, | |
aggressive_offload=self.model_config.low_vram, | |
) | |
pipeline = pipeline.to(self.device_torch) | |
return pipeline | |
# static method to get the scheduler | |
def get_train_scheduler(): | |
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) | |
return scheduler | |
def get_base_model_version(self): | |
return "wan_2.2_5b" | |
def generate_single_image( | |
self, | |
pipeline: AggressiveWanUnloadPipeline, | |
gen_config: GenerateImageConfig, | |
conditional_embeds: PromptEmbeds, | |
unconditional_embeds: PromptEmbeds, | |
generator: torch.Generator, | |
extra: dict, | |
): | |
# reactivate progress bar since this is slooooow | |
pipeline.set_progress_bar_config(disable=False) | |
num_frames = ( | |
(gen_config.num_frames - 1) // 4 | |
) * 4 + 1 # make sure it is divisible by 4 + 1 | |
gen_config.num_frames = num_frames | |
height = gen_config.height | |
width = gen_config.width | |
noise_mask = None | |
if gen_config.ctrl_img is not None: | |
control_img = Image.open(gen_config.ctrl_img).convert("RGB") | |
d = self.get_bucket_divisibility() | |
# make sure they are divisible by d | |
height = height // d * d | |
width = width // d * d | |
# resize the control image | |
control_img = control_img.resize((width, height), Image.LANCZOS) | |
# 5. Prepare latent variables | |
num_channels_latents = self.transformer.config.in_channels | |
latents = pipeline.prepare_latents( | |
1, | |
num_channels_latents, | |
height, | |
width, | |
gen_config.num_frames, | |
torch.float32, | |
self.device_torch, | |
generator, | |
None, | |
).to(self.torch_dtype) | |
first_frame_n1p1 = ( | |
TF.to_tensor(control_img) | |
.unsqueeze(0) | |
.to(self.device_torch, dtype=self.torch_dtype) | |
* 2.0 | |
- 1.0 | |
) # normalize to [-1, 1] | |
gen_config.latents, noise_mask = add_first_frame_conditioning_v22( | |
latent_model_input=latents, first_frame=first_frame_n1p1, vae=self.vae | |
) | |
output = pipeline( | |
prompt_embeds=conditional_embeds.text_embeds.to( | |
self.device_torch, dtype=self.torch_dtype | |
), | |
negative_prompt_embeds=unconditional_embeds.text_embeds.to( | |
self.device_torch, dtype=self.torch_dtype | |
), | |
height=height, | |
width=width, | |
num_inference_steps=gen_config.num_inference_steps, | |
guidance_scale=gen_config.guidance_scale, | |
latents=gen_config.latents, | |
num_frames=gen_config.num_frames, | |
generator=generator, | |
return_dict=False, | |
output_type="pil", | |
noise_mask=noise_mask, | |
**extra, | |
)[0] | |
# shape = [1, frames, channels, height, width] | |
batch_item = output[0] # list of pil images | |
if gen_config.num_frames > 1: | |
return batch_item # return the frames. | |
else: | |
# get just the first image | |
img = batch_item[0] | |
return img | |
def get_noise_prediction( | |
self, | |
latent_model_input: torch.Tensor, | |
timestep: torch.Tensor, # 0 to 1000 scale | |
text_embeddings: PromptEmbeds, | |
batch: DataLoaderBatchDTO, | |
**kwargs, | |
): | |
# videos come in (bs, num_frames, channels, height, width) | |
# images come in (bs, channels, height, width) | |
# for wan, only do i2v for video for now. Images do normal t2i | |
conditioned_latent = latent_model_input | |
noise_mask = None | |
if batch.dataset_config.do_i2v: | |
with torch.no_grad(): | |
frames = batch.tensor | |
if len(frames.shape) == 4: | |
first_frames = frames | |
elif len(frames.shape) == 5: | |
first_frames = frames[:, 0] | |
# Add conditioning using the standalone function | |
conditioned_latent, noise_mask = add_first_frame_conditioning_v22( | |
latent_model_input=latent_model_input.to( | |
self.device_torch, self.torch_dtype | |
), | |
first_frame=first_frames.to(self.device_torch, self.torch_dtype), | |
vae=self.vae, | |
) | |
else: | |
raise ValueError(f"Unknown frame shape {frames.shape}") | |
# make the noise mask | |
if noise_mask is None: | |
noise_mask = torch.ones( | |
conditioned_latent.shape, | |
dtype=conditioned_latent.dtype, | |
device=conditioned_latent.device, | |
) | |
# todo write this better | |
t_chunks = torch.chunk(timestep, timestep.shape[0]) | |
out_t_chunks = [] | |
for t in t_chunks: | |
# seq_len: num_latent_frames * latent_height//2 * latent_width//2 | |
temp_ts = (noise_mask[0][0][:, ::2, ::2] * t).flatten() | |
# batch_size, seq_len | |
temp_ts = temp_ts.unsqueeze(0) | |
out_t_chunks.append(temp_ts) | |
timestep = torch.cat(out_t_chunks, dim=0) | |
noise_pred = self.model( | |
hidden_states=conditioned_latent, | |
timestep=timestep, | |
encoder_hidden_states=text_embeddings.text_embeds, | |
return_dict=False, | |
**kwargs, | |
)[0] | |
return noise_pred | |