Spaces:
Running
on
Zero
Running
on
Zero
update aoti compile (#3)
Browse files- update aoti compile (9aa2a04d546ddb4bb6deb5d4ea24e016061b33cd)
- Update optimization_utils.py (8e5bed5da4bd6b198ad33fdbb55dbe760f3ea32a)
- optimization.py +4 -9
- optimization_utils.py +9 -0
optimization.py
CHANGED
|
@@ -122,13 +122,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
|
|
| 122 |
else:
|
| 123 |
return cp2(*args, **kwargs)
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
|
| 128 |
-
pipeline.
|
| 129 |
-
pipeline.
|
| 130 |
-
pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
|
| 131 |
-
|
| 132 |
-
pipeline.transformer_2 = combined_transformer_2
|
| 133 |
-
pipeline.transformer_2.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
|
| 134 |
-
pipeline.transformer_2.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
|
|
|
|
| 122 |
else:
|
| 123 |
return cp2(*args, **kwargs)
|
| 124 |
|
| 125 |
+
pipeline.transformer.forward = combined_transformer_1
|
| 126 |
+
drain_module_parameters(pipeline.transformer)
|
| 127 |
|
| 128 |
+
pipeline.transformer_2.forward = combined_transformer_2
|
| 129 |
+
drain_module_parameters(pipeline.transformer_2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimization_utils.py
CHANGED
|
@@ -96,3 +96,12 @@ def capture_component_call(
|
|
| 96 |
except CapturedCallException as e:
|
| 97 |
captured_call.args = e.args
|
| 98 |
captured_call.kwargs = e.kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
except CapturedCallException as e:
|
| 97 |
captured_call.args = e.args
|
| 98 |
captured_call.kwargs = e.kwargs
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def drain_module_parameters(module: torch.nn.Module):
|
| 102 |
+
state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()}
|
| 103 |
+
state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
|
| 104 |
+
module.load_state_dict(state_dict, assign=True)
|
| 105 |
+
for name, param in state_dict.items():
|
| 106 |
+
meta = state_dict_meta[name]
|
| 107 |
+
param.data = torch.Tensor([]).to(**meta)
|