File size: 4,315 Bytes
4fcd1d5 0629499 4fcd1d5 993cb41 4fcd1d5 0629499 993cb41 e09c84c 993cb41 0629499 4fcd1d5 993cb41 0629499 e09c84c 993cb41 0629499 993cb41 e09c84c 993cb41 0629499 993cb41 0629499 993cb41 0629499 e09c84c 993cb41 e09c84c 993cb41 0629499 993cb41 0629499 993cb41 0629499 993cb41 0629499 993cb41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
import torch
import torch._dynamo
import gc
import json
import transformers
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast
from PIL.Image import Image
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny, FluxTransformer2DModel, DiffusionPipeline
from pipelines.models import TextToImageRequest
from optimum.quanto import requantize
from torch import Generator
from torch._dynamo import config
from torch._inductor import config as ind_config
from typing import Dict, Any, Callable
from functools import wraps
def error_handler(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
print(f"Error in {func.__name__}: {str(e)}")
return wrapper
class TorchOptimizer:
def optimize_settings(self):
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("high")
def clear_cache(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
class PipelineManager:
def __init__(self):
self.ckpt_root = "MyApricity/FLUX_OPT_SCHNELL_1.2"
self.revision_root = "488528b6f815bff1bbc747cf1e0947c77c544665"
self.pipeline = None
self.optimizer = TorchOptimizer()
# Configure environment
torch._dynamo.config.suppress_errors = True
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
# Initialize torch settings
self.optimizer.optimize_settings()
def load_transformer(self):
transformer_path = os.path.join(
HF_HUB_CACHE,
"models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665"
)
return FluxTransformer2DModel.from_pretrained(
transformer_path,
torch_dtype=torch.bfloat16,
use_safetensors=False
)
@error_handler
def optimize_pipeline(self, pipe):
# Fuse QKV projections
pipe.transformer.fuse_qkv_projections()
pipe.vae.fuse_qkv_projections()
# Optimize memory layout
pipe.transformer.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)
# Configure torch inductor
config = torch._inductor.config
config.disable_progress = False
config.conv_1x1_as_mm = True
# Compile modules
pipe.transformer = torch.compile(
pipe.transformer,
mode="max-autotune",
fullgraph=True
)
pipe.vae.decode = torch.compile(
pipe.vae.decode,
mode="max-autotune",
fullgraph=True
)
return pipe
def load_pipeline(self):
# Load transformer model
transformer_model = self.load_transformer()
# Create pipeline
pipe = DiffusionPipeline.from_pretrained(
self.ckpt_root,
revision=self.revision_root,
transformer=transformer_model,
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
# Optimize pipeline
pipe = self.optimize_pipeline(pipe)
# Trigger compilation
print("Running torch compilation...")
pipe(
"dummy prompt to trigger torch compilation",
output_type="pil",
num_inference_steps=4
).images[0]
print("Finished torch compilation")
return pipe
def run_inference(self, request: TextToImageRequest) -> Image:
if self.pipeline is None:
self.pipeline = self.load_pipeline()
self.optimizer.clear_cache()
generator = Generator(self.pipeline.device).manual_seed(request.seed)
return self.pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
).images[0] |