Spaces:
Runtime error
Runtime error
import os | |
import torch | |
from torch.utils.data import DataLoader | |
import omegaconf | |
from omegaconf import DictConfig | |
import wandb | |
from checkpoing_saver import CheckpointSaver | |
from denoisers import get_model | |
from optimizers import get_optimizer | |
from losses import get_loss | |
from datasets import get_datasets | |
from testing.metrics import Metrics | |
from datasets.minimal import Minimal | |
from tqdm import tqdm | |
def init_wandb(cfg): | |
wandb.login(key=cfg['wandb']['api_key'], host=cfg['wandb']['host']) | |
wandb.init(project=cfg['wandb']['project'], | |
notes=cfg['wandb']['notes'], | |
config=omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True), | |
resume=cfg['wandb']['resume'], | |
name=cfg['wandb']['run_name']) | |
if wandb.run.resumed: | |
api = wandb.Api() | |
runs = api.runs(f"{cfg['wandb']['entity']}/{cfg['wandb']['project']}", | |
order='train_pesq') | |
run = [run for run in runs if run.name == cfg['wandb']['run_name'] and run.state != 'running'][0] | |
artifacts = run.logged_artifacts() | |
best_model = [artifact for artifact in artifacts if artifact.type == 'model'][0] | |
best_model.download() | |
def train(cfg: DictConfig): | |
device = torch.device(f'cuda:{cfg.gpu}' if torch.cuda.is_available() else 'cpu') | |
init_wandb(cfg) | |
checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name, | |
decreasing=False) | |
metrics = Metrics(source_rate=cfg['dataloader']['sample_rate']).to(device) | |
model = get_model(cfg['model']).to(device) | |
optimizer = get_optimizer(model.parameters(), cfg['optimizer']) | |
loss_fn = get_loss(cfg['loss'], device) | |
train_dataset, valid_dataset = get_datasets(cfg) | |
minimal_dataset = Minimal(cfg) | |
dataloaders = { | |
'train': DataLoader(train_dataset, batch_size=cfg['dataloader']['train_batch_size'], shuffle=True, | |
num_workers=cfg['dataloader']['num_workers']), | |
'val': DataLoader(valid_dataset, batch_size=cfg['dataloader']['valid_batch_size'], shuffle=False, | |
num_workers=cfg['dataloader']['num_workers']), | |
'minimal': DataLoader(minimal_dataset) | |
} | |
wandb.watch(model, log_freq=cfg['wandb']['log_interval']) | |
epoch = 0 | |
while epoch < cfg['training']['num_epochs']: | |
for phase in ['train', 'val']: | |
if phase == 'train': | |
model.train() | |
else: | |
model.eval() | |
running_loss, running_pesq, running_stoi = 0.0, 0.0, 0.0 | |
loop = tqdm(dataloaders[phase]) | |
for i, (inputs, labels) in enumerate(loop): | |
inputs = inputs.to(device) | |
labels = labels.to(device) | |
optimizer.zero_grad() | |
with torch.set_grad_enabled(phase == 'train'): | |
outputs = model(inputs) | |
loss = loss_fn(outputs, labels) | |
if phase == 'train': | |
loss.backward() | |
optimizer.step() | |
running_metrics = metrics(denoised=outputs, clean=labels) | |
running_loss += loss.item() * inputs.size(0) | |
running_pesq += running_metrics['PESQ'] | |
running_stoi += running_metrics['STOI'] | |
loop.set_description(f"Epoch [{epoch}/{cfg['training']['num_epochs']}][{phase}]") | |
loop.set_postfix(loss=running_loss / (i + 1) / inputs.size(0), | |
pesq=running_pesq / (i + 1) / inputs.size(0), | |
stoi=running_stoi / (i + 1) / inputs.size(0)) | |
if phase == 'train' and i % cfg['wandb']['log_interval'] == 0: | |
wandb.log({"train_loss": running_loss / (i + 1) / inputs.size(0), | |
"train_pesq": running_pesq / (i + 1) / inputs.size(0), | |
"train_stoi": running_stoi / (i + 1) / inputs.size(0)}) | |
epoch_loss = running_loss / len(dataloaders[phase].dataset) | |
eposh_pesq = running_pesq / len(dataloaders[phase].dataset) | |
eposh_stoi = running_stoi / len(dataloaders[phase].dataset) | |
wandb.log({f"{phase}_loss": epoch_loss, | |
f"{phase}_pesq": eposh_pesq, | |
f"{phase}_stoi": eposh_stoi}) | |
if phase == 'val': | |
for i, (wav, rate) in enumerate(dataloaders['minimal']): | |
if cfg['dataloader']['normalize']: | |
std = torch.std(wav) | |
wav = wav / std | |
prediction = model(wav.to(device)) | |
prediction = prediction * std | |
else: | |
prediction = model(wav.to(device)) | |
wandb.log({ | |
f"{i}_example": wandb.Audio( | |
prediction.detach().cpu().numpy()[0][0], | |
sample_rate=rate)}) | |
checkpoint_saver(model, epoch, metric_val=eposh_pesq, | |
optimizer=optimizer, loss=epoch_loss) | |
epoch += 1 | |
if __name__ == "__main__": | |
pass | |