OpsTorch / src /pipeline.py
KogoroMori
Update src/pipeline.py
993cb41 verified
raw
history blame
4.32 kB
import os
import torch
import torch._dynamo
import gc
import json
import transformers
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast
from PIL.Image import Image
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny, FluxTransformer2DModel, DiffusionPipeline
from pipelines.models import TextToImageRequest
from optimum.quanto import requantize
from torch import Generator
from torch._dynamo import config
from torch._inductor import config as ind_config
from typing import Dict, Any, Callable
from functools import wraps
def error_handler(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
print(f"Error in {func.__name__}: {str(e)}")
return wrapper
class TorchOptimizer:
def optimize_settings(self):
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("high")
def clear_cache(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
class PipelineManager:
def __init__(self):
self.ckpt_root = "MyApricity/FLUX_OPT_SCHNELL_1.2"
self.revision_root = "488528b6f815bff1bbc747cf1e0947c77c544665"
self.pipeline = None
self.optimizer = TorchOptimizer()
# Configure environment
torch._dynamo.config.suppress_errors = True
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
# Initialize torch settings
self.optimizer.optimize_settings()
def load_transformer(self):
transformer_path = os.path.join(
HF_HUB_CACHE,
"models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665"
)
return FluxTransformer2DModel.from_pretrained(
transformer_path,
torch_dtype=torch.bfloat16,
use_safetensors=False
)
@error_handler
def optimize_pipeline(self, pipe):
# Fuse QKV projections
pipe.transformer.fuse_qkv_projections()
pipe.vae.fuse_qkv_projections()
# Optimize memory layout
pipe.transformer.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)
# Configure torch inductor
config = torch._inductor.config
config.disable_progress = False
config.conv_1x1_as_mm = True
# Compile modules
pipe.transformer = torch.compile(
pipe.transformer,
mode="max-autotune",
fullgraph=True
)
pipe.vae.decode = torch.compile(
pipe.vae.decode,
mode="max-autotune",
fullgraph=True
)
return pipe
def load_pipeline(self):
# Load transformer model
transformer_model = self.load_transformer()
# Create pipeline
pipe = DiffusionPipeline.from_pretrained(
self.ckpt_root,
revision=self.revision_root,
transformer=transformer_model,
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
# Optimize pipeline
pipe = self.optimize_pipeline(pipe)
# Trigger compilation
print("Running torch compilation...")
pipe(
"dummy prompt to trigger torch compilation",
output_type="pil",
num_inference_steps=4
).images[0]
print("Finished torch compilation")
return pipe
def run_inference(self, request: TextToImageRequest) -> Image:
if self.pipeline is None:
self.pipeline = self.load_pipeline()
self.optimizer.clear_cache()
generator = Generator(self.pipeline.device).manual_seed(request.seed)
return self.pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
).images[0]