denoising / train.py
BorisovMaksim's picture
fix app.py
08d9656
raw
history blame
5.4 kB
import os
import torch
from torch.utils.data import DataLoader
import omegaconf
from omegaconf import DictConfig
import wandb
from checkpoing_saver import CheckpointSaver
from denoisers import get_model
from optimizers import get_optimizer
from losses import get_loss
from datasets import get_datasets
from testing.metrics import Metrics
from datasets.minimal import Minimal
from tqdm import tqdm
def init_wandb(cfg):
wandb.login(key=cfg['wandb']['api_key'], host=cfg['wandb']['host'])
wandb.init(project=cfg['wandb']['project'],
notes=cfg['wandb']['notes'],
config=omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
resume=cfg['wandb']['resume'],
name=cfg['wandb']['run_name'])
if wandb.run.resumed:
api = wandb.Api()
runs = api.runs(f"{cfg['wandb']['entity']}/{cfg['wandb']['project']}",
order='train_pesq')
run = [run for run in runs if run.name == cfg['wandb']['run_name'] and run.state != 'running'][0]
artifacts = run.logged_artifacts()
best_model = [artifact for artifact in artifacts if artifact.type == 'model'][0]
best_model.download()
def train(cfg: DictConfig):
device = torch.device(f'cuda:{cfg.gpu}' if torch.cuda.is_available() else 'cpu')
init_wandb(cfg)
checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name,
decreasing=False)
metrics = Metrics(source_rate=cfg['dataloader']['sample_rate']).to(device)
model = get_model(cfg['model']).to(device)
optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
loss_fn = get_loss(cfg['loss'], device)
train_dataset, valid_dataset = get_datasets(cfg)
minimal_dataset = Minimal(cfg)
dataloaders = {
'train': DataLoader(train_dataset, batch_size=cfg['dataloader']['train_batch_size'], shuffle=True,
num_workers=cfg['dataloader']['num_workers']),
'val': DataLoader(valid_dataset, batch_size=cfg['dataloader']['valid_batch_size'], shuffle=False,
num_workers=cfg['dataloader']['num_workers']),
'minimal': DataLoader(minimal_dataset)
}
wandb.watch(model, log_freq=cfg['wandb']['log_interval'])
epoch = 0
while epoch < cfg['training']['num_epochs']:
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss, running_pesq, running_stoi = 0.0, 0.0, 0.0
loop = tqdm(dataloaders[phase])
for i, (inputs, labels) in enumerate(loop):
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
loss = loss_fn(outputs, labels)
if phase == 'train':
loss.backward()
optimizer.step()
running_metrics = metrics(denoised=outputs, clean=labels)
running_loss += loss.item() * inputs.size(0)
running_pesq += running_metrics['PESQ']
running_stoi += running_metrics['STOI']
loop.set_description(f"Epoch [{epoch}/{cfg['training']['num_epochs']}][{phase}]")
loop.set_postfix(loss=running_loss / (i + 1) / inputs.size(0),
pesq=running_pesq / (i + 1) / inputs.size(0),
stoi=running_stoi / (i + 1) / inputs.size(0))
if phase == 'train' and i % cfg['wandb']['log_interval'] == 0:
wandb.log({"train_loss": running_loss / (i + 1) / inputs.size(0),
"train_pesq": running_pesq / (i + 1) / inputs.size(0),
"train_stoi": running_stoi / (i + 1) / inputs.size(0)})
epoch_loss = running_loss / len(dataloaders[phase].dataset)
eposh_pesq = running_pesq / len(dataloaders[phase].dataset)
eposh_stoi = running_stoi / len(dataloaders[phase].dataset)
wandb.log({f"{phase}_loss": epoch_loss,
f"{phase}_pesq": eposh_pesq,
f"{phase}_stoi": eposh_stoi})
if phase == 'val':
for i, (wav, rate) in enumerate(dataloaders['minimal']):
if cfg['dataloader']['normalize']:
std = torch.std(wav)
wav = wav / std
prediction = model(wav.to(device))
prediction = prediction * std
else:
prediction = model(wav.to(device))
wandb.log({
f"{i}_example": wandb.Audio(
prediction.detach().cpu().numpy()[0][0],
sample_rate=rate)})
checkpoint_saver(model, epoch, metric_val=eposh_pesq,
optimizer=optimizer, loss=epoch_loss)
epoch += 1
if __name__ == "__main__":
pass