griko commited on
Commit
309a4fd
·
verified ·
1 Parent(s): 010c2a5

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +8 -1
  2. modeling_gender.py +177 -0
README.md CHANGED
@@ -54,8 +54,15 @@ pip install scikit-learn pandas soundfile speechbrain torch torchaudio transform
54
 
55
  ```python
56
  from transformers import pipeline
 
 
 
 
 
 
 
 
57
 
58
- classifier = pipeline("audio-classification", model="{repo_id}")
59
  result = classifier("path/to/audio.wav")
60
  print(result) # ["female"] or ["male"]
61
  ```
 
54
 
55
  ```python
56
  from transformers import pipeline
57
+ from modeling_gender import GenderClassificationPipeline
58
+
59
+ # Load the pipeline
60
+ classifier = pipeline(
61
+ "audio-classification",
62
+ model="griko/gender_cls_svm_ecapa_voxceleb",
63
+ pipeline_class=GenderClassificationPipeline
64
+ )
65
 
 
66
  result = classifier("path/to/audio.wav")
67
  print(result) # ["female"] or ["male"]
68
  ```
modeling_gender.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import numpy as np
4
+ import pandas as pd
5
+ import soundfile as sf
6
+ from transformers import Pipeline
7
+ from typing import Union, List, Tuple, Dict, Any
8
+ from speechbrain.inference.speaker import EncoderClassifier
9
+
10
+ class GenderClassificationPipeline(Pipeline):
11
+ def __init__(self, svm_model, scaler, device="cpu"):
12
+ """
13
+ Initialize the pipeline with SVM model and scaler.
14
+ The ECAPA-TDNN model is loaded using SpeechBrain's EncoderClassifier.
15
+ """
16
+ # Convert device string to torch.device
17
+ self.device = torch.device(device)
18
+
19
+ # Initialize the model with the proper device
20
+ self.model = EncoderClassifier.from_hparams(
21
+ source="speechbrain/spkrec-ecapa-voxceleb",
22
+ run_opts={"device": str(self.device)} # SpeechBrain expects string
23
+ )
24
+
25
+ self.feature_names = [f"{i}_speechbrain_embedding" for i in range(192)]
26
+
27
+ # Add feature_extractor for handling multiple files
28
+ self.feature_extractor = lambda x: {"sampling_rate": 16000, "raw_speech": x}
29
+
30
+ # Audio processing parameters
31
+ self.target_sample_rate = 16000
32
+ self.target_bitrate = 256000 # 256 kbps
33
+ self.bits_per_sample = 16
34
+
35
+ self.svm_model = svm_model
36
+ self.scaler = scaler
37
+ self.labels = ["female", "male"]
38
+
39
+ # Required by Pipeline class
40
+ self.framework = "pt"
41
+ self._batch_size = 1
42
+ self._num_workers = None
43
+ self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters()
44
+ self._framework = "pt"
45
+ self.call_count = 0
46
+ self.sequential = True
47
+ # self.torch_dtype = None
48
+ self.is_encoder_decoder = False
49
+
50
+ def _sanitize_parameters(self, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
51
+ """Sanitize parameters for preprocess, forward, and postprocess steps"""
52
+ preprocess_kwargs = {}
53
+ forward_kwargs = {}
54
+ postprocess_kwargs = {}
55
+
56
+ return preprocess_kwargs, forward_kwargs, postprocess_kwargs
57
+
58
+ def _process_audio(self, waveform: torch.Tensor, sample_rate: int) -> torch.Tensor:
59
+ """Process audio to match target specifications"""
60
+ # Convert to mono if needed
61
+ if len(waveform.shape) > 1 and waveform.shape[0] > 1:
62
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
63
+
64
+ # Resample to 16kHz if needed
65
+ if sample_rate != self.target_sample_rate:
66
+ resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
67
+ waveform = resampler(waveform)
68
+
69
+ # Normalize the audio to be between -1 and 1
70
+ if waveform.abs().max() > 1:
71
+ waveform = waveform / waveform.abs().max()
72
+
73
+ # Convert to 16-bit precision
74
+ waveform = (waveform * 32767).round() / 32767
75
+
76
+ # Calculate target samples based on bitrate
77
+ # bitrate = sample_rate * bits_per_sample * channels
78
+ target_samples = int((self.target_bitrate * waveform.shape[1]) /
79
+ (self.target_sample_rate * self.bits_per_sample))
80
+
81
+ # Adjust number of samples if needed
82
+ if waveform.shape[1] != target_samples:
83
+ # Either truncate or pad with zeros
84
+ if waveform.shape[1] > target_samples:
85
+ waveform = waveform[:, :target_samples]
86
+ else:
87
+ padding = target_samples - waveform.shape[1]
88
+ waveform = torch.nn.functional.pad(waveform, (0, padding))
89
+
90
+ return waveform
91
+
92
+ def preprocess(self, audio_input: Union[str, np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
93
+ """Preprocess audio input"""
94
+ if isinstance(audio_input, list):
95
+ waveforms = []
96
+ for audio_file in audio_input:
97
+ if isinstance(audio_file, str):
98
+ wave, sr = sf.read(audio_file)
99
+ wave = torch.from_numpy(wave).float()
100
+ if len(wave.shape) == 1:
101
+ wave = wave.unsqueeze(0)
102
+ else:
103
+ wave = wave.T
104
+ wave = self._process_audio(wave, sr)
105
+ waveforms.append(wave)
106
+ else:
107
+ raise ValueError(f"Unsupported audio input type in list: {type(audio_file)}")
108
+
109
+ # Stack all waveforms
110
+ waveform = torch.stack(waveforms)
111
+ return {"inputs": waveform.to(self.device)}
112
+
113
+ # Handle single input
114
+ if isinstance(audio_input, str):
115
+ waveform, sample_rate = sf.read(audio_input)
116
+ waveform = torch.from_numpy(waveform).float()
117
+ if len(waveform.shape) == 1:
118
+ waveform = waveform.unsqueeze(0)
119
+ else:
120
+ waveform = waveform.T
121
+ elif isinstance(audio_input, np.ndarray):
122
+ waveform = torch.from_numpy(audio_input).float()
123
+ if len(waveform.shape) == 1:
124
+ waveform = waveform.unsqueeze(0)
125
+ sample_rate = self.target_sample_rate
126
+ else:
127
+ waveform = audio_input
128
+ sample_rate = self.target_sample_rate
129
+
130
+ waveform = self._process_audio(waveform, sample_rate)
131
+ return {"inputs": waveform.to(self.device)}
132
+
133
+ def _forward(self, model_inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
134
+ """Extract embeddings using the model"""
135
+ with torch.no_grad():
136
+ embeddings = self.model.encode_batch(model_inputs["inputs"])
137
+ return embeddings
138
+
139
+ def postprocess(self, model_outputs: torch.Tensor) -> List[str]:
140
+ """Process model outputs to final predictions"""
141
+ # Convert to numpy and reshape
142
+ embeddings = model_outputs.cpu().numpy().ravel().reshape(1, -1)
143
+ df_embeddings = pd.DataFrame(embeddings, columns=self.feature_names)
144
+ # embeddings = np.squeeze(embeddings, axis=1)
145
+ # if len(embeddings.shape) == 1:
146
+ # embeddings = embeddings.reshape(1, -1)
147
+
148
+ # Scale features
149
+ scaled_features = self.scaler.transform(df_embeddings)
150
+
151
+ # Get SVM predictions and probabilities
152
+ predictions = self.svm_model.predict(scaled_features)
153
+
154
+ # Format output
155
+ results = [self.labels[p] for p in predictions]
156
+
157
+ return results
158
+
159
+ @classmethod
160
+ def from_pretrained(cls, model_path: str, device="cpu"):
161
+ """Load all model components"""
162
+ import joblib
163
+ import json
164
+
165
+ # Load configuration
166
+ with open(f"{model_path}/config.json", "r") as f:
167
+ config = json.load(f)
168
+
169
+ # Load SVM and scaler
170
+ svm_model = joblib.load(f"{model_path}/svm_model.joblib")
171
+ scaler = joblib.load(f"{model_path}/scaler.joblib")
172
+
173
+ # Create pipeline instance
174
+ pipeline = cls(svm_model=svm_model, scaler=scaler, device=device)
175
+ pipeline.labels = config["labels"]
176
+
177
+ return pipeline