Spaces:
Runtime error
Runtime error
from audiotools import AudioSignal | |
import torch | |
from pathlib import Path | |
import argbind | |
from tqdm import tqdm | |
import random | |
from typing import List | |
from collections import defaultdict | |
def coarse2fine_infer( | |
signal, | |
model, | |
vqvae, | |
device, | |
): | |
output = {} | |
w = signal | |
w = w.to(device) | |
z = vqvae.encode(w.audio_data, w.sample_rate)["codes"] | |
model.to(device) | |
output["reconstructed"] = model.to_signal(z, vqvae).cpu() | |
# make a full mask | |
mask = torch.ones_like(z) | |
mask[:, :model.n_conditioning_codebooks, :] = 0 | |
output["sampled"] = model.sample( | |
codec=vqvae, | |
time_steps=z.shape[-1], | |
sampling_steps=12, | |
start_tokens=z, | |
mask=mask, | |
temperature=0.85, | |
top_k=None, | |
sample="gumbel", | |
typical_filtering=True, | |
return_signal=True | |
).cpu() | |
output["argmax"] = model.sample( | |
codec=vqvae, | |
time_steps=z.shape[-1], | |
sampling_steps=1, | |
start_tokens=z, | |
mask=mask, | |
temperature=1.0, | |
top_k=None, | |
sample="argmax", | |
typical_filtering=True, | |
return_signal=True | |
).cpu() | |
return output | |
def main( | |
sources=[ | |
"/data/spotdl/audio/val", "/data/spotdl/audio/test" | |
], | |
exp_name="noise_mode", | |
model_paths=[ | |
"runs/c2f-exp-03.22.23/ckpt/mask/epoch=400/vampnet/weights.pth", | |
"runs/c2f-exp-03.22.23/ckpt/random/epoch=400/vampnet/weights.pth", | |
], | |
model_keys=[ | |
"mask", | |
"random", | |
], | |
vqvae_path: str = "runs/codec-ckpt/codec.pth", | |
device: str = "cuda", | |
output_dir: str = ".", | |
max_excerpts: int = 5000, | |
duration: float = 3.0, | |
): | |
from vampnet.modules.transformer import VampNet | |
from lac.model.lac import LAC | |
models = { | |
k: VampNet.load(p) for k, p in zip(model_keys, model_paths) | |
} | |
for model in models.values(): | |
model.eval() | |
print(f"Loaded {len(models)} models.") | |
vqvae = LAC.load(vqvae_path) | |
vqvae.to(device) | |
vqvae.eval() | |
print("Loaded VQVAE.") | |
output_dir = Path(output_dir) / f"{exp_name}-samples" | |
from audiotools.data.datasets import AudioLoader, AudioDataset | |
loader = AudioLoader(sources=sources) | |
dataset = AudioDataset(loader, | |
sample_rate=vqvae.sample_rate, | |
duration=duration, | |
n_examples=max_excerpts, | |
without_replacement=True, | |
) | |
for i in tqdm(range(max_excerpts)): | |
sig = dataset[i]["signal"] | |
sig.resample(vqvae.sample_rate).normalize(-24).ensure_max_of_audio(1.0) | |
for model_key, model in models.items(): | |
out = coarse2fine_infer(sig, model, vqvae, device) | |
out_dir = output_dir / model_key / Path(sig.path_to_file).stem | |
out_dir.mkdir(parents=True, exist_ok=True) | |
for k, s in out.items(): | |
s.write(out_dir / f"{k}.wav") | |
if __name__ == "__main__": | |
args = argbind.parse_args() | |
with argbind.scope(args): | |
main() | |