|
import os |
|
import torch |
|
import torch._dynamo |
|
import gc |
|
import transformers |
|
from huggingface_hub.constants import HF_HUB_CACHE |
|
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel |
|
from torch import Generator |
|
from diffusers import FluxTransformer2DModel, DiffusionPipeline |
|
from PIL.Image import Image |
|
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny |
|
from pipelines.models import TextToImageRequest |
|
from typing import Dict, Any |
|
from functools import wraps |
|
|
|
|
|
torch._dynamo.config.suppress_errors = True |
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True" |
|
os.environ["TOKENIZERS_PARALLELISM"] = "True" |
|
|
|
ckpt_root = "MyApricity/FLUX_OPT_SCHNELL_1.2" |
|
revision_root = "488528b6f815bff1bbc747cf1e0947c77c544665" |
|
Pipeline = None |
|
|
|
def error_handler(func): |
|
@wraps(func) |
|
def wrapper(*args, **kwargs): |
|
try: |
|
return func(*args, **kwargs) |
|
except Exception as e: |
|
print(f"Error in {func.__name__}: {str(e)}") |
|
return None |
|
return wrapper |
|
|
|
|
|
def remove_cache(): |
|
torch.cuda.empty_cache() |
|
torch.cuda.reset_max_memory_allocated() |
|
gc.collect() |
|
torch.cuda.reset_peak_memory_stats() |
|
|
|
@error_handler |
|
def optimize_pipeline(pipe): |
|
|
|
pipe.transformer.fuse_qkv_projections() |
|
pipe.vae.fuse_qkv_projections() |
|
|
|
|
|
pipe.transformer.to(memory_format=torch.channels_last) |
|
pipe.vae.to(memory_format=torch.channels_last) |
|
|
|
|
|
from torch._inductor import config as ind_config |
|
ind_config.disable_progress = False |
|
ind_config.conv_1x1_as_mm = True |
|
|
|
return pipe |
|
|
|
|
|
def load_pipeline() -> Pipeline: |
|
|
|
transformer_path = os.path.join( |
|
HF_HUB_CACHE, |
|
"models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665" |
|
) |
|
|
|
transformer = FluxTransformer2DModel.from_pretrained( |
|
transformer_path, |
|
torch_dtype=torch.bfloat16, |
|
use_safetensors=False |
|
) |
|
|
|
try: |
|
pipeline = DiffusionPipeline.from_pretrained( |
|
ckpt_root, |
|
revision=revision_root, |
|
transformer=transformer, |
|
torch_dtype=torch.bfloat16 |
|
) |
|
except: |
|
pipeline = DiffusionPipeline.from_pretrained( |
|
ckpt_root, |
|
revision=revision_root, |
|
torch_dtype=torch.bfloat16 |
|
) |
|
|
|
pipeline.to("cuda") |
|
|
|
|
|
|
|
___ops_pipeline = optimize_pipeline(pipeline) |
|
|
|
if pipeline is not None: |
|
pipeline = ___ops_pipeline |
|
|
|
|
|
prompt_xnxx = "pantomorphia, dorsilateral, nonlife, unenthusiastic, quadriform, throatlet, bluntish, soldierize" |
|
pipeline( |
|
prompt=prompt_xnxx, |
|
width=1024, |
|
height=1024, |
|
guidance_scale=0.0, |
|
num_inference_steps=4, |
|
max_sequence_length=256 |
|
) |
|
|
|
return pipeline |
|
|
|
@torch.no_grad() |
|
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image: |
|
remove_cache() |
|
generator = Generator(pipeline.device).manual_seed(request.seed) |
|
|
|
return pipeline( |
|
request.prompt, |
|
generator=generator, |
|
guidance_scale=0.0, |
|
num_inference_steps=4, |
|
max_sequence_length=256, |
|
height=request.height, |
|
width=request.width, |
|
).images[0] |