Spaces:
Running
on
Zero
Running
on
Zero
Compilation
Browse files- app.py +3 -0
- optimization.py +54 -0
- zerogpu.py +62 -0
app.py
CHANGED
|
@@ -8,9 +8,12 @@ from PIL import Image
|
|
| 8 |
from diffusers import FluxKontextPipeline
|
| 9 |
from diffusers.utils import load_image
|
| 10 |
|
|
|
|
|
|
|
| 11 |
MAX_SEED = np.iinfo(np.int32).max
|
| 12 |
|
| 13 |
pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
|
|
|
|
| 14 |
|
| 15 |
@spaces.GPU
|
| 16 |
def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
| 8 |
from diffusers import FluxKontextPipeline
|
| 9 |
from diffusers.utils import load_image
|
| 10 |
|
| 11 |
+
from optimization import optimize_pipeline_
|
| 12 |
+
|
| 13 |
MAX_SEED = np.iinfo(np.int32).max
|
| 14 |
|
| 15 |
pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
|
| 16 |
+
optimize_pipeline_(pipe)
|
| 17 |
|
| 18 |
@spaces.GPU
|
| 19 |
def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
|
optimization.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
|
| 4 |
+
import spaces
|
| 5 |
+
import torch
|
| 6 |
+
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
| 7 |
+
from torchao.quantization import quantize_
|
| 8 |
+
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
|
| 9 |
+
|
| 10 |
+
from zerogpu import aoti_compile
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def optimize_pipeline_(pipeline: FluxPipeline):
|
| 14 |
+
|
| 15 |
+
@spaces.GPU(duration=1500)
|
| 16 |
+
def compile_transformer():
|
| 17 |
+
|
| 18 |
+
pipeline.transformer.fuse_qkv_projections()
|
| 19 |
+
quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
|
| 20 |
+
|
| 21 |
+
def _example_tensor(*shape):
|
| 22 |
+
return torch.randn(*shape, device='cuda', dtype=torch.bfloat16)
|
| 23 |
+
|
| 24 |
+
is_timestep_distilled = not pipeline.transformer.config.guidance_embeds
|
| 25 |
+
seq_length = 256 if is_timestep_distilled else 512
|
| 26 |
+
|
| 27 |
+
transformer_kwargs = {
|
| 28 |
+
'hidden_states': _example_tensor(1, 4096, 64),
|
| 29 |
+
'timestep': torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
|
| 30 |
+
'guidance': None if is_timestep_distilled else torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
|
| 31 |
+
'pooled_projections': _example_tensor(1, 768),
|
| 32 |
+
'encoder_hidden_states': _example_tensor(1, seq_length, 4096),
|
| 33 |
+
'txt_ids': _example_tensor(seq_length, 3),
|
| 34 |
+
'img_ids': _example_tensor(4096, 3),
|
| 35 |
+
'joint_attention_kwargs': {},
|
| 36 |
+
'return_dict': False,
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
inductor_configs = {
|
| 40 |
+
'conv_1x1_as_mm': True,
|
| 41 |
+
'epilogue_fusion': False,
|
| 42 |
+
'coordinate_descent_tuning': True,
|
| 43 |
+
'coordinate_descent_check_all_directions': True,
|
| 44 |
+
'max_autotune': True,
|
| 45 |
+
'triton.cudagraphs': True,
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
exported = torch.export.export(pipeline.transformer, args=(), kwargs=transformer_kwargs)
|
| 49 |
+
|
| 50 |
+
return aoti_compile(exported, inductor_configs)
|
| 51 |
+
|
| 52 |
+
transformer_config = pipeline.transformer.config
|
| 53 |
+
pipeline.transformer = compile_transformer()
|
| 54 |
+
pipeline.transformer.config = transformer_config
|
zerogpu.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
from contextvars import ContextVar
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
from typing import Any
|
| 6 |
+
from typing import cast
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch._inductor.package.package import package_aoti
|
| 10 |
+
from torch.export.pt2_archive._package import AOTICompiledModel
|
| 11 |
+
from torch.export.pt2_archive._package_weights import TensorProperties
|
| 12 |
+
from torch.export.pt2_archive._package_weights import Weights
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
INDUCTOR_CONFIGS_OVERRIDES = {
|
| 16 |
+
'aot_inductor.package_constants_in_so': False,
|
| 17 |
+
'aot_inductor.package_constants_on_disk': True,
|
| 18 |
+
'aot_inductor.package': True,
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ZeroGPUCompiledModel:
|
| 23 |
+
def __init__(self, archive_file: torch.types.FileLike, weights: Weights, cuda: bool = False):
|
| 24 |
+
self.archive_file = archive_file
|
| 25 |
+
self.weights = weights
|
| 26 |
+
if cuda:
|
| 27 |
+
self.weights_to_cuda_()
|
| 28 |
+
self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar('compiled_model', default=None)
|
| 29 |
+
def weights_to_cuda_(self):
|
| 30 |
+
for name in self.weights:
|
| 31 |
+
tensor, properties = self.weights.get_weight(name)
|
| 32 |
+
self.weights[name] = (tensor.to('cuda'), properties)
|
| 33 |
+
def __call__(self, *args, **kwargs):
|
| 34 |
+
if (compiled_model := self.compiled_model.get()) is None:
|
| 35 |
+
constants_map = {name: value[0] for name, value in self.weights.items()}
|
| 36 |
+
compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
|
| 37 |
+
compiled_model.load_constants(constants_map, check_full_update=True, user_managed=True)
|
| 38 |
+
self.compiled_model.set(compiled_model)
|
| 39 |
+
return compiled_model(*args, **kwargs)
|
| 40 |
+
def __reduce__(self):
|
| 41 |
+
weight_dict: dict[str, tuple[torch.Tensor, TensorProperties]] = {}
|
| 42 |
+
for name in self.weights:
|
| 43 |
+
tensor, properties = self.weights.get_weight(name)
|
| 44 |
+
tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
|
| 45 |
+
weight_dict[name] = (tensor_.copy_(tensor).detach().share_memory_(), properties)
|
| 46 |
+
return ZeroGPUCompiledModel, (self.archive_file, Weights(weight_dict), True)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def aoti_compile(
|
| 50 |
+
exported_program: torch.export.ExportedProgram,
|
| 51 |
+
inductor_configs: dict[str, Any] | None = None,
|
| 52 |
+
):
|
| 53 |
+
inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
|
| 54 |
+
gm = exported_program.module()
|
| 55 |
+
assert exported_program.example_inputs is not None
|
| 56 |
+
args, kwargs = exported_program.example_inputs
|
| 57 |
+
artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
|
| 58 |
+
archive_file = BytesIO()
|
| 59 |
+
files = [file for file in artifacts if isinstance(file, str)]
|
| 60 |
+
package_aoti(archive_file, files)
|
| 61 |
+
weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
|
| 62 |
+
return ZeroGPUCompiledModel(archive_file, weights)
|