import tensorflow as tf import tensorflow_hub as hub import numpy as np import matplotlib.pyplot as plt import gradio as gr import soundfile as sf from scipy.signal import resample import tempfile import os # Load YAMNet model from TensorFlow Hub yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1") # Load class labels def load_class_map(): class_map_path = tf.keras.utils.get_file( 'yamnet_class_map.csv', 'https://raw.githubusercontent.com/tensorflow/models/master/research/audioset/yamnet/yamnet_class_map.csv' ) with open(class_map_path, 'r') as f: return [line.strip().split(',')[2] for line in f.readlines()[1:]] class_names = load_class_map() # Main classification function def classify_audio(audio_input): try: # Case 1: Filepath from Gradio UI if isinstance(audio_input, str): file_path = audio_input # Case 2: Binary upload (n8n POST) without .name attribute elif hasattr(audio_input, "read"): with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: tmp.write(audio_input.read()) file_path = tmp.name else: raise ValueError("Unsupported input format") # Load audio file audio_data, sample_rate = sf.read(file_path) # Cleanup if temp file was created if 'tmp' in locals(): os.unlink(tmp.name) # Convert stereo to mono if len(audio_data.shape) > 1: audio_data = np.mean(audio_data, axis=1) # Normalize audio_data = audio_data / np.max(np.abs(audio_data)) # Resample to 16kHz if needed target_rate = 16000 if sample_rate != target_rate: duration = audio_data.shape[0] / sample_rate new_length = int(duration * target_rate) audio_data = resample(audio_data, new_length) sample_rate = target_rate # Tensor for model waveform = tf.convert_to_tensor(audio_data, dtype=tf.float32) # Run YAMNet model scores, embeddings, spectrogram = yamnet_model(waveform) mean_scores = tf.reduce_mean(scores, axis=0).numpy() top_5 = np.argsort(mean_scores)[::-1][:5] # Output results top_prediction = class_names[top_5[0]] top_scores = {class_names[i]: float(mean_scores[i]) for i in top_5} # Plot waveform fig, ax = plt.subplots() ax.plot(audio_data) ax.set_title("Waveform") ax.set_xlabel("Time (samples)") ax.set_ylabel("Amplitude") plt.tight_layout() return top_prediction, top_scores, fig except Exception as e: return f"Error processing audio: {str(e)}", {}, None # Gradio Interface interface = gr.Interface( fn=classify_audio, inputs=gr.Audio(type="filepath", label="Upload .wav or .mp3 audio file"), outputs=[ gr.Textbox(label="Top Prediction"), gr.Label(label="Top 5 Classes with Scores"), gr.Plot(label="Waveform") ], title="Audtheia YAMNet Audio Classifier", description="Upload an environmental or animal sound to classify using the YAMNet model. Returns label predictions and waveform." ) if __name__ == "__main__": interface.launch()