OpsTorch / src /pipeline.py
MyApricity's picture
Update src/pipeline.py
bc8ab2f verified
import os
import torch
import torch._dynamo
import gc
import transformers
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from PIL.Image import Image
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from pipelines.models import TextToImageRequest
from typing import Dict, Any
from functools import wraps
# Global settings
torch._dynamo.config.suppress_errors = True
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
ckpt_root = "MyApricity/FLUX_OPT_SCHNELL_1.2"
revision_root = "488528b6f815bff1bbc747cf1e0947c77c544665"
Pipeline = None
def error_handler(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
print(f"Error in {func.__name__}: {str(e)}")
return None
return wrapper
def remove_cache():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
gc.collect()
torch.cuda.reset_peak_memory_stats()
@error_handler
def optimize_pipeline(pipe):
# Fuse QKV projections
pipe.transformer.fuse_qkv_projections()
pipe.vae.fuse_qkv_projections()
# Optimize memory layout
pipe.transformer.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)
# Configure torch inductor
from torch._inductor import config as ind_config
ind_config.disable_progress = False
ind_config.conv_1x1_as_mm = True
return pipe
def load_pipeline() -> Pipeline:
transformer_path = os.path.join(
HF_HUB_CACHE,
"models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665"
)
transformer = FluxTransformer2DModel.from_pretrained(
transformer_path,
torch_dtype=torch.bfloat16,
use_safetensors=False
)
try:
pipeline = DiffusionPipeline.from_pretrained(
ckpt_root,
revision=revision_root,
transformer=transformer,
torch_dtype=torch.bfloat16
)
except:
pipeline = DiffusionPipeline.from_pretrained(
ckpt_root,
revision=revision_root,
torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
# Apply optimizations
___ops_pipeline = optimize_pipeline(pipeline)
if pipeline is not None:
pipeline = ___ops_pipeline
# Warmup runs
prompt_xnxx = "pantomorphia, dorsilateral, nonlife, unenthusiastic, quadriform, throatlet, bluntish, soldierize"
pipeline(
prompt=prompt_xnxx,
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256
)
return pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
remove_cache()
generator = Generator(pipeline.device).manual_seed(request.seed)
return pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
).images[0]