|
import argparse |
|
import json |
|
from collections import defaultdict |
|
from pathlib import Path |
|
from random import sample, shuffle |
|
from typing import Optional |
|
|
|
from tqdm import tqdm |
|
|
|
from config import get_config |
|
from style_bert_vits2.logging import logger |
|
from style_bert_vits2.nlp import clean_text |
|
from style_bert_vits2.nlp.japanese import pyopenjtalk_worker |
|
from style_bert_vits2.nlp.japanese.user_dict import update_dict |
|
from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT |
|
|
|
|
|
|
|
pyopenjtalk_worker.initialize_worker() |
|
|
|
|
|
update_dict() |
|
|
|
|
|
preprocess_text_config = get_config().preprocess_text_config |
|
|
|
|
|
|
|
def count_lines(file_path: Path): |
|
with file_path.open("r", encoding="utf-8") as file: |
|
return sum(1 for _ in file) |
|
|
|
|
|
def write_error_log(error_log_path: Path, line: str, error: Exception): |
|
with error_log_path.open("a", encoding="utf-8") as error_log: |
|
error_log.write(f"{line.strip()}\n{error}\n\n") |
|
|
|
|
|
def process_line( |
|
line: str, |
|
transcription_path: Path, |
|
correct_path: bool, |
|
use_jp_extra: bool, |
|
yomi_error: str, |
|
): |
|
splitted_line = line.strip().split("|") |
|
if len(splitted_line) != 4: |
|
raise ValueError(f"Invalid line format: {line.strip()}") |
|
utt, spk, language, text = splitted_line |
|
norm_text, phones, tones, word2ph = clean_text( |
|
text=text, |
|
language=language, |
|
use_jp_extra=use_jp_extra, |
|
raise_yomi_error=(yomi_error != "use"), |
|
) |
|
if correct_path: |
|
utt = str(transcription_path.parent / "wavs" / utt) |
|
|
|
return "{}|{}|{}|{}|{}|{}|{}\n".format( |
|
utt, |
|
spk, |
|
language, |
|
norm_text, |
|
" ".join(phones), |
|
" ".join([str(i) for i in tones]), |
|
" ".join([str(i) for i in word2ph]), |
|
) |
|
|
|
|
|
def preprocess( |
|
transcription_path: Path, |
|
cleaned_path: Optional[Path], |
|
train_path: Path, |
|
val_path: Path, |
|
config_path: Path, |
|
val_per_lang: int, |
|
max_val_total: int, |
|
|
|
use_jp_extra: bool, |
|
yomi_error: str, |
|
correct_path: bool, |
|
): |
|
assert yomi_error in ["raise", "skip", "use"] |
|
if cleaned_path == "" or cleaned_path is None: |
|
cleaned_path = transcription_path.with_name( |
|
transcription_path.name + ".cleaned" |
|
) |
|
|
|
error_log_path = transcription_path.parent / "text_error.log" |
|
if error_log_path.exists(): |
|
error_log_path.unlink() |
|
error_count = 0 |
|
|
|
total_lines = count_lines(transcription_path) |
|
|
|
|
|
with ( |
|
transcription_path.open("r", encoding="utf-8") as trans_file, |
|
cleaned_path.open("w", encoding="utf-8") as out_file, |
|
): |
|
for line in tqdm(trans_file, file=SAFE_STDOUT, total=total_lines): |
|
try: |
|
processed_line = process_line( |
|
line, |
|
transcription_path, |
|
correct_path, |
|
use_jp_extra, |
|
yomi_error, |
|
) |
|
out_file.write(processed_line) |
|
except Exception as e: |
|
logger.error( |
|
f"An error occurred at line:\n{line.strip()}\n{e}", encoding="utf-8" |
|
) |
|
write_error_log(error_log_path, line, e) |
|
error_count += 1 |
|
|
|
transcription_path = cleaned_path |
|
|
|
|
|
spk_utt_map: dict[str, list[str]] = defaultdict(list) |
|
|
|
|
|
spk_id_map: dict[str, int] = {} |
|
|
|
|
|
current_sid: int = 0 |
|
|
|
|
|
with transcription_path.open("r", encoding="utf-8") as f: |
|
audio_paths: set[str] = set() |
|
count_same = 0 |
|
count_not_found = 0 |
|
for line in f.readlines(): |
|
utt, spk = line.strip().split("|")[:2] |
|
if utt in audio_paths: |
|
logger.warning(f"Same audio file appears multiple times: {utt}") |
|
count_same += 1 |
|
continue |
|
if not Path(utt).is_file(): |
|
logger.warning(f"Audio not found: {utt}") |
|
count_not_found += 1 |
|
continue |
|
audio_paths.add(utt) |
|
spk_utt_map[spk].append(line) |
|
|
|
|
|
if spk not in spk_id_map: |
|
spk_id_map[spk] = current_sid |
|
current_sid += 1 |
|
if count_same > 0 or count_not_found > 0: |
|
logger.warning( |
|
f"Total repeated audios: {count_same}, Total number of audio not found: {count_not_found}" |
|
) |
|
|
|
train_list: list[str] = [] |
|
val_list: list[str] = [] |
|
|
|
|
|
for spk, utts in spk_utt_map.items(): |
|
if val_per_lang == 0: |
|
train_list.extend(utts) |
|
continue |
|
|
|
val_indices = set(sample(range(len(utts)), val_per_lang)) |
|
|
|
for index, utt in enumerate(utts): |
|
if index in val_indices: |
|
val_list.append(utt) |
|
else: |
|
train_list.append(utt) |
|
|
|
|
|
if len(val_list) > max_val_total: |
|
extra_val = val_list[max_val_total:] |
|
val_list = val_list[:max_val_total] |
|
|
|
train_list.extend(extra_val) |
|
|
|
with train_path.open("w", encoding="utf-8") as f: |
|
for line in train_list: |
|
f.write(line) |
|
|
|
with val_path.open("w", encoding="utf-8") as f: |
|
for line in val_list: |
|
f.write(line) |
|
|
|
with config_path.open("r", encoding="utf-8") as f: |
|
json_config = json.load(f) |
|
|
|
json_config["data"]["spk2id"] = spk_id_map |
|
json_config["data"]["n_speakers"] = len(spk_id_map) |
|
|
|
with config_path.open("w", encoding="utf-8") as f: |
|
json.dump(json_config, f, indent=2, ensure_ascii=False) |
|
if error_count > 0: |
|
if yomi_error == "skip": |
|
logger.warning( |
|
f"An error occurred in {error_count} lines. Proceed with lines without errors. Please check {error_log_path} for details." |
|
) |
|
else: |
|
|
|
|
|
|
|
logger.error( |
|
f"An error occurred in {error_count} lines. Please check {error_log_path} for details." |
|
) |
|
raise Exception( |
|
f"An error occurred in {error_count} lines. Please check `Data/you_model_name/text_error.log` file for details." |
|
) |
|
|
|
else: |
|
logger.info( |
|
"Training set and validation set generation from texts is complete!" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--transcription-path", default=preprocess_text_config.transcription_path |
|
) |
|
parser.add_argument("--cleaned-path", default=preprocess_text_config.cleaned_path) |
|
parser.add_argument("--train-path", default=preprocess_text_config.train_path) |
|
parser.add_argument("--val-path", default=preprocess_text_config.val_path) |
|
parser.add_argument("--config-path", default=preprocess_text_config.config_path) |
|
|
|
|
|
|
|
parser.add_argument( |
|
"--val-per-lang", |
|
default=preprocess_text_config.val_per_lang, |
|
help="Number of validation data per SPEAKER, not per language (due to compatibility with the original code).", |
|
) |
|
parser.add_argument("--max-val-total", default=preprocess_text_config.max_val_total) |
|
parser.add_argument("--use_jp_extra", action="store_true") |
|
parser.add_argument("--yomi_error", default="raise") |
|
parser.add_argument("--correct_path", action="store_true") |
|
|
|
args = parser.parse_args() |
|
|
|
transcription_path = Path(args.transcription_path) |
|
cleaned_path = Path(args.cleaned_path) if args.cleaned_path else None |
|
train_path = Path(args.train_path) |
|
val_path = Path(args.val_path) |
|
config_path = Path(args.config_path) |
|
val_per_lang = int(args.val_per_lang) |
|
max_val_total = int(args.max_val_total) |
|
use_jp_extra: bool = args.use_jp_extra |
|
yomi_error: str = args.yomi_error |
|
correct_path: bool = args.correct_path |
|
|
|
preprocess( |
|
transcription_path=transcription_path, |
|
cleaned_path=cleaned_path, |
|
train_path=train_path, |
|
val_path=val_path, |
|
config_path=config_path, |
|
val_per_lang=val_per_lang, |
|
max_val_total=max_val_total, |
|
use_jp_extra=use_jp_extra, |
|
yomi_error=yomi_error, |
|
correct_path=correct_path, |
|
) |
|
|