File size: 5,865 Bytes
4743e80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import torch
class Dataset(torch.utils.data.Dataset):
"""
This class loads and preprocesses the given text data
"""
def __init__(self, paths, tokenizer):
"""
This function initialises the object. It takes the given paths and tokeniser.
"""
# the last file might not have 10000 samples, which makes it difficult to get the total length of the ds
self.paths = paths[:len(paths)-1]
self.tokenizer = tokenizer
self.data = self.read_file(self.paths[0])
self.current_file = 1
self.remaining = len(self.data)
self.encodings = self.get_encodings(self.data)
def __len__(self):
"""
returns the lenght of the ds
"""
return 10000*len(self.paths)
def read_file(self, path):
"""
reads a given file
"""
with open(path, 'r', encoding='utf-8') as f:
lines = f.read().split('\n')
return lines
def get_encodings(self, lines_all):
"""
Creates encodings for a given text input
"""
# tokenise all text
batch = self.tokenizer(lines_all, max_length=512, padding='max_length', truncation=True)
# Ground Truth
labels = torch.tensor(batch['input_ids'])
# Attention Masks
mask = torch.tensor(batch['attention_mask'])
# Input to be masked
input_ids = labels.detach().clone()
rand = torch.rand(input_ids.shape)
# with a probability of 15%, mask a given word, leave out CLS, SEP and PAD
mask_arr = (rand < .15) * (input_ids != 0) * (input_ids != 2) * (input_ids != 3)
# assign token 4 (=MASK)
input_ids[mask_arr] = 4
return {'input_ids':input_ids, 'attention_mask':mask, 'labels':labels}
def __getitem__(self, i):
"""
returns item i
Note: do not use shuffling for this dataset
"""
# if we have looked at all items in the file - take next
if self.remaining == 0:
self.data = self.read_file(self.paths[self.current_file])
self.current_file += 1
self.remaining = len(self.data)
self.encodings = self.get_encodings(self.data)
# if we are at the end of the dataset, start over again
if self.current_file == len(self.paths):
self.current_file = 0
self.remaining -= 1
return {key: tensor[i%10000] for key, tensor in self.encodings.items()}
def test_model(model, optim, test_ds_loader, device):
"""
This function tests whether the parameters of the model that are frozen change, the ones that are not frozen do change,
and whether any parameters become NaN or Inf
:param model: model to be tested
:param optim: optimiser used for training
:param test_ds_loader: dataset to perform the forward pass on
:param device: current device
:raises Exception: if any of the above conditions are not met
"""
## Check if non-frozen parameters changed and frozen ones did not
# get initial parameters to check against
params = [ np for np in model.named_parameters() if np[1].requires_grad ]
initial_params = [ (name, p.clone()) for (name, p) in params ]
params_frozen = [ np for np in model.named_parameters() if not np[1].requires_grad ]
initial_params_frozen = [ (name, p.clone()) for (name, p) in params_frozen ]
optim.zero_grad()
# get data
batch = next(iter(test_ds_loader))
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
# forward pass and backpropagation
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
optim.step()
# check if variables have changed
for (_, p0), (name, p1) in zip(initial_params, params):
# check different than initial
try:
assert not torch.equal(p0.to(device), p1.to(device))
except AssertionError:
raise Exception(
"{var_name} {msg}".format(
var_name=name,
msg='did not change!'
)
)
# check not NaN
try:
assert not torch.isnan(p1).byte().any()
except AssertionError:
raise Exception(
"{var_name} {msg}".format(
var_name=name,
msg='is NaN!'
)
)
# check finite
try:
assert torch.isfinite(p1).byte().all()
except AssertionError:
raise Exception(
"{var_name} {msg}".format(
var_name=name,
msg='is Inf!'
)
)
# check that frozen weights have not changed
for (_, p0), (name, p1) in zip(initial_params_frozen, params_frozen):
# should be the same
try:
assert torch.equal(p0.to(device), p1.to(device))
except AssertionError:
raise Exception(
"{var_name} {msg}".format(
var_name=name,
msg='changed!'
)
)
# check not NaN
try:
assert not torch.isnan(p1).byte().any()
except AssertionError:
raise Exception(
"{var_name} {msg}".format(
var_name=name,
msg='is NaN!'
)
)
# check finite numbers
try:
assert torch.isfinite(p1).byte().all()
except AssertionError:
raise Exception(
"{var_name} {msg}".format(
var_name=name,
msg='is Inf!'
)
)
print("Passed") |