File size: 3,988 Bytes
a7ff77b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# ======================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
import os

from pipelines.models import TextToImageRequest
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline

from diffusers import FluxPipeline, AutoencoderKL
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
torch._dynamo.config.suppress_errors = True
Pipeline = None
from utils import Config
# ======================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================

config = Config()

CKPT = config.main_path
REV = config.rev

def load_pipeline() -> Pipeline:
    path = os.path.join(HF_HUB_CACHE, config.trans_path)
    transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False)
    vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-schnell", revision="741f7c3ce8b383c54771c7003378a50191e9efe9", subfolder="vae", torch_dtype=torch.bfloat16)
    pipeline = FluxPipeline.from_pretrained(CKPT, revision=REV, transformer=transformer, vae=vae, local_files_only=True, torch_dtype=torch.bfloat16)
    pipeline.to("cuda")
    pipeline = apply_cache_on_pipe(pipeline, residual_diff_threshold=0.756)

    # warmup 2 times
    pipeline("a beautiful girl only")
    pipeline("a beautiful girl only")

    return pipeline

@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    gc.collect()
    generator = Generator(pipeline.device).manual_seed(request.seed)
    return pipeline(
        request.prompt,
        generator=generator,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
    ).images[0]