Spaces:
Build error
Build error
import argparse | |
import re | |
import torch | |
from tqdm.auto import tqdm | |
from network import EntNet | |
from utils import read_conll_ner, split_conll_docs, create_context_data, extract_spans | |
use_cuda = torch.cuda.is_available() | |
device = torch.device("cuda" if use_cuda else "cpu") | |
def classify(model, sents, pos, batch_size): | |
model.eval() | |
result = [] | |
for i in tqdm(range(0, len(sents), batch_size), desc='classifying... '): | |
tag_seqs = model(sentences=sents[i:i + batch_size], | |
pos=pos[i:i + batch_size]) | |
result.extend(tag_seqs['pred_tags']) | |
# f1, p, r | |
return [[[w, t] for w, t in zip(s, r)] for s, r in zip(sents, result)] | |
def entities_from_token_classes(tokens): | |
ENTITY_BEGIN_REGEX = r"^B" # -(\w+)" | |
ENTITY_MIDDLE_REGEX = r"^I" # -(\w+)" | |
entities = [] | |
current_entity = None | |
start_index_of_current_entity = 0 | |
end_index_of_current_entity = 0 | |
for i, kls in enumerate(tokens): | |
m = re.match(ENTITY_BEGIN_REGEX, kls) | |
if m is not None: | |
if current_entity is not None: | |
entities.append({ | |
"type": current_entity, | |
"index": [start_index_of_current_entity, | |
end_index_of_current_entity] | |
}) | |
# start of entity | |
current_entity = m.string.split('-')[1] if '-' in m.string else '' | |
start_index_of_current_entity = i | |
end_index_of_current_entity = i | |
continue | |
m = re.match(ENTITY_MIDDLE_REGEX, kls) | |
if current_entity is not None: | |
if m is None: | |
# after the end of this entity | |
entities.append({ | |
"type": current_entity, | |
"index": [start_index_of_current_entity, | |
end_index_of_current_entity] | |
}) | |
current_entity = None | |
continue | |
# in the middle of this entity | |
end_index_of_current_entity = i | |
# Add any remaining entity | |
if current_entity is not None: | |
entities.append({ | |
"type": current_entity, | |
"index": [start_index_of_current_entity, | |
end_index_of_current_entity] | |
}) | |
return entities | |
def calc_f1(targs, preds): | |
stat_dict = { | |
'overall': {'unl_tp': 0, 'lab_tp': 0, 'targs': 0, 'preds': 0} | |
} | |
for sent_targs, sent_preds in zip(targs, preds): | |
stat_dict['overall']['targs'] += len(sent_targs) | |
stat_dict['overall']['preds'] += len(sent_preds) | |
for pred in sent_preds: | |
if pred['type'] not in stat_dict.keys(): | |
stat_dict[pred['type']] = {'lab_tp': 0, 'targs': 0, 'preds': 0} | |
stat_dict[pred['type']]['preds'] += 1 | |
for targ in sent_targs: | |
if targ['type'] not in stat_dict.keys(): | |
stat_dict[targ['type']] = {'lab_tp': 0, 'targs': 0, 'preds': 0} | |
stat_dict[targ['type']]['targs'] += 1 | |
# is there a span that matches exactly? | |
for pred in sent_preds: | |
if pred['index'][0] == targ['index'][0] and pred['index'][1] == targ['index'][1]: | |
stat_dict['overall']['unl_tp'] += 1 | |
# if so do the tags match exactly? | |
if pred['type'] == targ['type']: | |
stat_dict['overall']['lab_tp'] += 1 | |
stat_dict[targ['type']]['lab_tp'] += 1 | |
for k in stat_dict.keys(): | |
if k == 'overall': | |
stat_dict[k]['unl_p'] = stat_dict[k]['unl_tp'] / stat_dict[k]['preds'] if stat_dict[k]['preds'] else 0 | |
stat_dict[k]['unl_r'] = stat_dict[k]['unl_tp'] / stat_dict[k]['targs'] if stat_dict[k]['targs'] else 0 | |
stat_dict[k]['unl_f1'] = 2 * stat_dict[k]['unl_p'] * stat_dict[k]['unl_r'] / ( | |
stat_dict[k]['unl_p'] + stat_dict[k]['unl_r']) if ( | |
stat_dict[k]['unl_p'] + stat_dict[k]['unl_r']) else 0 | |
stat_dict[k]['lab_p'] = stat_dict[k]['lab_tp'] / stat_dict[k]['preds'] if stat_dict[k]['preds'] else 0 | |
stat_dict[k]['lab_r'] = stat_dict[k]['lab_tp'] / stat_dict[k]['targs'] if stat_dict[k]['targs'] else 0 | |
stat_dict[k]['lab_f1'] = 2 * stat_dict[k]['lab_p'] * stat_dict[k]['lab_r'] / ( | |
stat_dict[k]['lab_p'] + stat_dict[k]['lab_r']) if (stat_dict[k]['lab_p'] + stat_dict[k]['lab_r']) else 0 | |
class_f1s = [v['lab_f1'] for k, v in stat_dict.items() if k != 'overall'] | |
stat_dict['overall']['macro_lab_f1'] = sum(class_f1s) / len(class_f1s) | |
return stat_dict | |
def main(args): | |
global device | |
device = torch.device('cuda' if use_cuda else 'cpu') | |
test_columns = read_conll_ner(args.test_path) | |
test_docs = split_conll_docs(test_columns[0]) | |
test_data = create_context_data(test_docs, args.context_size) | |
sents = [td[0] for td in test_data] | |
pos = [td[1] for td in test_data] | |
if len(args.model_path) > 1 or args.span_model_path is not None: | |
model = StagedEnsemble(model_paths=args.model_path, span_model_paths=args.span_model_path, device=device) | |
else: | |
model = EntNet.load_model(args.model_path[0], device=device) | |
model.to(device) | |
BATCH_SIZE = args.batch_size | |
res = classify(model, sents, pos, BATCH_SIZE) | |
targets = [td[2] for td in test_data] | |
targ_tags = [entities_from_token_classes(td[2]) for td in test_data] | |
pred_tags = [entities_from_token_classes([t[1] for t in r]) for r in res] | |
result = calc_f1(targ_tags, pred_tags) | |
print(f'Overall unlabelled - F1:{result["overall"]["unl_f1"]}, ' | |
f'P:{result["overall"]["unl_p"]}, ' | |
f'R:{result["overall"]["unl_r"]}') | |
print(f'Overall labelled - Micro F1:{result["overall"]["lab_f1"]}, ' | |
f'P:{result["overall"]["lab_p"]}, ' | |
f'R:{result["overall"]["lab_r"]}') | |
print(f'Overall labelled - Macro F1:{result["overall"]["macro_lab_f1"]}') | |
for k, v in result.items(): | |
if k == 'overall': | |
continue | |
print(f'{k} - F1:{v["lab_f1"]}, P:{v["lab_p"]}, R:{v["lab_r"]}') | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--model_path', type=str, nargs='+', default=None, required=True, help='') | |
parser.add_argument('--span_model_path', type=str, nargs='*', default=None, help='') | |
# parser.add_argument('--network_type', type=str, | |
# choices=['span', 'entity', 'joint'], required=True, | |
# default=None, help='If entity is chosen, a path to a ' | |
# 'span model is required also') | |
parser.add_argument('--test_path', type=str, default=None, help='') | |
parser.add_argument('--context_size', type=int, default=1, help='') | |
parser.add_argument('--batch_size', type=int, default=8, help='') | |
# parser.add_argument('--cuda_id', type=int, default=0, help='') | |
args = parser.parse_args() | |
main(args) | |