File size: 4,194 Bytes
4fcd1d5 e09c84c 4fcd1d5 e09c84c 4fcd1d5 e09c84c 4fcd1d5 e09c84c 4fcd1d5 e09c84c 4fcd1d5 e09c84c 4fcd1d5 e09c84c 4870f5c e09c84c 4870f5c e09c84c 4870f5c 4fcd1d5 e09c84c 4fcd1d5 e09c84c 4fcd1d5 e09c84c 4fcd1d5 e09c84c 4fcd1d5 e09c84c 4fcd1d5 e09c84c 4fcd1d5 e09c84c 4fcd1d5 e09c84c 4fcd1d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import os
import torch
import torch._dynamo
import gc
import json
import transformers
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
# ApricityApricityApricityApricityApricityApricityApricityApricityApricityApricityApricity
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from PIL.Image import Image
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from pipelines.models import TextToImageRequest
from optimum.quanto import requantize
import json
# ApricityApricityApricityApricityApricityApricityApricityApricityApricityApricityApricity
torch._dynamo.config.suppress_errors = True
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
ckpt_root = "MyApricity/FLUX_OPT_SCHNELL_1.2"
revision_root = "488528b6f815bff1bbc747cf1e0947c77c544665"
Pipeline = None
use_com = False
import torch
import math
from typing import Dict, Any
def remove_cache():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
gc.collect()
torch.cuda.reset_peak_memory_stats()
def text_t5_loader() -> T5EncoderModel:
print("Loading text encoder...")
text_encoder = T5EncoderModel.from_pretrained(
"city96/t5-v1_1-xxl-encoder-bf16",
revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
torch_dtype=torch.bfloat16,
)
return text_encoder.to(memory_format=torch.channels_last)
class StableDiffusionTransformerCompile:
def __init__(self, pipeline, optimize=False):
self.pipeline = pipeline
self.optimize = optimize
if self.optimize:
self.model_compiling()
def model_compiling(self):
# Staff doing here
self.pipeline.unet = torch.compile(self.pipeline.unet)
def __call__(self, *args, **kwargs):
return self.pipeline(*args, **kwargs)
def load_pipeline() -> Pipeline:
text_t5_encoder = text_t5_loader()
transformer_path__ = os.path.join(HF_HUB_CACHE, "models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665")
transformer__ = FluxTransformer2DModel.from_pretrained(transformer_path__, torch_dtype=torch.bfloat16, use_safetensors=False)
try:
pipeline = DiffusionPipeline.from_pretrained(ckpt_root,
revision=revision_root,
transformer=transformer__,
torch_dtype=torch.bfloat16)
except:
pipeline = DiffusionPipeline.from_pretrained(ckpt_root,
revision=revision_root,
torch_dtype=torch.bfloat16)
pipeline.to("cuda")
try:
compiled_pipeline = StableDiffusionTransformerCompile(pipeline, optimize=False)
if use_com:
pipeline = compiled_pipeline
else:
print("Currently not compling affectively")
pipeline.disable_vae_compress()
pipeline.text_encoder_2 = text_t5_encoder
except:
print("pipeline")
prompt_1 = "albaspidin, pillmonger, palaeocrystalline"
pipeline(prompt=prompt_1,
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256)
prompt_2 = "obe, kilometrage, circuition"
pipeline(prompt=prompt_2,
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256)
return pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
remove_cache()
# remove cache here for better result
generator = Generator(pipeline.device).manual_seed(request.seed)
return pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
).images[0] |