arhanv commited on
Commit
1f5d38f
·
1 Parent(s): f6a3d7e

ported audio_utils and inference from colab

Browse files
Files changed (5) hide show
  1. app.py +33 -0
  2. audio_utils.py +12 -0
  3. dataset/audio_embeddings.pkl +3 -0
  4. inference.py +65 -0
  5. requirements.txt +79 -0
app.py CHANGED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import librosa
4
+ import soundfile as sf
5
+ import numpy as np
6
+ from inference import generate_drum_kit
7
+ from audio_utils import play_audio
8
+
9
+ # Streamlit UI
10
+ st.title("Generate Drum Kits with Text")
11
+
12
+ # User Inputs
13
+ prompt = st.text_input("Describe your drum kit (e.g., 'warm vintage')", "8-bit video game")
14
+ kit_size = st.slider("Number of sounds per instrument:", 1, 10, 4)
15
+
16
+ # Run the inference
17
+ if st.button("Generate Drum Kit"):
18
+ drum_kit = generate_drum_kit(prompt, kit_size)
19
+ st.session_state["drum_kit"] = drum_kit # Store results
20
+
21
+ # Display results
22
+ if "drum_kit" in st.session_state:
23
+ drum_kit = st.session_state["drum_kit"]
24
+ st.subheader("Generated Drum Kit")
25
+
26
+ for instrument, sounds in drum_kit.items():
27
+ st.write(f"**{instrument}**")
28
+ cols = st.columns(len(sounds))
29
+
30
+ for i, sound_file in enumerate(sounds):
31
+ with cols[i]:
32
+ if st.button(f"▶️ {os.path.basename(sound_file)}", key=sound_file):
33
+ play_audio(sound_file)
audio_utils.py CHANGED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import soundfile as sf
3
+ import librosa
4
+ import numpy as np
5
+ import io
6
+
7
+ def play_audio(file_path):
8
+ """Load and play an audio file."""
9
+ audio, sr = librosa.load(file_path, sr=16000)
10
+ audio_buffer = io.BytesIO()
11
+ sf.write(audio_buffer, audio, sr, format="wav")
12
+ st.audio(audio_buffer, format="audio/wav")
dataset/audio_embeddings.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:734662f4d6f61035f7519a883918c52b6abc9d9eee0a03df1c4aaebb559d8408
3
+ size 21482036
inference.py CHANGED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import laion_clap
2
+ import numpy as np
3
+ import librosa
4
+ import pickle
5
+ import os
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+ import pandas as pd
8
+ import zipfile
9
+ import json
10
+
11
+ dataset_zip = "dataset/one_shot_percussive_sounds.zip"
12
+ extracted_folder = "dataset/unzipped"
13
+ metadata_path = "dataset/licenses.txt"
14
+ audio_embeddings_path = "dataset/audio_embeddings.pkl"
15
+
16
+ # Unzip if not already extracted
17
+ if not os.path.exists(extracted_folder):
18
+ with zipfile.ZipFile(dataset_zip, "r") as zip_ref:
19
+ zip_ref.extractall("dataset")
20
+
21
+ # Load the model
22
+ model = laion_clap.CLAP_Module(enable_fusion=True)
23
+ model.load_ckpt(model_id=3)
24
+
25
+ # Load dataset metadata
26
+ with open(metadata_path, "r") as file:
27
+ data = json.load(file)
28
+
29
+ # Convert the JSON data into a Pandas DataFrame
30
+ metadata = pd.DataFrame.from_dict(data, orient="index")
31
+ metadata.index = metadata.index.astype(str) + '.wav'
32
+
33
+ # Load precomputed audio embeddings (to avoid recomputing on every request)
34
+ with open(audio_embeddings_path, "rb") as f:
35
+ audio_embeddings = pickle.load(f)
36
+
37
+ def get_clap_embeddings_from_text(text):
38
+ """Convert user text input to a CLAP embedding."""
39
+ text_embed = model.get_text_embedding([text])
40
+ return text_embed[0, :]
41
+
42
+ def find_top_sounds(text_embed, instrument, top_N=4):
43
+ """Finds the closest N sounds for an instrument."""
44
+ valid_sounds = metadata[metadata["Instrument"] == instrument].index.tolist()
45
+ relevant_embeddings = {k: v for k, v in audio_embeddings.items() if k in valid_sounds}
46
+
47
+ # Compute cosine similarity
48
+ all_embeds = np.array([v for v in relevant_embeddings.values()])
49
+ similarities = cosine_similarity([text_embed], all_embeds)[0]
50
+
51
+ # Get top N matches
52
+ top_indices = np.argsort(similarities)[-top_N:][::-1]
53
+ top_files = [valid_sounds[i] for i in top_indices]
54
+
55
+ return top_files
56
+
57
+ def generate_drum_kit(prompt, kit_size=4):
58
+ """Generate a drum kit dictionary from user input."""
59
+ text_embed = get_clap_embeddings_from_text(prompt)
60
+ drum_kit = {}
61
+
62
+ for instrument in ["Kick", "Snare", "Hi-Hat", "Tom", "Cymbal", "Clap", "Percussion", "Other"]:
63
+ drum_kit[instrument] = find_top_sounds(text_embed, instrument, top_N=kit_size)
64
+
65
+ return drum_kit
requirements.txt ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==5.5.0
2
+ annotated-types==0.7.0
3
+ attrs==25.3.0
4
+ audioread==3.0.1
5
+ blinker==1.9.0
6
+ braceexpand==0.1.7
7
+ cachetools==5.5.2
8
+ certifi==2025.1.31
9
+ cffi==1.17.1
10
+ charset-normalizer==3.4.1
11
+ click==8.1.8
12
+ decorator==5.2.1
13
+ docker-pycreds==0.4.0
14
+ filelock==3.17.0
15
+ fsspec==2025.3.0
16
+ ftfy==6.3.1
17
+ gitdb==4.0.12
18
+ GitPython==3.1.44
19
+ h5py==3.13.0
20
+ huggingface-hub==0.29.3
21
+ idna==3.10
22
+ Jinja2==3.1.6
23
+ joblib==1.4.2
24
+ jsonschema==4.23.0
25
+ jsonschema-specifications==2024.10.1
26
+ laion_clap==1.1.6
27
+ lazy_loader==0.4
28
+ librosa==0.11.0
29
+ llvmlite==0.43.0
30
+ MarkupSafe==3.0.2
31
+ msgpack==1.1.0
32
+ narwhals==1.30.0
33
+ numba==0.60.0
34
+ numpy==1.23.5
35
+ packaging==24.2
36
+ pandas==2.2.3
37
+ pillow==11.1.0
38
+ platformdirs==4.3.6
39
+ pooch==1.8.2
40
+ progressbar==2.5
41
+ protobuf==5.29.3
42
+ psutil==7.0.0
43
+ pyarrow==19.0.1
44
+ pycparser==2.22
45
+ pydantic==2.10.6
46
+ pydantic_core==2.27.2
47
+ pydeck==0.9.1
48
+ python-dateutil==2.9.0.post0
49
+ pytz==2025.1
50
+ PyYAML==6.0.2
51
+ referencing==0.36.2
52
+ regex==2024.11.6
53
+ requests==2.32.3
54
+ rpds-py==0.23.1
55
+ safetensors==0.5.3
56
+ scikit-learn==1.6.1
57
+ scipy==1.15.2
58
+ sentry-sdk==2.22.0
59
+ setproctitle==1.3.5
60
+ six==1.17.0
61
+ smmap==5.0.2
62
+ soundfile==0.13.1
63
+ soxr==0.5.0.post1
64
+ streamlit==1.43.2
65
+ tenacity==9.0.0
66
+ threadpoolctl==3.6.0
67
+ tokenizers==0.21.1
68
+ toml==0.10.2
69
+ torchlibrosa==0.1.0
70
+ tornado==6.4.2
71
+ tqdm==4.67.1
72
+ transformers==4.49.0
73
+ typing_extensions==4.12.2
74
+ tzdata==2025.1
75
+ urllib3==2.3.0
76
+ wandb==0.19.8
77
+ wcwidth==0.2.13
78
+ webdataset==0.2.111
79
+ wget==3.2