import os import torch import torch._dynamo import gc import transformers from huggingface_hub.constants import HF_HUB_CACHE from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel from torch import Generator from diffusers import FluxTransformer2DModel, DiffusionPipeline from PIL.Image import Image from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny from pipelines.models import TextToImageRequest from typing import Dict, Any from functools import wraps # Global settings torch._dynamo.config.suppress_errors = True os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True" os.environ["TOKENIZERS_PARALLELISM"] = "True" ckpt_root = "MyApricity/FLUX_OPT_SCHNELL_1.2" revision_root = "488528b6f815bff1bbc747cf1e0947c77c544665" Pipeline = None def error_handler(func): @wraps(func) def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: print(f"Error in {func.__name__}: {str(e)}") return None return wrapper def remove_cache(): torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() gc.collect() torch.cuda.reset_peak_memory_stats() @error_handler def optimize_pipeline(pipe): # Fuse QKV projections pipe.transformer.fuse_qkv_projections() pipe.vae.fuse_qkv_projections() # Optimize memory layout pipe.transformer.to(memory_format=torch.channels_last) pipe.vae.to(memory_format=torch.channels_last) # Configure torch inductor from torch._inductor import config as ind_config ind_config.disable_progress = False ind_config.conv_1x1_as_mm = True return pipe def load_pipeline() -> Pipeline: transformer_path = os.path.join( HF_HUB_CACHE, "models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665" ) transformer = FluxTransformer2DModel.from_pretrained( transformer_path, torch_dtype=torch.bfloat16, use_safetensors=False ) try: pipeline = DiffusionPipeline.from_pretrained( ckpt_root, revision=revision_root, transformer=transformer, torch_dtype=torch.bfloat16 ) except: pipeline = DiffusionPipeline.from_pretrained( ckpt_root, revision=revision_root, torch_dtype=torch.bfloat16 ) pipeline.to("cuda") # Apply optimizations ___ops_pipeline = optimize_pipeline(pipeline) if pipeline is not None: pipeline = ___ops_pipeline # Warmup runs prompt_xnxx = "pantomorphia, dorsilateral, nonlife, unenthusiastic, quadriform, throatlet, bluntish, soldierize" pipeline( prompt=prompt_xnxx, width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256 ) return pipeline @torch.no_grad() def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image: remove_cache() generator = Generator(pipeline.device).manual_seed(request.seed) return pipeline( request.prompt, generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, ).images[0]