File size: 5,363 Bytes
8750a83
805097b
 
 
8750a83
 
d83fb5a
8750a83
d83fb5a
 
8750a83
d83fb5a
 
 
 
 
 
 
 
 
dc155d4
 
58fd9c6
 
8750a83
 
dc155d4
 
 
58fd9c6
ac9d3e7
58fd9c6
 
b6730f2
048bf77
b34c056
58fd9c6
 
 
 
b34c056
58fd9c6
ac9d3e7
58fd9c6
 
 
048bf77
58fd9c6
8750a83
58fd9c6
 
 
 
8750a83
 
dc155d4
8750a83
 
 
09a6fb7
8750a83
 
6e8eb03
8750a83
 
 
 
 
 
 
 
 
 
 
 
 
3faf8ae
8750a83
 
 
 
 
 
 
 
3faf8ae
8750a83
 
 
 
 
58fd9c6
8750a83
 
dc155d4
8750a83
 
dc155d4
8750a83
6e8eb03
8750a83
 
 
 
 
 
dc155d4
8750a83
 
 
 
 
 
58fd9c6
8750a83
 
 
 
6e8eb03
8750a83
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from typing import Any, Callable, ParamSpec
import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig
from optimization_utils import capture_component_call, aoti_compile, ZeroGPUCompiledModel, drain_module_parameters

P = ParamSpec('P')

TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)
TRANSFORMER_DYNAMIC_SHAPES = {'hidden_states': {2: TRANSFORMER_NUM_FRAMES_DIM}}

INDUCTOR_CONFIGS = {
    'conv_1x1_as_mm': True,
    'epilogue_fusion': False,
    'coordinate_descent_tuning': True,
    'coordinate_descent_check_all_directions': True,
    'max_autotune': True,
    'triton.cudagraphs': True,
}

def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
    print("[optimize_pipeline_] Starting pipeline optimization")

    quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
    print("[optimize_pipeline_] Text encoder quantized")

    @spaces.GPU(duration=1500)
    def compile_transformer():
        print("[compile_transformer] Loading LoRA weights")
        pipeline.load_lora_weights(
            "DeepBeepMeep/Wan2.2",
            weight_name="loras_accelerators/Wan2.2-Lightning_T2V-A14B-4steps-lora_HIGH_fp16.safetensors",
            adapter_name="lightning"
        )
        pipeline.load_lora_weights(
            "DeepBeepMeep/Wan2.2",
            weight_name="loras_accelerators/Wan2.2-Lightning_T2V-A14B-4steps-lora_LOW_fp16.safetensors",
            adapter_name="lightning_2",
            load_into_transformer_2=True
        )
        pipeline.set_adapters(["lightning", "lightning_2"], adapter_weights=[1.0, 1.0])

        print("[compile_transformer] Fusing LoRA weights")
        pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3.0, components=["transformer"])
        pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1.0, components=["transformer_2"])
        pipeline.unload_lora_weights()

        print("[compile_transformer] Running dummy forward pass to capture component call")
        with torch.inference_mode():
            with capture_component_call(pipeline, 'transformer') as call:
                pipeline(*args, **kwargs)

        dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
        dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES

        print("[compile_transformer] Quantizing transformers with Float8DynamicActivationFloat8WeightConfig")
        quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
        quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())

        hidden_states: torch.Tensor = call.kwargs['hidden_states']
        hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()

        if hidden_states.shape[-1] > hidden_states.shape[-2]:
            hidden_states_landscape = hidden_states
            hidden_states_portrait = hidden_states_transposed
        else:
            hidden_states_landscape = hidden_states_transposed
            hidden_states_portrait = hidden_states

        print("[compile_transformer] Exporting transformer landscape model")
        exported_landscape_1 = torch.export.export(
            mod=pipeline.transformer,
            args=call.args,
            kwargs={**call.kwargs, 'hidden_states': hidden_states_landscape},
            dynamic_shapes=dynamic_shapes,
        )
        torch.cuda.synchronize()

        print("[compile_transformer] Exporting transformer portrait model")
        exported_portrait_2 = torch.export.export(
            mod=pipeline.transformer_2,
            args=call.args,
            kwargs={**call.kwargs, 'hidden_states': hidden_states_portrait},
            dynamic_shapes=dynamic_shapes,
        )
        torch.cuda.synchronize()

        print("[compile_transformer] Compiling models with AoT compilation")
        compiled_landscape_1 = aoti_compile(exported_landscape_1, INDUCTOR_CONFIGS)
        compiled_portrait_2 = aoti_compile(exported_portrait_2, INDUCTOR_CONFIGS)

        compiled_landscape_2 = ZeroGPUCompiledModel(compiled_landscape_1.archive_file, compiled_portrait_2.weights)
        compiled_portrait_1 = ZeroGPUCompiledModel(compiled_portrait_2.archive_file, compiled_landscape_1.weights)

        print("[compile_transformer] Compilation done")
        return compiled_landscape_1, compiled_landscape_2, compiled_portrait_1, compiled_portrait_2

    cl1, cl2, cp1, cp2 = compile_transformer()

    def combined_transformer_1(*args, **kwargs):
        hidden_states: torch.Tensor = kwargs['hidden_states']
        if hidden_states.shape[-1] > hidden_states.shape[-2]:
            return cl1(*args, **kwargs)
        else:
            return cp1(*args, **kwargs)

    def combined_transformer_2(*args, **kwargs):
        hidden_states: torch.Tensor = kwargs['hidden_states']
        if hidden_states.shape[-1] > hidden_states.shape[-2]:
            return cl2(*args, **kwargs)
        else:
            return cp2(*args, **kwargs)

    pipeline.transformer.forward = combined_transformer_1
    drain_module_parameters(pipeline.transformer)
    pipeline.transformer_2.forward = combined_transformer_2
    drain_module_parameters(pipeline.transformer_2)

    print("[optimize_pipeline_] Optimization complete")