Spaces:
Runtime error
Runtime error
Commit
·
f424501
1
Parent(s):
d41c21d
Attempt gc again for faster speeds
Browse files
app.py
CHANGED
|
@@ -7,6 +7,7 @@ import lora
|
|
| 7 |
from time import sleep
|
| 8 |
import copy
|
| 9 |
import json
|
|
|
|
| 10 |
|
| 11 |
with open("sdxl_loras.json", "r") as file:
|
| 12 |
data = json.load(file)
|
|
@@ -35,11 +36,14 @@ pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
| 35 |
"stabilityai/stable-diffusion-xl-base-1.0",
|
| 36 |
vae=vae,
|
| 37 |
torch_dtype=torch.float16,
|
| 38 |
-
).to(
|
|
|
|
|
|
|
| 39 |
|
| 40 |
last_lora = ""
|
| 41 |
last_merged = False
|
| 42 |
|
|
|
|
| 43 |
def update_selection(selected_state: gr.SelectData):
|
| 44 |
lora_repo = sdxl_loras[selected_state.index]["repo"]
|
| 45 |
instance_prompt = sdxl_loras[selected_state.index]["trigger_word"]
|
|
@@ -129,11 +133,10 @@ def run_lora(prompt, negative, lora_scale, selected_state):
|
|
| 129 |
cross_attention_kwargs = None
|
| 130 |
if last_lora != repo_name:
|
| 131 |
if last_merged:
|
| 132 |
-
pipe
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
).to(device)
|
| 137 |
else:
|
| 138 |
pipe.unload_lora_weights()
|
| 139 |
is_compatible = sdxl_loras[selected_state.index]["is_compatible"]
|
|
@@ -260,4 +263,4 @@ with gr.Blocks(css="custom.css") as demo:
|
|
| 260 |
share_button.click(None, [], [], _js=share_js)
|
| 261 |
|
| 262 |
demo.queue(max_size=20)
|
| 263 |
-
demo.launch()
|
|
|
|
| 7 |
from time import sleep
|
| 8 |
import copy
|
| 9 |
import json
|
| 10 |
+
import gc
|
| 11 |
|
| 12 |
with open("sdxl_loras.json", "r") as file:
|
| 13 |
data = json.load(file)
|
|
|
|
| 36 |
"stabilityai/stable-diffusion-xl-base-1.0",
|
| 37 |
vae=vae,
|
| 38 |
torch_dtype=torch.float16,
|
| 39 |
+
).to("cpu")
|
| 40 |
+
original_pipe = copy.deepcopy(pipe)
|
| 41 |
+
pipe.to(device)
|
| 42 |
|
| 43 |
last_lora = ""
|
| 44 |
last_merged = False
|
| 45 |
|
| 46 |
+
|
| 47 |
def update_selection(selected_state: gr.SelectData):
|
| 48 |
lora_repo = sdxl_loras[selected_state.index]["repo"]
|
| 49 |
instance_prompt = sdxl_loras[selected_state.index]["trigger_word"]
|
|
|
|
| 133 |
cross_attention_kwargs = None
|
| 134 |
if last_lora != repo_name:
|
| 135 |
if last_merged:
|
| 136 |
+
del pipe
|
| 137 |
+
gc.collect()
|
| 138 |
+
pipe = copy.deepcopy(original_pipe)
|
| 139 |
+
pipe.to(device)
|
|
|
|
| 140 |
else:
|
| 141 |
pipe.unload_lora_weights()
|
| 142 |
is_compatible = sdxl_loras[selected_state.index]["is_compatible"]
|
|
|
|
| 263 |
share_button.click(None, [], [], _js=share_js)
|
| 264 |
|
| 265 |
demo.queue(max_size=20)
|
| 266 |
+
demo.launch()
|