File size: 7,227 Bytes
309a4fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import torch
import torchaudio
import numpy as np
import pandas as pd
import soundfile as sf
from transformers import Pipeline
from typing import Union, List, Tuple, Dict, Any
from speechbrain.inference.speaker import EncoderClassifier
class GenderClassificationPipeline(Pipeline):
def __init__(self, svm_model, scaler, device="cpu"):
"""
Initialize the pipeline with SVM model and scaler.
The ECAPA-TDNN model is loaded using SpeechBrain's EncoderClassifier.
"""
# Convert device string to torch.device
self.device = torch.device(device)
# Initialize the model with the proper device
self.model = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
run_opts={"device": str(self.device)} # SpeechBrain expects string
)
self.feature_names = [f"{i}_speechbrain_embedding" for i in range(192)]
# Add feature_extractor for handling multiple files
self.feature_extractor = lambda x: {"sampling_rate": 16000, "raw_speech": x}
# Audio processing parameters
self.target_sample_rate = 16000
self.target_bitrate = 256000 # 256 kbps
self.bits_per_sample = 16
self.svm_model = svm_model
self.scaler = scaler
self.labels = ["female", "male"]
# Required by Pipeline class
self.framework = "pt"
self._batch_size = 1
self._num_workers = None
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters()
self._framework = "pt"
self.call_count = 0
self.sequential = True
# self.torch_dtype = None
self.is_encoder_decoder = False
def _sanitize_parameters(self, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
"""Sanitize parameters for preprocess, forward, and postprocess steps"""
preprocess_kwargs = {}
forward_kwargs = {}
postprocess_kwargs = {}
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
def _process_audio(self, waveform: torch.Tensor, sample_rate: int) -> torch.Tensor:
"""Process audio to match target specifications"""
# Convert to mono if needed
if len(waveform.shape) > 1 and waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample to 16kHz if needed
if sample_rate != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
waveform = resampler(waveform)
# Normalize the audio to be between -1 and 1
if waveform.abs().max() > 1:
waveform = waveform / waveform.abs().max()
# Convert to 16-bit precision
waveform = (waveform * 32767).round() / 32767
# Calculate target samples based on bitrate
# bitrate = sample_rate * bits_per_sample * channels
target_samples = int((self.target_bitrate * waveform.shape[1]) /
(self.target_sample_rate * self.bits_per_sample))
# Adjust number of samples if needed
if waveform.shape[1] != target_samples:
# Either truncate or pad with zeros
if waveform.shape[1] > target_samples:
waveform = waveform[:, :target_samples]
else:
padding = target_samples - waveform.shape[1]
waveform = torch.nn.functional.pad(waveform, (0, padding))
return waveform
def preprocess(self, audio_input: Union[str, np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Preprocess audio input"""
if isinstance(audio_input, list):
waveforms = []
for audio_file in audio_input:
if isinstance(audio_file, str):
wave, sr = sf.read(audio_file)
wave = torch.from_numpy(wave).float()
if len(wave.shape) == 1:
wave = wave.unsqueeze(0)
else:
wave = wave.T
wave = self._process_audio(wave, sr)
waveforms.append(wave)
else:
raise ValueError(f"Unsupported audio input type in list: {type(audio_file)}")
# Stack all waveforms
waveform = torch.stack(waveforms)
return {"inputs": waveform.to(self.device)}
# Handle single input
if isinstance(audio_input, str):
waveform, sample_rate = sf.read(audio_input)
waveform = torch.from_numpy(waveform).float()
if len(waveform.shape) == 1:
waveform = waveform.unsqueeze(0)
else:
waveform = waveform.T
elif isinstance(audio_input, np.ndarray):
waveform = torch.from_numpy(audio_input).float()
if len(waveform.shape) == 1:
waveform = waveform.unsqueeze(0)
sample_rate = self.target_sample_rate
else:
waveform = audio_input
sample_rate = self.target_sample_rate
waveform = self._process_audio(waveform, sample_rate)
return {"inputs": waveform.to(self.device)}
def _forward(self, model_inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Extract embeddings using the model"""
with torch.no_grad():
embeddings = self.model.encode_batch(model_inputs["inputs"])
return embeddings
def postprocess(self, model_outputs: torch.Tensor) -> List[str]:
"""Process model outputs to final predictions"""
# Convert to numpy and reshape
embeddings = model_outputs.cpu().numpy().ravel().reshape(1, -1)
df_embeddings = pd.DataFrame(embeddings, columns=self.feature_names)
# embeddings = np.squeeze(embeddings, axis=1)
# if len(embeddings.shape) == 1:
# embeddings = embeddings.reshape(1, -1)
# Scale features
scaled_features = self.scaler.transform(df_embeddings)
# Get SVM predictions and probabilities
predictions = self.svm_model.predict(scaled_features)
# Format output
results = [self.labels[p] for p in predictions]
return results
@classmethod
def from_pretrained(cls, model_path: str, device="cpu"):
"""Load all model components"""
import joblib
import json
# Load configuration
with open(f"{model_path}/config.json", "r") as f:
config = json.load(f)
# Load SVM and scaler
svm_model = joblib.load(f"{model_path}/svm_model.joblib")
scaler = joblib.load(f"{model_path}/scaler.joblib")
# Create pipeline instance
pipeline = cls(svm_model=svm_model, scaler=scaler, device=device)
pipeline.labels = config["labels"]
return pipeline |