YOURNAME commited on
Commit
0629499
·
1 Parent(s): e09c84c
Files changed (2) hide show
  1. src/main.py +3 -3
  2. src/pipeline.py +72 -73
src/main.py CHANGED
@@ -7,14 +7,14 @@ from pathlib import Path
7
  from PIL.JpegImagePlugin import JpegImageFile
8
  from pipelines.models import TextToImageRequest
9
 
10
- from pipeline import load_pipeline, infer
11
 
12
  SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock")
13
 
14
 
15
  def main():
16
  print(f"Loading pipeline")
17
- pipeline = load_pipeline()
18
 
19
  print(f"Pipeline loaded! , creating socket at '{SOCKET}'")
20
 
@@ -36,7 +36,7 @@ def main():
36
 
37
  return
38
 
39
- image = infer(request, pipeline)
40
 
41
  data = BytesIO()
42
  image.save(data, format=JpegImageFile.format)
 
7
  from PIL.JpegImagePlugin import JpegImageFile
8
  from pipelines.models import TextToImageRequest
9
 
10
+ from pipeline import pipeline_loader, inference
11
 
12
  SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock")
13
 
14
 
15
  def main():
16
  print(f"Loading pipeline")
17
+ pipeline = pipeline_loader()
18
 
19
  print(f"Pipeline loaded! , creating socket at '{SOCKET}'")
20
 
 
36
 
37
  return
38
 
39
+ image = inference(request, pipeline)
40
 
41
  data = BytesIO()
42
  image.save(data, format=JpegImageFile.format)
src/pipeline.py CHANGED
@@ -6,22 +6,24 @@ import gc
6
  import json
7
  import transformers
8
  from huggingface_hub.constants import HF_HUB_CACHE
9
- from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
10
-
11
- # ApricityApricityApricityApricityApricityApricityApricityApricityApricityApricityApricity
12
-
13
- from torch import Generator
14
- from diffusers import FluxTransformer2DModel, DiffusionPipeline
15
-
16
  from PIL.Image import Image
17
  from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
18
  from pipelines.models import TextToImageRequest
19
  from optimum.quanto import requantize
20
  import json
21
 
 
 
22
 
23
- # ApricityApricityApricityApricityApricityApricityApricityApricityApricityApricityApricity
 
24
 
 
 
 
 
 
25
 
26
  torch._dynamo.config.suppress_errors = True
27
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
@@ -32,102 +34,99 @@ revision_root = "488528b6f815bff1bbc747cf1e0947c77c544665"
32
  Pipeline = None
33
  use_com = False
34
 
35
- import torch
36
- import math
37
- from typing import Dict, Any
38
 
39
- def remove_cache():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  torch.cuda.empty_cache()
41
  torch.cuda.reset_max_memory_allocated()
42
- gc.collect()
43
  torch.cuda.reset_peak_memory_stats()
44
 
45
 
 
 
46
 
47
- def text_t5_loader() -> T5EncoderModel:
48
  print("Loading text encoder...")
49
- text_encoder = T5EncoderModel.from_pretrained(
50
  "city96/t5-v1_1-xxl-encoder-bf16",
51
  revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
52
  torch_dtype=torch.bfloat16,
53
  )
54
- return text_encoder.to(memory_format=torch.channels_last)
55
-
56
 
57
- class StableDiffusionTransformerCompile:
58
- def __init__(self, pipeline, optimize=False):
59
- self.pipeline = pipeline
60
- self.optimize = optimize
61
- if self.optimize:
62
- self.model_compiling()
63
 
64
- def model_compiling(self):
65
- # Staff doing here
66
- self.pipeline.unet = torch.compile(self.pipeline.unet)
67
-
68
- def __call__(self, *args, **kwargs):
69
- return self.pipeline(*args, **kwargs)
70
-
71
- def load_pipeline() -> Pipeline:
72
 
73
- text_t5_encoder = text_t5_loader()
74
 
75
- transformer_path__ = os.path.join(HF_HUB_CACHE, "models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665")
76
-
77
- transformer__ = FluxTransformer2DModel.from_pretrained(transformer_path__, torch_dtype=torch.bfloat16, use_safetensors=False)
 
 
78
 
79
  try:
80
- pipeline = DiffusionPipeline.from_pretrained(ckpt_root,
81
- revision=revision_root,
82
- transformer=transformer__,
83
- torch_dtype=torch.bfloat16)
84
 
85
- except:
86
- pipeline = DiffusionPipeline.from_pretrained(ckpt_root,
87
- revision=revision_root,
88
- torch_dtype=torch.bfloat16)
89
 
90
- pipeline.to("cuda")
 
 
91
 
92
- try:
93
- compiled_pipeline = StableDiffusionTransformerCompile(pipeline, optimize=False)
 
 
94
 
95
- if use_com:
96
- pipeline = compiled_pipeline
97
- else:
98
- print("Currently not compling affectively")
 
 
 
99
 
100
- pipeline.disable_vae_compress()
101
- pipeline.text_encoder_2 = text_t5_encoder
102
 
103
- except:
104
- print("pipeline")
105
-
 
 
106
 
107
- prompt_1 = "albaspidin, pillmonger, palaeocrystalline"
108
- pipeline(prompt=prompt_1,
109
- width=1024,
110
- height=1024,
111
- guidance_scale=0.0,
112
- num_inference_steps=4,
113
- max_sequence_length=256)
114
 
115
- prompt_2 = "obe, kilometrage, circuition"
116
- pipeline(prompt=prompt_2,
117
- width=1024,
118
- height=1024,
119
- guidance_scale=0.0,
120
- num_inference_steps=4,
121
- max_sequence_length=256)
 
122
 
123
- return pipeline
 
124
 
125
 
126
  @torch.no_grad()
127
- def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
128
 
129
- remove_cache()
130
- # remove cache here for better result
131
  generator = Generator(pipeline.device).manual_seed(request.seed)
132
 
133
  return pipeline(
 
6
  import json
7
  import transformers
8
  from huggingface_hub.constants import HF_HUB_CACHE
9
+ from transformers import T5EncoderModel, T5TokenizerFast
 
 
 
 
 
 
10
  from PIL.Image import Image
11
  from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
12
  from pipelines.models import TextToImageRequest
13
  from optimum.quanto import requantize
14
  import json
15
 
16
+ from torch import Generator
17
+ from diffusers import FluxTransformer2DModel, DiffusionPipeline
18
 
19
+ # MYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMY
20
+ # ApricityApricityApricityApricityApricityApricityApricityApricityApricityApricityApricityApricityApricity
21
 
22
+ from torch._dynamo import config
23
+ from torch._inductor import config as ind_config
24
+ import torch
25
+ import math
26
+ from typing import Dict, Any
27
 
28
  torch._dynamo.config.suppress_errors = True
29
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
 
34
  Pipeline = None
35
  use_com = False
36
 
 
 
 
37
 
38
+ def optimize_torch():
39
+ torch.backends.cuda.matmul.allow_tf32 = True
40
+ torch.backends.cudnn.allow_tf32 = True
41
+ torch.backends.cudnn.benchmark = True
42
+ # torch.backends.cudnn.benchmark_limit = 20
43
+ torch.set_float32_matmul_precision("high")
44
+ # config.cache_size_limit = 10000000000
45
+ # ind_config.shape_padding = True
46
+
47
+ try:
48
+ optimize_torch()
49
+ except:
50
+ print("nothing wrong")
51
+
52
+ def delete_ca_che():
53
  torch.cuda.empty_cache()
54
  torch.cuda.reset_max_memory_allocated()
 
55
  torch.cuda.reset_peak_memory_stats()
56
 
57
 
58
+
59
+ def pipeline_loader() -> Pipeline:
60
 
 
61
  print("Loading text encoder...")
62
+ en = T5EncoderModel.from_pretrained(
63
  "city96/t5-v1_1-xxl-encoder-bf16",
64
  revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
65
  torch_dtype=torch.bfloat16,
66
  )
 
 
67
 
68
+ transformer_path_main = os.path.join(HF_HUB_CACHE, "models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665")
 
 
 
 
 
69
 
70
+ transformer_model = FluxTransformer2DModel.from_pretrained(transformer_path_main, torch_dtype=torch.bfloat16, use_safetensors=False)
 
 
 
 
 
 
 
71
 
 
72
 
73
+ pipe = DiffusionPipeline.from_pretrained(ckpt_root,
74
+ revision=revision_root,
75
+ transformer=transformer_model,
76
+ torch_dtype=torch.bfloat16)
77
+ pipe.to("cuda")
78
 
79
  try:
 
 
 
 
80
 
81
+ # fuse QKV projections in Transformer and VAE
82
+ pipe.transformer.fuse_qkv_projections()
83
+ pipe.vae.fuse_qkv_projections()
 
84
 
85
+ # switch memory layout to Torch's preferred, channels_last
86
+ pipe.transformer.to(memory_format=torch.channels_last)
87
+ pipe.vae.to(memory_format=torch.channels_last)
88
 
89
+ # set torch compile flags
90
+ config = torch._inductor.config
91
+ config.disable_progress = False # show progress bar
92
+ config.conv_1x1_as_mm = True # treat 1x1 convolutions as matrix muls
93
 
94
+ # tag the compute-intensive modules, the Transformer and VAE decoder, for compilation
95
+ pipe.transformer = torch.compile(
96
+ pipe.transformer, mode="max-autotune", fullgraph=True
97
+ )
98
+ pipe.vae.decode = torch.compile(
99
+ pipe.vae.decode, mode="max-autotune", fullgraph=True
100
+ )
101
 
102
+ # trigger torch compilation
103
+ print("running torch compiliation..")
104
 
105
+ pipe(
106
+ "dummy prompt to trigger torch compilation",
107
+ output_type="pil",
108
+ num_inference_steps=4, # use ~50 for [dev], smaller for [schnell]
109
+ ).images[0]
110
 
111
+ print("finished torch compilation")
 
 
 
 
 
 
112
 
113
+ except:
114
+
115
+ pipe(
116
+ "a beautiful girl",
117
+ output_type="pil",
118
+ num_inference_steps=4, # use ~50 for [dev], smaller for [schnell]
119
+ ).images[0]
120
+ print("Pass error")
121
 
122
+
123
+ return pipe
124
 
125
 
126
  @torch.no_grad()
127
+ def inference(request: TextToImageRequest, pipeline: Pipeline) -> Image:
128
 
129
+ delete_ca_che()
 
130
  generator = Generator(pipeline.device).manual_seed(request.seed)
131
 
132
  return pipeline(