# ====================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================== | |
from huggingface_hub.constants import HF_HUB_CACHE | |
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel | |
import torch | |
import torch._dynamo | |
import gc | |
import os | |
from pipelines.models import TextToImageRequest | |
from torch import Generator | |
from diffusers import FluxTransformer2DModel, DiffusionPipeline | |
from diffusers import FluxPipeline, AutoencoderKL | |
from PIL.Image import Image | |
from pipelines.models import TextToImageRequest | |
from torch import Generator | |
from diffusers import FluxTransformer2DModel, DiffusionPipeline | |
from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe | |
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True" | |
os.environ["TOKENIZERS_PARALLELISM"] = "True" | |
torch._dynamo.config.suppress_errors = True | |
Pipeline = None | |
from utils import Config | |
# ====================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================== | |
config = Config() | |
CKPT = config.main_path | |
REV = config.rev | |
def load_pipeline() -> Pipeline: | |
path = os.path.join(HF_HUB_CACHE, config.trans_path) | |
transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False) | |
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-schnell", revision="741f7c3ce8b383c54771c7003378a50191e9efe9", subfolder="vae", torch_dtype=torch.bfloat16) | |
pipeline = FluxPipeline.from_pretrained(CKPT, revision=REV, transformer=transformer, vae=vae, local_files_only=True, torch_dtype=torch.bfloat16) | |
pipeline.to("cuda") | |
pipeline = apply_cache_on_pipe(pipeline, residual_diff_threshold=0.756) | |
# warmup 2 times | |
pipeline("a beautiful girl only") | |
pipeline("a beautiful girl only") | |
return pipeline | |
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image: | |
gc.collect() | |
generator = Generator(pipeline.device).manual_seed(request.seed) | |
return pipeline( | |
request.prompt, | |
generator=generator, | |
guidance_scale=0.0, | |
num_inference_steps=4, | |
max_sequence_length=256, | |
height=request.height, | |
width=request.width, | |
).images[0] | |