MyApricity commited on
Commit
bc8ab2f
·
verified ·
1 Parent(s): ad29a1a

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +96 -120
src/pipeline.py CHANGED
@@ -2,21 +2,27 @@ import os
2
  import torch
3
  import torch._dynamo
4
  import gc
5
- import json
6
  import transformers
7
  from huggingface_hub.constants import HF_HUB_CACHE
8
- from transformers import T5EncoderModel, T5TokenizerFast
 
 
9
  from PIL.Image import Image
10
- from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny, FluxTransformer2DModel, DiffusionPipeline
11
  from pipelines.models import TextToImageRequest
12
- from optimum.quanto import requantize
13
- from torch import Generator
14
- from torch._dynamo import config
15
- from torch._inductor import config as ind_config
16
- from typing import Dict, Any, Callable
17
  from functools import wraps
18
 
19
- def error_handler(func: Callable):
 
 
 
 
 
 
 
 
 
20
  @wraps(func)
21
  def wrapper(*args, **kwargs):
22
  try:
@@ -26,121 +32,91 @@ def error_handler(func: Callable):
26
  return None
27
  return wrapper
28
 
29
- class TorchOptimizer:
30
- def optimize_settings(self):
31
- torch.backends.cuda.matmul.allow_tf32 = True
32
- torch.backends.cudnn.allow_tf32 = True
33
- torch.backends.cudnn.benchmark = True
34
- torch.set_float32_matmul_precision("high")
35
-
36
- def clear_cache(self):
37
- torch.cuda.empty_cache()
38
- torch.cuda.reset_max_memory_allocated()
39
- torch.cuda.reset_peak_memory_stats()
40
-
41
- class PipelineManager:
42
- def __init__(self):
43
- self.ckpt_root = "MyApricity/FLUX_OPT_SCHNELL_1.2"
44
- self.revision_root = "488528b6f815bff1bbc747cf1e0947c77c544665"
45
- self.pipeline = None
46
- self.optimizer = TorchOptimizer()
47
-
48
- # Configure environment
49
- torch._dynamo.config.suppress_errors = True
50
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
51
- os.environ["TOKENIZERS_PARALLELISM"] = "True"
52
-
53
- # Initialize torch settings
54
- self.optimizer.optimize_settings()
55
-
56
- # Pre-load the pipeline during initialization
57
- print("Initializing pipeline...")
58
- self.pipeline = self.load_pipeline()
59
- print("Pipeline initialization complete.")
60
-
61
 
62
- def load_transformer(self):
63
- transformer_path = os.path.join(
64
- HF_HUB_CACHE,
65
- "models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  )
67
- return FluxTransformer2DModel.from_pretrained(
68
- transformer_path,
69
- torch_dtype=torch.bfloat16,
70
- use_safetensors=False
 
71
  )
72
 
73
- @error_handler
74
- def optimize_pipeline(self, pipe):
75
- # Fuse QKV projections
76
- pipe.transformer.fuse_qkv_projections()
77
- pipe.vae.fuse_qkv_projections()
78
-
79
- # Optimize memory layout
80
- pipe.transformer.to(memory_format=torch.channels_last)
81
- pipe.vae.to(memory_format=torch.channels_last)
82
-
83
- # Configure torch inductor
84
- config = torch._inductor.config
85
- config.disable_progress = False
86
- config.conv_1x1_as_mm = True
87
-
88
- # Compile modules
89
- pipe.transformer = torch.compile(
90
- pipe.transformer,
91
- mode="max-autotune",
92
- fullgraph=True
93
- )
94
- pipe.vae.decode = torch.compile(
95
- pipe.vae.decode,
96
- mode="max-autotune",
97
- fullgraph=True
98
- )
99
 
100
- return pipe
101
 
102
- def load_pipeline(self):
103
- # Load transformer model
104
- transformer_model = self.load_transformer()
 
 
105
 
106
- # Create pipeline
107
- pipe = DiffusionPipeline.from_pretrained(
108
- self.ckpt_root,
109
- revision=self.revision_root,
110
- transformer=transformer_model,
111
- torch_dtype=torch.bfloat16
112
- )
113
- pipe.to("cuda")
114
-
115
- # Optimize pipeline
116
- pipe_ops = self.optimize_pipeline(pipe)
117
- if pipe_ops!=None:
118
- pipe = pipe_ops
119
-
120
- # Trigger compilation
121
- print("Running torch compilation...")
122
- pipe(
123
- "dummy prompt to trigger torch compilation",
124
- output_type="pil",
125
- num_inference_steps=4
126
- ).images[0]
127
- print("Finished torch compilation")
128
-
129
- return pipe
130
-
131
- def run_inference(self, request: TextToImageRequest) -> Image:
132
- if self.pipeline is None:
133
- self.pipeline = self.load_pipeline()
134
-
135
- self.optimizer.clear_cache()
136
- generator = Generator(self.pipeline.device).manual_seed(request.seed)
137
-
138
- return self.pipeline(
139
- request.prompt,
140
- generator=generator,
141
- guidance_scale=0.0,
142
- num_inference_steps=4,
143
- max_sequence_length=256,
144
- height=request.height,
145
- width=request.width,
146
- ).images[0]
 
2
  import torch
3
  import torch._dynamo
4
  import gc
 
5
  import transformers
6
  from huggingface_hub.constants import HF_HUB_CACHE
7
+ from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
8
+ from torch import Generator
9
+ from diffusers import FluxTransformer2DModel, DiffusionPipeline
10
  from PIL.Image import Image
11
+ from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
12
  from pipelines.models import TextToImageRequest
13
+ from typing import Dict, Any
 
 
 
 
14
  from functools import wraps
15
 
16
+ # Global settings
17
+ torch._dynamo.config.suppress_errors = True
18
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
19
+ os.environ["TOKENIZERS_PARALLELISM"] = "True"
20
+
21
+ ckpt_root = "MyApricity/FLUX_OPT_SCHNELL_1.2"
22
+ revision_root = "488528b6f815bff1bbc747cf1e0947c77c544665"
23
+ Pipeline = None
24
+
25
+ def error_handler(func):
26
  @wraps(func)
27
  def wrapper(*args, **kwargs):
28
  try:
 
32
  return None
33
  return wrapper
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ def remove_cache():
37
+ torch.cuda.empty_cache()
38
+ torch.cuda.reset_max_memory_allocated()
39
+ gc.collect()
40
+ torch.cuda.reset_peak_memory_stats()
41
+
42
+ @error_handler
43
+ def optimize_pipeline(pipe):
44
+ # Fuse QKV projections
45
+ pipe.transformer.fuse_qkv_projections()
46
+ pipe.vae.fuse_qkv_projections()
47
+
48
+ # Optimize memory layout
49
+ pipe.transformer.to(memory_format=torch.channels_last)
50
+ pipe.vae.to(memory_format=torch.channels_last)
51
+
52
+ # Configure torch inductor
53
+ from torch._inductor import config as ind_config
54
+ ind_config.disable_progress = False
55
+ ind_config.conv_1x1_as_mm = True
56
+
57
+ return pipe
58
+
59
+
60
+ def load_pipeline() -> Pipeline:
61
+
62
+ transformer_path = os.path.join(
63
+ HF_HUB_CACHE,
64
+ "models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665"
65
+ )
66
+
67
+ transformer = FluxTransformer2DModel.from_pretrained(
68
+ transformer_path,
69
+ torch_dtype=torch.bfloat16,
70
+ use_safetensors=False
71
+ )
72
+
73
+ try:
74
+ pipeline = DiffusionPipeline.from_pretrained(
75
+ ckpt_root,
76
+ revision=revision_root,
77
+ transformer=transformer,
78
+ torch_dtype=torch.bfloat16
79
  )
80
+ except:
81
+ pipeline = DiffusionPipeline.from_pretrained(
82
+ ckpt_root,
83
+ revision=revision_root,
84
+ torch_dtype=torch.bfloat16
85
  )
86
 
87
+ pipeline.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
 
89
 
90
+ # Apply optimizations
91
+ ___ops_pipeline = optimize_pipeline(pipeline)
92
+
93
+ if pipeline is not None:
94
+ pipeline = ___ops_pipeline
95
 
96
+ # Warmup runs
97
+ prompt_xnxx = "pantomorphia, dorsilateral, nonlife, unenthusiastic, quadriform, throatlet, bluntish, soldierize"
98
+ pipeline(
99
+ prompt=prompt_xnxx,
100
+ width=1024,
101
+ height=1024,
102
+ guidance_scale=0.0,
103
+ num_inference_steps=4,
104
+ max_sequence_length=256
105
+ )
106
+
107
+ return pipeline
108
+
109
+ @torch.no_grad()
110
+ def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
111
+ remove_cache()
112
+ generator = Generator(pipeline.device).manual_seed(request.seed)
113
+
114
+ return pipeline(
115
+ request.prompt,
116
+ generator=generator,
117
+ guidance_scale=0.0,
118
+ num_inference_steps=4,
119
+ max_sequence_length=256,
120
+ height=request.height,
121
+ width=request.width,
122
+ ).images[0]