flosstradamus commited on
Commit
41e9bb3
·
verified ·
1 Parent(s): 6a91f5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -202
app.py CHANGED
@@ -29,40 +29,7 @@ current_model_name = None
29
  MODELS_DIR = os.path.join(os.path.dirname(__file__), "models")
30
  GENERATIONS_DIR = os.path.join(os.path.dirname(__file__), "generations")
31
 
32
- def prepare(t5, clip, img, prompt):
33
- bs, c, h, w = img.shape
34
- if bs == 1 and not isinstance(prompt, str):
35
- bs = len(prompt)
36
-
37
- img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
38
- if img.shape[0] == 1 and bs > 1:
39
- img = repeat(img, "1 ... -> bs ...", bs=bs)
40
-
41
- img_ids = torch.zeros(h // 2, w // 2, 3)
42
- img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
43
- img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
44
- img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
45
-
46
- if isinstance(prompt, str):
47
- prompt = [prompt]
48
-
49
- # Generate text embeddings
50
- txt = t5(prompt)
51
-
52
- if txt.shape[0] == 1 and bs > 1:
53
- txt = repeat(txt, "1 ... -> bs ...", bs=bs)
54
- txt_ids = torch.zeros(bs, txt.shape[1], 3)
55
-
56
- vec = clip(prompt)
57
- if vec.shape[0] == 1 and bs > 1:
58
- vec = repeat(vec, "1 ... -> bs ...", bs=bs)
59
-
60
- return img, {
61
- "img_ids": img_ids.to(img.device),
62
- "txt": txt.to(img.device),
63
- "txt_ids": txt_ids.to(img.device),
64
- "y": vec.to(img.device),
65
- }
66
 
67
  def unload_current_model():
68
  global global_model, current_model_name
@@ -90,6 +57,9 @@ def load_model(model_name, device, model_url=None):
90
  else:
91
  model_path = os.path.join(MODELS_DIR, model_name)
92
 
 
 
 
93
  # Determine model size from filename
94
  if 'musicflow_b' in model_name:
95
  model_size = "base"
@@ -104,9 +74,8 @@ def load_model(model_name, device, model_url=None):
104
 
105
  print(f"Loading {model_size} model: {model_name}")
106
 
107
- global_model = build_model(model_size).to(device)
108
-
109
  try:
 
110
  state_dict = torch.load(model_path, map_location=device, weights_only=True)
111
  global_model.load_state_dict(state_dict['ema'], strict=False)
112
  global_model.eval()
@@ -115,178 +84,16 @@ def load_model(model_name, device, model_url=None):
115
  current_model_name = model_name
116
  return f"Successfully loaded model: {model_name}"
117
  except Exception as e:
 
 
118
  print(f"Error loading model {model_name}: {str(e)}")
119
  return f"Failed to load model: {model_name}. Error: {str(e)}"
120
 
121
- def load_resources(device):
122
- global global_t5, global_clap, global_vae, global_vocoder, global_diffusion
123
-
124
- print("Loading T5 and CLAP models...")
125
- global_t5 = load_t5(device, max_length=256)
126
- global_clap = load_clap(device, max_length=256)
127
-
128
- print("Loading VAE and vocoder...")
129
- global_vae = AutoencoderKL.from_pretrained('cvssp/audioldm2', subfolder="vae").to(device)
130
- global_vocoder = SpeechT5HifiGan.from_pretrained('cvssp/audioldm2', subfolder="vocoder").to(device)
131
-
132
- print("Initializing diffusion...")
133
- global_diffusion = RF()
134
-
135
- print("Base resources loaded successfully!")
136
-
137
- def generate_music(prompt, seed, cfg_scale, steps, duration, device, batch_size=4, progress=gr.Progress()):
138
- global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion
139
-
140
- if global_model is None:
141
- return "Please select and load a model first.", None
142
-
143
- if seed == 0:
144
- seed = random.randint(1, 1000000)
145
- print(f"Using seed: {seed}")
146
-
147
- torch.manual_seed(seed)
148
- torch.set_grad_enabled(False)
149
-
150
- # Calculate the number of segments needed for the desired duration
151
- segment_duration = 10 # Each segment is 10 seconds
152
- num_segments = int(np.ceil(duration / segment_duration))
153
-
154
- all_waveforms = []
155
-
156
- for i in range(num_segments):
157
- progress(i / num_segments, desc=f"Generating segment {i+1}/{num_segments}")
158
-
159
- # Use the same seed for all segments
160
- torch.manual_seed(seed + i) # Add i to slightly vary each segment while maintaining consistency
161
-
162
- latent_size = (256, 16)
163
- conds_txt = [prompt]
164
- unconds_txt = ["low quality, gentle"]
165
- L = len(conds_txt)
166
-
167
- init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).to(device)
168
-
169
- img, conds = prepare(global_t5, global_clap, init_noise, conds_txt)
170
- _, unconds = prepare(global_t5, global_clap, init_noise, unconds_txt)
171
-
172
- # Implement batching for CPU inference
173
- images = []
174
- for batch_start in range(0, img.shape[0], batch_size):
175
- batch_end = min(batch_start + batch_size, img.shape[0])
176
- batch_img = img[batch_start:batch_end]
177
- batch_conds = {k: v[batch_start:batch_end] for k, v in conds.items()}
178
- batch_unconds = {k: v[batch_start:batch_end] for k, v in unconds.items()}
179
-
180
- with torch.no_grad():
181
- batch_images = global_diffusion.sample_with_xps(
182
- global_model, batch_img, conds=batch_conds, null_cond=batch_unconds,
183
- sample_steps=steps, cfg=cfg_scale
184
- )
185
- images.append(batch_images[-1])
186
-
187
- images = torch.cat(images, dim=0)
188
-
189
- images = rearrange(
190
- images,
191
- "b (h w) (c ph pw) -> b c (h ph) (w pw)",
192
- h=128,
193
- w=8,
194
- ph=2,
195
- pw=2,)
196
-
197
- latents = 1 / global_vae.config.scaling_factor * images
198
- mel_spectrogram = global_vae.decode(latents).sample
199
-
200
- x_i = mel_spectrogram[0]
201
- if x_i.dim() == 4:
202
- x_i = x_i.squeeze(1)
203
- waveform = global_vocoder(x_i)
204
- waveform = waveform[0].cpu().float().detach().numpy()
205
-
206
- all_waveforms.append(waveform)
207
-
208
- # Concatenate all waveforms
209
- final_waveform = np.concatenate(all_waveforms)
210
-
211
- # Trim to exact duration
212
- sample_rate = 16000
213
- final_waveform = final_waveform[:int(duration * sample_rate)]
214
-
215
- progress(0.9, desc="Saving audio file")
216
-
217
- # Create 'generations' folder
218
- os.makedirs(GENERATIONS_DIR, exist_ok=True)
219
-
220
- # Generate filename
221
- prompt_part = re.sub(r'[^\w\s-]', '', prompt)[:10].strip().replace(' ', '_')
222
- model_name = os.path.splitext(os.path.basename(global_model.model_path))[0]
223
- model_suffix = '_mf_b' if model_name == 'musicflow_b' else f'_{model_name}'
224
- base_filename = f"{prompt_part}_{seed}{model_suffix}"
225
- output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}.wav")
226
-
227
- # Check if file exists and add numerical suffix if needed
228
- counter = 1
229
- while os.path.exists(output_path):
230
- output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}_{counter}.wav")
231
- counter += 1
232
-
233
- wavfile.write(output_path, sample_rate, final_waveform)
234
-
235
- progress(1.0, desc="Audio generation complete")
236
- return f"Generated with seed: {seed}", output_path
237
-
238
- # Get list of .pt files in the models directory
239
- model_files = glob.glob(os.path.join(MODELS_DIR, "*.pt"))
240
- model_choices = [os.path.basename(f) for f in model_files]
241
-
242
- # Ensure we have at least one model
243
- if not model_choices:
244
- print(f"No models found in the models directory: {MODELS_DIR}")
245
- print("Available files in the directory:")
246
- print(os.listdir(MODELS_DIR))
247
- model_choices = ["No models available"]
248
-
249
- # Set default model
250
- default_model = 'musicflow_b.pt' if 'musicflow_b.pt' in model_choices else model_choices[0]
251
-
252
- # Set up dark grey theme
253
- theme = gr.themes.Monochrome(
254
- primary_hue="gray",
255
- secondary_hue="gray",
256
- neutral_hue="gray",
257
- radius_size=gr.themes.sizes.radius_sm,
258
- )
259
 
260
  # Gradio Interface
261
  with gr.Blocks(theme=theme) as iface:
262
- gr.Markdown(
263
- """
264
- <div style="text-align: center;">
265
- <h1>FluxMusic Generator</h1>
266
- <p>Generate music based on text prompts using FluxMusic model.</p>
267
- <p>Feel free to clone this space and run on GPU locally or on Hugging Face.</p>
268
- </div>
269
- """)
270
-
271
- with gr.Row():
272
- model_dropdown = gr.Dropdown(choices=model_choices, label="Select Model", value=default_model)
273
- model_url = gr.Textbox(label="Or enter model URL")
274
- device_choice = gr.Radio(["cpu", "cuda"], label="Device", value="cpu")
275
- load_model_button = gr.Button("Load Model")
276
- model_status = gr.Textbox(label="Model Status", value="No model loaded")
277
-
278
- with gr.Row():
279
- prompt = gr.Textbox(label="Prompt")
280
- seed = gr.Number(label="Seed", value=0)
281
-
282
- with gr.Row():
283
- cfg_scale = gr.Slider(minimum=1, maximum=40, step=0.1, label="CFG Scale", value=20)
284
- steps = gr.Slider(minimum=10, maximum=200, step=1, label="Steps", value=100)
285
- duration = gr.Number(label="Duration (seconds)", value=10, minimum=10, maximum=300, step=1)
286
-
287
- generate_button = gr.Button("Generate Music")
288
- output_status = gr.Textbox(label="Generation Status")
289
- output_audio = gr.Audio(type="filepath")
290
 
291
  def on_load_model_click(model_name, device, url):
292
  if url:
 
29
  MODELS_DIR = os.path.join(os.path.dirname(__file__), "models")
30
  GENERATIONS_DIR = os.path.join(os.path.dirname(__file__), "generations")
31
 
32
+ # ... (keep other functions unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def unload_current_model():
35
  global global_model, current_model_name
 
57
  else:
58
  model_path = os.path.join(MODELS_DIR, model_name)
59
 
60
+ if not os.path.exists(model_path):
61
+ return f"Model file not found: {model_path}"
62
+
63
  # Determine model size from filename
64
  if 'musicflow_b' in model_name:
65
  model_size = "base"
 
74
 
75
  print(f"Loading {model_size} model: {model_name}")
76
 
 
 
77
  try:
78
+ global_model = build_model(model_size).to(device)
79
  state_dict = torch.load(model_path, map_location=device, weights_only=True)
80
  global_model.load_state_dict(state_dict['ema'], strict=False)
81
  global_model.eval()
 
84
  current_model_name = model_name
85
  return f"Successfully loaded model: {model_name}"
86
  except Exception as e:
87
+ global_model = None
88
+ current_model_name = None
89
  print(f"Error loading model {model_name}: {str(e)}")
90
  return f"Failed to load model: {model_name}. Error: {str(e)}"
91
 
92
+ # ... (keep the rest of the file unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  # Gradio Interface
95
  with gr.Blocks(theme=theme) as iface:
96
+ # ... (keep the interface definition unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def on_load_model_click(model_name, device, url):
99
  if url: