File size: 5,399 Bytes
bd0a813
 
 
1160793
3f8f152
9ff4511
bd0a813
9ff4511
3f8f152
 
 
9ff4511
 
1160793
20c7778
bd0a813
20c7778
9ff4511
 
 
20c7778
 
 
 
 
 
 
 
 
 
 
 
bd0a813
20c7778
 
 
08d9656
 
20c7778
bd0a813
9ff4511
3f8f152
95d8ea8
9ff4511
1160793
3f8f152
1160793
20c7778
 
 
 
1160793
 
bd0a813
20c7778
 
 
1160793
 
 
 
 
9ff4511
1160793
20c7778
 
1160793
 
bd0a813
1160793
bd0a813
1160793
 
 
9ff4511
1160793
 
 
bd0a813
20c7778
1160793
 
 
bd0a813
20c7778
 
 
 
 
1160793
95d8ea8
 
 
 
 
 
 
1160793
 
 
 
 
 
 
20c7778
 
 
 
 
 
 
1160793
 
95d8ea8
1160793
 
95d8ea8
 
20c7778
95d8ea8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import torch
from torch.utils.data import DataLoader
import omegaconf
from omegaconf import DictConfig
import wandb

from checkpoing_saver import CheckpointSaver
from denoisers import get_model
from optimizers import get_optimizer
from losses import get_loss
from datasets import get_datasets
from testing.metrics import Metrics
from datasets.minimal import Minimal
from tqdm import tqdm

def init_wandb(cfg):
    wandb.login(key=cfg['wandb']['api_key'], host=cfg['wandb']['host'])
    wandb.init(project=cfg['wandb']['project'],
               notes=cfg['wandb']['notes'],
               config=omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
               resume=cfg['wandb']['resume'],
               name=cfg['wandb']['run_name'])
    if wandb.run.resumed:
        api = wandb.Api()
        runs = api.runs(f"{cfg['wandb']['entity']}/{cfg['wandb']['project']}",
                        order='train_pesq')
        run = [run for run in runs if run.name == cfg['wandb']['run_name'] and run.state != 'running'][0]
        artifacts = run.logged_artifacts()
        best_model = [artifact for artifact in artifacts if artifact.type == 'model'][0]

        best_model.download()

def train(cfg: DictConfig):
    device = torch.device(f'cuda:{cfg.gpu}' if torch.cuda.is_available() else 'cpu')
    init_wandb(cfg)
    checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name,
                                       decreasing=False)
    metrics = Metrics(source_rate=cfg['dataloader']['sample_rate']).to(device)

    model = get_model(cfg['model']).to(device)
    optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
    loss_fn = get_loss(cfg['loss'], device)
    train_dataset, valid_dataset = get_datasets(cfg)
    minimal_dataset = Minimal(cfg)

    dataloaders = {
        'train':  DataLoader(train_dataset, batch_size=cfg['dataloader']['train_batch_size'], shuffle=True,
                             num_workers=cfg['dataloader']['num_workers']),
        'val': DataLoader(valid_dataset, batch_size=cfg['dataloader']['valid_batch_size'], shuffle=False,
                          num_workers=cfg['dataloader']['num_workers']),
        'minimal': DataLoader(minimal_dataset)
    }

    wandb.watch(model, log_freq=cfg['wandb']['log_interval'])
    epoch = 0
    while epoch < cfg['training']['num_epochs']:
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss, running_pesq, running_stoi = 0.0, 0.0, 0.0
            loop = tqdm(dataloaders[phase])
            for i, (inputs, labels) in enumerate(loop):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = loss_fn(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_metrics = metrics(denoised=outputs, clean=labels)
                running_loss += loss.item() * inputs.size(0)
                running_pesq += running_metrics['PESQ']
                running_stoi += running_metrics['STOI']

                loop.set_description(f"Epoch [{epoch}/{cfg['training']['num_epochs']}][{phase}]")
                loop.set_postfix(loss=running_loss / (i + 1) / inputs.size(0),
                                 pesq=running_pesq / (i + 1) / inputs.size(0),
                                 stoi=running_stoi / (i + 1) / inputs.size(0))

                if phase == 'train' and i % cfg['wandb']['log_interval'] == 0:
                    wandb.log({"train_loss": running_loss / (i + 1) / inputs.size(0),
                               "train_pesq": running_pesq / (i + 1) / inputs.size(0),
                               "train_stoi": running_stoi / (i + 1) / inputs.size(0)})

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            eposh_pesq = running_pesq / len(dataloaders[phase].dataset)
            eposh_stoi = running_stoi / len(dataloaders[phase].dataset)

            wandb.log({f"{phase}_loss": epoch_loss,
                       f"{phase}_pesq": eposh_pesq,
                       f"{phase}_stoi": eposh_stoi})

            if phase == 'val':
                for i, (wav, rate) in enumerate(dataloaders['minimal']):
                    if cfg['dataloader']['normalize']:
                        std = torch.std(wav)
                        wav = wav / std
                        prediction = model(wav.to(device))
                        prediction = prediction * std
                    else:
                        prediction = model(wav.to(device))
                    wandb.log({
                        f"{i}_example": wandb.Audio(
                            prediction.detach().cpu().numpy()[0][0],
                            sample_rate=rate)})

                checkpoint_saver(model, epoch, metric_val=eposh_pesq,
                                 optimizer=optimizer, loss=epoch_loss)
        epoch += 1


if __name__ == "__main__":
    pass