diff --git a/.gitignore b/.gitignore
index 685d6448d5a197ed8a434f72d6f2279552613c91..81d14bc405b5624e69bceda90e746519ff9930a5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -182,3 +182,4 @@ audiotools/
 descript-audio-codec/
 # *.pth
 .git-old
+conf/generated/*
diff --git a/README.md b/README.md
index 687fb086b0db5ec747e42728d8d25be07f51e7cb..231407a1cc1ca55dc8ea22d351de801f7dd069f4 100644
--- a/README.md
+++ b/README.md
@@ -7,12 +7,14 @@ sdk: gradio
 sdk_version: 3.36.1
 app_file: app.py
 pinned: false
-duplicated_from: hugggof/vampnet
 ---
 
 # VampNet
 
-This repository contains recipes for training generative music models on top of the Lyrebird Audio Codec.
+This repository contains recipes for training generative music models on top of the Descript Audio Codec.
+
+## try `unloop`
+you can try vampnet in a co-creative looper called unloop. see this link: https://github.com/hugofloresgarcia/unloop
 
 # Setting up
 
@@ -35,7 +37,7 @@ Config files are stored in the `conf/` folder.
 ### Licensing for Pretrained Models: 
 The weights for the models are licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml). Likewise, any VampNet models fine-tuned on the pretrained models are also licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml).
 
-Download the pretrained models from [this link](https://zenodo.org/record/8136545). Then, extract the models to the `models/` folder. 
+Download the pretrained models from [this link](https://zenodo.org/record/8136629). Then, extract the models to the `models/` folder. 
 
 
 # Usage
diff --git a/app.py b/app.py
index dfa3ff80e64cacf3e395fe74855f2a3f7370156f..de30215d0e4c43a5731ec014ae27412cd9ee80fa 100644
--- a/app.py
+++ b/app.py
@@ -124,7 +124,7 @@ def _vamp(data, return_mask=False):
     )
 
     if use_coarse2fine: 
-        zv = interface.coarse_to_fine(zv, temperature=data[temp])
+        zv = interface.coarse_to_fine(zv, temperature=data[temp], mask=mask)
 
     sig = interface.to_signal(zv).cpu()
     print("done")
@@ -407,7 +407,8 @@ with gr.Blocks() as demo:
 
             use_coarse2fine = gr.Checkbox(
                 label="use coarse2fine",
-                value=True
+                value=True, 
+                visible=False
             )
 
             num_steps = gr.Slider(
diff --git a/conf/generated-v0/berta-goldman-speech/c2f.yml b/conf/generated-v0/berta-goldman-speech/c2f.yml
deleted file mode 100644
index 0f5a4cd57e7a801121d7c77a62a0e8767b7fe61c..0000000000000000000000000000000000000000
--- a/conf/generated-v0/berta-goldman-speech/c2f.yml
+++ /dev/null
@@ -1,15 +0,0 @@
-$include:
-- conf/lora/lora.yml
-AudioDataset.duration: 3.0
-AudioDataset.loudness_cutoff: -40.0
-VampNet.embedding_dim: 1280
-VampNet.n_codebooks: 14
-VampNet.n_conditioning_codebooks: 4
-VampNet.n_heads: 20
-VampNet.n_layers: 16
-fine_tune: true
-save_path: ./runs/berta-goldman-speech/c2f
-train/AudioLoader.sources:
-- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3
-val/AudioLoader.sources:
-- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3
diff --git a/conf/generated-v0/berta-goldman-speech/coarse.yml b/conf/generated-v0/berta-goldman-speech/coarse.yml
deleted file mode 100644
index 7c1207e9cfe83bac59f76fcf21068405cd6c9551..0000000000000000000000000000000000000000
--- a/conf/generated-v0/berta-goldman-speech/coarse.yml
+++ /dev/null
@@ -1,8 +0,0 @@
-$include:
-- conf/lora/lora.yml
-fine_tune: true
-save_path: ./runs/berta-goldman-speech/coarse
-train/AudioLoader.sources:
-- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3
-val/AudioLoader.sources:
-- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3
diff --git a/conf/generated-v0/berta-goldman-speech/interface.yml b/conf/generated-v0/berta-goldman-speech/interface.yml
deleted file mode 100644
index d1ba35ec732a0148f3bced5542e27e85575c4d4e..0000000000000000000000000000000000000000
--- a/conf/generated-v0/berta-goldman-speech/interface.yml
+++ /dev/null
@@ -1,5 +0,0 @@
-AudioLoader.sources:
-- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3
-Interface.coarse2fine_ckpt: ./runs/berta-goldman-speech/c2f/best/vampnet/weights.pth
-Interface.coarse_ckpt: ./runs/berta-goldman-speech/coarse/best/vampnet/weights.pth
-Interface.codec_ckpt: ./models/spotdl/codec.pth
diff --git a/conf/generated-v0/gamelan-xeno-canto/c2f.yml b/conf/generated-v0/gamelan-xeno-canto/c2f.yml
deleted file mode 100644
index 9e6fec4ddc7dd0a2e02d1be66cc7f6eafa669ed1..0000000000000000000000000000000000000000
--- a/conf/generated-v0/gamelan-xeno-canto/c2f.yml
+++ /dev/null
@@ -1,17 +0,0 @@
-$include:
-- conf/lora/lora.yml
-AudioDataset.duration: 3.0
-AudioDataset.loudness_cutoff: -40.0
-VampNet.embedding_dim: 1280
-VampNet.n_codebooks: 14
-VampNet.n_conditioning_codebooks: 4
-VampNet.n_heads: 20
-VampNet.n_layers: 16
-fine_tune: true
-save_path: ./runs/gamelan-xeno-canto/c2f
-train/AudioLoader.sources:
-- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3
-- /media/CHONK/hugo/loras/xeno-canto-2
-val/AudioLoader.sources:
-- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3
-- /media/CHONK/hugo/loras/xeno-canto-2
diff --git a/conf/generated-v0/gamelan-xeno-canto/coarse.yml b/conf/generated-v0/gamelan-xeno-canto/coarse.yml
deleted file mode 100644
index 7e8d38e18d714cb08db7ed456939737404533c3e..0000000000000000000000000000000000000000
--- a/conf/generated-v0/gamelan-xeno-canto/coarse.yml
+++ /dev/null
@@ -1,10 +0,0 @@
-$include:
-- conf/lora/lora.yml
-fine_tune: true
-save_path: ./runs/gamelan-xeno-canto/coarse
-train/AudioLoader.sources:
-- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3
-- /media/CHONK/hugo/loras/xeno-canto-2
-val/AudioLoader.sources:
-- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3
-- /media/CHONK/hugo/loras/xeno-canto-2
diff --git a/conf/generated-v0/gamelan-xeno-canto/interface.yml b/conf/generated-v0/gamelan-xeno-canto/interface.yml
deleted file mode 100644
index e567800477816ac1cc41719744c1ba40562e35b1..0000000000000000000000000000000000000000
--- a/conf/generated-v0/gamelan-xeno-canto/interface.yml
+++ /dev/null
@@ -1,6 +0,0 @@
-AudioLoader.sources:
-- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3
-- /media/CHONK/hugo/loras/xeno-canto-2
-Interface.coarse2fine_ckpt: ./runs/gamelan-xeno-canto/c2f/best/vampnet/weights.pth
-Interface.coarse_ckpt: ./runs/gamelan-xeno-canto/coarse/best/vampnet/weights.pth
-Interface.codec_ckpt: ./models/spotdl/codec.pth
diff --git a/conf/generated-v0/nasralla/c2f.yml b/conf/generated-v0/nasralla/c2f.yml
deleted file mode 100644
index 9d9db7bed268c18f3ca4047dcde34dd18a5a2301..0000000000000000000000000000000000000000
--- a/conf/generated-v0/nasralla/c2f.yml
+++ /dev/null
@@ -1,15 +0,0 @@
-$include:
-- conf/lora/lora.yml
-AudioDataset.duration: 3.0
-AudioDataset.loudness_cutoff: -40.0
-VampNet.embedding_dim: 1280
-VampNet.n_codebooks: 14
-VampNet.n_conditioning_codebooks: 4
-VampNet.n_heads: 20
-VampNet.n_layers: 16
-fine_tune: true
-save_path: ./runs/nasralla/c2f
-train/AudioLoader.sources:
-- /media/CHONK/hugo/nasralla
-val/AudioLoader.sources:
-- /media/CHONK/hugo/nasralla
diff --git a/conf/generated-v0/nasralla/coarse.yml b/conf/generated-v0/nasralla/coarse.yml
deleted file mode 100644
index 43a4d18c7f955e38200ded0d2a4fa0959ddb639e..0000000000000000000000000000000000000000
--- a/conf/generated-v0/nasralla/coarse.yml
+++ /dev/null
@@ -1,8 +0,0 @@
-$include:
-- conf/lora/lora.yml
-fine_tune: true
-save_path: ./runs/nasralla/coarse
-train/AudioLoader.sources:
-- /media/CHONK/hugo/nasralla
-val/AudioLoader.sources:
-- /media/CHONK/hugo/nasralla
diff --git a/conf/generated-v0/nasralla/interface.yml b/conf/generated-v0/nasralla/interface.yml
deleted file mode 100644
index c93e872d1e4b66567812755882a996814794ad8f..0000000000000000000000000000000000000000
--- a/conf/generated-v0/nasralla/interface.yml
+++ /dev/null
@@ -1,5 +0,0 @@
-AudioLoader.sources:
-- /media/CHONK/hugo/nasralla
-Interface.coarse2fine_ckpt: ./runs/nasralla/c2f/best/vampnet/weights.pth
-Interface.coarse_ckpt: ./runs/nasralla/coarse/best/vampnet/weights.pth
-Interface.codec_ckpt: ./models/spotdl/codec.pth
diff --git a/conf/generated/breaks-steps/c2f.yml b/conf/generated/breaks-steps/c2f.yml
deleted file mode 100644
index 49617a6d52de00a9bc7c82c6e820168076402fac..0000000000000000000000000000000000000000
--- a/conf/generated/breaks-steps/c2f.yml
+++ /dev/null
@@ -1,15 +0,0 @@
-$include:
-- conf/lora/lora.yml
-AudioDataset.duration: 3.0
-AudioDataset.loudness_cutoff: -40.0
-VampNet.embedding_dim: 1280
-VampNet.n_codebooks: 14
-VampNet.n_conditioning_codebooks: 4
-VampNet.n_heads: 20
-VampNet.n_layers: 16
-fine_tune: true
-fine_tune_checkpoint: ./models/spotdl/c2f.pth
-save_path: ./runs/breaks-steps/c2f
-train/AudioLoader.sources: &id001
-- /media/CHONK/hugo/breaks-steps
-val/AudioLoader.sources: *id001
diff --git a/conf/generated/breaks-steps/coarse.yml b/conf/generated/breaks-steps/coarse.yml
deleted file mode 100644
index 71d9b27fbc4aac7d407d3606e98c4eaca35e2d3f..0000000000000000000000000000000000000000
--- a/conf/generated/breaks-steps/coarse.yml
+++ /dev/null
@@ -1,8 +0,0 @@
-$include:
-- conf/lora/lora.yml
-fine_tune: true
-fine_tune_checkpoint: ./models/spotdl/coarse.pth
-save_path: ./runs/breaks-steps/coarse
-train/AudioLoader.sources: &id001
-- /media/CHONK/hugo/breaks-steps
-val/AudioLoader.sources: *id001
diff --git a/conf/generated/breaks-steps/interface.yml b/conf/generated/breaks-steps/interface.yml
deleted file mode 100644
index b4b5182c4a378884e1614d89bc39abdf78a4eaa2..0000000000000000000000000000000000000000
--- a/conf/generated/breaks-steps/interface.yml
+++ /dev/null
@@ -1,7 +0,0 @@
-AudioLoader.sources:
-- - /media/CHONK/hugo/breaks-steps
-Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
-Interface.coarse2fine_lora_ckpt: ./runs/breaks-steps/c2f/latest/lora.pth
-Interface.coarse_ckpt: ./models/spotdl/coarse.pth
-Interface.coarse_lora_ckpt: ./runs/breaks-steps/coarse/latest/lora.pth
-Interface.codec_ckpt: ./models/spotdl/codec.pth
diff --git a/conf/generated/bulgarian-tv-choir/c2f.yml b/conf/generated/bulgarian-tv-choir/c2f.yml
deleted file mode 100644
index 7bc54bf54bc8cc5c599a11f30c036822fa4b84c5..0000000000000000000000000000000000000000
--- a/conf/generated/bulgarian-tv-choir/c2f.yml
+++ /dev/null
@@ -1,15 +0,0 @@
-$include:
-- conf/lora/lora.yml
-AudioDataset.duration: 3.0
-AudioDataset.loudness_cutoff: -40.0
-VampNet.embedding_dim: 1280
-VampNet.n_codebooks: 14
-VampNet.n_conditioning_codebooks: 4
-VampNet.n_heads: 20
-VampNet.n_layers: 16
-fine_tune: true
-fine_tune_checkpoint: ./models/spotdl/c2f.pth
-save_path: ./runs/bulgarian-tv-choir/c2f
-train/AudioLoader.sources: &id001
-- /media/CHONK/hugo/loras/bulgarian-female-tv-choir/
-val/AudioLoader.sources: *id001
diff --git a/conf/generated/bulgarian-tv-choir/coarse.yml b/conf/generated/bulgarian-tv-choir/coarse.yml
deleted file mode 100644
index 06f27f140dbd8c6d6315aab0787435ff501f8958..0000000000000000000000000000000000000000
--- a/conf/generated/bulgarian-tv-choir/coarse.yml
+++ /dev/null
@@ -1,8 +0,0 @@
-$include:
-- conf/lora/lora.yml
-fine_tune: true
-fine_tune_checkpoint: ./models/spotdl/coarse.pth
-save_path: ./runs/bulgarian-tv-choir/coarse
-train/AudioLoader.sources: &id001
-- /media/CHONK/hugo/loras/bulgarian-female-tv-choir/
-val/AudioLoader.sources: *id001
diff --git a/conf/generated/bulgarian-tv-choir/interface.yml b/conf/generated/bulgarian-tv-choir/interface.yml
deleted file mode 100644
index b56e8d721adf99da361dadf423a669bb576478e1..0000000000000000000000000000000000000000
--- a/conf/generated/bulgarian-tv-choir/interface.yml
+++ /dev/null
@@ -1,7 +0,0 @@
-AudioLoader.sources:
-- - /media/CHONK/hugo/loras/bulgarian-female-tv-choir/
-Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
-Interface.coarse2fine_lora_ckpt: ./runs/bulgarian-tv-choir/c2f/latest/lora.pth
-Interface.coarse_ckpt: ./models/spotdl/coarse.pth
-Interface.coarse_lora_ckpt: ./runs/bulgarian-tv-choir/coarse/latest/lora.pth
-Interface.codec_ckpt: ./models/spotdl/codec.pth
diff --git a/conf/generated/dariacore/c2f.yml b/conf/generated/dariacore/c2f.yml
deleted file mode 100644
index e8e52fc05be63fe891d3adf0c2115efd5e06ecef..0000000000000000000000000000000000000000
--- a/conf/generated/dariacore/c2f.yml
+++ /dev/null
@@ -1,15 +0,0 @@
-$include:
-- conf/lora/lora.yml
-AudioDataset.duration: 3.0
-AudioDataset.loudness_cutoff: -40.0
-VampNet.embedding_dim: 1280
-VampNet.n_codebooks: 14
-VampNet.n_conditioning_codebooks: 4
-VampNet.n_heads: 20
-VampNet.n_layers: 16
-fine_tune: true
-fine_tune_checkpoint: ./models/spotdl/c2f.pth
-save_path: ./runs/dariacore/c2f
-train/AudioLoader.sources: &id001
-- /media/CHONK/hugo/loras/dariacore
-val/AudioLoader.sources: *id001
diff --git a/conf/generated/dariacore/coarse.yml b/conf/generated/dariacore/coarse.yml
deleted file mode 100644
index 42044d7bbafbf890d6d6bc504beb49edf977c39b..0000000000000000000000000000000000000000
--- a/conf/generated/dariacore/coarse.yml
+++ /dev/null
@@ -1,8 +0,0 @@
-$include:
-- conf/lora/lora.yml
-fine_tune: true
-fine_tune_checkpoint: ./models/spotdl/coarse.pth
-save_path: ./runs/dariacore/coarse
-train/AudioLoader.sources: &id001
-- /media/CHONK/hugo/loras/dariacore
-val/AudioLoader.sources: *id001
diff --git a/conf/generated/dariacore/interface.yml b/conf/generated/dariacore/interface.yml
deleted file mode 100644
index 29342d2fe9d97f20d9521885869f1cca16d2aeba..0000000000000000000000000000000000000000
--- a/conf/generated/dariacore/interface.yml
+++ /dev/null
@@ -1,7 +0,0 @@
-AudioLoader.sources:
-- - /media/CHONK/hugo/loras/dariacore
-Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
-Interface.coarse2fine_lora_ckpt: ./runs/dariacore/c2f/latest/lora.pth
-Interface.coarse_ckpt: ./models/spotdl/coarse.pth
-Interface.coarse_lora_ckpt: ./runs/dariacore/coarse/latest/lora.pth
-Interface.codec_ckpt: ./models/spotdl/codec.pth
diff --git a/conf/generated/musica-bolero-marimba/c2f.yml b/conf/generated/musica-bolero-marimba/c2f.yml
deleted file mode 100644
index cd06c72814deaf9fd41d3dabc8e6046e050ad968..0000000000000000000000000000000000000000
--- a/conf/generated/musica-bolero-marimba/c2f.yml
+++ /dev/null
@@ -1,18 +0,0 @@
-$include:
-- conf/lora/lora.yml
-AudioDataset.duration: 3.0
-AudioDataset.loudness_cutoff: -40.0
-VampNet.embedding_dim: 1280
-VampNet.n_codebooks: 14
-VampNet.n_conditioning_codebooks: 4
-VampNet.n_heads: 20
-VampNet.n_layers: 16
-fine_tune: true
-fine_tune_checkpoint: ./models/spotdl/c2f.pth
-save_path: ./runs/musica-bolero-marimba/c2f
-train/AudioLoader.sources:
-- /media/CHONK/hugo/loras/boleros
-- /media/CHONK/hugo/loras/marimba-honduras
-val/AudioLoader.sources:
-- /media/CHONK/hugo/loras/boleros
-- /media/CHONK/hugo/loras/marimba-honduras
diff --git a/conf/generated/musica-bolero-marimba/coarse.yml b/conf/generated/musica-bolero-marimba/coarse.yml
deleted file mode 100644
index a3e1c0ee8e8593528cb389fb84c56894727cfca5..0000000000000000000000000000000000000000
--- a/conf/generated/musica-bolero-marimba/coarse.yml
+++ /dev/null
@@ -1,11 +0,0 @@
-$include:
-- conf/lora/lora.yml
-fine_tune: true
-fine_tune_checkpoint: ./models/spotdl/coarse.pth
-save_path: ./runs/musica-bolero-marimba/coarse
-train/AudioLoader.sources:
-- /media/CHONK/hugo/loras/boleros
-- /media/CHONK/hugo/loras/marimba-honduras
-val/AudioLoader.sources:
-- /media/CHONK/hugo/loras/boleros
-- /media/CHONK/hugo/loras/marimba-honduras
diff --git a/conf/generated/musica-bolero-marimba/interface.yml b/conf/generated/musica-bolero-marimba/interface.yml
deleted file mode 100644
index 08b42e3120a3cedbb5aafb9a39ca879d8958127a..0000000000000000000000000000000000000000
--- a/conf/generated/musica-bolero-marimba/interface.yml
+++ /dev/null
@@ -1,8 +0,0 @@
-AudioLoader.sources:
-- /media/CHONK/hugo/loras/boleros
-- /media/CHONK/hugo/loras/marimba-honduras
-Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
-Interface.coarse2fine_lora_ckpt: ./runs/musica-bolero-marimba/c2f/latest/lora.pth
-Interface.coarse_ckpt: ./models/spotdl/coarse.pth
-Interface.coarse_lora_ckpt: ./runs/musica-bolero-marimba/coarse/latest/lora.pth
-Interface.codec_ckpt: ./models/spotdl/codec.pth
diff --git a/conf/generated/panchos/c2f.yml b/conf/generated/panchos/c2f.yml
deleted file mode 100644
index 4efd6fb4caf409382929dcf61d40ed37e3773eac..0000000000000000000000000000000000000000
--- a/conf/generated/panchos/c2f.yml
+++ /dev/null
@@ -1,15 +0,0 @@
-$include:
-- conf/lora/lora.yml
-AudioDataset.duration: 3.0
-AudioDataset.loudness_cutoff: -40.0
-VampNet.embedding_dim: 1280
-VampNet.n_codebooks: 14
-VampNet.n_conditioning_codebooks: 4
-VampNet.n_heads: 20
-VampNet.n_layers: 16
-fine_tune: true
-fine_tune_checkpoint: ./models/spotdl/c2f.pth
-save_path: ./runs/panchos/c2f
-train/AudioLoader.sources: &id001
-- /media/CHONK/hugo/loras/panchos/
-val/AudioLoader.sources: *id001
diff --git a/conf/generated/panchos/coarse.yml b/conf/generated/panchos/coarse.yml
deleted file mode 100644
index c4f21a3f4deb58cd6b98680e82d59ad32098542e..0000000000000000000000000000000000000000
--- a/conf/generated/panchos/coarse.yml
+++ /dev/null
@@ -1,8 +0,0 @@
-$include:
-- conf/lora/lora.yml
-fine_tune: true
-fine_tune_checkpoint: ./models/spotdl/coarse.pth
-save_path: ./runs/panchos/coarse
-train/AudioLoader.sources: &id001
-- /media/CHONK/hugo/loras/panchos/
-val/AudioLoader.sources: *id001
diff --git a/conf/generated/panchos/interface.yml b/conf/generated/panchos/interface.yml
deleted file mode 100644
index 8bae11c225a0fa49c27efdfc808a63d53c21755a..0000000000000000000000000000000000000000
--- a/conf/generated/panchos/interface.yml
+++ /dev/null
@@ -1,7 +0,0 @@
-AudioLoader.sources:
-- - /media/CHONK/hugo/loras/panchos/
-Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
-Interface.coarse2fine_lora_ckpt: ./runs/panchos/c2f/latest/lora.pth
-Interface.coarse_ckpt: ./models/spotdl/coarse.pth
-Interface.coarse_lora_ckpt: ./runs/panchos/coarse/latest/lora.pth
-Interface.codec_ckpt: ./models/spotdl/codec.pth
diff --git a/conf/generated/titi-monkey/c2f.yml b/conf/generated/titi-monkey/c2f.yml
deleted file mode 100644
index 456912ab1589eee1dfe6c5768e70ede4e455c828..0000000000000000000000000000000000000000
--- a/conf/generated/titi-monkey/c2f.yml
+++ /dev/null
@@ -1,15 +0,0 @@
-$include:
-- conf/lora/lora.yml
-AudioDataset.duration: 3.0
-AudioDataset.loudness_cutoff: -40.0
-VampNet.embedding_dim: 1280
-VampNet.n_codebooks: 14
-VampNet.n_conditioning_codebooks: 4
-VampNet.n_heads: 20
-VampNet.n_layers: 16
-fine_tune: true
-fine_tune_checkpoint: ./models/spotdl/c2f.pth
-save_path: ./runs/titi-monkey/c2f
-train/AudioLoader.sources: &id001
-- /media/CHONK/hugo/loras/titi-monkey.mp3
-val/AudioLoader.sources: *id001
diff --git a/conf/generated/titi-monkey/coarse.yml b/conf/generated/titi-monkey/coarse.yml
deleted file mode 100644
index c2af934aa5aff33c26ae95a2d7a46eb19f9b7194..0000000000000000000000000000000000000000
--- a/conf/generated/titi-monkey/coarse.yml
+++ /dev/null
@@ -1,8 +0,0 @@
-$include:
-- conf/lora/lora.yml
-fine_tune: true
-fine_tune_checkpoint: ./models/spotdl/coarse.pth
-save_path: ./runs/titi-monkey/coarse
-train/AudioLoader.sources: &id001
-- /media/CHONK/hugo/loras/titi-monkey.mp3
-val/AudioLoader.sources: *id001
diff --git a/conf/generated/titi-monkey/interface.yml b/conf/generated/titi-monkey/interface.yml
deleted file mode 100644
index cbc4ffad24c7c3b34e930aff08404955348b49a2..0000000000000000000000000000000000000000
--- a/conf/generated/titi-monkey/interface.yml
+++ /dev/null
@@ -1,7 +0,0 @@
-AudioLoader.sources:
-- - /media/CHONK/hugo/loras/titi-monkey.mp3
-Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
-Interface.coarse2fine_lora_ckpt: ./runs/titi-monkey/c2f/latest/lora.pth
-Interface.coarse_ckpt: ./models/spotdl/coarse.pth
-Interface.coarse_lora_ckpt: ./runs/titi-monkey/coarse/latest/lora.pth
-Interface.codec_ckpt: ./models/spotdl/codec.pth
diff --git a/conf/generated/xeno-canto/c2f.yml b/conf/generated/xeno-canto/c2f.yml
deleted file mode 100644
index 251b0e361ee15d01f7715608480cb3d5e9fdb122..0000000000000000000000000000000000000000
--- a/conf/generated/xeno-canto/c2f.yml
+++ /dev/null
@@ -1,15 +0,0 @@
-$include:
-- conf/lora/lora.yml
-AudioDataset.duration: 3.0
-AudioDataset.loudness_cutoff: -40.0
-VampNet.embedding_dim: 1280
-VampNet.n_codebooks: 14
-VampNet.n_conditioning_codebooks: 4
-VampNet.n_heads: 20
-VampNet.n_layers: 16
-fine_tune: true
-fine_tune_checkpoint: ./models/spotdl/c2f.pth
-save_path: ./runs/xeno-canto/c2f
-train/AudioLoader.sources: &id001
-- /media/CHONK/hugo/loras/xeno-canto-2/
-val/AudioLoader.sources: *id001
diff --git a/conf/generated/xeno-canto/coarse.yml b/conf/generated/xeno-canto/coarse.yml
deleted file mode 100644
index ea151dbb64ff13982b0004685901da2b58c8e596..0000000000000000000000000000000000000000
--- a/conf/generated/xeno-canto/coarse.yml
+++ /dev/null
@@ -1,8 +0,0 @@
-$include:
-- conf/lora/lora.yml
-fine_tune: true
-fine_tune_checkpoint: ./models/spotdl/coarse.pth
-save_path: ./runs/xeno-canto/coarse
-train/AudioLoader.sources: &id001
-- /media/CHONK/hugo/loras/xeno-canto-2/
-val/AudioLoader.sources: *id001
diff --git a/conf/generated/xeno-canto/interface.yml b/conf/generated/xeno-canto/interface.yml
deleted file mode 100644
index 1a8b1420f142cef024471073e674cd9db59ffad0..0000000000000000000000000000000000000000
--- a/conf/generated/xeno-canto/interface.yml
+++ /dev/null
@@ -1,7 +0,0 @@
-AudioLoader.sources:
-- - /media/CHONK/hugo/loras/xeno-canto-2/
-Interface.coarse2fine_ckpt: ./mod els/spotdl/c2f.pth
-Interface.coarse2fine_lora_ckpt: ./runs/xeno-canto/c2f/latest/lora.pth
-Interface.coarse_ckpt: ./models/spotdl/coarse.pth
-Interface.coarse_lora_ckpt: ./runs/xeno-canto/coarse/latest/lora.pth
-Interface.codec_ckpt: ./models/spotdl/codec.pth
diff --git a/conf/lora/birds.yml b/conf/lora/birds.yml
deleted file mode 100644
index de413ec0dec4f974e664923c9319861a1c957e87..0000000000000000000000000000000000000000
--- a/conf/lora/birds.yml
+++ /dev/null
@@ -1,10 +0,0 @@
-$include:
-  - conf/lora/lora.yml
-
-fine_tune: True
-
-train/AudioLoader.sources:
-  - /media/CHONK/hugo/spotdl/subsets/birds
-
-val/AudioLoader.sources:
-  - /media/CHONK/hugo/spotdl/subsets/birds
diff --git a/conf/lora/birdss.yml b/conf/lora/birdss.yml
deleted file mode 100644
index 3526de67d24e296de2cc0a7d2e5ebbc18245a6c8..0000000000000000000000000000000000000000
--- a/conf/lora/birdss.yml
+++ /dev/null
@@ -1,12 +0,0 @@
-$include:
-  - conf/lora/lora.yml
-
-fine_tune: True
-
-train/AudioLoader.sources:
-  - /media/CHONK/hugo/spotdl/subsets/birds
-  - /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/
-
-val/AudioLoader.sources:
-  - /media/CHONK/hugo/spotdl/subsets/birds
-  - /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/
diff --git a/conf/lora/constructions.yml b/conf/lora/constructions.yml
deleted file mode 100644
index f513b4898e06339fa0d0b4af24e98fdf5289094a..0000000000000000000000000000000000000000
--- a/conf/lora/constructions.yml
+++ /dev/null
@@ -1,10 +0,0 @@
-$include:
-  - conf/lora/lora.yml
-
-fine_tune: True
-
-train/AudioLoader.sources:
-  - /media/CHONK/hugo/spotdl/subsets/constructions/third.mp3
-
-val/AudioLoader.sources:
-  - /media/CHONK/hugo/spotdl/subsets/constructions/third.mp3
diff --git a/conf/lora/ella-baila-sola.yml b/conf/lora/ella-baila-sola.yml
deleted file mode 100644
index 24eeada8013ea0d56d7d6474db52a48c3fd43bc1..0000000000000000000000000000000000000000
--- a/conf/lora/ella-baila-sola.yml
+++ /dev/null
@@ -1,10 +0,0 @@
-$include:
-  - conf/lora/lora.yml
-
-fine_tune: True
-
-train/AudioLoader.sources:
-  - /media/CHONK/hugo/spotdl/subsets/ella-baila-sola.mp3
-
-val/AudioLoader.sources:
-  - /media/CHONK/hugo/spotdl/subsets/ella-baila-sola.mp3
diff --git a/conf/lora/gas-station.yml b/conf/lora/gas-station.yml
deleted file mode 100644
index 4369f9203232fa3dcfd21667f3e55d0d0fda108e..0000000000000000000000000000000000000000
--- a/conf/lora/gas-station.yml
+++ /dev/null
@@ -1,10 +0,0 @@
-$include:
-  - conf/lora/lora.yml
-
-fine_tune: True
-
-train/AudioLoader.sources:
-  - /media/CHONK/hugo/spotdl/subsets/gas-station-sushi.mp3
-
-val/AudioLoader.sources:
-  - /media/CHONK/hugo/spotdl/subsets/gas-station-sushi.mp3
diff --git a/conf/lora/lora-is-this-charlie-parker.yml b/conf/lora/lora-is-this-charlie-parker.yml
deleted file mode 100644
index 9cfaa31a421266fafa60a1ee4bb2d45f1c47577c..0000000000000000000000000000000000000000
--- a/conf/lora/lora-is-this-charlie-parker.yml
+++ /dev/null
@@ -1,10 +0,0 @@
-$include:
-  - conf/lora/lora.yml
-
-fine_tune: True
-
-train/AudioLoader.sources:
-  - /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/Charlie Parker - Donna Lee.mp3
-
-val/AudioLoader.sources:
-  - /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/Charlie Parker - Donna Lee.mp3
diff --git a/conf/lora/lora.yml b/conf/lora/lora.yml
index b901ea00a6008b92f25728d6d01a258c6aba5d1e..c6abe7e0bddac557ea3885309f3425877541cfe9 100644
--- a/conf/lora/lora.yml
+++ b/conf/lora/lora.yml
@@ -3,20 +3,18 @@ $include:
 
 fine_tune: True
 
-train/AudioDataset.n_examples: 10000000
-
-val/AudioDataset.n_examples: 10
+train/AudioDataset.n_examples: 100000000
+val/AudioDataset.n_examples: 100
 
 
 NoamScheduler.warmup: 500
 
 batch_size: 7
 num_workers: 7
-epoch_length: 100
-save_audio_epochs: 10
+save_iters: [100000, 200000, 300000, 4000000, 500000]
 
 AdamW.lr: 0.0001
 
 # let's us organize sound classes into folders and choose from those sound classes uniformly
 AudioDataset.without_replacement: False
-max_epochs: 500
\ No newline at end of file
+num_iters: 500000
\ No newline at end of file
diff --git a/conf/lora/underworld.yml b/conf/lora/underworld.yml
deleted file mode 100644
index 6fd1a6cf1e74220a2b51b1117afb373acda033a7..0000000000000000000000000000000000000000
--- a/conf/lora/underworld.yml
+++ /dev/null
@@ -1,10 +0,0 @@
-$include:
-  - conf/lora/lora.yml
-
-fine_tune: True
-
-train/AudioLoader.sources:
-  - /media/CHONK/hugo/spotdl/subsets/underworld.mp3
-
-val/AudioLoader.sources:
-  - /media/CHONK/hugo/spotdl/subsets/underworld.mp3
diff --git a/conf/lora/xeno-canto/c2f.yml b/conf/lora/xeno-canto/c2f.yml
deleted file mode 100644
index 94f9906189f0b74b6c492bdd53fa56d58a0fa04d..0000000000000000000000000000000000000000
--- a/conf/lora/xeno-canto/c2f.yml
+++ /dev/null
@@ -1,21 +0,0 @@
-$include:
-  - conf/lora/lora.yml
-
-fine_tune: True
-
-train/AudioLoader.sources:
-  - /media/CHONK/hugo/xeno-canto-2
-
-val/AudioLoader.sources:
-  - /media/CHONK/hugo/xeno-canto-2
-
-
-VampNet.n_codebooks: 14
-VampNet.n_conditioning_codebooks: 4
-
-VampNet.embedding_dim: 1280
-VampNet.n_layers: 16
-VampNet.n_heads: 20
-
-AudioDataset.duration: 3.0
-AudioDataset.loudness_cutoff: -40.0
diff --git a/conf/lora/xeno-canto/coarse.yml b/conf/lora/xeno-canto/coarse.yml
deleted file mode 100644
index 223c8f0f8481f55ac1c33816ed79fe45b50f1495..0000000000000000000000000000000000000000
--- a/conf/lora/xeno-canto/coarse.yml
+++ /dev/null
@@ -1,10 +0,0 @@
-$include:
-  - conf/lora/lora.yml
-
-fine_tune: True
-
-train/AudioLoader.sources:
-  - /media/CHONK/hugo/xeno-canto-2
-
-val/AudioLoader.sources:
-  - /media/CHONK/hugo/xeno-canto-2
diff --git a/conf/vampnet-musdb-drums.yml b/conf/vampnet-musdb-drums.yml
deleted file mode 100644
index 010843d81ec9ac3c832b8e88f30af2f99a56ba99..0000000000000000000000000000000000000000
--- a/conf/vampnet-musdb-drums.yml
+++ /dev/null
@@ -1,22 +0,0 @@
-$include:
-  - conf/vampnet.yml
-
-VampNet.embedding_dim: 512
-VampNet.n_layers: 12
-VampNet.n_heads: 8
-
-AudioDataset.duration: 12.0
-
-train/AudioDataset.n_examples: 10000000
-train/AudioLoader.sources:
-  - /data/musdb18hq/train/**/*drums.wav
-
-
-val/AudioDataset.n_examples: 500
-val/AudioLoader.sources:
-  - /data/musdb18hq/test/**/*drums.wav
-
-
-test/AudioDataset.n_examples: 1000
-test/AudioLoader.sources:
-  - /data/musdb18hq/test/**/*drums.wav
diff --git a/conf/vampnet.yml b/conf/vampnet.yml
index d24df3fc1923eeb98f76f5747a52c3e83ef98795..6157577f3435c86e302bee1aa62f48d855128490 100644
--- a/conf/vampnet.yml
+++ b/conf/vampnet.yml
@@ -1,21 +1,17 @@
 
-codec_ckpt: ./models/spotdl/codec.pth
+codec_ckpt: ./models/vampnet/codec.pth
 save_path: ckpt
-max_epochs: 1000
-epoch_length: 1000
-save_audio_epochs: 2
-val_idx: [0,1,2,3,4,5,6,7,8,9]
 
-prefix_amt: 0.0
-suffix_amt: 0.0
-prefix_dropout: 0.1
-suffix_dropout: 0.1
+num_iters: 1000000000
+save_iters: [10000, 50000, 100000, 300000, 500000]
+val_idx: [0,1,2,3,4,5,6,7,8,9]
+sample_freq: 10000
+val_freq: 1000
 
 batch_size: 8
 num_workers: 10
 
 # Optimization
-detect_anomaly: false
 amp: false
 
 CrossEntropyLoss.label_smoothing: 0.1
@@ -25,9 +21,6 @@ AdamW.lr: 0.001
 NoamScheduler.factor: 2.0
 NoamScheduler.warmup: 10000
 
-PitchShift.shift_amount: [const, 0]
-PitchShift.prob: 0.0
-
 VampNet.vocab_size: 1024
 VampNet.n_codebooks: 4
 VampNet.n_conditioning_codebooks: 0
@@ -48,12 +41,9 @@ AudioDataset.duration: 10.0
 
 train/AudioDataset.n_examples: 10000000
 train/AudioLoader.sources:
-  - /data/spotdl/audio/train
+  - /media/CHONK/hugo/spotdl/audio-train
 
 val/AudioDataset.n_examples: 2000
 val/AudioLoader.sources:
-  - /data/spotdl/audio/val
+  - /media/CHONK/hugo/spotdl/audio-val
 
-test/AudioDataset.n_examples: 1000
-test/AudioLoader.sources:
-  - /data/spotdl/audio/test
diff --git a/scripts/exp/fine_tune.py b/scripts/exp/fine_tune.py
index e2c6c3b768f585242705e5cdabeebe45ced557cf..d3145378c574ee293c96e4973ec6f33ee3cb8713 100644
--- a/scripts/exp/fine_tune.py
+++ b/scripts/exp/fine_tune.py
@@ -35,7 +35,7 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
         "AudioDataset.duration": 3.0,
         "AudioDataset.loudness_cutoff": -40.0,
         "save_path": f"./runs/{name}/c2f",
-        "fine_tune_checkpoint": "./models/spotdl/c2f.pth"
+        "fine_tune_checkpoint": "./models/vampnet/c2f.pth"
     }
 
     finetune_coarse_conf = {
@@ -44,17 +44,17 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
         "train/AudioLoader.sources": audio_files_or_folders,
         "val/AudioLoader.sources": audio_files_or_folders,
         "save_path": f"./runs/{name}/coarse",
-        "fine_tune_checkpoint": "./models/spotdl/coarse.pth"
+        "fine_tune_checkpoint": "./models/vampnet/coarse.pth"
     }
 
     interface_conf = {
-        "Interface.coarse_ckpt": f"./models/spotdl/coarse.pth",
+        "Interface.coarse_ckpt": f"./models/vampnet/coarse.pth",
         "Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
 
-        "Interface.coarse2fine_ckpt": f"./models/spotdl/c2f.pth",
+        "Interface.coarse2fine_ckpt": f"./models/vampnet/c2f.pth",
         "Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
 
-        "Interface.codec_ckpt": "./models/spotdl/codec.pth",
+        "Interface.codec_ckpt": "./models/vampnet/codec.pth",
         "AudioLoader.sources": [audio_files_or_folders],
     }
 
diff --git a/scripts/exp/train.py b/scripts/exp/train.py
index 79251a529c9512b7bf8c2613e6ae173df21c5c61..68ddd88221710645c2e500203aae64dfd8d09257 100644
--- a/scripts/exp/train.py
+++ b/scripts/exp/train.py
@@ -1,9 +1,9 @@
 import os
-import subprocess
-import time
+import sys
 import warnings
 from pathlib import Path
 from typing import Optional
+from dataclasses import dataclass
 
 import argbind
 import audiotools as at
@@ -23,6 +23,12 @@ from vampnet import mask as pmask
 # from dac.model.dac import DAC
 from lac.model.lac import LAC as DAC
 
+from audiotools.ml.decorators import (
+    timer, Tracker, when
+)
+
+import loralib as lora
+
 
 # Enable cudnn autotuner to speed up training
 # (can be altered by the funcs.seed function)
@@ -85,11 +91,7 @@ def build_datasets(args, sample_rate: int):
         )
     with argbind.scope(args, "val"):
         val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform())
-    with argbind.scope(args, "test"):
-        test_data = AudioDataset(
-            AudioLoader(), sample_rate, transform=build_transform()
-        )
-    return train_data, val_data, test_data
+    return train_data, val_data
 
 
 def rand_float(shape, low, high, rng):
@@ -100,16 +102,393 @@ def flip_coin(shape, p, rng):
     return rng.draw(shape)[:, 0] < p
 
 
+def num_params_hook(o, p):
+    return o + f" {p/1e6:<.3f}M params."
+
+
+def add_num_params_repr_hook(model):
+    import numpy as np
+    from functools import partial
+
+    for n, m in model.named_modules():
+        o = m.extra_repr()
+        p = sum([np.prod(p.size()) for p in m.parameters()])
+
+        setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
+
+
+def accuracy(
+    preds: torch.Tensor,
+    target: torch.Tensor,
+    top_k: int = 1,
+    ignore_index: Optional[int] = None,
+) -> torch.Tensor:
+    # Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
+    preds = rearrange(preds, "b p s -> (b s) p")
+    target = rearrange(target, "b s -> (b s)")
+
+    # return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index)
+    if ignore_index is not None:
+        # Create a mask for the ignored index
+        mask = target != ignore_index
+        # Apply the mask to the target and predictions
+        preds = preds[mask]
+        target = target[mask]
+
+    # Get the top-k predicted classes and their indices
+    _, pred_indices = torch.topk(preds, k=top_k, dim=-1)
+
+    # Determine if the true target is in the top-k predicted classes
+    correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1)
+
+    # Calculate the accuracy
+    accuracy = torch.mean(correct.float())
+
+    return accuracy
+
+def _metrics(z_hat, r, target, flat_mask, output):
+    for r_range in [(0, 0.5), (0.5, 1.0)]:
+        unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
+        masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
+
+        assert target.shape[0] == r.shape[0]
+        # grab the indices of the r values that are in the range
+        r_idx = (r >= r_range[0]) & (r < r_range[1])
+
+        # grab the target and z_hat values that are in the range
+        r_unmasked_target = unmasked_target[r_idx]
+        r_masked_target = masked_target[r_idx]
+        r_z_hat = z_hat[r_idx]
+
+        for topk in (1, 25):
+            s, e = r_range
+            tag = f"accuracy-{s}-{e}/top{topk}"
+
+            output[f"{tag}/unmasked"] = accuracy(
+                preds=r_z_hat,
+                target=r_unmasked_target,
+                ignore_index=IGNORE_INDEX,
+                top_k=topk,
+            )
+            output[f"{tag}/masked"] = accuracy(
+                preds=r_z_hat,
+                target=r_masked_target,
+                ignore_index=IGNORE_INDEX,
+                top_k=topk,
+            )
+
+
+@dataclass
+class State:
+    model: VampNet
+    codec: DAC
+
+    optimizer: AdamW
+    scheduler: NoamScheduler
+    criterion: CrossEntropyLoss
+    grad_clip_val: float
+
+    rng: torch.quasirandom.SobolEngine
+
+    train_data: AudioDataset
+    val_data: AudioDataset
+
+    tracker: Tracker
+
+
+@timer()
+def train_loop(state: State, batch: dict, accel: Accelerator):
+    state.model.train()
+    batch = at.util.prepare_batch(batch, accel.device)
+    signal = apply_transform(state.train_data.transform, batch)
+
+    output = {}
+    vn = accel.unwrap(state.model)
+    with accel.autocast():
+        with torch.inference_mode():
+            state.codec.to(accel.device)
+            z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
+            z = z[:, : vn.n_codebooks, :]
+
+        n_batch = z.shape[0]
+        r = state.rng.draw(n_batch)[:, 0].to(accel.device)
+
+        mask = pmask.random(z, r)
+        mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
+        z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
+        
+        z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
+
+        dtype = torch.bfloat16 if accel.amp else None
+        with accel.autocast(dtype=dtype):
+            z_hat = state.model(z_mask_latent, r)
+
+        target = codebook_flatten(
+            z[:, vn.n_conditioning_codebooks :, :],
+        )
+
+        flat_mask = codebook_flatten(
+            mask[:, vn.n_conditioning_codebooks :, :],
+        )
+
+        # replace target with ignore index for masked tokens
+        t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
+        output["loss"] = state.criterion(z_hat, t_masked)
+
+        _metrics(
+            r=r,
+            z_hat=z_hat,
+            target=target,
+            flat_mask=flat_mask,
+            output=output,
+        )
+
+    
+    accel.backward(output["loss"])
+
+    output["other/learning_rate"] = state.optimizer.param_groups[0]["lr"]
+    output["other/batch_size"] = z.shape[0]
+
+
+    accel.scaler.unscale_(state.optimizer)
+    output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
+        state.model.parameters(), state.grad_clip_val
+    )
+
+    accel.step(state.optimizer)
+    state.optimizer.zero_grad()
+
+    state.scheduler.step()
+    accel.update()
+
+
+    return {k: v for k, v in sorted(output.items())}
+
+
+@timer()
+@torch.no_grad()
+def val_loop(state: State, batch: dict, accel: Accelerator):
+    state.model.eval()
+    state.codec.eval()
+    batch = at.util.prepare_batch(batch, accel.device)
+    signal = apply_transform(state.val_data.transform, batch)
+
+    vn = accel.unwrap(state.model)
+    z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
+    z = z[:, : vn.n_codebooks, :]
+
+    n_batch = z.shape[0]
+    r = state.rng.draw(n_batch)[:, 0].to(accel.device)
+
+    mask = pmask.random(z, r)
+    mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
+    z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
+
+    z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
+
+    z_hat = state.model(z_mask_latent, r)
+
+    target = codebook_flatten(
+        z[:, vn.n_conditioning_codebooks :, :],
+    )
+
+    flat_mask = codebook_flatten(
+        mask[:, vn.n_conditioning_codebooks :, :]
+    )
+
+    output = {}
+    # replace target with ignore index for masked tokens
+    t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
+    output["loss"] = state.criterion(z_hat, t_masked)
+
+    _metrics(
+        r=r,
+        z_hat=z_hat,
+        target=target,
+        flat_mask=flat_mask,
+        output=output,
+    )
+
+    return output
+
+
+def validate(state, val_dataloader, accel):
+    for batch in val_dataloader:
+        output = val_loop(state, batch, accel)
+    # Consolidate state dicts if using ZeroRedundancyOptimizer
+    if hasattr(state.optimizer, "consolidate_state_dict"):
+        state.optimizer.consolidate_state_dict()
+    return output
+
+
+def checkpoint(state, save_iters, save_path, fine_tune):
+    if accel.local_rank != 0:
+        state.tracker.print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}")
+        return
+
+    metadata = {"logs": dict(state.tracker.history)}
+
+    tags = ["latest"]
+    state.tracker.print(f"Saving to {str(Path('.').absolute())}")
+
+    if state.tracker.step in save_iters:
+        tags.append(f"{state.tracker.step // 1000}k")
+
+    if state.tracker.is_best("val", "loss"):
+        state.tracker.print(f"Best model so far")
+        tags.append("best")
+
+    if fine_tune:
+        for tag in tags: 
+            # save the lora model 
+            (Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
+            torch.save(
+                lora.lora_state_dict(accel.unwrap(state.model)), 
+                f"{save_path}/{tag}/lora.pth"
+            )
+
+    for tag in tags:
+        model_extra = {
+            "optimizer.pth": state.optimizer.state_dict(),
+            "scheduler.pth": state.scheduler.state_dict(),
+            "tracker.pth": state.tracker.state_dict(),
+            "metadata.pth": metadata,
+        }
+
+        accel.unwrap(state.model).metadata = metadata
+        accel.unwrap(state.model).save_to_folder(
+            f"{save_path}/{tag}", model_extra, package=False
+        )
+
+
+def save_sampled(state, z, writer):
+    num_samples = z.shape[0]
+
+    for i in range(num_samples):
+        sampled = accel.unwrap(state.model).generate(
+            codec=state.codec,
+            time_steps=z.shape[-1],
+            start_tokens=z[i : i + 1],
+        )
+        sampled.cpu().write_audio_to_tb(
+            f"sampled/{i}",
+            writer,
+            step=state.tracker.step,
+            plot_fn=None,
+        )
+
+
+def save_imputation(state, z, val_idx, writer):
+    n_prefix = int(z.shape[-1] * 0.25)
+    n_suffix = int(z.shape[-1] *  0.25)
+
+    vn = accel.unwrap(state.model)
+
+    mask = pmask.inpaint(z, n_prefix, n_suffix)
+    mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
+    z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
+
+    imputed_noisy = vn.to_signal(z_mask, state.codec)
+    imputed_true = vn.to_signal(z, state.codec)
+
+    imputed = []
+    for i in range(len(z)):
+        imputed.append(
+            vn.generate(
+                codec=state.codec,
+                time_steps=z.shape[-1],
+                start_tokens=z[i][None, ...],
+                mask=mask[i][None, ...],
+            )   
+        )   
+    imputed = AudioSignal.batch(imputed)
+
+    for i in range(len(val_idx)):
+        imputed_noisy[i].cpu().write_audio_to_tb(
+            f"imputed_noisy/{i}",
+            writer,
+            step=state.tracker.step,
+            plot_fn=None,
+        )
+        imputed[i].cpu().write_audio_to_tb(
+            f"imputed/{i}",
+            writer,
+            step=state.tracker.step,
+            plot_fn=None,
+        )
+        imputed_true[i].cpu().write_audio_to_tb(
+            f"imputed_true/{i}",
+            writer,
+            step=state.tracker.step,
+            plot_fn=None,
+        )
+
+
+@torch.no_grad()
+def save_samples(state: State, val_idx: int, writer: SummaryWriter):
+    state.model.eval()
+    state.codec.eval()
+    vn = accel.unwrap(state.model)
+
+    batch = [state.val_data[i] for i in val_idx]
+    batch = at.util.prepare_batch(state.val_data.collate(batch), accel.device)
+
+    signal = apply_transform(state.val_data.transform, batch)
+
+    z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
+    z = z[:, : vn.n_codebooks, :]
+
+    r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
+
+
+    mask = pmask.random(z, r)
+    mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
+    z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
+
+    z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
+
+    z_hat = state.model(z_mask_latent, r)
+
+    z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
+    z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
+    z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
+
+    generated = vn.to_signal(z_pred, state.codec)
+    reconstructed = vn.to_signal(z, state.codec)
+    masked = vn.to_signal(z_mask.squeeze(1), state.codec)
+
+    for i in range(generated.batch_size):
+        audio_dict = {
+            "original": signal[i],
+            "masked": masked[i],
+            "generated": generated[i],
+            "reconstructed": reconstructed[i],
+        }
+        for k, v in audio_dict.items():
+            v.cpu().write_audio_to_tb(
+                f"samples/_{i}.r={r[i]:0.2f}/{k}",
+                writer,
+                step=state.tracker.step,
+                plot_fn=None,
+            )
+
+    save_sampled(state=state, z=z, writer=writer)
+    save_imputation(state=state, z=z, val_idx=val_idx, writer=writer)
+
+
+
 @argbind.bind(without_prefix=True)
 def load(
     args,
     accel: at.ml.Accelerator,
+    tracker: Tracker,
     save_path: str,
     resume: bool = False,
     tag: str = "latest",
     load_weights: bool = False,
     fine_tune_checkpoint: Optional[str] = None,
-):
+    grad_clip_val: float = 5.0,
+) -> State:
     codec = DAC.load(args["codec_ckpt"], map_location="cpu")
     codec.eval()
 
@@ -121,6 +500,7 @@ def load(
             "map_location": "cpu",
             "package": not load_weights,
         }
+        tracker.print(f"Loading checkpoint from {kwargs['folder']}")
         if (Path(kwargs["folder"]) / "vampnet").exists():
             model, v_extra = VampNet.load_from_folder(**kwargs)
         else:
@@ -147,89 +527,57 @@ def load(
     scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
     scheduler.step()
 
-    trainer_state = {"state_dict": None, "start_idx": 0}
-
     if "optimizer.pth" in v_extra:
         optimizer.load_state_dict(v_extra["optimizer.pth"])
-    if "scheduler.pth" in v_extra:
         scheduler.load_state_dict(v_extra["scheduler.pth"])
-    if "trainer.pth" in v_extra:
-        trainer_state = v_extra["trainer.pth"]
-
-    return {
-        "model": model,
-        "codec": codec,
-        "optimizer": optimizer,
-        "scheduler": scheduler,
-        "trainer_state": trainer_state,
-    }
-
-
-
-def num_params_hook(o, p):
-    return o + f" {p/1e6:<.3f}M params."
-
-
-def add_num_params_repr_hook(model):
-    import numpy as np
-    from functools import partial
-
-    for n, m in model.named_modules():
-        o = m.extra_repr()
-        p = sum([np.prod(p.size()) for p in m.parameters()])
-
-        setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
-
-
-def accuracy(
-    preds: torch.Tensor,
-    target: torch.Tensor,
-    top_k: int = 1,
-    ignore_index: Optional[int] = None,
-) -> torch.Tensor:
-    # Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
-    preds = rearrange(preds, "b p s -> (b s) p")
-    target = rearrange(target, "b s -> (b s)")
-
-    # return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index)
-    if ignore_index is not None:
-        # Create a mask for the ignored index
-        mask = target != ignore_index
-        # Apply the mask to the target and predictions
-        preds = preds[mask]
-        target = target[mask]
+    if "tracker.pth" in v_extra:
+        tracker.load_state_dict(v_extra["tracker.pth"])
+    
+    criterion = CrossEntropyLoss()
 
-    # Get the top-k predicted classes and their indices
-    _, pred_indices = torch.topk(preds, k=top_k, dim=-1)
+    sample_rate = codec.sample_rate
 
-    # Determine if the true target is in the top-k predicted classes
-    correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1)
+    # a better rng for sampling from our schedule
+    rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=args["seed"])  
 
-    # Calculate the accuracy
-    accuracy = torch.mean(correct.float())
+    # log a model summary w/ num params
+    if accel.local_rank == 0:
+        add_num_params_repr_hook(accel.unwrap(model))
+        with open(f"{save_path}/model.txt", "w") as f:
+            f.write(repr(accel.unwrap(model)))
 
-    return accuracy
+    # load the datasets
+    train_data, val_data = build_datasets(args, sample_rate)
+
+    return State(
+        tracker=tracker,
+        model=model,
+        codec=codec,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        criterion=criterion,
+        rng=rng,
+        train_data=train_data,
+        val_data=val_data,
+        grad_clip_val=grad_clip_val,
+    )
 
 
 @argbind.bind(without_prefix=True)
 def train(
     args,
     accel: at.ml.Accelerator,
-    codec_ckpt: str = None,
     seed: int = 0,
+    codec_ckpt: str = None,
     save_path: str = "ckpt",
-    max_epochs: int = int(100e3),
-    epoch_length: int = 1000,
-    save_audio_epochs: int = 2,
-    save_epochs: list = [10, 50, 100, 200, 300, 400,],
-    batch_size: int = 48,
-    grad_acc_steps: int = 1,
+    num_iters: int = int(1000e6),
+    save_iters: list = [10000, 50000, 100000, 300000, 500000,],
+    sample_freq: int = 10000, 
+    val_freq: int = 1000,
+    batch_size: int = 12,
     val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
     num_workers: int = 10,
-    detect_anomaly: bool = False,
-    grad_clip_val: float = 5.0,
     fine_tune: bool = False, 
-    quiet: bool = False,
 ):
     assert codec_ckpt is not None, "codec_ckpt is required"
 
@@ -241,376 +589,76 @@ def train(
         writer = SummaryWriter(log_dir=f"{save_path}/logs/")
         argbind.dump_args(args, f"{save_path}/args.yml")
 
-    # load the codec model
-    loaded = load(args, accel, save_path)
-    model = loaded["model"]
-    codec = loaded["codec"]
-    optimizer = loaded["optimizer"]
-    scheduler = loaded["scheduler"]
-    trainer_state = loaded["trainer_state"]
-
-    sample_rate = codec.sample_rate
+        tracker = Tracker(
+            writer=writer, log_file=f"{save_path}/log.txt", rank=accel.local_rank
+        )
 
-    # a better rng for sampling from our schedule
-    rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=seed)  
+    # load the codec model
+    state: State = load(
+        args=args, 
+        accel=accel, 
+        tracker=tracker, 
+        save_path=save_path)
 
-    # log a model summary w/ num params
-    if accel.local_rank == 0:
-        add_num_params_repr_hook(accel.unwrap(model))
-        with open(f"{save_path}/model.txt", "w") as f:
-            f.write(repr(accel.unwrap(model)))
 
-    # load the datasets
-    train_data, val_data, _ = build_datasets(args, sample_rate)
     train_dataloader = accel.prepare_dataloader(
-        train_data,
-        start_idx=trainer_state["start_idx"],
+        state.train_data,
+        start_idx=state.tracker.step * batch_size,
         num_workers=num_workers,
         batch_size=batch_size,
-        collate_fn=train_data.collate,
+        collate_fn=state.train_data.collate,
     )
     val_dataloader = accel.prepare_dataloader(
-        val_data,
+        state.val_data,
         start_idx=0,
         num_workers=num_workers,
         batch_size=batch_size,
-        collate_fn=val_data.collate,
+        collate_fn=state.val_data.collate,
+        persistent_workers=True,
     )
 
-    criterion = CrossEntropyLoss()
+    
 
     if fine_tune:
-        import loralib as lora
-        lora.mark_only_lora_as_trainable(model)
-
-
-    class Trainer(at.ml.BaseTrainer):
-        _last_grad_norm = 0.0
-
-        def _metrics(self, vn, z_hat, r, target, flat_mask, output):
-            for r_range in [(0, 0.5), (0.5, 1.0)]:
-                unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
-                masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
-
-                assert target.shape[0] == r.shape[0]
-                # grab the indices of the r values that are in the range
-                r_idx = (r >= r_range[0]) & (r < r_range[1])
-
-                # grab the target and z_hat values that are in the range
-                r_unmasked_target = unmasked_target[r_idx]
-                r_masked_target = masked_target[r_idx]
-                r_z_hat = z_hat[r_idx]
-
-                for topk in (1, 25):
-                    s, e = r_range
-                    tag = f"accuracy-{s}-{e}/top{topk}"
-
-                    output[f"{tag}/unmasked"] = accuracy(
-                        preds=r_z_hat,
-                        target=r_unmasked_target,
-                        ignore_index=IGNORE_INDEX,
-                        top_k=topk,
-                    )
-                    output[f"{tag}/masked"] = accuracy(
-                        preds=r_z_hat,
-                        target=r_masked_target,
-                        ignore_index=IGNORE_INDEX,
-                        top_k=topk,
-                    )
-
-        def train_loop(self, engine, batch):
-            model.train()
-            batch = at.util.prepare_batch(batch, accel.device)
-            signal = apply_transform(train_data.transform, batch)
-
-            output = {}
-            vn = accel.unwrap(model)
-            with accel.autocast():
-                with torch.inference_mode():
-                    codec.to(accel.device)
-                    z = codec.encode(signal.samples, signal.sample_rate)["codes"]
-                    z = z[:, : vn.n_codebooks, :]
-
-                n_batch = z.shape[0]
-                r = rng.draw(n_batch)[:, 0].to(accel.device)
-
-                mask = pmask.random(z, r)
-                mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
-                z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
-                
-                z_mask_latent = vn.embedding.from_codes(z_mask, codec)
-
-                dtype = torch.bfloat16 if accel.amp else None
-                with accel.autocast(dtype=dtype):
-                    z_hat = model(z_mask_latent, r)
-
-                target = codebook_flatten(
-                    z[:, vn.n_conditioning_codebooks :, :],
-                )
-
-                flat_mask = codebook_flatten(
-                    mask[:, vn.n_conditioning_codebooks :, :],
-                )
-
-                # replace target with ignore index for masked tokens
-                t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
-                output["loss"] = criterion(z_hat, t_masked)
-
-                self._metrics(
-                    vn=vn,
-                    r=r,
-                    z_hat=z_hat,
-                    target=target,
-                    flat_mask=flat_mask,
-                    output=output,
-                )
-
-            
-            accel.backward(output["loss"] / grad_acc_steps)
-
-            output["other/learning_rate"] = optimizer.param_groups[0]["lr"]
-            output["other/batch_size"] = z.shape[0]
-
-            if (
-                (engine.state.iteration % grad_acc_steps == 0)
-                or (engine.state.iteration % epoch_length == 0)
-                or (engine.state.iteration % epoch_length == 1)
-            ):  # (or we reached the end of the epoch)
-                accel.scaler.unscale_(optimizer)
-                output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
-                    model.parameters(), grad_clip_val
-                )
-                self._last_grad_norm = output["other/grad_norm"]
-
-                accel.step(optimizer)
-                optimizer.zero_grad()
-
-                scheduler.step()
-                accel.update()
-            else:
-                output["other/grad_norm"] = self._last_grad_norm
-
-            return {k: v for k, v in sorted(output.items())}
-
-        @torch.no_grad()
-        def val_loop(self, engine, batch):
-            model.eval()
-            codec.eval()
-            batch = at.util.prepare_batch(batch, accel.device)
-            signal = apply_transform(val_data.transform, batch)
-
-            vn = accel.unwrap(model)
-            z = codec.encode(signal.samples, signal.sample_rate)["codes"]
-            z = z[:, : vn.n_codebooks, :]
-
-            n_batch = z.shape[0]
-            r = rng.draw(n_batch)[:, 0].to(accel.device)
-
-            mask = pmask.random(z, r)
-            mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
-            z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
+        lora.mark_only_lora_as_trainable(state.model)
 
-            z_mask_latent = vn.embedding.from_codes(z_mask, codec)
+    # Wrap the functions so that they neatly track in TensorBoard + progress bars
+    # and only run when specific conditions are met.
+    global train_loop, val_loop, validate, save_samples, checkpoint
 
-            z_hat = model(z_mask_latent, r)
+    train_loop = tracker.log("train", "value", history=False)(
+        tracker.track("train", num_iters, completed=state.tracker.step)(train_loop)
+    )
+    val_loop = tracker.track("val", len(val_dataloader))(val_loop)
+    validate = tracker.log("val", "mean")(validate)
 
-            target = codebook_flatten(
-                z[:, vn.n_conditioning_codebooks :, :],
-            )
+    save_samples = when(lambda: accel.local_rank == 0)(save_samples)
+    checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
 
-            flat_mask = codebook_flatten(
-                mask[:, vn.n_conditioning_codebooks :, :]
-            )
+    with tracker.live:
+        for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
+            train_loop(state, batch, accel)
 
-            output = {}
-            # replace target with ignore index for masked tokens
-            t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
-            output["loss"] = criterion(z_hat, t_masked)
-
-            self._metrics(
-                vn=vn,
-                r=r,
-                z_hat=z_hat,
-                target=target,
-                flat_mask=flat_mask,
-                output=output,
+            last_iter = (
+                tracker.step == num_iters - 1 if num_iters is not None else False
             )
 
-            return output
+            if tracker.step % sample_freq == 0 or last_iter:
+                save_samples(state, val_idx, writer)
 
-        def checkpoint(self, engine):
-            if accel.local_rank != 0:
-                print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}")
-                return
-
-            metadata = {"logs": dict(engine.state.logs["epoch"])}
-
-            if self.state.epoch % save_audio_epochs == 0:
-                self.save_samples()
-
-            tags = ["latest"]
-            loss_key = "loss/val" if "loss/val" in metadata["logs"] else "loss/train"
-            self.print(f"Saving to {str(Path('.').absolute())}")
-
-            if self.state.epoch in save_epochs:
-                tags.append(f"epoch={self.state.epoch}")
-
-            if self.is_best(engine, loss_key):
-                self.print(f"Best model so far")
-                tags.append("best")
-
-            if fine_tune:
-                for tag in tags: 
-                    # save the lora model 
-                    (Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
-                    torch.save(
-                        lora.lora_state_dict(accel.unwrap(model)), 
-                        f"{save_path}/{tag}/lora.pth"
-                    )
-
-            for tag in tags:
-                model_extra = {
-                    "optimizer.pth": optimizer.state_dict(),
-                    "scheduler.pth": scheduler.state_dict(),
-                    "trainer.pth": {
-                        "start_idx": self.state.iteration * batch_size,
-                        "state_dict": self.state_dict(),
-                    },
-                    "metadata.pth": metadata,
-                }
-
-                accel.unwrap(model).metadata = metadata
-                accel.unwrap(model).save_to_folder(
-                    f"{save_path}/{tag}", model_extra,
-                )
-
-        def save_sampled(self, z):
-            num_samples = z.shape[0]
-
-            for i in range(num_samples):
-                sampled = accel.unwrap(model).generate(
-                    codec=codec,
-                    time_steps=z.shape[-1],
-                    start_tokens=z[i : i + 1],
-                )
-                sampled.cpu().write_audio_to_tb(
-                    f"sampled/{i}",
-                    self.writer,
-                    step=self.state.epoch,
-                    plot_fn=None,
-                )
-
-
-        def save_imputation(self, z: torch.Tensor):
-            n_prefix = int(z.shape[-1] * 0.25)
-            n_suffix = int(z.shape[-1] *  0.25)
-
-            vn = accel.unwrap(model)
-
-            mask = pmask.inpaint(z, n_prefix, n_suffix)
-            mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
-            z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
-
-            imputed_noisy = vn.to_signal(z_mask, codec)
-            imputed_true = vn.to_signal(z, codec)
-
-            imputed = []
-            for i in range(len(z)):
-                imputed.append(
-                    vn.generate(
-                        codec=codec,
-                        time_steps=z.shape[-1],
-                        start_tokens=z[i][None, ...],
-                        mask=mask[i][None, ...],
-                    )   
-                )   
-            imputed = AudioSignal.batch(imputed)
-
-            for i in range(len(val_idx)):
-                imputed_noisy[i].cpu().write_audio_to_tb(
-                    f"imputed_noisy/{i}",
-                    self.writer,
-                    step=self.state.epoch,
-                    plot_fn=None,
-                )
-                imputed[i].cpu().write_audio_to_tb(
-                    f"imputed/{i}",
-                    self.writer,
-                    step=self.state.epoch,
-                    plot_fn=None,
-                )
-                imputed_true[i].cpu().write_audio_to_tb(
-                    f"imputed_true/{i}",
-                    self.writer,
-                    step=self.state.epoch,
-                    plot_fn=None,
-                )
-
-        @torch.no_grad()
-        def save_samples(self):
-            model.eval()
-            codec.eval()
-            vn = accel.unwrap(model)
-
-            batch = [val_data[i] for i in val_idx]
-            batch = at.util.prepare_batch(val_data.collate(batch), accel.device)
-
-            signal = apply_transform(val_data.transform, batch)
-
-            z = codec.encode(signal.samples, signal.sample_rate)["codes"]
-            z = z[:, : vn.n_codebooks, :]
-
-            r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
-
-
-            mask = pmask.random(z, r)
-            mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
-            z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
-
-            z_mask_latent = vn.embedding.from_codes(z_mask, codec)
-
-            z_hat = model(z_mask_latent, r)
+            if tracker.step % val_freq == 0 or last_iter:
+                validate(state, val_dataloader, accel)
+                checkpoint(
+                    state=state, 
+                    save_iters=save_iters, 
+                    save_path=save_path, 
+                    fine_tune=fine_tune)
 
-            z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
-            z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
-            z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
+                # Reset validation progress bar, print summary since last validation.
+                tracker.done("val", f"Iteration {tracker.step}")
 
-            generated = vn.to_signal(z_pred, codec)
-            reconstructed = vn.to_signal(z, codec)
-            masked = vn.to_signal(z_mask.squeeze(1), codec)
-
-            for i in range(generated.batch_size):
-                audio_dict = {
-                    "original": signal[i],
-                    "masked": masked[i],
-                    "generated": generated[i],
-                    "reconstructed": reconstructed[i],
-                }
-                for k, v in audio_dict.items():
-                    v.cpu().write_audio_to_tb(
-                        f"samples/_{i}.r={r[i]:0.2f}/{k}",
-                        self.writer,
-                        step=self.state.epoch,
-                        plot_fn=None,
-                    )
-
-            self.save_sampled(z)
-            self.save_imputation(z)
-
-    trainer = Trainer(writer=writer, quiet=quiet)
-
-    if trainer_state["state_dict"] is not None:
-        trainer.load_state_dict(trainer_state["state_dict"])
-    if hasattr(train_dataloader.sampler, "set_epoch"):
-        train_dataloader.sampler.set_epoch(trainer.trainer.state.epoch)
-
-    trainer.run(
-        train_dataloader,
-        val_dataloader,
-        num_epochs=max_epochs,
-        epoch_length=epoch_length,
-        detect_anomaly=detect_anomaly,
-    )
+            if last_iter:
+                break
 
 
 if __name__ == "__main__":
@@ -618,4 +666,6 @@ if __name__ == "__main__":
     args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
     with argbind.scope(args):
         with Accelerator() as accel:
+            if accel.local_rank != 0:
+                sys.tracebacklimit = 0
             train(args, accel)
diff --git a/setup.py b/setup.py
index 2964e0f810f32dab3abc433912a2de128c081761..fb4d211908e8c6e89a4a59b4cb99a0508fbbfc80 100644
--- a/setup.py
+++ b/setup.py
@@ -31,7 +31,7 @@ setup(
         "numpy==1.22",
         "wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
         "lac @ git+https://github.com/hugofloresgarcia/lac.git",
-        "audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git",
+        "descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2",
         "gradio", 
         "tensorboardX",
         "loralib",
diff --git a/vampnet/beats.py b/vampnet/beats.py
index 317496ef83d7b764fbbc51068c13170ce0c17e13..2b03a4e3df705a059cd34e6e01a72752fc4d8a98 100644
--- a/vampnet/beats.py
+++ b/vampnet/beats.py
@@ -9,6 +9,7 @@ from typing import Tuple
 from typing import Union
 
 import librosa
+import torch
 import numpy as np
 from audiotools import AudioSignal
 
@@ -203,7 +204,7 @@ class WaveBeat(BeatTracker):
     def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
         from wavebeat.dstcn import dsTCNModel
 
-        model = dsTCNModel.load_from_checkpoint(ckpt_path)
+        model = dsTCNModel.load_from_checkpoint(ckpt_path, map_location=torch.device(device))
         model.eval()
 
         self.device = device
diff --git a/vampnet/interface.py b/vampnet/interface.py
index 0a6e39182c9d91c1b76bcb18476f9c018a247543..39e313e949d1721df04ab112f3fb40bebba37b61 100644
--- a/vampnet/interface.py
+++ b/vampnet/interface.py
@@ -22,6 +22,7 @@ def signal_concat(
 
     return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
 
+
 def _load_model(
     ckpt: str, 
     lora_ckpt: str = None,
@@ -64,7 +65,7 @@ class Interface(torch.nn.Module):
     ):
         super().__init__()
         assert codec_ckpt is not None, "must provide a codec checkpoint"
-        self.codec = DAC.load(Path(codec_ckpt))
+        self.codec = DAC.load(codec_ckpt)
         self.codec.eval()
         self.codec.to(device)
 
@@ -275,34 +276,44 @@ class Interface(torch.nn.Module):
         
     def coarse_to_fine(
         self, 
-        coarse_z: torch.Tensor,
+        z: torch.Tensor,
+        mask: torch.Tensor = None,
         **kwargs
     ):
         assert self.c2f is not None, "No coarse2fine model loaded"
-        length = coarse_z.shape[-1]
+        length = z.shape[-1]
         chunk_len = self.s2t(self.c2f.chunk_size_s)
-        n_chunks = math.ceil(coarse_z.shape[-1] / chunk_len)
+        n_chunks = math.ceil(z.shape[-1] / chunk_len)
 
         # zero pad to chunk_len
         if length % chunk_len != 0:
             pad_len = chunk_len - (length % chunk_len)
-            coarse_z = torch.nn.functional.pad(coarse_z, (0, pad_len))
+            z = torch.nn.functional.pad(z, (0, pad_len))
+            mask = torch.nn.functional.pad(mask, (0, pad_len)) if mask is not None else None
 
-        n_codebooks_to_append = self.c2f.n_codebooks - coarse_z.shape[1]
+        n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1]
         if n_codebooks_to_append > 0:
-            coarse_z = torch.cat([
-                coarse_z,
-                torch.zeros(coarse_z.shape[0], n_codebooks_to_append, coarse_z.shape[-1]).long().to(self.device)
+            z = torch.cat([
+                z,
+                torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device)
             ], dim=1)
 
+        # set the mask to 0 for all conditioning codebooks
+        if mask is not None:
+            mask = mask.clone()
+            mask[:, :self.c2f.n_conditioning_codebooks, :] = 0
+
         fine_z = []
         for i in range(n_chunks):
-            chunk = coarse_z[:, :, i * chunk_len : (i + 1) * chunk_len]
+            chunk = z[:, :, i * chunk_len : (i + 1) * chunk_len]
+            mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len] if mask is not None else None
+            
             chunk = self.c2f.generate(
                 codec=self.codec,
                 time_steps=chunk_len,
                 start_tokens=chunk,
                 return_signal=False,
+                mask=mask_chunk,
                 **kwargs
             )
             fine_z.append(chunk)
@@ -337,6 +348,12 @@ class Interface(torch.nn.Module):
             **kwargs
         )
 
+        # add the fine codes back in
+        c_vamp = torch.cat(
+            [c_vamp, z[:, self.coarse.n_codebooks :, :]], 
+            dim=1
+        )
+
         if return_mask:
             return c_vamp, cz_masked
         
@@ -352,17 +369,18 @@ if __name__ == "__main__":
     at.util.seed(42)
 
     interface = Interface(
-        coarse_ckpt="./models/spotdl/coarse.pth", 
-        coarse2fine_ckpt="./models/spotdl/c2f.pth", 
-        codec_ckpt="./models/spotdl/codec.pth",
+        coarse_ckpt="./models/vampnet/coarse.pth", 
+        coarse2fine_ckpt="./models/vampnet/c2f.pth", 
+        codec_ckpt="./models/vampnet/codec.pth",
         device="cuda", 
         wavebeat_ckpt="./models/wavebeat.pth"
     )
 
 
-    sig = at.AudioSignal.zeros(duration=10, sample_rate=44100)
+    sig = at.AudioSignal('assets/example.wav')
 
     z = interface.encode(sig)
+    breakpoint()
 
     # mask = linear_random(z, 1.0)
     # mask = mask_and(
@@ -374,13 +392,14 @@ if __name__ == "__main__":
     #     )
     # )
 
-    mask = interface.make_beat_mask(
-        sig, 0.0, 0.075
-    )
+    # mask = interface.make_beat_mask(
+    #     sig, 0.0, 0.075
+    # )
     # mask = dropout(mask, 0.0)
     # mask = codebook_unmask(mask, 0)
+
+    mask = inpaint(z, n_prefix=100, n_suffix=100)
     
-    breakpoint()
     zv, mask_z = interface.coarse_vamp(
         z, 
         mask=mask,
@@ -389,16 +408,16 @@ if __name__ == "__main__":
         return_mask=True, 
         gen_fn=interface.coarse.generate
     )
+    
 
     use_coarse2fine = True
     if use_coarse2fine: 
-        zv = interface.coarse_to_fine(zv, temperature=0.8)
+        zv = interface.coarse_to_fine(zv, temperature=0.8, mask=mask)
+        breakpoint()
 
     mask = interface.to_signal(mask_z).cpu()
 
     sig = interface.to_signal(zv).cpu()
     print("done")
 
-    sig.write("output3.wav")
-    mask.write("mask.wav")
         
\ No newline at end of file
diff --git a/vampnet/modules/__init__.py b/vampnet/modules/__init__.py
index 3481f32e0287faa9e79ba219f17d18529a4b57ac..3f4c8c226e42d022c60b620e8f21ccaf4e6a57bd 100644
--- a/vampnet/modules/__init__.py
+++ b/vampnet/modules/__init__.py
@@ -2,3 +2,5 @@ import audiotools
 
 audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"]
 audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention", "loralib"]
+
+from .transformer import VampNet
\ No newline at end of file