import argparse import re import torch from tqdm.auto import tqdm from network import EntNet from utils import read_conll_ner, split_conll_docs, create_context_data, extract_spans use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") def classify(model, sents, pos, batch_size): model.eval() result = [] for i in tqdm(range(0, len(sents), batch_size), desc='classifying... '): tag_seqs = model(sentences=sents[i:i + batch_size], pos=pos[i:i + batch_size]) result.extend(tag_seqs['pred_tags']) # f1, p, r return [[[w, t] for w, t in zip(s, r)] for s, r in zip(sents, result)] def entities_from_token_classes(tokens): ENTITY_BEGIN_REGEX = r"^B" # -(\w+)" ENTITY_MIDDLE_REGEX = r"^I" # -(\w+)" entities = [] current_entity = None start_index_of_current_entity = 0 end_index_of_current_entity = 0 for i, kls in enumerate(tokens): m = re.match(ENTITY_BEGIN_REGEX, kls) if m is not None: if current_entity is not None: entities.append({ "type": current_entity, "index": [start_index_of_current_entity, end_index_of_current_entity] }) # start of entity current_entity = m.string.split('-')[1] if '-' in m.string else '' start_index_of_current_entity = i end_index_of_current_entity = i continue m = re.match(ENTITY_MIDDLE_REGEX, kls) if current_entity is not None: if m is None: # after the end of this entity entities.append({ "type": current_entity, "index": [start_index_of_current_entity, end_index_of_current_entity] }) current_entity = None continue # in the middle of this entity end_index_of_current_entity = i # Add any remaining entity if current_entity is not None: entities.append({ "type": current_entity, "index": [start_index_of_current_entity, end_index_of_current_entity] }) return entities def calc_f1(targs, preds): stat_dict = { 'overall': {'unl_tp': 0, 'lab_tp': 0, 'targs': 0, 'preds': 0} } for sent_targs, sent_preds in zip(targs, preds): stat_dict['overall']['targs'] += len(sent_targs) stat_dict['overall']['preds'] += len(sent_preds) for pred in sent_preds: if pred['type'] not in stat_dict.keys(): stat_dict[pred['type']] = {'lab_tp': 0, 'targs': 0, 'preds': 0} stat_dict[pred['type']]['preds'] += 1 for targ in sent_targs: if targ['type'] not in stat_dict.keys(): stat_dict[targ['type']] = {'lab_tp': 0, 'targs': 0, 'preds': 0} stat_dict[targ['type']]['targs'] += 1 # is there a span that matches exactly? for pred in sent_preds: if pred['index'][0] == targ['index'][0] and pred['index'][1] == targ['index'][1]: stat_dict['overall']['unl_tp'] += 1 # if so do the tags match exactly? if pred['type'] == targ['type']: stat_dict['overall']['lab_tp'] += 1 stat_dict[targ['type']]['lab_tp'] += 1 for k in stat_dict.keys(): if k == 'overall': stat_dict[k]['unl_p'] = stat_dict[k]['unl_tp'] / stat_dict[k]['preds'] if stat_dict[k]['preds'] else 0 stat_dict[k]['unl_r'] = stat_dict[k]['unl_tp'] / stat_dict[k]['targs'] if stat_dict[k]['targs'] else 0 stat_dict[k]['unl_f1'] = 2 * stat_dict[k]['unl_p'] * stat_dict[k]['unl_r'] / ( stat_dict[k]['unl_p'] + stat_dict[k]['unl_r']) if ( stat_dict[k]['unl_p'] + stat_dict[k]['unl_r']) else 0 stat_dict[k]['lab_p'] = stat_dict[k]['lab_tp'] / stat_dict[k]['preds'] if stat_dict[k]['preds'] else 0 stat_dict[k]['lab_r'] = stat_dict[k]['lab_tp'] / stat_dict[k]['targs'] if stat_dict[k]['targs'] else 0 stat_dict[k]['lab_f1'] = 2 * stat_dict[k]['lab_p'] * stat_dict[k]['lab_r'] / ( stat_dict[k]['lab_p'] + stat_dict[k]['lab_r']) if (stat_dict[k]['lab_p'] + stat_dict[k]['lab_r']) else 0 class_f1s = [v['lab_f1'] for k, v in stat_dict.items() if k != 'overall'] stat_dict['overall']['macro_lab_f1'] = sum(class_f1s) / len(class_f1s) return stat_dict def main(args): global device device = torch.device('cuda' if use_cuda else 'cpu') test_columns = read_conll_ner(args.test_path) test_docs = split_conll_docs(test_columns[0]) test_data = create_context_data(test_docs, args.context_size) sents = [td[0] for td in test_data] pos = [td[1] for td in test_data] if len(args.model_path) > 1 or args.span_model_path is not None: model = StagedEnsemble(model_paths=args.model_path, span_model_paths=args.span_model_path, device=device) else: model = EntNet.load_model(args.model_path[0], device=device) model.to(device) BATCH_SIZE = args.batch_size res = classify(model, sents, pos, BATCH_SIZE) targets = [td[2] for td in test_data] targ_tags = [entities_from_token_classes(td[2]) for td in test_data] pred_tags = [entities_from_token_classes([t[1] for t in r]) for r in res] result = calc_f1(targ_tags, pred_tags) print(f'Overall unlabelled - F1:{result["overall"]["unl_f1"]}, ' f'P:{result["overall"]["unl_p"]}, ' f'R:{result["overall"]["unl_r"]}') print(f'Overall labelled - Micro F1:{result["overall"]["lab_f1"]}, ' f'P:{result["overall"]["lab_p"]}, ' f'R:{result["overall"]["lab_r"]}') print(f'Overall labelled - Macro F1:{result["overall"]["macro_lab_f1"]}') for k, v in result.items(): if k == 'overall': continue print(f'{k} - F1:{v["lab_f1"]}, P:{v["lab_p"]}, R:{v["lab_r"]}') if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--model_path', type=str, nargs='+', default=None, required=True, help='') parser.add_argument('--span_model_path', type=str, nargs='*', default=None, help='') # parser.add_argument('--network_type', type=str, # choices=['span', 'entity', 'joint'], required=True, # default=None, help='If entity is chosen, a path to a ' # 'span model is required also') parser.add_argument('--test_path', type=str, default=None, help='') parser.add_argument('--context_size', type=int, default=1, help='') parser.add_argument('--batch_size', type=int, default=8, help='') # parser.add_argument('--cuda_id', type=int, default=0, help='') args = parser.parse_args() main(args)