cbensimon HF Staff commited on
Commit
3df4fd5
·
1 Parent(s): 6631fab

Compilation

Browse files
Files changed (3) hide show
  1. app.py +3 -0
  2. optimization.py +54 -0
  3. zerogpu.py +62 -0
app.py CHANGED
@@ -8,9 +8,12 @@ from PIL import Image
8
  from diffusers import FluxKontextPipeline
9
  from diffusers.utils import load_image
10
 
 
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
 
13
  pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
 
14
 
15
  @spaces.GPU
16
  def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
 
8
  from diffusers import FluxKontextPipeline
9
  from diffusers.utils import load_image
10
 
11
+ from optimization import optimize_pipeline_
12
+
13
  MAX_SEED = np.iinfo(np.int32).max
14
 
15
  pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
16
+ optimize_pipeline_(pipe)
17
 
18
  @spaces.GPU
19
  def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
optimization.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ import spaces
5
+ import torch
6
+ from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
7
+ from torchao.quantization import quantize_
8
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
9
+
10
+ from zerogpu import aoti_compile
11
+
12
+
13
+ def optimize_pipeline_(pipeline: FluxPipeline):
14
+
15
+ @spaces.GPU(duration=1500)
16
+ def compile_transformer():
17
+
18
+ pipeline.transformer.fuse_qkv_projections()
19
+ quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
20
+
21
+ def _example_tensor(*shape):
22
+ return torch.randn(*shape, device='cuda', dtype=torch.bfloat16)
23
+
24
+ is_timestep_distilled = not pipeline.transformer.config.guidance_embeds
25
+ seq_length = 256 if is_timestep_distilled else 512
26
+
27
+ transformer_kwargs = {
28
+ 'hidden_states': _example_tensor(1, 4096, 64),
29
+ 'timestep': torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
30
+ 'guidance': None if is_timestep_distilled else torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
31
+ 'pooled_projections': _example_tensor(1, 768),
32
+ 'encoder_hidden_states': _example_tensor(1, seq_length, 4096),
33
+ 'txt_ids': _example_tensor(seq_length, 3),
34
+ 'img_ids': _example_tensor(4096, 3),
35
+ 'joint_attention_kwargs': {},
36
+ 'return_dict': False,
37
+ }
38
+
39
+ inductor_configs = {
40
+ 'conv_1x1_as_mm': True,
41
+ 'epilogue_fusion': False,
42
+ 'coordinate_descent_tuning': True,
43
+ 'coordinate_descent_check_all_directions': True,
44
+ 'max_autotune': True,
45
+ 'triton.cudagraphs': True,
46
+ }
47
+
48
+ exported = torch.export.export(pipeline.transformer, args=(), kwargs=transformer_kwargs)
49
+
50
+ return aoti_compile(exported, inductor_configs)
51
+
52
+ transformer_config = pipeline.transformer.config
53
+ pipeline.transformer = compile_transformer()
54
+ pipeline.transformer.config = transformer_config
zerogpu.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from contextvars import ContextVar
4
+ from io import BytesIO
5
+ from typing import Any
6
+ from typing import cast
7
+
8
+ import torch
9
+ from torch._inductor.package.package import package_aoti
10
+ from torch.export.pt2_archive._package import AOTICompiledModel
11
+ from torch.export.pt2_archive._package_weights import TensorProperties
12
+ from torch.export.pt2_archive._package_weights import Weights
13
+
14
+
15
+ INDUCTOR_CONFIGS_OVERRIDES = {
16
+ 'aot_inductor.package_constants_in_so': False,
17
+ 'aot_inductor.package_constants_on_disk': True,
18
+ 'aot_inductor.package': True,
19
+ }
20
+
21
+
22
+ class ZeroGPUCompiledModel:
23
+ def __init__(self, archive_file: torch.types.FileLike, weights: Weights, cuda: bool = False):
24
+ self.archive_file = archive_file
25
+ self.weights = weights
26
+ if cuda:
27
+ self.weights_to_cuda_()
28
+ self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar('compiled_model', default=None)
29
+ def weights_to_cuda_(self):
30
+ for name in self.weights:
31
+ tensor, properties = self.weights.get_weight(name)
32
+ self.weights[name] = (tensor.to('cuda'), properties)
33
+ def __call__(self, *args, **kwargs):
34
+ if (compiled_model := self.compiled_model.get()) is None:
35
+ constants_map = {name: value[0] for name, value in self.weights.items()}
36
+ compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
37
+ compiled_model.load_constants(constants_map, check_full_update=True, user_managed=True)
38
+ self.compiled_model.set(compiled_model)
39
+ return compiled_model(*args, **kwargs)
40
+ def __reduce__(self):
41
+ weight_dict: dict[str, tuple[torch.Tensor, TensorProperties]] = {}
42
+ for name in self.weights:
43
+ tensor, properties = self.weights.get_weight(name)
44
+ tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
45
+ weight_dict[name] = (tensor_.copy_(tensor).detach().share_memory_(), properties)
46
+ return ZeroGPUCompiledModel, (self.archive_file, Weights(weight_dict), True)
47
+
48
+
49
+ def aoti_compile(
50
+ exported_program: torch.export.ExportedProgram,
51
+ inductor_configs: dict[str, Any] | None = None,
52
+ ):
53
+ inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
54
+ gm = exported_program.module()
55
+ assert exported_program.example_inputs is not None
56
+ args, kwargs = exported_program.example_inputs
57
+ artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
58
+ archive_file = BytesIO()
59
+ files = [file for file in artifacts if isinstance(file, str)]
60
+ package_aoti(archive_file, files)
61
+ weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
62
+ return ZeroGPUCompiledModel(archive_file, weights)