Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import gc | |
| import os | |
| import random | |
| import shutil | |
| import numpy as np | |
| import torch | |
| import tqdm | |
| from examples.textless_nlp.gslm.speech2unit.pretrained.cpc_feature_reader import ( | |
| CpcFeatureReader, | |
| ) | |
| from examples.textless_nlp.gslm.speech2unit.pretrained.hubert_feature_reader import ( | |
| HubertFeatureReader, | |
| ) | |
| from examples.textless_nlp.gslm.speech2unit.pretrained.logmel_feature_reader import ( | |
| LogMelFeatureReader, | |
| ) | |
| from examples.textless_nlp.gslm.speech2unit.pretrained.w2v2_feature_reader import ( | |
| Wav2VecFeatureReader, | |
| ) | |
| def get_feature_reader(feature_type): | |
| if feature_type == "logmel": | |
| return LogMelFeatureReader | |
| elif feature_type == "hubert": | |
| return HubertFeatureReader | |
| elif feature_type == "w2v2": | |
| return Wav2VecFeatureReader | |
| elif feature_type == "cpc": | |
| return CpcFeatureReader | |
| else: | |
| raise NotImplementedError(f"{feature_type} is not supported.") | |
| def get_feature_iterator( | |
| feature_type, checkpoint_path, layer, manifest_path, sample_pct | |
| ): | |
| feature_reader_cls = get_feature_reader(feature_type) | |
| with open(manifest_path, "r") as fp: | |
| lines = fp.read().split("\n") | |
| root = lines.pop(0).strip() | |
| file_path_list = [ | |
| os.path.join(root, line.split("\t")[0]) | |
| for line in lines | |
| if len(line) > 0 | |
| ] | |
| if sample_pct < 1.0: | |
| file_path_list = random.sample( | |
| file_path_list, int(sample_pct * len(file_path_list)) | |
| ) | |
| num_files = len(file_path_list) | |
| reader = feature_reader_cls( | |
| checkpoint_path=checkpoint_path, layer=layer | |
| ) | |
| def iterate(): | |
| for file_path in file_path_list: | |
| feats = reader.get_feats(file_path) | |
| yield feats.cpu().numpy() | |
| return iterate, num_files | |
| def get_features( | |
| feature_type, checkpoint_path, layer, manifest_path, sample_pct, flatten | |
| ): | |
| generator, num_files = get_feature_iterator( | |
| feature_type=feature_type, | |
| checkpoint_path=checkpoint_path, | |
| layer=layer, | |
| manifest_path=manifest_path, | |
| sample_pct=sample_pct, | |
| ) | |
| iterator = generator() | |
| features_list = [] | |
| for features in tqdm.tqdm(iterator, total=num_files): | |
| features_list.append(features) | |
| # Explicit clean up | |
| del iterator | |
| del generator | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| if flatten: | |
| return np.concatenate(features_list) | |
| return features_list | |
| def get_and_dump_features( | |
| feature_type, | |
| checkpoint_path, | |
| layer, | |
| manifest_path, | |
| sample_pct, | |
| flatten, | |
| out_features_path, | |
| ): | |
| # Feature extraction | |
| features_batch = get_features( | |
| feature_type=feature_type, | |
| checkpoint_path=checkpoint_path, | |
| layer=layer, | |
| manifest_path=manifest_path, | |
| sample_pct=sample_pct, | |
| flatten=flatten, | |
| ) | |
| # Save features | |
| out_dir_path = os.path.dirname(out_features_path) | |
| os.makedirs(out_dir_path, exist_ok=True) | |
| shutil.copyfile( | |
| manifest_path, | |
| os.path.join(out_dir_path, os.path.basename(manifest_path)), | |
| ) | |
| np.save(out_features_path, features_batch) | |
| return features_batch | |