electroma's picture
Update tasks/audio.py
ee725de verified
from fastapi import APIRouter
from datetime import datetime
from datasets import load_dataset
from sklearn.metrics import accuracy_score
import pandas as pd
import numpy as np
import pickle
import xgboost
import random
import os
from .utils.evaluation import AudioEvaluationRequest
from .utils.emissions import tracker, clean_emissions_data, get_space_info
from .utils.preprocess import resample_audio, create_mel_spectrogram
from dotenv import load_dotenv
load_dotenv()
router = APIRouter()
DESCRIPTION = "Random Baseline"
ROUTE = "/audio"
@router.post(ROUTE, tags=["Audio Task"],
description=DESCRIPTION)
async def evaluate_audio(request: AudioEvaluationRequest):
"""
Evaluate audio classification for rainforest sound detection.
Current Model: Random Baseline
- Makes random predictions from the label space (0-1)
- Used as a baseline for comparison
"""
# Get space info
print("start audio")
username, space_url = get_space_info()
print(username)
print(space_url)
# Define the label mapping
LABEL_MAPPING = {
"chainsaw": 0,
"environment": 1
}
# Load and prepare the dataset
# Because the dataset is gated, we need to use the HF_TOKEN environment variable to authenticate
dataset = load_dataset(request.dataset_name,token=os.getenv("HF_TOKEN"))
# Split dataset
train = dataset["train"]
test = dataset["test"]
#preprocess data: resample data to be on the same sampling rate
target_sr = 12000
test_df = pd.DataFrame(test)
test_df["array"] = test_df["audio"].apply(lambda x: x['array'])
test_df["sampling_rate"] = test_df["audio"].apply(lambda x: x['sampling_rate'])
test_df["resampled_array"] = test_df.apply(
lambda row: resample_audio(row["array"], row["sampling_rate"], target_sr=target_sr), axis=1
)
test_df["sampling_rate"] = target_sr
features = []
for idx, row in test_df.iterrows():
features.append(create_mel_spectrogram(row['resampled_array'], row['sampling_rate']))
# Convert features to a numpy array and add to the DataFrame
test_df['basic_melspect'] = features
# Filter on samples with the same mel spectogram shape
test_df["shape"] = test_df['basic_melspect'].apply(lambda x: x.shape[1])
test_df = test_df[test_df["shape"]==71]
# Start tracking emissions
tracker.start()
tracker.start_task("inference")
#--------------------------------------------------------------------------------------------
# YOUR MODEL INFERENCE CODE HERE
# Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
#--------------------------------------------------------------------------------------------
# Make random predictions (placeholder for actual model inference)
with open("./train_models/xgboost_audio_model.pkl", "rb") as f:
loaded_model = pickle.load(f)
# Flatten Mel Spectrograms into 1D Features
test_df["flattened_mel"] = test_df["basic_melspect"].apply(lambda x: x.flatten())
# Convert to NumPy arrays
X = np.stack(test_df["flattened_mel"].values) # Features
y = test_df["label"].values # Labels (0: chainsaw, 1: rainforest)
dtest = xgboost.DMatrix(X, label=y)
# Make Predictions
y_pred_probs = loaded_model.predict(dtest)
y_pred = (y_pred_probs > 0.5).astype(int) # Convert probabilities to binary labels
#--------------------------------------------------------------------------------------------
# YOUR MODEL INFERENCE STOPS HERE
#--------------------------------------------------------------------------------------------
# Stop tracking emissions
emissions_data = tracker.stop_task()
# Calculate accuracy
accuracy = accuracy_score(y, y_pred)
# Prepare results dictionary
results = {
"username": username,
"space_url": space_url,
"submission_timestamp": datetime.now().isoformat(),
"model_description": DESCRIPTION,
"accuracy": float(accuracy),
"energy_consumed_wh": emissions_data.energy_consumed * 1000,
"emissions_gco2eq": emissions_data.emissions * 1000,
"emissions_data": clean_emissions_data(emissions_data),
"api_route": ROUTE,
"dataset_config": {
"dataset_name": request.dataset_name,
"test_size": request.test_size,
"test_seed": request.test_seed
}
}
return results