YOURNAME commited on
Commit
4870f5c
·
1 Parent(s): 4fcd1d5
Files changed (2) hide show
  1. pyproject.toml +0 -1
  2. src/pipeline.py +24 -23
pyproject.toml CHANGED
@@ -25,7 +25,6 @@ dependencies = [
25
  [[tool.edge-maxxing.models]]
26
  repository = "black-forest-labs/FLUX.1-schnell"
27
  revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
28
- exclude = ["transformer", "vae", "text_encoder_2"]
29
 
30
  [[tool.edge-maxxing.models]]
31
  repository = "city96/t5-v1_1-xxl-encoder-bf16"
 
25
  [[tool.edge-maxxing.models]]
26
  repository = "black-forest-labs/FLUX.1-schnell"
27
  revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
 
28
 
29
  [[tool.edge-maxxing.models]]
30
  repository = "city96/t5-v1_1-xxl-encoder-bf16"
src/pipeline.py CHANGED
@@ -41,7 +41,7 @@ def remove_cache():
41
  torch.cuda.reset_peak_memory_stats()
42
 
43
 
44
- class InitModel:
45
 
46
  @staticmethod
47
  def load_text_encoder() -> T5EncoderModel:
@@ -53,16 +53,6 @@ class InitModel:
53
  )
54
  return text_encoder.to(memory_format=torch.channels_last)
55
 
56
- @staticmethod
57
- def load_vae() -> AutoencoderTiny:
58
- print("Loading VAE model...")
59
- vae = AutoencoderTiny.from_pretrained(
60
- "XiangquiAI/FLUX_Vae_Model",
61
- revision="103bcc03998f48ef311c100ee119f1b9942132ab",
62
- torch_dtype=torch.bfloat16,
63
- )
64
- return vae
65
-
66
  @staticmethod
67
  def load_transformer(trans_path: str) -> FluxTransformer2DModel:
68
  print("Loading transformer model...")
@@ -74,35 +64,46 @@ class InitModel:
74
  return transformer.to(memory_format=torch.channels_last)
75
 
76
 
77
-
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def load_pipeline() -> Pipeline:
79
 
80
 
81
- transformer_path = os.path.join(HF_HUB_CACHE, "models--MyApricity--Flux_Transformer_float8/snapshots/66c5f182385555a00ec90272ab711bb6d3c197db")
82
- transformer = InitModel.load_transformer(transformer_path)
83
 
84
- text_encoder_2 = InitModel.load_text_encoder()
85
- vae = InitModel.load_vae()
86
 
87
 
88
  pipeline = DiffusionPipeline.from_pretrained(CHECKPOINT,
89
  revision=REVISION,
90
- vae=vae,
91
- transformer=transformer,
92
  text_encoder_2=text_encoder_2,
93
  torch_dtype=torch.bfloat16)
94
  pipeline.to("cuda")
95
  try:
96
  pipeline.disable_vae_slice()
 
97
  except:
98
- print("Using origin pipeline")
99
 
100
 
101
  promts_listing = [
102
- "melanogen, endosome",
103
- "buffer, cutie, buttinsky, prototrophic",
104
- "puzzlehead, fistical, must return non duplicate",
105
- "apical, polymyodous, tiptilt"
106
  ]
107
 
108
  for p in promts_listing:
 
41
  torch.cuda.reset_peak_memory_stats()
42
 
43
 
44
+ class InitializingModel:
45
 
46
  @staticmethod
47
  def load_text_encoder() -> T5EncoderModel:
 
53
  )
54
  return text_encoder.to(memory_format=torch.channels_last)
55
 
 
 
 
 
 
 
 
 
 
 
56
  @staticmethod
57
  def load_transformer(trans_path: str) -> FluxTransformer2DModel:
58
  print("Loading transformer model...")
 
64
  return transformer.to(memory_format=torch.channels_last)
65
 
66
 
67
+ class CompileTransformerDiffusion:
68
+ def __init__(self, pipeline, optimize=False):
69
+ self.pipeline = pipeline
70
+ self.optimize = optimize
71
+ if self.optimize:
72
+ self._compile_model()
73
+
74
+ def _compile_model(self):
75
+ print("Compiling transformer model for optimized diffusion...")
76
+ self.pipeline.unet = torch.compile(self.pipeline.unet)
77
+
78
+ def __call__(self, *args, **kwargs):
79
+ return self.pipeline(*args, **kwargs)
80
+
81
  def load_pipeline() -> Pipeline:
82
 
83
 
84
+ base_transformer_path = os.path.join(HF_HUB_CACHE, "models--MyApricity--Flux_Transformer_float8/snapshots/66c5f182385555a00ec90272ab711bb6d3c197db")
85
+ base_transformer = InitializingModel.load_transformer(base_transformer_path)
86
 
87
+ text_encoder_2 = InitializingModel.load_text_encoder()
 
88
 
89
 
90
  pipeline = DiffusionPipeline.from_pretrained(CHECKPOINT,
91
  revision=REVISION,
92
+ transformer=base_transformer,
 
93
  text_encoder_2=text_encoder_2,
94
  torch_dtype=torch.bfloat16)
95
  pipeline.to("cuda")
96
  try:
97
  pipeline.disable_vae_slice()
98
+ compiled_pipeline = CompileTransformerDiffusion(pipeline, optimize=False)
99
  except:
100
+ print("Stay safe here pipeline")
101
 
102
 
103
  promts_listing = [
104
+ "sellate, Tremellales, thro, albescent",
105
+ "must return non duplicate",
106
+ "albaspidin, pillmonger, palaeocrystalline"
 
107
  ]
108
 
109
  for p in promts_listing: