import json import gradio as gr import matplotlib.pyplot as plt import numpy as np import os import requests from config import Config from model import BirdAST import torch import librosa import noisereduce as nr import timm from typing import Iterable import gradio as gr from gradio.themes.base import Base from gradio.themes.utils import colors, fonts, sizes import time import pandas as pd from classpred import predict_class import torch.nn.functional as F import random from torchaudio.compliance import kaldi from torchaudio.functional import resample from transformers import ASTFeatureExtractor #TAG = "gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k" #MODEL = timm.create_model(f"hf_hub:{TAG}", pretrained=True).eval() #LABEL_URL = "https://huggingface.co/datasets/huggingface/label-files/raw/main/audioset-id2label.json" #AUDIOSET_LABELS = list(json.loads(requests.get(LABEL_URL).content).values()) FEATURE_EXTRACTOR = ASTFeatureExtractor() def plot_mel(sr, x): mel_spec = librosa.feature.melspectrogram(y=x, sr=sr, n_mels=128, fmax=10000) mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max) mel_spec_db = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min()) # normalize spectrogram to [0,1] mel_spec_db = np.stack([mel_spec_db, mel_spec_db, mel_spec_db], axis=-1) # Convert to 3-channel fig, ax = plt.subplots(nrows=1, ncols=1, sharex=True) librosa.display.specshow(mel_spec_db[:, :, 0], sr=sr, x_axis='time', y_axis='mel', fmin = 0, fmax=10000, ax = ax) return fig def plot_wave(sr, x): ry = nr.reduce_noise(y=x, sr=sr) fig, ax = plt.subplots(2, 1, figsize=(12, 8)) # Plot the original waveform librosa.display.waveshow(x, sr=sr, ax=ax[0]) ax[0].set(title='Original Waveform') ax[0].set_xlabel('Time (s)') ax[0].set_ylabel('Amplitude') # Plot the noise-reduced waveform librosa.display.waveshow(ry, sr=sr, ax=ax[1]) ax[1].set(title='Noise Reduced Waveform') ax[1].set_xlabel('Time (s)') ax[1].set_ylabel('Amplitude') plt.tight_layout() return fig def predict(audio, start, end): sr, x = audio x = np.array(x, dtype=np.float32)/32768.0 x = x[int(start*sr) : int(end*sr)] res = preprocess_for_inference(x, sr) if start >= end: raise gr.Error(f"`start` ({start}) must be smaller than end ({end}s)") if x.shape[0] < start * sr: raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({x.shape[0] / sr:.0f}s)") if x.shape[0] > end * sr: end = x.shape[0]/(1.0*sr) fig1 = plot_mel(sr, x) fig2 = plot_wave(sr, x) return predict_class(x, sr, start, end), res, fig1, fig2 def download_model(url, model_path): if not os.path.exists(model_path): response = requests.get(url) response.raise_for_status() # Ensure the request was successful with open(model_path, 'wb') as f: f.write(response.content) # Model URL and path model_urls = [f'https://huggingface.co/shiyi-li/BirdAST/resolve/main/BirdAST_Baseline_GroupKFold_fold_{i}.pth' for i in range(5)] model_paths = [f'BirdAST_Baseline_GroupKFold_fold_{i}.pth' for i in range(5)] for (model_url, model_path) in zip(model_urls, model_paths): download_model(model_url, model_path) # Load the model (assumes you have the model architecture defined) eval_models = [BirdAST(Config().backbone_name, Config().n_classes, n_mlp_layers=1, activation='silu') for i in range(5)] state_dicts = [torch.load(f'BirdAST_Baseline_GroupKFold_fold_{i}.pth', map_location='cpu') for i in range(5)] for idx, sd in enumerate(state_dicts): eval_models[idx].load_state_dict(sd) # Set to evaluation mode for i in range(5): eval_models[i].eval() # Load the species mapping label_mapping = pd.read_csv('BirdAST_Baseline_GroupKFold_label_map.csv') species_id_to_name = {row['species_id']: row['scientific_name'] for index, row in label_mapping.iterrows()} def preprocess_for_inference(audio_arr, sr): print(sr) spec = FEATURE_EXTRACTOR(audio_arr, sampling_rate=sr, padding="max_length", return_tensors="pt") input_values = spec['input_values'] # Get the input values prepared for model input # Initialize a list to store predictions from all models model_outputs = [] with torch.no_grad(): # Accumulate predictions from each model for model in eval_models: output = model(input_values) predict_score = F.softmax(output['logits'], dim=1) model_outputs.append(predict_score) print(predict_score[0, 434]) # Average the predictions across all models avg_predictions = torch.mean(torch.cat(model_outputs), dim=0) #.values print(avg_predictions[434]) # Get the top 10 predictions based on the average prediction scores topk_values, topk_indices = torch.topk(avg_predictions, 10) print(topk_values.shape, topk_indices.shape) # Initialize results list to store the species names and their associated probabilities results = [] for idx, scores in zip(topk_indices, topk_values): species_name = species_id_to_name[idx.item()] probability = scores.item()*100 results.append([species_name, probability]) return results DESCRIPTION = """ # Bird audio classification using SOTA Voice of Jungle Technology. \n # Introduction It is esimated that 50% of the global economy is threatened by biodiversity loss. As such, efforts have been concerted into estimating bird biodiversity, as birds are a top indicator of biodiversity in the region. One of these efforts is finding the bird species in a region using bird species audio classification. Prediction on left table shows prediction on the type of noise (class), while the right predictions are the species of bird. If class prediction does not output bird, then consequently the species prediction is not confident. """ css = """ .number-input { height: 100%; padding-bottom: 60px; /* Adust the value as needed for more or less space */ } .full-height { height: 100%; } .column-container { height: 100%; } """ class Seafoam(Base): def __init__( self, *, primary_hue: colors.Color | str = colors.emerald, secondary_hue: colors.Color | str = colors.blue, neutral_hue: colors.Color | str = colors.gray, spacing_size: sizes.Size | str = sizes.spacing_md, radius_size: sizes.Size | str = sizes.radius_md, text_size: sizes.Size | str = sizes.text_lg, font: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("Quicksand"), "ui-sans-serif", "sans-serif", ), font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", ), ): super().__init__( primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue, spacing_size=spacing_size, radius_size=radius_size, text_size=text_size, font=font, font_mono=font_mono, ) seafoam = Seafoam() ## logo: vojlogo ## cactus: spur with gr.Blocks(theme=seafoam, css = css) as demo: #img_src = 'spur' #gr.Markdown(f"{img_src}") #gr.Markdown(f"# Team Voice of Jungle {img_src} more text") gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(elem_classes="column-container"): start_time_input = gr.Number(label="Start Time", value=0, elem_classes="number-input full-height") end_time_input = gr.Number(label="End Time", value=10, elem_classes="number-input full-height") with gr.Column(): audio_input = gr.Audio(label="Input Audio", elem_classes="full-height") with gr.Row(): raw_class_output = gr.Dataframe(headers=["Class", "Score [%]"], row_count=10, label="Class Prediction") species_output = gr.Dataframe(headers=["Class", "Score [%]"], row_count=10, label="Species Prediction") with gr.Row(): waveform_output = gr.Plot(label="Waveform") spectrogram_output = gr.Plot(label="Spectrogram") gr.Examples( examples=[ ["312_Cissopis_leverinia_1.wav", 0, 5], ["1094_Pionus_fuscus_2.wav", 0, 10], ], inputs=[audio_input, start_time_input, end_time_input] ) gr.Button("Predict").click(predict, [audio_input, start_time_input, end_time_input], [raw_class_output, species_output, waveform_output, spectrogram_output]) demo.launch(share = True)