Question Answering
Numini / util.py
sanjudebnath's picture
Upload 8 files
4743e80 verified
import re
import numpy as np
from prettytable import PrettyTable
from tqdm import tqdm
import torch
def normalize_text(s):
"""
Removes articles and punctuation, and standardizing whitespace are all typical text processing steps.
Copied from: https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html#Metrics-for-QA
:param s: string to clean
:return: cleaned string
"""
import string, re
def remove_articles(text):
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
return re.sub(regex, " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def compute_exact_match(prediction, truth):
"""
Returns true if the predicted is an exact match, else False
Retrieved from: https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html#Metrics-for-QA
:param prediction: predicted answer
:param truth: ground truth
:return: 1 if exact match, else 0
"""
return int(normalize_text(prediction) == normalize_text(truth))
def compute_f1(prediction, truth):
"""
Computes the F-1 score of a prediction, based on the tokens
Retrieved from: https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html#Metrics-for-QA
:param prediction: predicted answer
:param truth: ground truth
:return: the f-1 score of the prediction
"""
pred_tokens = normalize_text(prediction).split()
truth_tokens = normalize_text(truth).split()
# if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
if len(pred_tokens) == 0 or len(truth_tokens) == 0:
return int(pred_tokens == truth_tokens)
# get tokens that are in the prediction and gt
common_tokens = set(pred_tokens) & set(truth_tokens)
# if there are no common tokens then f1 = 0
if len(common_tokens) == 0:
return 0
# calculate precision and recall
prec = len(common_tokens) / len(pred_tokens)
rec = len(common_tokens) / len(truth_tokens)
return 2 * (prec * rec) / (prec + rec)
def eval_test_set(model, tokenizer, test_loader, device):
"""
Calculates the mean EM and mean F-1 score on the test set
:param model: pytorch model
:param tokenizer: tokenizer used to encode the samples
:param test_loader: dataloader object with test data
:param device: device the model is on
"""
mean_em = []
mean_f1 = []
model.to(device)
model.eval()
for batch in tqdm(test_loader):
# get test data and transfer to device
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
start = batch['start_positions'].to(device)
end = batch['end_positions'].to(device)
# predict
outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)
# iterate over samples, calculate EM and F-1 for all
for input_i, s, e, trues, truee in zip(input_ids, outputs['start_logits'], outputs['end_logits'], start, end):
# get predicted start and end logits (maximum score)
start_logits = torch.argmax(s)
end_logits = torch.argmax(e)
# get predicted answer as string
ans_tokens = input_i[start_logits: end_logits + 1]
answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True)
predicted = tokenizer.convert_tokens_to_string(answer_tokens)
# get ground truth as string
ans_tokens = input_i[trues: truee + 1]
answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True)
true = tokenizer.convert_tokens_to_string(answer_tokens)
# compute score
em_score = compute_exact_match(predicted, true)
f1_score = compute_f1(predicted, true)
mean_em.append(em_score)
mean_f1.append(f1_score)
print("Mean EM: ", np.mean(mean_em))
print("Mean F-1: ", np.mean(mean_f1))
def count_parameters(model):
"""
This function prints statistic regarding the trainable parameters
:param model: pytorch model
:return: parameters to be fine-tuned
"""
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
params = parameter.numel()
table.add_row([name, params])
total_params += params
print(table)
print(f"Total Trainable Params: {total_params}")
return total_params