OpsTorch / src /pipeline.py
YOURNAME
i
e09c84c
raw
history blame
4.19 kB
import os
import torch
import torch._dynamo
import gc
import json
import transformers
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
# ApricityApricityApricityApricityApricityApricityApricityApricityApricityApricityApricity
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from PIL.Image import Image
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from pipelines.models import TextToImageRequest
from optimum.quanto import requantize
import json
# ApricityApricityApricityApricityApricityApricityApricityApricityApricityApricityApricity
torch._dynamo.config.suppress_errors = True
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
ckpt_root = "MyApricity/FLUX_OPT_SCHNELL_1.2"
revision_root = "488528b6f815bff1bbc747cf1e0947c77c544665"
Pipeline = None
use_com = False
import torch
import math
from typing import Dict, Any
def remove_cache():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
gc.collect()
torch.cuda.reset_peak_memory_stats()
def text_t5_loader() -> T5EncoderModel:
print("Loading text encoder...")
text_encoder = T5EncoderModel.from_pretrained(
"city96/t5-v1_1-xxl-encoder-bf16",
revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
torch_dtype=torch.bfloat16,
)
return text_encoder.to(memory_format=torch.channels_last)
class StableDiffusionTransformerCompile:
def __init__(self, pipeline, optimize=False):
self.pipeline = pipeline
self.optimize = optimize
if self.optimize:
self.model_compiling()
def model_compiling(self):
# Staff doing here
self.pipeline.unet = torch.compile(self.pipeline.unet)
def __call__(self, *args, **kwargs):
return self.pipeline(*args, **kwargs)
def load_pipeline() -> Pipeline:
text_t5_encoder = text_t5_loader()
transformer_path__ = os.path.join(HF_HUB_CACHE, "models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665")
transformer__ = FluxTransformer2DModel.from_pretrained(transformer_path__, torch_dtype=torch.bfloat16, use_safetensors=False)
try:
pipeline = DiffusionPipeline.from_pretrained(ckpt_root,
revision=revision_root,
transformer=transformer__,
torch_dtype=torch.bfloat16)
except:
pipeline = DiffusionPipeline.from_pretrained(ckpt_root,
revision=revision_root,
torch_dtype=torch.bfloat16)
pipeline.to("cuda")
try:
compiled_pipeline = StableDiffusionTransformerCompile(pipeline, optimize=False)
if use_com:
pipeline = compiled_pipeline
else:
print("Currently not compling affectively")
pipeline.disable_vae_compress()
pipeline.text_encoder_2 = text_t5_encoder
except:
print("pipeline")
prompt_1 = "albaspidin, pillmonger, palaeocrystalline"
pipeline(prompt=prompt_1,
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256)
prompt_2 = "obe, kilometrage, circuition"
pipeline(prompt=prompt_2,
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256)
return pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
remove_cache()
# remove cache here for better result
generator = Generator(pipeline.device).manual_seed(request.seed)
return pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
).images[0]