import sys import os sys.path.append("./") sys.path.append("../") sys.path.append("../../") sys.path.append("../../../") from models import SPAN_EXTRACTION_MODEL_CLASSES from models import TOKENIZER_CLASSES import numpy as np import torch class HugIEAPI: def __init__(self, model_type, hugie_model_name_or_path) -> None: if model_type not in SPAN_EXTRACTION_MODEL_CLASSES[ "global_pointer"].keys(): raise KeyError( "You must choose one of the following model: {}".format( ", ".join( list(SPAN_EXTRACTION_MODEL_CLASSES["global_pointer"]. keys())))) self.model_type = model_type self.model = SPAN_EXTRACTION_MODEL_CLASSES["global_pointer"][ self.model_type].from_pretrained(hugie_model_name_or_path) self.tokenizer = TOKENIZER_CLASSES[self.model_type].from_pretrained( hugie_model_name_or_path) self.max_seq_length = 512 def fush_multi_answer(self, has_answer, new_answer): # 对于某个id测试集,出现多个example时(例如同一个测试样本使用了多个模板而生成了多个example),此时将预测的topk结果进行合并 # has为已经合并的结果,new为当前新产生的结果, # has格式为 {"ans": {"prob": float(prob[index_ids[ei]]), "pos": (s, e)}, ...} # new {"ans": {"prob": float(prob[index_ids[ei]]), "pos": (s, e)}, ...} # print("has_answer=", has_answer) for ans, value in new_answer.items(): if ans not in has_answer.keys(): has_answer[ans] = value else: has_answer[ans]["prob"] += value["prob"] has_answer[ans]["pos"].extend(value["pos"]) return has_answer def get_predict_result(self, probs, indices, examples): probs = probs.squeeze(1) # topk结果的概率 indices = indices.squeeze(1) # topk结果的索引 # print("probs=", probs) # [n, m] # print("indices=", indices) # [n, m] predictions = {} topk_predictions = {} idx = 0 for prob, index in zip(probs, indices): index_ids = torch.Tensor([i for i in range(len(index))]).long() topk_answer = list() answer = [] topk_answer_dict = dict() # TODO 1. 调节阈值 2. 处理输出实体重叠问题 entity_index = index[prob > 0.1] index_ids = index_ids[prob > 0.1] for ei, entity in enumerate(entity_index): # 1D index转2D index start_end = np.unravel_index( entity, (self.max_seq_length, self.max_seq_length)) s = examples["offset_mapping"][idx][start_end[0]][0] e = examples["offset_mapping"][idx][start_end[1]][1] ans = examples["content"][idx][s:e] if ans not in answer: answer.append(ans) # topk_answer.append({"answer": ans, "prob": float(prob[index_ids[ei]]), "pos": (s, e)}) topk_answer_dict[ans] = { "prob": float(prob[index_ids[ei]]), "pos": [(s.detach().cpu().numpy().tolist(), e.detach().cpu().numpy().tolist())] } predictions[idx] = answer if idx not in topk_predictions.keys(): # print("topk_answer_dict=", topk_answer_dict) topk_predictions[idx] = topk_answer_dict else: # print("topk_predictions[id_]=", topk_predictions[id_]) topk_predictions[idx] = self.fush_multi_answer( topk_predictions[idx], topk_answer_dict) idx += 1 for idx, values in topk_predictions.items(): # values {"ans": {}, ...} answer_list = list() for ans, value in values.items(): answer_list.append({ "answer": ans, "prob": value["prob"], "pos": value["pos"] }) topk_predictions[idx] = answer_list return predictions, topk_predictions def request(self, text: str, entity_type: str, relation: str = None): assert text is not None and entity_type is not None if relation is None: instruction = "找到文章中所有【{}】类型的实体?文章:【{}】".format(entity_type, text) pre_len = 21 - 2 + len(entity_type) else: instruction = "找到文章中【{}】的【{}】?文章:【{}】".format( entity_type, relation, text) pre_len = 19 - 4 + len(entity_type) + len(relation) inputs = self.tokenizer(instruction, max_length=self.max_seq_length, padding="max_length", return_tensors="pt", return_offsets_mapping=True) examples = { "content": [instruction], "offset_mapping": inputs["offset_mapping"] } batch_input = { "input_ids": inputs["input_ids"], "token_type_ids": inputs["token_type_ids"], "attention_mask": inputs["attention_mask"], } outputs = self.model(**batch_input) probs, indices = outputs["topk_probs"], outputs["topk_indices"] predictions, topk_predictions = self.get_predict_result( probs, indices, examples=examples) return predictions, topk_predictions if __name__ == "__main__": from applications.information_extraction.HugIE.api_test import HugIEAPI model_type = "bert" hugie_model_name_or_path = "wjn1996/wjn1996-hugnlp-hugie-large-zh" hugie = HugIEAPI("bert", hugie_model_name_or_path) text = "央广网北京2月23日消息 据中国地震台网正式测定,2月23日8时37分在塔吉克斯坦发生7.2级地震,震源深度10公里,震中位于北纬37.98度,东经73.29度,距我国边境线最近约82公里,地震造成新疆喀什等地震感强烈。" ## named entity recognition entity_type = "国家" predictions, topk_predictions = hugie.request(text, entity_type) print("entity_type:{}".format(entity_type)) print("predictions:\n{}".format(predictions)) print("topk_predictions:\n{}".format(topk_predictions)) print("\n\n") ## event extraction entity = "塔吉克斯坦地震" relation = "震源深度" predictions, topk_predictions = hugie.request(text, entity, relation=relation) print("entity:{}, relation:{}".format(entity, relation)) print("predictions:\n{}".format(predictions)) print("topk_predictions:\n{}".format(topk_predictions)) print("\n\n") ## event extraction entity = "塔吉克斯坦地震" relation = "震源位置" predictions, topk_predictions = hugie.request(text, entity, relation=relation) print("entity:{}, relation:{}".format(entity, relation)) print("predictions:\n{}".format(predictions)) print("topk_predictions:\n{}".format(topk_predictions)) print("\n\n") ## event extraction entity = "塔吉克斯坦地震" relation = "时间" predictions, topk_predictions = hugie.request(text, entity, relation=relation) print("entity:{}, relation:{}".format(entity, relation)) print("predictions:\n{}".format(predictions)) print("topk_predictions:\n{}".format(topk_predictions)) print("\n\n") ## event extraction entity = "塔吉克斯坦地震" relation = "影响" predictions, topk_predictions = hugie.request(text, entity, relation=relation) print("entity:{}, relation:{}".format(entity, relation)) print("predictions:\n{}".format(predictions)) print("topk_predictions:\n{}".format(topk_predictions)) print("\n\n") """ Output results: entity_type:国家 predictions: {0: ["塔吉克斯坦"]} predictions: {0: [{"answer": "塔吉克斯坦", "prob": 0.9999997615814209, "pos": [(tensor(57), tensor(62))]}]} entity:塔吉克斯坦地震, relation:震源深度 predictions: {0: ["10公里"]} predictions: {0: [{"answer": "10公里", "prob": 0.999994158744812, "pos": [(tensor(80), tensor(84))]}]} entity:塔吉克斯坦地震, relation:震源位置 predictions: {0: ["10公里", "距我国边境线最近约82公里", "北纬37.98度,东经73.29度", "北纬37.98度,东经73.29度,距我国边境线最近约82公里"]} predictions: {0: [{"answer": "10公里", "prob": 0.9895901083946228, "pos": [(tensor(80), tensor(84))]}, {"answer": "距我国边境线最近约82公里", "prob": 0.8584909439086914, "pos": [(tensor(107), tensor(120))]}, {"answer": "北纬37.98度,东经73.29度", "prob": 0.7202121615409851, "pos": [(tensor(89), tensor(106))]}, {"answer": "北纬37.98度,东经73.29度,距我国边境线最近约82公里", "prob": 0.11628123372793198, "pos": [(tensor(89), tensor(120))]}]} entity:塔吉克斯坦地震, relation:时间 predictions: {0: ["2月23日8时37分"]} predictions: {0: [{"answer": "2月23日8时37分", "prob": 0.9999995231628418, "pos": [(tensor(49), tensor(59))]}]} entity:塔吉克斯坦地震, relation:影响 predictions: {0: ["新疆喀什等地震感强烈"]} predictions: {0: [{"answer": "新疆喀什等地震感强烈", "prob": 0.9525265693664551, "pos": [(tensor(123), tensor(133))]}]} """