arhanv commited on
Commit
ac3dd61
·
1 Parent(s): ab0f8ba

fixed dataset and inference

Browse files
.gitignore CHANGED
@@ -1,3 +1,5 @@
1
  /.env
2
  ._*
3
- /dataset/unzipped
 
 
 
1
  /.env
2
  ._*
3
+ /dataset/unzipped
4
+ /dataset/all_sounds
5
+ *.pyc
dataset/{one_shot_percussive_sounds.zip → all_sounds.zip} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c45401b3cbdd56606f0d9e5e494a18efbae1ca830f835504dccc316c1934720c
3
- size 112614838
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d8773333b5f600f968c3d2c2c5fb09332440f19867ce69e9a20b76ab5aed618
3
+ size 112639857
inference.py CHANGED
@@ -1,4 +1,3 @@
1
- import laion_clap
2
  import numpy as np
3
  import librosa
4
  import pickle
@@ -7,20 +6,23 @@ from sklearn.metrics.pairwise import cosine_similarity
7
  import pandas as pd
8
  import zipfile
9
  import json
 
 
 
10
 
11
- dataset_zip = "dataset/one_shot_percussive_sounds.zip"
12
- extracted_folder = "dataset/unzipped"
13
  metadata_path = "dataset/licenses.txt"
14
  audio_embeddings_path = "dataset/audio_embeddings.pkl"
15
 
16
  # Unzip if not already extracted
17
  if not os.path.exists(extracted_folder):
18
  with zipfile.ZipFile(dataset_zip, "r") as zip_ref:
19
- zip_ref.extractall("dataset")
20
 
21
- # Load the model
22
- model = laion_clap.CLAP_Module(enable_fusion=True)
23
- model.load_ckpt(model_id=3)
24
 
25
  # Load dataset metadata
26
  with open(metadata_path, "r") as file:
@@ -30,14 +32,38 @@ with open(metadata_path, "r") as file:
30
  metadata = pd.DataFrame.from_dict(data, orient="index")
31
  metadata.index = metadata.index.astype(str) + '.wav'
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # Load precomputed audio embeddings (to avoid recomputing on every request)
34
  with open(audio_embeddings_path, "rb") as f:
35
  audio_embeddings = pickle.load(f)
36
 
37
  def get_clap_embeddings_from_text(text):
38
- """Convert user text input to a CLAP embedding."""
39
- text_embed = model.get_text_embedding([text])
40
- return text_embed[0, :]
 
 
41
 
42
  def find_top_sounds(text_embed, instrument, top_N=4):
43
  """Finds the closest N sounds for an instrument."""
@@ -50,7 +76,7 @@ def find_top_sounds(text_embed, instrument, top_N=4):
50
 
51
  # Get top N matches
52
  top_indices = np.argsort(similarities)[-top_N:][::-1]
53
- top_files = [valid_sounds[i] for i in top_indices]
54
 
55
  return top_files
56
 
@@ -62,4 +88,4 @@ def generate_drum_kit(prompt, kit_size=4):
62
  for instrument in ["Kick", "Snare", "Hi-Hat", "Tom", "Cymbal", "Clap", "Percussion", "Other"]:
63
  drum_kit[instrument] = find_top_sounds(text_embed, instrument, top_N=kit_size)
64
 
65
- return drum_kit
 
 
1
  import numpy as np
2
  import librosa
3
  import pickle
 
6
  import pandas as pd
7
  import zipfile
8
  import json
9
+ from transformers import ClapModel, ClapProcessor
10
+ import torch
11
+ import shutil
12
 
13
+ dataset_zip = "dataset/all_sounds.zip"
14
+ extracted_folder = "dataset/all_sounds"
15
  metadata_path = "dataset/licenses.txt"
16
  audio_embeddings_path = "dataset/audio_embeddings.pkl"
17
 
18
  # Unzip if not already extracted
19
  if not os.path.exists(extracted_folder):
20
  with zipfile.ZipFile(dataset_zip, "r") as zip_ref:
21
+ zip_ref.extractall(extracted_folder)
22
 
23
+ # Load Hugging Face's CLAP model
24
+ processor = ClapProcessor.from_pretrained("laion/clap-htsat-fused")
25
+ model = ClapModel.from_pretrained("laion/clap-htsat-fused")
26
 
27
  # Load dataset metadata
28
  with open(metadata_path, "r") as file:
 
32
  metadata = pd.DataFrame.from_dict(data, orient="index")
33
  metadata.index = metadata.index.astype(str) + '.wav'
34
 
35
+ instrument_categories = {
36
+ "Kick": ["kick", "bd", "bass", "808", "kd"],
37
+ "Snare": ["snare", "sd", "sn"],
38
+ "Hi-Hat": ["hihat", "hh", "hi_hat", "hi-hat"],
39
+ "Tom": ["tom"],
40
+ "Cymbal": ["crash", "ride", "splash", "cymbal"],
41
+ "Clap": ["clap"],
42
+ "Percussion": ["shaker", "perc", "tamb", "cowbell", "bongo", "conga", "egg"]
43
+ }
44
+
45
+ # Function to categorize filenames based on keywords
46
+ def categorize_instrument(filename):
47
+ lower_filename = filename.lower()
48
+ for category, keywords in instrument_categories.items():
49
+ if any(keyword in lower_filename for keyword in keywords):
50
+ return category
51
+ return "Other" # Default category if no match is found
52
+
53
+ # Apply function to create a new 'Instrument' column
54
+ metadata["Instrument"] = metadata["name"].apply(categorize_instrument)
55
+ metadata["Instrument"].value_counts()
56
+
57
  # Load precomputed audio embeddings (to avoid recomputing on every request)
58
  with open(audio_embeddings_path, "rb") as f:
59
  audio_embeddings = pickle.load(f)
60
 
61
  def get_clap_embeddings_from_text(text):
62
+ """Convert user text input to a CLAP embedding using Hugging Face's CLAP."""
63
+ inputs = processor(text=text, return_tensors="pt")
64
+ with torch.no_grad():
65
+ text_embeddings = model.get_text_features(**inputs)
66
+ return text_embeddings.squeeze(0).numpy()
67
 
68
  def find_top_sounds(text_embed, instrument, top_N=4):
69
  """Finds the closest N sounds for an instrument."""
 
76
 
77
  # Get top N matches
78
  top_indices = np.argsort(similarities)[-top_N:][::-1]
79
+ top_files = [os.path.join(extracted_folder, valid_sounds[i]) for i in top_indices]
80
 
81
  return top_files
82
 
 
88
  for instrument in ["Kick", "Snare", "Hi-Hat", "Tom", "Cymbal", "Clap", "Percussion", "Other"]:
89
  drum_kit[instrument] = find_top_sounds(text_embed, instrument, top_N=kit_size)
90
 
91
+ return drum_kit
requirements.txt CHANGED
@@ -1,49 +1,41 @@
1
  altair==5.5.0
2
- annotated-types==0.7.0
3
  attrs==25.3.0
4
  audioread==3.0.1
5
  blinker==1.9.0
6
- braceexpand==0.1.7
7
  cachetools==5.5.2
8
  certifi==2025.1.31
9
  cffi==1.17.1
10
  charset-normalizer==3.4.1
11
  click==8.1.8
12
  decorator==5.2.1
13
- docker-pycreds==0.4.0
14
  filelock==3.17.0
15
  fsspec==2025.3.0
16
- ftfy==6.3.1
17
  gitdb==4.0.12
18
  GitPython==3.1.44
19
- h5py==3.13.0
20
  huggingface-hub==0.29.3
21
  idna==3.10
22
  Jinja2==3.1.6
23
  joblib==1.4.2
24
  jsonschema==4.23.0
25
  jsonschema-specifications==2024.10.1
26
- laion_clap==1.1.6
27
  lazy_loader==0.4
28
  librosa==0.11.0
29
- llvmlite==0.43.0
30
  MarkupSafe==3.0.2
 
31
  msgpack==1.1.0
32
  narwhals==1.30.0
33
- numba==0.60.0
34
- numpy==1.23.5
 
35
  packaging==24.2
36
  pandas==2.2.3
37
  pillow==11.1.0
38
  platformdirs==4.3.6
39
  pooch==1.8.2
40
- progressbar==2.5
41
  protobuf==5.29.3
42
- psutil==7.0.0
43
  pyarrow==19.0.1
44
  pycparser==2.22
45
- pydantic==2.10.6
46
- pydantic_core==2.27.2
47
  pydeck==0.9.1
48
  python-dateutil==2.9.0.post0
49
  pytz==2025.1
@@ -55,26 +47,20 @@ rpds-py==0.23.1
55
  safetensors==0.5.3
56
  scikit-learn==1.6.1
57
  scipy==1.15.2
58
- sentry-sdk==2.22.0
59
- setproctitle==1.3.5
60
  six==1.17.0
61
  smmap==5.0.2
62
  soundfile==0.13.1
63
  soxr==0.5.0.post1
64
  streamlit==1.43.2
 
65
  tenacity==9.0.0
66
  threadpoolctl==3.6.0
67
  tokenizers==0.21.1
68
  toml==0.10.2
69
- torchlibrosa==0.1.0
70
  tornado==6.4.2
71
  tqdm==4.67.1
72
  transformers==4.49.0
73
  typing_extensions==4.12.2
74
  tzdata==2025.1
75
  urllib3==2.3.0
76
- wandb==0.19.8
77
- wcwidth==0.2.13
78
- webdataset==0.2.111
79
- wget==3.2
80
- torch
 
1
  altair==5.5.0
 
2
  attrs==25.3.0
3
  audioread==3.0.1
4
  blinker==1.9.0
 
5
  cachetools==5.5.2
6
  certifi==2025.1.31
7
  cffi==1.17.1
8
  charset-normalizer==3.4.1
9
  click==8.1.8
10
  decorator==5.2.1
 
11
  filelock==3.17.0
12
  fsspec==2025.3.0
 
13
  gitdb==4.0.12
14
  GitPython==3.1.44
 
15
  huggingface-hub==0.29.3
16
  idna==3.10
17
  Jinja2==3.1.6
18
  joblib==1.4.2
19
  jsonschema==4.23.0
20
  jsonschema-specifications==2024.10.1
 
21
  lazy_loader==0.4
22
  librosa==0.11.0
23
+ llvmlite==0.44.0
24
  MarkupSafe==3.0.2
25
+ mpmath==1.3.0
26
  msgpack==1.1.0
27
  narwhals==1.30.0
28
+ networkx==3.4.2
29
+ numba==0.61.0
30
+ numpy==2.1.3
31
  packaging==24.2
32
  pandas==2.2.3
33
  pillow==11.1.0
34
  platformdirs==4.3.6
35
  pooch==1.8.2
 
36
  protobuf==5.29.3
 
37
  pyarrow==19.0.1
38
  pycparser==2.22
 
 
39
  pydeck==0.9.1
40
  python-dateutil==2.9.0.post0
41
  pytz==2025.1
 
47
  safetensors==0.5.3
48
  scikit-learn==1.6.1
49
  scipy==1.15.2
 
 
50
  six==1.17.0
51
  smmap==5.0.2
52
  soundfile==0.13.1
53
  soxr==0.5.0.post1
54
  streamlit==1.43.2
55
+ sympy==1.13.1
56
  tenacity==9.0.0
57
  threadpoolctl==3.6.0
58
  tokenizers==0.21.1
59
  toml==0.10.2
60
+ torch==2.6.0
61
  tornado==6.4.2
62
  tqdm==4.67.1
63
  transformers==4.49.0
64
  typing_extensions==4.12.2
65
  tzdata==2025.1
66
  urllib3==2.3.0