cbensimon HF Staff commited on
Commit
674e245
·
1 Parent(s): a3d55a6

ZeroGPUCompiledModel

Browse files
Files changed (2) hide show
  1. app.py +5 -33
  2. utils/zerogpu.py +60 -0
app.py CHANGED
@@ -19,9 +19,8 @@ import spaces
19
  import torch
20
  import torch._inductor
21
  from diffusers import FluxPipeline
22
- from torch._inductor.package import package_aoti
23
- from torch.export.pt2_archive._package import AOTICompiledModel
24
- from torch.export.pt2_archive._package_weights import Weights
25
 
26
 
27
  pipeline = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-schnell', torch_dtype=torch.bfloat16).to('cuda')
@@ -61,43 +60,16 @@ def compile_transformer():
61
 
62
  exported = torch.export.export(pipeline.transformer, args=(), kwargs=transformer_kwargs)
63
 
64
- artifacts = torch._inductor.aot_compile(exported.module(), *exported.example_inputs, options=inductor_configs | {
65
- 'aot_inductor.package_constants_in_so': False,
66
- 'aot_inductor.package_constants_on_disk': True,
67
- 'aot_inductor.package': True,
68
- })
69
-
70
- files = [file for file in artifacts if isinstance(file, str)]
71
- package_aoti(package_path, files)
72
-
73
- weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
74
- weights_: dict[str, torch.Tensor] = {}
75
-
76
- for name in weights:
77
- tensor, _properties = weights.get_weight(name)
78
- tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
79
- weights_[name] = tensor_.copy_(tensor).detach().share_memory_()
80
 
81
- return weights_
82
-
83
-
84
- weights = compile_transformer()
85
- weights = {name: tensor.to('cuda') for name, tensor in weights.items()}
86
- print('compile_transformer', -(t0 - (t0 := datetime.now())))
87
 
88
  transformer_config = pipeline.transformer.config
89
- pipeline.transformer = None
90
-
91
 
92
  @spaces.GPU
93
  def _generate_image(prompt: str, t0: datetime):
94
  print('@spaces.GPU', -(t0 - (t0 := datetime.now())))
95
- compiled_transformer: AOTICompiledModel = torch._inductor.aoti_load_package(package_path)
96
- print('aoti_load_package', -(t0 - (t0 := datetime.now())))
97
- compiled_transformer.load_constants(weights, check_full_update=True, user_managed=True)
98
- print('load_constants', -(t0 - (t0 := datetime.now())))
99
- pipeline.transformer = compiled_transformer
100
- pipeline.transformer.config = transformer_config
101
  images = []
102
  for _ in range(4):
103
  images += pipeline(prompt, num_inference_steps=4).images
 
19
  import torch
20
  import torch._inductor
21
  from diffusers import FluxPipeline
22
+
23
+ from .utils.zerogpu import aoti_compile
 
24
 
25
 
26
  pipeline = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-schnell', torch_dtype=torch.bfloat16).to('cuda')
 
60
 
61
  exported = torch.export.export(pipeline.transformer, args=(), kwargs=transformer_kwargs)
62
 
63
+ return aoti_compile(exported, inductor_configs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
 
 
 
 
 
 
65
 
66
  transformer_config = pipeline.transformer.config
67
+ pipeline.transformer = compile_transformer()
68
+ pipeline.transformer.config = transformer_config
69
 
70
  @spaces.GPU
71
  def _generate_image(prompt: str, t0: datetime):
72
  print('@spaces.GPU', -(t0 - (t0 := datetime.now())))
 
 
 
 
 
 
73
  images = []
74
  for _ in range(4):
75
  images += pipeline(prompt, num_inference_steps=4).images
utils/zerogpu.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from io import BytesIO
4
+ from typing import Any
5
+
6
+ import torch
7
+ from torch._inductor.package.package import package_aoti
8
+ from torch.export.pt2_archive._package import AOTICompiledModel
9
+ from torch.export.pt2_archive._package_weights import TensorProperties
10
+ from torch.export.pt2_archive._package_weights import Weights
11
+
12
+
13
+ INDUCTOR_CONFIGS_OVERRIDES = {
14
+ 'aot_inductor.package_constants_in_so': False,
15
+ 'aot_inductor.package_constants_on_disk': True,
16
+ 'aot_inductor.package': True,
17
+ }
18
+
19
+
20
+ class ZeroGPUCompiledModel:
21
+ def __init__(self, archive_file: BytesIO, weights: Weights, cuda: bool = False):
22
+ self.archive_file = archive_file
23
+ self.weights = weights
24
+ if cuda:
25
+ self.weights_to_cuda_()
26
+ self.compiled_model: AOTICompiledModel | None = None
27
+ def weights_to_cuda_(self):
28
+ for name in self.weights:
29
+ tensor, properties = self.weights.get_weight(name)
30
+ self.weights[name] = (tensor.to('cuda'), properties)
31
+ def __call__(self, *args, **kwargs):
32
+ if self.compiled_model is None:
33
+ constants_map = {name: value[1] for name, value in self.weights.items()}
34
+ compiled_model: AOTICompiledModel = torch._inductor.aoti_load_package(self.archive_file)
35
+ compiled_model.load_constants(constants_map, check_full_update=True, user_managed=True)
36
+ self.compiled_model = compiled_model
37
+ return self.compiled_model(*args, **kwargs)
38
+ def __reduce__(self):
39
+ weight_dict: dict[str, tuple[torch.Tensor, TensorProperties]] = {}
40
+ for name in self.weights:
41
+ tensor, properties = self.weights.get_weight(name)
42
+ tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
43
+ weight_dict[name] = (tensor_.copy_(tensor).detach().share_memory_(), properties)
44
+ return ZeroGPUCompiledModel, (self.archive_file, Weights(weight_dict), True)
45
+
46
+
47
+ def aoti_compile(
48
+ exported_program: torch.export.ExportedProgram,
49
+ inductor_configs: dict[str, Any] | None = None,
50
+ ):
51
+ inductor_configs = inductor_configs or {} | INDUCTOR_CONFIGS_OVERRIDES
52
+ gm = exported_program.module()
53
+ assert exported_program.example_inputs is not None
54
+ args, kwargs = exported_program.example_inputs
55
+ artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
56
+ archive_file = BytesIO()
57
+ files = [file for file in artifacts if isinstance(file, str)]
58
+ package_aoti(archive_file, files)
59
+ weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
60
+ return ZeroGPUCompiledModel(archive_file, weights)