import re |
import numpy as np |
from prettytable import PrettyTable |
from tqdm import tqdm |
import torch |
def normalize_text(s): |
""" |
Removes articles and punctuation, and standardizing whitespace are all typical text processing steps. |
Copied from: https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html#Metrics-for-QA |
:param s: string to clean |
:return: cleaned string |
""" |
import string, re |
def remove_articles(text): |
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) |
return re.sub(regex, " ", text) |
def white_space_fix(text): |
return " ".join(text.split()) |
def remove_punc(text): |
exclude = set(string.punctuation) |
return "".join(ch for ch in text if ch not in exclude) |
def lower(text): |
return text.lower() |
return white_space_fix(remove_articles(remove_punc(lower(s)))) |
def compute_exact_match(prediction, truth): |
""" |
Returns true if the predicted is an exact match, else False |
Retrieved from: https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html#Metrics-for-QA |
:param prediction: predicted answer |
:param truth: ground truth |
:return: 1 if exact match, else 0 |
""" |
return int(normalize_text(prediction) == normalize_text(truth)) |
def compute_f1(prediction, truth): |
""" |
Computes the F-1 score of a prediction, based on the tokens |
Retrieved from: https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html#Metrics-for-QA |
:param prediction: predicted answer |
:param truth: ground truth |
:return: the f-1 score of the prediction |
""" |
pred_tokens = normalize_text(prediction).split() |
truth_tokens = normalize_text(truth).split() |
if len(pred_tokens) == 0 or len(truth_tokens) == 0: |
return int(pred_tokens == truth_tokens) |
common_tokens = set(pred_tokens) & set(truth_tokens) |
if len(common_tokens) == 0: |
return 0 |
prec = len(common_tokens) / len(pred_tokens) |
rec = len(common_tokens) / len(truth_tokens) |
return 2 * (prec * rec) / (prec + rec) |
def eval_test_set(model, tokenizer, test_loader, device): |
""" |
Calculates the mean EM and mean F-1 score on the test set |
:param model: pytorch model |
:param tokenizer: tokenizer used to encode the samples |
:param test_loader: dataloader object with test data |
:param device: device the model is on |
""" |
mean_em = [] |
mean_f1 = [] |
model.to(device) |
model.eval() |
for batch in tqdm(test_loader): |
input_ids = batch['input_ids'].to(device) |
attention_mask = batch['attention_mask'].to(device) |
start = batch['start_positions'].to(device) |
end = batch['end_positions'].to(device) |
outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end) |
for input_i, s, e, trues, truee in zip(input_ids, outputs['start_logits'], outputs['end_logits'], start, end): |
start_logits = torch.argmax(s) |
end_logits = torch.argmax(e) |
ans_tokens = input_i[start_logits: end_logits + 1] |
answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True) |
predicted = tokenizer.convert_tokens_to_string(answer_tokens) |
ans_tokens = input_i[trues: truee + 1] |
answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True) |
true = tokenizer.convert_tokens_to_string(answer_tokens) |
em_score = compute_exact_match(predicted, true) |
f1_score = compute_f1(predicted, true) |
mean_em.append(em_score) |
mean_f1.append(f1_score) |
print("Mean EM: ", np.mean(mean_em)) |
print("Mean F-1: ", np.mean(mean_f1)) |
def count_parameters(model): |
""" |
This function prints statistic regarding the trainable parameters |
:param model: pytorch model |
:return: parameters to be fine-tuned |
""" |
table = PrettyTable(["Modules", "Parameters"]) |
total_params = 0 |
for name, parameter in model.named_parameters(): |
if not parameter.requires_grad: continue |
params = parameter.numel() |
table.add_row([name, params]) |
total_params += params |
print(table) |
print(f"Total Trainable Params: {total_params}") |
return total_params |