import gradio as gr import torch import librosa from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor model_name = "greenarcade/wav2vec2-vd-bird-sound-classification" # model = Wav2Vec2ForSequenceClassification.from_pretrained( # model_name, # local_files_only=False, # use_auth_token=None, # trust_remote_code=False # ) model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name) def predict(audio_file): # Handle MP3/WAV files audio, sr = librosa.load(audio_file, sr=16000) # Process audio inputs = feature_extractor( audio, sampling_rate=16000, return_tensors="pt", padding=True, truncation=True, max_length=16000 * 5, ) # Predict with torch.no_grad(): logits = model(**inputs).logits probs = torch.softmax(logits, dim=-1).squeeze().tolist() # Format results - return actual float values instead of formatted strings predictions = {model.config.id2label[i]: prob for i, prob in enumerate(probs)} sorted_preds = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:3] return {k: v for k, v in sorted_preds} # Gradio Interface demo = gr.Interface( fn=predict, inputs=gr.Audio(sources=["upload"], type="filepath"), outputs=gr.Label(num_top_classes=3), title="🦜 Bird Sound Classifier (Indian birds)", description="Upload a 5-second audio clip to identify bird species", examples=[["greyheron-sample.wav"], ["blue-tail-sample.mp3"]] ) demo.launch()