flosstradamus commited on
Commit
c5125f1
·
verified ·
1 Parent(s): e39e85d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -215
app.py CHANGED
@@ -12,143 +12,7 @@ import re
12
  import requests
13
  import time
14
 
15
- # Import necessary functions and classes
16
- from utils import load_t5, load_clap
17
- from train import RF
18
- from constants import build_model
19
-
20
- # Global variables to store loaded models and resources
21
- global_model = None
22
- global_t5 = None
23
- global_clap = None
24
- global_vae = None
25
- global_vocoder = None
26
- global_diffusion = None
27
- current_model_name = None
28
-
29
- # Set the models directory
30
- MODELS_DIR = os.path.join(os.path.dirname(__file__), "models")
31
- GENERATIONS_DIR = os.path.join(os.path.dirname(__file__), "generations")
32
-
33
- def prepare(t5, clip, img, prompt):
34
- bs, c, h, w = img.shape
35
- if bs == 1 and not isinstance(prompt, str):
36
- bs = len(prompt)
37
-
38
- img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
39
- if img.shape[0] == 1 and bs > 1:
40
- img = repeat(img, "1 ... -> bs ...", bs=bs)
41
-
42
- img_ids = torch.zeros(h // 2, w // 2, 3)
43
- img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
44
- img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
45
- img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
46
-
47
- if isinstance(prompt, str):
48
- prompt = [prompt]
49
-
50
- # Generate text embeddings
51
- txt = t5(prompt)
52
-
53
- if txt.shape[0] == 1 and bs > 1:
54
- txt = repeat(txt, "1 ... -> bs ...", bs=bs)
55
- txt_ids = torch.zeros(bs, txt.shape[1], 3)
56
-
57
- vec = clip(prompt)
58
- if vec.shape[0] == 1 and bs > 1:
59
- vec = repeat(vec, "1 ... -> bs ...", bs=bs)
60
-
61
- return img, {
62
- "img_ids": img_ids.to(img.device),
63
- "txt": txt.to(img.device),
64
- "txt_ids": txt_ids.to(img.device),
65
- "y": vec.to(img.device),
66
- }
67
-
68
- def unload_current_model():
69
- global global_model, current_model_name
70
- if global_model is not None:
71
- del global_model
72
- torch.cuda.empty_cache()
73
- global_model = None
74
- current_model_name = None
75
-
76
- def load_model(model_name, device, model_url=None):
77
- global global_model, current_model_name
78
-
79
- unload_current_model()
80
-
81
- if model_url:
82
- print(f"Downloading model from URL: {model_url}")
83
- response = requests.get(model_url)
84
- if response.status_code == 200:
85
- model_path = os.path.join(MODELS_DIR, "downloaded_model.pt")
86
- with open(model_path, 'wb') as f:
87
- f.write(response.content)
88
- model_name = "downloaded_model.pt"
89
- else:
90
- return f"Failed to download model from URL: {model_url}"
91
- else:
92
- model_path = os.path.join(MODELS_DIR, model_name)
93
-
94
- if not os.path.exists(model_path):
95
- return f"Model file not found: {model_path}"
96
-
97
- # Determine model size from filename
98
- if 'musicflow_b' in model_name:
99
- model_size = "base"
100
- elif 'musicflow_g' in model_name:
101
- model_size = "giant"
102
- elif 'musicflow_l' in model_name:
103
- model_size = "large"
104
- elif 'musicflow_s' in model_name:
105
- model_size = "small"
106
- else:
107
- model_size = "base" # Default to base if unrecognized
108
-
109
- print(f"Loading {model_size} model: {model_name}")
110
-
111
- try:
112
- start_time = time.time()
113
- global_model = build_model(model_size).to(device)
114
- state_dict = torch.load(model_path, map_location=device, weights_only=True)
115
- global_model.load_state_dict(state_dict['ema'], strict=False)
116
- global_model.eval()
117
-
118
- global_model.model_path = model_path
119
- current_model_name = model_name
120
- end_time = time.time()
121
- load_time = end_time - start_time
122
- return f"Successfully loaded model: {model_name} in {load_time:.2f} seconds"
123
- except Exception as e:
124
- global_model = None
125
- current_model_name = None
126
- print(f"Error loading model {model_name}: {str(e)}")
127
- return f"Failed to load model: {model_name}. Error: {str(e)}"
128
-
129
- def load_resources(device):
130
- global global_t5, global_clap, global_vae, global_vocoder, global_diffusion
131
-
132
- try:
133
- start_time = time.time()
134
- print("Loading T5 and CLAP models...")
135
- global_t5 = load_t5(device, max_length=256)
136
- global_clap = load_clap(device, max_length=256)
137
-
138
- print("Loading VAE and vocoder...")
139
- global_vae = AutoencoderKL.from_pretrained('cvssp/audioldm2', subfolder="vae").to(device)
140
- global_vocoder = SpeechT5HifiGan.from_pretrained('cvssp/audioldm2', subfolder="vocoder").to(device)
141
-
142
- print("Initializing diffusion...")
143
- global_diffusion = RF()
144
-
145
- end_time = time.time()
146
- load_time = end_time - start_time
147
- print(f"Base resources loaded successfully in {load_time:.2f} seconds!")
148
- return f"Resources loaded successfully in {load_time:.2f} seconds!"
149
- except Exception as e:
150
- print(f"Error loading resources: {str(e)}")
151
- return f"Failed to load resources. Error: {str(e)}"
152
 
153
  def generate_music(prompt, seed, cfg_scale, steps, duration, device, batch_size=1, progress=gr.Progress()):
154
  global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion
@@ -166,6 +30,11 @@ def generate_music(prompt, seed, cfg_scale, steps, duration, device, batch_size=
166
  torch.manual_seed(seed)
167
  torch.set_grad_enabled(False)
168
 
 
 
 
 
 
169
  # Calculate the number of segments needed for the desired duration
170
  segment_duration = 10 # Each segment is 10 seconds
171
  num_segments = int(np.ceil(duration / segment_duration))
@@ -224,90 +93,20 @@ def generate_music(prompt, seed, cfg_scale, steps, duration, device, batch_size=
224
 
225
  all_waveforms.append(waveform)
226
 
227
- # Concatenate all waveforms
228
- final_waveform = np.concatenate(all_waveforms)
229
-
230
- # Trim to exact duration
231
- sample_rate = 16000
232
- final_waveform = final_waveform[:int(duration * sample_rate)]
233
-
234
- progress(0.9, desc="Saving audio file")
235
-
236
- # Create 'generations' folder
237
- os.makedirs(GENERATIONS_DIR, exist_ok=True)
238
-
239
- # Generate filename
240
- prompt_part = re.sub(r'[^\w\s-]', '', prompt)[:10].strip().replace(' ', '_')
241
- model_name = os.path.splitext(os.path.basename(global_model.model_path))[0]
242
- model_suffix = '_mf_b' if model_name == 'musicflow_b' else f'_{model_name}'
243
- base_filename = f"{prompt_part}_{seed}{model_suffix}"
244
- output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}.wav")
245
-
246
- # Check if file exists and add numerical suffix if needed
247
- counter = 1
248
- while os.path.exists(output_path):
249
- output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}_{counter}.wav")
250
- counter += 1
251
-
252
- wavfile.write(output_path, sample_rate, final_waveform)
253
-
254
- progress(1.0, desc="Audio generation complete")
255
- return f"Generated with seed: {seed}", output_path
256
-
257
- # Get list of .pt files in the models directory
258
- model_files = glob.glob(os.path.join(MODELS_DIR, "*.pt"))
259
- model_choices = [os.path.basename(f) for f in model_files]
260
-
261
- # Ensure we have at least one model
262
- if not model_choices:
263
- print(f"No models found in the models directory: {MODELS_DIR}")
264
- print("Available files in the directory:")
265
- print(os.listdir(MODELS_DIR))
266
- model_choices = ["No models available"]
267
-
268
- # Set default model
269
- default_model = 'musicflow_b.pt' if 'musicflow_b.pt' in model_choices else model_choices[0]
270
 
271
- # Set up dark grey theme
272
- theme = gr.themes.Monochrome(
273
- primary_hue="gray",
274
- secondary_hue="gray",
275
- neutral_hue="gray",
276
- radius_size=gr.themes.sizes.radius_sm,
277
- )
278
 
279
  # Gradio Interface
280
  with gr.Blocks(theme=theme) as iface:
281
- gr.Markdown(
282
- """
283
- <div style="text-align: center;">
284
- <h1>FluxMusic Generator</h1>
285
- <p>Generate music based on text prompts using FluxMusic model.</p>
286
- <p>Feel free to clone this space and run on GPU locally or on Hugging Face.</p>
287
- </div>
288
- """)
289
-
290
- with gr.Row():
291
- model_dropdown = gr.Dropdown(choices=model_choices, label="Select Model", value=default_model)
292
- model_url = gr.Textbox(label="Or enter model URL")
293
- device_choice = gr.Radio(["cpu", "cuda"], label="Device", value="cpu")
294
- load_model_button = gr.Button("Load Model")
295
- model_status = gr.Textbox(label="Model Status", value="No model loaded")
296
-
297
- with gr.Row():
298
- prompt = gr.Textbox(label="Prompt")
299
- seed = gr.Number(label="Seed", value=0)
300
-
301
- with gr.Row():
302
- cfg_scale = gr.Slider(minimum=1, maximum=40, step=0.1, label="CFG Scale", value=20)
303
- steps = gr.Slider(minimum=10, maximum=200, step=1, label="Steps", value=100)
304
- duration = gr.Number(label="Duration (seconds)", value=10, minimum=10, maximum=300, step=1)
305
-
306
- generate_button = gr.Button("Generate Music")
307
- output_status = gr.Textbox(label="Generation Status")
308
- output_audio = gr.Audio(type="filepath")
309
 
310
  def on_load_model_click(model_name, device, url):
 
 
 
 
 
311
  resource_status = load_resources(device)
312
  if "Failed" in resource_status:
313
  return resource_status
 
12
  import requests
13
  import time
14
 
15
+ # ... (keep the imports and global variables)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def generate_music(prompt, seed, cfg_scale, steps, duration, device, batch_size=1, progress=gr.Progress()):
18
  global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion
 
30
  torch.manual_seed(seed)
31
  torch.set_grad_enabled(False)
32
 
33
+ # Ensure we're using CPU if CUDA is not available
34
+ if device == "cuda" and not torch.cuda.is_available():
35
+ print("CUDA is not available. Falling back to CPU.")
36
+ device = "cpu"
37
+
38
  # Calculate the number of segments needed for the desired duration
39
  segment_duration = 10 # Each segment is 10 seconds
40
  num_segments = int(np.ceil(duration / segment_duration))
 
93
 
94
  all_waveforms.append(waveform)
95
 
96
+ # ... (keep the rest of the function unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ # ... (keep the rest of the file unchanged)
 
 
 
 
 
 
99
 
100
  # Gradio Interface
101
  with gr.Blocks(theme=theme) as iface:
102
+ # ... (keep the interface definition unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  def on_load_model_click(model_name, device, url):
105
+ # Ensure we're using CPU if CUDA is not available
106
+ if device == "cuda" and not torch.cuda.is_available():
107
+ print("CUDA is not available. Falling back to CPU.")
108
+ device = "cpu"
109
+
110
  resource_status = load_resources(device)
111
  if "Failed" in resource_status:
112
  return resource_status