FLUX.1-Kontext-Dev / optimization.py
cbensimon's picture
cbensimon HF Staff
capture_component_call
b63cd34
raw
history blame
1.12 kB
"""
"""
from typing import Any
from typing import Callable
from typing import ParamSpec
import spaces
import torch
from pipeline_utils import capture_component_call
from zerogpu import aoti_compile
P = ParamSpec('P')
INDUCTOR_CONFIGS = {
'conv_1x1_as_mm': True,
'epilogue_fusion': False,
'coordinate_descent_tuning': True,
'coordinate_descent_check_all_directions': True,
'max_autotune': True,
'triton.cudagraphs': True,
}
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
@spaces.GPU(duration=1500)
def compile_transformer():
with capture_component_call(pipeline, 'transformer') as call:
pipeline(*args, **kwargs)
pipeline.transformer.fuse_qkv_projections()
exported = torch.export.export(pipeline.transformer, args=call.args, kwargs=call.kwargs)
return aoti_compile(exported, INDUCTOR_CONFIGS)
transformer_config = pipeline.transformer.config
pipeline.transformer = compile_transformer()
pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]