import gradio as gr import numpy as np import pandas as pd import joblib import librosa import plotly.express as px import os # Load the model and class mapping data once at the beginning model = joblib.load('model.joblib') class_mapping_data = pd.read_csv('data.csv') # Preprocess the class mapping data for faster lookups class_mapping_dict = dict(zip(class_mapping_data['encoded_label'], class_mapping_data[['scientific_name', 'latitude', 'longitude', 'primary_label']].values)) # Define the feature extraction function def extract_features(file_path): audio, _ = librosa.load(file_path) mfccs = librosa.feature.mfcc(y=audio, n_mfcc=40) return np.mean(mfccs.T, axis=0) # Define the prediction function def predict_bird(audio_file, sample_name): # Handle sample audio selection if sample_name in sample_audio_files: audio_file = sample_audio_files[sample_name] # Extract features from the audio file features = extract_features(audio_file) features = features.reshape(1, -1) # Predict the bird species prediction = model.predict(features)[0] # Retrieve bird information directly from the preprocessed dictionary bird_info = class_mapping_dict[prediction] predicted_bird = bird_info[0] # Calculate prediction confidence confidence = model.predict_proba(features)[0][prediction] * 100 # Create a DataFrame for plotting tmp = pd.DataFrame([bird_info[1:]], columns=['latitude', 'longitude', 'primary_label']) # Create a scatter mapbox plot fig = px.scatter_mapbox( tmp, lat="latitude", lon="longitude", color="primary_label", zoom=10, title='Bird Recordings Location', mapbox_style="open-street-map" ) fig.update_layout(margin={"r": 0, "t": 30, "l": 0, "b": 0}) # Generate a spectrogram audio, sr = librosa.load(audio_file) spectrogram = librosa.feature.melspectrogram(y=audio, sr=sr) spectrogram_db = librosa.power_to_db(spectrogram, ref=np.max) fig_spectrogram = px.imshow(spectrogram_db, color_continuous_scale='viridis') fig_spectrogram.update_layout(title='Spectrogram', xaxis_title='Time', yaxis_title='Frequency') return predicted_bird, f"Confidence: {confidence:.2f}%", fig, fig_spectrogram # Define sample audio files with full paths sample_audio_files = { "Audio 1": os.path.join('sounds', 'asbfly.ogg'), "Audio 2": os.path.join('sounds', 'bkwsti.ogg'), "Audio 3": os.path.join('sounds', 'comros.ogg'), } # Create Gradio interface iface = gr.Interface( fn=predict_bird, inputs=[ gr.Audio(type="filepath", label="Upload Bird Sound"), gr.Dropdown(choices=list(sample_audio_files.keys()), label="Or Select a Sample", value=None) ], outputs=[ gr.Textbox(label="Predicted Bird"), gr.Textbox(label="Prediction Confidence"), gr.Plot(label="Bird Recordings Location"), gr.Plot(label="Spectrogram"), ], title="Bird ID: Identify Bird Species from Audio Recordings", description="Upload an audio recording of a bird or select a sample to identify the species!", ) # Launch the Gradio interface iface.launch()