import os from torch.utils.tensorboard import SummaryWriter import torch from torch.nn import Sequential from torch.utils.data import DataLoader from datetime import datetime from torchvision.transforms import RandomCrop from utils import load_wav from denoisers.demucs import Demucs from pathlib import Path from omegaconf import DictConfig from optimizers import OPTIMIZERS_POOL from losses import LOSSES from datasets import DATASETS_POOL from denoisers import get_model from optimizers import get_optimizer from losses import get_loss os.environ['CUDA_VISIBLE_DEVICES'] = "1" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # # # DATASET_PATH = Path('/media/public/dataset/denoising/DS_10283_2791/') # VALID_WAVS = {'hard': 'p257_171.wav', # 'medium': 'p232_071.wav', # 'easy': 'p232_284.wav'} # MAX_SECONDS = 2 # SAMPLE_RATE = 16000 # # transform = Sequential(RandomCrop((1, int(MAX_SECONDS * SAMPLE_RATE)), pad_if_needed=True)) # # training_loader = DataLoader(Valentini(valid=False, transform=transform), batch_size=12, shuffle=True) # validation_loader = DataLoader(Valentini(valid=True, transform=transform), batch_size=12, shuffle=True) def train(cfg: DictConfig): timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') model = get_model(cfg['model']) optimizer = get_optimizer(model.parameters(), cfg['optimizer']) loss_fn = get_loss(cfg['loss']) writer = SummaryWriter('runs/denoising_trainer_{}'.format(timestamp)) epoch_number = 0 EPOCHS = 5 best_vloss = 1_000_000. for tag, wav_path in VALID_WAVS.items(): wav = load_wav(DATASET_PATH / 'noisy_testset_wav' / wav_path) writer.add_audio(tag=tag, snd_tensor=wav, sample_rate=SAMPLE_RATE) writer.flush() for epoch in range(EPOCHS): print('EPOCH {}:'.format(epoch_number + 1)) model.train(True) running_loss = 0. last_loss = 0. for i, data in enumerate(training_loader): inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_fn(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if i % 1000 == 999: last_loss = running_loss / 1000 # loss per batch print(' batch {} loss: {}'.format(i + 1, last_loss)) tb_x = epoch_number * len(training_loader) + i + 1 writer.add_scalar('Loss/train', last_loss, tb_x) running_loss = 0. avg_loss = last_loss model.train(False) running_vloss = 0.0 with torch.no_grad(): for i, vdata in enumerate(validation_loader): vinputs, vlabels = vdata vinputs, vlabels = vinputs.to(device), vlabels.to(device) voutputs = model(vinputs) vloss = loss_fn(voutputs, vlabels) running_vloss += vloss avg_vloss = running_vloss / (i + 1) print('LOSS train {} valid {}'.format(avg_loss, avg_vloss)) writer.add_scalars('Training vs. Validation Loss', {'Training': avg_loss, 'Validation': avg_vloss}, epoch_number + 1) for tag, wav_path in VALID_WAVS.items(): wav = load_wav(DATASET_PATH / 'noisy_testset_wav' / wav_path) wav = torch.reshape(wav, (1, 1, -1)).to(device) prediction = model(wav) writer.add_audio(tag=f"Model predicted {tag} on epoch {epoch}", snd_tensor=prediction, sample_rate=SAMPLE_RATE) writer.flush() if avg_vloss < best_vloss: best_vloss = avg_vloss model_path = 'checkpoints/model_{}_{}'.format(timestamp, epoch_number) torch.save(model.state_dict(), model_path) epoch_number += 1 if __name__ == '__main__': train()