axrzce commited on
Commit
b558f4c
·
verified ·
1 Parent(s): 6c917de

Deploy from GitHub main

Browse files
.gitattributes CHANGED
@@ -2,3 +2,4 @@
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  *.safetensors filter=lfs diff=lfs merge=lfs -text
4
  *.ckpt filter=lfs diff=lfs merge=lfs -text
 
 
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  *.safetensors filter=lfs diff=lfs merge=lfs -text
4
  *.ckpt filter=lfs diff=lfs merge=lfs -text
5
+ exports/compi_export_20250823_171107.zip filter=lfs diff=lfs merge=lfs -text
exports/compi_export_20250823_171107.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5bd409356326d49795e641662c102b854fb2177798c827f84b05ec190b8bd197
3
+ size 330651
src/ui/compi_phase3_final_dashboard.py CHANGED
@@ -18,6 +18,10 @@ Features:
18
 
19
  import gc
20
  import os
 
 
 
 
21
  import io
22
  import csv
23
  import json
@@ -292,17 +296,87 @@ def load_sd15(txt2img=True):
292
  )
293
  return pipe.to(DEVICE)
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  @st.cache_resource(show_spinner=True)
296
  def load_sdxl():
297
- """Load SDXL pipeline"""
298
  if not HAS_SDXL:
299
  return None
300
- pipe = StableDiffusionXLPipeline.from_pretrained(
301
- "stabilityai/stable-diffusion-xl-base-1.0",
302
- safety_checker=None,
303
- torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
304
- )
305
- return pipe.to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
  @st.cache_resource(show_spinner=True)
308
  def load_upscaler():
@@ -1046,12 +1120,14 @@ with tab_inputs:
1046
  # Use first chosen style reference as init image
1047
  init_image = ref_images[style_idxs[0]-1].resize((int(width), int(height)))
1048
 
1049
- # Generation + Clear buttons side-by-side
1050
- col_gen, col_clear = st.columns([3, 1])
1051
  with col_gen:
1052
  go = st.button("🚀 Generate Multimodal Art", type="primary", use_container_width=True)
1053
  with col_clear:
1054
  clear = st.button("🧹 Clear", use_container_width=True)
 
 
1055
 
1056
  # Clear logic: reset prompt fields and any generated output state
1057
  if 'generated_images' not in st.session_state:
@@ -1065,6 +1141,17 @@ with tab_inputs:
1065
  st.success("Cleared current prompt and output. Ready for a new prompt.")
1066
  st.rerun()
1067
 
 
 
 
 
 
 
 
 
 
 
 
1068
  # Cached pipeline getters
1069
  @st.cache_resource(show_spinner=True)
1070
  def get_txt2img():
@@ -1078,6 +1165,8 @@ with tab_inputs:
1078
  def get_sdxl():
1079
  return load_sdxl()
1080
 
 
 
1081
  @st.cache_resource(show_spinner=True)
1082
  def get_upscaler():
1083
  return load_upscaler()
@@ -1142,7 +1231,7 @@ with tab_inputs:
1142
  # Choose pipeline based on model selection
1143
  if model_choice.startswith("SDXL") and HAS_SDXL and gen_mode == "txt2img":
1144
  pipe = get_sdxl()
1145
- model_id = "SDXL-Base-1.0"
1146
  else:
1147
  if gen_mode == "txt2img":
1148
  pipe = get_txt2img()
@@ -1151,6 +1240,10 @@ with tab_inputs:
1151
  pipe = get_img2img()
1152
  model_id = "SD-1.5 (img2img)"
1153
 
 
 
 
 
1154
  # Apply performance optimizations
1155
  xformed = attempt_enable_xformers(pipe) if use_xformers else False
1156
  apply_perf(pipe, attn_slice, vae_slice, vae_tile)
 
18
 
19
  import gc
20
  import os
21
+
22
+ # Set PyTorch memory management for better VRAM handling
23
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
24
+
25
  import io
26
  import csv
27
  import json
 
296
  )
297
  return pipe.to(DEVICE)
298
 
299
+ def force_clear_vram():
300
+ """Nuclear VRAM cleanup - clears everything possible"""
301
+ if DEVICE == "cuda":
302
+ try:
303
+ # Clear PyTorch cache multiple times
304
+ for _ in range(3):
305
+ torch.cuda.empty_cache()
306
+ torch.cuda.synchronize()
307
+
308
+ # Force Python garbage collection
309
+ import gc
310
+ gc.collect()
311
+
312
+ # Try to reset memory stats (if available)
313
+ try:
314
+ torch.cuda.reset_peak_memory_stats()
315
+ torch.cuda.reset_accumulated_memory_stats()
316
+ except:
317
+ pass
318
+
319
+ # Final cache clear
320
+ torch.cuda.empty_cache()
321
+ torch.cuda.synchronize()
322
+
323
+ # Show memory status
324
+ allocated = torch.cuda.memory_allocated() / (1024**3)
325
+ reserved = torch.cuda.memory_reserved() / (1024**3)
326
+ st.info(f"🧹 Memory cleared - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
327
+
328
+ except Exception as e:
329
+ st.warning(f"Memory clearing failed: {e}")
330
+
331
  @st.cache_resource(show_spinner=True)
332
  def load_sdxl():
333
+ """Load SDXL pipeline with nuclear VRAM management"""
334
  if not HAS_SDXL:
335
  return None
336
+
337
+ # Nuclear cleanup before loading
338
+ force_clear_vram()
339
+
340
+ # Try loading SDXL with retry logic
341
+ for attempt in range(3): # Try up to 3 times
342
+ try:
343
+ if attempt > 0:
344
+ st.info(f"🔄 SDXL loading attempt {attempt + 1}/3 - nuclear VRAM cleanup...")
345
+ force_clear_vram()
346
+ # Wait a moment for cleanup to take effect
347
+ import time
348
+ time.sleep(1)
349
+
350
+ pipe = StableDiffusionXLPipeline.from_pretrained(
351
+ "stabilityai/stable-diffusion-xl-base-1.0",
352
+ safety_checker=None,
353
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
354
+ low_cpu_mem_usage=True, # Load model parts progressively
355
+ use_safetensors=True, # More memory efficient loading
356
+ )
357
+ result = pipe.to(DEVICE)
358
+ if attempt > 0:
359
+ st.success(f"✅ SDXL loaded successfully on attempt {attempt + 1}")
360
+ return result
361
+
362
+ except torch.OutOfMemoryError as e:
363
+ if attempt < 2: # Not the last attempt
364
+ st.warning(f"⚠️ CUDA OOM on attempt {attempt + 1} - nuclear cleanup and retry...")
365
+ force_clear_vram()
366
+ # Longer wait for memory to actually be freed
367
+ import time
368
+ time.sleep(2)
369
+ continue
370
+ else:
371
+ st.error(f"🚫 SDXL failed after 3 attempts. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True")
372
+ st.error(f"Error details: {e}")
373
+ force_clear_vram()
374
+ return None
375
+ except Exception as e:
376
+ st.error(f"Failed to load SDXL: {e}")
377
+ return None
378
+
379
+ return None
380
 
381
  @st.cache_resource(show_spinner=True)
382
  def load_upscaler():
 
1120
  # Use first chosen style reference as init image
1121
  init_image = ref_images[style_idxs[0]-1].resize((int(width), int(height)))
1122
 
1123
+ # Generation + Clear + Memory buttons
1124
+ col_gen, col_clear, col_mem = st.columns([3, 1, 1])
1125
  with col_gen:
1126
  go = st.button("🚀 Generate Multimodal Art", type="primary", use_container_width=True)
1127
  with col_clear:
1128
  clear = st.button("🧹 Clear", use_container_width=True)
1129
+ with col_mem:
1130
+ clear_mem = st.button("💾 Free VRAM", use_container_width=True, help="Clear model cache and free VRAM")
1131
 
1132
  # Clear logic: reset prompt fields and any generated output state
1133
  if 'generated_images' not in st.session_state:
 
1141
  st.success("Cleared current prompt and output. Ready for a new prompt.")
1142
  st.rerun()
1143
 
1144
+ # Define clear function before using it
1145
+ def clear_model_cache():
1146
+ """Clear all cached models to free VRAM"""
1147
+ st.cache_resource.clear()
1148
+ force_clear_vram()
1149
+ st.success("🧹 All model caches cleared!")
1150
+
1151
+ if clear_mem:
1152
+ clear_model_cache()
1153
+ st.rerun()
1154
+
1155
  # Cached pipeline getters
1156
  @st.cache_resource(show_spinner=True)
1157
  def get_txt2img():
 
1165
  def get_sdxl():
1166
  return load_sdxl()
1167
 
1168
+
1169
+
1170
  @st.cache_resource(show_spinner=True)
1171
  def get_upscaler():
1172
  return load_upscaler()
 
1231
  # Choose pipeline based on model selection
1232
  if model_choice.startswith("SDXL") and HAS_SDXL and gen_mode == "txt2img":
1233
  pipe = get_sdxl()
1234
+ model_id = "SDXL-Base-1.0" if pipe else "SD-1.5-fallback"
1235
  else:
1236
  if gen_mode == "txt2img":
1237
  pipe = get_txt2img()
 
1240
  pipe = get_img2img()
1241
  model_id = "SD-1.5 (img2img)"
1242
 
1243
+ if not pipe:
1244
+ st.error("❌ Failed to load pipeline after all retry attempts. Try restarting the app or use a different model.")
1245
+ st.stop()
1246
+
1247
  # Apply performance optimizations
1248
  xformed = attempt_enable_xformers(pipe) if use_xformers else False
1249
  apply_perf(pipe, attn_slice, vae_slice, vae_tile)