Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -11,7 +11,7 @@ import numpy as np
|
|
11 |
import re
|
12 |
|
13 |
# Import necessary functions and classes
|
14 |
-
from utils import load_t5, load_clap
|
15 |
from train import RF
|
16 |
from constants import build_model
|
17 |
|
@@ -22,6 +22,7 @@ global_clap = None
|
|
22 |
global_vae = None
|
23 |
global_vocoder = None
|
24 |
global_diffusion = None
|
|
|
25 |
|
26 |
# Set the models directory
|
27 |
MODELS_DIR = os.path.join(os.path.dirname(__file__), "models")
|
@@ -63,15 +64,16 @@ def prepare(t5, clip, img, prompt):
|
|
63 |
}
|
64 |
|
65 |
def unload_current_model():
|
66 |
-
global global_model
|
67 |
if global_model is not None:
|
68 |
del global_model
|
69 |
torch.cuda.empty_cache()
|
70 |
global_model = None
|
|
|
71 |
|
72 |
def load_model(model_name):
|
73 |
-
global global_model
|
74 |
-
device = "
|
75 |
|
76 |
unload_current_model()
|
77 |
|
@@ -90,16 +92,27 @@ def load_model(model_name):
|
|
90 |
print(f"Loading {model_size} model: {model_name}")
|
91 |
|
92 |
model_path = os.path.join(MODELS_DIR, model_name)
|
93 |
-
global_model = build_model(model_size).to(device)
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
def load_resources():
|
100 |
global global_t5, global_clap, global_vae, global_vocoder, global_diffusion
|
101 |
|
102 |
-
device = "
|
103 |
|
104 |
print("Loading T5 and CLAP models...")
|
105 |
global_t5 = load_t5(device, max_length=256)
|
@@ -114,17 +127,17 @@ def load_resources():
|
|
114 |
|
115 |
print("Base resources loaded successfully!")
|
116 |
|
117 |
-
def generate_music(prompt, seed, cfg_scale, steps, duration, progress=gr.Progress()):
|
118 |
global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion
|
119 |
|
120 |
if global_model is None:
|
121 |
-
return "Please select a model first.", None
|
122 |
|
123 |
if seed == 0:
|
124 |
seed = random.randint(1, 1000000)
|
125 |
print(f"Using seed: {seed}")
|
126 |
|
127 |
-
device = "
|
128 |
torch.manual_seed(seed)
|
129 |
torch.set_grad_enabled(False)
|
130 |
|
@@ -150,11 +163,25 @@ def generate_music(prompt, seed, cfg_scale, steps, duration, progress=gr.Progres
|
|
150 |
img, conds = prepare(global_t5, global_clap, init_noise, conds_txt)
|
151 |
_, unconds = prepare(global_t5, global_clap, init_noise, unconds_txt)
|
152 |
|
153 |
-
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
images = rearrange(
|
157 |
-
images
|
158 |
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
159 |
h=128,
|
160 |
w=8,
|
@@ -239,6 +266,8 @@ with gr.Blocks(theme=theme) as iface:
|
|
239 |
|
240 |
with gr.Row():
|
241 |
model_dropdown = gr.Dropdown(choices=model_choices, label="Select Model", value=default_model)
|
|
|
|
|
242 |
|
243 |
with gr.Row():
|
244 |
prompt = gr.Textbox(label="Prompt")
|
@@ -253,22 +282,15 @@ with gr.Blocks(theme=theme) as iface:
|
|
253 |
output_status = gr.Textbox(label="Generation Status")
|
254 |
output_audio = gr.Audio(type="filepath")
|
255 |
|
256 |
-
def
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
except Exception as e:
|
262 |
-
print(f"Error loading model {model_name}: {str(e)}")
|
263 |
-
else:
|
264 |
-
print("No valid model selected.")
|
265 |
-
|
266 |
-
model_dropdown.change(on_model_change, inputs=[model_dropdown])
|
267 |
generate_button.click(generate_music, inputs=[prompt, seed, cfg_scale, steps, duration], outputs=[output_status, output_audio])
|
268 |
|
269 |
-
# Load default model on startup
|
270 |
-
|
271 |
-
iface.load(lambda: load_model(default_model), inputs=None, outputs=None)
|
272 |
|
273 |
# Launch the interface
|
274 |
iface.launch()
|
|
|
11 |
import re
|
12 |
|
13 |
# Import necessary functions and classes
|
14 |
+
from utils import load_t5, load_clap, quantize_model
|
15 |
from train import RF
|
16 |
from constants import build_model
|
17 |
|
|
|
22 |
global_vae = None
|
23 |
global_vocoder = None
|
24 |
global_diffusion = None
|
25 |
+
current_model_name = None
|
26 |
|
27 |
# Set the models directory
|
28 |
MODELS_DIR = os.path.join(os.path.dirname(__file__), "models")
|
|
|
64 |
}
|
65 |
|
66 |
def unload_current_model():
|
67 |
+
global global_model, current_model_name
|
68 |
if global_model is not None:
|
69 |
del global_model
|
70 |
torch.cuda.empty_cache()
|
71 |
global_model = None
|
72 |
+
current_model_name = None
|
73 |
|
74 |
def load_model(model_name):
|
75 |
+
global global_model, current_model_name
|
76 |
+
device = "cpu" # Force CPU usage
|
77 |
|
78 |
unload_current_model()
|
79 |
|
|
|
92 |
print(f"Loading {model_size} model: {model_name}")
|
93 |
|
94 |
model_path = os.path.join(MODELS_DIR, model_name)
|
95 |
+
global_model = build_model(model_size, device="cpu").to(device)
|
96 |
+
|
97 |
+
try:
|
98 |
+
state_dict = torch.load(model_path, map_location=device, weights_only=True)
|
99 |
+
global_model.load_state_dict(state_dict['ema'], strict=False)
|
100 |
+
global_model.eval()
|
101 |
+
|
102 |
+
# Quantize the model for CPU inference
|
103 |
+
global_model = quantize_model(global_model)
|
104 |
+
|
105 |
+
global_model.model_path = model_path
|
106 |
+
current_model_name = model_name
|
107 |
+
return f"Successfully loaded and quantized model: {model_name}"
|
108 |
+
except Exception as e:
|
109 |
+
print(f"Error loading model {model_name}: {str(e)}")
|
110 |
+
return f"Failed to load model: {model_name}. Error: {str(e)}"
|
111 |
|
112 |
def load_resources():
|
113 |
global global_t5, global_clap, global_vae, global_vocoder, global_diffusion
|
114 |
|
115 |
+
device = "cpu"
|
116 |
|
117 |
print("Loading T5 and CLAP models...")
|
118 |
global_t5 = load_t5(device, max_length=256)
|
|
|
127 |
|
128 |
print("Base resources loaded successfully!")
|
129 |
|
130 |
+
def generate_music(prompt, seed, cfg_scale, steps, duration, batch_size=4, progress=gr.Progress()):
|
131 |
global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion
|
132 |
|
133 |
if global_model is None:
|
134 |
+
return "Please select and load a model first.", None
|
135 |
|
136 |
if seed == 0:
|
137 |
seed = random.randint(1, 1000000)
|
138 |
print(f"Using seed: {seed}")
|
139 |
|
140 |
+
device = "cpu"
|
141 |
torch.manual_seed(seed)
|
142 |
torch.set_grad_enabled(False)
|
143 |
|
|
|
163 |
img, conds = prepare(global_t5, global_clap, init_noise, conds_txt)
|
164 |
_, unconds = prepare(global_t5, global_clap, init_noise, unconds_txt)
|
165 |
|
166 |
+
# Implement batching for CPU inference
|
167 |
+
images = []
|
168 |
+
for batch_start in range(0, img.shape[0], batch_size):
|
169 |
+
batch_end = min(batch_start + batch_size, img.shape[0])
|
170 |
+
batch_img = img[batch_start:batch_end]
|
171 |
+
batch_conds = {k: v[batch_start:batch_end] for k, v in conds.items()}
|
172 |
+
batch_unconds = {k: v[batch_start:batch_end] for k, v in unconds.items()}
|
173 |
+
|
174 |
+
with torch.no_grad():
|
175 |
+
batch_images = global_diffusion.sample_with_xps(
|
176 |
+
global_model, batch_img, conds=batch_conds, null_cond=batch_unconds,
|
177 |
+
sample_steps=steps, cfg=cfg_scale
|
178 |
+
)
|
179 |
+
images.append(batch_images[-1])
|
180 |
+
|
181 |
+
images = torch.cat(images, dim=0)
|
182 |
|
183 |
images = rearrange(
|
184 |
+
images,
|
185 |
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
186 |
h=128,
|
187 |
w=8,
|
|
|
266 |
|
267 |
with gr.Row():
|
268 |
model_dropdown = gr.Dropdown(choices=model_choices, label="Select Model", value=default_model)
|
269 |
+
load_model_button = gr.Button("Load Model")
|
270 |
+
model_status = gr.Textbox(label="Model Status", value="No model loaded")
|
271 |
|
272 |
with gr.Row():
|
273 |
prompt = gr.Textbox(label="Prompt")
|
|
|
282 |
output_status = gr.Textbox(label="Generation Status")
|
283 |
output_audio = gr.Audio(type="filepath")
|
284 |
|
285 |
+
def on_load_model_click(model_name):
|
286 |
+
result = load_model(model_name)
|
287 |
+
return result
|
288 |
+
|
289 |
+
load_model_button.click(on_load_model_click, inputs=[model_dropdown], outputs=[model_status])
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
generate_button.click(generate_music, inputs=[prompt, seed, cfg_scale, steps, duration], outputs=[output_status, output_audio])
|
291 |
|
292 |
+
# Load default model on startup
|
293 |
+
iface.load(lambda: on_load_model_click(default_model), inputs=None, outputs=None)
|
|
|
294 |
|
295 |
# Launch the interface
|
296 |
iface.launch()
|