cbensimon HF Staff commited on
Commit
69667cb
·
1 Parent(s): da628cb

fa3 + big refresh

Browse files
Files changed (5) hide show
  1. app.py +4 -55
  2. fa3.py +115 -0
  3. optimization.py +43 -0
  4. requirements.txt +1 -1
  5. zerogpu.py +0 -62
app.py CHANGED
@@ -1,75 +1,24 @@
1
- """
2
- """
3
- # Upgrade PyTorch
4
- import os
5
- os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
6
-
7
- # Actual app.py
8
- import os
9
  from datetime import datetime
10
 
11
  import gradio as gr
12
  import spaces
13
  import torch
14
  from diffusers import FluxPipeline
15
- from torchao.quantization import quantize_
16
- from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
17
 
18
- from zerogpu import aoti_compile
19
 
20
 
21
  pipeline = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-schnell', torch_dtype=torch.bfloat16).to('cuda')
22
-
23
-
24
- @spaces.GPU(duration=1500)
25
- def compile_transformer():
26
-
27
- pipeline.transformer.fuse_qkv_projections()
28
- quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
29
-
30
- def _example_tensor(*shape):
31
- return torch.randn(*shape, device='cuda', dtype=torch.bfloat16)
32
-
33
- is_timestep_distilled = not pipeline.transformer.config.guidance_embeds
34
- seq_length = 256 if is_timestep_distilled else 512
35
-
36
- transformer_kwargs = {
37
- 'hidden_states': _example_tensor(1, 4096, 64),
38
- 'timestep': torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
39
- 'guidance': None if is_timestep_distilled else torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
40
- 'pooled_projections': _example_tensor(1, 768),
41
- 'encoder_hidden_states': _example_tensor(1, seq_length, 4096),
42
- 'txt_ids': _example_tensor(seq_length, 3),
43
- 'img_ids': _example_tensor(4096, 3),
44
- 'joint_attention_kwargs': {},
45
- 'return_dict': False,
46
- }
47
-
48
- inductor_configs = {
49
- 'conv_1x1_as_mm': True,
50
- 'epilogue_fusion': False,
51
- 'coordinate_descent_tuning': True,
52
- 'coordinate_descent_check_all_directions': True,
53
- 'max_autotune': True,
54
- 'triton.cudagraphs': True,
55
- }
56
-
57
- exported = torch.export.export(pipeline.transformer, args=(), kwargs=transformer_kwargs)
58
-
59
- return aoti_compile(exported, inductor_configs)
60
-
61
-
62
- transformer_config = pipeline.transformer.config
63
- pipeline.transformer = compile_transformer()
64
- pipeline.transformer.config = transformer_config
65
 
66
 
67
  @spaces.GPU
68
  def generate_image(prompt: str):
 
69
  t0 = datetime.now()
70
  images = []
71
  for _ in range(9):
72
- image = pipeline(prompt, num_inference_steps=4).images[0]
73
  elapsed = -(t0 - (t0 := datetime.now()))
74
  images += [(image, f'{elapsed.total_seconds():.2f}s')]
75
  yield images
 
 
 
 
 
 
 
 
 
1
  from datetime import datetime
2
 
3
  import gradio as gr
4
  import spaces
5
  import torch
6
  from diffusers import FluxPipeline
 
 
7
 
8
+ from optimization import optimize_pipeline_
9
 
10
 
11
  pipeline = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-schnell', torch_dtype=torch.bfloat16).to('cuda')
12
+ optimize_pipeline_(pipeline, "prompt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  @spaces.GPU
16
  def generate_image(prompt: str):
17
+ generator = torch.Generator(device='cuda').manual_seed(42)
18
  t0 = datetime.now()
19
  images = []
20
  for _ in range(9):
21
+ image = pipeline(prompt, num_inference_steps=4, generator=generator).images[0]
22
  elapsed = -(t0 - (t0 := datetime.now()))
23
  images += [(image, f'{elapsed.total_seconds():.2f}s')]
24
  yield images
fa3.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ import torch
5
+ from kernels import get_kernel
6
+
7
+
8
+ _flash_attn_func = get_kernel("kernels-community/vllm-flash-attn3").flash_attn_func
9
+
10
+
11
+ @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
12
+ def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
13
+ outputs, lse = _flash_attn_func(q, k, v)
14
+ return outputs
15
+
16
+ @flash_attn_func.register_fake
17
+ def _(q, k, v, **kwargs):
18
+ # two outputs:
19
+ # 1. output: (batch, seq_len, num_heads, head_dim)
20
+ # 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32
21
+ meta_q = torch.empty_like(q).contiguous()
22
+ return meta_q #, q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32)
23
+
24
+ # Copied FusedFluxAttnProcessor2_0 but using flash v3 instead of SDPA
25
+ class FlashFusedFluxAttnProcessor3_0:
26
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
27
+
28
+ def __call__(
29
+ self,
30
+ attn,
31
+ hidden_states: torch.FloatTensor,
32
+ encoder_hidden_states: torch.FloatTensor | None = None,
33
+ attention_mask: torch.FloatTensor | None = None,
34
+ image_rotary_emb: torch.Tensor | None = None,
35
+ ) -> torch.FloatTensor:
36
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
37
+
38
+ # `sample` projections.
39
+ qkv = attn.to_qkv(hidden_states)
40
+ split_size = qkv.shape[-1] // 3
41
+ query, key, value = torch.split(qkv, split_size, dim=-1)
42
+
43
+ inner_dim = key.shape[-1]
44
+ head_dim = inner_dim // attn.heads
45
+
46
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
47
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
48
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
49
+
50
+ if attn.norm_q is not None:
51
+ query = attn.norm_q(query)
52
+ if attn.norm_k is not None:
53
+ key = attn.norm_k(key)
54
+
55
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
56
+ # `context` projections.
57
+ if encoder_hidden_states is not None:
58
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
59
+ split_size = encoder_qkv.shape[-1] // 3
60
+ (
61
+ encoder_hidden_states_query_proj,
62
+ encoder_hidden_states_key_proj,
63
+ encoder_hidden_states_value_proj,
64
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
65
+
66
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
67
+ batch_size, -1, attn.heads, head_dim
68
+ ).transpose(1, 2)
69
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
70
+ batch_size, -1, attn.heads, head_dim
71
+ ).transpose(1, 2)
72
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
73
+ batch_size, -1, attn.heads, head_dim
74
+ ).transpose(1, 2)
75
+
76
+ if attn.norm_added_q is not None:
77
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
78
+ if attn.norm_added_k is not None:
79
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
80
+
81
+ # attention
82
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
83
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
84
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
85
+
86
+ if image_rotary_emb is not None:
87
+ from diffusers.models.embeddings import apply_rotary_emb
88
+
89
+ query = apply_rotary_emb(query, image_rotary_emb)
90
+ key = apply_rotary_emb(key, image_rotary_emb)
91
+
92
+ # NB: transposes are necessary to match expected SDPA input shape
93
+ hidden_states = flash_attn_func(
94
+ query.transpose(1, 2),
95
+ key.transpose(1, 2),
96
+ value.transpose(1, 2))[0].transpose(1, 2)
97
+
98
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
99
+ hidden_states = hidden_states.to(query.dtype)
100
+
101
+ if encoder_hidden_states is not None:
102
+ encoder_hidden_states, hidden_states = (
103
+ hidden_states[:, : encoder_hidden_states.shape[1]],
104
+ hidden_states[:, encoder_hidden_states.shape[1] :],
105
+ )
106
+
107
+ # linear proj
108
+ hidden_states = attn.to_out[0](hidden_states)
109
+ # dropout
110
+ hidden_states = attn.to_out[1](hidden_states)
111
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
112
+
113
+ return hidden_states, encoder_hidden_states
114
+ else:
115
+ return hidden_states
optimization.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from typing import Any
5
+ from typing import Callable
6
+ from typing import ParamSpec
7
+ import spaces
8
+ import torch
9
+
10
+ from fa3 import FlashFusedFluxAttnProcessor3_0
11
+
12
+
13
+ P = ParamSpec('P')
14
+
15
+
16
+ INDUCTOR_CONFIGS = {
17
+ 'conv_1x1_as_mm': True,
18
+ 'epilogue_fusion': False,
19
+ 'coordinate_descent_tuning': True,
20
+ 'coordinate_descent_check_all_directions': True,
21
+ 'max_autotune': True,
22
+ 'triton.cudagraphs': True,
23
+ }
24
+
25
+
26
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
27
+
28
+ @spaces.GPU(duration=1500)
29
+ def compile_transformer():
30
+
31
+ with spaces.aoti_capture(pipeline.transformer) as call:
32
+ pipeline(*args, **kwargs)
33
+
34
+ exported = torch.export.export(
35
+ mod=pipeline.transformer,
36
+ args=call.args,
37
+ kwargs=call.kwargs,
38
+ )
39
+
40
+ return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
41
+
42
+ pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
43
+ spaces.aoti_apply(compile_transformer(), pipeline.transformer)
requirements.txt CHANGED
@@ -3,4 +3,4 @@ diffusers
3
  transformers
4
  sentencepiece
5
  protobuf
6
- torchao
 
3
  transformers
4
  sentencepiece
5
  protobuf
6
+ kernels
zerogpu.py DELETED
@@ -1,62 +0,0 @@
1
- """
2
- """
3
- from contextvars import ContextVar
4
- from io import BytesIO
5
- from typing import Any
6
- from typing import cast
7
-
8
- import torch
9
- from torch._inductor.package.package import package_aoti
10
- from torch.export.pt2_archive._package import AOTICompiledModel
11
- from torch.export.pt2_archive._package_weights import TensorProperties
12
- from torch.export.pt2_archive._package_weights import Weights
13
-
14
-
15
- INDUCTOR_CONFIGS_OVERRIDES = {
16
- 'aot_inductor.package_constants_in_so': False,
17
- 'aot_inductor.package_constants_on_disk': True,
18
- 'aot_inductor.package': True,
19
- }
20
-
21
-
22
- class ZeroGPUCompiledModel:
23
- def __init__(self, archive_file: torch.types.FileLike, weights: Weights, cuda: bool = False):
24
- self.archive_file = archive_file
25
- self.weights = weights
26
- if cuda:
27
- self.weights_to_cuda_()
28
- self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar('compiled_model', default=None)
29
- def weights_to_cuda_(self):
30
- for name in self.weights:
31
- tensor, properties = self.weights.get_weight(name)
32
- self.weights[name] = (tensor.to('cuda'), properties)
33
- def __call__(self, *args, **kwargs):
34
- if (compiled_model := self.compiled_model.get()) is None:
35
- constants_map = {name: value[0] for name, value in self.weights.items()}
36
- compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
37
- compiled_model.load_constants(constants_map, check_full_update=True, user_managed=True)
38
- self.compiled_model.set(compiled_model)
39
- return compiled_model(*args, **kwargs)
40
- def __reduce__(self):
41
- weight_dict: dict[str, tuple[torch.Tensor, TensorProperties]] = {}
42
- for name in self.weights:
43
- tensor, properties = self.weights.get_weight(name)
44
- tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
45
- weight_dict[name] = (tensor_.copy_(tensor).detach().share_memory_(), properties)
46
- return ZeroGPUCompiledModel, (self.archive_file, Weights(weight_dict), True)
47
-
48
-
49
- def aoti_compile(
50
- exported_program: torch.export.ExportedProgram,
51
- inductor_configs: dict[str, Any] | None = None,
52
- ):
53
- inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
54
- gm = exported_program.module()
55
- assert exported_program.example_inputs is not None
56
- args, kwargs = exported_program.example_inputs
57
- artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
58
- archive_file = BytesIO()
59
- files = [file for file in artifacts if isinstance(file, str)]
60
- package_aoti(archive_file, files)
61
- weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
62
- return ZeroGPUCompiledModel(archive_file, weights)