Spaces:
Runtime error
Runtime error
File size: 4,503 Bytes
a38e25f e59cfff a38e25f 232d8f8 a38e25f 232d8f8 1f644df 85adfed a38e25f 7c8662f a38e25f 1f644df 80802e9 a38e25f 7c8662f 232d8f8 e59cfff a38e25f 7c8662f a38e25f 232d8f8 cf26dbd a38e25f 7c8662f a38e25f 7c8662f 1f644df 7c8662f 1f644df 7c8662f a38e25f 9fe5149 a38e25f 9fe5149 a38e25f e59cfff a38e25f 232d8f8 a38e25f 232d8f8 a38e25f e59cfff a38e25f 3e88903 a38e25f 232d8f8 a38e25f 232d8f8 a38e25f e59cfff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from datetime import datetime
from tqdm import tqdm
import wandb
# torch
import torch
import torchaudio
from torch import nn
from torch.utils.data import DataLoader
# internal
from dataset import VoiceDataset
from cnn import CNNetwork
BATCH_SIZE = 128
EPOCHS = 10
LEARNING_RATE = 0.001
TRAIN_FILE="data/train"
AISF_TRAIN_FILE="data/aisf/train"
TEST_FILE="data/test"
SAMPLE_RATE=48000
def train(model, train_dataloader, loss_fn, optimizer, device, epochs, test_dataloader=None):
training_acc = []
training_loss = []
testing_acc = []
testing_loss = []
for i in range(epochs):
print(f"Epoch {i + 1}/{epochs}")
# train model
train_epoch_loss, train_epoch_acc = train_epoch(model, train_dataloader, loss_fn, optimizer, device)
# training metrics
training_loss.append(train_epoch_loss/len(train_dataloader))
training_acc.append(train_epoch_acc/len(train_dataloader))
print("Training Loss: {:.2f}, Training Accuracy {}".format(training_loss[i], training_acc[i]))
wandb.log({'training_loss': training_loss[i], 'training_acc': training_acc[i]})
if test_dataloader:
# test model
test_epoch_loss, test_epoch_acc = validate_epoch(model, test_dataloader, loss_fn, device)
# testing metrics
testing_loss.append(test_epoch_loss/len(test_dataloader))
testing_acc.append(test_epoch_acc/len(test_dataloader))
print("Testing Loss: {:.2f}, Testing Accuracy {}".format(testing_loss[i], testing_acc[i]))
wandb.log({'testing_loss': testing_loss[i], 'testing_acc': testing_acc[i]})
print ("-------------------------------------------- \n")
print("---- Finished Training ----")
return training_acc, training_loss, testing_acc, testing_loss
def train_epoch(model, train_dataloader, loss_fn, optimizer, device):
train_loss = 0.0
train_acc = 0.0
total = 0.0
model.train()
for wav, target in tqdm(train_dataloader, "Training batch..."):
wav, target = wav.to(device), target.to(device)
# calculate loss
output = model(wav)
loss = loss_fn(output, target)
# backprop and update weights
optimizer.zero_grad()
loss.backward()
optimizer.step()
# metrics
train_loss += loss.item()
prediction = torch.argmax(output, 1)
train_acc += (prediction == target).sum().item()/len(prediction)
total += 1
return train_loss, train_acc
def validate_epoch(model, test_dataloader, loss_fn, device):
test_loss = 0.0
test_acc = 0.0
total = 0.0
model.eval()
with torch.no_grad():
for wav, target in tqdm(test_dataloader, "Testing batch..."):
wav, target = wav.to(device), target.to(device)
output = model(wav)
loss = loss_fn(output, target)
test_loss += loss.item()
prediciton = torch.argmax(output, 1)
test_acc += (prediciton == target).sum().item()/len(prediciton)
total += 1
return test_loss, test_acc
if __name__ == "__main__":
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print(f"Using {device} device.")
# instantiating our dataset object and create data loader
mel_spectrogram = torchaudio.transforms.MelSpectrogram(
sample_rate=SAMPLE_RATE,
n_fft=2048,
hop_length=512,
n_mels=128
)
train_dataset = VoiceDataset(AISF_TRAIN_FILE, mel_spectrogram, device, time_limit_in_secs=3)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# construct model
model = CNNetwork().to(device)
print(model)
print(train_dataset.label_mapping)
# init loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)
wandb.init(project="void-train")
# train model
train(model, train_dataloader, loss_fn, optimizer, device, EPOCHS)
model.label_mapping = train_dataset.label_mapping
# save model
now = datetime.now()
now = now.strftime("%Y%m%d_%H%M%S")
model_filename = f"models/aisf/void_{now}.pth"
torch.save(model.state_dict(), model_filename)
print(f"Trained void model saved at {model_filename}")
wandb.finish() |