from inference import get_clap_embeddings_from_audio, get_clap_embeddings_from_text from pedalboard import Pedalboard, Reverb, HighpassFilter, LowpassFilter, Distortion, Bitcrush from sklearn.metrics.pairwise import cosine_similarity import soundfile as sf from skopt import gp_minimize from skopt.space import Real import librosa import numpy as np import os concat_file_path = "temp_concat.wav" def concatenate_sounds(drum_kit, output_path="temp_concat.wav"): """Stitch together all drum sounds into one audio file.""" all_audio = [] sr = 48000 for instrument, samples in drum_kit.items(): for sample in samples: audio, _ = librosa.load(sample, sr=48000) all_audio.append(audio) # Concatenate all sounds with a small silence gap gap = np.zeros(int(sr * 0.2)) # 200ms silence between sounds full_audio = np.concatenate([item for audio in all_audio for item in (audio, gap)]) # Save to temp file sf.write(output_path, full_audio, sr) return output_path def evaluate_fitness(audio_path, text_embed): """Compute similarity between processed audio and text query.""" audio_embed = get_clap_embeddings_from_audio(audio_path) return cosine_similarity([text_embed], [audio_embed])[0][0] def apply_fx(audio_path, params, write_wav=True, output_dir="processed_audio"): """Apply EQ and Reverb to an audio file and return the modified file path.""" audio, sr = librosa.load(audio_path, sr=48000) board = Pedalboard([ LowpassFilter(cutoff_frequency_hz=params['lowpass']), HighpassFilter(cutoff_frequency_hz=params['highpass']), Distortion(drive_db=params['drive_db']), Bitcrush(bit_depth=params['bit_depth']), Reverb(room_size=params['reverb_size'], wet_level=params['reverb_wet']) ]) processed_audio = board(audio, sr) if write_wav: # Determine output directory dynamically base_dir = os.path.dirname(os.path.dirname(audio_path)) # Get 'dataset' level output_dir = os.path.join(base_dir, output_dir) # Ensure the output directory exists os.makedirs(output_dir, exist_ok=True) # Create new file path inside the processed_sounds directory file_name = os.path.basename(audio_path).replace(".wav", "_processed.wav") output_path = os.path.join(output_dir, file_name) # Save processed audio sf.write(output_path, processed_audio, sr) return output_path else: return processed_audio def objective_function(params, audio_file, text_embedding): """Objective function for Bayesian Optimization using the concatenated file.""" processed_audio = apply_fx(audio_file, { "lowpass": params[0], "highpass": params[1], "reverb_size": params[2], "reverb_wet": params[3], "drive_db": params[4], "bit_depth": params[5] }, write_wav=True) similarity = evaluate_fitness(processed_audio, text_embedding) return -similarity # Minimize negative similarity (maximize similarity) def get_params_dict(params_list): return { "lowpass cutoff (Hz)": params_list[0], "highpass cutoff (Hz)": params_list[1], "reverb size": params_list[2], "reverb mix": params_list[3], "distortion - gain_db": params_list[4], "bitcrush - bit depth": params_list[5] } # Define parameter search space search_space = [ Real(4000, 20000, name="lowpass"), Real(50, 1000, name="highpass"), Real(0.0, 0.8, name="reverb_size"), Real(0.2, 1.0, name="reverb_wet"), Real(0.0, 10.0, name="drive_db"), Real(4.0, 32.0, name="bit_depth") ] ##### Main function ##### def get_fx(drum_kit, fx_prompt): """Optimize FX settings for the entire drum kit by using a concatenated audio file.""" text_embedding = get_clap_embeddings_from_text(fx_prompt) # Concatenate all drum sounds concat_file = concatenate_sounds(drum_kit) # Define the objective function for the concatenated file def obj_func(params): return objective_function(params, concat_file, text_embedding) # Get CLAP similarity without FX (for evaluation purposes) pre_fx_fitness = - evaluate_fitness(concat_file_path, text_embedding) # Run Bayesian optimization res = gp_minimize(obj_func, search_space, n_calls=30, random_state=42) best_params = res.x # Get post-FX fitness (for evaluation purposes) post_fx_fitness = obj_func(best_params) # Apply the best FX parameters to each individual sound optimized_kit = {} for instrument, samples in drum_kit.items(): optimized_kit[instrument] = [apply_fx(sample, { "lowpass": best_params[0], "highpass": best_params[1], "reverb_size": best_params[2], "reverb_wet": best_params[3], "drive_db": best_params[4], "bit_depth": best_params[5] }, write_wav=True) for sample in samples] return optimized_kit, get_params_dict(best_params), pre_fx_fitness, post_fx_fitness