Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -133,8 +133,8 @@ models_rbm.generator.eval().requires_grad_(False)
|
|
| 133 |
def infer(ref_style_file, style_description, caption, use_low_vram, progress):
|
| 134 |
global models_rbm, models_b, device
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
try:
|
| 139 |
|
| 140 |
caption = f"{caption} in {style_description}"
|
|
@@ -234,6 +234,8 @@ def infer(ref_style_file, style_description, caption, use_low_vram, progress):
|
|
| 234 |
return sampled_image # Return the sampled_image PIL image
|
| 235 |
|
| 236 |
finally:
|
|
|
|
|
|
|
| 237 |
# Clear CUDA cache
|
| 238 |
torch.cuda.empty_cache()
|
| 239 |
gc.collect()
|
|
@@ -241,10 +243,9 @@ def infer(ref_style_file, style_description, caption, use_low_vram, progress):
|
|
| 241 |
def infer_compo(style_description, ref_style_file, caption, ref_sub_file, use_low_vram, progress):
|
| 242 |
global models_rbm, models_b, device
|
| 243 |
sam_model = LangSAM()
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
models_to(sam_model.sam, device=device)
|
| 248 |
try:
|
| 249 |
caption = f"{caption} in {style_description}"
|
| 250 |
sam_prompt = f"{caption}"
|
|
@@ -361,6 +362,10 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file, use_lo
|
|
| 361 |
return sampled_image # Return the sampled_image PIL image
|
| 362 |
|
| 363 |
finally:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
# Clear CUDA cache
|
| 365 |
torch.cuda.empty_cache()
|
| 366 |
gc.collect()
|
|
|
|
| 133 |
def infer(ref_style_file, style_description, caption, use_low_vram, progress):
|
| 134 |
global models_rbm, models_b, device
|
| 135 |
|
| 136 |
+
models_to(models_rbm, device=device)
|
| 137 |
+
|
| 138 |
try:
|
| 139 |
|
| 140 |
caption = f"{caption} in {style_description}"
|
|
|
|
| 234 |
return sampled_image # Return the sampled_image PIL image
|
| 235 |
|
| 236 |
finally:
|
| 237 |
+
if use_low_vram:
|
| 238 |
+
models_to(models_rbm, device=device)
|
| 239 |
# Clear CUDA cache
|
| 240 |
torch.cuda.empty_cache()
|
| 241 |
gc.collect()
|
|
|
|
| 243 |
def infer_compo(style_description, ref_style_file, caption, ref_sub_file, use_low_vram, progress):
|
| 244 |
global models_rbm, models_b, device
|
| 245 |
sam_model = LangSAM()
|
| 246 |
+
models_to(models_rbm, device=device)
|
| 247 |
+
models_to(sam_model, device=device)
|
| 248 |
+
models_to(sam_model.sam, device=device)
|
|
|
|
| 249 |
try:
|
| 250 |
caption = f"{caption} in {style_description}"
|
| 251 |
sam_prompt = f"{caption}"
|
|
|
|
| 362 |
return sampled_image # Return the sampled_image PIL image
|
| 363 |
|
| 364 |
finally:
|
| 365 |
+
if use_low_vram:
|
| 366 |
+
models_to(models_rbm, device=device, excepts=["generator", "previewer"])
|
| 367 |
+
models_to(sam_model, device=device)
|
| 368 |
+
models_to(sam_model.sam, device=device)
|
| 369 |
# Clear CUDA cache
|
| 370 |
torch.cuda.empty_cache()
|
| 371 |
gc.collect()
|