sovits-test / svc_train_retrieval.py
atsushieee's picture
Upload folder using huggingface_hub
9791162
raw
history blame
4.08 kB
import argparse
import logging
import multiprocessing
from functools import partial
from pathlib import Path
import faiss
from feature_retrieval import (
train_index,
FaissIVFFlatTrainableFeatureIndexBuilder,
OnConditionFeatureTransform,
MinibatchKmeansFeatureTransform,
DummyFeatureTransform,
)
logger = logging.getLogger(__name__)
def get_speaker_list(base_path: Path):
speakers_path = base_path / "waves-16k"
if not speakers_path.exists():
raise FileNotFoundError(f"path {speakers_path} does not exists")
return [speaker_dir.name for speaker_dir in speakers_path.iterdir() if speaker_dir.is_dir()]
def create_indexes_path(base_path: Path) -> Path:
indexes_path = base_path / "indexes"
logger.info("create indexes folder %s", indexes_path)
indexes_path.mkdir(exist_ok=True)
return indexes_path
def create_index(
feature_name: str,
prefix: str,
speaker: str,
base_path: Path,
indexes_path: Path,
compress_features_after: int,
n_clusters: int,
n_parallel: int,
train_batch_size: int = 8192,
) -> None:
features_path = base_path / feature_name / speaker
if not features_path.exists():
raise ValueError(f'features not found by path {features_path}')
index_path = indexes_path / speaker
index_path.mkdir(exist_ok=True)
index_filename = f"{prefix}{feature_name}.index"
index_filepath = index_path / index_filename
logger.debug('index will be save to %s', index_filepath)
builder = FaissIVFFlatTrainableFeatureIndexBuilder(train_batch_size, distance=faiss.METRIC_L2)
transform = OnConditionFeatureTransform(
condition=lambda matrix: matrix.shape[0] > compress_features_after,
on_condition=MinibatchKmeansFeatureTransform(n_clusters, n_parallel),
otherwise=DummyFeatureTransform()
)
train_index(features_path, index_filepath, builder, transform)
def main() -> None:
arg_parser = argparse.ArgumentParser("crate faiss indexes for feature retrieval")
arg_parser.add_argument("--debug", action="store_true")
arg_parser.add_argument("--prefix", default='', help="add prefix to index filename")
arg_parser.add_argument('--speakers', nargs="+",
help="speaker names to create an index. By default all speakers are from data_svc")
arg_parser.add_argument("--compress-features-after", type=int, default=200_000,
help="If the number of features is greater than the value compress "
"feature vectors using MiniBatchKMeans.")
arg_parser.add_argument("--n-clusters", type=int, default=10_000,
help="Number of centroids to which features will be compressed")
arg_parser.add_argument("--n-parallel", type=int, default=multiprocessing.cpu_count()-1,
help="Nuber of parallel job of MinibatchKmeans. Default is cpus-1")
args = arg_parser.parse_args()
if args.debug:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
base_path = Path(".").absolute() / "data_svc"
if args.speakers:
speakers = args.speakers
else:
speakers = get_speaker_list(base_path)
logger.info("got %s speakers: %s", len(speakers), speakers)
indexes_path = create_indexes_path(base_path)
create_index_func = partial(
create_index,
prefix=args.prefix,
base_path=base_path,
indexes_path=indexes_path,
compress_features_after=args.compress_features_after,
n_clusters=args.n_clusters,
n_parallel=args.n_parallel,
)
for speaker in speakers:
logger.info("create hubert index for speaker %s", speaker)
create_index_func(feature_name="hubert", speaker=speaker)
logger.info("create whisper index for speaker %s", speaker)
create_index_func(feature_name="whisper", speaker=speaker)
logger.info("done!")
if __name__ == '__main__':
main()