|
|
import tensorflow as tf |
|
|
import tensorflow_hub as hub |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import gradio as gr |
|
|
import soundfile as sf |
|
|
from scipy.signal import resample |
|
|
import tempfile |
|
|
import os |
|
|
|
|
|
|
|
|
yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1") |
|
|
|
|
|
|
|
|
def load_class_map(): |
|
|
class_map_path = tf.keras.utils.get_file( |
|
|
'yamnet_class_map.csv', |
|
|
'https://raw.githubusercontent.com/tensorflow/models/master/research/audioset/yamnet/yamnet_class_map.csv' |
|
|
) |
|
|
with open(class_map_path, 'r') as f: |
|
|
return [line.strip().split(',')[2] for line in f.readlines()[1:]] |
|
|
|
|
|
class_names = load_class_map() |
|
|
|
|
|
|
|
|
def classify_audio(audio_input): |
|
|
try: |
|
|
|
|
|
if isinstance(audio_input, str): |
|
|
file_path = audio_input |
|
|
|
|
|
|
|
|
elif hasattr(audio_input, "read"): |
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: |
|
|
tmp.write(audio_input.read()) |
|
|
file_path = tmp.name |
|
|
else: |
|
|
raise ValueError("Unsupported input format") |
|
|
|
|
|
|
|
|
audio_data, sample_rate = sf.read(file_path) |
|
|
|
|
|
|
|
|
if 'tmp' in locals(): |
|
|
os.unlink(tmp.name) |
|
|
|
|
|
|
|
|
if len(audio_data.shape) > 1: |
|
|
audio_data = np.mean(audio_data, axis=1) |
|
|
|
|
|
|
|
|
audio_data = audio_data / np.max(np.abs(audio_data)) |
|
|
|
|
|
|
|
|
target_rate = 16000 |
|
|
if sample_rate != target_rate: |
|
|
duration = audio_data.shape[0] / sample_rate |
|
|
new_length = int(duration * target_rate) |
|
|
audio_data = resample(audio_data, new_length) |
|
|
sample_rate = target_rate |
|
|
|
|
|
|
|
|
waveform = tf.convert_to_tensor(audio_data, dtype=tf.float32) |
|
|
|
|
|
|
|
|
scores, embeddings, spectrogram = yamnet_model(waveform) |
|
|
mean_scores = tf.reduce_mean(scores, axis=0).numpy() |
|
|
top_5 = np.argsort(mean_scores)[::-1][:5] |
|
|
|
|
|
|
|
|
top_prediction = class_names[top_5[0]] |
|
|
top_scores = {class_names[i]: float(mean_scores[i]) for i in top_5} |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots() |
|
|
ax.plot(audio_data) |
|
|
ax.set_title("Waveform") |
|
|
ax.set_xlabel("Time (samples)") |
|
|
ax.set_ylabel("Amplitude") |
|
|
plt.tight_layout() |
|
|
|
|
|
return top_prediction, top_scores, fig |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error processing audio: {str(e)}", {}, None |
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
|
fn=classify_audio, |
|
|
inputs=gr.Audio(type="filepath", label="Upload .wav or .mp3 audio file"), |
|
|
outputs=[ |
|
|
gr.Textbox(label="Top Prediction"), |
|
|
gr.Label(label="Top 5 Classes with Scores"), |
|
|
gr.Plot(label="Waveform") |
|
|
], |
|
|
title="Audtheia YAMNet Audio Classifier", |
|
|
description="Upload an environmental or animal sound to classify using the YAMNet model. Returns label predictions and waveform." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
interface.launch() |
|
|
|