|
import json |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import os |
|
import requests |
|
from config import Config |
|
from model import BirdAST |
|
import torch |
|
import librosa |
|
import noisereduce as nr |
|
import timm |
|
from typing import Iterable |
|
import gradio as gr |
|
from gradio.themes.base import Base |
|
from gradio.themes.utils import colors, fonts, sizes |
|
import time |
|
import pandas as pd |
|
from classpred import predict_class |
|
import torch.nn.functional as F |
|
import random |
|
from torchaudio.compliance import kaldi |
|
from torchaudio.functional import resample |
|
from transformers import ASTFeatureExtractor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FEATURE_EXTRACTOR = ASTFeatureExtractor() |
|
|
|
def plot_mel(sr, x): |
|
mel_spec = librosa.feature.melspectrogram(y=x, sr=sr, n_mels=128, fmax=10000) |
|
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max) |
|
mel_spec_db = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min()) |
|
mel_spec_db = np.stack([mel_spec_db, mel_spec_db, mel_spec_db], axis=-1) |
|
fig, ax = plt.subplots(nrows=1, ncols=1, sharex=True) |
|
librosa.display.specshow(mel_spec_db[:, :, 0], sr=sr, x_axis='time', y_axis='mel', fmin = 0, fmax=10000, ax = ax) |
|
return fig |
|
|
|
def plot_wave(sr, x): |
|
ry = nr.reduce_noise(y=x, sr=sr) |
|
fig, ax = plt.subplots(2, 1, figsize=(12, 8)) |
|
|
|
|
|
librosa.display.waveshow(x, sr=sr, ax=ax[0]) |
|
ax[0].set(title='Original Waveform') |
|
ax[0].set_xlabel('Time (s)') |
|
ax[0].set_ylabel('Amplitude') |
|
|
|
|
|
librosa.display.waveshow(ry, sr=sr, ax=ax[1]) |
|
ax[1].set(title='Noise Reduced Waveform') |
|
ax[1].set_xlabel('Time (s)') |
|
ax[1].set_ylabel('Amplitude') |
|
|
|
plt.tight_layout() |
|
return fig |
|
|
|
def predict(audio, start, end): |
|
sr, x = audio |
|
|
|
x = np.array(x, dtype=np.float32)/32768.0 |
|
x = x[int(start*sr) : int(end*sr)] |
|
res = preprocess_for_inference(x, sr) |
|
|
|
if start >= end: |
|
raise gr.Error(f"`start` ({start}) must be smaller than end ({end}s)") |
|
|
|
if x.shape[0] < start * sr: |
|
raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({x.shape[0] / sr:.0f}s)") |
|
|
|
if x.shape[0] > end * sr: |
|
end = x.shape[0]/(1.0*sr) |
|
|
|
fig1 = plot_mel(sr, x) |
|
fig2 = plot_wave(sr, x) |
|
|
|
|
|
return predict_class(x, sr, start, end), res, fig1, fig2 |
|
|
|
def download_model(url, model_path): |
|
if not os.path.exists(model_path): |
|
response = requests.get(url) |
|
response.raise_for_status() |
|
with open(model_path, 'wb') as f: |
|
f.write(response.content) |
|
|
|
|
|
model_urls = [f'https://huggingface.co/shiyi-li/BirdAST/resolve/main/BirdAST_Baseline_GroupKFold_fold_{i}.pth' for i in range(5)] |
|
model_paths = [f'BirdAST_Baseline_GroupKFold_fold_{i}.pth' for i in range(5)] |
|
|
|
for (model_url, model_path) in zip(model_urls, model_paths): |
|
download_model(model_url, model_path) |
|
|
|
|
|
eval_models = [BirdAST(Config().backbone_name, Config().n_classes, n_mlp_layers=1, activation='silu') for i in range(5)] |
|
state_dicts = [torch.load(f'BirdAST_Baseline_GroupKFold_fold_{i}.pth', map_location='cpu') for i in range(5)] |
|
for idx, sd in enumerate(state_dicts): |
|
eval_models[idx].load_state_dict(sd) |
|
|
|
|
|
for i in range(5): |
|
eval_models[i].eval() |
|
|
|
|
|
label_mapping = pd.read_csv('BirdAST_Baseline_GroupKFold_label_map.csv') |
|
species_id_to_name = {row['species_id']: row['scientific_name'] for index, row in label_mapping.iterrows()} |
|
|
|
def preprocess_for_inference(audio_arr, sr): |
|
print(sr) |
|
spec = FEATURE_EXTRACTOR(audio_arr, sampling_rate=sr, padding="max_length", return_tensors="pt") |
|
input_values = spec['input_values'] |
|
|
|
|
|
|
|
model_outputs = [] |
|
|
|
with torch.no_grad(): |
|
|
|
for model in eval_models: |
|
output = model(input_values) |
|
predict_score = F.softmax(output['logits'], dim=1) |
|
model_outputs.append(predict_score) |
|
print(predict_score[0, 434]) |
|
|
|
|
|
|
|
avg_predictions = torch.mean(torch.cat(model_outputs), dim=0) |
|
print(avg_predictions[434]) |
|
|
|
|
|
topk_values, topk_indices = torch.topk(avg_predictions, 10) |
|
print(topk_values.shape, topk_indices.shape) |
|
|
|
|
|
results = [] |
|
for idx, scores in zip(topk_indices, topk_values): |
|
species_name = species_id_to_name[idx.item()] |
|
probability = scores.item()*100 |
|
results.append([species_name, probability]) |
|
|
|
return results |
|
|
|
DESCRIPTION = """ |
|
# Bird audio classification using SOTA Voice of Jungle Technology. \n |
|
# Introduction |
|
|
|
It is esimated that 50% of the global economy is threatened by biodiversity loss. As such, efforts have been concerted into estimating bird biodiversity, as birds are a top indicator of biodiversity in the region. One of these efforts is |
|
finding the bird species in a region using bird species audio classification. |
|
Prediction on left table shows prediction on the type of noise (class), while the right predictions are the species of bird. If class prediction does not output bird, then consequently the species prediction is not confident. |
|
""" |
|
|
|
|
|
css = """ |
|
.number-input { |
|
height: 100%; |
|
padding-bottom: 60px; /* Adust the value as needed for more or less space */ |
|
} |
|
.full-height { |
|
height: 100%; |
|
} |
|
.column-container { |
|
height: 100%; |
|
} |
|
""" |
|
|
|
|
|
|
|
|
|
class Seafoam(Base): |
|
def __init__( |
|
self, |
|
*, |
|
primary_hue: colors.Color | str = colors.emerald, |
|
secondary_hue: colors.Color | str = colors.blue, |
|
neutral_hue: colors.Color | str = colors.gray, |
|
spacing_size: sizes.Size | str = sizes.spacing_md, |
|
radius_size: sizes.Size | str = sizes.radius_md, |
|
text_size: sizes.Size | str = sizes.text_lg, |
|
font: fonts.Font |
|
| str |
|
| Iterable[fonts.Font | str] = ( |
|
fonts.GoogleFont("Quicksand"), |
|
"ui-sans-serif", |
|
"sans-serif", |
|
), |
|
font_mono: fonts.Font |
|
| str |
|
| Iterable[fonts.Font | str] = ( |
|
fonts.GoogleFont("IBM Plex Mono"), |
|
"ui-monospace", |
|
"monospace", |
|
), |
|
): |
|
super().__init__( |
|
primary_hue=primary_hue, |
|
secondary_hue=secondary_hue, |
|
neutral_hue=neutral_hue, |
|
spacing_size=spacing_size, |
|
radius_size=radius_size, |
|
text_size=text_size, |
|
font=font, |
|
font_mono=font_mono, |
|
) |
|
|
|
|
|
seafoam = Seafoam() |
|
|
|
|
|
|
|
with gr.Blocks(theme=seafoam, css = css) as demo: |
|
|
|
|
|
|
|
|
|
gr.Markdown(DESCRIPTION) |
|
|
|
with gr.Row(): |
|
with gr.Column(elem_classes="column-container"): |
|
start_time_input = gr.Number(label="Start Time", value=0, elem_classes="number-input full-height") |
|
end_time_input = gr.Number(label="End Time", value=10, elem_classes="number-input full-height") |
|
with gr.Column(): |
|
audio_input = gr.Audio(label="Input Audio", elem_classes="full-height") |
|
|
|
|
|
with gr.Row(): |
|
raw_class_output = gr.Dataframe(headers=["Class", "Score [%]"], row_count=10, label="Class Prediction") |
|
species_output = gr.Dataframe(headers=["Class", "Score [%]"], row_count=10, label="Species Prediction") |
|
|
|
with gr.Row(): |
|
waveform_output = gr.Plot(label="Waveform") |
|
spectrogram_output = gr.Plot(label="Spectrogram") |
|
|
|
gr.Examples( |
|
examples=[ |
|
["312_Cissopis_leverinia_1.wav", 0, 5], |
|
["1094_Pionus_fuscus_2.wav", 0, 10], |
|
], |
|
inputs=[audio_input, start_time_input, end_time_input] |
|
) |
|
|
|
gr.Button("Predict").click(predict, [audio_input, start_time_input, end_time_input], [raw_class_output, species_output, waveform_output, spectrogram_output]) |
|
|
|
demo.launch(share = True) |