gender_cls_svm_ecapa_voxceleb / modeling_gender.py
griko's picture
Upload folder using huggingface_hub
309a4fd verified
raw
history blame
7.23 kB
import torch
import torchaudio
import numpy as np
import pandas as pd
import soundfile as sf
from transformers import Pipeline
from typing import Union, List, Tuple, Dict, Any
from speechbrain.inference.speaker import EncoderClassifier
class GenderClassificationPipeline(Pipeline):
def __init__(self, svm_model, scaler, device="cpu"):
"""
Initialize the pipeline with SVM model and scaler.
The ECAPA-TDNN model is loaded using SpeechBrain's EncoderClassifier.
"""
# Convert device string to torch.device
self.device = torch.device(device)
# Initialize the model with the proper device
self.model = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
run_opts={"device": str(self.device)} # SpeechBrain expects string
)
self.feature_names = [f"{i}_speechbrain_embedding" for i in range(192)]
# Add feature_extractor for handling multiple files
self.feature_extractor = lambda x: {"sampling_rate": 16000, "raw_speech": x}
# Audio processing parameters
self.target_sample_rate = 16000
self.target_bitrate = 256000 # 256 kbps
self.bits_per_sample = 16
self.svm_model = svm_model
self.scaler = scaler
self.labels = ["female", "male"]
# Required by Pipeline class
self.framework = "pt"
self._batch_size = 1
self._num_workers = None
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters()
self._framework = "pt"
self.call_count = 0
self.sequential = True
# self.torch_dtype = None
self.is_encoder_decoder = False
def _sanitize_parameters(self, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
"""Sanitize parameters for preprocess, forward, and postprocess steps"""
preprocess_kwargs = {}
forward_kwargs = {}
postprocess_kwargs = {}
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
def _process_audio(self, waveform: torch.Tensor, sample_rate: int) -> torch.Tensor:
"""Process audio to match target specifications"""
# Convert to mono if needed
if len(waveform.shape) > 1 and waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample to 16kHz if needed
if sample_rate != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
waveform = resampler(waveform)
# Normalize the audio to be between -1 and 1
if waveform.abs().max() > 1:
waveform = waveform / waveform.abs().max()
# Convert to 16-bit precision
waveform = (waveform * 32767).round() / 32767
# Calculate target samples based on bitrate
# bitrate = sample_rate * bits_per_sample * channels
target_samples = int((self.target_bitrate * waveform.shape[1]) /
(self.target_sample_rate * self.bits_per_sample))
# Adjust number of samples if needed
if waveform.shape[1] != target_samples:
# Either truncate or pad with zeros
if waveform.shape[1] > target_samples:
waveform = waveform[:, :target_samples]
else:
padding = target_samples - waveform.shape[1]
waveform = torch.nn.functional.pad(waveform, (0, padding))
return waveform
def preprocess(self, audio_input: Union[str, np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Preprocess audio input"""
if isinstance(audio_input, list):
waveforms = []
for audio_file in audio_input:
if isinstance(audio_file, str):
wave, sr = sf.read(audio_file)
wave = torch.from_numpy(wave).float()
if len(wave.shape) == 1:
wave = wave.unsqueeze(0)
else:
wave = wave.T
wave = self._process_audio(wave, sr)
waveforms.append(wave)
else:
raise ValueError(f"Unsupported audio input type in list: {type(audio_file)}")
# Stack all waveforms
waveform = torch.stack(waveforms)
return {"inputs": waveform.to(self.device)}
# Handle single input
if isinstance(audio_input, str):
waveform, sample_rate = sf.read(audio_input)
waveform = torch.from_numpy(waveform).float()
if len(waveform.shape) == 1:
waveform = waveform.unsqueeze(0)
else:
waveform = waveform.T
elif isinstance(audio_input, np.ndarray):
waveform = torch.from_numpy(audio_input).float()
if len(waveform.shape) == 1:
waveform = waveform.unsqueeze(0)
sample_rate = self.target_sample_rate
else:
waveform = audio_input
sample_rate = self.target_sample_rate
waveform = self._process_audio(waveform, sample_rate)
return {"inputs": waveform.to(self.device)}
def _forward(self, model_inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Extract embeddings using the model"""
with torch.no_grad():
embeddings = self.model.encode_batch(model_inputs["inputs"])
return embeddings
def postprocess(self, model_outputs: torch.Tensor) -> List[str]:
"""Process model outputs to final predictions"""
# Convert to numpy and reshape
embeddings = model_outputs.cpu().numpy().ravel().reshape(1, -1)
df_embeddings = pd.DataFrame(embeddings, columns=self.feature_names)
# embeddings = np.squeeze(embeddings, axis=1)
# if len(embeddings.shape) == 1:
# embeddings = embeddings.reshape(1, -1)
# Scale features
scaled_features = self.scaler.transform(df_embeddings)
# Get SVM predictions and probabilities
predictions = self.svm_model.predict(scaled_features)
# Format output
results = [self.labels[p] for p in predictions]
return results
@classmethod
def from_pretrained(cls, model_path: str, device="cpu"):
"""Load all model components"""
import joblib
import json
# Load configuration
with open(f"{model_path}/config.json", "r") as f:
config = json.load(f)
# Load SVM and scaler
svm_model = joblib.load(f"{model_path}/svm_model.joblib")
scaler = joblib.load(f"{model_path}/scaler.joblib")
# Create pipeline instance
pipeline = cls(svm_model=svm_model, scaler=scaler, device=device)
pipeline.labels = config["labels"]
return pipeline