File size: 3,365 Bytes
4fcd1d5
 
 
 
 
 
bc8ab2f
 
 
4fcd1d5
bc8ab2f
4fcd1d5
bc8ab2f
df41d99
 
bc8ab2f
 
 
 
 
 
 
 
 
 
df41d99
 
 
 
 
 
6e6f409
df41d99
 
 
bc8ab2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df41d99
bc8ab2f
 
 
 
 
df41d99
e09c84c
bc8ab2f
e09c84c
df41d99
bc8ab2f
 
 
 
 
df41d99
bc8ab2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import torch
import torch._dynamo
import gc
import transformers
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from PIL.Image import Image
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from pipelines.models import TextToImageRequest
from typing import Dict, Any
from functools import wraps

# Global settings
torch._dynamo.config.suppress_errors = True
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"

ckpt_root = "MyApricity/FLUX_OPT_SCHNELL_1.2"
revision_root = "488528b6f815bff1bbc747cf1e0947c77c544665"
Pipeline = None

def error_handler(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except Exception as e:
            print(f"Error in {func.__name__}: {str(e)}")
            return None
    return wrapper


def remove_cache():
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    gc.collect()
    torch.cuda.reset_peak_memory_stats()

@error_handler
def optimize_pipeline(pipe):
    # Fuse QKV projections
    pipe.transformer.fuse_qkv_projections()
    pipe.vae.fuse_qkv_projections()

    # Optimize memory layout
    pipe.transformer.to(memory_format=torch.channels_last)
    pipe.vae.to(memory_format=torch.channels_last)

    # Configure torch inductor
    from torch._inductor import config as ind_config
    ind_config.disable_progress = False
    ind_config.conv_1x1_as_mm = True

    return pipe


def load_pipeline() -> Pipeline:

    transformer_path = os.path.join(
        HF_HUB_CACHE,
        "models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665"
    )
    
    transformer = FluxTransformer2DModel.from_pretrained(
        transformer_path,
        torch_dtype=torch.bfloat16,
        use_safetensors=False
    )

    try:
        pipeline = DiffusionPipeline.from_pretrained(
            ckpt_root,
            revision=revision_root,
            transformer=transformer,
            torch_dtype=torch.bfloat16
        )
    except:
        pipeline = DiffusionPipeline.from_pretrained(
            ckpt_root,
            revision=revision_root,
            torch_dtype=torch.bfloat16
        )

    pipeline.to("cuda")


    # Apply optimizations
    ___ops_pipeline = optimize_pipeline(pipeline)

    if pipeline is not None:
        pipeline = ___ops_pipeline
        
    # Warmup runs
    prompt_xnxx = "pantomorphia, dorsilateral, nonlife, unenthusiastic, quadriform, throatlet, bluntish, soldierize"
    pipeline(
        prompt=prompt_xnxx,
        width=1024,
        height=1024,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256
    )

    return pipeline

@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    remove_cache()
    generator = Generator(pipeline.device).manual_seed(request.seed)

    return pipeline(
        request.prompt,
        generator=generator,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
    ).images[0]