Spaces:
Running
on
Zero
Running
on
Zero
fa3 + big refresh
Browse files- app.py +4 -55
- fa3.py +115 -0
- optimization.py +43 -0
- requirements.txt +1 -1
- zerogpu.py +0 -62
app.py
CHANGED
@@ -1,75 +1,24 @@
|
|
1 |
-
"""
|
2 |
-
"""
|
3 |
-
# Upgrade PyTorch
|
4 |
-
import os
|
5 |
-
os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
|
6 |
-
|
7 |
-
# Actual app.py
|
8 |
-
import os
|
9 |
from datetime import datetime
|
10 |
|
11 |
import gradio as gr
|
12 |
import spaces
|
13 |
import torch
|
14 |
from diffusers import FluxPipeline
|
15 |
-
from torchao.quantization import quantize_
|
16 |
-
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
|
17 |
|
18 |
-
from
|
19 |
|
20 |
|
21 |
pipeline = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-schnell', torch_dtype=torch.bfloat16).to('cuda')
|
22 |
-
|
23 |
-
|
24 |
-
@spaces.GPU(duration=1500)
|
25 |
-
def compile_transformer():
|
26 |
-
|
27 |
-
pipeline.transformer.fuse_qkv_projections()
|
28 |
-
quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
|
29 |
-
|
30 |
-
def _example_tensor(*shape):
|
31 |
-
return torch.randn(*shape, device='cuda', dtype=torch.bfloat16)
|
32 |
-
|
33 |
-
is_timestep_distilled = not pipeline.transformer.config.guidance_embeds
|
34 |
-
seq_length = 256 if is_timestep_distilled else 512
|
35 |
-
|
36 |
-
transformer_kwargs = {
|
37 |
-
'hidden_states': _example_tensor(1, 4096, 64),
|
38 |
-
'timestep': torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
|
39 |
-
'guidance': None if is_timestep_distilled else torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
|
40 |
-
'pooled_projections': _example_tensor(1, 768),
|
41 |
-
'encoder_hidden_states': _example_tensor(1, seq_length, 4096),
|
42 |
-
'txt_ids': _example_tensor(seq_length, 3),
|
43 |
-
'img_ids': _example_tensor(4096, 3),
|
44 |
-
'joint_attention_kwargs': {},
|
45 |
-
'return_dict': False,
|
46 |
-
}
|
47 |
-
|
48 |
-
inductor_configs = {
|
49 |
-
'conv_1x1_as_mm': True,
|
50 |
-
'epilogue_fusion': False,
|
51 |
-
'coordinate_descent_tuning': True,
|
52 |
-
'coordinate_descent_check_all_directions': True,
|
53 |
-
'max_autotune': True,
|
54 |
-
'triton.cudagraphs': True,
|
55 |
-
}
|
56 |
-
|
57 |
-
exported = torch.export.export(pipeline.transformer, args=(), kwargs=transformer_kwargs)
|
58 |
-
|
59 |
-
return aoti_compile(exported, inductor_configs)
|
60 |
-
|
61 |
-
|
62 |
-
transformer_config = pipeline.transformer.config
|
63 |
-
pipeline.transformer = compile_transformer()
|
64 |
-
pipeline.transformer.config = transformer_config
|
65 |
|
66 |
|
67 |
@spaces.GPU
|
68 |
def generate_image(prompt: str):
|
|
|
69 |
t0 = datetime.now()
|
70 |
images = []
|
71 |
for _ in range(9):
|
72 |
-
image = pipeline(prompt, num_inference_steps=4).images[0]
|
73 |
elapsed = -(t0 - (t0 := datetime.now()))
|
74 |
images += [(image, f'{elapsed.total_seconds():.2f}s')]
|
75 |
yield images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from datetime import datetime
|
2 |
|
3 |
import gradio as gr
|
4 |
import spaces
|
5 |
import torch
|
6 |
from diffusers import FluxPipeline
|
|
|
|
|
7 |
|
8 |
+
from optimization import optimize_pipeline_
|
9 |
|
10 |
|
11 |
pipeline = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-schnell', torch_dtype=torch.bfloat16).to('cuda')
|
12 |
+
optimize_pipeline_(pipeline, "prompt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
@spaces.GPU
|
16 |
def generate_image(prompt: str):
|
17 |
+
generator = torch.Generator(device='cuda').manual_seed(42)
|
18 |
t0 = datetime.now()
|
19 |
images = []
|
20 |
for _ in range(9):
|
21 |
+
image = pipeline(prompt, num_inference_steps=4, generator=generator).images[0]
|
22 |
elapsed = -(t0 - (t0 := datetime.now()))
|
23 |
images += [(image, f'{elapsed.total_seconds():.2f}s')]
|
24 |
yield images
|
fa3.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from kernels import get_kernel
|
6 |
+
|
7 |
+
|
8 |
+
_flash_attn_func = get_kernel("kernels-community/vllm-flash-attn3").flash_attn_func
|
9 |
+
|
10 |
+
|
11 |
+
@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
|
12 |
+
def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
13 |
+
outputs, lse = _flash_attn_func(q, k, v)
|
14 |
+
return outputs
|
15 |
+
|
16 |
+
@flash_attn_func.register_fake
|
17 |
+
def _(q, k, v, **kwargs):
|
18 |
+
# two outputs:
|
19 |
+
# 1. output: (batch, seq_len, num_heads, head_dim)
|
20 |
+
# 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32
|
21 |
+
meta_q = torch.empty_like(q).contiguous()
|
22 |
+
return meta_q #, q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32)
|
23 |
+
|
24 |
+
# Copied FusedFluxAttnProcessor2_0 but using flash v3 instead of SDPA
|
25 |
+
class FlashFusedFluxAttnProcessor3_0:
|
26 |
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
27 |
+
|
28 |
+
def __call__(
|
29 |
+
self,
|
30 |
+
attn,
|
31 |
+
hidden_states: torch.FloatTensor,
|
32 |
+
encoder_hidden_states: torch.FloatTensor | None = None,
|
33 |
+
attention_mask: torch.FloatTensor | None = None,
|
34 |
+
image_rotary_emb: torch.Tensor | None = None,
|
35 |
+
) -> torch.FloatTensor:
|
36 |
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
37 |
+
|
38 |
+
# `sample` projections.
|
39 |
+
qkv = attn.to_qkv(hidden_states)
|
40 |
+
split_size = qkv.shape[-1] // 3
|
41 |
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
42 |
+
|
43 |
+
inner_dim = key.shape[-1]
|
44 |
+
head_dim = inner_dim // attn.heads
|
45 |
+
|
46 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
47 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
48 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
49 |
+
|
50 |
+
if attn.norm_q is not None:
|
51 |
+
query = attn.norm_q(query)
|
52 |
+
if attn.norm_k is not None:
|
53 |
+
key = attn.norm_k(key)
|
54 |
+
|
55 |
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
56 |
+
# `context` projections.
|
57 |
+
if encoder_hidden_states is not None:
|
58 |
+
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
59 |
+
split_size = encoder_qkv.shape[-1] // 3
|
60 |
+
(
|
61 |
+
encoder_hidden_states_query_proj,
|
62 |
+
encoder_hidden_states_key_proj,
|
63 |
+
encoder_hidden_states_value_proj,
|
64 |
+
) = torch.split(encoder_qkv, split_size, dim=-1)
|
65 |
+
|
66 |
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
67 |
+
batch_size, -1, attn.heads, head_dim
|
68 |
+
).transpose(1, 2)
|
69 |
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
70 |
+
batch_size, -1, attn.heads, head_dim
|
71 |
+
).transpose(1, 2)
|
72 |
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
73 |
+
batch_size, -1, attn.heads, head_dim
|
74 |
+
).transpose(1, 2)
|
75 |
+
|
76 |
+
if attn.norm_added_q is not None:
|
77 |
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
78 |
+
if attn.norm_added_k is not None:
|
79 |
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
80 |
+
|
81 |
+
# attention
|
82 |
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
83 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
84 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
85 |
+
|
86 |
+
if image_rotary_emb is not None:
|
87 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
88 |
+
|
89 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
90 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
91 |
+
|
92 |
+
# NB: transposes are necessary to match expected SDPA input shape
|
93 |
+
hidden_states = flash_attn_func(
|
94 |
+
query.transpose(1, 2),
|
95 |
+
key.transpose(1, 2),
|
96 |
+
value.transpose(1, 2))[0].transpose(1, 2)
|
97 |
+
|
98 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
99 |
+
hidden_states = hidden_states.to(query.dtype)
|
100 |
+
|
101 |
+
if encoder_hidden_states is not None:
|
102 |
+
encoder_hidden_states, hidden_states = (
|
103 |
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
104 |
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
105 |
+
)
|
106 |
+
|
107 |
+
# linear proj
|
108 |
+
hidden_states = attn.to_out[0](hidden_states)
|
109 |
+
# dropout
|
110 |
+
hidden_states = attn.to_out[1](hidden_states)
|
111 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
112 |
+
|
113 |
+
return hidden_states, encoder_hidden_states
|
114 |
+
else:
|
115 |
+
return hidden_states
|
optimization.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
|
4 |
+
from typing import Any
|
5 |
+
from typing import Callable
|
6 |
+
from typing import ParamSpec
|
7 |
+
import spaces
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from fa3 import FlashFusedFluxAttnProcessor3_0
|
11 |
+
|
12 |
+
|
13 |
+
P = ParamSpec('P')
|
14 |
+
|
15 |
+
|
16 |
+
INDUCTOR_CONFIGS = {
|
17 |
+
'conv_1x1_as_mm': True,
|
18 |
+
'epilogue_fusion': False,
|
19 |
+
'coordinate_descent_tuning': True,
|
20 |
+
'coordinate_descent_check_all_directions': True,
|
21 |
+
'max_autotune': True,
|
22 |
+
'triton.cudagraphs': True,
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
|
27 |
+
|
28 |
+
@spaces.GPU(duration=1500)
|
29 |
+
def compile_transformer():
|
30 |
+
|
31 |
+
with spaces.aoti_capture(pipeline.transformer) as call:
|
32 |
+
pipeline(*args, **kwargs)
|
33 |
+
|
34 |
+
exported = torch.export.export(
|
35 |
+
mod=pipeline.transformer,
|
36 |
+
args=call.args,
|
37 |
+
kwargs=call.kwargs,
|
38 |
+
)
|
39 |
+
|
40 |
+
return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
|
41 |
+
|
42 |
+
pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
|
43 |
+
spaces.aoti_apply(compile_transformer(), pipeline.transformer)
|
requirements.txt
CHANGED
@@ -3,4 +3,4 @@ diffusers
|
|
3 |
transformers
|
4 |
sentencepiece
|
5 |
protobuf
|
6 |
-
|
|
|
3 |
transformers
|
4 |
sentencepiece
|
5 |
protobuf
|
6 |
+
kernels
|
zerogpu.py
DELETED
@@ -1,62 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
"""
|
3 |
-
from contextvars import ContextVar
|
4 |
-
from io import BytesIO
|
5 |
-
from typing import Any
|
6 |
-
from typing import cast
|
7 |
-
|
8 |
-
import torch
|
9 |
-
from torch._inductor.package.package import package_aoti
|
10 |
-
from torch.export.pt2_archive._package import AOTICompiledModel
|
11 |
-
from torch.export.pt2_archive._package_weights import TensorProperties
|
12 |
-
from torch.export.pt2_archive._package_weights import Weights
|
13 |
-
|
14 |
-
|
15 |
-
INDUCTOR_CONFIGS_OVERRIDES = {
|
16 |
-
'aot_inductor.package_constants_in_so': False,
|
17 |
-
'aot_inductor.package_constants_on_disk': True,
|
18 |
-
'aot_inductor.package': True,
|
19 |
-
}
|
20 |
-
|
21 |
-
|
22 |
-
class ZeroGPUCompiledModel:
|
23 |
-
def __init__(self, archive_file: torch.types.FileLike, weights: Weights, cuda: bool = False):
|
24 |
-
self.archive_file = archive_file
|
25 |
-
self.weights = weights
|
26 |
-
if cuda:
|
27 |
-
self.weights_to_cuda_()
|
28 |
-
self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar('compiled_model', default=None)
|
29 |
-
def weights_to_cuda_(self):
|
30 |
-
for name in self.weights:
|
31 |
-
tensor, properties = self.weights.get_weight(name)
|
32 |
-
self.weights[name] = (tensor.to('cuda'), properties)
|
33 |
-
def __call__(self, *args, **kwargs):
|
34 |
-
if (compiled_model := self.compiled_model.get()) is None:
|
35 |
-
constants_map = {name: value[0] for name, value in self.weights.items()}
|
36 |
-
compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
|
37 |
-
compiled_model.load_constants(constants_map, check_full_update=True, user_managed=True)
|
38 |
-
self.compiled_model.set(compiled_model)
|
39 |
-
return compiled_model(*args, **kwargs)
|
40 |
-
def __reduce__(self):
|
41 |
-
weight_dict: dict[str, tuple[torch.Tensor, TensorProperties]] = {}
|
42 |
-
for name in self.weights:
|
43 |
-
tensor, properties = self.weights.get_weight(name)
|
44 |
-
tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
|
45 |
-
weight_dict[name] = (tensor_.copy_(tensor).detach().share_memory_(), properties)
|
46 |
-
return ZeroGPUCompiledModel, (self.archive_file, Weights(weight_dict), True)
|
47 |
-
|
48 |
-
|
49 |
-
def aoti_compile(
|
50 |
-
exported_program: torch.export.ExportedProgram,
|
51 |
-
inductor_configs: dict[str, Any] | None = None,
|
52 |
-
):
|
53 |
-
inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
|
54 |
-
gm = exported_program.module()
|
55 |
-
assert exported_program.example_inputs is not None
|
56 |
-
args, kwargs = exported_program.example_inputs
|
57 |
-
artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
|
58 |
-
archive_file = BytesIO()
|
59 |
-
files = [file for file in artifacts if isinstance(file, str)]
|
60 |
-
package_aoti(archive_file, files)
|
61 |
-
weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
|
62 |
-
return ZeroGPUCompiledModel(archive_file, weights)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|