import json |
import gradio as gr |
import matplotlib.pyplot as plt |
import numpy as np |
import os |
import requests |
from config import Config |
from model import BirdAST |
import torch |
import librosa |
import noisereduce as nr |
import pandas as pd |
import torch.nn.functional as F |
import random |
from torchaudio.compliance import kaldi |
from torchaudio.functional import resample |
from transformers import ASTFeatureExtractor |
FEATURE_EXTRACTOR = ASTFeatureExtractor() |
def plot_mel(sr, x): |
mel_spec = librosa.feature.melspectrogram(y=x, sr=sr, n_mels=128, fmax=10000) |
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max) |
mel_spec_db = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min()) |
mel_spec_db = np.stack([mel_spec_db, mel_spec_db, mel_spec_db], axis=-1) |
fig, ax = plt.subplots(nrows=1, ncols=1, sharex=True) |
librosa.display.specshow(mel_spec_db[:, :, 0], sr=sr, x_axis='time', y_axis='mel', fmin = 0, fmax=10000, ax = ax) |
return fig |
def plot_wave(sr, x): |
ry = nr.reduce_noise(y=x, sr=sr) |
fig, ax = plt.subplots(2, 1, figsize=(12, 8)) |
librosa.display.waveshow(x, sr=sr, ax=ax[0]) |
ax[0].set(title='Original Waveform') |
ax[0].set_xlabel('Time (s)') |
ax[0].set_ylabel('Amplitude') |
librosa.display.waveshow(ry, sr=sr, ax=ax[1]) |
ax[1].set(title='Noise Reduced Waveform') |
ax[1].set_xlabel('Time (s)') |
ax[1].set_ylabel('Amplitude') |
plt.tight_layout() |
return fig |
def predict(audio, start, end): |
sr, x = audio |
x = np.array(x, dtype=np.float32)/32768.0 |
x = x[start*sr : end*sr] |
res = preprocess_for_inference(x, sr) |
if start >= end: |
raise gr.Error(f"`start` ({start}) must be smaller than end ({end}s)") |
if x.shape[0] < start * sr: |
raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({x.shape[0] / sr:.0f}s)") |
if x.shape[0] > end * sr: |
end = x.shape[0]/(1.0*sr) |
fig1 = plot_mel(sr, x) |
fig2 = plot_wave(sr, x) |
return res, res, fig1, fig2 |
def download_model(url, model_path): |
if not os.path.exists(model_path): |
response = requests.get(url) |
response.raise_for_status() |
with open(model_path, 'wb') as f: |
f.write(response.content) |
model_urls = [f'https://huggingface.co/shiyi-li/BirdAST/resolve/main/BirdAST_Baseline_5folds_fold_{i}.pth' for i in range(5)] |
model_paths = [f'BirdAST_Baseline_5folds_fold_{i}.pth' for i in range(5)] |
for (model_url, model_path) in zip(model_urls, model_paths): |
download_model(model_url, model_path) |
eval_models = [BirdAST(Config().backbone_name, Config().n_classes, n_mlp_layers=1, activation='silu') for i in range(5)] |
state_dicts = [torch.load(f'BirdAST_Baseline_5folds_fold_{i}.pth', map_location='cpu') for i in range(5)] |
for idx, sd in enumerate(state_dicts): |
eval_models[idx].load_state_dict(sd) |
for i in range(5): |
eval_models[i].eval() |
label_mapping = pd.read_csv('BirdAST_Baseline_5folds_label_map.csv') |
species_id_to_name = {row['species_id']: row['scientific_name'] for index, row in label_mapping.iterrows()} |
def preprocess_for_inference(audio_arr, sr): |
print(sr) |
spec = FEATURE_EXTRACTOR(audio_arr, sampling_rate=sr, padding="max_length", return_tensors="pt") |
input_values = spec['input_values'] |
model_outputs = [] |
with torch.no_grad(): |
for model in eval_models: |
output = model(input_values) |
predict_score = F.softmax(output['logits'], dim=1) |
model_outputs.append(predict_score) |
print(predict_score[0, 434]) |
avg_predictions = torch.mean(torch.cat(model_outputs), dim=0) |
print(avg_predictions[434]) |
topk_values, topk_indices = torch.topk(avg_predictions, 10) |
print(topk_values.shape, topk_indices.shape) |
results = [] |
for idx, scores in zip(topk_indices, topk_values): |
species_name = species_id_to_name[idx.item()] |
probability = scores.item() |
results.append([species_name, probability]) |
return results |
Bird audio classification using SOTA Voice of Jungle Technology. |
""" |
css = """ |
.number-input { |
height: 100%; |
padding-bottom: 60px; /* Adust the value as needed for more or less space */ |
} |
.full-height { |
height: 100%; |
} |
.column-container { |
height: 100%; |
} |
""" |
with gr.Blocks(css = css) as demo: |
gr.Markdown("# Bird Species Audio Classification") |
gr.Markdown(DESCRIPTION) |
with gr.Row(): |
with gr.Column(elem_classes="column-container"): |
start_time_input = gr.Number(label="Start Time", value=0, elem_classes="number-input full-height") |
end_time_input = gr.Number(label="End Time", value=1, elem_classes="number-input full-height") |
with gr.Column(): |
audio_input = gr.Audio(label="Input Audio", elem_classes="full-height") |
with gr.Row(): |
raw_class_output = gr.Dataframe(headers=["class", "score"], row_count=10, label="Class Prediction") |
species_output = gr.Dataframe(headers=["class", "score"], row_count=10, label="Species Prediction") |
with gr.Row(): |
waveform_output = gr.Plot(label="Waveform") |
spectrogram_output = gr.Plot(label="Spectrogram") |
gr.Examples( |
examples=[ |
["312_Cissopis_leverinia_1.wav", 0, 5], |
["1094_Pionus_fuscus_2.wav", 0, 10], |
], |
inputs=[audio_input, start_time_input, end_time_input] |
) |
gr.Button("Predict").click(predict, [audio_input, start_time_input, end_time_input], [raw_class_output, species_output, waveform_output, spectrogram_output]) |
demo.launch(share = True) |