Spaces:
Runtime error
Runtime error
File size: 5,399 Bytes
bd0a813 1160793 3f8f152 9ff4511 bd0a813 9ff4511 3f8f152 9ff4511 1160793 20c7778 bd0a813 20c7778 9ff4511 20c7778 bd0a813 20c7778 08d9656 20c7778 bd0a813 9ff4511 3f8f152 95d8ea8 9ff4511 1160793 3f8f152 1160793 20c7778 1160793 bd0a813 20c7778 1160793 9ff4511 1160793 20c7778 1160793 bd0a813 1160793 bd0a813 1160793 9ff4511 1160793 bd0a813 20c7778 1160793 bd0a813 20c7778 1160793 95d8ea8 1160793 20c7778 1160793 95d8ea8 1160793 95d8ea8 20c7778 95d8ea8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import os
import torch
from torch.utils.data import DataLoader
import omegaconf
from omegaconf import DictConfig
import wandb
from checkpoing_saver import CheckpointSaver
from denoisers import get_model
from optimizers import get_optimizer
from losses import get_loss
from datasets import get_datasets
from testing.metrics import Metrics
from datasets.minimal import Minimal
from tqdm import tqdm
def init_wandb(cfg):
wandb.login(key=cfg['wandb']['api_key'], host=cfg['wandb']['host'])
wandb.init(project=cfg['wandb']['project'],
notes=cfg['wandb']['notes'],
config=omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
resume=cfg['wandb']['resume'],
name=cfg['wandb']['run_name'])
if wandb.run.resumed:
api = wandb.Api()
runs = api.runs(f"{cfg['wandb']['entity']}/{cfg['wandb']['project']}",
order='train_pesq')
run = [run for run in runs if run.name == cfg['wandb']['run_name'] and run.state != 'running'][0]
artifacts = run.logged_artifacts()
best_model = [artifact for artifact in artifacts if artifact.type == 'model'][0]
best_model.download()
def train(cfg: DictConfig):
device = torch.device(f'cuda:{cfg.gpu}' if torch.cuda.is_available() else 'cpu')
init_wandb(cfg)
checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name,
decreasing=False)
metrics = Metrics(source_rate=cfg['dataloader']['sample_rate']).to(device)
model = get_model(cfg['model']).to(device)
optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
loss_fn = get_loss(cfg['loss'], device)
train_dataset, valid_dataset = get_datasets(cfg)
minimal_dataset = Minimal(cfg)
dataloaders = {
'train': DataLoader(train_dataset, batch_size=cfg['dataloader']['train_batch_size'], shuffle=True,
num_workers=cfg['dataloader']['num_workers']),
'val': DataLoader(valid_dataset, batch_size=cfg['dataloader']['valid_batch_size'], shuffle=False,
num_workers=cfg['dataloader']['num_workers']),
'minimal': DataLoader(minimal_dataset)
}
wandb.watch(model, log_freq=cfg['wandb']['log_interval'])
epoch = 0
while epoch < cfg['training']['num_epochs']:
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss, running_pesq, running_stoi = 0.0, 0.0, 0.0
loop = tqdm(dataloaders[phase])
for i, (inputs, labels) in enumerate(loop):
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
loss = loss_fn(outputs, labels)
if phase == 'train':
loss.backward()
optimizer.step()
running_metrics = metrics(denoised=outputs, clean=labels)
running_loss += loss.item() * inputs.size(0)
running_pesq += running_metrics['PESQ']
running_stoi += running_metrics['STOI']
loop.set_description(f"Epoch [{epoch}/{cfg['training']['num_epochs']}][{phase}]")
loop.set_postfix(loss=running_loss / (i + 1) / inputs.size(0),
pesq=running_pesq / (i + 1) / inputs.size(0),
stoi=running_stoi / (i + 1) / inputs.size(0))
if phase == 'train' and i % cfg['wandb']['log_interval'] == 0:
wandb.log({"train_loss": running_loss / (i + 1) / inputs.size(0),
"train_pesq": running_pesq / (i + 1) / inputs.size(0),
"train_stoi": running_stoi / (i + 1) / inputs.size(0)})
epoch_loss = running_loss / len(dataloaders[phase].dataset)
eposh_pesq = running_pesq / len(dataloaders[phase].dataset)
eposh_stoi = running_stoi / len(dataloaders[phase].dataset)
wandb.log({f"{phase}_loss": epoch_loss,
f"{phase}_pesq": eposh_pesq,
f"{phase}_stoi": eposh_stoi})
if phase == 'val':
for i, (wav, rate) in enumerate(dataloaders['minimal']):
if cfg['dataloader']['normalize']:
std = torch.std(wav)
wav = wav / std
prediction = model(wav.to(device))
prediction = prediction * std
else:
prediction = model(wav.to(device))
wandb.log({
f"{i}_example": wandb.Audio(
prediction.detach().cpu().numpy()[0][0],
sample_rate=rate)})
checkpoint_saver(model, epoch, metric_val=eposh_pesq,
optimizer=optimizer, loss=epoch_loss)
epoch += 1
if __name__ == "__main__":
pass
|