File size: 4,194 Bytes
4fcd1d5
 
 
 
 
 
 
 
 
 
e09c84c
 
4fcd1d5
 
 
 
 
 
 
 
 
 
e09c84c
4fcd1d5
 
 
 
 
 
e09c84c
 
4fcd1d5
e09c84c
4fcd1d5
 
 
 
 
 
 
 
e09c84c
4fcd1d5
 
 
 
e09c84c
 
 
 
 
 
 
 
4fcd1d5
 
e09c84c
4870f5c
 
 
 
e09c84c
4870f5c
e09c84c
 
4870f5c
 
 
 
 
4fcd1d5
 
e09c84c
4fcd1d5
e09c84c
4fcd1d5
e09c84c
4fcd1d5
e09c84c
 
 
 
 
 
 
 
 
 
4fcd1d5
 
e09c84c
4fcd1d5
e09c84c
 
 
 
 
 
 
 
 
 
4fcd1d5
e09c84c
4fcd1d5
 
e09c84c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fcd1d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import torch
import torch._dynamo
import gc

import json
import transformers
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel

# ApricityApricityApricityApricityApricityApricityApricityApricityApricityApricityApricity

from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline

from PIL.Image import Image
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from pipelines.models import TextToImageRequest
from optimum.quanto import requantize
import json


# ApricityApricityApricityApricityApricityApricityApricityApricityApricityApricityApricity


torch._dynamo.config.suppress_errors = True
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"

ckpt_root = "MyApricity/FLUX_OPT_SCHNELL_1.2"
revision_root = "488528b6f815bff1bbc747cf1e0947c77c544665"
Pipeline = None
use_com = False

import torch
import math
from typing import Dict, Any

def remove_cache():
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    gc.collect()
    torch.cuda.reset_peak_memory_stats()



def text_t5_loader() -> T5EncoderModel:
    print("Loading text encoder...")
    text_encoder = T5EncoderModel.from_pretrained(
        "city96/t5-v1_1-xxl-encoder-bf16",
        revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
        torch_dtype=torch.bfloat16,
    )
    return text_encoder.to(memory_format=torch.channels_last)


class StableDiffusionTransformerCompile:
    def __init__(self, pipeline, optimize=False):
        self.pipeline = pipeline
        self.optimize = optimize
        if self.optimize:
            self.model_compiling()
    
    def model_compiling(self):
        # Staff doing here
        self.pipeline.unet = torch.compile(self.pipeline.unet)
    
    def __call__(self, *args, **kwargs):
        return self.pipeline(*args, **kwargs)
    
def load_pipeline() -> Pipeline:

    text_t5_encoder = text_t5_loader()

    transformer_path__ = os.path.join(HF_HUB_CACHE, "models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665")
    
    transformer__ = FluxTransformer2DModel.from_pretrained(transformer_path__, torch_dtype=torch.bfloat16, use_safetensors=False)

    try:
        pipeline = DiffusionPipeline.from_pretrained(ckpt_root, 
                            revision=revision_root, 
                            transformer=transformer__, 
                            torch_dtype=torch.bfloat16)

    except:
        pipeline = DiffusionPipeline.from_pretrained(ckpt_root, 
                            revision=revision_root,
                            torch_dtype=torch.bfloat16)

    pipeline.to("cuda")

    try:
        compiled_pipeline = StableDiffusionTransformerCompile(pipeline, optimize=False)

        if use_com:
            pipeline = compiled_pipeline
        else:
            print("Currently not compling affectively")

        pipeline.disable_vae_compress()
        pipeline.text_encoder_2 = text_t5_encoder

    except:
        print("pipeline")
        

    prompt_1 = "albaspidin, pillmonger, palaeocrystalline"
    pipeline(prompt=prompt_1, 
                    width=1024,
                    height=1024,
                    guidance_scale=0.0,
                    num_inference_steps=4,
                    max_sequence_length=256)

    prompt_2 = "obe, kilometrage, circuition"
    pipeline(prompt=prompt_2, 
                    width=1024,
                    height=1024,
                    guidance_scale=0.0,
                    num_inference_steps=4,
                    max_sequence_length=256)

    return pipeline


@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:

    remove_cache()
    # remove cache here for better result
    generator = Generator(pipeline.device).manual_seed(request.seed)

    return pipeline(
        request.prompt,
        generator=generator,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
    ).images[0]