Spaces:
Sleeping
Sleeping
File size: 2,108 Bytes
a559a3b 64a6fed a559a3b 6c005f0 a559a3b 6c005f0 64a6fed a559a3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import hydra
from omegaconf import DictConfig
import torch
from remfx.models import RemFXChainInference
import torchaudio
@hydra.main(
version_base=None,
config_path="../cfg",
config_name="config.yaml",
)
def main(cfg: DictConfig):
print("Loading models...")
models = {}
for effect in cfg.ckpts:
model = hydra.utils.instantiate(cfg.ckpts[effect].model, _convert_="partial")
ckpt_path = cfg.ckpts[effect].ckpt_path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
model.load_state_dict(state_dict)
model.to(device)
models[effect] = model
classifier = hydra.utils.instantiate(cfg.classifier, _convert_="partial")
ckpt_path = cfg.classifier_ckpt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
classifier.load_state_dict(state_dict)
classifier.to(device)
inference_model = RemFXChainInference(
models,
sample_rate=cfg.sample_rate,
num_bins=cfg.num_bins,
effect_order=cfg.inference_effects_ordering,
classifier=classifier,
shuffle_effect_order=cfg.inference_effects_shuffle,
use_all_effect_models=cfg.inference_use_all_effect_models,
)
audio_file = cfg.audio_input
print("Loading", audio_file)
audio, sr = torchaudio.load(audio_file)
# Resample
audio = torchaudio.transforms.Resample(sr, cfg.sample_rate)(audio)
# Convert to mono
audio = audio.mean(0, keepdim=True)
# Add dimension for batch
audio = audio.unsqueeze(0)
audio = audio.to(device)
batch = [audio, audio, None, None]
_, y = inference_model(batch, 0, verbose=True)
y = y.cpu()
if "output_path" in cfg:
output_path = cfg.output_path
else:
output_path = "./output.wav"
print("Saving output to", output_path)
torchaudio.save(output_path, y[0], sample_rate=cfg.sample_rate)
if __name__ == "__main__":
main()
|