Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files
rvc/lib/process/model_fusion.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def extract(ckpt):
|
| 6 |
+
model = ckpt["model"]
|
| 7 |
+
opt = OrderedDict()
|
| 8 |
+
opt["weight"] = {key: value for key, value in model.items() if "enc_q" not in key}
|
| 9 |
+
return opt
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def model_fusion(model_name, pth_path_1, pth_path_2):
|
| 13 |
+
ckpt1 = torch.load(pth_path_1, map_location="cpu")
|
| 14 |
+
ckpt2 = torch.load(pth_path_2, map_location="cpu")
|
| 15 |
+
if "model" in ckpt1:
|
| 16 |
+
ckpt1 = extract(ckpt1)
|
| 17 |
+
else:
|
| 18 |
+
ckpt1 = ckpt1["weight"]
|
| 19 |
+
if "model" in ckpt2:
|
| 20 |
+
ckpt2 = extract(ckpt2)
|
| 21 |
+
else:
|
| 22 |
+
ckpt2 = ckpt2["weight"]
|
| 23 |
+
if sorted(ckpt1.keys()) != sorted(ckpt2.keys()):
|
| 24 |
+
return "Fail to merge the models. The model architectures are not the same."
|
| 25 |
+
opt = OrderedDict(
|
| 26 |
+
weight={
|
| 27 |
+
key: 1 * value.float() + (1 - 1) * ckpt2[key].float()
|
| 28 |
+
for key, value in ckpt1.items()
|
| 29 |
+
}
|
| 30 |
+
)
|
| 31 |
+
opt["info"] = f"Model fusion of {pth_path_1} and {pth_path_2}"
|
| 32 |
+
torch.save(opt, f"logs/{model_name}.pth")
|
| 33 |
+
print(f"Model fusion of {pth_path_1} and {pth_path_2} is done.")
|
rvc/lib/process/model_information.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def model_information(path):
|
| 4 |
+
model_data = torch.load(path, map_location="cpu")
|
| 5 |
+
|
| 6 |
+
print(f"Loaded model from {path}")
|
| 7 |
+
|
| 8 |
+
data = model_data
|
| 9 |
+
|
| 10 |
+
epochs = data.get("info", "None")
|
| 11 |
+
sr = data.get("sr", "None")
|
| 12 |
+
f0 = data.get("f0", "None")
|
| 13 |
+
version = data.get("version", "None")
|
| 14 |
+
|
| 15 |
+
return(f"Epochs: {epochs}\nSampling rate: {sr}\nPitch guidance: {f0}\nVersion: {version}")
|