Spaces:
Build error
Build error
| import torch | |
| import julius | |
| import torchopenl3 | |
| import torchmetrics | |
| import pytorch_lightning as pl | |
| from typing import Tuple, List, Dict | |
| from argparse import ArgumentParser | |
| from deepafx_st.probes.cdpam_encoder import CDPAMEncoder | |
| from deepafx_st.probes.random_mel import RandomMelProjection | |
| import deepafx_st.utils as utils | |
| from deepafx_st.utils import DSPMode | |
| from deepafx_st.system import System | |
| from deepafx_st.data.style import StyleDataset | |
| class ProbeSystem(pl.LightningModule): | |
| def __init__( | |
| self, | |
| audio_dir=None, | |
| num_classes=5, | |
| task="style", | |
| encoder_type="deepafx_st_autodiff", | |
| deepafx_st_autodiff_ckpt=None, | |
| deepafx_st_spsa_ckpt=None, | |
| deepafx_st_proxy0_ckpt=None, | |
| probe_type="linear", | |
| batch_size=32, | |
| lr=3e-4, | |
| lr_patience=20, | |
| patience=10, | |
| preload=False, | |
| sample_rate=24000, | |
| shuffle=True, | |
| num_workers=16, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| if "deepafx_st" in self.hparams.encoder_type: | |
| if "autodiff" in self.hparams.encoder_type: | |
| self.hparams.deepafx_st_ckpt = self.hparams.deepafx_st_autodiff_ckpt | |
| elif "spsa" in self.hparams.encoder_type: | |
| self.hparams.deepafx_st_ckpt = self.hparams.deepafx_st_spsa_ckpt | |
| elif "proxy0" in self.hparams.encoder_type: | |
| self.hparams.deepafx_st_ckpt = self.hparams.deepafx_st_proxy0_ckpt | |
| else: | |
| raise RuntimeError(f"Invalid encoder_type: {self.hparams.encoder_type}") | |
| if self.hparams.deepafx_st_ckpt is None: | |
| raise RuntimeError( | |
| f"Must supply {self.hparams.encoder_type}_ckpt checkpoint." | |
| ) | |
| use_dsp = DSPMode.NONE | |
| system = System.load_from_checkpoint( | |
| self.hparams.deepafx_st_ckpt, | |
| use_dsp=use_dsp, | |
| batch_size=self.hparams.batch_size, | |
| spsa_parallel=False, | |
| proxy_ckpts=[], | |
| strict=False, | |
| ) | |
| system.eval() | |
| self.encoder = system.encoder | |
| self.hparams.embed_dim = self.encoder.embed_dim | |
| # freeze weights | |
| for name, param in self.encoder.named_parameters(): | |
| param.requires_grad = False | |
| elif self.hparams.encoder_type == "openl3": | |
| self.encoder = torchopenl3.models.load_audio_embedding_model( | |
| input_repr=self.hparams.openl3_input_repr, | |
| embedding_size=self.hparams.openl3_embedding_size, | |
| content_type=self.hparams.openl3_content_type, | |
| ) | |
| self.hparams.embed_dim = 6144 | |
| elif self.hparams.encoder_type == "random_mel": | |
| self.encoder = RandomMelProjection( | |
| self.hparams.sample_rate, | |
| self.hparams.random_mel_embedding_size, | |
| self.hparams.random_mel_n_mels, | |
| self.hparams.random_mel_n_fft, | |
| self.hparams.random_mel_hop_size, | |
| ) | |
| self.hparams.embed_dim = self.hparams.random_mel_embedding_size | |
| elif self.hparams.encoder_type == "cdpam": | |
| self.encoder = CDPAMEncoder(self.hparams.cdpam_ckpt) | |
| self.encoder.eval() | |
| self.hparams.embed_dim = self.encoder.embed_dim | |
| else: | |
| raise ValueError(f"Invalid encoder_type: {self.hparams.encoder_type}") | |
| if self.hparams.probe_type == "linear": | |
| if self.hparams.task == "style": | |
| self.probe = torch.nn.Sequential( | |
| torch.nn.Linear(self.hparams.embed_dim, self.hparams.num_classes), | |
| # torch.nn.Softmax(-1), | |
| ) | |
| elif self.hparams.probe_type == "mlp": | |
| if self.hparams.task == "style": | |
| self.probe = torch.nn.Sequential( | |
| torch.nn.Linear(self.hparams.embed_dim, 512), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(512, 512), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(512, self.hparams.num_classes), | |
| ) | |
| self.accuracy = torchmetrics.Accuracy() | |
| self.f1_score = torchmetrics.F1Score(self.hparams.num_classes) | |
| def forward(self, x): | |
| bs, chs, samp = x.size() | |
| with torch.no_grad(): | |
| if "deepafx_st" in self.hparams.encoder_type: | |
| x /= x.abs().max() | |
| x *= 10 ** (-12.0 / 20) # with min 12 dBFS headroom | |
| e = self.encoder(x) | |
| norm = torch.norm(e, p=2, dim=-1, keepdim=True) | |
| e = e / norm | |
| elif self.hparams.encoder_type == "openl3": | |
| # x = julius.resample_frac(x, self.hparams.sample_rate, 48000) | |
| e, ts = torchopenl3.get_audio_embedding( | |
| x, | |
| 48000, | |
| model=self.encoder, | |
| input_repr="mel128", | |
| content_type="music", | |
| ) | |
| e = e.permute(0, 2, 1) | |
| e = e.mean(dim=-1) | |
| # normalize by L2 norm | |
| norm = torch.norm(e, p=2, dim=-1, keepdim=True) | |
| e = e / norm | |
| elif self.hparams.encoder_type == "random_mel": | |
| e = self.encoder(x) | |
| norm = torch.norm(e, p=2, dim=-1, keepdim=True) | |
| e = e / norm | |
| elif self.hparams.encoder_type == "cdpam": | |
| # x = julius.resample_frac(x, self.hparams.sample_rate, 22050) | |
| x = torch.round(x * 32768) | |
| e = self.encoder(x) | |
| return self.probe(e) | |
| def common_step( | |
| self, | |
| batch: Tuple, | |
| batch_idx: int, | |
| optimizer_idx: int = 0, | |
| train: bool = True, | |
| ): | |
| loss = 0 | |
| x, y = batch | |
| y_hat = self(x) | |
| # compute CE | |
| if self.hparams.task == "style": | |
| loss = torch.nn.functional.cross_entropy(y_hat, y) | |
| if not train: | |
| # store audio data | |
| data_dict = {"x": x.float().cpu()} | |
| else: | |
| data_dict = {} | |
| self.log( | |
| "train_loss" if train else "val_loss", | |
| loss, | |
| on_step=True, | |
| on_epoch=True, | |
| prog_bar=False, | |
| logger=True, | |
| sync_dist=True, | |
| ) | |
| if not train and self.hparams.task == "style": | |
| self.log("val_acc_step", self.accuracy(y_hat, y)) | |
| self.log("val_f1_step", self.f1_score(y_hat, y)) | |
| return loss, data_dict | |
| def training_step(self, batch, batch_idx, optimizer_idx=0): | |
| loss, _ = self.common_step(batch, batch_idx) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| loss, data_dict = self.common_step(batch, batch_idx, train=False) | |
| if batch_idx == 0: | |
| return data_dict | |
| def validation_epoch_end(self, outputs) -> None: | |
| if self.hparams.task == "style": | |
| self.log("val_acc_epoch", self.accuracy.compute()) | |
| self.log("val_f1_epoch", self.f1_score.compute()) | |
| return super().validation_epoch_end(outputs) | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.AdamW( | |
| self.probe.parameters(), | |
| lr=self.hparams.lr, | |
| betas=(0.9, 0.999), | |
| ) | |
| ms1 = int(self.hparams.max_epochs * 0.8) | |
| ms2 = int(self.hparams.max_epochs * 0.95) | |
| print( | |
| "Learning rate schedule:", | |
| f"0 {self.hparams.lr:0.2e} -> ", | |
| f"{ms1} {self.hparams.lr*0.1:0.2e} -> ", | |
| f"{ms2} {self.hparams.lr*0.01:0.2e}", | |
| ) | |
| scheduler = torch.optim.lr_scheduler.MultiStepLR( | |
| optimizer, | |
| milestones=[ms1, ms2], | |
| gamma=0.1, | |
| ) | |
| return [optimizer], {"scheduler": scheduler, "monitor": "val_loss"} | |
| def train_dataloader(self): | |
| if self.hparams.task == "style": | |
| train_dataset = StyleDataset( | |
| self.hparams.audio_dir, | |
| "train", | |
| sample_rate=self.hparams.encoder_sample_rate, | |
| ) | |
| g = torch.Generator() | |
| g.manual_seed(0) | |
| return torch.utils.data.DataLoader( | |
| train_dataset, | |
| num_workers=self.hparams.num_workers, | |
| batch_size=self.hparams.batch_size, | |
| shuffle=True, | |
| worker_init_fn=utils.seed_worker, | |
| generator=g, | |
| pin_memory=True, | |
| ) | |
| def val_dataloader(self): | |
| if self.hparams.task == "style": | |
| val_dataset = StyleDataset( | |
| self.hparams.audio_dir, | |
| subset="val", | |
| sample_rate=self.hparams.encoder_sample_rate, | |
| ) | |
| g = torch.Generator() | |
| g.manual_seed(0) | |
| return torch.utils.data.DataLoader( | |
| val_dataset, | |
| num_workers=self.hparams.num_workers, | |
| batch_size=self.hparams.batch_size, | |
| worker_init_fn=utils.seed_worker, | |
| generator=g, | |
| pin_memory=True, | |
| ) | |
| # add any model hyperparameters here | |
| def add_model_specific_args(parent_parser): | |
| parser = ArgumentParser(parents=[parent_parser], add_help=False) | |
| # --- Model --- | |
| parser.add_argument("--encoder_type", type=str, default="deeapfx2") | |
| parser.add_argument("--probe_type", type=str, default="linear") | |
| parser.add_argument("--task", type=str, default="style") | |
| parser.add_argument("--encoder_sample_rate", type=int, default=24000) | |
| # --- deeapfx2 --- | |
| parser.add_argument("--deepafx_st_autodiff_ckpt", type=str) | |
| parser.add_argument("--deepafx_st_spsa_ckpt", type=str) | |
| parser.add_argument("--deepafx_st_proxy0_ckpt", type=str) | |
| # --- cdpam --- | |
| parser.add_argument("--cdpam_ckpt", type=str) | |
| # --- openl3 --- | |
| parser.add_argument("--openl3_input_repr", type=str, default="mel128") | |
| parser.add_argument("--openl3_content_type", type=str, default="env") | |
| parser.add_argument("--openl3_embedding_size", type=int, default=6144) | |
| # --- random_mel --- | |
| parser.add_argument("--random_mel_embedding_size", type=str, default=4096) | |
| parser.add_argument("--random_mel_n_fft", type=str, default=4096) | |
| parser.add_argument("--random_mel_hop_size", type=str, default=1024) | |
| parser.add_argument("--random_mel_n_mels", type=str, default=128) | |
| # --- Training --- | |
| parser.add_argument("--audio_dir", type=str) | |
| parser.add_argument("--num_classes", type=int, default=5) | |
| parser.add_argument("--batch_size", type=int, default=32) | |
| parser.add_argument("--lr", type=float, default=3e-4) | |
| parser.add_argument("--lr_patience", type=int, default=20) | |
| parser.add_argument("--patience", type=int, default=10) | |
| parser.add_argument("--preload", action="store_true") | |
| parser.add_argument("--sample_rate", type=int, default=24000) | |
| parser.add_argument("--num_workers", type=int, default=8) | |
| return parser | |