import os import torch import gradio as gr from einops import rearrange, repeat from diffusers import AutoencoderKL from transformers import SpeechT5HifiGan from scipy.io import wavfile import glob import random import numpy as np import re # Import necessary functions and classes from utils import load_t5, load_clap from train import RF from constants import build_model # Disable flash attention if not available torch.backends.cuda.enable_flash_sdp(False) # Global variables to store loaded models and resources global_model = None global_t5 = None global_clap = None global_vae = None global_vocoder = None global_diffusion = None # Set the models directory MODELS_DIR = "/content/models" GENERATIONS_DIR = "/content/generations" def prepare(t5, clip, img, prompt): # ... [The prepare function remains unchanged] pass def unload_current_model(): global global_model if global_model is not None: del global_model torch.cuda.empty_cache() global_model = None def load_model(model_name): global global_model device = "cuda" if torch.cuda.is_available() else "cpu" unload_current_model() # Determine model size from filename if 'musicflow_b' in model_name: model_size = "base" elif 'musicflow_g' in model_name: model_size = "giant" elif 'musicflow_l' in model_name: model_size = "large" elif 'musicflow_s' in model_name: model_size = "small" else: model_size = "base" # Default to base if unrecognized print(f"Loading {model_size} model: {model_name}") model_path = os.path.join(MODELS_DIR, model_name) global_model = build_model(model_size).to(device) state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True) global_model.load_state_dict(state_dict['ema']) global_model.eval() global_model.model_path = model_path def load_resources(): global global_t5, global_clap, global_vae, global_vocoder, global_diffusion device = "cuda" if torch.cuda.is_available() else "cpu" print("Loading T5 and CLAP models...") global_t5 = load_t5(device, max_length=256) global_clap = load_clap(device, max_length=256) print("Loading VAE and vocoder...") global_vae = AutoencoderKL.from_pretrained('cvssp/audioldm2', subfolder="vae").to(device) global_vocoder = SpeechT5HifiGan.from_pretrained('cvssp/audioldm2', subfolder="vocoder").to(device) print("Initializing diffusion...") global_diffusion = RF() print("Base resources loaded successfully!") def generate_music(prompt, seed, cfg_scale, steps, duration, progress=gr.Progress()): # ... [The generate_music function remains largely unchanged] # Update the output directory output_dir = GENERATIONS_DIR os.makedirs(output_dir, exist_ok=True) # ... [Rest of the function remains the same] pass # Load base resources at startup load_resources() # Get list of .pt files in the models directory model_files = glob.glob(os.path.join(MODELS_DIR, "*.pt")) model_choices = [os.path.basename(f) for f in model_files] # Ensure 'musicflow_b.pt' is the default choice if it exists default_model = 'musicflow_b.pt' if default_model in model_choices: model_choices.remove(default_model) model_choices.insert(0, default_model) # Set up dark grey theme theme = gr.themes.Monochrome( primary_hue="gray", secondary_hue="gray", neutral_hue="gray", radius_size=gr.themes.sizes.radius_sm, ) # Gradio Interface with gr.Blocks(theme=theme) as iface: gr.Markdown( """

FluxMusic Generator

Generate music based on text prompts using FluxMusic model.

""") with gr.Row(): model_dropdown = gr.Dropdown(choices=model_choices, label="Select Model", value=default_model if default_model in model_choices else model_choices[0]) with gr.Row(): prompt = gr.Textbox(label="Prompt") seed = gr.Number(label="Seed", value=0) with gr.Row(): cfg_scale = gr.Slider(minimum=1, maximum=40, step=0.1, label="CFG Scale", value=20) steps = gr.Slider(minimum=10, maximum=200, step=1, label="Steps", value=100) duration = gr.Number(label="Duration (seconds)", value=10, minimum=10, maximum=300, step=1) generate_button = gr.Button("Generate Music") output_status = gr.Textbox(label="Generation Status") output_audio = gr.Audio(type="filepath") def on_model_change(model_name): load_model(model_name) model_dropdown.change(on_model_change, inputs=[model_dropdown]) generate_button.click(generate_music, inputs=[prompt, seed, cfg_scale, steps, duration], outputs=[output_status, output_audio]) # Load default model on startup default_model_path = os.path.join(MODELS_DIR, default_model) if os.path.exists(default_model_path): iface.load(lambda: load_model(default_model), inputs=None, outputs=None) # Launch the interface iface.launch()