Spaces:
Runtime error
Runtime error
| from audiotools import AudioSignal | |
| import torch | |
| from pathlib import Path | |
| import argbind | |
| from tqdm import tqdm | |
| import random | |
| from typing import List | |
| from collections import defaultdict | |
| def coarse2fine_infer( | |
| signal, | |
| model, | |
| vqvae, | |
| device, | |
| signal_window=3, | |
| signal_hop=1.5, | |
| max_excerpts=20, | |
| ): | |
| output = defaultdict(list) | |
| # split into 3 seconds | |
| windows = [s for s in signal.clone().windows(signal_window, signal_hop)] | |
| windows = windows[1:] # skip first window since it's half zero padded | |
| random.shuffle(windows) | |
| for w in windows[:max_excerpts]: | |
| # batch the signal into chunks of 3 | |
| with torch.no_grad(): | |
| # get codes | |
| w = w.to(device) | |
| z = vqvae.encode(w.audio_data, w.sample_rate)["codes"] | |
| model.to(device) | |
| output["reconstructed"] = model.to_signal(z, vqvae).cpu() | |
| # make a full mask | |
| mask = torch.ones_like(z) | |
| mask[:, :model.n_conditioning_codebooks, :] = 0 | |
| output["sampled"].append(model.sample( | |
| codec=vqvae, | |
| time_steps=z.shape[-1], | |
| sampling_steps=12, | |
| start_tokens=z, | |
| mask=mask, | |
| temperature=0.85, | |
| top_k=None, | |
| sample="gumbel", | |
| typical_filtering=True, | |
| return_signal=True | |
| ).cpu()) | |
| output["argmax"].append(model.sample( | |
| codec=vqvae, | |
| time_steps=z.shape[-1], | |
| sampling_steps=1, | |
| start_tokens=z, | |
| mask=mask, | |
| temperature=1.0, | |
| top_k=None, | |
| sample="argmax", | |
| typical_filtering=True, | |
| return_signal=True | |
| ).cpu()) | |
| return output | |
| def main( | |
| sources=[ | |
| "/data/spotdl/audio/val", "/data/spotdl/audio/test" | |
| ], | |
| audio_ext="mp3", | |
| exp_name="noise_mode", | |
| model_paths=[ | |
| "runs/c2f-exp-03.22.23/ckpt/mask/best/vampnet/weights.pth", | |
| "runs/c2f-exp-03.22.23/ckpt/random/best/vampnet/weights.pth", | |
| ], | |
| model_keys=[ | |
| "mask", | |
| "random", | |
| ], | |
| vqvae_path: str = "runs/codec-ckpt/codec.pth", | |
| device: str = "cuda", | |
| output_dir: str = ".", | |
| ): | |
| from vampnet.modules.transformer import VampNet | |
| from lac.model.lac import LAC | |
| from audiotools.post import audio_zip | |
| models = { | |
| k: VampNet.load(p) for k, p in zip(model_keys, model_paths) | |
| } | |
| for model in models.values(): | |
| model.eval() | |
| print(f"Loaded {len(models)} models.") | |
| vqvae = LAC.load(vqvae_path) | |
| vqvae.to(device) | |
| vqvae.eval() | |
| print("Loaded VQVAE.") | |
| output_dir = Path(output_dir) / f"{exp_name}-samples" | |
| for source in sources: | |
| print(f"Processing {source}...") | |
| source_files = list(Path(source).glob(f"**/*.{audio_ext}")) | |
| random.shuffle(source_files) | |
| for path in tqdm(source_files): | |
| sig = AudioSignal(path) | |
| sig.resample(vqvae.sample_rate).normalize(-24).ensure_max_of_audio(1.0) | |
| out_dir = output_dir / path.stem | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| if out_dir.exists(): | |
| print(f"Skipping {path.stem} since {out_dir} already exists.") | |
| continue | |
| for model_key, model in models.items(): | |
| out = coarse2fine_infer(sig, model, vqvae, device) | |
| for k, sig_list in out.items(): | |
| for i, s in enumerate(sig_list): | |
| s.write(out_dir / f"{model_key}-{k}-{i}.wav") | |
| if __name__ == "__main__": | |
| args = argbind.parse_args() | |
| with argbind.scope(args): | |
| main() | |