Spaces:
Runtime error
Runtime error
| import os | |
| import tqdm | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| from torch.utils.data import DataLoader | |
| from transformers import AutoFeatureExtractor, AutoModelForAudioClassification, Wav2Vec2Processor | |
| from torch.nn import functional as F | |
| class CustomDataset(torch.utils.data.Dataset): | |
| def __init__(self, dataset, basedir=None, sampling_rate=16000, max_audio_len=5): | |
| self.dataset = dataset | |
| self.basedir = basedir | |
| self.sampling_rate = sampling_rate | |
| self.max_audio_len = max_audio_len | |
| def __len__(self): | |
| return len(self.dataset) | |
| def _cutorpad(self, audio): | |
| effective_length = self.sampling_rate * self.max_audio_len | |
| len_audio = len(audio) | |
| if len_audio > effective_length: | |
| audio = audio[:effective_length] | |
| return audio | |
| def __getitem__(self, index): | |
| if self.basedir is None: | |
| filepath = self.dataset[index] | |
| else: | |
| filepath = os.path.join(self.basedir, self.dataset[index]) | |
| speech_array, sr = torchaudio.load(filepath) | |
| if speech_array.shape[0] > 1: | |
| speech_array = torch.mean(speech_array, dim=0, keepdim=True) | |
| if sr != self.sampling_rate: | |
| transform = torchaudio.transforms.Resample(sr, self.sampling_rate) | |
| speech_array = transform(speech_array) | |
| sr = self.sampling_rate | |
| speech_array = speech_array.squeeze().numpy() | |
| speech_array = self._cutorpad(speech_array) | |
| return {"input_values": speech_array, "attention_mask": None} | |
| class CollateFunc: | |
| def __init__(self, processor, max_length=None, padding=True, pad_to_multiple_of=None, sampling_rate=16000): | |
| self.padding = padding | |
| self.processor = processor | |
| self.max_length = max_length | |
| self.sampling_rate = sampling_rate | |
| self.pad_to_multiple_of = pad_to_multiple_of | |
| def __call__(self, batch): | |
| input_features = [] | |
| for audio in batch: | |
| input_tensor = self.processor(audio["input_values"], sampling_rate=self.sampling_rate).input_values | |
| input_tensor = np.squeeze(input_tensor) | |
| input_features.append({"input_values": input_tensor}) | |
| batch = self.processor.pad( | |
| input_features, | |
| padding=self.padding, | |
| max_length=self.max_length, | |
| pad_to_multiple_of=self.pad_to_multiple_of, | |
| return_tensors="pt", | |
| ) | |
| return batch | |
| def predict(test_dataloader, model, device): | |
| model.to(device) | |
| model.eval() | |
| preds = [] | |
| with torch.no_grad(): | |
| for batch in tqdm.tqdm(test_dataloader): | |
| input_values = batch['input_values'].to(device) | |
| logits = model(input_values).logits | |
| scores = F.softmax(logits, dim=-1) | |
| pred = torch.argmax(scores, dim=1).cpu().detach().numpy() | |
| preds.extend(pred) | |
| return preds | |
| def get_gender(model_name_or_path, audio_paths, device): | |
| num_labels = 2 | |
| feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path) | |
| model = AutoModelForAudioClassification.from_pretrained( | |
| pretrained_model_name_or_path=model_name_or_path, | |
| num_labels=num_labels, | |
| ) | |
| test_dataset = CustomDataset(audio_paths) | |
| data_collator = CollateFunc( | |
| processor=feature_extractor, | |
| padding=True, | |
| sampling_rate=16000, | |
| ) | |
| test_dataloader = DataLoader( | |
| dataset=test_dataset, | |
| batch_size=16, | |
| collate_fn=data_collator, | |
| shuffle=False, | |
| num_workers=10 | |
| ) | |
| preds = predict(test_dataloader=test_dataloader, model=model, device=device) | |
| # Map class indices to labels | |
| label_mapping = {0: "female", 1: "male"} | |
| # Determine the most common predicted label | |
| most_common_label = max(set(preds), key=preds.count) | |
| predicted_label = label_mapping[most_common_label] | |
| return predicted_label |