jiuface commited on
Commit
95e4e92
·
verified ·
1 Parent(s): 5ecfc10

Create optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +128 -0
optimization.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Any
3
+ from typing import Callable
4
+ from typing import ParamSpec
5
+
6
+ import spaces
7
+ import torch
8
+ from torch.utils._pytree import tree_map_only
9
+ from torchao.quantization import quantize_
10
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
11
+ from torchao.quantization import Int8WeightOnlyConfig
12
+
13
+ from optimization_utils import capture_component_call
14
+ from optimization_utils import aoti_compile
15
+ from optimization_utils import ZeroGPUCompiledModel
16
+ from optimization_utils import drain_module_parameters
17
+
18
+
19
+ P = ParamSpec('P')
20
+
21
+
22
+ TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)
23
+
24
+ TRANSFORMER_DYNAMIC_SHAPES = {
25
+ 'hidden_states': {
26
+ 2: TRANSFORMER_NUM_FRAMES_DIM,
27
+ },
28
+ }
29
+
30
+ INDUCTOR_CONFIGS = {
31
+ 'conv_1x1_as_mm': True,
32
+ 'epilogue_fusion': False,
33
+ 'coordinate_descent_tuning': True,
34
+ 'coordinate_descent_check_all_directions': True,
35
+ 'max_autotune': True,
36
+ 'triton.cudagraphs': True,
37
+ }
38
+
39
+
40
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
41
+
42
+ @spaces.GPU(duration=1500)
43
+ def compile_transformer():
44
+
45
+ pipeline.load_lora_weights(
46
+ "Kijai/WanVideo_comfy",
47
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
48
+ adapter_name="lightx2v"
49
+ )
50
+ kwargs_lora = {}
51
+ kwargs_lora["load_into_transformer_2"] = True
52
+ pipeline.load_lora_weights(
53
+ "Kijai/WanVideo_comfy",
54
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
55
+ adapter_name="lightx2v_2", **kwargs_lora
56
+ )
57
+ pipeline.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1., 1.])
58
+ pipeline.fuse_lora(adapter_names=["lightx2v"], lora_scale=3., components=["transformer"])
59
+ pipeline.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1., components=["transformer_2"])
60
+ pipeline.unload_lora_weights()
61
+
62
+ with capture_component_call(pipeline, 'transformer') as call:
63
+ pipeline(*args, **kwargs)
64
+
65
+ dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
66
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
67
+
68
+ quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
69
+ quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
70
+
71
+ hidden_states: torch.Tensor = call.kwargs['hidden_states']
72
+ hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
73
+ if hidden_states.shape[-1] > hidden_states.shape[-2]:
74
+ hidden_states_landscape = hidden_states
75
+ hidden_states_portrait = hidden_states_transposed
76
+ else:
77
+ hidden_states_landscape = hidden_states_transposed
78
+ hidden_states_portrait = hidden_states
79
+
80
+ exported_landscape_1 = torch.export.export(
81
+ mod=pipeline.transformer,
82
+ args=call.args,
83
+ kwargs=call.kwargs | {'hidden_states': hidden_states_landscape},
84
+ dynamic_shapes=dynamic_shapes,
85
+ )
86
+
87
+ exported_portrait_2 = torch.export.export(
88
+ mod=pipeline.transformer_2,
89
+ args=call.args,
90
+ kwargs=call.kwargs | {'hidden_states': hidden_states_portrait},
91
+ dynamic_shapes=dynamic_shapes,
92
+ )
93
+
94
+ compiled_landscape_1 = aoti_compile(exported_landscape_1, INDUCTOR_CONFIGS)
95
+ compiled_portrait_2 = aoti_compile(exported_portrait_2, INDUCTOR_CONFIGS)
96
+
97
+ compiled_landscape_2 = ZeroGPUCompiledModel(compiled_landscape_1.archive_file, compiled_portrait_2.weights)
98
+ compiled_portrait_1 = ZeroGPUCompiledModel(compiled_portrait_2.archive_file, compiled_landscape_1.weights)
99
+
100
+ return (
101
+ compiled_landscape_1,
102
+ compiled_landscape_2,
103
+ compiled_portrait_1,
104
+ compiled_portrait_2,
105
+ )
106
+
107
+ quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
108
+ cl1, cl2, cp1, cp2 = compile_transformer()
109
+
110
+ def combined_transformer_1(*args, **kwargs):
111
+ hidden_states: torch.Tensor = kwargs['hidden_states']
112
+ if hidden_states.shape[-1] > hidden_states.shape[-2]:
113
+ return cl1(*args, **kwargs)
114
+ else:
115
+ return cp1(*args, **kwargs)
116
+
117
+ def combined_transformer_2(*args, **kwargs):
118
+ hidden_states: torch.Tensor = kwargs['hidden_states']
119
+ if hidden_states.shape[-1] > hidden_states.shape[-2]:
120
+ return cl2(*args, **kwargs)
121
+ else:
122
+ return cp2(*args, **kwargs)
123
+
124
+ pipeline.transformer.forward = combined_transformer_1
125
+ drain_module_parameters(pipeline.transformer)
126
+
127
+ pipeline.transformer_2.forward = combined_transformer_2
128
+ drain_module_parameters(pipeline.transformer_2)