Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
ported audio_utils and inference from colab
Browse files- app.py +33 -0
- audio_utils.py +12 -0
- dataset/audio_embeddings.pkl +3 -0
- inference.py +65 -0
- requirements.txt +79 -0
app.py
CHANGED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
import librosa
|
4 |
+
import soundfile as sf
|
5 |
+
import numpy as np
|
6 |
+
from inference import generate_drum_kit
|
7 |
+
from audio_utils import play_audio
|
8 |
+
|
9 |
+
# Streamlit UI
|
10 |
+
st.title("Generate Drum Kits with Text")
|
11 |
+
|
12 |
+
# User Inputs
|
13 |
+
prompt = st.text_input("Describe your drum kit (e.g., 'warm vintage')", "8-bit video game")
|
14 |
+
kit_size = st.slider("Number of sounds per instrument:", 1, 10, 4)
|
15 |
+
|
16 |
+
# Run the inference
|
17 |
+
if st.button("Generate Drum Kit"):
|
18 |
+
drum_kit = generate_drum_kit(prompt, kit_size)
|
19 |
+
st.session_state["drum_kit"] = drum_kit # Store results
|
20 |
+
|
21 |
+
# Display results
|
22 |
+
if "drum_kit" in st.session_state:
|
23 |
+
drum_kit = st.session_state["drum_kit"]
|
24 |
+
st.subheader("Generated Drum Kit")
|
25 |
+
|
26 |
+
for instrument, sounds in drum_kit.items():
|
27 |
+
st.write(f"**{instrument}**")
|
28 |
+
cols = st.columns(len(sounds))
|
29 |
+
|
30 |
+
for i, sound_file in enumerate(sounds):
|
31 |
+
with cols[i]:
|
32 |
+
if st.button(f"▶️ {os.path.basename(sound_file)}", key=sound_file):
|
33 |
+
play_audio(sound_file)
|
audio_utils.py
CHANGED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import soundfile as sf
|
3 |
+
import librosa
|
4 |
+
import numpy as np
|
5 |
+
import io
|
6 |
+
|
7 |
+
def play_audio(file_path):
|
8 |
+
"""Load and play an audio file."""
|
9 |
+
audio, sr = librosa.load(file_path, sr=16000)
|
10 |
+
audio_buffer = io.BytesIO()
|
11 |
+
sf.write(audio_buffer, audio, sr, format="wav")
|
12 |
+
st.audio(audio_buffer, format="audio/wav")
|
dataset/audio_embeddings.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:734662f4d6f61035f7519a883918c52b6abc9d9eee0a03df1c4aaebb559d8408
|
3 |
+
size 21482036
|
inference.py
CHANGED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import laion_clap
|
2 |
+
import numpy as np
|
3 |
+
import librosa
|
4 |
+
import pickle
|
5 |
+
import os
|
6 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
7 |
+
import pandas as pd
|
8 |
+
import zipfile
|
9 |
+
import json
|
10 |
+
|
11 |
+
dataset_zip = "dataset/one_shot_percussive_sounds.zip"
|
12 |
+
extracted_folder = "dataset/unzipped"
|
13 |
+
metadata_path = "dataset/licenses.txt"
|
14 |
+
audio_embeddings_path = "dataset/audio_embeddings.pkl"
|
15 |
+
|
16 |
+
# Unzip if not already extracted
|
17 |
+
if not os.path.exists(extracted_folder):
|
18 |
+
with zipfile.ZipFile(dataset_zip, "r") as zip_ref:
|
19 |
+
zip_ref.extractall("dataset")
|
20 |
+
|
21 |
+
# Load the model
|
22 |
+
model = laion_clap.CLAP_Module(enable_fusion=True)
|
23 |
+
model.load_ckpt(model_id=3)
|
24 |
+
|
25 |
+
# Load dataset metadata
|
26 |
+
with open(metadata_path, "r") as file:
|
27 |
+
data = json.load(file)
|
28 |
+
|
29 |
+
# Convert the JSON data into a Pandas DataFrame
|
30 |
+
metadata = pd.DataFrame.from_dict(data, orient="index")
|
31 |
+
metadata.index = metadata.index.astype(str) + '.wav'
|
32 |
+
|
33 |
+
# Load precomputed audio embeddings (to avoid recomputing on every request)
|
34 |
+
with open(audio_embeddings_path, "rb") as f:
|
35 |
+
audio_embeddings = pickle.load(f)
|
36 |
+
|
37 |
+
def get_clap_embeddings_from_text(text):
|
38 |
+
"""Convert user text input to a CLAP embedding."""
|
39 |
+
text_embed = model.get_text_embedding([text])
|
40 |
+
return text_embed[0, :]
|
41 |
+
|
42 |
+
def find_top_sounds(text_embed, instrument, top_N=4):
|
43 |
+
"""Finds the closest N sounds for an instrument."""
|
44 |
+
valid_sounds = metadata[metadata["Instrument"] == instrument].index.tolist()
|
45 |
+
relevant_embeddings = {k: v for k, v in audio_embeddings.items() if k in valid_sounds}
|
46 |
+
|
47 |
+
# Compute cosine similarity
|
48 |
+
all_embeds = np.array([v for v in relevant_embeddings.values()])
|
49 |
+
similarities = cosine_similarity([text_embed], all_embeds)[0]
|
50 |
+
|
51 |
+
# Get top N matches
|
52 |
+
top_indices = np.argsort(similarities)[-top_N:][::-1]
|
53 |
+
top_files = [valid_sounds[i] for i in top_indices]
|
54 |
+
|
55 |
+
return top_files
|
56 |
+
|
57 |
+
def generate_drum_kit(prompt, kit_size=4):
|
58 |
+
"""Generate a drum kit dictionary from user input."""
|
59 |
+
text_embed = get_clap_embeddings_from_text(prompt)
|
60 |
+
drum_kit = {}
|
61 |
+
|
62 |
+
for instrument in ["Kick", "Snare", "Hi-Hat", "Tom", "Cymbal", "Clap", "Percussion", "Other"]:
|
63 |
+
drum_kit[instrument] = find_top_sounds(text_embed, instrument, top_N=kit_size)
|
64 |
+
|
65 |
+
return drum_kit
|
requirements.txt
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
altair==5.5.0
|
2 |
+
annotated-types==0.7.0
|
3 |
+
attrs==25.3.0
|
4 |
+
audioread==3.0.1
|
5 |
+
blinker==1.9.0
|
6 |
+
braceexpand==0.1.7
|
7 |
+
cachetools==5.5.2
|
8 |
+
certifi==2025.1.31
|
9 |
+
cffi==1.17.1
|
10 |
+
charset-normalizer==3.4.1
|
11 |
+
click==8.1.8
|
12 |
+
decorator==5.2.1
|
13 |
+
docker-pycreds==0.4.0
|
14 |
+
filelock==3.17.0
|
15 |
+
fsspec==2025.3.0
|
16 |
+
ftfy==6.3.1
|
17 |
+
gitdb==4.0.12
|
18 |
+
GitPython==3.1.44
|
19 |
+
h5py==3.13.0
|
20 |
+
huggingface-hub==0.29.3
|
21 |
+
idna==3.10
|
22 |
+
Jinja2==3.1.6
|
23 |
+
joblib==1.4.2
|
24 |
+
jsonschema==4.23.0
|
25 |
+
jsonschema-specifications==2024.10.1
|
26 |
+
laion_clap==1.1.6
|
27 |
+
lazy_loader==0.4
|
28 |
+
librosa==0.11.0
|
29 |
+
llvmlite==0.43.0
|
30 |
+
MarkupSafe==3.0.2
|
31 |
+
msgpack==1.1.0
|
32 |
+
narwhals==1.30.0
|
33 |
+
numba==0.60.0
|
34 |
+
numpy==1.23.5
|
35 |
+
packaging==24.2
|
36 |
+
pandas==2.2.3
|
37 |
+
pillow==11.1.0
|
38 |
+
platformdirs==4.3.6
|
39 |
+
pooch==1.8.2
|
40 |
+
progressbar==2.5
|
41 |
+
protobuf==5.29.3
|
42 |
+
psutil==7.0.0
|
43 |
+
pyarrow==19.0.1
|
44 |
+
pycparser==2.22
|
45 |
+
pydantic==2.10.6
|
46 |
+
pydantic_core==2.27.2
|
47 |
+
pydeck==0.9.1
|
48 |
+
python-dateutil==2.9.0.post0
|
49 |
+
pytz==2025.1
|
50 |
+
PyYAML==6.0.2
|
51 |
+
referencing==0.36.2
|
52 |
+
regex==2024.11.6
|
53 |
+
requests==2.32.3
|
54 |
+
rpds-py==0.23.1
|
55 |
+
safetensors==0.5.3
|
56 |
+
scikit-learn==1.6.1
|
57 |
+
scipy==1.15.2
|
58 |
+
sentry-sdk==2.22.0
|
59 |
+
setproctitle==1.3.5
|
60 |
+
six==1.17.0
|
61 |
+
smmap==5.0.2
|
62 |
+
soundfile==0.13.1
|
63 |
+
soxr==0.5.0.post1
|
64 |
+
streamlit==1.43.2
|
65 |
+
tenacity==9.0.0
|
66 |
+
threadpoolctl==3.6.0
|
67 |
+
tokenizers==0.21.1
|
68 |
+
toml==0.10.2
|
69 |
+
torchlibrosa==0.1.0
|
70 |
+
tornado==6.4.2
|
71 |
+
tqdm==4.67.1
|
72 |
+
transformers==4.49.0
|
73 |
+
typing_extensions==4.12.2
|
74 |
+
tzdata==2025.1
|
75 |
+
urllib3==2.3.0
|
76 |
+
wandb==0.19.8
|
77 |
+
wcwidth==0.2.13
|
78 |
+
webdataset==0.2.111
|
79 |
+
wget==3.2
|