RemFx / remfx /models.py
mattricesound's picture
Init dcunet and dptnet
7d6f241
raw
history blame
12.2 kB
import torch
import torchmetrics
import pytorch_lightning as pl
from torch import Tensor, nn
from einops import rearrange
from torchaudio.models import HDemucs
from audio_diffusion_pytorch import DiffusionModel
from auraloss.time import SISDRLoss
from auraloss.freq import MultiResolutionSTFTLoss
from umx.openunmix.model import OpenUnmix, Separator
from utils import FADLoss, spectrogram, log_wandb_audio_batch
from dptnet import DPTNet_base
from dcunet import RefineSpectrogramUnet
class RemFX(pl.LightningModule):
def __init__(
self,
lr: float,
lr_beta1: float,
lr_beta2: float,
lr_eps: float,
lr_weight_decay: float,
sample_rate: float,
network: nn.Module,
):
super().__init__()
self.lr = lr
self.lr_beta1 = lr_beta1
self.lr_beta2 = lr_beta2
self.lr_eps = lr_eps
self.lr_weight_decay = lr_weight_decay
self.sample_rate = sample_rate
self.model = network
self.metrics = nn.ModuleDict(
{
"SISDR": SISDRLoss(),
"STFT": MultiResolutionSTFTLoss(),
"FAD": FADLoss(sample_rate=sample_rate),
}
)
# Log first batch metrics input vs output only once
self.log_train_audio = True
@property
def device(self):
return next(self.model.parameters()).device
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
list(self.model.parameters()),
lr=self.lr,
betas=(self.lr_beta1, self.lr_beta2),
eps=self.lr_eps,
weight_decay=self.lr_weight_decay,
)
return optimizer
# Add step-based learning rate scheduler
def optimizer_step(
self,
epoch,
batch_idx,
optimizer,
optimizer_idx,
optimizer_closure,
on_tpu,
using_lbfgs,
):
# update params
optimizer.step(closure=optimizer_closure)
# update learning rate. Reduce by factor of 10 at 80% and 95% of training
if self.trainer.global_step == 0.8 * self.trainer.max_steps:
for pg in optimizer.param_groups:
pg["lr"] = 0.1 * pg["lr"]
if self.trainer.global_step == 0.95 * self.trainer.max_steps:
for pg in optimizer.param_groups:
pg["lr"] = 0.1 * pg["lr"]
def training_step(self, batch, batch_idx):
loss = self.common_step(batch, batch_idx, mode="train")
return loss
def validation_step(self, batch, batch_idx):
loss = self.common_step(batch, batch_idx, mode="valid")
return loss
def test_step(self, batch, batch_idx):
loss = self.common_step(batch, batch_idx, mode="test")
return loss
def common_step(self, batch, batch_idx, mode: str = "train"):
x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
loss, output = self.model((x, y))
self.log(f"{mode}_loss", loss)
# Metric logging
with torch.no_grad():
for metric in self.metrics:
# SISDR returns negative values, so negate them
if metric == "SISDR":
negate = -1
else:
negate = 1
# Only Log FAD on test set
if metric == "FAD" and mode != "test":
continue
self.log(
f"{mode}_{metric}",
negate * self.metrics[metric](output, y),
on_step=False,
on_epoch=True,
logger=True,
prog_bar=True,
sync_dist=True,
)
return loss
def on_train_batch_start(self, batch, batch_idx):
# Log initial audio
if self.log_train_audio:
x, y, _, _ = batch
# Concat samples together for easier viewing in dashboard
input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
log_wandb_audio_batch(
logger=self.logger,
id="input_effected_audio",
samples=input_samples.cpu(),
sampling_rate=self.sample_rate,
caption="Training Data",
)
log_wandb_audio_batch(
logger=self.logger,
id="target_audio",
samples=target_samples.cpu(),
sampling_rate=self.sample_rate,
caption="Target Data",
)
self.log_train_audio = False
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
x, target, _, _ = batch
# Log Input Metrics
for metric in self.metrics:
# SISDR returns negative values, so negate them
if metric == "SISDR":
negate = -1
else:
negate = 1
# Only Log FAD on test set
if metric == "FAD":
continue
self.log(
f"Input_{metric}",
negate * self.metrics[metric](x, target),
on_step=False,
on_epoch=True,
logger=True,
prog_bar=True,
sync_dist=True,
)
# Only run on first batch
if batch_idx == 0:
self.model.eval()
with torch.no_grad():
y = self.model.sample(x)
# Concat samples together for easier viewing in dashboard
# 2 seconds of silence between each sample
silence = torch.zeros_like(x)
silence = silence[:, : self.sample_rate * 2]
concat_samples = torch.cat([y, silence, x, silence, target], dim=-1)
log_wandb_audio_batch(
logger=self.logger,
id="prediction_input_target",
samples=concat_samples.cpu(),
sampling_rate=self.sample_rate,
caption=f"Epoch {self.current_epoch}",
)
self.model.train()
def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
# Log FAD
x, target, _, _ = batch
self.log(
"Input_FAD",
self.metrics["FAD"](x, target),
on_step=False,
on_epoch=True,
logger=True,
prog_bar=True,
sync_dist=True,
)
class OpenUnmixModel(nn.Module):
def __init__(
self,
n_fft: int = 2048,
hop_length: int = 512,
n_channels: int = 1,
alpha: float = 0.3,
sample_rate: int = 22050,
):
super().__init__()
self.n_channels = n_channels
self.n_fft = n_fft
self.hop_length = hop_length
self.alpha = alpha
window = torch.hann_window(n_fft)
self.register_buffer("window", window)
self.num_bins = self.n_fft // 2 + 1
self.sample_rate = sample_rate
self.model = OpenUnmix(
nb_channels=self.n_channels,
nb_bins=self.num_bins,
)
self.separator = Separator(
target_models={"other": self.model},
nb_channels=self.n_channels,
sample_rate=self.sample_rate,
n_fft=self.n_fft,
n_hop=self.hop_length,
)
self.mrstftloss = MultiResolutionSTFTLoss(
n_bins=self.num_bins, sample_rate=self.sample_rate
)
self.l1loss = nn.L1Loss()
def forward(self, batch):
x, target = batch
X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
Y = self.model(X)
sep_out = self.separator(x).squeeze(1)
loss = self.mrstftloss(sep_out, target) + self.l1loss(sep_out, target) * 100
return loss, sep_out
def sample(self, x: Tensor) -> Tensor:
return self.separator(x).squeeze(1)
class DemucsModel(nn.Module):
def __init__(self, sample_rate, **kwargs) -> None:
super().__init__()
self.model = HDemucs(**kwargs)
self.num_bins = kwargs["nfft"] // 2 + 1
self.mrstftloss = MultiResolutionSTFTLoss(
n_bins=self.num_bins, sample_rate=sample_rate
)
self.l1loss = nn.L1Loss()
def forward(self, batch):
x, target = batch
output = self.model(x).squeeze(1)
loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
return loss, output
def sample(self, x: Tensor) -> Tensor:
return self.model(x).squeeze(1)
class DiffusionGenerationModel(nn.Module):
def __init__(self, n_channels: int = 1):
super().__init__()
self.model = DiffusionModel(in_channels=n_channels)
def forward(self, batch):
x, target = batch
sampled_out = self.model.sample(x)
return self.model(x), sampled_out
def sample(self, x: Tensor, num_steps: int = 10) -> Tensor:
noise = torch.randn(x.shape).to(x)
return self.model.sample(noise, num_steps=num_steps)
class DPTNetModel(nn.Module):
def __init__(self, sample_rate, **kwargs):
super().__init__()
self.model = DPTNet_base(**kwargs)
self.mrstftloss = MultiResolutionSTFTLoss(
n_bins=self.num_bins, sample_rate=sample_rate
)
self.l1loss = nn.L1Loss()
def forward(self, batch):
x, target = batch
output = self.model(x).squeeze(1)
loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
return loss, output
def sample(self, x: Tensor) -> Tensor:
return self.model.sample(x)
class DCUNetModel(nn.Module):
def __init__(self, sample_rate, **kwargs):
super().__init__()
self.model = RefineSpectrogramUnet(**kwargs)
self.mrstftloss = MultiResolutionSTFTLoss(
n_bins=self.num_bins, sample_rate=sample_rate
)
self.l1loss = nn.L1Loss()
def forward(self, batch):
x, target = batch
output = self.model(x).squeeze(1)
loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
return loss, output
def sample(self, x: Tensor) -> Tensor:
return self.model.sample(x)
class FXClassifier(pl.LightningModule):
def __init__(
self,
lr: float,
lr_weight_decay: float,
sample_rate: float,
network: nn.Module,
):
super().__init__()
self.lr = lr
self.lr_weight_decay = lr_weight_decay
self.sample_rate = sample_rate
self.network = network
def forward(self, x: torch.Tensor):
return self.network(x)
def common_step(self, batch, batch_idx, mode: str = "train"):
x, y, dry_label, wet_label = batch
pred_label = self.network(x)
loss = nn.functional.cross_entropy(pred_label, dry_label)
self.log(
f"{mode}_loss",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
)
self.log(
f"{mode}_mAP",
torchmetrics.functional.retrieval_average_precision(
pred_label, dry_label.long()
),
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
)
return loss
def training_step(self, batch, batch_idx):
return self.common_step(batch, batch_idx, mode="train")
def validation_step(self, batch, batch_idx):
return self.common_step(batch, batch_idx, mode="valid")
def test_step(self, batch, batch_idx):
return self.common_step(batch, batch_idx, mode="test")
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.network.parameters(),
lr=self.lr,
weight_decay=self.lr_weight_decay,
)
return optimizer