wan2.2-14B-TI2V-ALL / optimization.py
rahul7star's picture
Update optimization.py
8750a83 verified
raw
history blame
5.36 kB
from typing import Any, Callable, ParamSpec
import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig
from optimization_utils import capture_component_call, aoti_compile, ZeroGPUCompiledModel, drain_module_parameters
P = ParamSpec('P')
TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)
TRANSFORMER_DYNAMIC_SHAPES = {'hidden_states': {2: TRANSFORMER_NUM_FRAMES_DIM}}
INDUCTOR_CONFIGS = {
'conv_1x1_as_mm': True,
'epilogue_fusion': False,
'coordinate_descent_tuning': True,
'coordinate_descent_check_all_directions': True,
'max_autotune': True,
'triton.cudagraphs': True,
}
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
print("[optimize_pipeline_] Starting pipeline optimization")
quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
print("[optimize_pipeline_] Text encoder quantized")
@spaces.GPU(duration=1500)
def compile_transformer():
print("[compile_transformer] Loading LoRA weights")
pipeline.load_lora_weights(
"DeepBeepMeep/Wan2.2",
weight_name="loras_accelerators/Wan2.2-Lightning_T2V-A14B-4steps-lora_HIGH_fp16.safetensors",
adapter_name="lightning"
)
pipeline.load_lora_weights(
"DeepBeepMeep/Wan2.2",
weight_name="loras_accelerators/Wan2.2-Lightning_T2V-A14B-4steps-lora_LOW_fp16.safetensors",
adapter_name="lightning_2",
load_into_transformer_2=True
)
pipeline.set_adapters(["lightning", "lightning_2"], adapter_weights=[1.0, 1.0])
print("[compile_transformer] Fusing LoRA weights")
pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3.0, components=["transformer"])
pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1.0, components=["transformer_2"])
pipeline.unload_lora_weights()
print("[compile_transformer] Running dummy forward pass to capture component call")
with torch.inference_mode():
with capture_component_call(pipeline, 'transformer') as call:
pipeline(*args, **kwargs)
dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
print("[compile_transformer] Quantizing transformers with Float8DynamicActivationFloat8WeightConfig")
quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
hidden_states: torch.Tensor = call.kwargs['hidden_states']
hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
if hidden_states.shape[-1] > hidden_states.shape[-2]:
hidden_states_landscape = hidden_states
hidden_states_portrait = hidden_states_transposed
else:
hidden_states_landscape = hidden_states_transposed
hidden_states_portrait = hidden_states
print("[compile_transformer] Exporting transformer landscape model")
exported_landscape_1 = torch.export.export(
mod=pipeline.transformer,
args=call.args,
kwargs={**call.kwargs, 'hidden_states': hidden_states_landscape},
dynamic_shapes=dynamic_shapes,
)
torch.cuda.synchronize()
print("[compile_transformer] Exporting transformer portrait model")
exported_portrait_2 = torch.export.export(
mod=pipeline.transformer_2,
args=call.args,
kwargs={**call.kwargs, 'hidden_states': hidden_states_portrait},
dynamic_shapes=dynamic_shapes,
)
torch.cuda.synchronize()
print("[compile_transformer] Compiling models with AoT compilation")
compiled_landscape_1 = aoti_compile(exported_landscape_1, INDUCTOR_CONFIGS)
compiled_portrait_2 = aoti_compile(exported_portrait_2, INDUCTOR_CONFIGS)
compiled_landscape_2 = ZeroGPUCompiledModel(compiled_landscape_1.archive_file, compiled_portrait_2.weights)
compiled_portrait_1 = ZeroGPUCompiledModel(compiled_portrait_2.archive_file, compiled_landscape_1.weights)
print("[compile_transformer] Compilation done")
return compiled_landscape_1, compiled_landscape_2, compiled_portrait_1, compiled_portrait_2
cl1, cl2, cp1, cp2 = compile_transformer()
def combined_transformer_1(*args, **kwargs):
hidden_states: torch.Tensor = kwargs['hidden_states']
if hidden_states.shape[-1] > hidden_states.shape[-2]:
return cl1(*args, **kwargs)
else:
return cp1(*args, **kwargs)
def combined_transformer_2(*args, **kwargs):
hidden_states: torch.Tensor = kwargs['hidden_states']
if hidden_states.shape[-1] > hidden_states.shape[-2]:
return cl2(*args, **kwargs)
else:
return cp2(*args, **kwargs)
pipeline.transformer.forward = combined_transformer_1
drain_module_parameters(pipeline.transformer)
pipeline.transformer_2.forward = combined_transformer_2
drain_module_parameters(pipeline.transformer_2)
print("[optimize_pipeline_] Optimization complete")