nehalelkaref's picture
Update validate.py
572b22d
raw
history blame
7.01 kB
import argparse
import re
import torch
from tqdm.auto import tqdm
from network import EntNet
from utils import read_conll_ner, split_conll_docs, create_context_data, extract_spans
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
def classify(model, sents, pos, batch_size):
model.eval()
result = []
for i in tqdm(range(0, len(sents), batch_size), desc='classifying... '):
tag_seqs = model(sentences=sents[i:i + batch_size],
pos=pos[i:i + batch_size])
result.extend(tag_seqs['pred_tags'])
# f1, p, r
return [[[w, t] for w, t in zip(s, r)] for s, r in zip(sents, result)]
def entities_from_token_classes(tokens):
ENTITY_BEGIN_REGEX = r"^B" # -(\w+)"
ENTITY_MIDDLE_REGEX = r"^I" # -(\w+)"
entities = []
current_entity = None
start_index_of_current_entity = 0
end_index_of_current_entity = 0
for i, kls in enumerate(tokens):
m = re.match(ENTITY_BEGIN_REGEX, kls)
if m is not None:
if current_entity is not None:
entities.append({
"type": current_entity,
"index": [start_index_of_current_entity,
end_index_of_current_entity]
})
# start of entity
current_entity = m.string.split('-')[1] if '-' in m.string else ''
start_index_of_current_entity = i
end_index_of_current_entity = i
continue
m = re.match(ENTITY_MIDDLE_REGEX, kls)
if current_entity is not None:
if m is None:
# after the end of this entity
entities.append({
"type": current_entity,
"index": [start_index_of_current_entity,
end_index_of_current_entity]
})
current_entity = None
continue
# in the middle of this entity
end_index_of_current_entity = i
# Add any remaining entity
if current_entity is not None:
entities.append({
"type": current_entity,
"index": [start_index_of_current_entity,
end_index_of_current_entity]
})
return entities
def calc_f1(targs, preds):
stat_dict = {
'overall': {'unl_tp': 0, 'lab_tp': 0, 'targs': 0, 'preds': 0}
}
for sent_targs, sent_preds in zip(targs, preds):
stat_dict['overall']['targs'] += len(sent_targs)
stat_dict['overall']['preds'] += len(sent_preds)
for pred in sent_preds:
if pred['type'] not in stat_dict.keys():
stat_dict[pred['type']] = {'lab_tp': 0, 'targs': 0, 'preds': 0}
stat_dict[pred['type']]['preds'] += 1
for targ in sent_targs:
if targ['type'] not in stat_dict.keys():
stat_dict[targ['type']] = {'lab_tp': 0, 'targs': 0, 'preds': 0}
stat_dict[targ['type']]['targs'] += 1
# is there a span that matches exactly?
for pred in sent_preds:
if pred['index'][0] == targ['index'][0] and pred['index'][1] == targ['index'][1]:
stat_dict['overall']['unl_tp'] += 1
# if so do the tags match exactly?
if pred['type'] == targ['type']:
stat_dict['overall']['lab_tp'] += 1
stat_dict[targ['type']]['lab_tp'] += 1
for k in stat_dict.keys():
if k == 'overall':
stat_dict[k]['unl_p'] = stat_dict[k]['unl_tp'] / stat_dict[k]['preds'] if stat_dict[k]['preds'] else 0
stat_dict[k]['unl_r'] = stat_dict[k]['unl_tp'] / stat_dict[k]['targs'] if stat_dict[k]['targs'] else 0
stat_dict[k]['unl_f1'] = 2 * stat_dict[k]['unl_p'] * stat_dict[k]['unl_r'] / (
stat_dict[k]['unl_p'] + stat_dict[k]['unl_r']) if (
stat_dict[k]['unl_p'] + stat_dict[k]['unl_r']) else 0
stat_dict[k]['lab_p'] = stat_dict[k]['lab_tp'] / stat_dict[k]['preds'] if stat_dict[k]['preds'] else 0
stat_dict[k]['lab_r'] = stat_dict[k]['lab_tp'] / stat_dict[k]['targs'] if stat_dict[k]['targs'] else 0
stat_dict[k]['lab_f1'] = 2 * stat_dict[k]['lab_p'] * stat_dict[k]['lab_r'] / (
stat_dict[k]['lab_p'] + stat_dict[k]['lab_r']) if (stat_dict[k]['lab_p'] + stat_dict[k]['lab_r']) else 0
class_f1s = [v['lab_f1'] for k, v in stat_dict.items() if k != 'overall']
stat_dict['overall']['macro_lab_f1'] = sum(class_f1s) / len(class_f1s)
return stat_dict
def main(args):
global device
device = torch.device('cuda' if use_cuda else 'cpu')
test_columns = read_conll_ner(args.test_path)
test_docs = split_conll_docs(test_columns[0])
test_data = create_context_data(test_docs, args.context_size)
sents = [td[0] for td in test_data]
pos = [td[1] for td in test_data]
if len(args.model_path) > 1 or args.span_model_path is not None:
model = StagedEnsemble(model_paths=args.model_path, span_model_paths=args.span_model_path, device=device)
else:
model = EntNet.load_model(args.model_path[0], device=device)
model.to(device)
BATCH_SIZE = args.batch_size
res = classify(model, sents, pos, BATCH_SIZE)
targets = [td[2] for td in test_data]
targ_tags = [entities_from_token_classes(td[2]) for td in test_data]
pred_tags = [entities_from_token_classes([t[1] for t in r]) for r in res]
result = calc_f1(targ_tags, pred_tags)
print(f'Overall unlabelled - F1:{result["overall"]["unl_f1"]}, '
f'P:{result["overall"]["unl_p"]}, '
f'R:{result["overall"]["unl_r"]}')
print(f'Overall labelled - Micro F1:{result["overall"]["lab_f1"]}, '
f'P:{result["overall"]["lab_p"]}, '
f'R:{result["overall"]["lab_r"]}')
print(f'Overall labelled - Macro F1:{result["overall"]["macro_lab_f1"]}')
for k, v in result.items():
if k == 'overall':
continue
print(f'{k} - F1:{v["lab_f1"]}, P:{v["lab_p"]}, R:{v["lab_r"]}')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, nargs='+', default=None, required=True, help='')
parser.add_argument('--span_model_path', type=str, nargs='*', default=None, help='')
# parser.add_argument('--network_type', type=str,
# choices=['span', 'entity', 'joint'], required=True,
# default=None, help='If entity is chosen, a path to a '
# 'span model is required also')
parser.add_argument('--test_path', type=str, default=None, help='')
parser.add_argument('--context_size', type=int, default=1, help='')
parser.add_argument('--batch_size', type=int, default=8, help='')
# parser.add_argument('--cuda_id', type=int, default=0, help='')
args = parser.parse_args()
main(args)