Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
fixed dataset and inference
Browse files- .gitignore +3 -1
- dataset/{one_shot_percussive_sounds.zip → all_sounds.zip} +2 -2
- inference.py +38 -12
- requirements.txt +7 -21
.gitignore
CHANGED
@@ -1,3 +1,5 @@
|
|
1 |
/.env
|
2 |
._*
|
3 |
-
/dataset/unzipped
|
|
|
|
|
|
1 |
/.env
|
2 |
._*
|
3 |
+
/dataset/unzipped
|
4 |
+
/dataset/all_sounds
|
5 |
+
*.pyc
|
dataset/{one_shot_percussive_sounds.zip → all_sounds.zip}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5d8773333b5f600f968c3d2c2c5fb09332440f19867ce69e9a20b76ab5aed618
|
3 |
+
size 112639857
|
inference.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import laion_clap
|
2 |
import numpy as np
|
3 |
import librosa
|
4 |
import pickle
|
@@ -7,20 +6,23 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|
7 |
import pandas as pd
|
8 |
import zipfile
|
9 |
import json
|
|
|
|
|
|
|
10 |
|
11 |
-
dataset_zip = "dataset/
|
12 |
-
extracted_folder = "dataset/
|
13 |
metadata_path = "dataset/licenses.txt"
|
14 |
audio_embeddings_path = "dataset/audio_embeddings.pkl"
|
15 |
|
16 |
# Unzip if not already extracted
|
17 |
if not os.path.exists(extracted_folder):
|
18 |
with zipfile.ZipFile(dataset_zip, "r") as zip_ref:
|
19 |
-
zip_ref.extractall(
|
20 |
|
21 |
-
# Load
|
22 |
-
|
23 |
-
model.
|
24 |
|
25 |
# Load dataset metadata
|
26 |
with open(metadata_path, "r") as file:
|
@@ -30,14 +32,38 @@ with open(metadata_path, "r") as file:
|
|
30 |
metadata = pd.DataFrame.from_dict(data, orient="index")
|
31 |
metadata.index = metadata.index.astype(str) + '.wav'
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
# Load precomputed audio embeddings (to avoid recomputing on every request)
|
34 |
with open(audio_embeddings_path, "rb") as f:
|
35 |
audio_embeddings = pickle.load(f)
|
36 |
|
37 |
def get_clap_embeddings_from_text(text):
|
38 |
-
"""Convert user text input to a CLAP embedding."""
|
39 |
-
|
40 |
-
|
|
|
|
|
41 |
|
42 |
def find_top_sounds(text_embed, instrument, top_N=4):
|
43 |
"""Finds the closest N sounds for an instrument."""
|
@@ -50,7 +76,7 @@ def find_top_sounds(text_embed, instrument, top_N=4):
|
|
50 |
|
51 |
# Get top N matches
|
52 |
top_indices = np.argsort(similarities)[-top_N:][::-1]
|
53 |
-
top_files = [valid_sounds[i] for i in top_indices]
|
54 |
|
55 |
return top_files
|
56 |
|
@@ -62,4 +88,4 @@ def generate_drum_kit(prompt, kit_size=4):
|
|
62 |
for instrument in ["Kick", "Snare", "Hi-Hat", "Tom", "Cymbal", "Clap", "Percussion", "Other"]:
|
63 |
drum_kit[instrument] = find_top_sounds(text_embed, instrument, top_N=kit_size)
|
64 |
|
65 |
-
return drum_kit
|
|
|
|
|
1 |
import numpy as np
|
2 |
import librosa
|
3 |
import pickle
|
|
|
6 |
import pandas as pd
|
7 |
import zipfile
|
8 |
import json
|
9 |
+
from transformers import ClapModel, ClapProcessor
|
10 |
+
import torch
|
11 |
+
import shutil
|
12 |
|
13 |
+
dataset_zip = "dataset/all_sounds.zip"
|
14 |
+
extracted_folder = "dataset/all_sounds"
|
15 |
metadata_path = "dataset/licenses.txt"
|
16 |
audio_embeddings_path = "dataset/audio_embeddings.pkl"
|
17 |
|
18 |
# Unzip if not already extracted
|
19 |
if not os.path.exists(extracted_folder):
|
20 |
with zipfile.ZipFile(dataset_zip, "r") as zip_ref:
|
21 |
+
zip_ref.extractall(extracted_folder)
|
22 |
|
23 |
+
# Load Hugging Face's CLAP model
|
24 |
+
processor = ClapProcessor.from_pretrained("laion/clap-htsat-fused")
|
25 |
+
model = ClapModel.from_pretrained("laion/clap-htsat-fused")
|
26 |
|
27 |
# Load dataset metadata
|
28 |
with open(metadata_path, "r") as file:
|
|
|
32 |
metadata = pd.DataFrame.from_dict(data, orient="index")
|
33 |
metadata.index = metadata.index.astype(str) + '.wav'
|
34 |
|
35 |
+
instrument_categories = {
|
36 |
+
"Kick": ["kick", "bd", "bass", "808", "kd"],
|
37 |
+
"Snare": ["snare", "sd", "sn"],
|
38 |
+
"Hi-Hat": ["hihat", "hh", "hi_hat", "hi-hat"],
|
39 |
+
"Tom": ["tom"],
|
40 |
+
"Cymbal": ["crash", "ride", "splash", "cymbal"],
|
41 |
+
"Clap": ["clap"],
|
42 |
+
"Percussion": ["shaker", "perc", "tamb", "cowbell", "bongo", "conga", "egg"]
|
43 |
+
}
|
44 |
+
|
45 |
+
# Function to categorize filenames based on keywords
|
46 |
+
def categorize_instrument(filename):
|
47 |
+
lower_filename = filename.lower()
|
48 |
+
for category, keywords in instrument_categories.items():
|
49 |
+
if any(keyword in lower_filename for keyword in keywords):
|
50 |
+
return category
|
51 |
+
return "Other" # Default category if no match is found
|
52 |
+
|
53 |
+
# Apply function to create a new 'Instrument' column
|
54 |
+
metadata["Instrument"] = metadata["name"].apply(categorize_instrument)
|
55 |
+
metadata["Instrument"].value_counts()
|
56 |
+
|
57 |
# Load precomputed audio embeddings (to avoid recomputing on every request)
|
58 |
with open(audio_embeddings_path, "rb") as f:
|
59 |
audio_embeddings = pickle.load(f)
|
60 |
|
61 |
def get_clap_embeddings_from_text(text):
|
62 |
+
"""Convert user text input to a CLAP embedding using Hugging Face's CLAP."""
|
63 |
+
inputs = processor(text=text, return_tensors="pt")
|
64 |
+
with torch.no_grad():
|
65 |
+
text_embeddings = model.get_text_features(**inputs)
|
66 |
+
return text_embeddings.squeeze(0).numpy()
|
67 |
|
68 |
def find_top_sounds(text_embed, instrument, top_N=4):
|
69 |
"""Finds the closest N sounds for an instrument."""
|
|
|
76 |
|
77 |
# Get top N matches
|
78 |
top_indices = np.argsort(similarities)[-top_N:][::-1]
|
79 |
+
top_files = [os.path.join(extracted_folder, valid_sounds[i]) for i in top_indices]
|
80 |
|
81 |
return top_files
|
82 |
|
|
|
88 |
for instrument in ["Kick", "Snare", "Hi-Hat", "Tom", "Cymbal", "Clap", "Percussion", "Other"]:
|
89 |
drum_kit[instrument] = find_top_sounds(text_embed, instrument, top_N=kit_size)
|
90 |
|
91 |
+
return drum_kit
|
requirements.txt
CHANGED
@@ -1,49 +1,41 @@
|
|
1 |
altair==5.5.0
|
2 |
-
annotated-types==0.7.0
|
3 |
attrs==25.3.0
|
4 |
audioread==3.0.1
|
5 |
blinker==1.9.0
|
6 |
-
braceexpand==0.1.7
|
7 |
cachetools==5.5.2
|
8 |
certifi==2025.1.31
|
9 |
cffi==1.17.1
|
10 |
charset-normalizer==3.4.1
|
11 |
click==8.1.8
|
12 |
decorator==5.2.1
|
13 |
-
docker-pycreds==0.4.0
|
14 |
filelock==3.17.0
|
15 |
fsspec==2025.3.0
|
16 |
-
ftfy==6.3.1
|
17 |
gitdb==4.0.12
|
18 |
GitPython==3.1.44
|
19 |
-
h5py==3.13.0
|
20 |
huggingface-hub==0.29.3
|
21 |
idna==3.10
|
22 |
Jinja2==3.1.6
|
23 |
joblib==1.4.2
|
24 |
jsonschema==4.23.0
|
25 |
jsonschema-specifications==2024.10.1
|
26 |
-
laion_clap==1.1.6
|
27 |
lazy_loader==0.4
|
28 |
librosa==0.11.0
|
29 |
-
llvmlite==0.
|
30 |
MarkupSafe==3.0.2
|
|
|
31 |
msgpack==1.1.0
|
32 |
narwhals==1.30.0
|
33 |
-
|
34 |
-
|
|
|
35 |
packaging==24.2
|
36 |
pandas==2.2.3
|
37 |
pillow==11.1.0
|
38 |
platformdirs==4.3.6
|
39 |
pooch==1.8.2
|
40 |
-
progressbar==2.5
|
41 |
protobuf==5.29.3
|
42 |
-
psutil==7.0.0
|
43 |
pyarrow==19.0.1
|
44 |
pycparser==2.22
|
45 |
-
pydantic==2.10.6
|
46 |
-
pydantic_core==2.27.2
|
47 |
pydeck==0.9.1
|
48 |
python-dateutil==2.9.0.post0
|
49 |
pytz==2025.1
|
@@ -55,26 +47,20 @@ rpds-py==0.23.1
|
|
55 |
safetensors==0.5.3
|
56 |
scikit-learn==1.6.1
|
57 |
scipy==1.15.2
|
58 |
-
sentry-sdk==2.22.0
|
59 |
-
setproctitle==1.3.5
|
60 |
six==1.17.0
|
61 |
smmap==5.0.2
|
62 |
soundfile==0.13.1
|
63 |
soxr==0.5.0.post1
|
64 |
streamlit==1.43.2
|
|
|
65 |
tenacity==9.0.0
|
66 |
threadpoolctl==3.6.0
|
67 |
tokenizers==0.21.1
|
68 |
toml==0.10.2
|
69 |
-
|
70 |
tornado==6.4.2
|
71 |
tqdm==4.67.1
|
72 |
transformers==4.49.0
|
73 |
typing_extensions==4.12.2
|
74 |
tzdata==2025.1
|
75 |
urllib3==2.3.0
|
76 |
-
wandb==0.19.8
|
77 |
-
wcwidth==0.2.13
|
78 |
-
webdataset==0.2.111
|
79 |
-
wget==3.2
|
80 |
-
torch
|
|
|
1 |
altair==5.5.0
|
|
|
2 |
attrs==25.3.0
|
3 |
audioread==3.0.1
|
4 |
blinker==1.9.0
|
|
|
5 |
cachetools==5.5.2
|
6 |
certifi==2025.1.31
|
7 |
cffi==1.17.1
|
8 |
charset-normalizer==3.4.1
|
9 |
click==8.1.8
|
10 |
decorator==5.2.1
|
|
|
11 |
filelock==3.17.0
|
12 |
fsspec==2025.3.0
|
|
|
13 |
gitdb==4.0.12
|
14 |
GitPython==3.1.44
|
|
|
15 |
huggingface-hub==0.29.3
|
16 |
idna==3.10
|
17 |
Jinja2==3.1.6
|
18 |
joblib==1.4.2
|
19 |
jsonschema==4.23.0
|
20 |
jsonschema-specifications==2024.10.1
|
|
|
21 |
lazy_loader==0.4
|
22 |
librosa==0.11.0
|
23 |
+
llvmlite==0.44.0
|
24 |
MarkupSafe==3.0.2
|
25 |
+
mpmath==1.3.0
|
26 |
msgpack==1.1.0
|
27 |
narwhals==1.30.0
|
28 |
+
networkx==3.4.2
|
29 |
+
numba==0.61.0
|
30 |
+
numpy==2.1.3
|
31 |
packaging==24.2
|
32 |
pandas==2.2.3
|
33 |
pillow==11.1.0
|
34 |
platformdirs==4.3.6
|
35 |
pooch==1.8.2
|
|
|
36 |
protobuf==5.29.3
|
|
|
37 |
pyarrow==19.0.1
|
38 |
pycparser==2.22
|
|
|
|
|
39 |
pydeck==0.9.1
|
40 |
python-dateutil==2.9.0.post0
|
41 |
pytz==2025.1
|
|
|
47 |
safetensors==0.5.3
|
48 |
scikit-learn==1.6.1
|
49 |
scipy==1.15.2
|
|
|
|
|
50 |
six==1.17.0
|
51 |
smmap==5.0.2
|
52 |
soundfile==0.13.1
|
53 |
soxr==0.5.0.post1
|
54 |
streamlit==1.43.2
|
55 |
+
sympy==1.13.1
|
56 |
tenacity==9.0.0
|
57 |
threadpoolctl==3.6.0
|
58 |
tokenizers==0.21.1
|
59 |
toml==0.10.2
|
60 |
+
torch==2.6.0
|
61 |
tornado==6.4.2
|
62 |
tqdm==4.67.1
|
63 |
transformers==4.49.0
|
64 |
typing_extensions==4.12.2
|
65 |
tzdata==2025.1
|
66 |
urllib3==2.3.0
|
|
|
|
|
|
|
|
|
|