flosstradamus commited on
Commit
89589a7
·
verified ·
1 Parent(s): c5125f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -4
app.py CHANGED
@@ -12,7 +12,143 @@ import re
12
  import requests
13
  import time
14
 
15
- # ... (keep the imports and global variables)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def generate_music(prompt, seed, cfg_scale, steps, duration, device, batch_size=1, progress=gr.Progress()):
18
  global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion
@@ -93,13 +229,88 @@ def generate_music(prompt, seed, cfg_scale, steps, duration, device, batch_size=
93
 
94
  all_waveforms.append(waveform)
95
 
96
- # ... (keep the rest of the function unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- # ... (keep the rest of the file unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  # Gradio Interface
101
  with gr.Blocks(theme=theme) as iface:
102
- # ... (keep the interface definition unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  def on_load_model_click(model_name, device, url):
105
  # Ensure we're using CPU if CUDA is not available
 
12
  import requests
13
  import time
14
 
15
+ # Import necessary functions and classes
16
+ from utils import load_t5, load_clap
17
+ from train import RF
18
+ from constants import build_model
19
+
20
+ # Global variables to store loaded models and resources
21
+ global_model = None
22
+ global_t5 = None
23
+ global_clap = None
24
+ global_vae = None
25
+ global_vocoder = None
26
+ global_diffusion = None
27
+ current_model_name = None
28
+
29
+ # Set the models directory
30
+ MODELS_DIR = os.path.join(os.path.dirname(__file__), "models")
31
+ GENERATIONS_DIR = os.path.join(os.path.dirname(__file__), "generations")
32
+
33
+ def prepare(t5, clip, img, prompt):
34
+ bs, c, h, w = img.shape
35
+ if bs == 1 and not isinstance(prompt, str):
36
+ bs = len(prompt)
37
+
38
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
39
+ if img.shape[0] == 1 and bs > 1:
40
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
41
+
42
+ img_ids = torch.zeros(h // 2, w // 2, 3)
43
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
44
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
45
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
46
+
47
+ if isinstance(prompt, str):
48
+ prompt = [prompt]
49
+
50
+ # Generate text embeddings
51
+ txt = t5(prompt)
52
+
53
+ if txt.shape[0] == 1 and bs > 1:
54
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
55
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
56
+
57
+ vec = clip(prompt)
58
+ if vec.shape[0] == 1 and bs > 1:
59
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
60
+
61
+ return img, {
62
+ "img_ids": img_ids.to(img.device),
63
+ "txt": txt.to(img.device),
64
+ "txt_ids": txt_ids.to(img.device),
65
+ "y": vec.to(img.device),
66
+ }
67
+
68
+ def unload_current_model():
69
+ global global_model, current_model_name
70
+ if global_model is not None:
71
+ del global_model
72
+ torch.cuda.empty_cache()
73
+ global_model = None
74
+ current_model_name = None
75
+
76
+ def load_model(model_name, device, model_url=None):
77
+ global global_model, current_model_name
78
+
79
+ unload_current_model()
80
+
81
+ if model_url:
82
+ print(f"Downloading model from URL: {model_url}")
83
+ response = requests.get(model_url)
84
+ if response.status_code == 200:
85
+ model_path = os.path.join(MODELS_DIR, "downloaded_model.pt")
86
+ with open(model_path, 'wb') as f:
87
+ f.write(response.content)
88
+ model_name = "downloaded_model.pt"
89
+ else:
90
+ return f"Failed to download model from URL: {model_url}"
91
+ else:
92
+ model_path = os.path.join(MODELS_DIR, model_name)
93
+
94
+ if not os.path.exists(model_path):
95
+ return f"Model file not found: {model_path}"
96
+
97
+ # Determine model size from filename
98
+ if 'musicflow_b' in model_name:
99
+ model_size = "base"
100
+ elif 'musicflow_g' in model_name:
101
+ model_size = "giant"
102
+ elif 'musicflow_l' in model_name:
103
+ model_size = "large"
104
+ elif 'musicflow_s' in model_name:
105
+ model_size = "small"
106
+ else:
107
+ model_size = "base" # Default to base if unrecognized
108
+
109
+ print(f"Loading {model_size} model: {model_name}")
110
+
111
+ try:
112
+ start_time = time.time()
113
+ global_model = build_model(model_size).to(device)
114
+ state_dict = torch.load(model_path, map_location=device, weights_only=True)
115
+ global_model.load_state_dict(state_dict['ema'], strict=False)
116
+ global_model.eval()
117
+
118
+ global_model.model_path = model_path
119
+ current_model_name = model_name
120
+ end_time = time.time()
121
+ load_time = end_time - start_time
122
+ return f"Successfully loaded model: {model_name} in {load_time:.2f} seconds"
123
+ except Exception as e:
124
+ global_model = None
125
+ current_model_name = None
126
+ print(f"Error loading model {model_name}: {str(e)}")
127
+ return f"Failed to load model: {model_name}. Error: {str(e)}"
128
+
129
+ def load_resources(device):
130
+ global global_t5, global_clap, global_vae, global_vocoder, global_diffusion
131
+
132
+ try:
133
+ start_time = time.time()
134
+ print("Loading T5 and CLAP models...")
135
+ global_t5 = load_t5(device, max_length=256)
136
+ global_clap = load_clap(device, max_length=256)
137
+
138
+ print("Loading VAE and vocoder...")
139
+ global_vae = AutoencoderKL.from_pretrained('cvssp/audioldm2', subfolder="vae").to(device)
140
+ global_vocoder = SpeechT5HifiGan.from_pretrained('cvssp/audioldm2', subfolder="vocoder").to(device)
141
+
142
+ print("Initializing diffusion...")
143
+ global_diffusion = RF()
144
+
145
+ end_time = time.time()
146
+ load_time = end_time - start_time
147
+ print(f"Base resources loaded successfully in {load_time:.2f} seconds!")
148
+ return f"Resources loaded successfully in {load_time:.2f} seconds!"
149
+ except Exception as e:
150
+ print(f"Error loading resources: {str(e)}")
151
+ return f"Failed to load resources. Error: {str(e)}"
152
 
153
  def generate_music(prompt, seed, cfg_scale, steps, duration, device, batch_size=1, progress=gr.Progress()):
154
  global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion
 
229
 
230
  all_waveforms.append(waveform)
231
 
232
+ # Concatenate all waveforms
233
+ final_waveform = np.concatenate(all_waveforms)
234
+
235
+ # Trim to exact duration
236
+ sample_rate = 16000
237
+ final_waveform = final_waveform[:int(duration * sample_rate)]
238
+
239
+ progress(0.9, desc="Saving audio file")
240
+
241
+ # Create 'generations' folder
242
+ os.makedirs(GENERATIONS_DIR, exist_ok=True)
243
+
244
+ # Generate filename
245
+ prompt_part = re.sub(r'[^\w\s-]', '', prompt)[:10].strip().replace(' ', '_')
246
+ model_name = os.path.splitext(os.path.basename(global_model.model_path))[0]
247
+ model_suffix = '_mf_b' if model_name == 'musicflow_b' else f'_{model_name}'
248
+ base_filename = f"{prompt_part}_{seed}{model_suffix}"
249
+ output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}.wav")
250
+
251
+ # Check if file exists and add numerical suffix if needed
252
+ counter = 1
253
+ while os.path.exists(output_path):
254
+ output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}_{counter}.wav")
255
+ counter += 1
256
 
257
+ wavfile.write(output_path, sample_rate, final_waveform)
258
+
259
+ progress(1.0, desc="Audio generation complete")
260
+ return f"Generated with seed: {seed}", output_path
261
+
262
+ # Get list of .pt files in the models directory
263
+ model_files = glob.glob(os.path.join(MODELS_DIR, "*.pt"))
264
+ model_choices = [os.path.basename(f) for f in model_files]
265
+
266
+ # Ensure we have at least one model
267
+ if not model_choices:
268
+ print(f"No models found in the models directory: {MODELS_DIR}")
269
+ print("Available files in the directory:")
270
+ print(os.listdir(MODELS_DIR))
271
+ model_choices = ["No models available"]
272
+
273
+ # Set default model
274
+ default_model = 'musicflow_b.pt' if 'musicflow_b.pt' in model_choices else model_choices[0]
275
+
276
+ # Set up dark grey theme
277
+ theme = gr.themes.Monochrome(
278
+ primary_hue="gray",
279
+ secondary_hue="gray",
280
+ neutral_hue="gray",
281
+ radius_size=gr.themes.sizes.radius_sm,
282
+ )
283
 
284
  # Gradio Interface
285
  with gr.Blocks(theme=theme) as iface:
286
+ gr.Markdown(
287
+ """
288
+ <div style="text-align: center;">
289
+ <h1>FluxMusic Generator</h1>
290
+ <p>Generate music based on text prompts using FluxMusic model.</p>
291
+ <p>Feel free to clone this space and run on GPU locally or on Hugging Face.</p>
292
+ </div>
293
+ """)
294
+
295
+ with gr.Row():
296
+ model_dropdown = gr.Dropdown(choices=model_choices, label="Select Model", value=default_model)
297
+ model_url = gr.Textbox(label="Or enter model URL")
298
+ device_choice = gr.Radio(["cpu", "cuda"], label="Device", value="cpu")
299
+ load_model_button = gr.Button("Load Model")
300
+ model_status = gr.Textbox(label="Model Status", value="No model loaded")
301
+
302
+ with gr.Row():
303
+ prompt = gr.Textbox(label="Prompt")
304
+ seed = gr.Number(label="Seed", value=0)
305
+
306
+ with gr.Row():
307
+ cfg_scale = gr.Slider(minimum=1, maximum=40, step=0.1, label="CFG Scale", value=20)
308
+ steps = gr.Slider(minimum=10, maximum=200, step=1, label="Steps", value=100)
309
+ duration = gr.Number(label="Duration (seconds)", value=10, minimum=10, maximum=300, step=1)
310
+
311
+ generate_button = gr.Button("Generate Music")
312
+ output_status = gr.Textbox(label="Generation Status")
313
+ output_audio = gr.Audio(type="filepath")
314
 
315
  def on_load_model_click(model_name, device, url):
316
  # Ensure we're using CPU if CUDA is not available