ginipick commited on
Commit
28fa1bf
·
verified ·
1 Parent(s): b57dbb1

Upload 2 files

Browse files
Files changed (2) hide show
  1. optimization.py +130 -0
  2. optimization_utils.py +107 -0
optimization.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from typing import Any
5
+ from typing import Callable
6
+ from typing import ParamSpec
7
+
8
+ import spaces
9
+ import torch
10
+ from torch.utils._pytree import tree_map_only
11
+ from torchao.quantization import quantize_
12
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
13
+ from torchao.quantization import Int8WeightOnlyConfig
14
+
15
+ from optimization_utils import capture_component_call
16
+ from optimization_utils import aoti_compile
17
+ from optimization_utils import ZeroGPUCompiledModel
18
+ from optimization_utils import drain_module_parameters
19
+
20
+
21
+ P = ParamSpec('P')
22
+
23
+
24
+ TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)
25
+
26
+ TRANSFORMER_DYNAMIC_SHAPES = {
27
+ 'hidden_states': {
28
+ 2: TRANSFORMER_NUM_FRAMES_DIM,
29
+ },
30
+ }
31
+
32
+ INDUCTOR_CONFIGS = {
33
+ 'conv_1x1_as_mm': True,
34
+ 'epilogue_fusion': False,
35
+ 'coordinate_descent_tuning': True,
36
+ 'coordinate_descent_check_all_directions': True,
37
+ 'max_autotune': True,
38
+ 'triton.cudagraphs': True,
39
+ }
40
+
41
+
42
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
43
+
44
+ @spaces.GPU(duration=1500)
45
+ def compile_transformer():
46
+
47
+ pipeline.load_lora_weights(
48
+ "Kijai/WanVideo_comfy",
49
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
50
+ adapter_name="lightx2v"
51
+ )
52
+ kwargs_lora = {}
53
+ kwargs_lora["load_into_transformer_2"] = True
54
+ pipeline.load_lora_weights(
55
+ "Kijai/WanVideo_comfy",
56
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
57
+ adapter_name="lightx2v_2", **kwargs_lora
58
+ )
59
+ pipeline.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1., 1.])
60
+ pipeline.fuse_lora(adapter_names=["lightx2v"], lora_scale=3., components=["transformer"])
61
+ pipeline.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1., components=["transformer_2"])
62
+ pipeline.unload_lora_weights()
63
+
64
+ with capture_component_call(pipeline, 'transformer') as call:
65
+ pipeline(*args, **kwargs)
66
+
67
+ dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
68
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
69
+
70
+ quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
71
+ quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
72
+
73
+ hidden_states: torch.Tensor = call.kwargs['hidden_states']
74
+ hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
75
+ if hidden_states.shape[-1] > hidden_states.shape[-2]:
76
+ hidden_states_landscape = hidden_states
77
+ hidden_states_portrait = hidden_states_transposed
78
+ else:
79
+ hidden_states_landscape = hidden_states_transposed
80
+ hidden_states_portrait = hidden_states
81
+
82
+ exported_landscape_1 = torch.export.export(
83
+ mod=pipeline.transformer,
84
+ args=call.args,
85
+ kwargs=call.kwargs | {'hidden_states': hidden_states_landscape},
86
+ dynamic_shapes=dynamic_shapes,
87
+ )
88
+
89
+ exported_portrait_2 = torch.export.export(
90
+ mod=pipeline.transformer_2,
91
+ args=call.args,
92
+ kwargs=call.kwargs | {'hidden_states': hidden_states_portrait},
93
+ dynamic_shapes=dynamic_shapes,
94
+ )
95
+
96
+ compiled_landscape_1 = aoti_compile(exported_landscape_1, INDUCTOR_CONFIGS)
97
+ compiled_portrait_2 = aoti_compile(exported_portrait_2, INDUCTOR_CONFIGS)
98
+
99
+ compiled_landscape_2 = ZeroGPUCompiledModel(compiled_landscape_1.archive_file, compiled_portrait_2.weights)
100
+ compiled_portrait_1 = ZeroGPUCompiledModel(compiled_portrait_2.archive_file, compiled_landscape_1.weights)
101
+
102
+ return (
103
+ compiled_landscape_1,
104
+ compiled_landscape_2,
105
+ compiled_portrait_1,
106
+ compiled_portrait_2,
107
+ )
108
+
109
+ quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
110
+ cl1, cl2, cp1, cp2 = compile_transformer()
111
+
112
+ def combined_transformer_1(*args, **kwargs):
113
+ hidden_states: torch.Tensor = kwargs['hidden_states']
114
+ if hidden_states.shape[-1] > hidden_states.shape[-2]:
115
+ return cl1(*args, **kwargs)
116
+ else:
117
+ return cp1(*args, **kwargs)
118
+
119
+ def combined_transformer_2(*args, **kwargs):
120
+ hidden_states: torch.Tensor = kwargs['hidden_states']
121
+ if hidden_states.shape[-1] > hidden_states.shape[-2]:
122
+ return cl2(*args, **kwargs)
123
+ else:
124
+ return cp2(*args, **kwargs)
125
+
126
+ pipeline.transformer.forward = combined_transformer_1
127
+ drain_module_parameters(pipeline.transformer)
128
+
129
+ pipeline.transformer_2.forward = combined_transformer_2
130
+ drain_module_parameters(pipeline.transformer_2)
optimization_utils.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ import contextlib
4
+ from contextvars import ContextVar
5
+ from io import BytesIO
6
+ from typing import Any
7
+ from typing import cast
8
+ from unittest.mock import patch
9
+
10
+ import torch
11
+ from torch._inductor.package.package import package_aoti
12
+ from torch.export.pt2_archive._package import AOTICompiledModel
13
+ from torch.export.pt2_archive._package_weights import Weights
14
+
15
+
16
+ INDUCTOR_CONFIGS_OVERRIDES = {
17
+ 'aot_inductor.package_constants_in_so': False,
18
+ 'aot_inductor.package_constants_on_disk': True,
19
+ 'aot_inductor.package': True,
20
+ }
21
+
22
+
23
+ class ZeroGPUWeights:
24
+ def __init__(self, constants_map: dict[str, torch.Tensor], to_cuda: bool = False):
25
+ if to_cuda:
26
+ self.constants_map = {name: tensor.to('cuda') for name, tensor in constants_map.items()}
27
+ else:
28
+ self.constants_map = constants_map
29
+ def __reduce__(self):
30
+ constants_map: dict[str, torch.Tensor] = {}
31
+ for name, tensor in self.constants_map.items():
32
+ tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
33
+ constants_map[name] = tensor_.copy_(tensor).detach().share_memory_()
34
+ return ZeroGPUWeights, (constants_map, True)
35
+
36
+
37
+ class ZeroGPUCompiledModel:
38
+ def __init__(self, archive_file: torch.types.FileLike, weights: ZeroGPUWeights):
39
+ self.archive_file = archive_file
40
+ self.weights = weights
41
+ self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar('compiled_model', default=None)
42
+ def __call__(self, *args, **kwargs):
43
+ if (compiled_model := self.compiled_model.get()) is None:
44
+ compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
45
+ compiled_model.load_constants(self.weights.constants_map, check_full_update=True, user_managed=True)
46
+ self.compiled_model.set(compiled_model)
47
+ return compiled_model(*args, **kwargs)
48
+ def __reduce__(self):
49
+ return ZeroGPUCompiledModel, (self.archive_file, self.weights)
50
+
51
+
52
+ def aoti_compile(
53
+ exported_program: torch.export.ExportedProgram,
54
+ inductor_configs: dict[str, Any] | None = None,
55
+ ):
56
+ inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
57
+ gm = cast(torch.fx.GraphModule, exported_program.module())
58
+ assert exported_program.example_inputs is not None
59
+ args, kwargs = exported_program.example_inputs
60
+ artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
61
+ archive_file = BytesIO()
62
+ files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
63
+ package_aoti(archive_file, files)
64
+ weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
65
+ zerogpu_weights = ZeroGPUWeights({name: weights.get_weight(name)[0] for name in weights})
66
+ return ZeroGPUCompiledModel(archive_file, zerogpu_weights)
67
+
68
+
69
+ @contextlib.contextmanager
70
+ def capture_component_call(
71
+ pipeline: Any,
72
+ component_name: str,
73
+ component_method='forward',
74
+ ):
75
+
76
+ class CapturedCallException(Exception):
77
+ def __init__(self, *args, **kwargs):
78
+ super().__init__()
79
+ self.args = args
80
+ self.kwargs = kwargs
81
+
82
+ class CapturedCall:
83
+ def __init__(self):
84
+ self.args: tuple[Any, ...] = ()
85
+ self.kwargs: dict[str, Any] = {}
86
+
87
+ component = getattr(pipeline, component_name)
88
+ captured_call = CapturedCall()
89
+
90
+ def capture_call(*args, **kwargs):
91
+ raise CapturedCallException(*args, **kwargs)
92
+
93
+ with patch.object(component, component_method, new=capture_call):
94
+ try:
95
+ yield captured_call
96
+ except CapturedCallException as e:
97
+ captured_call.args = e.args
98
+ captured_call.kwargs = e.kwargs
99
+
100
+
101
+ def drain_module_parameters(module: torch.nn.Module):
102
+ state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()}
103
+ state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
104
+ module.load_state_dict(state_dict, assign=True)
105
+ for name, param in state_dict.items():
106
+ meta = state_dict_meta[name]
107
+ param.data = torch.Tensor([]).to(**meta)