YAMNet / app.py
Kaworu17's picture
Update app.py
29e4b0d verified
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr
import soundfile as sf
from scipy.signal import resample
import tempfile
import os
# Load YAMNet model from TensorFlow Hub
yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1")
# Load class labels
def load_class_map():
class_map_path = tf.keras.utils.get_file(
'yamnet_class_map.csv',
'https://raw.githubusercontent.com/tensorflow/models/master/research/audioset/yamnet/yamnet_class_map.csv'
)
with open(class_map_path, 'r') as f:
return [line.strip().split(',')[2] for line in f.readlines()[1:]]
class_names = load_class_map()
# Main classification function
def classify_audio(audio_input):
try:
# Case 1: Filepath from Gradio UI
if isinstance(audio_input, str):
file_path = audio_input
# Case 2: Binary upload (n8n POST) without .name attribute
elif hasattr(audio_input, "read"):
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(audio_input.read())
file_path = tmp.name
else:
raise ValueError("Unsupported input format")
# Load audio file
audio_data, sample_rate = sf.read(file_path)
# Cleanup if temp file was created
if 'tmp' in locals():
os.unlink(tmp.name)
# Convert stereo to mono
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
# Normalize
audio_data = audio_data / np.max(np.abs(audio_data))
# Resample to 16kHz if needed
target_rate = 16000
if sample_rate != target_rate:
duration = audio_data.shape[0] / sample_rate
new_length = int(duration * target_rate)
audio_data = resample(audio_data, new_length)
sample_rate = target_rate
# Tensor for model
waveform = tf.convert_to_tensor(audio_data, dtype=tf.float32)
# Run YAMNet model
scores, embeddings, spectrogram = yamnet_model(waveform)
mean_scores = tf.reduce_mean(scores, axis=0).numpy()
top_5 = np.argsort(mean_scores)[::-1][:5]
# Output results
top_prediction = class_names[top_5[0]]
top_scores = {class_names[i]: float(mean_scores[i]) for i in top_5}
# Plot waveform
fig, ax = plt.subplots()
ax.plot(audio_data)
ax.set_title("Waveform")
ax.set_xlabel("Time (samples)")
ax.set_ylabel("Amplitude")
plt.tight_layout()
return top_prediction, top_scores, fig
except Exception as e:
return f"Error processing audio: {str(e)}", {}, None
# Gradio Interface
interface = gr.Interface(
fn=classify_audio,
inputs=gr.Audio(type="filepath", label="Upload .wav or .mp3 audio file"),
outputs=[
gr.Textbox(label="Top Prediction"),
gr.Label(label="Top 5 Classes with Scores"),
gr.Plot(label="Waveform")
],
title="Audtheia YAMNet Audio Classifier",
description="Upload an environmental or animal sound to classify using the YAMNet model. Returns label predictions and waveform."
)
if __name__ == "__main__":
interface.launch()