flosstradamus commited on
Commit
0dacaeb
·
verified ·
1 Parent(s): 41e9bb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -3
app.py CHANGED
@@ -29,7 +29,40 @@ current_model_name = None
29
  MODELS_DIR = os.path.join(os.path.dirname(__file__), "models")
30
  GENERATIONS_DIR = os.path.join(os.path.dirname(__file__), "generations")
31
 
32
- # ... (keep other functions unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def unload_current_model():
35
  global global_model, current_model_name
@@ -89,11 +122,175 @@ def load_model(model_name, device, model_url=None):
89
  print(f"Error loading model {model_name}: {str(e)}")
90
  return f"Failed to load model: {model_name}. Error: {str(e)}"
91
 
92
- # ... (keep the rest of the file unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  # Gradio Interface
95
  with gr.Blocks(theme=theme) as iface:
96
- # ... (keep the interface definition unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def on_load_model_click(model_name, device, url):
99
  if url:
 
29
  MODELS_DIR = os.path.join(os.path.dirname(__file__), "models")
30
  GENERATIONS_DIR = os.path.join(os.path.dirname(__file__), "generations")
31
 
32
+ def prepare(t5, clip, img, prompt):
33
+ bs, c, h, w = img.shape
34
+ if bs == 1 and not isinstance(prompt, str):
35
+ bs = len(prompt)
36
+
37
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
38
+ if img.shape[0] == 1 and bs > 1:
39
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
40
+
41
+ img_ids = torch.zeros(h // 2, w // 2, 3)
42
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
43
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
44
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
45
+
46
+ if isinstance(prompt, str):
47
+ prompt = [prompt]
48
+
49
+ # Generate text embeddings
50
+ txt = t5(prompt)
51
+
52
+ if txt.shape[0] == 1 and bs > 1:
53
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
54
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
55
+
56
+ vec = clip(prompt)
57
+ if vec.shape[0] == 1 and bs > 1:
58
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
59
+
60
+ return img, {
61
+ "img_ids": img_ids.to(img.device),
62
+ "txt": txt.to(img.device),
63
+ "txt_ids": txt_ids.to(img.device),
64
+ "y": vec.to(img.device),
65
+ }
66
 
67
  def unload_current_model():
68
  global global_model, current_model_name
 
122
  print(f"Error loading model {model_name}: {str(e)}")
123
  return f"Failed to load model: {model_name}. Error: {str(e)}"
124
 
125
+ def load_resources(device):
126
+ global global_t5, global_clap, global_vae, global_vocoder, global_diffusion
127
+
128
+ print("Loading T5 and CLAP models...")
129
+ global_t5 = load_t5(device, max_length=256)
130
+ global_clap = load_clap(device, max_length=256)
131
+
132
+ print("Loading VAE and vocoder...")
133
+ global_vae = AutoencoderKL.from_pretrained('cvssp/audioldm2', subfolder="vae").to(device)
134
+ global_vocoder = SpeechT5HifiGan.from_pretrained('cvssp/audioldm2', subfolder="vocoder").to(device)
135
+
136
+ print("Initializing diffusion...")
137
+ global_diffusion = RF()
138
+
139
+ print("Base resources loaded successfully!")
140
+
141
+ def generate_music(prompt, seed, cfg_scale, steps, duration, device, batch_size=4, progress=gr.Progress()):
142
+ global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion
143
+
144
+ if global_model is None:
145
+ return "Please select and load a model first.", None
146
+
147
+ if seed == 0:
148
+ seed = random.randint(1, 1000000)
149
+ print(f"Using seed: {seed}")
150
+
151
+ torch.manual_seed(seed)
152
+ torch.set_grad_enabled(False)
153
+
154
+ # Calculate the number of segments needed for the desired duration
155
+ segment_duration = 10 # Each segment is 10 seconds
156
+ num_segments = int(np.ceil(duration / segment_duration))
157
+
158
+ all_waveforms = []
159
+
160
+ for i in range(num_segments):
161
+ progress(i / num_segments, desc=f"Generating segment {i+1}/{num_segments}")
162
+
163
+ # Use the same seed for all segments
164
+ torch.manual_seed(seed + i) # Add i to slightly vary each segment while maintaining consistency
165
+
166
+ latent_size = (256, 16)
167
+ conds_txt = [prompt]
168
+ unconds_txt = ["low quality, gentle"]
169
+ L = len(conds_txt)
170
+
171
+ init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).to(device)
172
+
173
+ img, conds = prepare(global_t5, global_clap, init_noise, conds_txt)
174
+ _, unconds = prepare(global_t5, global_clap, init_noise, unconds_txt)
175
+
176
+ # Implement batching for CPU inference
177
+ images = []
178
+ for batch_start in range(0, img.shape[0], batch_size):
179
+ batch_end = min(batch_start + batch_size, img.shape[0])
180
+ batch_img = img[batch_start:batch_end]
181
+ batch_conds = {k: v[batch_start:batch_end] for k, v in conds.items()}
182
+ batch_unconds = {k: v[batch_start:batch_end] for k, v in unconds.items()}
183
+
184
+ with torch.no_grad():
185
+ batch_images = global_diffusion.sample_with_xps(
186
+ global_model, batch_img, conds=batch_conds, null_cond=batch_unconds,
187
+ sample_steps=steps, cfg=cfg_scale
188
+ )
189
+ images.append(batch_images[-1])
190
+
191
+ images = torch.cat(images, dim=0)
192
+
193
+ images = rearrange(
194
+ images,
195
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
196
+ h=128,
197
+ w=8,
198
+ ph=2,
199
+ pw=2,)
200
+
201
+ latents = 1 / global_vae.config.scaling_factor * images
202
+ mel_spectrogram = global_vae.decode(latents).sample
203
+
204
+ x_i = mel_spectrogram[0]
205
+ if x_i.dim() == 4:
206
+ x_i = x_i.squeeze(1)
207
+ waveform = global_vocoder(x_i)
208
+ waveform = waveform[0].cpu().float().detach().numpy()
209
+
210
+ all_waveforms.append(waveform)
211
+
212
+ # Concatenate all waveforms
213
+ final_waveform = np.concatenate(all_waveforms)
214
+
215
+ # Trim to exact duration
216
+ sample_rate = 16000
217
+ final_waveform = final_waveform[:int(duration * sample_rate)]
218
+
219
+ progress(0.9, desc="Saving audio file")
220
+
221
+ # Create 'generations' folder
222
+ os.makedirs(GENERATIONS_DIR, exist_ok=True)
223
+
224
+ # Generate filename
225
+ prompt_part = re.sub(r'[^\w\s-]', '', prompt)[:10].strip().replace(' ', '_')
226
+ model_name = os.path.splitext(os.path.basename(global_model.model_path))[0]
227
+ model_suffix = '_mf_b' if model_name == 'musicflow_b' else f'_{model_name}'
228
+ base_filename = f"{prompt_part}_{seed}{model_suffix}"
229
+ output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}.wav")
230
+
231
+ # Check if file exists and add numerical suffix if needed
232
+ counter = 1
233
+ while os.path.exists(output_path):
234
+ output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}_{counter}.wav")
235
+ counter += 1
236
+
237
+ wavfile.write(output_path, sample_rate, final_waveform)
238
+
239
+ progress(1.0, desc="Audio generation complete")
240
+ return f"Generated with seed: {seed}", output_path
241
+
242
+ # Get list of .pt files in the models directory
243
+ model_files = glob.glob(os.path.join(MODELS_DIR, "*.pt"))
244
+ model_choices = [os.path.basename(f) for f in model_files]
245
+
246
+ # Ensure we have at least one model
247
+ if not model_choices:
248
+ print(f"No models found in the models directory: {MODELS_DIR}")
249
+ print("Available files in the directory:")
250
+ print(os.listdir(MODELS_DIR))
251
+ model_choices = ["No models available"]
252
+
253
+ # Set default model
254
+ default_model = 'musicflow_b.pt' if 'musicflow_b.pt' in model_choices else model_choices[0]
255
+
256
+ # Set up dark grey theme
257
+ theme = gr.themes.Monochrome(
258
+ primary_hue="gray",
259
+ secondary_hue="gray",
260
+ neutral_hue="gray",
261
+ radius_size=gr.themes.sizes.radius_sm,
262
+ )
263
 
264
  # Gradio Interface
265
  with gr.Blocks(theme=theme) as iface:
266
+ gr.Markdown(
267
+ """
268
+ <div style="text-align: center;">
269
+ <h1>FluxMusic Generator</h1>
270
+ <p>Generate music based on text prompts using FluxMusic model.</p>
271
+ <p>Feel free to clone this space and run on GPU locally or on Hugging Face.</p>
272
+ </div>
273
+ """)
274
+
275
+ with gr.Row():
276
+ model_dropdown = gr.Dropdown(choices=model_choices, label="Select Model", value=default_model)
277
+ model_url = gr.Textbox(label="Or enter model URL")
278
+ device_choice = gr.Radio(["cpu", "cuda"], label="Device", value="cpu")
279
+ load_model_button = gr.Button("Load Model")
280
+ model_status = gr.Textbox(label="Model Status", value="No model loaded")
281
+
282
+ with gr.Row():
283
+ prompt = gr.Textbox(label="Prompt")
284
+ seed = gr.Number(label="Seed", value=0)
285
+
286
+ with gr.Row():
287
+ cfg_scale = gr.Slider(minimum=1, maximum=40, step=0.1, label="CFG Scale", value=20)
288
+ steps = gr.Slider(minimum=10, maximum=200, step=1, label="Steps", value=100)
289
+ duration = gr.Number(label="Duration (seconds)", value=10, minimum=10, maximum=300, step=1)
290
+
291
+ generate_button = gr.Button("Generate Music")
292
+ output_status = gr.Textbox(label="Generation Status")
293
+ output_audio = gr.Audio(type="filepath")
294
 
295
  def on_load_model_click(model_name, device, url):
296
  if url: