Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import joblib | |
import librosa | |
import plotly.express as px | |
import os | |
# Load the model and class mapping data once at the beginning | |
model = joblib.load('model.joblib') | |
class_mapping_data = pd.read_csv('data.csv') | |
# Preprocess the class mapping data for faster lookups | |
class_mapping_dict = dict(zip(class_mapping_data['encoded_label'], | |
class_mapping_data[['scientific_name', 'latitude', 'longitude', 'primary_label']].values)) | |
# Define the feature extraction function | |
def extract_features(file_path): | |
audio, _ = librosa.load(file_path) | |
mfccs = librosa.feature.mfcc(y=audio, n_mfcc=40) | |
return np.mean(mfccs.T, axis=0) | |
# Define the prediction function | |
def predict_bird(audio_file, sample_name): | |
# Handle sample audio selection | |
if sample_name in sample_audio_files: | |
audio_file = sample_audio_files[sample_name] | |
# Extract features from the audio file | |
features = extract_features(audio_file) | |
features = features.reshape(1, -1) | |
# Predict the bird species | |
prediction = model.predict(features)[0] | |
# Retrieve bird information directly from the preprocessed dictionary | |
bird_info = class_mapping_dict[prediction] | |
predicted_bird = bird_info[0] | |
# Calculate prediction confidence | |
confidence = model.predict_proba(features)[0][prediction] * 100 | |
# Create a DataFrame for plotting | |
tmp = pd.DataFrame([bird_info[1:]], columns=['latitude', 'longitude', 'primary_label']) | |
# Create a scatter mapbox plot | |
fig = px.scatter_mapbox( | |
tmp, | |
lat="latitude", | |
lon="longitude", | |
color="primary_label", | |
zoom=10, | |
title='Bird Recordings Location', | |
mapbox_style="open-street-map" | |
) | |
fig.update_layout(margin={"r": 0, "t": 30, "l": 0, "b": 0}) | |
# Generate a spectrogram | |
audio, sr = librosa.load(audio_file) | |
spectrogram = librosa.feature.melspectrogram(y=audio, sr=sr) | |
spectrogram_db = librosa.power_to_db(spectrogram, ref=np.max) | |
fig_spectrogram = px.imshow(spectrogram_db, color_continuous_scale='viridis') | |
fig_spectrogram.update_layout(title='Spectrogram', xaxis_title='Time', yaxis_title='Frequency') | |
return predicted_bird, f"Confidence: {confidence:.2f}%", fig, fig_spectrogram | |
# Define sample audio files with full paths | |
sample_audio_files = { | |
"Audio 1": os.path.join('sounds', 'asbfly.ogg'), | |
"Audio 2": os.path.join('sounds', 'bkwsti.ogg'), | |
"Audio 3": os.path.join('sounds', 'comros.ogg'), | |
} | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=predict_bird, | |
inputs=[ | |
gr.Audio(type="filepath", label="Upload Bird Sound"), | |
gr.Dropdown(choices=list(sample_audio_files.keys()), label="Or Select a Sample", value=None) | |
], | |
outputs=[ | |
gr.Textbox(label="Predicted Bird"), | |
gr.Textbox(label="Prediction Confidence"), | |
gr.Plot(label="Bird Recordings Location"), | |
gr.Plot(label="Spectrogram"), | |
], | |
title="Bird ID: Identify Bird Species from Audio Recordings", | |
description="Upload an audio recording of a bird or select a sample to identify the species!", | |
) | |
# Launch the Gradio interface | |
iface.launch() |