flosstradamus commited on
Commit
368ac79
·
verified ·
1 Parent(s): 313d650

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -30
app.py CHANGED
@@ -11,7 +11,7 @@ import numpy as np
11
  import re
12
 
13
  # Import necessary functions and classes
14
- from utils import load_t5, load_clap
15
  from train import RF
16
  from constants import build_model
17
 
@@ -22,6 +22,7 @@ global_clap = None
22
  global_vae = None
23
  global_vocoder = None
24
  global_diffusion = None
 
25
 
26
  # Set the models directory
27
  MODELS_DIR = os.path.join(os.path.dirname(__file__), "models")
@@ -63,15 +64,16 @@ def prepare(t5, clip, img, prompt):
63
  }
64
 
65
  def unload_current_model():
66
- global global_model
67
  if global_model is not None:
68
  del global_model
69
  torch.cuda.empty_cache()
70
  global_model = None
 
71
 
72
  def load_model(model_name):
73
- global global_model
74
- device = "cuda" if torch.cuda.is_available() else "cpu"
75
 
76
  unload_current_model()
77
 
@@ -90,16 +92,27 @@ def load_model(model_name):
90
  print(f"Loading {model_size} model: {model_name}")
91
 
92
  model_path = os.path.join(MODELS_DIR, model_name)
93
- global_model = build_model(model_size).to(device)
94
- state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
95
- global_model.load_state_dict(state_dict['ema'])
96
- global_model.eval()
97
- global_model.model_path = model_path
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  def load_resources():
100
  global global_t5, global_clap, global_vae, global_vocoder, global_diffusion
101
 
102
- device = "cuda" if torch.cuda.is_available() else "cpu"
103
 
104
  print("Loading T5 and CLAP models...")
105
  global_t5 = load_t5(device, max_length=256)
@@ -114,17 +127,17 @@ def load_resources():
114
 
115
  print("Base resources loaded successfully!")
116
 
117
- def generate_music(prompt, seed, cfg_scale, steps, duration, progress=gr.Progress()):
118
  global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion
119
 
120
  if global_model is None:
121
- return "Please select a model first.", None
122
 
123
  if seed == 0:
124
  seed = random.randint(1, 1000000)
125
  print(f"Using seed: {seed}")
126
 
127
- device = "cuda" if torch.cuda.is_available() else "cpu"
128
  torch.manual_seed(seed)
129
  torch.set_grad_enabled(False)
130
 
@@ -150,11 +163,25 @@ def generate_music(prompt, seed, cfg_scale, steps, duration, progress=gr.Progres
150
  img, conds = prepare(global_t5, global_clap, init_noise, conds_txt)
151
  _, unconds = prepare(global_t5, global_clap, init_noise, unconds_txt)
152
 
153
- with torch.autocast(device_type='cuda'):
154
- images = global_diffusion.sample_with_xps(global_model, img, conds=conds, null_cond=unconds, sample_steps=steps, cfg=cfg_scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  images = rearrange(
157
- images[-1],
158
  "b (h w) (c ph pw) -> b c (h ph) (w pw)",
159
  h=128,
160
  w=8,
@@ -239,6 +266,8 @@ with gr.Blocks(theme=theme) as iface:
239
 
240
  with gr.Row():
241
  model_dropdown = gr.Dropdown(choices=model_choices, label="Select Model", value=default_model)
 
 
242
 
243
  with gr.Row():
244
  prompt = gr.Textbox(label="Prompt")
@@ -253,22 +282,15 @@ with gr.Blocks(theme=theme) as iface:
253
  output_status = gr.Textbox(label="Generation Status")
254
  output_audio = gr.Audio(type="filepath")
255
 
256
- def on_model_change(model_name):
257
- if model_name != "No models available":
258
- try:
259
- load_model(model_name)
260
- print(f"Successfully loaded model: {model_name}")
261
- except Exception as e:
262
- print(f"Error loading model {model_name}: {str(e)}")
263
- else:
264
- print("No valid model selected.")
265
-
266
- model_dropdown.change(on_model_change, inputs=[model_dropdown])
267
  generate_button.click(generate_music, inputs=[prompt, seed, cfg_scale, steps, duration], outputs=[output_status, output_audio])
268
 
269
- # Load default model on startup if it's a valid model
270
- if default_model != "No models available":
271
- iface.load(lambda: load_model(default_model), inputs=None, outputs=None)
272
 
273
  # Launch the interface
274
  iface.launch()
 
11
  import re
12
 
13
  # Import necessary functions and classes
14
+ from utils import load_t5, load_clap, quantize_model
15
  from train import RF
16
  from constants import build_model
17
 
 
22
  global_vae = None
23
  global_vocoder = None
24
  global_diffusion = None
25
+ current_model_name = None
26
 
27
  # Set the models directory
28
  MODELS_DIR = os.path.join(os.path.dirname(__file__), "models")
 
64
  }
65
 
66
  def unload_current_model():
67
+ global global_model, current_model_name
68
  if global_model is not None:
69
  del global_model
70
  torch.cuda.empty_cache()
71
  global_model = None
72
+ current_model_name = None
73
 
74
  def load_model(model_name):
75
+ global global_model, current_model_name
76
+ device = "cpu" # Force CPU usage
77
 
78
  unload_current_model()
79
 
 
92
  print(f"Loading {model_size} model: {model_name}")
93
 
94
  model_path = os.path.join(MODELS_DIR, model_name)
95
+ global_model = build_model(model_size, device="cpu").to(device)
96
+
97
+ try:
98
+ state_dict = torch.load(model_path, map_location=device, weights_only=True)
99
+ global_model.load_state_dict(state_dict['ema'], strict=False)
100
+ global_model.eval()
101
+
102
+ # Quantize the model for CPU inference
103
+ global_model = quantize_model(global_model)
104
+
105
+ global_model.model_path = model_path
106
+ current_model_name = model_name
107
+ return f"Successfully loaded and quantized model: {model_name}"
108
+ except Exception as e:
109
+ print(f"Error loading model {model_name}: {str(e)}")
110
+ return f"Failed to load model: {model_name}. Error: {str(e)}"
111
 
112
  def load_resources():
113
  global global_t5, global_clap, global_vae, global_vocoder, global_diffusion
114
 
115
+ device = "cpu"
116
 
117
  print("Loading T5 and CLAP models...")
118
  global_t5 = load_t5(device, max_length=256)
 
127
 
128
  print("Base resources loaded successfully!")
129
 
130
+ def generate_music(prompt, seed, cfg_scale, steps, duration, batch_size=4, progress=gr.Progress()):
131
  global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion
132
 
133
  if global_model is None:
134
+ return "Please select and load a model first.", None
135
 
136
  if seed == 0:
137
  seed = random.randint(1, 1000000)
138
  print(f"Using seed: {seed}")
139
 
140
+ device = "cpu"
141
  torch.manual_seed(seed)
142
  torch.set_grad_enabled(False)
143
 
 
163
  img, conds = prepare(global_t5, global_clap, init_noise, conds_txt)
164
  _, unconds = prepare(global_t5, global_clap, init_noise, unconds_txt)
165
 
166
+ # Implement batching for CPU inference
167
+ images = []
168
+ for batch_start in range(0, img.shape[0], batch_size):
169
+ batch_end = min(batch_start + batch_size, img.shape[0])
170
+ batch_img = img[batch_start:batch_end]
171
+ batch_conds = {k: v[batch_start:batch_end] for k, v in conds.items()}
172
+ batch_unconds = {k: v[batch_start:batch_end] for k, v in unconds.items()}
173
+
174
+ with torch.no_grad():
175
+ batch_images = global_diffusion.sample_with_xps(
176
+ global_model, batch_img, conds=batch_conds, null_cond=batch_unconds,
177
+ sample_steps=steps, cfg=cfg_scale
178
+ )
179
+ images.append(batch_images[-1])
180
+
181
+ images = torch.cat(images, dim=0)
182
 
183
  images = rearrange(
184
+ images,
185
  "b (h w) (c ph pw) -> b c (h ph) (w pw)",
186
  h=128,
187
  w=8,
 
266
 
267
  with gr.Row():
268
  model_dropdown = gr.Dropdown(choices=model_choices, label="Select Model", value=default_model)
269
+ load_model_button = gr.Button("Load Model")
270
+ model_status = gr.Textbox(label="Model Status", value="No model loaded")
271
 
272
  with gr.Row():
273
  prompt = gr.Textbox(label="Prompt")
 
282
  output_status = gr.Textbox(label="Generation Status")
283
  output_audio = gr.Audio(type="filepath")
284
 
285
+ def on_load_model_click(model_name):
286
+ result = load_model(model_name)
287
+ return result
288
+
289
+ load_model_button.click(on_load_model_click, inputs=[model_dropdown], outputs=[model_status])
 
 
 
 
 
 
290
  generate_button.click(generate_music, inputs=[prompt, seed, cfg_scale, steps, duration], outputs=[output_status, output_audio])
291
 
292
+ # Load default model on startup
293
+ iface.load(lambda: on_load_model_click(default_model), inputs=None, outputs=None)
 
294
 
295
  # Launch the interface
296
  iface.launch()