Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import laion_clap | |
import numpy as np | |
import librosa | |
import pickle | |
import os | |
from sklearn.metrics.pairwise import cosine_similarity | |
import pandas as pd | |
import zipfile | |
import json | |
dataset_zip = "dataset/one_shot_percussive_sounds.zip" | |
extracted_folder = "dataset/unzipped" | |
metadata_path = "dataset/licenses.txt" | |
audio_embeddings_path = "dataset/audio_embeddings.pkl" | |
# Unzip if not already extracted | |
if not os.path.exists(extracted_folder): | |
with zipfile.ZipFile(dataset_zip, "r") as zip_ref: | |
zip_ref.extractall("dataset") | |
# Load the model | |
model = laion_clap.CLAP_Module(enable_fusion=True) | |
model.load_ckpt(model_id=3) | |
# Load dataset metadata | |
with open(metadata_path, "r") as file: | |
data = json.load(file) | |
# Convert the JSON data into a Pandas DataFrame | |
metadata = pd.DataFrame.from_dict(data, orient="index") | |
metadata.index = metadata.index.astype(str) + '.wav' | |
# Load precomputed audio embeddings (to avoid recomputing on every request) | |
with open(audio_embeddings_path, "rb") as f: | |
audio_embeddings = pickle.load(f) | |
def get_clap_embeddings_from_text(text): | |
"""Convert user text input to a CLAP embedding.""" | |
text_embed = model.get_text_embedding([text]) | |
return text_embed[0, :] | |
def find_top_sounds(text_embed, instrument, top_N=4): | |
"""Finds the closest N sounds for an instrument.""" | |
valid_sounds = metadata[metadata["Instrument"] == instrument].index.tolist() | |
relevant_embeddings = {k: v for k, v in audio_embeddings.items() if k in valid_sounds} | |
# Compute cosine similarity | |
all_embeds = np.array([v for v in relevant_embeddings.values()]) | |
similarities = cosine_similarity([text_embed], all_embeds)[0] | |
# Get top N matches | |
top_indices = np.argsort(similarities)[-top_N:][::-1] | |
top_files = [valid_sounds[i] for i in top_indices] | |
return top_files | |
def generate_drum_kit(prompt, kit_size=4): | |
"""Generate a drum kit dictionary from user input.""" | |
text_embed = get_clap_embeddings_from_text(prompt) | |
drum_kit = {} | |
for instrument in ["Kick", "Snare", "Hi-Hat", "Tom", "Cymbal", "Clap", "Percussion", "Other"]: | |
drum_kit[instrument] = find_top_sounds(text_embed, instrument, top_N=kit_size) | |
return drum_kit |