import torch import torchaudio import numpy as np import pandas as pd import soundfile as sf from transformers import Pipeline from typing import Union, List, Tuple, Dict, Any from speechbrain.inference.speaker import EncoderClassifier class GenderClassificationPipeline(Pipeline): def __init__(self, svm_model, scaler, device="cpu"): """ Initialize the pipeline with SVM model and scaler. The ECAPA-TDNN model is loaded using SpeechBrain's EncoderClassifier. """ # Convert device string to torch.device self.device = torch.device(device) # Initialize the model with the proper device self.model = EncoderClassifier.from_hparams( source="speechbrain/spkrec-ecapa-voxceleb", run_opts={"device": str(self.device)} # SpeechBrain expects string ) self.feature_names = [f"{i}_speechbrain_embedding" for i in range(192)] # Add feature_extractor for handling multiple files self.feature_extractor = lambda x: {"sampling_rate": 16000, "raw_speech": x} # Audio processing parameters self.target_sample_rate = 16000 self.target_bitrate = 256000 # 256 kbps self.bits_per_sample = 16 self.svm_model = svm_model self.scaler = scaler self.labels = ["female", "male"] # Required by Pipeline class self.framework = "pt" self._batch_size = 1 self._num_workers = None self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters() self._framework = "pt" self.call_count = 0 self.sequential = True # self.torch_dtype = None self.is_encoder_decoder = False def _sanitize_parameters(self, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: """Sanitize parameters for preprocess, forward, and postprocess steps""" preprocess_kwargs = {} forward_kwargs = {} postprocess_kwargs = {} return preprocess_kwargs, forward_kwargs, postprocess_kwargs def _process_audio(self, waveform: torch.Tensor, sample_rate: int) -> torch.Tensor: """Process audio to match target specifications""" # Convert to mono if needed if len(waveform.shape) > 1 and waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # Resample to 16kHz if needed if sample_rate != self.target_sample_rate: resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate) waveform = resampler(waveform) # Normalize the audio to be between -1 and 1 if waveform.abs().max() > 1: waveform = waveform / waveform.abs().max() # Convert to 16-bit precision waveform = (waveform * 32767).round() / 32767 # Calculate target samples based on bitrate # bitrate = sample_rate * bits_per_sample * channels target_samples = int((self.target_bitrate * waveform.shape[1]) / (self.target_sample_rate * self.bits_per_sample)) # Adjust number of samples if needed if waveform.shape[1] != target_samples: # Either truncate or pad with zeros if waveform.shape[1] > target_samples: waveform = waveform[:, :target_samples] else: padding = target_samples - waveform.shape[1] waveform = torch.nn.functional.pad(waveform, (0, padding)) return waveform def preprocess(self, audio_input: Union[str, np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]: """Preprocess audio input""" if isinstance(audio_input, list): waveforms = [] for audio_file in audio_input: if isinstance(audio_file, str): wave, sr = sf.read(audio_file) wave = torch.from_numpy(wave).float() if len(wave.shape) == 1: wave = wave.unsqueeze(0) else: wave = wave.T wave = self._process_audio(wave, sr) waveforms.append(wave) else: raise ValueError(f"Unsupported audio input type in list: {type(audio_file)}") # Stack all waveforms waveform = torch.stack(waveforms) return {"inputs": waveform.to(self.device)} # Handle single input if isinstance(audio_input, str): waveform, sample_rate = sf.read(audio_input) waveform = torch.from_numpy(waveform).float() if len(waveform.shape) == 1: waveform = waveform.unsqueeze(0) else: waveform = waveform.T elif isinstance(audio_input, np.ndarray): waveform = torch.from_numpy(audio_input).float() if len(waveform.shape) == 1: waveform = waveform.unsqueeze(0) sample_rate = self.target_sample_rate else: waveform = audio_input sample_rate = self.target_sample_rate waveform = self._process_audio(waveform, sample_rate) return {"inputs": waveform.to(self.device)} def _forward(self, model_inputs: Dict[str, torch.Tensor]) -> torch.Tensor: """Extract embeddings using the model""" with torch.no_grad(): embeddings = self.model.encode_batch(model_inputs["inputs"]) return embeddings def postprocess(self, model_outputs: torch.Tensor) -> List[str]: """Process model outputs to final predictions""" # Convert to numpy and reshape embeddings = model_outputs.cpu().numpy().ravel().reshape(1, -1) df_embeddings = pd.DataFrame(embeddings, columns=self.feature_names) # embeddings = np.squeeze(embeddings, axis=1) # if len(embeddings.shape) == 1: # embeddings = embeddings.reshape(1, -1) # Scale features scaled_features = self.scaler.transform(df_embeddings) # Get SVM predictions and probabilities predictions = self.svm_model.predict(scaled_features) # Format output results = [self.labels[p] for p in predictions] return results @classmethod def from_pretrained(cls, model_path: str, device="cpu"): """Load all model components""" import joblib import json # Load configuration with open(f"{model_path}/config.json", "r") as f: config = json.load(f) # Load SVM and scaler svm_model = joblib.load(f"{model_path}/svm_model.joblib") scaler = joblib.load(f"{model_path}/scaler.joblib") # Create pipeline instance pipeline = cls(svm_model=svm_model, scaler=scaler, device=device) pipeline.labels = config["labels"] return pipeline