File size: 3,365 Bytes
4fcd1d5 bc8ab2f 4fcd1d5 bc8ab2f 4fcd1d5 bc8ab2f df41d99 bc8ab2f df41d99 6e6f409 df41d99 bc8ab2f df41d99 bc8ab2f df41d99 e09c84c bc8ab2f e09c84c df41d99 bc8ab2f df41d99 bc8ab2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import os
import torch
import torch._dynamo
import gc
import transformers
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from PIL.Image import Image
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from pipelines.models import TextToImageRequest
from typing import Dict, Any
from functools import wraps
# Global settings
torch._dynamo.config.suppress_errors = True
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
ckpt_root = "MyApricity/FLUX_OPT_SCHNELL_1.2"
revision_root = "488528b6f815bff1bbc747cf1e0947c77c544665"
Pipeline = None
def error_handler(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
print(f"Error in {func.__name__}: {str(e)}")
return None
return wrapper
def remove_cache():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
gc.collect()
torch.cuda.reset_peak_memory_stats()
@error_handler
def optimize_pipeline(pipe):
# Fuse QKV projections
pipe.transformer.fuse_qkv_projections()
pipe.vae.fuse_qkv_projections()
# Optimize memory layout
pipe.transformer.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)
# Configure torch inductor
from torch._inductor import config as ind_config
ind_config.disable_progress = False
ind_config.conv_1x1_as_mm = True
return pipe
def load_pipeline() -> Pipeline:
transformer_path = os.path.join(
HF_HUB_CACHE,
"models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665"
)
transformer = FluxTransformer2DModel.from_pretrained(
transformer_path,
torch_dtype=torch.bfloat16,
use_safetensors=False
)
try:
pipeline = DiffusionPipeline.from_pretrained(
ckpt_root,
revision=revision_root,
transformer=transformer,
torch_dtype=torch.bfloat16
)
except:
pipeline = DiffusionPipeline.from_pretrained(
ckpt_root,
revision=revision_root,
torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
# Apply optimizations
___ops_pipeline = optimize_pipeline(pipeline)
if pipeline is not None:
pipeline = ___ops_pipeline
# Warmup runs
prompt_xnxx = "pantomorphia, dorsilateral, nonlife, unenthusiastic, quadriform, throatlet, bluntish, soldierize"
pipeline(
prompt=prompt_xnxx,
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256
)
return pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
remove_cache()
generator = Generator(pipeline.device).manual_seed(request.seed)
return pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
).images[0] |