File size: 7,227 Bytes
309a4fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import torch
import torchaudio
import numpy as np
import pandas as pd
import soundfile as sf
from transformers import Pipeline
from typing import Union, List, Tuple, Dict, Any
from speechbrain.inference.speaker import EncoderClassifier

class GenderClassificationPipeline(Pipeline):
    def __init__(self, svm_model, scaler, device="cpu"):
        """

        Initialize the pipeline with SVM model and scaler.

        The ECAPA-TDNN model is loaded using SpeechBrain's EncoderClassifier.

        """
        # Convert device string to torch.device
        self.device = torch.device(device)
        
        # Initialize the model with the proper device
        self.model = EncoderClassifier.from_hparams(
            source="speechbrain/spkrec-ecapa-voxceleb",
            run_opts={"device": str(self.device)}  # SpeechBrain expects string
        )

        self.feature_names = [f"{i}_speechbrain_embedding" for i in range(192)]
        
        # Add feature_extractor for handling multiple files
        self.feature_extractor = lambda x: {"sampling_rate": 16000, "raw_speech": x}

        # Audio processing parameters
        self.target_sample_rate = 16000
        self.target_bitrate = 256000  # 256 kbps
        self.bits_per_sample = 16
        
        self.svm_model = svm_model
        self.scaler = scaler
        self.labels = ["female", "male"]
        
        # Required by Pipeline class
        self.framework = "pt"
        self._batch_size = 1
        self._num_workers = None
        self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters()
        self._framework = "pt"
        self.call_count = 0
        self.sequential = True
        # self.torch_dtype = None
        self.is_encoder_decoder = False

    def _sanitize_parameters(self, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
        """Sanitize parameters for preprocess, forward, and postprocess steps"""
        preprocess_kwargs = {}
        forward_kwargs = {}
        postprocess_kwargs = {}
        
        return preprocess_kwargs, forward_kwargs, postprocess_kwargs
    
    def _process_audio(self, waveform: torch.Tensor, sample_rate: int) -> torch.Tensor:
        """Process audio to match target specifications"""
        # Convert to mono if needed
        if len(waveform.shape) > 1 and waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)

        # Resample to 16kHz if needed
        if sample_rate != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
            waveform = resampler(waveform)

        # Normalize the audio to be between -1 and 1
        if waveform.abs().max() > 1:
            waveform = waveform / waveform.abs().max()

        # Convert to 16-bit precision
        waveform = (waveform * 32767).round() / 32767

        # Calculate target samples based on bitrate
        # bitrate = sample_rate * bits_per_sample * channels
        target_samples = int((self.target_bitrate * waveform.shape[1]) / 
                           (self.target_sample_rate * self.bits_per_sample))
        
        # Adjust number of samples if needed
        if waveform.shape[1] != target_samples:
            # Either truncate or pad with zeros
            if waveform.shape[1] > target_samples:
                waveform = waveform[:, :target_samples]
            else:
                padding = target_samples - waveform.shape[1]
                waveform = torch.nn.functional.pad(waveform, (0, padding))

        return waveform

    def preprocess(self, audio_input: Union[str, np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Preprocess audio input"""
        if isinstance(audio_input, list):
            waveforms = []
            for audio_file in audio_input:
                if isinstance(audio_file, str):
                    wave, sr = sf.read(audio_file)
                    wave = torch.from_numpy(wave).float()
                    if len(wave.shape) == 1:
                        wave = wave.unsqueeze(0)
                    else:
                        wave = wave.T
                    wave = self._process_audio(wave, sr)
                    waveforms.append(wave)
                else:
                    raise ValueError(f"Unsupported audio input type in list: {type(audio_file)}")
            
            # Stack all waveforms
            waveform = torch.stack(waveforms)
            return {"inputs": waveform.to(self.device)}

        # Handle single input
        if isinstance(audio_input, str):
            waveform, sample_rate = sf.read(audio_input)
            waveform = torch.from_numpy(waveform).float()
            if len(waveform.shape) == 1:
                waveform = waveform.unsqueeze(0)
            else:
                waveform = waveform.T
        elif isinstance(audio_input, np.ndarray):
            waveform = torch.from_numpy(audio_input).float()
            if len(waveform.shape) == 1:
                waveform = waveform.unsqueeze(0)
            sample_rate = self.target_sample_rate
        else:
            waveform = audio_input
            sample_rate = self.target_sample_rate

        waveform = self._process_audio(waveform, sample_rate)
        return {"inputs": waveform.to(self.device)}

    def _forward(self, model_inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
        """Extract embeddings using the model"""
        with torch.no_grad():
            embeddings = self.model.encode_batch(model_inputs["inputs"])
            return embeddings

    def postprocess(self, model_outputs: torch.Tensor) -> List[str]:
        """Process model outputs to final predictions"""
        # Convert to numpy and reshape
        embeddings = model_outputs.cpu().numpy().ravel().reshape(1, -1)
        df_embeddings = pd.DataFrame(embeddings, columns=self.feature_names)
        # embeddings = np.squeeze(embeddings, axis=1)
        # if len(embeddings.shape) == 1:
        #     embeddings = embeddings.reshape(1, -1)

        # Scale features
        scaled_features = self.scaler.transform(df_embeddings)
        
        # Get SVM predictions and probabilities
        predictions = self.svm_model.predict(scaled_features)
        
        # Format output
        results = [self.labels[p] for p in predictions]
        
        return results

    @classmethod
    def from_pretrained(cls, model_path: str, device="cpu"):
        """Load all model components"""
        import joblib
        import json
        
        # Load configuration
        with open(f"{model_path}/config.json", "r") as f:
            config = json.load(f)
        
        # Load SVM and scaler
        svm_model = joblib.load(f"{model_path}/svm_model.joblib")
        scaler = joblib.load(f"{model_path}/scaler.joblib")
        
        # Create pipeline instance
        pipeline = cls(svm_model=svm_model, scaler=scaler, device=device)
        pipeline.labels = config["labels"]
        
        return pipeline