import laion_clap import numpy as np import librosa import pickle import os from sklearn.metrics.pairwise import cosine_similarity import pandas as pd import zipfile import json dataset_zip = "dataset/one_shot_percussive_sounds.zip" extracted_folder = "dataset/unzipped" metadata_path = "dataset/licenses.txt" audio_embeddings_path = "dataset/audio_embeddings.pkl" # Unzip if not already extracted if not os.path.exists(extracted_folder): with zipfile.ZipFile(dataset_zip, "r") as zip_ref: zip_ref.extractall("dataset") # Load the model model = laion_clap.CLAP_Module(enable_fusion=True) model.load_ckpt(model_id=3) # Load dataset metadata with open(metadata_path, "r") as file: data = json.load(file) # Convert the JSON data into a Pandas DataFrame metadata = pd.DataFrame.from_dict(data, orient="index") metadata.index = metadata.index.astype(str) + '.wav' # Load precomputed audio embeddings (to avoid recomputing on every request) with open(audio_embeddings_path, "rb") as f: audio_embeddings = pickle.load(f) def get_clap_embeddings_from_text(text): """Convert user text input to a CLAP embedding.""" text_embed = model.get_text_embedding([text]) return text_embed[0, :] def find_top_sounds(text_embed, instrument, top_N=4): """Finds the closest N sounds for an instrument.""" valid_sounds = metadata[metadata["Instrument"] == instrument].index.tolist() relevant_embeddings = {k: v for k, v in audio_embeddings.items() if k in valid_sounds} # Compute cosine similarity all_embeds = np.array([v for v in relevant_embeddings.values()]) similarities = cosine_similarity([text_embed], all_embeds)[0] # Get top N matches top_indices = np.argsort(similarities)[-top_N:][::-1] top_files = [valid_sounds[i] for i in top_indices] return top_files def generate_drum_kit(prompt, kit_size=4): """Generate a drum kit dictionary from user input.""" text_embed = get_clap_embeddings_from_text(prompt) drum_kit = {} for instrument in ["Kick", "Snare", "Hi-Hat", "Tom", "Cymbal", "Clap", "Percussion", "Other"]: drum_kit[instrument] = find_top_sounds(text_embed, instrument, top_N=kit_size) return drum_kit