BirdCLEF-2024 / app.py
abdellatif-laghjaj's picture
Update app.py
e230f0b verified
import gradio as gr
import numpy as np
import pandas as pd
import joblib
import librosa
import plotly.express as px
import os
# Load the model and class mapping data once at the beginning
model = joblib.load('model.joblib')
class_mapping_data = pd.read_csv('data.csv')
# Preprocess the class mapping data for faster lookups
class_mapping_dict = dict(zip(class_mapping_data['encoded_label'],
class_mapping_data[['scientific_name', 'latitude', 'longitude', 'primary_label']].values))
# Define the feature extraction function
def extract_features(file_path):
audio, _ = librosa.load(file_path)
mfccs = librosa.feature.mfcc(y=audio, n_mfcc=40)
return np.mean(mfccs.T, axis=0)
# Define the prediction function
def predict_bird(audio_file, sample_name):
# Handle sample audio selection
if sample_name in sample_audio_files:
audio_file = sample_audio_files[sample_name]
# Extract features from the audio file
features = extract_features(audio_file)
features = features.reshape(1, -1)
# Predict the bird species
prediction = model.predict(features)[0]
# Retrieve bird information directly from the preprocessed dictionary
bird_info = class_mapping_dict[prediction]
predicted_bird = bird_info[0]
# Calculate prediction confidence
confidence = model.predict_proba(features)[0][prediction] * 100
# Create a DataFrame for plotting
tmp = pd.DataFrame([bird_info[1:]], columns=['latitude', 'longitude', 'primary_label'])
# Create a scatter mapbox plot
fig = px.scatter_mapbox(
tmp,
lat="latitude",
lon="longitude",
color="primary_label",
zoom=10,
title='Bird Recordings Location',
mapbox_style="open-street-map"
)
fig.update_layout(margin={"r": 0, "t": 30, "l": 0, "b": 0})
# Generate a spectrogram
audio, sr = librosa.load(audio_file)
spectrogram = librosa.feature.melspectrogram(y=audio, sr=sr)
spectrogram_db = librosa.power_to_db(spectrogram, ref=np.max)
fig_spectrogram = px.imshow(spectrogram_db, color_continuous_scale='viridis')
fig_spectrogram.update_layout(title='Spectrogram', xaxis_title='Time', yaxis_title='Frequency')
return predicted_bird, f"Confidence: {confidence:.2f}%", fig, fig_spectrogram
# Define sample audio files with full paths
sample_audio_files = {
"Audio 1": os.path.join('sounds', 'asbfly.ogg'),
"Audio 2": os.path.join('sounds', 'bkwsti.ogg'),
"Audio 3": os.path.join('sounds', 'comros.ogg'),
}
# Create Gradio interface
iface = gr.Interface(
fn=predict_bird,
inputs=[
gr.Audio(type="filepath", label="Upload Bird Sound"),
gr.Dropdown(choices=list(sample_audio_files.keys()), label="Or Select a Sample", value=None)
],
outputs=[
gr.Textbox(label="Predicted Bird"),
gr.Textbox(label="Prediction Confidence"),
gr.Plot(label="Bird Recordings Location"),
gr.Plot(label="Spectrogram"),
],
title="Bird ID: Identify Bird Species from Audio Recordings",
description="Upload an audio recording of a bird or select a sample to identify the species!",
)
# Launch the Gradio interface
iface.launch()