jbilcke-hf's picture
jbilcke-hf HF Staff
Upload 430 files
3cc1e25 verified
raw
history blame
20.3 kB
from functools import partial
import os
from typing import Any, Dict, Optional, Union, List
from typing_extensions import Self
import torch
import yaml
from toolkit.accelerator import unwrap_model
from toolkit.basic import flush
from toolkit.prompt_utils import PromptEmbeds
from PIL import Image
from diffusers import UniPCMultistepScheduler
import torch
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.samplers.custom_flowmatch_sampler import (
CustomFlowMatchEulerDiscreteScheduler,
)
from toolkit.util.quantize import quantize_model
from .wan22_pipeline import Wan22Pipeline
from diffusers import WanTransformer3DModel
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from torchvision.transforms import functional as TF
from toolkit.models.wan21.wan21 import AggressiveWanUnloadPipeline, Wan21
from .wan22_5b_model import (
scheduler_config,
time_text_monkeypatch,
Wan225bModel,
)
from safetensors.torch import load_file, save_file
boundary_ratio_t2v = 0.875
boundary_ratio_i2v = 0.9
scheduler_configUniPC = {
"_class_name": "UniPCMultistepScheduler",
"_diffusers_version": "0.35.0.dev0",
"beta_end": 0.02,
"beta_schedule": "linear",
"beta_start": 0.0001,
"disable_corrector": [],
"dynamic_thresholding_ratio": 0.995,
"final_sigmas_type": "zero",
"flow_shift": 3.0,
"lower_order_final": True,
"num_train_timesteps": 1000,
"predict_x0": True,
"prediction_type": "flow_prediction",
"rescale_betas_zero_snr": False,
"sample_max_value": 1.0,
"solver_order": 2,
"solver_p": None,
"solver_type": "bh2",
"steps_offset": 0,
"thresholding": False,
"time_shift_type": "exponential",
"timestep_spacing": "linspace",
"trained_betas": None,
"use_beta_sigmas": False,
"use_dynamic_shifting": False,
"use_exponential_sigmas": False,
"use_flow_sigmas": True,
"use_karras_sigmas": False,
}
class DualWanTransformer3DModel(torch.nn.Module):
def __init__(
self,
transformer_1: WanTransformer3DModel,
transformer_2: WanTransformer3DModel,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
device: Optional[Union[str, torch.device]] = None,
boundary_ratio: float = boundary_ratio_t2v,
low_vram: bool = False,
) -> None:
super().__init__()
self.transformer_1: WanTransformer3DModel = transformer_1
self.transformer_2: WanTransformer3DModel = transformer_2
self.torch_dtype: torch.dtype = torch_dtype
self.device_torch: torch.device = device
self.boundary_ratio: float = boundary_ratio
self.boundary: float = self.boundary_ratio * 1000
self.low_vram: bool = low_vram
self._active_transformer_name = "transformer_1" # default to transformer_1
@property
def device(self) -> torch.device:
return self.device_torch
@property
def dtype(self) -> torch.dtype:
return self.torch_dtype
@property
def config(self):
return self.transformer_1.config
@property
def transformer(self) -> WanTransformer3DModel:
return getattr(self, self._active_transformer_name)
def enable_gradient_checkpointing(self):
"""
Enable gradient checkpointing for both transformers.
"""
self.transformer_1.enable_gradient_checkpointing()
self.transformer_2.enable_gradient_checkpointing()
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
**kwargs
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
# determine if doing high noise or low noise by meaning the timestep.
# timesteps are in the range of 0 to 1000, so we can use a threshold
with torch.no_grad():
if timestep.float().mean().item() > self.boundary:
t_name = "transformer_1"
else:
t_name = "transformer_2"
# check if we are changing the active transformer, if so, we need to swap the one in
# vram if low_vram is enabled
# todo swap the loras as well
if t_name != self._active_transformer_name:
if self.low_vram:
getattr(self, self._active_transformer_name).to("cpu")
getattr(self, t_name).to(self.device_torch)
torch.cuda.empty_cache()
self._active_transformer_name = t_name
if self.transformer.device != hidden_states.device:
if self.low_vram:
# move other transformer to cpu
other_tname = (
"transformer_1" if t_name == "transformer_2" else "transformer_2"
)
getattr(self, other_tname).to("cpu")
self.transformer.to(hidden_states.device)
return self.transformer(
hidden_states=hidden_states,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_image=encoder_hidden_states_image,
return_dict=return_dict,
attention_kwargs=attention_kwargs,
)
def to(self, *args, **kwargs) -> Self:
# do not do to, this will be handled separately
return self
class Wan2214bModel(Wan21):
arch = "wan22_14b"
_wan_generation_scheduler_config = scheduler_configUniPC
_wan_expand_timesteps = False
_wan_vae_path = "ai-toolkit/wan2.1-vae"
def __init__(
self,
device,
model_config: ModelConfig,
dtype="bf16",
custom_pipeline=None,
noise_scheduler=None,
**kwargs,
):
super().__init__(
device=device,
model_config=model_config,
dtype=dtype,
custom_pipeline=custom_pipeline,
noise_scheduler=noise_scheduler,
**kwargs,
)
# target it so we can target both transformers
self.target_lora_modules = ["DualWanTransformer3DModel"]
self._wan_cache = None
self.is_multistage = True
# multistage boundaries split the models up when sampling timesteps
# for wan 2.2 14b. the timesteps are 1000-875 for transformer 1 and 875-0 for transformer 2
self.multistage_boundaries: List[float] = [0.875, 0.0]
self.train_high_noise = model_config.model_kwargs.get("train_high_noise", True)
self.train_low_noise = model_config.model_kwargs.get("train_low_noise", True)
self.trainable_multistage_boundaries: List[int] = []
if self.train_high_noise:
self.trainable_multistage_boundaries.append(0)
if self.train_low_noise:
self.trainable_multistage_boundaries.append(1)
if len(self.trainable_multistage_boundaries) == 0:
raise ValueError(
"At least one of train_high_noise or train_low_noise must be True in model.model_kwargs"
)
# if we are only training one or the other, the target LoRA modules will be the wan transformer class
if not self.train_high_noise or not self.train_low_noise:
self.target_lora_modules = ["WanTransformer3DModel"]
@property
def max_step_saves_to_keep_multiplier(self):
# the cleanup mechanism checks this to see how many saves to keep
# if we are training a LoRA, we need to set this to 2 so we keep both the high noise and low noise LoRAs at saves to keep
if (
self.network is not None
and self.network.network_config.split_multistage_loras
):
return 2
return 1
def load_model(self):
# load model from patent parent. Wan21 not immediate parent
# super().load_model()
super().load_model()
# we have to split up the model on the pipeline
self.pipeline.transformer = self.model.transformer_1
self.pipeline.transformer_2 = self.model.transformer_2
# patch the condition embedder
self.model.transformer_1.condition_embedder.forward = partial(
time_text_monkeypatch, self.model.transformer_1.condition_embedder
)
self.model.transformer_2.condition_embedder.forward = partial(
time_text_monkeypatch, self.model.transformer_2.condition_embedder
)
def get_bucket_divisibility(self):
# 16x compression and 2x2 patch size
return 32
def load_wan_transformer(self, transformer_path, subfolder=None):
if self.model_config.split_model_over_gpus:
raise ValueError(
"Splitting model over gpus is not supported for Wan2.2 models"
)
if (
self.model_config.assistant_lora_path is not None
or self.model_config.inference_lora_path is not None
):
raise ValueError(
"Assistant LoRA is not supported for Wan2.2 models currently"
)
if self.model_config.lora_path is not None:
raise ValueError(
"Loading LoRA is not supported for Wan2.2 models currently"
)
# transformer path can be a directory that ends with /transformer or a hf path.
transformer_path_1 = transformer_path
subfolder_1 = subfolder
transformer_path_2 = transformer_path
subfolder_2 = subfolder
if subfolder_2 is None:
# we have a local path, replace it with transformer_2 folder
transformer_path_2 = os.path.join(
os.path.dirname(transformer_path_1), "transformer_2"
)
else:
# we have a hf path, replace it with transformer_2 subfolder
subfolder_2 = "transformer_2"
self.print_and_status_update("Loading transformer 1")
dtype = self.torch_dtype
transformer_1 = WanTransformer3DModel.from_pretrained(
transformer_path_1,
subfolder=subfolder_1,
torch_dtype=dtype,
).to(dtype=dtype)
flush()
if not self.model_config.low_vram:
# quantize on the device
transformer_1.to(self.quantize_device, dtype=dtype)
flush()
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None:
# todo handle two ARAs
self.print_and_status_update("Quantizing Transformer 1")
quantize_model(self, transformer_1)
flush()
if self.model_config.low_vram:
self.print_and_status_update("Moving transformer 1 to CPU")
transformer_1.to("cpu")
self.print_and_status_update("Loading transformer 2")
dtype = self.torch_dtype
transformer_2 = WanTransformer3DModel.from_pretrained(
transformer_path_2,
subfolder=subfolder_2,
torch_dtype=dtype,
).to(dtype=dtype)
flush()
if not self.model_config.low_vram:
# quantize on the device
transformer_2.to(self.quantize_device, dtype=dtype)
flush()
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None:
# todo handle two ARAs
self.print_and_status_update("Quantizing Transformer 2")
quantize_model(self, transformer_2)
flush()
if self.model_config.low_vram:
self.print_and_status_update("Moving transformer 2 to CPU")
transformer_2.to("cpu")
# make the combined model
self.print_and_status_update("Creating DualWanTransformer3DModel")
transformer = DualWanTransformer3DModel(
transformer_1=transformer_1,
transformer_2=transformer_2,
torch_dtype=self.torch_dtype,
device=self.device_torch,
boundary_ratio=boundary_ratio_t2v,
low_vram=self.model_config.low_vram,
)
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is not None:
# apply the accuracy recovery adapter to both transformers
self.print_and_status_update("Applying Accuracy Recovery Adapter to Transformers")
quantize_model(self, transformer)
flush()
return transformer
def get_generation_pipeline(self):
scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config)
pipeline = Wan22Pipeline(
vae=self.vae,
transformer=self.model.transformer_1,
transformer_2=self.model.transformer_2,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=scheduler,
expand_timesteps=self._wan_expand_timesteps,
device=self.device_torch,
aggressive_offload=self.model_config.low_vram,
# todo detect if it is i2v or t2v
boundary_ratio=boundary_ratio_t2v,
)
# pipeline = pipeline.to(self.device_torch)
return pipeline
# static method to get the scheduler
@staticmethod
def get_train_scheduler():
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
return scheduler
def get_base_model_version(self):
return "wan_2.2_14b"
def generate_single_image(
self,
pipeline: AggressiveWanUnloadPipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
return super().generate_single_image(
pipeline=pipeline,
gen_config=gen_config,
conditional_embeds=conditional_embeds,
unconditional_embeds=unconditional_embeds,
generator=generator,
extra=extra,
)
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
batch: DataLoaderBatchDTO,
**kwargs,
):
# todo do we need to override this? Adjust timesteps?
return super().get_noise_prediction(
latent_model_input=latent_model_input,
timestep=timestep,
text_embeddings=text_embeddings,
batch=batch,
**kwargs,
)
def get_model_has_grad(self):
return False
def get_te_has_grad(self):
return False
def save_model(self, output_path, meta, save_dtype):
transformer_combo: DualWanTransformer3DModel = unwrap_model(self.model)
transformer_combo.transformer_1.save_pretrained(
save_directory=os.path.join(output_path, "transformer"),
safe_serialization=True,
)
transformer_combo.transformer_2.save_pretrained(
save_directory=os.path.join(output_path, "transformer_2"),
safe_serialization=True,
)
meta_path = os.path.join(output_path, "aitk_meta.yaml")
with open(meta_path, "w") as f:
yaml.dump(meta, f)
def save_lora(
self,
state_dict: Dict[str, torch.Tensor],
output_path: str,
metadata: Optional[Dict[str, Any]] = None,
):
if not self.network.network_config.split_multistage_loras:
# just save as a combo lora
save_file(state_dict, output_path, metadata=metadata)
return
# we need to build out both dictionaries for high and low noise LoRAs
high_noise_lora = {}
low_noise_lora = {}
only_train_high_noise = self.train_high_noise and not self.train_low_noise
only_train_low_noise = self.train_low_noise and not self.train_high_noise
for key in state_dict:
if ".transformer_1." in key or only_train_high_noise:
# this is a high noise LoRA
new_key = key.replace(".transformer_1.", ".")
high_noise_lora[new_key] = state_dict[key]
elif ".transformer_2." in key or only_train_low_noise:
# this is a low noise LoRA
new_key = key.replace(".transformer_2.", ".")
low_noise_lora[new_key] = state_dict[key]
# loras have either LORA_MODEL_NAME_000005000.safetensors or LORA_MODEL_NAME.safetensors
if len(high_noise_lora.keys()) > 0:
# save the high noise LoRA
high_noise_lora_path = output_path.replace(
".safetensors", "_high_noise.safetensors"
)
save_file(high_noise_lora, high_noise_lora_path, metadata=metadata)
if len(low_noise_lora.keys()) > 0:
# save the low noise LoRA
low_noise_lora_path = output_path.replace(
".safetensors", "_low_noise.safetensors"
)
save_file(low_noise_lora, low_noise_lora_path, metadata=metadata)
def load_lora(self, file: str):
# if it doesnt have high_noise or low_noise, it is a combo LoRA
if (
"_high_noise.safetensors" not in file
and "_low_noise.safetensors" not in file
):
# this is a combined LoRA, we dont need to split it up
sd = load_file(file)
return sd
# we may have been passed the high_noise or the low_noise LoRA path, but we need to load both
high_noise_lora_path = file.replace(
"_low_noise.safetensors", "_high_noise.safetensors"
)
low_noise_lora_path = file.replace(
"_high_noise.safetensors", "_low_noise.safetensors"
)
combined_dict = {}
if os.path.exists(high_noise_lora_path) and self.train_high_noise:
# load the high noise LoRA
high_noise_lora = load_file(high_noise_lora_path)
for key in high_noise_lora:
new_key = key.replace(
"diffusion_model.", "diffusion_model.transformer_1."
)
combined_dict[new_key] = high_noise_lora[key]
if os.path.exists(low_noise_lora_path) and self.train_low_noise:
# load the low noise LoRA
low_noise_lora = load_file(low_noise_lora_path)
for key in low_noise_lora:
new_key = key.replace(
"diffusion_model.", "diffusion_model.transformer_2."
)
combined_dict[new_key] = low_noise_lora[key]
# if we are not training both stages, we wont have transformer designations in the keys
if not self.train_high_noise and not self.train_low_noise:
new_dict = {}
for key in combined_dict:
if ".transformer_1." in key:
new_key = key.replace(".transformer_1.", ".")
elif ".transformer_2." in key:
new_key = key.replace(".transformer_2.", ".")
else:
new_key = key
new_dict[new_key] = combined_dict[key]
combined_dict = new_dict
return combined_dict
def get_model_to_train(self):
# todo, loras wont load right unless they have the transformer_1 or transformer_2 in the key.
# called when setting up the LoRA. We only need to get the model for the stages we want to train.
if self.train_high_noise and self.train_low_noise:
# we are training both stages, return the unified model
return self.model
elif self.train_high_noise:
# we are only training the high noise stage, return transformer_1
return self.model.transformer_1
elif self.train_low_noise:
# we are only training the low noise stage, return transformer_2
return self.model.transformer_2
else:
raise ValueError(
"At least one of train_high_noise or train_low_noise must be True in model.model_kwargs"
)