griko commited on
Commit
4617266
·
verified ·
1 Parent(s): 84a863d

Delete modeling_gender.py

Browse files
Files changed (1) hide show
  1. modeling_gender.py +0 -177
modeling_gender.py DELETED
@@ -1,177 +0,0 @@
1
- import torch
2
- import torchaudio
3
- import numpy as np
4
- import pandas as pd
5
- import soundfile as sf
6
- from transformers import Pipeline
7
- from typing import Union, List, Tuple, Dict, Any
8
- from speechbrain.inference.speaker import EncoderClassifier
9
-
10
- class GenderClassificationPipeline(Pipeline):
11
- def __init__(self, svm_model, scaler, device="cpu"):
12
- """
13
- Initialize the pipeline with SVM model and scaler.
14
- The ECAPA-TDNN model is loaded using SpeechBrain's EncoderClassifier.
15
- """
16
- # Convert device string to torch.device
17
- self.device = torch.device(device)
18
-
19
- # Initialize the model with the proper device
20
- self.model = EncoderClassifier.from_hparams(
21
- source="speechbrain/spkrec-ecapa-voxceleb",
22
- run_opts={"device": str(self.device)} # SpeechBrain expects string
23
- )
24
-
25
- self.feature_names = [f"{i}_speechbrain_embedding" for i in range(192)]
26
-
27
- # Add feature_extractor for handling multiple files
28
- self.feature_extractor = lambda x: {"sampling_rate": 16000, "raw_speech": x}
29
-
30
- # Audio processing parameters
31
- self.target_sample_rate = 16000
32
- self.target_bitrate = 256000 # 256 kbps
33
- self.bits_per_sample = 16
34
-
35
- self.svm_model = svm_model
36
- self.scaler = scaler
37
- self.labels = ["female", "male"]
38
-
39
- # Required by Pipeline class
40
- self.framework = "pt"
41
- self._batch_size = 1
42
- self._num_workers = None
43
- self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters()
44
- self._framework = "pt"
45
- self.call_count = 0
46
- self.sequential = True
47
- # self.torch_dtype = None
48
- self.is_encoder_decoder = False
49
-
50
- def _sanitize_parameters(self, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
51
- """Sanitize parameters for preprocess, forward, and postprocess steps"""
52
- preprocess_kwargs = {}
53
- forward_kwargs = {}
54
- postprocess_kwargs = {}
55
-
56
- return preprocess_kwargs, forward_kwargs, postprocess_kwargs
57
-
58
- def _process_audio(self, waveform: torch.Tensor, sample_rate: int) -> torch.Tensor:
59
- """Process audio to match target specifications"""
60
- # Convert to mono if needed
61
- if len(waveform.shape) > 1 and waveform.shape[0] > 1:
62
- waveform = torch.mean(waveform, dim=0, keepdim=True)
63
-
64
- # Resample to 16kHz if needed
65
- if sample_rate != self.target_sample_rate:
66
- resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
67
- waveform = resampler(waveform)
68
-
69
- # Normalize the audio to be between -1 and 1
70
- if waveform.abs().max() > 1:
71
- waveform = waveform / waveform.abs().max()
72
-
73
- # Convert to 16-bit precision
74
- waveform = (waveform * 32767).round() / 32767
75
-
76
- # Calculate target samples based on bitrate
77
- # bitrate = sample_rate * bits_per_sample * channels
78
- target_samples = int((self.target_bitrate * waveform.shape[1]) /
79
- (self.target_sample_rate * self.bits_per_sample))
80
-
81
- # Adjust number of samples if needed
82
- if waveform.shape[1] != target_samples:
83
- # Either truncate or pad with zeros
84
- if waveform.shape[1] > target_samples:
85
- waveform = waveform[:, :target_samples]
86
- else:
87
- padding = target_samples - waveform.shape[1]
88
- waveform = torch.nn.functional.pad(waveform, (0, padding))
89
-
90
- return waveform
91
-
92
- def preprocess(self, audio_input: Union[str, np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
93
- """Preprocess audio input"""
94
- if isinstance(audio_input, list):
95
- waveforms = []
96
- for audio_file in audio_input:
97
- if isinstance(audio_file, str):
98
- wave, sr = sf.read(audio_file)
99
- wave = torch.from_numpy(wave).float()
100
- if len(wave.shape) == 1:
101
- wave = wave.unsqueeze(0)
102
- else:
103
- wave = wave.T
104
- wave = self._process_audio(wave, sr)
105
- waveforms.append(wave)
106
- else:
107
- raise ValueError(f"Unsupported audio input type in list: {type(audio_file)}")
108
-
109
- # Stack all waveforms
110
- waveform = torch.stack(waveforms)
111
- return {"inputs": waveform.to(self.device)}
112
-
113
- # Handle single input
114
- if isinstance(audio_input, str):
115
- waveform, sample_rate = sf.read(audio_input)
116
- waveform = torch.from_numpy(waveform).float()
117
- if len(waveform.shape) == 1:
118
- waveform = waveform.unsqueeze(0)
119
- else:
120
- waveform = waveform.T
121
- elif isinstance(audio_input, np.ndarray):
122
- waveform = torch.from_numpy(audio_input).float()
123
- if len(waveform.shape) == 1:
124
- waveform = waveform.unsqueeze(0)
125
- sample_rate = self.target_sample_rate
126
- else:
127
- waveform = audio_input
128
- sample_rate = self.target_sample_rate
129
-
130
- waveform = self._process_audio(waveform, sample_rate)
131
- return {"inputs": waveform.to(self.device)}
132
-
133
- def _forward(self, model_inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
134
- """Extract embeddings using the model"""
135
- with torch.no_grad():
136
- embeddings = self.model.encode_batch(model_inputs["inputs"])
137
- return embeddings
138
-
139
- def postprocess(self, model_outputs: torch.Tensor) -> List[str]:
140
- """Process model outputs to final predictions"""
141
- # Convert to numpy and reshape
142
- embeddings = model_outputs.cpu().numpy().ravel().reshape(1, -1)
143
- df_embeddings = pd.DataFrame(embeddings, columns=self.feature_names)
144
- # embeddings = np.squeeze(embeddings, axis=1)
145
- # if len(embeddings.shape) == 1:
146
- # embeddings = embeddings.reshape(1, -1)
147
-
148
- # Scale features
149
- scaled_features = self.scaler.transform(df_embeddings)
150
-
151
- # Get SVM predictions and probabilities
152
- predictions = self.svm_model.predict(scaled_features)
153
-
154
- # Format output
155
- results = [self.labels[p] for p in predictions]
156
-
157
- return results
158
-
159
- @classmethod
160
- def from_pretrained(cls, model_path: str, device="cpu"):
161
- """Load all model components"""
162
- import joblib
163
- import json
164
-
165
- # Load configuration
166
- with open(f"{model_path}/config.json", "r") as f:
167
- config = json.load(f)
168
-
169
- # Load SVM and scaler
170
- svm_model = joblib.load(f"{model_path}/svm_model.joblib")
171
- scaler = joblib.load(f"{model_path}/scaler.joblib")
172
-
173
- # Create pipeline instance
174
- pipeline = cls(svm_model=svm_model, scaler=scaler, device=device)
175
- pipeline.labels = config["labels"]
176
-
177
- return pipeline