File size: 4,315 Bytes
4fcd1d5
 
 
 
 
 
 
0629499
4fcd1d5
993cb41
4fcd1d5
 
0629499
 
 
993cb41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e09c84c
993cb41
 
 
0629499
 
4fcd1d5
993cb41
0629499
 
e09c84c
993cb41
0629499
993cb41
 
e09c84c
993cb41
0629499
993cb41
 
 
0629499
 
993cb41
 
 
0629499
e09c84c
993cb41
 
 
 
 
 
 
 
 
 
 
 
 
 
e09c84c
993cb41
 
0629499
993cb41
 
0629499
993cb41
0629499
993cb41
0629499
993cb41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import torch
import torch._dynamo
import gc
import json
import transformers
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast
from PIL.Image import Image
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny, FluxTransformer2DModel, DiffusionPipeline
from pipelines.models import TextToImageRequest
from optimum.quanto import requantize
from torch import Generator
from torch._dynamo import config
from torch._inductor import config as ind_config
from typing import Dict, Any, Callable
from functools import wraps

def error_handler(func: Callable):
    @wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except Exception as e:
            print(f"Error in {func.__name__}: {str(e)}")
    return wrapper

class TorchOptimizer:
    def optimize_settings(self):
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        torch.backends.cudnn.benchmark = True
        torch.set_float32_matmul_precision("high")

    def clear_cache(self):
        torch.cuda.empty_cache()
        torch.cuda.reset_max_memory_allocated()
        torch.cuda.reset_peak_memory_stats()

class PipelineManager:
    def __init__(self):
        self.ckpt_root = "MyApricity/FLUX_OPT_SCHNELL_1.2"
        self.revision_root = "488528b6f815bff1bbc747cf1e0947c77c544665"
        self.pipeline = None
        self.optimizer = TorchOptimizer()
        
        # Configure environment
        torch._dynamo.config.suppress_errors = True
        os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
        os.environ["TOKENIZERS_PARALLELISM"] = "True"
        
        # Initialize torch settings
        self.optimizer.optimize_settings()


    def load_transformer(self):
        transformer_path = os.path.join(
            HF_HUB_CACHE,
            "models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665"
        )
        return FluxTransformer2DModel.from_pretrained(
            transformer_path,
            torch_dtype=torch.bfloat16,
            use_safetensors=False
        )

    @error_handler
    def optimize_pipeline(self, pipe):
        # Fuse QKV projections
        pipe.transformer.fuse_qkv_projections()
        pipe.vae.fuse_qkv_projections()

        # Optimize memory layout
        pipe.transformer.to(memory_format=torch.channels_last)
        pipe.vae.to(memory_format=torch.channels_last)

        # Configure torch inductor
        config = torch._inductor.config
        config.disable_progress = False
        config.conv_1x1_as_mm = True

        # Compile modules
        pipe.transformer = torch.compile(
            pipe.transformer,
            mode="max-autotune",
            fullgraph=True
        )
        pipe.vae.decode = torch.compile(
            pipe.vae.decode,
            mode="max-autotune",
            fullgraph=True
        )

        return pipe

    def load_pipeline(self):
        # Load transformer model
        transformer_model = self.load_transformer()
        
        # Create pipeline
        pipe = DiffusionPipeline.from_pretrained(
            self.ckpt_root,
            revision=self.revision_root,
            transformer=transformer_model,
            torch_dtype=torch.bfloat16
        )
        pipe.to("cuda")

        # Optimize pipeline
        pipe = self.optimize_pipeline(pipe)

        # Trigger compilation
        print("Running torch compilation...")
        pipe(
            "dummy prompt to trigger torch compilation",
            output_type="pil",
            num_inference_steps=4
        ).images[0]
        print("Finished torch compilation")

        return pipe

    def run_inference(self, request: TextToImageRequest) -> Image:
        if self.pipeline is None:
            self.pipeline = self.load_pipeline()

        self.optimizer.clear_cache()
        generator = Generator(self.pipeline.device).manual_seed(request.seed)

        return self.pipeline(
            request.prompt,
            generator=generator,
            guidance_scale=0.0,
            num_inference_steps=4,
            max_sequence_length=256,
            height=request.height,
            width=request.width,
        ).images[0]