|
import torch
|
|
import torchaudio
|
|
import numpy as np
|
|
import pandas as pd
|
|
import soundfile as sf
|
|
from transformers import Pipeline
|
|
from typing import Union, List, Tuple, Dict, Any
|
|
from speechbrain.inference.speaker import EncoderClassifier
|
|
|
|
class GenderClassificationPipeline(Pipeline):
|
|
def __init__(self, svm_model, scaler, device="cpu"):
|
|
"""
|
|
Initialize the pipeline with SVM model and scaler.
|
|
The ECAPA-TDNN model is loaded using SpeechBrain's EncoderClassifier.
|
|
"""
|
|
|
|
self.device = torch.device(device)
|
|
|
|
|
|
self.model = EncoderClassifier.from_hparams(
|
|
source="speechbrain/spkrec-ecapa-voxceleb",
|
|
run_opts={"device": str(self.device)}
|
|
)
|
|
|
|
self.feature_names = [f"{i}_speechbrain_embedding" for i in range(192)]
|
|
|
|
|
|
self.feature_extractor = lambda x: {"sampling_rate": 16000, "raw_speech": x}
|
|
|
|
|
|
self.target_sample_rate = 16000
|
|
self.target_bitrate = 256000
|
|
self.bits_per_sample = 16
|
|
|
|
self.svm_model = svm_model
|
|
self.scaler = scaler
|
|
self.labels = ["female", "male"]
|
|
|
|
|
|
self.framework = "pt"
|
|
self._batch_size = 1
|
|
self._num_workers = None
|
|
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters()
|
|
self._framework = "pt"
|
|
self.call_count = 0
|
|
self.sequential = True
|
|
|
|
self.is_encoder_decoder = False
|
|
|
|
def _sanitize_parameters(self, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
|
"""Sanitize parameters for preprocess, forward, and postprocess steps"""
|
|
preprocess_kwargs = {}
|
|
forward_kwargs = {}
|
|
postprocess_kwargs = {}
|
|
|
|
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
|
|
|
|
def _process_audio(self, waveform: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
|
"""Process audio to match target specifications"""
|
|
|
|
if len(waveform.shape) > 1 and waveform.shape[0] > 1:
|
|
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
|
|
|
|
|
if sample_rate != self.target_sample_rate:
|
|
resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
|
|
waveform = resampler(waveform)
|
|
|
|
|
|
if waveform.abs().max() > 1:
|
|
waveform = waveform / waveform.abs().max()
|
|
|
|
|
|
waveform = (waveform * 32767).round() / 32767
|
|
|
|
|
|
|
|
target_samples = int((self.target_bitrate * waveform.shape[1]) /
|
|
(self.target_sample_rate * self.bits_per_sample))
|
|
|
|
|
|
if waveform.shape[1] != target_samples:
|
|
|
|
if waveform.shape[1] > target_samples:
|
|
waveform = waveform[:, :target_samples]
|
|
else:
|
|
padding = target_samples - waveform.shape[1]
|
|
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
|
|
|
return waveform
|
|
|
|
def preprocess(self, audio_input: Union[str, np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
"""Preprocess audio input"""
|
|
if isinstance(audio_input, list):
|
|
waveforms = []
|
|
for audio_file in audio_input:
|
|
if isinstance(audio_file, str):
|
|
wave, sr = sf.read(audio_file)
|
|
wave = torch.from_numpy(wave).float()
|
|
if len(wave.shape) == 1:
|
|
wave = wave.unsqueeze(0)
|
|
else:
|
|
wave = wave.T
|
|
wave = self._process_audio(wave, sr)
|
|
waveforms.append(wave)
|
|
else:
|
|
raise ValueError(f"Unsupported audio input type in list: {type(audio_file)}")
|
|
|
|
|
|
waveform = torch.stack(waveforms)
|
|
return {"inputs": waveform.to(self.device)}
|
|
|
|
|
|
if isinstance(audio_input, str):
|
|
waveform, sample_rate = sf.read(audio_input)
|
|
waveform = torch.from_numpy(waveform).float()
|
|
if len(waveform.shape) == 1:
|
|
waveform = waveform.unsqueeze(0)
|
|
else:
|
|
waveform = waveform.T
|
|
elif isinstance(audio_input, np.ndarray):
|
|
waveform = torch.from_numpy(audio_input).float()
|
|
if len(waveform.shape) == 1:
|
|
waveform = waveform.unsqueeze(0)
|
|
sample_rate = self.target_sample_rate
|
|
else:
|
|
waveform = audio_input
|
|
sample_rate = self.target_sample_rate
|
|
|
|
waveform = self._process_audio(waveform, sample_rate)
|
|
return {"inputs": waveform.to(self.device)}
|
|
|
|
def _forward(self, model_inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
"""Extract embeddings using the model"""
|
|
with torch.no_grad():
|
|
embeddings = self.model.encode_batch(model_inputs["inputs"])
|
|
return embeddings
|
|
|
|
def postprocess(self, model_outputs: torch.Tensor) -> List[str]:
|
|
"""Process model outputs to final predictions"""
|
|
|
|
embeddings = model_outputs.cpu().numpy().ravel().reshape(1, -1)
|
|
df_embeddings = pd.DataFrame(embeddings, columns=self.feature_names)
|
|
|
|
|
|
|
|
|
|
|
|
scaled_features = self.scaler.transform(df_embeddings)
|
|
|
|
|
|
predictions = self.svm_model.predict(scaled_features)
|
|
|
|
|
|
results = [self.labels[p] for p in predictions]
|
|
|
|
return results
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, model_path: str, device="cpu"):
|
|
"""Load all model components"""
|
|
import joblib
|
|
import json
|
|
|
|
|
|
with open(f"{model_path}/config.json", "r") as f:
|
|
config = json.load(f)
|
|
|
|
|
|
svm_model = joblib.load(f"{model_path}/svm_model.joblib")
|
|
scaler = joblib.load(f"{model_path}/scaler.joblib")
|
|
|
|
|
|
pipeline = cls(svm_model=svm_model, scaler=scaler, device=device)
|
|
pipeline.labels = config["labels"]
|
|
|
|
return pipeline |