Spaces:
Sleeping
Sleeping
import torch | |
from torch import Tensor | |
import pytorch_lightning as pl | |
from einops import rearrange | |
import wandb | |
from audio_diffusion_pytorch import AudioDiffusionModel | |
import sys | |
sys.path.append("/Users/matthewrice/Developer/remfx/umx/") | |
from umx.openunmix.model import OpenUnmix, Separator | |
SAMPLE_RATE = 22050 # From audio-diffusion-pytorch | |
class OpenUnmixModel(pl.LightningModule): | |
def __init__( | |
self, | |
n_fft: int = 2048, | |
hop_length: int = 512, | |
alpha: float = 0.3, | |
): | |
super().__init__() | |
self.model = OpenUnmix( | |
nb_channels=1, | |
nb_bins=n_fft // 2 + 1, | |
) | |
self.n_fft = n_fft | |
self.hop_length = hop_length | |
self.alpha = alpha | |
window = torch.hann_window(n_fft) | |
self.register_buffer("window", window) | |
def forward(self, x: torch.Tensor): | |
return self.model(x) | |
def training_step(self, batch, batch_idx): | |
loss, _ = self.common_step(batch, batch_idx, mode="train") | |
return loss | |
def validation_step(self, batch, batch_idx): | |
loss, Y = self.common_step(batch, batch_idx, mode="val") | |
return loss, Y | |
def common_step(self, batch, batch_idx, mode: str = "train"): | |
x, target, label = batch | |
X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha) | |
Y = self(X) | |
Y_hat = spectrogram( | |
target, self.window, self.n_fft, self.hop_length, self.alpha | |
) | |
loss = torch.nn.functional.mse_loss(Y, Y_hat) | |
self.log(f"{mode}_loss", loss, on_step=True, on_epoch=True) | |
return loss, Y | |
def configure_optimizers(self): | |
return torch.optim.Adam( | |
self.parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3 | |
) | |
def on_validation_epoch_start(self): | |
self.log_next = True | |
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): | |
if self.log_next: | |
x, target, label = batch | |
s = Separator( | |
target_models={"other": self.model}, | |
nb_channels=1, | |
sample_rate=SAMPLE_RATE, | |
n_fft=self.n_fft, | |
n_hop=self.hop_length, | |
) | |
outputs = s(x).squeeze(1) | |
log_wandb_audio_batch( | |
id="sample", | |
samples=x, | |
sampling_rate=SAMPLE_RATE, | |
caption=f"Epoch {self.current_epoch}", | |
) | |
log_wandb_audio_batch( | |
id="prediction", | |
samples=outputs, | |
sampling_rate=SAMPLE_RATE, | |
caption=f"Epoch {self.current_epoch}", | |
) | |
log_wandb_audio_batch( | |
id="target", | |
samples=target, | |
sampling_rate=SAMPLE_RATE, | |
caption=f"Epoch {self.current_epoch}", | |
) | |
self.log_next = False | |
class DiffusionGenerationModel(pl.LightningModule): | |
def __init__(self, model: torch.nn.Module): | |
super().__init__() | |
self.model = model | |
def forward(self, x: torch.Tensor): | |
return self.model(x) | |
def sample(self, *args, **kwargs) -> Tensor: | |
return self.model.sample(*args, **kwargs) | |
def training_step(self, batch, batch_idx): | |
loss = self.common_step(batch, batch_idx, mode="train") | |
return loss | |
def validation_step(self, batch, batch_idx): | |
loss = self.common_step(batch, batch_idx, mode="val") | |
def common_step(self, batch, batch_idx, mode: str = "train"): | |
x, target, label = batch | |
loss = self(x) | |
self.log(f"{mode}_loss", loss, on_step=True, on_epoch=True) | |
return loss | |
def configure_optimizers(self): | |
return torch.optim.Adam( | |
self.parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3 | |
) | |
def on_validation_epoch_start(self): | |
self.log_next = True | |
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): | |
x, target, label = batch | |
if self.log_next: | |
self.log_sample(x) | |
self.log_next = False | |
def log_sample(self, batch, num_steps=10): | |
# Get start diffusion noise | |
noise = torch.randn(batch.shape, device=self.device) | |
sampled = self.sample(noise=noise, num_steps=num_steps) # Suggested range: 2-50 | |
log_wandb_audio_batch( | |
id="sample", | |
samples=sampled, | |
sampling_rate=SAMPLE_RATE, | |
caption=f"Sampled in {num_steps} steps", | |
) | |
def log_wandb_audio_batch( | |
id: str, samples: Tensor, sampling_rate: int, caption: str = "" | |
): | |
num_items = samples.shape[0] | |
samples = rearrange(samples, "b c t -> b t c") | |
for idx in range(num_items): | |
wandb.log( | |
{ | |
f"{id}_{idx}": wandb.Audio( | |
samples[idx].cpu().numpy(), | |
caption=caption, | |
sample_rate=sampling_rate, | |
) | |
} | |
) | |
def spectrogram( | |
x: torch.Tensor, | |
window: torch.Tensor, | |
n_fft: int, | |
hop_length: int, | |
alpha: float, | |
) -> torch.Tensor: | |
bs, chs, samp = x.size() | |
x = x.view(bs * chs, -1) # move channels onto batch dim | |
X = torch.stft( | |
x, | |
n_fft=n_fft, | |
hop_length=hop_length, | |
window=window, | |
return_complex=True, | |
) | |
# move channels back | |
X = X.view(bs, chs, X.shape[-2], X.shape[-1]) | |
return torch.pow(X.abs() + 1e-8, alpha) | |