yiren98 commited on
Commit
ea28f87
·
1 Parent(s): 21599ad
Files changed (1) hide show
  1. gradio_app.py +32 -61
gradio_app.py CHANGED
@@ -94,35 +94,33 @@ def load_target_model(selected_model):
94
  AE_PATH = download_file(ae_repo_id, ae_file)
95
  LORA_WEIGHTS_PATH = download_file(lora_repo, lora_file)
96
 
97
- return "Models loaded successfully. Using Recraft: {}".format(selected_model)
98
-
99
- # logger.info("Loading models...")
100
- # try:
101
- # if model is None is None or clip_l is None or t5xxl is None or ae is None:
102
- # _, model = flux_utils.load_flow_model(
103
- # BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
104
- # )
105
- # clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
106
- # clip_l.eval()
107
- # t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
108
- # t5xxl.eval()
109
- # ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
110
-
111
- # # Load LoRA weights
112
- # multiplier = 1.0
113
- # weights_sd = load_file(LORA_WEIGHTS_PATH)
114
- # lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
115
- # lora_model.apply_to([clip_l, t5xxl], model)
116
- # info = lora_model.load_state_dict(weights_sd, strict=True)
117
- # logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
118
- # lora_model.eval()
119
-
120
- # logger.info("Models loaded successfully.")
121
- # return "Models loaded successfully. Using Recraft: {}".format(selected_model)
122
-
123
- # except Exception as e:
124
- # logger.error(f"Error loading models: {e}")
125
- # return f"Error loading models: {e}"
126
 
127
  # Image pre-processing (resize and padding)
128
  class ResizeWithPadding:
@@ -156,37 +154,10 @@ class ResizeWithPadding:
156
  # The function to generate image from a prompt and conditional image
157
  @spaces.GPU(duration=180)
158
  def infer(prompt, sample_image, recraft_model, seed=0):
159
- # global model, clip_l, t5xxl, ae, lora_model
160
- # if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
161
- # logger.error("Models not loaded. Please load the models first.")
162
- # return None
163
- logger.info("Loading models...")
164
- try:
165
- if model is None is None or clip_l is None or t5xxl is None or ae is None:
166
- _, model = flux_utils.load_flow_model(
167
- BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cuda", disable_mmap=False
168
- )
169
- clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cuda", disable_mmap=False)
170
- clip_l.eval()
171
- t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cuda", disable_mmap=False)
172
- t5xxl.eval()
173
- ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cuda", disable_mmap=False)
174
-
175
- # Load LoRA weights
176
- multiplier = 1.0
177
- weights_sd = load_file(LORA_WEIGHTS_PATH)
178
- lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
179
- lora_model.apply_to([clip_l, t5xxl], model)
180
- info = lora_model.load_state_dict(weights_sd, strict=True)
181
- logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
182
- lora_model.eval()
183
-
184
- logger.info("Models loaded successfully.")
185
- # return "Models loaded successfully. Using Recraft: {}".format(selected_model)
186
-
187
- except Exception as e:
188
- logger.error(f"Error loading models: {e}")
189
- return f"Error loading models: {e}"
190
 
191
  model_path = model_paths[recraft_model]
192
  frame_num = model_path['Frame']
@@ -317,7 +288,7 @@ def infer(prompt, sample_image, recraft_model, seed=0):
317
 
318
  # Gradio interface
319
  with gr.Blocks() as demo:
320
- gr.Markdown("## FLUX Image Generation")
321
 
322
  with gr.Row():
323
  with gr.Column(scale=1):
 
94
  AE_PATH = download_file(ae_repo_id, ae_file)
95
  LORA_WEIGHTS_PATH = download_file(lora_repo, lora_file)
96
 
97
+ logger.info("Loading models...")
98
+ try:
99
+ if model is None is None or clip_l is None or t5xxl is None or ae is None:
100
+ clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
101
+ clip_l.eval()
102
+ t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
103
+ t5xxl.eval()
104
+ ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
105
+
106
+ # Load flux & LoRA weights
107
+ _, model = flux_utils.load_flow_model(
108
+ BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
109
+ )
110
+ multiplier = 1.0
111
+ weights_sd = load_file(LORA_WEIGHTS_PATH)
112
+ lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
113
+ lora_model.apply_to([clip_l, t5xxl], model)
114
+ info = lora_model.load_state_dict(weights_sd, strict=True)
115
+ logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
116
+ lora_model.eval()
117
+
118
+ logger.info("Models loaded successfully.")
119
+ return "Models loaded successfully. Using Recraft: {}".format(selected_model)
120
+
121
+ except Exception as e:
122
+ logger.error(f"Error loading models: {e}")
123
+ return f"Error loading models: {e}"
 
 
124
 
125
  # Image pre-processing (resize and padding)
126
  class ResizeWithPadding:
 
154
  # The function to generate image from a prompt and conditional image
155
  @spaces.GPU(duration=180)
156
  def infer(prompt, sample_image, recraft_model, seed=0):
157
+ global model, clip_l, t5xxl, ae, lora_model
158
+ if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
159
+ logger.error("Models not loaded. Please load the models first.")
160
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
  model_path = model_paths[recraft_model]
163
  frame_num = model_path['Frame']
 
288
 
289
  # Gradio interface
290
  with gr.Blocks() as demo:
291
+ gr.Markdown("## Recraft Generation")
292
 
293
  with gr.Row():
294
  with gr.Column(scale=1):