File size: 5,363 Bytes
8750a83 805097b 8750a83 d83fb5a 8750a83 d83fb5a 8750a83 d83fb5a dc155d4 58fd9c6 8750a83 dc155d4 58fd9c6 ac9d3e7 58fd9c6 b6730f2 048bf77 b34c056 58fd9c6 b34c056 58fd9c6 ac9d3e7 58fd9c6 048bf77 58fd9c6 8750a83 58fd9c6 8750a83 dc155d4 8750a83 09a6fb7 8750a83 6e8eb03 8750a83 3faf8ae 8750a83 3faf8ae 8750a83 58fd9c6 8750a83 dc155d4 8750a83 dc155d4 8750a83 6e8eb03 8750a83 dc155d4 8750a83 58fd9c6 8750a83 6e8eb03 8750a83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
from typing import Any, Callable, ParamSpec
import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig
from optimization_utils import capture_component_call, aoti_compile, ZeroGPUCompiledModel, drain_module_parameters
P = ParamSpec('P')
TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)
TRANSFORMER_DYNAMIC_SHAPES = {'hidden_states': {2: TRANSFORMER_NUM_FRAMES_DIM}}
INDUCTOR_CONFIGS = {
'conv_1x1_as_mm': True,
'epilogue_fusion': False,
'coordinate_descent_tuning': True,
'coordinate_descent_check_all_directions': True,
'max_autotune': True,
'triton.cudagraphs': True,
}
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
print("[optimize_pipeline_] Starting pipeline optimization")
quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
print("[optimize_pipeline_] Text encoder quantized")
@spaces.GPU(duration=1500)
def compile_transformer():
print("[compile_transformer] Loading LoRA weights")
pipeline.load_lora_weights(
"DeepBeepMeep/Wan2.2",
weight_name="loras_accelerators/Wan2.2-Lightning_T2V-A14B-4steps-lora_HIGH_fp16.safetensors",
adapter_name="lightning"
)
pipeline.load_lora_weights(
"DeepBeepMeep/Wan2.2",
weight_name="loras_accelerators/Wan2.2-Lightning_T2V-A14B-4steps-lora_LOW_fp16.safetensors",
adapter_name="lightning_2",
load_into_transformer_2=True
)
pipeline.set_adapters(["lightning", "lightning_2"], adapter_weights=[1.0, 1.0])
print("[compile_transformer] Fusing LoRA weights")
pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3.0, components=["transformer"])
pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1.0, components=["transformer_2"])
pipeline.unload_lora_weights()
print("[compile_transformer] Running dummy forward pass to capture component call")
with torch.inference_mode():
with capture_component_call(pipeline, 'transformer') as call:
pipeline(*args, **kwargs)
dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
print("[compile_transformer] Quantizing transformers with Float8DynamicActivationFloat8WeightConfig")
quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
hidden_states: torch.Tensor = call.kwargs['hidden_states']
hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
if hidden_states.shape[-1] > hidden_states.shape[-2]:
hidden_states_landscape = hidden_states
hidden_states_portrait = hidden_states_transposed
else:
hidden_states_landscape = hidden_states_transposed
hidden_states_portrait = hidden_states
print("[compile_transformer] Exporting transformer landscape model")
exported_landscape_1 = torch.export.export(
mod=pipeline.transformer,
args=call.args,
kwargs={**call.kwargs, 'hidden_states': hidden_states_landscape},
dynamic_shapes=dynamic_shapes,
)
torch.cuda.synchronize()
print("[compile_transformer] Exporting transformer portrait model")
exported_portrait_2 = torch.export.export(
mod=pipeline.transformer_2,
args=call.args,
kwargs={**call.kwargs, 'hidden_states': hidden_states_portrait},
dynamic_shapes=dynamic_shapes,
)
torch.cuda.synchronize()
print("[compile_transformer] Compiling models with AoT compilation")
compiled_landscape_1 = aoti_compile(exported_landscape_1, INDUCTOR_CONFIGS)
compiled_portrait_2 = aoti_compile(exported_portrait_2, INDUCTOR_CONFIGS)
compiled_landscape_2 = ZeroGPUCompiledModel(compiled_landscape_1.archive_file, compiled_portrait_2.weights)
compiled_portrait_1 = ZeroGPUCompiledModel(compiled_portrait_2.archive_file, compiled_landscape_1.weights)
print("[compile_transformer] Compilation done")
return compiled_landscape_1, compiled_landscape_2, compiled_portrait_1, compiled_portrait_2
cl1, cl2, cp1, cp2 = compile_transformer()
def combined_transformer_1(*args, **kwargs):
hidden_states: torch.Tensor = kwargs['hidden_states']
if hidden_states.shape[-1] > hidden_states.shape[-2]:
return cl1(*args, **kwargs)
else:
return cp1(*args, **kwargs)
def combined_transformer_2(*args, **kwargs):
hidden_states: torch.Tensor = kwargs['hidden_states']
if hidden_states.shape[-1] > hidden_states.shape[-2]:
return cl2(*args, **kwargs)
else:
return cp2(*args, **kwargs)
pipeline.transformer.forward = combined_transformer_1
drain_module_parameters(pipeline.transformer)
pipeline.transformer_2.forward = combined_transformer_2
drain_module_parameters(pipeline.transformer_2)
print("[optimize_pipeline_] Optimization complete") |