File size: 4,503 Bytes
a38e25f
 
e59cfff
a38e25f
 
 
 
 
 
 
 
 
 
 
 
232d8f8
a38e25f
 
 
232d8f8
1f644df
85adfed
a38e25f
7c8662f
 
 
 
 
a38e25f
1f644df
80802e9
a38e25f
7c8662f
 
 
 
 
 
 
232d8f8
e59cfff
a38e25f
7c8662f
 
 
 
 
 
 
a38e25f
232d8f8
cf26dbd
a38e25f
7c8662f
 
 
 
a38e25f
7c8662f
 
 
 
 
 
 
 
1f644df
7c8662f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f644df
7c8662f
 
 
 
 
 
 
 
 
 
 
a38e25f
 
 
 
 
 
 
 
 
 
 
9fe5149
a38e25f
9fe5149
a38e25f
 
e59cfff
a38e25f
 
 
 
 
232d8f8
a38e25f
 
 
232d8f8
 
a38e25f
e59cfff
a38e25f
 
3e88903
a38e25f
232d8f8
 
a38e25f
 
 
232d8f8
a38e25f
e59cfff
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from datetime import datetime
from tqdm import tqdm
import wandb

# torch
import torch
import torchaudio
from torch import nn
from torch.utils.data import DataLoader

# internal
from dataset import VoiceDataset
from cnn import CNNetwork

BATCH_SIZE = 128
EPOCHS = 10
LEARNING_RATE = 0.001

TRAIN_FILE="data/train"
AISF_TRAIN_FILE="data/aisf/train"
TEST_FILE="data/test"
SAMPLE_RATE=48000

def train(model, train_dataloader, loss_fn, optimizer, device, epochs, test_dataloader=None):
    training_acc = []
    training_loss = []
    testing_acc = []
    testing_loss = []

    for i in range(epochs):
        print(f"Epoch {i + 1}/{epochs}")

        # train model
        train_epoch_loss, train_epoch_acc = train_epoch(model, train_dataloader, loss_fn, optimizer, device)

        # training metrics
        training_loss.append(train_epoch_loss/len(train_dataloader))
        training_acc.append(train_epoch_acc/len(train_dataloader))

        print("Training Loss: {:.2f}, Training Accuracy  {}".format(training_loss[i], training_acc[i]))
        wandb.log({'training_loss': training_loss[i], 'training_acc': training_acc[i]})

        if test_dataloader:
            # test model
            test_epoch_loss, test_epoch_acc = validate_epoch(model, test_dataloader, loss_fn, device)
            
            # testing metrics
            testing_loss.append(test_epoch_loss/len(test_dataloader))
            testing_acc.append(test_epoch_acc/len(test_dataloader))

            print("Testing Loss: {:.2f}, Testing Accuracy  {}".format(testing_loss[i], testing_acc[i]))
            wandb.log({'testing_loss': testing_loss[i], 'testing_acc': testing_acc[i]})

        print ("-------------------------------------------- \n")
    
    print("---- Finished Training ----")
    return training_acc, training_loss, testing_acc, testing_loss
  

def train_epoch(model, train_dataloader, loss_fn, optimizer, device):
    train_loss = 0.0
    train_acc = 0.0
    total = 0.0

    model.train()

    for wav, target in tqdm(train_dataloader, "Training batch..."):
        wav, target = wav.to(device), target.to(device)

        # calculate loss
        output = model(wav)
        loss = loss_fn(output, target)

        # backprop and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # metrics
        train_loss += loss.item()
        prediction = torch.argmax(output, 1)
        train_acc += (prediction == target).sum().item()/len(prediction)
        total += 1
       
    return train_loss, train_acc

def validate_epoch(model, test_dataloader, loss_fn, device):
    test_loss = 0.0
    test_acc = 0.0
    total = 0.0

    model.eval()

    with torch.no_grad():
        for wav, target in tqdm(test_dataloader, "Testing batch..."):
            wav, target = wav.to(device), target.to(device)

            output = model(wav)
            loss = loss_fn(output, target)

            test_loss += loss.item()
            prediciton = torch.argmax(output, 1)
            test_acc += (prediciton == target).sum().item()/len(prediciton)
            total += 1
    
    return test_loss, test_acc

if __name__ == "__main__":
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    print(f"Using {device} device.")

    # instantiating our dataset object and create data loader
    mel_spectrogram = torchaudio.transforms.MelSpectrogram(
        sample_rate=SAMPLE_RATE,
        n_fft=2048,
        hop_length=512,
        n_mels=128
    )

    train_dataset = VoiceDataset(AISF_TRAIN_FILE, mel_spectrogram, device, time_limit_in_secs=3)
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    # construct model
    model = CNNetwork().to(device)
    print(model)
    print(train_dataset.label_mapping)

    # init loss function and optimizer
    loss_fn = nn.CrossEntropyLoss()
    # optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)

    wandb.init(project="void-train")

    # train model
    train(model, train_dataloader, loss_fn, optimizer, device, EPOCHS)

    model.label_mapping = train_dataset.label_mapping

    # save model
    now = datetime.now()
    now = now.strftime("%Y%m%d_%H%M%S")
    model_filename = f"models/aisf/void_{now}.pth"
    torch.save(model.state_dict(), model_filename)
    print(f"Trained void model saved at {model_filename}")
    wandb.finish()