Spaces:
Runtime error
Runtime error
Hugo Flores Garcia
commited on
Commit
·
93b48cb
1
Parent(s):
128981d
more tweaks
Browse files- demo.py +22 -16
- scripts/exp/eval.py +17 -12
- scripts/utils/vamp_folder.py +116 -22
- vampnet/interface.py +0 -2
- vampnet/modules/base.py +10 -3
demo.py
CHANGED
|
@@ -210,25 +210,30 @@ with gr.Blocks() as demo:
|
|
| 210 |
|
| 211 |
""")
|
| 212 |
gr.Markdown("## Input Audio")
|
| 213 |
-
with gr.Column():
|
| 214 |
-
gr.Markdown("""
|
| 215 |
-
## Mask Hints
|
| 216 |
-
- most of the original audio will be masked and replaced with audio generated by vampnet
|
| 217 |
-
- mask hints are used to guide vampnet to generate audio that sounds like the original
|
| 218 |
-
- the more hints you give, the more the generated audio will sound like the original
|
| 219 |
|
| 220 |
-
""")
|
| 221 |
with gr.Column():
|
| 222 |
gr.Markdown("""
|
| 223 |
### Tips
|
| 224 |
- use the beat hint button so the output audio has the same beat structure as the input audio
|
| 225 |
-
- if you want
|
| 226 |
-
-
|
| 227 |
-
- decrease the periodic unmasking to anywhere from 2 to 8
|
| 228 |
- if you want a more "random" generation:
|
| 229 |
-
-
|
| 230 |
-
- increase the periodic unmasking to 16 or more
|
| 231 |
- increase the temperatures!
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
""")
|
| 234 |
|
|
@@ -243,7 +248,8 @@ with gr.Blocks() as demo:
|
|
| 243 |
num_vamps = gr.Number(
|
| 244 |
label="number of vamps. more vamps = longer generated audio",
|
| 245 |
value=1,
|
| 246 |
-
precision=0
|
|
|
|
| 247 |
)
|
| 248 |
|
| 249 |
manual_audio_upload = gr.File(
|
|
@@ -286,7 +292,7 @@ with gr.Blocks() as demo:
|
|
| 286 |
minimum=0,
|
| 287 |
maximum=64,
|
| 288 |
step=1,
|
| 289 |
-
value=
|
| 290 |
)
|
| 291 |
|
| 292 |
|
|
@@ -326,8 +332,8 @@ with gr.Blocks() as demo:
|
|
| 326 |
)
|
| 327 |
|
| 328 |
use_beats = gr.Checkbox(
|
| 329 |
-
label="use beat hints",
|
| 330 |
-
value=
|
| 331 |
)
|
| 332 |
|
| 333 |
snap_to_beats = gr.Checkbox(
|
|
|
|
| 210 |
|
| 211 |
""")
|
| 212 |
gr.Markdown("## Input Audio")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
|
|
|
| 214 |
with gr.Column():
|
| 215 |
gr.Markdown("""
|
| 216 |
### Tips
|
| 217 |
- use the beat hint button so the output audio has the same beat structure as the input audio
|
| 218 |
+
- if you want more beat structure:
|
| 219 |
+
- enable beat hints
|
|
|
|
| 220 |
- if you want a more "random" generation:
|
| 221 |
+
- increase the periodic unmasking to 12 or more
|
|
|
|
| 222 |
- increase the temperatures!
|
| 223 |
+
- uncheck the beat hint button (or reduce the beat unmask duration)
|
| 224 |
+
- if you want the generated audio to sound like the original, but with a different beat structure:
|
| 225 |
+
- uncheck the beat hint button
|
| 226 |
+
- decrease the periodic unmasking to anywhere from 2 to 20
|
| 227 |
+
- slightly decrease the random intensity, to like .95
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
""")
|
| 231 |
+
with gr.Column():
|
| 232 |
+
gr.Markdown("""
|
| 233 |
+
## Mask Hints
|
| 234 |
+
- most of the original audio will be masked and replaced with audio generated by vampnet
|
| 235 |
+
- mask hints are used to guide vampnet to generate audio that sounds like the original
|
| 236 |
+
- the more hints you give, the more the generated audio will sound like the original
|
| 237 |
|
| 238 |
""")
|
| 239 |
|
|
|
|
| 248 |
num_vamps = gr.Number(
|
| 249 |
label="number of vamps. more vamps = longer generated audio",
|
| 250 |
value=1,
|
| 251 |
+
precision=0,
|
| 252 |
+
visible=False
|
| 253 |
)
|
| 254 |
|
| 255 |
manual_audio_upload = gr.File(
|
|
|
|
| 292 |
minimum=0,
|
| 293 |
maximum=64,
|
| 294 |
step=1,
|
| 295 |
+
value=9,
|
| 296 |
)
|
| 297 |
|
| 298 |
|
|
|
|
| 332 |
)
|
| 333 |
|
| 334 |
use_beats = gr.Checkbox(
|
| 335 |
+
label="use beat hints (helps the output stick to the beat structure of the input)",
|
| 336 |
+
value=False
|
| 337 |
)
|
| 338 |
|
| 339 |
snap_to_beats = gr.Checkbox(
|
scripts/exp/eval.py
CHANGED
|
@@ -5,6 +5,7 @@ from functools import partial
|
|
| 5 |
from frechet_audio_distance import FrechetAudioDistance
|
| 6 |
import pandas
|
| 7 |
import argbind
|
|
|
|
| 8 |
from tqdm import tqdm
|
| 9 |
|
| 10 |
import audiotools
|
|
@@ -21,15 +22,16 @@ def eval(
|
|
| 21 |
assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"
|
| 22 |
|
| 23 |
# set up our metrics
|
| 24 |
-
sisdr_loss = audiotools.metrics.distance.SISDRLoss()
|
| 25 |
-
stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
|
| 26 |
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
|
| 27 |
frechet = FrechetAudioDistance(
|
| 28 |
use_pca=False,
|
| 29 |
use_activation=False,
|
| 30 |
-
verbose=True
|
|
|
|
| 31 |
)
|
| 32 |
-
|
| 33 |
|
| 34 |
# figure out what conditions we have
|
| 35 |
conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
|
|
@@ -44,7 +46,7 @@ def eval(
|
|
| 44 |
baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
|
| 45 |
|
| 46 |
metrics = []
|
| 47 |
-
for condition in conditions:
|
| 48 |
cond_dir = exp_dir / condition
|
| 49 |
cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
|
| 50 |
|
|
@@ -68,14 +70,17 @@ def eval(
|
|
| 68 |
cond_sig.resample(baseline_sig.sample_rate)
|
| 69 |
cond_sig.truncate_samples(baseline_sig.length)
|
| 70 |
|
| 71 |
-
#
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
| 76 |
return {
|
| 77 |
-
"sisdr": -sisdr_loss(baseline_sig, cond_sig).item(),
|
| 78 |
-
"stft": stft_loss(baseline_sig, cond_sig).item(),
|
| 79 |
"mel": mel_loss(baseline_sig, cond_sig).item(),
|
| 80 |
"frechet": frechet_score,
|
| 81 |
# "visqol": vsq,
|
|
|
|
| 5 |
from frechet_audio_distance import FrechetAudioDistance
|
| 6 |
import pandas
|
| 7 |
import argbind
|
| 8 |
+
import torch
|
| 9 |
from tqdm import tqdm
|
| 10 |
|
| 11 |
import audiotools
|
|
|
|
| 22 |
assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"
|
| 23 |
|
| 24 |
# set up our metrics
|
| 25 |
+
# sisdr_loss = audiotools.metrics.distance.SISDRLoss()
|
| 26 |
+
# stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
|
| 27 |
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
|
| 28 |
frechet = FrechetAudioDistance(
|
| 29 |
use_pca=False,
|
| 30 |
use_activation=False,
|
| 31 |
+
verbose=True,
|
| 32 |
+
audio_load_worker=4,
|
| 33 |
)
|
| 34 |
+
frechet.model.to("cuda" if torch.cuda.is_available() else "cpu")
|
| 35 |
|
| 36 |
# figure out what conditions we have
|
| 37 |
conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
|
|
|
|
| 46 |
baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
|
| 47 |
|
| 48 |
metrics = []
|
| 49 |
+
for condition in tqdm(conditions):
|
| 50 |
cond_dir = exp_dir / condition
|
| 51 |
cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
|
| 52 |
|
|
|
|
| 70 |
cond_sig.resample(baseline_sig.sample_rate)
|
| 71 |
cond_sig.truncate_samples(baseline_sig.length)
|
| 72 |
|
| 73 |
+
# if our condition is inpainting, we need to trim the conditioning off
|
| 74 |
+
if "inpaint" in condition:
|
| 75 |
+
ctx_amt = float(condition.split("_")[-1])
|
| 76 |
+
ctx_samples = int(ctx_amt * baseline_sig.sample_rate)
|
| 77 |
+
print(f"found inpainting condition. trimming off {ctx_samples} samples from {cond_file} and {baseline_file}")
|
| 78 |
+
cond_sig.trim(ctx_samples, ctx_samples)
|
| 79 |
+
baseline_sig.trim(ctx_samples, ctx_samples)
|
| 80 |
+
|
| 81 |
return {
|
| 82 |
+
# "sisdr": -sisdr_loss(baseline_sig, cond_sig).item(),
|
| 83 |
+
# "stft": stft_loss(baseline_sig, cond_sig).item(),
|
| 84 |
"mel": mel_loss(baseline_sig, cond_sig).item(),
|
| 85 |
"frechet": frechet_score,
|
| 86 |
# "visqol": vsq,
|
scripts/utils/vamp_folder.py
CHANGED
|
@@ -6,7 +6,7 @@ import subprocess
|
|
| 6 |
|
| 7 |
import argbind
|
| 8 |
from tqdm import tqdm
|
| 9 |
-
import
|
| 10 |
|
| 11 |
from vampnet.interface import Interface
|
| 12 |
import audiotools as at
|
|
@@ -48,7 +48,6 @@ def coarse2fine_argmax(sig, interface):
|
|
| 48 |
)
|
| 49 |
return interface.to_signal(z)
|
| 50 |
|
| 51 |
-
|
| 52 |
class CoarseCond:
|
| 53 |
|
| 54 |
def __init__(self, num_codebooks, downsample_factor):
|
|
@@ -59,13 +58,12 @@ class CoarseCond:
|
|
| 59 |
n_conditioning_codebooks = interface.coarse.n_codebooks - self.num_codebooks
|
| 60 |
zv = interface.coarse_vamp_v2(sig,
|
| 61 |
n_conditioning_codebooks=n_conditioning_codebooks,
|
| 62 |
-
downsample_factor=self.downsample_factor
|
| 63 |
)
|
| 64 |
|
| 65 |
zv = interface.coarse_to_fine(zv)
|
| 66 |
return interface.to_signal(zv)
|
| 67 |
|
| 68 |
-
|
| 69 |
def opus(sig, interface, bitrate=128):
|
| 70 |
sig = interface.preprocess(sig)
|
| 71 |
|
|
@@ -97,8 +95,78 @@ def opus(sig, interface, bitrate=128):
|
|
| 97 |
)
|
| 98 |
return sig
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
"baseline": baseline,
|
| 103 |
"reconstructed": reconstructed,
|
| 104 |
"coarse2fine": coarse2fine,
|
|
@@ -119,23 +187,55 @@ COARSE_SAMPLE_CONDS ={
|
|
| 119 |
|
| 120 |
}
|
| 121 |
|
| 122 |
-
|
| 123 |
f"opus_{bitrate}": lambda sig, interface: opus(sig, interface, bitrate=bitrate)
|
| 124 |
for bitrate in [5620, 1875, 1250, 625]
|
| 125 |
}
|
| 126 |
|
| 127 |
-
|
| 128 |
f"opus_{bitrate}": lambda sig, interface: opus(sig, interface, bitrate=bitrate)
|
| 129 |
for bitrate in [8036, 2296, 1148, 574]
|
| 130 |
}
|
| 131 |
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
"baseline": baseline,
|
| 134 |
"reconstructed": reconstructed,
|
| 135 |
"coarse2fine": coarse2fine,
|
| 136 |
"coarse2fine_argmax": coarse2fine_argmax,
|
| 137 |
}
|
| 138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
@argbind.bind(without_prefix=True)
|
| 140 |
def main(
|
| 141 |
sources=[
|
|
@@ -162,14 +262,8 @@ def main(
|
|
| 162 |
without_replacement=True,
|
| 163 |
)
|
| 164 |
|
| 165 |
-
if exp_type
|
| 166 |
-
SAMPLE_CONDS =
|
| 167 |
-
elif exp_type == "opus-spotdl":
|
| 168 |
-
SAMPLE_CONDS = OPUS_SPOTDL_SAMPLE_CONDS
|
| 169 |
-
elif exp_type == "coarse":
|
| 170 |
-
SAMPLE_CONDS = COARSE_SAMPLE_CONDS
|
| 171 |
-
elif exp_type == "c2f":
|
| 172 |
-
SAMPLE_CONDS = C2F_SAMPLE_CONDS
|
| 173 |
else:
|
| 174 |
raise ValueError(f"Unknown exp_type {exp_type}")
|
| 175 |
|
|
@@ -178,12 +272,12 @@ def main(
|
|
| 178 |
random.shuffle(indices)
|
| 179 |
for i in tqdm(indices):
|
| 180 |
# if all our files are already there, skip
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
|
| 188 |
sig = dataset[i]["signal"]
|
| 189 |
results = {
|
|
|
|
| 6 |
|
| 7 |
import argbind
|
| 8 |
from tqdm import tqdm
|
| 9 |
+
import torch
|
| 10 |
|
| 11 |
from vampnet.interface import Interface
|
| 12 |
import audiotools as at
|
|
|
|
| 48 |
)
|
| 49 |
return interface.to_signal(z)
|
| 50 |
|
|
|
|
| 51 |
class CoarseCond:
|
| 52 |
|
| 53 |
def __init__(self, num_codebooks, downsample_factor):
|
|
|
|
| 58 |
n_conditioning_codebooks = interface.coarse.n_codebooks - self.num_codebooks
|
| 59 |
zv = interface.coarse_vamp_v2(sig,
|
| 60 |
n_conditioning_codebooks=n_conditioning_codebooks,
|
| 61 |
+
downsample_factor=self.downsample_factor,
|
| 62 |
)
|
| 63 |
|
| 64 |
zv = interface.coarse_to_fine(zv)
|
| 65 |
return interface.to_signal(zv)
|
| 66 |
|
|
|
|
| 67 |
def opus(sig, interface, bitrate=128):
|
| 68 |
sig = interface.preprocess(sig)
|
| 69 |
|
|
|
|
| 95 |
)
|
| 96 |
return sig
|
| 97 |
|
| 98 |
+
def token_noise(ratio=1.0):
|
| 99 |
+
def wrapper(sig, interface):
|
| 100 |
+
z = interface.encode(sig)
|
| 101 |
+
r = interface.coarse.invgamma(ratio).to(interface.device)
|
| 102 |
+
print(f'adding noise with ratio {ratio}')
|
| 103 |
+
z, mask = interface.coarse.add_noise(
|
| 104 |
+
z,
|
| 105 |
+
r,
|
| 106 |
+
noise_mode="random"
|
| 107 |
+
)
|
| 108 |
+
return interface.to_signal(z)
|
| 109 |
+
return wrapper
|
| 110 |
+
|
| 111 |
+
def mask_ratio_1_step(ratio=1.0):
|
| 112 |
+
def wrapper(sig, interface):
|
| 113 |
+
r = interface.coarse.invgamma(ratio).to(interface.device)
|
| 114 |
+
intensity = 1-r
|
| 115 |
+
|
| 116 |
+
zv = interface.coarse_vamp_v2(
|
| 117 |
+
sig,
|
| 118 |
+
sample='argmax',
|
| 119 |
+
sampling_steps=1,
|
| 120 |
+
intensity=intensity
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
return interface.to_signal(zv)
|
| 124 |
+
return wrapper
|
| 125 |
+
|
| 126 |
+
def num_sampling_steps(num_steps=1):
|
| 127 |
+
def wrapper(sig, interface):
|
| 128 |
+
zv = interface.coarse_vamp_v2(
|
| 129 |
+
sig,
|
| 130 |
+
downsample_factor=16,
|
| 131 |
+
sampling_steps=num_steps,
|
| 132 |
+
)
|
| 133 |
|
| 134 |
+
zv = interface.coarse_to_fine(zv)
|
| 135 |
+
return interface.to_signal(zv)
|
| 136 |
+
return wrapper
|
| 137 |
+
|
| 138 |
+
def beat_mask(ctx_time):
|
| 139 |
+
def wrapper(sig, interface):
|
| 140 |
+
beat_mask = interface.make_beat_mask(
|
| 141 |
+
sig,
|
| 142 |
+
before_beat_s=0.0,
|
| 143 |
+
after_beat_s=ctx_time,
|
| 144 |
+
invert=True
|
| 145 |
+
)
|
| 146 |
+
zv = interface.coarse_vamp_v2(
|
| 147 |
+
sig,
|
| 148 |
+
ext_mask=beat_mask,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
zv = interface.coarse_to_fine(zv)
|
| 152 |
+
return interface.to_signal(zv)
|
| 153 |
+
return wrapper
|
| 154 |
+
|
| 155 |
+
def inpaint(ctx_time):
|
| 156 |
+
def wrapper(sig, interface):
|
| 157 |
+
zv = interface.coarse_vamp_v2(
|
| 158 |
+
sig,
|
| 159 |
+
prefix_dur_s=ctx_time,
|
| 160 |
+
suffix_dur_s=ctx_time,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
zv = interface.coarse_to_fine(zv)
|
| 164 |
+
return interface.to_signal(zv)
|
| 165 |
+
return wrapper
|
| 166 |
+
|
| 167 |
+
EXP_REGISTRY = {}
|
| 168 |
+
|
| 169 |
+
EXP_REGISTRY["gen-compression"] = {
|
| 170 |
"baseline": baseline,
|
| 171 |
"reconstructed": reconstructed,
|
| 172 |
"coarse2fine": coarse2fine,
|
|
|
|
| 187 |
|
| 188 |
}
|
| 189 |
|
| 190 |
+
EXP_REGISTRY["opus-jazzpop"] = {
|
| 191 |
f"opus_{bitrate}": lambda sig, interface: opus(sig, interface, bitrate=bitrate)
|
| 192 |
for bitrate in [5620, 1875, 1250, 625]
|
| 193 |
}
|
| 194 |
|
| 195 |
+
EXP_REGISTRY["opus-spotdl"] = {
|
| 196 |
f"opus_{bitrate}": lambda sig, interface: opus(sig, interface, bitrate=bitrate)
|
| 197 |
for bitrate in [8036, 2296, 1148, 574]
|
| 198 |
}
|
| 199 |
|
| 200 |
+
EXP_REGISTRY["opus-baseline"] = {
|
| 201 |
+
f"opus_{bitrate}": lambda sig, interface: opus(sig, interface, bitrate=bitrate)
|
| 202 |
+
for bitrate in [8000, 12000, 16000]
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
EXP_REGISTRY["c2f"] = {
|
| 206 |
"baseline": baseline,
|
| 207 |
"reconstructed": reconstructed,
|
| 208 |
"coarse2fine": coarse2fine,
|
| 209 |
"coarse2fine_argmax": coarse2fine_argmax,
|
| 210 |
}
|
| 211 |
|
| 212 |
+
EXP_REGISTRY["token-noise"] = {
|
| 213 |
+
f"token_noise_{r}": token_noise(r) for r in [0.25, 0.5, 0.75, 1.0]
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
EXP_REGISTRY["mask-ratio"] = {
|
| 217 |
+
"codec": reconstructed,
|
| 218 |
+
**{f"mask_ratio_{r}": mask_ratio_1_step(r) for r in [0.25, 0.5, 0.75, 0.9]}
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
EXP_REGISTRY["sampling-steps"] = {
|
| 222 |
+
"codec": reconstructed,
|
| 223 |
+
**{f"steps_{n}": num_sampling_steps(n) for n in [1, 4, 12, 24, 36, 64, 72, 128]},
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
EXP_REGISTRY["baseline"] = {
|
| 227 |
+
"baseline": baseline,
|
| 228 |
+
"codec": reconstructed,
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
EXP_REGISTRY["musical-sampling"] = {
|
| 232 |
+
"baseline": baseline,
|
| 233 |
+
"codec": reconstructed,
|
| 234 |
+
**{f"downsample_{x}x": CoarseCond(4, downsample_factor=x) for x in [16, 32]},
|
| 235 |
+
**{f"beat_mask_{t}": beat_mask(t) for t in [0.075]},
|
| 236 |
+
**{f"inpaint_{t}": inpaint(t) for t in [0.5, 1.0,]}, # multiply these by 2 (they go left and right)
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
@argbind.bind(without_prefix=True)
|
| 240 |
def main(
|
| 241 |
sources=[
|
|
|
|
| 262 |
without_replacement=True,
|
| 263 |
)
|
| 264 |
|
| 265 |
+
if exp_type in EXP_REGISTRY:
|
| 266 |
+
SAMPLE_CONDS = EXP_REGISTRY[exp_type]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
else:
|
| 268 |
raise ValueError(f"Unknown exp_type {exp_type}")
|
| 269 |
|
|
|
|
| 272 |
random.shuffle(indices)
|
| 273 |
for i in tqdm(indices):
|
| 274 |
# if all our files are already there, skip
|
| 275 |
+
done = []
|
| 276 |
+
for name in SAMPLE_CONDS:
|
| 277 |
+
o_dir = Path(output_dir) / name
|
| 278 |
+
done.append((o_dir / f"{i}.wav").exists())
|
| 279 |
+
if all(done):
|
| 280 |
+
continue
|
| 281 |
|
| 282 |
sig = dataset[i]["signal"]
|
| 283 |
results = {
|
vampnet/interface.py
CHANGED
|
@@ -183,10 +183,8 @@ class Interface:
|
|
| 183 |
num_steps = mask[_slice[0]:_slice[1]].shape[0]
|
| 184 |
_m = torch.ones(num_steps, device=self.device)
|
| 185 |
_m = torch.nn.functional.dropout(_m, p=dropout)
|
| 186 |
-
print(_m)
|
| 187 |
|
| 188 |
mask[_slice[0]:_slice[1]] = _m
|
| 189 |
-
print(mask)
|
| 190 |
|
| 191 |
if mask_downbeats:
|
| 192 |
for downbeat_idx in downbeats_z:
|
|
|
|
| 183 |
num_steps = mask[_slice[0]:_slice[1]].shape[0]
|
| 184 |
_m = torch.ones(num_steps, device=self.device)
|
| 185 |
_m = torch.nn.functional.dropout(_m, p=dropout)
|
|
|
|
| 186 |
|
| 187 |
mask[_slice[0]:_slice[1]] = _m
|
|
|
|
| 188 |
|
| 189 |
if mask_downbeats:
|
| 190 |
for downbeat_idx in downbeats_z:
|
vampnet/modules/base.py
CHANGED
|
@@ -42,6 +42,7 @@ class VampBase(at.ml.BaseModel):
|
|
| 42 |
n_suffix: Optional[torch.Tensor] = None,
|
| 43 |
downsample_factor: Optional[int] = None,
|
| 44 |
n_conditioning_codebooks: Optional[int] = None,
|
|
|
|
| 45 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 46 |
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
| 47 |
|
|
@@ -89,13 +90,14 @@ class VampBase(at.ml.BaseModel):
|
|
| 89 |
if random_x is None:
|
| 90 |
random_x = torch.randint_like(x, 0, self.vocab_size)
|
| 91 |
|
| 92 |
-
if self.noise_mode
|
|
|
|
| 93 |
random_x = torch.full_like(x, self.mask_token)
|
| 94 |
-
elif
|
| 95 |
if random_x is None:
|
| 96 |
random_x = torch.randint_like(x, 0, self.vocab_size)
|
| 97 |
else:
|
| 98 |
-
raise ValueError(f"invalid noise mode {
|
| 99 |
|
| 100 |
# add the external mask if we were given one
|
| 101 |
if ext_mask is not None:
|
|
@@ -132,6 +134,11 @@ class VampBase(at.ml.BaseModel):
|
|
| 132 |
def gamma(self, r):
|
| 133 |
return (r * torch.pi / 2).cos()
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
def r_embed(self, r, max_positions=10000):
|
| 136 |
""" """
|
| 137 |
assert hasattr(self, "r_cond_dim"), "must set r_cond_dim before calling r_embed"
|
|
|
|
| 42 |
n_suffix: Optional[torch.Tensor] = None,
|
| 43 |
downsample_factor: Optional[int] = None,
|
| 44 |
n_conditioning_codebooks: Optional[int] = None,
|
| 45 |
+
noise_mode: str = None,
|
| 46 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 47 |
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
| 48 |
|
|
|
|
| 90 |
if random_x is None:
|
| 91 |
random_x = torch.randint_like(x, 0, self.vocab_size)
|
| 92 |
|
| 93 |
+
noise_mode = noise_mode if noise_mode is not None else self.noise_mode
|
| 94 |
+
if noise_mode == "mask":
|
| 95 |
random_x = torch.full_like(x, self.mask_token)
|
| 96 |
+
elif noise_mode == "random":
|
| 97 |
if random_x is None:
|
| 98 |
random_x = torch.randint_like(x, 0, self.vocab_size)
|
| 99 |
else:
|
| 100 |
+
raise ValueError(f"invalid noise mode {noise_mode}")
|
| 101 |
|
| 102 |
# add the external mask if we were given one
|
| 103 |
if ext_mask is not None:
|
|
|
|
| 134 |
def gamma(self, r):
|
| 135 |
return (r * torch.pi / 2).cos()
|
| 136 |
|
| 137 |
+
def invgamma(self, y):
|
| 138 |
+
if not torch.is_tensor(y):
|
| 139 |
+
y = torch.tensor(y)[None]
|
| 140 |
+
return 2 * y.acos() / torch.pi
|
| 141 |
+
|
| 142 |
def r_embed(self, r, max_positions=10000):
|
| 143 |
""" """
|
| 144 |
assert hasattr(self, "r_cond_dim"), "must set r_cond_dim before calling r_embed"
|