Spaces:
Runtime error
Runtime error
| import torch | |
| def get_synthesizer(pth_path, device=torch.device("cpu")): | |
| from infer.lib.infer_pack.models import ( | |
| SynthesizerTrnMs256NSFsid, | |
| SynthesizerTrnMs256NSFsid_nono, | |
| SynthesizerTrnMs768NSFsid, | |
| SynthesizerTrnMs768NSFsid_nono, | |
| ) | |
| cpt = torch.load(pth_path, map_location=torch.device("cpu")) | |
| # tgt_sr = cpt["config"][-1] | |
| cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] | |
| if_f0 = cpt.get("f0", 1) | |
| version = cpt.get("version", "v1") | |
| if version == "v1": | |
| if if_f0 == 1: | |
| net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=False) | |
| else: | |
| net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) | |
| elif version == "v2": | |
| if if_f0 == 1: | |
| net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=False) | |
| else: | |
| net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"]) | |
| del net_g.enc_q | |
| # net_g.forward = net_g.infer | |
| # ckpt = {} | |
| # ckpt["config"] = cpt["config"] | |
| # ckpt["f0"] = if_f0 | |
| # ckpt["version"] = version | |
| # ckpt["info"] = cpt.get("info", "0epoch") | |
| net_g.load_state_dict(cpt["weight"], strict=False) | |
| net_g = net_g.float() | |
| net_g.eval().to(device) | |
| net_g.remove_weight_norm() | |
| return net_g, cpt | |