|
import os |
|
import torch |
|
import torch._dynamo |
|
import gc |
|
import json |
|
import transformers |
|
from huggingface_hub.constants import HF_HUB_CACHE |
|
from transformers import T5EncoderModel, T5TokenizerFast |
|
from PIL.Image import Image |
|
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny, FluxTransformer2DModel, DiffusionPipeline |
|
from pipelines.models import TextToImageRequest |
|
from optimum.quanto import requantize |
|
from torch import Generator |
|
from torch._dynamo import config |
|
from torch._inductor import config as ind_config |
|
from typing import Dict, Any, Callable |
|
from functools import wraps |
|
|
|
def error_handler(func: Callable): |
|
@wraps(func) |
|
def wrapper(*args, **kwargs): |
|
try: |
|
return func(*args, **kwargs) |
|
except Exception as e: |
|
print(f"Error in {func.__name__}: {str(e)}") |
|
return wrapper |
|
|
|
class TorchOptimizer: |
|
def optimize_settings(self): |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
torch.backends.cudnn.benchmark = True |
|
torch.set_float32_matmul_precision("high") |
|
|
|
def clear_cache(self): |
|
torch.cuda.empty_cache() |
|
torch.cuda.reset_max_memory_allocated() |
|
torch.cuda.reset_peak_memory_stats() |
|
|
|
class PipelineManager: |
|
def __init__(self): |
|
self.ckpt_root = "MyApricity/FLUX_OPT_SCHNELL_1.2" |
|
self.revision_root = "488528b6f815bff1bbc747cf1e0947c77c544665" |
|
self.pipeline = None |
|
self.optimizer = TorchOptimizer() |
|
|
|
|
|
torch._dynamo.config.suppress_errors = True |
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True" |
|
os.environ["TOKENIZERS_PARALLELISM"] = "True" |
|
|
|
|
|
self.optimizer.optimize_settings() |
|
|
|
|
|
def load_transformer(self): |
|
transformer_path = os.path.join( |
|
HF_HUB_CACHE, |
|
"models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665" |
|
) |
|
return FluxTransformer2DModel.from_pretrained( |
|
transformer_path, |
|
torch_dtype=torch.bfloat16, |
|
use_safetensors=False |
|
) |
|
|
|
@error_handler |
|
def optimize_pipeline(self, pipe): |
|
|
|
pipe.transformer.fuse_qkv_projections() |
|
pipe.vae.fuse_qkv_projections() |
|
|
|
|
|
pipe.transformer.to(memory_format=torch.channels_last) |
|
pipe.vae.to(memory_format=torch.channels_last) |
|
|
|
|
|
config = torch._inductor.config |
|
config.disable_progress = False |
|
config.conv_1x1_as_mm = True |
|
|
|
|
|
pipe.transformer = torch.compile( |
|
pipe.transformer, |
|
mode="max-autotune", |
|
fullgraph=True |
|
) |
|
pipe.vae.decode = torch.compile( |
|
pipe.vae.decode, |
|
mode="max-autotune", |
|
fullgraph=True |
|
) |
|
|
|
return pipe |
|
|
|
def load_pipeline(self): |
|
|
|
transformer_model = self.load_transformer() |
|
|
|
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
self.ckpt_root, |
|
revision=self.revision_root, |
|
transformer=transformer_model, |
|
torch_dtype=torch.bfloat16 |
|
) |
|
pipe.to("cuda") |
|
|
|
|
|
pipe = self.optimize_pipeline(pipe) |
|
|
|
|
|
print("Running torch compilation...") |
|
pipe( |
|
"dummy prompt to trigger torch compilation", |
|
output_type="pil", |
|
num_inference_steps=4 |
|
).images[0] |
|
print("Finished torch compilation") |
|
|
|
return pipe |
|
|
|
def run_inference(self, request: TextToImageRequest) -> Image: |
|
if self.pipeline is None: |
|
self.pipeline = self.load_pipeline() |
|
|
|
self.optimizer.clear_cache() |
|
generator = Generator(self.pipeline.device).manual_seed(request.seed) |
|
|
|
return self.pipeline( |
|
request.prompt, |
|
generator=generator, |
|
guidance_scale=0.0, |
|
num_inference_steps=4, |
|
max_sequence_length=256, |
|
height=request.height, |
|
width=request.width, |
|
).images[0] |