|
import torch |
|
|
|
class Dataset(torch.utils.data.Dataset): |
|
""" |
|
This class loads and preprocesses the given text data |
|
""" |
|
def __init__(self, paths, tokenizer): |
|
""" |
|
This function initialises the object. It takes the given paths and tokeniser. |
|
""" |
|
|
|
self.paths = paths[:len(paths)-1] |
|
self.tokenizer = tokenizer |
|
self.data = self.read_file(self.paths[0]) |
|
self.current_file = 1 |
|
self.remaining = len(self.data) |
|
self.encodings = self.get_encodings(self.data) |
|
|
|
def __len__(self): |
|
""" |
|
returns the lenght of the ds |
|
""" |
|
return 10000*len(self.paths) |
|
|
|
def read_file(self, path): |
|
""" |
|
reads a given file |
|
""" |
|
with open(path, 'r', encoding='utf-8') as f: |
|
lines = f.read().split('\n') |
|
return lines |
|
|
|
def get_encodings(self, lines_all): |
|
""" |
|
Creates encodings for a given text input |
|
""" |
|
|
|
batch = self.tokenizer(lines_all, max_length=512, padding='max_length', truncation=True) |
|
|
|
|
|
labels = torch.tensor(batch['input_ids']) |
|
|
|
mask = torch.tensor(batch['attention_mask']) |
|
|
|
|
|
input_ids = labels.detach().clone() |
|
rand = torch.rand(input_ids.shape) |
|
|
|
|
|
mask_arr = (rand < .15) * (input_ids != 0) * (input_ids != 2) * (input_ids != 3) |
|
|
|
input_ids[mask_arr] = 4 |
|
|
|
return {'input_ids':input_ids, 'attention_mask':mask, 'labels':labels} |
|
|
|
def __getitem__(self, i): |
|
""" |
|
returns item i |
|
Note: do not use shuffling for this dataset |
|
""" |
|
|
|
if self.remaining == 0: |
|
self.data = self.read_file(self.paths[self.current_file]) |
|
self.current_file += 1 |
|
self.remaining = len(self.data) |
|
self.encodings = self.get_encodings(self.data) |
|
|
|
|
|
if self.current_file == len(self.paths): |
|
self.current_file = 0 |
|
|
|
self.remaining -= 1 |
|
return {key: tensor[i%10000] for key, tensor in self.encodings.items()} |
|
|
|
def test_model(model, optim, test_ds_loader, device): |
|
""" |
|
This function tests whether the parameters of the model that are frozen change, the ones that are not frozen do change, |
|
and whether any parameters become NaN or Inf |
|
:param model: model to be tested |
|
:param optim: optimiser used for training |
|
:param test_ds_loader: dataset to perform the forward pass on |
|
:param device: current device |
|
:raises Exception: if any of the above conditions are not met |
|
""" |
|
|
|
|
|
|
|
params = [ np for np in model.named_parameters() if np[1].requires_grad ] |
|
initial_params = [ (name, p.clone()) for (name, p) in params ] |
|
|
|
params_frozen = [ np for np in model.named_parameters() if not np[1].requires_grad ] |
|
initial_params_frozen = [ (name, p.clone()) for (name, p) in params_frozen ] |
|
|
|
optim.zero_grad() |
|
|
|
|
|
batch = next(iter(test_ds_loader)) |
|
|
|
input_ids = batch['input_ids'].to(device) |
|
attention_mask = batch['attention_mask'].to(device) |
|
labels = batch['labels'].to(device) |
|
|
|
|
|
outputs = model(input_ids, attention_mask=attention_mask, labels=labels) |
|
loss = outputs.loss |
|
loss.backward() |
|
optim.step() |
|
|
|
|
|
for (_, p0), (name, p1) in zip(initial_params, params): |
|
|
|
try: |
|
assert not torch.equal(p0.to(device), p1.to(device)) |
|
except AssertionError: |
|
raise Exception( |
|
"{var_name} {msg}".format( |
|
var_name=name, |
|
msg='did not change!' |
|
) |
|
) |
|
|
|
try: |
|
assert not torch.isnan(p1).byte().any() |
|
except AssertionError: |
|
raise Exception( |
|
"{var_name} {msg}".format( |
|
var_name=name, |
|
msg='is NaN!' |
|
) |
|
) |
|
|
|
try: |
|
assert torch.isfinite(p1).byte().all() |
|
except AssertionError: |
|
raise Exception( |
|
"{var_name} {msg}".format( |
|
var_name=name, |
|
msg='is Inf!' |
|
) |
|
) |
|
|
|
|
|
for (_, p0), (name, p1) in zip(initial_params_frozen, params_frozen): |
|
|
|
try: |
|
assert torch.equal(p0.to(device), p1.to(device)) |
|
except AssertionError: |
|
raise Exception( |
|
"{var_name} {msg}".format( |
|
var_name=name, |
|
msg='changed!' |
|
) |
|
) |
|
|
|
try: |
|
assert not torch.isnan(p1).byte().any() |
|
except AssertionError: |
|
raise Exception( |
|
"{var_name} {msg}".format( |
|
var_name=name, |
|
msg='is NaN!' |
|
) |
|
) |
|
|
|
|
|
try: |
|
assert torch.isfinite(p1).byte().all() |
|
except AssertionError: |
|
raise Exception( |
|
"{var_name} {msg}".format( |
|
var_name=name, |
|
msg='is Inf!' |
|
) |
|
) |
|
print("Passed") |