English
John6666 commited on
Commit
015668e
·
verified ·
1 Parent(s): 4bdfc27

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +75 -211
  2. requirements.txt +7 -12
handler.py CHANGED
@@ -1,227 +1,91 @@
1
- # https://github.com/sayakpaul/diffusers-torchao
2
- # https://github.com/pytorch/ao/releases
3
- # https://developer.nvidia.com/cuda-gpus
4
-
5
  import os
6
- from typing import Any, Dict
7
- import gc
8
- import time
9
  from PIL import Image
10
- from huggingface_hub import hf_hub_download
11
  import torch
12
- from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int8_weight, int8_dynamic_activation_int4_weight, float8_dynamic_activation_float8_weight
13
- from torchao.quantization.quant_api import PerRow
14
- from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
15
- from transformers import T5EncoderModel, BitsAndBytesConfig
16
- from optimum.quanto import freeze, qfloat8, quantize
17
- from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
18
  from huggingface_inference_toolkit.logging import logger
19
-
20
- import subprocess
21
- subprocess.run("pip list", shell=True)
22
-
23
- print("device name:", torch.cuda.get_device_name())
24
- print("device capability:", torch.cuda.get_device_capability())
25
-
26
- IS_TURBO = False
27
- IS_4BIT = False
28
- IS_PARA = True
29
- IS_LVRAM = False
30
- IS_COMPILE = True
31
- IS_AUTOQ = False
32
- IS_CC90 = True if torch.cuda.get_device_capability() >= (9, 0) else False
33
- IS_CC89 = True if torch.cuda.get_device_capability() >= (8, 9) else False
34
 
35
  # Set high precision for float32 matrix multiplications.
36
  # This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
37
  torch.set_float32_matmul_precision("high")
38
 
39
- if IS_COMPILE:
40
- import torch._dynamo
41
- torch._dynamo.config.suppress_errors = True
42
-
43
- def print_vram():
44
- free = torch.cuda.mem_get_info()[0] / (1024 ** 3)
45
- total = torch.cuda.mem_get_info()[1] / (1024 ** 3)
46
- print(f"VRAM: {total - free:.2f}/{total:.2f}GB")
47
-
48
- def pil_to_base64(image: Image.Image, modelname: str, prompt: str, height: int, width: int, steps: int, cfg: float, seed: int) -> str:
49
- import base64
50
- from io import BytesIO
51
- import json
52
- from PIL import PngImagePlugin
53
- metadata = {"prompt": prompt, "num_inference_steps": steps, "guidance_scale": cfg, "seed": seed, "resolution": f"{width} x {height}",
54
- "Model": {"Model": modelname.split("/")[-1]}}
55
- info = PngImagePlugin.PngInfo()
56
- info.add_text("metadata", json.dumps(metadata))
57
- buffered = BytesIO()
58
- image.save(buffered, "PNG", pnginfo=info)
59
- return base64.b64encode(buffered.getvalue()).decode('ascii')
60
-
61
- def load_te2(repo_id: str, dtype: torch.dtype) -> Any:
62
- if IS_4BIT:
63
- nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
64
- text_encoder_2 = T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_2", torch_dtype=dtype, quantization_config=nf4_config)
65
- else:
66
- text_encoder_2 = T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_2", torch_dtype=dtype)
67
- quantize(text_encoder_2, weights=qfloat8)
68
- freeze(text_encoder_2)
69
- return text_encoder_2
70
-
71
- def load_pipeline_stable(repo_id: str, dtype: torch.dtype) -> Any:
72
- quantization_config = TorchAoConfig("int4dq" if IS_4BIT else "float8dq" if IS_CC90 else "int8wo")
73
- vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
74
- pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, text_encoder_2=load_te2(repo_id, dtype), torch_dtype=dtype, quantization_config=quantization_config)
75
- pipe.transformer.fuse_qkv_projections()
76
- pipe.vae.fuse_qkv_projections()
77
- return pipe
78
-
79
- def load_pipeline_lowvram(repo_id: str, dtype: torch.dtype) -> Any:
80
- int4_config = TorchAoConfig("int4dq")
81
- float8_config = TorchAoConfig("float8dq")
82
- vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
83
- transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype, quantization_config=float8_config)
84
- pipe = FluxPipeline.from_pretrained(repo_id, vae=None, transformer=None, text_encoder_2=load_te2(repo_id, dtype), torch_dtype=dtype, quantization_config=int4_config)
85
- pipe.transformer = transformer
86
- pipe.vae = vae
87
- #pipe.transformer.fuse_qkv_projections()
88
- #pipe.vae.fuse_qkv_projections()
89
- pipe.to("cuda")
90
- return pipe
91
-
92
- def load_pipeline_compile(repo_id: str, dtype: torch.dtype) -> Any:
93
- quantization_config = TorchAoConfig("int4dq" if IS_4BIT else "float8dq" if IS_CC90 else "int8wo")
94
- vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
95
- pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, text_encoder_2=load_te2(repo_id, dtype), torch_dtype=dtype, quantization_config=quantization_config)
96
- pipe.transformer.fuse_qkv_projections()
97
- pipe.vae.fuse_qkv_projections()
98
- pipe.transformer.to(memory_format=torch.channels_last)
99
- pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
100
- pipe.vae.to(memory_format=torch.channels_last)
101
- pipe.vae = torch.compile(pipe.vae, mode="max-autotune", fullgraph=True)
102
- return pipe
103
-
104
- def load_pipeline_autoquant(repo_id: str, dtype: torch.dtype) -> Any:
105
- pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype)
106
- pipe.transformer.fuse_qkv_projections()
107
- pipe.vae.fuse_qkv_projections()
108
- pipe.transformer.to(memory_format=torch.channels_last)
109
- pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
110
- pipe.vae.to(memory_format=torch.channels_last)
111
- pipe.vae = torch.compile(pipe.vae, mode="max-autotune", fullgraph=True)
112
- pipe.transformer = autoquant(pipe.transformer, error_on_unseen=False)
113
- pipe.vae = autoquant(pipe.vae, error_on_unseen=False)
114
- return pipe
115
-
116
- def load_pipeline_turbo(repo_id: str, dtype: torch.dtype) -> Any:
117
- pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype)
118
- pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
119
- pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
120
- pipe.fuse_lora()
121
- pipe.unload_lora_weights()
122
- pipe.transformer.fuse_qkv_projections()
123
- pipe.vae.fuse_qkv_projections()
124
- weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
125
- quantize_(pipe.transformer, weight, device="cuda")
126
- quantize_(pipe.vae, weight, device="cuda")
127
- return pipe
128
-
129
- def load_pipeline_turbo_compile(repo_id: str, dtype: torch.dtype) -> Any:
130
- pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype)
131
- pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
132
- pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
133
- pipe.fuse_lora()
134
- pipe.unload_lora_weights()
135
- pipe.transformer.fuse_qkv_projections()
136
- pipe.vae.fuse_qkv_projections()
137
- weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
138
- quantize_(pipe.transformer, weight, device="cuda")
139
- quantize_(pipe.vae, weight, device="cuda")
140
- pipe.transformer.to(memory_format=torch.channels_last)
141
- pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
142
- pipe.vae.to(memory_format=torch.channels_last)
143
- pipe.vae = torch.compile(pipe.vae, mode="max-autotune", fullgraph=True)
144
- return pipe
145
-
146
- def load_pipeline_opt(repo_id: str, dtype: torch.dtype) -> Any:
147
- quantization_config = TorchAoConfig("int4dq" if IS_4BIT else "float8dq" if IS_CC90 else "int8wo")
148
- weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
149
- transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype)
150
- transformer.fuse_qkv_projections()
151
- if IS_CC90: quantize_(transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
152
- elif IS_CC89: quantize_(transformer, float8_dynamic_activation_float8_weight(), device="cuda")
153
- else: quantize_(transformer, weight, device="cuda")
154
- transformer.to(memory_format=torch.channels_last)
155
- transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
156
- vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
157
- vae.fuse_qkv_projections()
158
- if IS_CC90: quantize_(vae, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
159
- elif IS_CC89: quantize_(vae, float8_dynamic_activation_float8_weight(), device="cuda")
160
- else: quantize_(vae, weight, device="cuda")
161
- vae.to(memory_format=torch.channels_last)
162
- vae = torch.compile(vae, mode="max-autotune", fullgraph=True)
163
- pipe = FluxPipeline.from_pretrained(repo_id, transformer=None, vae=None, torch_dtype=dtype, quantization_config=quantization_config)
164
- pipe.transformer = transformer
165
- pipe.vae = vae
166
- return pipe
167
 
168
  class EndpointHandler:
169
  def __init__(self, path=""):
170
- repo_id = "NoMoreCopyrightOrg/flux-dev-8step" if IS_TURBO else "NoMoreCopyrightOrg/flux-dev"
171
- self.repo_id = repo_id
172
- dtype = torch.bfloat16
173
- #dtype = torch.float16 # for older nVidia GPUs
174
- print_vram()
175
- print("Loading pipeline...")
176
- if IS_AUTOQ: self.pipeline = load_pipeline_autoquant(repo_id, dtype)
177
- elif IS_COMPILE: self.pipeline = load_pipeline_opt(repo_id, dtype)
178
- elif IS_LVRAM and IS_CC89: self.pipeline = load_pipeline_lowvram(repo_id, dtype)
179
- else: self.pipeline = load_pipeline_stable(repo_id, dtype)
180
- if IS_PARA: apply_cache_on_pipe(self.pipeline, residual_diff_threshold=0.12)
181
- gc.collect()
182
- torch.cuda.empty_cache()
183
  self.pipeline.enable_vae_slicing()
184
  self.pipeline.enable_vae_tiling()
185
- self.pipeline.to("cuda")
186
- print_vram()
187
-
188
- def __call__(self, data: Dict[str, Any]) -> Image.Image:
189
- logger.info(f"Received incoming request with {data=}")
190
-
191
- if "inputs" in data and isinstance(data["inputs"], str):
192
- prompt = data.pop("inputs")
193
- elif "prompt" in data and isinstance(data["prompt"], str):
194
- prompt = data.pop("prompt")
195
- else:
196
- raise ValueError(
197
- "Provided input body must contain either the key `inputs` or `prompt` with the"
198
- " prompt to use for the image generation, and it needs to be a non-empty string."
199
- )
200
-
201
- parameters = data.pop("parameters", {})
202
-
203
- num_inference_steps = parameters.get("num_inference_steps", 8 if IS_TURBO else 28)
204
- width = parameters.get("width", 1024)
205
- height = parameters.get("height", 1024)
206
- guidance_scale = parameters.get("guidance_scale", 3.5)
207
-
208
- # seed generator (seed cannot be provided as is but via a generator)
209
- seed = parameters.get("seed", 0)
210
- generator = torch.manual_seed(seed)
211
-
212
- start = time.time()
213
- image = self.pipeline( # type: ignore
214
- prompt,
215
- height=height,
216
- width=width,
217
- guidance_scale=guidance_scale,
218
- num_inference_steps=num_inference_steps,
219
- generator=generator,
220
- output_type="pil",
221
- ).images[0]
222
- end = time.time()
223
- print(f'Elapsed {end - start:.3f} sec. / prompt:"{prompt}" / size:{width}x{height} / steps:{num_inference_steps} / guidance scale:{guidance_scale} / seed:{seed}')
224
-
225
- return pil_to_base64(image, self.repo_id, prompt, height, width, num_inference_steps, guidance_scale, seed)
226
 
 
 
 
 
 
 
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import Any, Dict, Union
 
 
3
  from PIL import Image
 
4
  import torch
5
+ from diffusers import FluxPipeline
 
 
 
 
 
6
  from huggingface_inference_toolkit.logging import logger
7
+ from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
8
+ from torchao.quantization import autoquant
9
+ import time
10
+ import gc
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # Set high precision for float32 matrix multiplications.
13
  # This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
14
  torch.set_float32_matmul_precision("high")
15
 
16
+ import torch._dynamo
17
+ torch._dynamo.config.suppress_errors = False # for debugging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  class EndpointHandler:
20
  def __init__(self, path=""):
21
+ self.pipeline = FluxPipeline.from_pretrained(
22
+ "NoMoreCopyrightOrg/flux-dev",
23
+ torch_dtype=torch.bfloat16,
24
+ ).to("cuda")
 
 
 
 
 
 
 
 
 
25
  self.pipeline.enable_vae_slicing()
26
  self.pipeline.enable_vae_tiling()
27
+ self.pipeline.transformer.fuse_qkv_projections()
28
+ self.pipeline.vae.fuse_qkv_projections()
29
+ self.pipeline.transformer.to(memory_format=torch.channels_last)
30
+ self.pipeline.vae.to(memory_format=torch.channels_last)
31
+ apply_cache_on_pipe(self.pipeline, residual_diff_threshold=0.12)
32
+ self.pipeline.transformer = torch.compile(
33
+ self.pipeline.transformer, mode="max-autotune-no-cudagraphs",
34
+ )
35
+ self.pipeline.vae = torch.compile(
36
+ self.pipeline.vae, mode="max-autotune-no-cudagraphs",
37
+ )
38
+ self.pipeline.transformer = autoquant(self.pipeline.transformer, error_on_unseen=False)
39
+ self.pipeline.vae = autoquant(self.pipeline.vae, error_on_unseen=False)
40
+
41
+ gc.collect()
42
+ torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ start_time = time.time()
45
+ print("Start warming-up pipeline")
46
+ self.pipeline("Hello world!") # Warm-up for compiling
47
+ end_time = time.time()
48
+ time_taken = end_time - start_time
49
+ print(f"Time taken: {time_taken:.2f} seconds")
50
 
51
+ def __call__(self, data: Dict[str, Any]) -> Union[Image.Image, None]:
52
+ logger.info(f"Received incoming request with {data=}")
53
+ try:
54
+ if "inputs" in data and isinstance(data["inputs"], str):
55
+ prompt = data.pop("inputs")
56
+ elif "prompt" in data and isinstance(data["prompt"], str):
57
+ prompt = data.pop("prompt")
58
+ else:
59
+ raise ValueError(
60
+ "Provided input body must contain either the key `inputs` or `prompt` with the"
61
+ " prompt to use for the image generation, and it needs to be a non-empty string."
62
+ )
63
+
64
+ parameters = data.pop("parameters", {})
65
+
66
+ num_inference_steps = parameters.get("num_inference_steps", 28)
67
+ width = parameters.get("width", 1024)
68
+ height = parameters.get("height", 1024)
69
+ #guidance_scale = parameters.get("guidance_scale", 3.5)
70
+ guidance_scale = parameters.get("guidance", 3.5)
71
+
72
+ # seed generator (seed cannot be provided as is but via a generator)
73
+ seed = parameters.get("seed", 0)
74
+ generator = torch.manual_seed(seed)
75
+ start_time = time.time()
76
+ result = self.pipeline( # type: ignore
77
+ prompt,
78
+ height=height,
79
+ width=width,
80
+ guidance_scale=guidance_scale,
81
+ num_inference_steps=num_inference_steps,
82
+ generator=generator,
83
+ ).images[0]
84
+ end_time = time.time()
85
+ time_taken = end_time - start_time
86
+ print(f"Time taken: {time_taken:.2f} seconds")
87
+
88
+ return result
89
+ except Exception as e:
90
+ print(e)
91
+ return None
requirements.txt CHANGED
@@ -1,21 +1,16 @@
1
- --extra-index-url https://download.pytorch.org/whl/cu126
2
- torch>=2.6.0
3
  torchvision
4
  torchaudio
5
  huggingface_hub
6
- torchao>=0.9.0
7
- diffusers>=0.32.2
8
  peft
9
- transformers==4.48.3
10
- accelerate
11
- numpy
12
  scipy
13
  Pillow
14
  sentencepiece
15
  protobuf
16
  triton
17
- gemlite
18
- tabulate
19
- para-attn
20
- bitsandbytes
21
- optimum-quanto
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+ torch==2.6.0
3
  torchvision
4
  torchaudio
5
  huggingface_hub
6
+ torchao==0.9.0
7
+ diffusers==0.32.2
8
  peft
9
+ transformers<=4.48.3
10
+ numpy<2
 
11
  scipy
12
  Pillow
13
  sentencepiece
14
  protobuf
15
  triton
16
+ para-attn==0.3.23