Spaces:
Runtime error
Runtime error
import os | |
import torch | |
from torch.utils.data import DataLoader | |
import omegaconf | |
from omegaconf import DictConfig | |
import wandb | |
from checkpoing_saver import CheckpointSaver | |
from denoisers import get_model | |
from optimizers import get_optimizer | |
from losses import get_loss | |
from datasets import get_datasets | |
from testing.metrics import Metrics | |
from datasets.minimal import Minimal | |
def train(cfg: DictConfig): | |
device = torch.device(f'cuda:{cfg.gpu}' if torch.cuda.is_available() else 'cpu') | |
wandb.login(key=cfg['wandb']['api_key'], host=cfg['wandb']['host']) | |
wandb.init(project=cfg['wandb']['project'], | |
notes=cfg['wandb']['notes'], | |
tags=cfg['wandb']['tags'], | |
config=omegaconf.OmegaConf.to_container( | |
cfg, resolve=True, throw_on_missing=True)) | |
wandb.run.name = cfg['wandb']['run_name'] | |
checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name) | |
metrics = Metrics(rate=cfg['dataloader']['sample_rate']) | |
model = get_model(cfg['model']).to(device) | |
optimizer = get_optimizer(model.parameters(), cfg['optimizer']) | |
loss_fn = get_loss(cfg['loss'], device) | |
train_dataset, valid_dataset = get_datasets(cfg) | |
minimal_dataset = Minimal(cfg) | |
dataloaders = { | |
'train': DataLoader(train_dataset, batch_size=cfg['dataloader']['train_batch_size'], shuffle=True), | |
'val': DataLoader(valid_dataset, batch_size=cfg['dataloader']['valid_batch_size'], shuffle=True), | |
'minimal': DataLoader(minimal_dataset) | |
} | |
wandb.watch(model, log_freq=100) | |
for epoch in range(cfg['training']['num_epochs']): | |
for phase in ['train', 'val']: | |
if phase == 'train': | |
model.train() | |
else: | |
model.eval() | |
running_loss, running_pesq, running_stoi = 0.0, 0.0, 0.0 | |
for i, (inputs, labels) in enumerate(dataloaders[phase]): | |
inputs = inputs.to(device) | |
labels = labels.to(device) | |
optimizer.zero_grad() | |
with torch.set_grad_enabled(phase == 'train'): | |
outputs = model(inputs) | |
loss = loss_fn(outputs, labels) | |
if phase == 'train': | |
loss.backward() | |
optimizer.step() | |
running_metrics = metrics.calculate(denoised=outputs, clean=labels) | |
running_loss += loss.item() * inputs.size(0) | |
running_pesq += running_metrics['PESQ'] | |
running_stoi += running_metrics['STOI'] | |
if phase == 'train' and i % cfg['wandb']['log_interval'] == 0: | |
wandb.log({"train_loss": running_loss / (i + 1) / inputs.size(0), | |
"train_pesq": running_pesq / (i + 1) / inputs.size(0), | |
"train_stoi": running_stoi / (i + 1) / inputs.size(0)}) | |
epoch_loss = running_loss / len(dataloaders[phase].dataset) | |
eposh_pesq = running_pesq / len(dataloaders[phase].dataset) | |
eposh_stoi = running_stoi / len(dataloaders[phase].dataset) | |
wandb.log({f"{phase}_loss": epoch_loss, | |
f"{phase}_pesq": eposh_pesq, | |
f"{phase}_stoi": eposh_stoi}) | |
if phase == 'val': | |
for i, (wav, rate) in enumerate(dataloaders['minimal']): | |
prediction = model(wav.to(device)) | |
wandb.log({ | |
f"{i}_example": wandb.Audio( | |
prediction.detach().cpu().numpy()[0][0], | |
sample_rate=rate)}) | |
checkpoint_saver(model, epoch, metric_val=eposh_pesq, | |
optimizer=optimizer, loss=epoch_loss) | |
if __name__ == "__main__": | |
pass | |