|
import re |
|
import numpy as np |
|
from prettytable import PrettyTable |
|
from tqdm import tqdm |
|
import torch |
|
|
|
|
|
def normalize_text(s): |
|
""" |
|
Removes articles and punctuation, and standardizing whitespace are all typical text processing steps. |
|
Copied from: https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html#Metrics-for-QA |
|
:param s: string to clean |
|
:return: cleaned string |
|
""" |
|
import string, re |
|
|
|
def remove_articles(text): |
|
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) |
|
return re.sub(regex, " ", text) |
|
|
|
def white_space_fix(text): |
|
return " ".join(text.split()) |
|
|
|
def remove_punc(text): |
|
exclude = set(string.punctuation) |
|
return "".join(ch for ch in text if ch not in exclude) |
|
|
|
def lower(text): |
|
return text.lower() |
|
|
|
return white_space_fix(remove_articles(remove_punc(lower(s)))) |
|
|
|
|
|
def compute_exact_match(prediction, truth): |
|
""" |
|
Returns true if the predicted is an exact match, else False |
|
Retrieved from: https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html#Metrics-for-QA |
|
:param prediction: predicted answer |
|
:param truth: ground truth |
|
:return: 1 if exact match, else 0 |
|
""" |
|
return int(normalize_text(prediction) == normalize_text(truth)) |
|
|
|
|
|
def compute_f1(prediction, truth): |
|
""" |
|
Computes the F-1 score of a prediction, based on the tokens |
|
Retrieved from: https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html#Metrics-for-QA |
|
:param prediction: predicted answer |
|
:param truth: ground truth |
|
:return: the f-1 score of the prediction |
|
""" |
|
pred_tokens = normalize_text(prediction).split() |
|
truth_tokens = normalize_text(truth).split() |
|
|
|
|
|
if len(pred_tokens) == 0 or len(truth_tokens) == 0: |
|
return int(pred_tokens == truth_tokens) |
|
|
|
|
|
common_tokens = set(pred_tokens) & set(truth_tokens) |
|
|
|
|
|
if len(common_tokens) == 0: |
|
return 0 |
|
|
|
|
|
prec = len(common_tokens) / len(pred_tokens) |
|
rec = len(common_tokens) / len(truth_tokens) |
|
|
|
return 2 * (prec * rec) / (prec + rec) |
|
|
|
def eval_test_set(model, tokenizer, test_loader, device): |
|
""" |
|
Calculates the mean EM and mean F-1 score on the test set |
|
:param model: pytorch model |
|
:param tokenizer: tokenizer used to encode the samples |
|
:param test_loader: dataloader object with test data |
|
:param device: device the model is on |
|
""" |
|
mean_em = [] |
|
mean_f1 = [] |
|
model.to(device) |
|
model.eval() |
|
for batch in tqdm(test_loader): |
|
|
|
input_ids = batch['input_ids'].to(device) |
|
attention_mask = batch['attention_mask'].to(device) |
|
start = batch['start_positions'].to(device) |
|
end = batch['end_positions'].to(device) |
|
|
|
|
|
outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end) |
|
|
|
|
|
for input_i, s, e, trues, truee in zip(input_ids, outputs['start_logits'], outputs['end_logits'], start, end): |
|
|
|
start_logits = torch.argmax(s) |
|
end_logits = torch.argmax(e) |
|
|
|
|
|
ans_tokens = input_i[start_logits: end_logits + 1] |
|
answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True) |
|
predicted = tokenizer.convert_tokens_to_string(answer_tokens) |
|
|
|
|
|
ans_tokens = input_i[trues: truee + 1] |
|
answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True) |
|
true = tokenizer.convert_tokens_to_string(answer_tokens) |
|
|
|
|
|
em_score = compute_exact_match(predicted, true) |
|
f1_score = compute_f1(predicted, true) |
|
mean_em.append(em_score) |
|
mean_f1.append(f1_score) |
|
print("Mean EM: ", np.mean(mean_em)) |
|
print("Mean F-1: ", np.mean(mean_f1)) |
|
|
|
def count_parameters(model): |
|
""" |
|
This function prints statistic regarding the trainable parameters |
|
:param model: pytorch model |
|
:return: parameters to be fine-tuned |
|
""" |
|
table = PrettyTable(["Modules", "Parameters"]) |
|
total_params = 0 |
|
for name, parameter in model.named_parameters(): |
|
if not parameter.requires_grad: continue |
|
params = parameter.numel() |
|
table.add_row([name, params]) |
|
total_params += params |
|
print(table) |
|
print(f"Total Trainable Params: {total_params}") |
|
return total_params |
|
|