|
""" |
|
https://alexander-stasiuk.medium.com/pytorch-weights-averaging-e2c0fa611a0c |
|
""" |
|
|
|
import os |
|
|
|
import torch |
|
|
|
from Architectures.ToucanTTS.InferenceToucanTTS import ToucanTTS |
|
from Architectures.Vocoder.HiFiGAN_Generator import HiFiGAN |
|
from Utility.storage_config import MODELS_DIR |
|
|
|
|
|
def load_net_toucan(path): |
|
check_dict = torch.load(path, map_location=torch.device("cpu")) |
|
net = ToucanTTS(weights=check_dict["model"], config=check_dict["config"]) |
|
return net, check_dict["default_emb"] |
|
|
|
|
|
def load_net_bigvgan(path): |
|
check_dict = torch.load(path, map_location=torch.device("cpu")) |
|
net = HiFiGAN(weights=check_dict["generator"]) |
|
return net, None |
|
|
|
|
|
def get_n_recent_checkpoints_paths(checkpoint_dir, n=5): |
|
print("selecting checkpoints...") |
|
checkpoint_list = list() |
|
for el in os.listdir(checkpoint_dir): |
|
if el.endswith(".pt") and el.startswith("checkpoint_"): |
|
try: |
|
checkpoint_list.append(int(el.split(".")[0].split("_")[1])) |
|
except RuntimeError: |
|
pass |
|
if len(checkpoint_list) == 0: |
|
return None |
|
elif len(checkpoint_list) < n: |
|
n = len(checkpoint_list) |
|
checkpoint_list.sort(reverse=True) |
|
return [os.path.join(checkpoint_dir, "checkpoint_{}.pt".format(step)) for step in checkpoint_list[:n]] |
|
|
|
|
|
def average_checkpoints(list_of_checkpoint_paths, load_func): |
|
|
|
if list_of_checkpoint_paths is None or len(list_of_checkpoint_paths) == 0: |
|
return None |
|
checkpoints_weights = {} |
|
model = None |
|
default_embed = None |
|
|
|
|
|
for path_to_checkpoint in list_of_checkpoint_paths: |
|
print("loading model {}".format(path_to_checkpoint)) |
|
model, default_embed = load_func(path=path_to_checkpoint) |
|
checkpoints_weights[path_to_checkpoint] = dict(model.named_parameters()) |
|
|
|
|
|
params = model.named_parameters() |
|
dict_params = dict(params) |
|
checkpoint_amount = len(checkpoints_weights) |
|
print("averaging...") |
|
for name in dict_params.keys(): |
|
custom_params = None |
|
for _, checkpoint_parameters in checkpoints_weights.items(): |
|
if custom_params is None: |
|
custom_params = checkpoint_parameters[name].data |
|
else: |
|
custom_params += checkpoint_parameters[name].data |
|
dict_params[name].data.copy_(custom_params / checkpoint_amount) |
|
model_dict = model.state_dict() |
|
model_dict.update(dict_params) |
|
model.load_state_dict(model_dict) |
|
model.eval() |
|
return model, default_embed |
|
|
|
|
|
def save_model_for_use(model, name="", default_embed=None, dict_name="model"): |
|
print("saving model...") |
|
torch.save({dict_name: model.state_dict(), "default_emb": default_embed, "config": model.config}, name) |
|
print("...done!") |
|
|
|
|
|
def make_best_in_all(): |
|
for model_dir in os.listdir(MODELS_DIR): |
|
if os.path.isdir(os.path.join(MODELS_DIR, model_dir)): |
|
if "ToucanTTS" in model_dir: |
|
checkpoint_paths = get_n_recent_checkpoints_paths(checkpoint_dir=os.path.join(MODELS_DIR, model_dir), n=3) |
|
if checkpoint_paths is None: |
|
continue |
|
averaged_model, default_embed = average_checkpoints(checkpoint_paths, load_func=load_net_toucan) |
|
save_model_for_use(model=averaged_model, default_embed=default_embed, name=os.path.join(MODELS_DIR, model_dir, "best.pt")) |
|
|
|
|
|
def count_parameters(net): |
|
return sum(p.numel() for p in net.parameters() if p.requires_grad) |
|
|
|
|
|
if __name__ == '__main__': |
|
make_best_in_all() |
|
|