YOURNAME commited on
Commit
e09c84c
·
1 Parent(s): 4870f5c
Files changed (2) hide show
  1. pyproject.toml +2 -8
  2. src/pipeline.py +59 -53
pyproject.toml CHANGED
@@ -23,20 +23,14 @@ dependencies = [
23
  ]
24
 
25
  [[tool.edge-maxxing.models]]
26
- repository = "black-forest-labs/FLUX.1-schnell"
27
- revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
28
 
29
  [[tool.edge-maxxing.models]]
30
  repository = "city96/t5-v1_1-xxl-encoder-bf16"
31
  revision = "1b9c856aadb864af93c1dcdc226c2774fa67bc86"
32
 
33
- [[tool.edge-maxxing.models]]
34
- repository = "MyApricity/Vae_Only"
35
- revision = "a47d57702caf8ff0c0e21d30b93f9d3297b81920"
36
 
37
- [[tool.edge-maxxing.models]]
38
- repository = "MyApricity/Flux_Transformer_float8"
39
- revision = "66c5f182385555a00ec90272ab711bb6d3c197db"
40
 
41
  [project.scripts]
42
  start_inference = "main:main"
 
23
  ]
24
 
25
  [[tool.edge-maxxing.models]]
26
+ repository = "MyApricity/FLUX_OPT_SCHNELL_1.2"
27
+ revision = "488528b6f815bff1bbc747cf1e0947c77c544665"
28
 
29
  [[tool.edge-maxxing.models]]
30
  repository = "city96/t5-v1_1-xxl-encoder-bf16"
31
  revision = "1b9c856aadb864af93c1dcdc226c2774fa67bc86"
32
 
 
 
 
33
 
 
 
 
34
 
35
  [project.scripts]
36
  start_inference = "main:main"
src/pipeline.py CHANGED
@@ -8,7 +8,8 @@ import transformers
8
  from huggingface_hub.constants import HF_HUB_CACHE
9
  from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
10
 
11
- from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
 
12
  from torch import Generator
13
  from diffusers import FluxTransformer2DModel, DiffusionPipeline
14
 
@@ -19,60 +20,49 @@ from optimum.quanto import requantize
19
  import json
20
 
21
 
 
22
 
23
 
24
  torch._dynamo.config.suppress_errors = True
25
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
26
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
27
 
28
- CHECKPOINT = "black-forest-labs/FLUX.1-schnell"
29
- REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
30
  Pipeline = None
31
-
32
 
33
  import torch
34
  import math
35
  from typing import Dict, Any
36
 
37
  def remove_cache():
38
- gc.collect()
39
  torch.cuda.empty_cache()
40
  torch.cuda.reset_max_memory_allocated()
 
41
  torch.cuda.reset_peak_memory_stats()
42
 
43
 
44
- class InitializingModel:
45
 
46
- @staticmethod
47
- def load_text_encoder() -> T5EncoderModel:
48
- print("Loading text encoder...")
49
- text_encoder = T5EncoderModel.from_pretrained(
50
- "city96/t5-v1_1-xxl-encoder-bf16",
51
- revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
52
- torch_dtype=torch.bfloat16,
53
- )
54
- return text_encoder.to(memory_format=torch.channels_last)
55
 
56
- @staticmethod
57
- def load_transformer(trans_path: str) -> FluxTransformer2DModel:
58
- print("Loading transformer model...")
59
- transformer = FluxTransformer2DModel.from_pretrained(
60
- trans_path,
61
- torch_dtype=torch.bfloat16,
62
- use_safetensors=False,
63
- )
64
- return transformer.to(memory_format=torch.channels_last)
65
 
66
-
67
- class CompileTransformerDiffusion:
68
  def __init__(self, pipeline, optimize=False):
69
  self.pipeline = pipeline
70
  self.optimize = optimize
71
  if self.optimize:
72
- self._compile_model()
73
 
74
- def _compile_model(self):
75
- print("Compiling transformer model for optimized diffusion...")
76
  self.pipeline.unet = torch.compile(self.pipeline.unet)
77
 
78
  def __call__(self, *args, **kwargs):
@@ -80,39 +70,55 @@ class CompileTransformerDiffusion:
80
 
81
  def load_pipeline() -> Pipeline:
82
 
 
83
 
84
- base_transformer_path = os.path.join(HF_HUB_CACHE, "models--MyApricity--Flux_Transformer_float8/snapshots/66c5f182385555a00ec90272ab711bb6d3c197db")
85
- base_transformer = InitializingModel.load_transformer(base_transformer_path)
86
 
87
- text_encoder_2 = InitializingModel.load_text_encoder()
88
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- pipeline = DiffusionPipeline.from_pretrained(CHECKPOINT,
91
- revision=REVISION,
92
- transformer=base_transformer,
93
- text_encoder_2=text_encoder_2,
94
- torch_dtype=torch.bfloat16)
95
  pipeline.to("cuda")
 
96
  try:
97
- pipeline.disable_vae_slice()
98
- compiled_pipeline = CompileTransformerDiffusion(pipeline, optimize=False)
 
 
 
 
 
 
 
 
99
  except:
100
- print("Stay safe here pipeline")
101
 
102
 
103
- promts_listing = [
104
- "sellate, Tremellales, thro, albescent",
105
- "must return non duplicate",
106
- "albaspidin, pillmonger, palaeocrystalline"
107
- ]
108
-
109
- for p in promts_listing:
110
- pipeline(prompt=p,
111
- width=1024,
112
- height=1024,
113
- guidance_scale=0.0,
114
- num_inference_steps=4,
115
- max_sequence_length=256)
 
 
116
 
117
  return pipeline
118
 
 
8
  from huggingface_hub.constants import HF_HUB_CACHE
9
  from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
10
 
11
+ # ApricityApricityApricityApricityApricityApricityApricityApricityApricityApricityApricity
12
+
13
  from torch import Generator
14
  from diffusers import FluxTransformer2DModel, DiffusionPipeline
15
 
 
20
  import json
21
 
22
 
23
+ # ApricityApricityApricityApricityApricityApricityApricityApricityApricityApricityApricity
24
 
25
 
26
  torch._dynamo.config.suppress_errors = True
27
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
28
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
29
 
30
+ ckpt_root = "MyApricity/FLUX_OPT_SCHNELL_1.2"
31
+ revision_root = "488528b6f815bff1bbc747cf1e0947c77c544665"
32
  Pipeline = None
33
+ use_com = False
34
 
35
  import torch
36
  import math
37
  from typing import Dict, Any
38
 
39
  def remove_cache():
 
40
  torch.cuda.empty_cache()
41
  torch.cuda.reset_max_memory_allocated()
42
+ gc.collect()
43
  torch.cuda.reset_peak_memory_stats()
44
 
45
 
 
46
 
47
+ def text_t5_loader() -> T5EncoderModel:
48
+ print("Loading text encoder...")
49
+ text_encoder = T5EncoderModel.from_pretrained(
50
+ "city96/t5-v1_1-xxl-encoder-bf16",
51
+ revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
52
+ torch_dtype=torch.bfloat16,
53
+ )
54
+ return text_encoder.to(memory_format=torch.channels_last)
 
55
 
 
 
 
 
 
 
 
 
 
56
 
57
+ class StableDiffusionTransformerCompile:
 
58
  def __init__(self, pipeline, optimize=False):
59
  self.pipeline = pipeline
60
  self.optimize = optimize
61
  if self.optimize:
62
+ self.model_compiling()
63
 
64
+ def model_compiling(self):
65
+ # Staff doing here
66
  self.pipeline.unet = torch.compile(self.pipeline.unet)
67
 
68
  def __call__(self, *args, **kwargs):
 
70
 
71
  def load_pipeline() -> Pipeline:
72
 
73
+ text_t5_encoder = text_t5_loader()
74
 
75
+ transformer_path__ = os.path.join(HF_HUB_CACHE, "models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665")
 
76
 
77
+ transformer__ = FluxTransformer2DModel.from_pretrained(transformer_path__, torch_dtype=torch.bfloat16, use_safetensors=False)
78
 
79
+ try:
80
+ pipeline = DiffusionPipeline.from_pretrained(ckpt_root,
81
+ revision=revision_root,
82
+ transformer=transformer__,
83
+ torch_dtype=torch.bfloat16)
84
+
85
+ except:
86
+ pipeline = DiffusionPipeline.from_pretrained(ckpt_root,
87
+ revision=revision_root,
88
+ torch_dtype=torch.bfloat16)
89
 
 
 
 
 
 
90
  pipeline.to("cuda")
91
+
92
  try:
93
+ compiled_pipeline = StableDiffusionTransformerCompile(pipeline, optimize=False)
94
+
95
+ if use_com:
96
+ pipeline = compiled_pipeline
97
+ else:
98
+ print("Currently not compling affectively")
99
+
100
+ pipeline.disable_vae_compress()
101
+ pipeline.text_encoder_2 = text_t5_encoder
102
+
103
  except:
104
+ print("pipeline")
105
 
106
 
107
+ prompt_1 = "albaspidin, pillmonger, palaeocrystalline"
108
+ pipeline(prompt=prompt_1,
109
+ width=1024,
110
+ height=1024,
111
+ guidance_scale=0.0,
112
+ num_inference_steps=4,
113
+ max_sequence_length=256)
114
+
115
+ prompt_2 = "obe, kilometrage, circuition"
116
+ pipeline(prompt=prompt_2,
117
+ width=1024,
118
+ height=1024,
119
+ guidance_scale=0.0,
120
+ num_inference_steps=4,
121
+ max_sequence_length=256)
122
 
123
  return pipeline
124