Spaces:
Running
Running
File size: 9,968 Bytes
08f4077 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
import sys
import os
sys.path.append("./")
sys.path.append("../")
sys.path.append("../../")
sys.path.append("../../../")
from models import SPAN_EXTRACTION_MODEL_CLASSES
from models import TOKENIZER_CLASSES
import numpy as np
import torch
class HugIEAPI:
def __init__(self, model_type, hugie_model_name_or_path) -> None:
if model_type not in SPAN_EXTRACTION_MODEL_CLASSES[
"global_pointer"].keys():
raise KeyError(
"You must choose one of the following model: {}".format(
", ".join(
list(SPAN_EXTRACTION_MODEL_CLASSES["global_pointer"].
keys()))))
self.model_type = model_type
self.model = SPAN_EXTRACTION_MODEL_CLASSES["global_pointer"][
self.model_type].from_pretrained(hugie_model_name_or_path)
self.tokenizer = TOKENIZER_CLASSES[self.model_type].from_pretrained(
hugie_model_name_or_path)
self.max_seq_length = 512
def fush_multi_answer(self, has_answer, new_answer):
# 对于某个id测试集,出现多个example时(例如同一个测试样本使用了多个模板而生成了多个example),此时将预测的topk结果进行合并
# has为已经合并的结果,new为当前新产生的结果,
# has格式为 {"ans": {"prob": float(prob[index_ids[ei]]), "pos": (s, e)}, ...}
# new {"ans": {"prob": float(prob[index_ids[ei]]), "pos": (s, e)}, ...}
# print("has_answer=", has_answer)
for ans, value in new_answer.items():
if ans not in has_answer.keys():
has_answer[ans] = value
else:
has_answer[ans]["prob"] += value["prob"]
has_answer[ans]["pos"].extend(value["pos"])
return has_answer
def get_predict_result(self, probs, indices, examples):
probs = probs.squeeze(1) # topk结果的概率
indices = indices.squeeze(1) # topk结果的索引
# print("probs=", probs) # [n, m]
# print("indices=", indices) # [n, m]
predictions = {}
topk_predictions = {}
idx = 0
for prob, index in zip(probs, indices):
index_ids = torch.Tensor([i for i in range(len(index))]).long()
topk_answer = list()
answer = []
topk_answer_dict = dict()
# TODO 1. 调节阈值 2. 处理输出实体重叠问题
entity_index = index[prob > 0.1]
index_ids = index_ids[prob > 0.1]
for ei, entity in enumerate(entity_index):
# 1D index转2D index
start_end = np.unravel_index(
entity, (self.max_seq_length, self.max_seq_length))
s = examples["offset_mapping"][idx][start_end[0]][0]
e = examples["offset_mapping"][idx][start_end[1]][1]
ans = examples["content"][idx][s:e]
if ans not in answer:
answer.append(ans)
# topk_answer.append({"answer": ans, "prob": float(prob[index_ids[ei]]), "pos": (s, e)})
topk_answer_dict[ans] = {
"prob":
float(prob[index_ids[ei]]),
"pos": [(s.detach().cpu().numpy().tolist(),
e.detach().cpu().numpy().tolist())]
}
predictions[idx] = answer
if idx not in topk_predictions.keys():
# print("topk_answer_dict=", topk_answer_dict)
topk_predictions[idx] = topk_answer_dict
else:
# print("topk_predictions[id_]=", topk_predictions[id_])
topk_predictions[idx] = self.fush_multi_answer(
topk_predictions[idx], topk_answer_dict)
idx += 1
for idx, values in topk_predictions.items():
# values {"ans": {}, ...}
answer_list = list()
for ans, value in values.items():
answer_list.append({
"answer": ans,
"prob": value["prob"],
"pos": value["pos"]
})
topk_predictions[idx] = answer_list
return predictions, topk_predictions
def request(self, text: str, entity_type: str, relation: str = None):
assert text is not None and entity_type is not None
if relation is None:
instruction = "找到文章中所有【{}】类型的实体?文章:【{}】".format(entity_type, text)
pre_len = 21 - 2 + len(entity_type)
else:
instruction = "找到文章中【{}】的【{}】?文章:【{}】".format(
entity_type, relation, text)
pre_len = 19 - 4 + len(entity_type) + len(relation)
inputs = self.tokenizer(instruction,
max_length=self.max_seq_length,
padding="max_length",
return_tensors="pt",
return_offsets_mapping=True)
examples = {
"content": [instruction],
"offset_mapping": inputs["offset_mapping"]
}
batch_input = {
"input_ids": inputs["input_ids"],
"token_type_ids": inputs["token_type_ids"],
"attention_mask": inputs["attention_mask"],
}
outputs = self.model(**batch_input)
probs, indices = outputs["topk_probs"], outputs["topk_indices"]
predictions, topk_predictions = self.get_predict_result(
probs, indices, examples=examples)
return predictions, topk_predictions
if __name__ == "__main__":
from applications.information_extraction.HugIE.api_test import HugIEAPI
model_type = "bert"
hugie_model_name_or_path = "wjn1996/wjn1996-hugnlp-hugie-large-zh"
hugie = HugIEAPI("bert", hugie_model_name_or_path)
text = "央广网北京2月23日消息 据中国地震台网正式测定,2月23日8时37分在塔吉克斯坦发生7.2级地震,震源深度10公里,震中位于北纬37.98度,东经73.29度,距我国边境线最近约82公里,地震造成新疆喀什等地震感强烈。"
## named entity recognition
entity_type = "国家"
predictions, topk_predictions = hugie.request(text, entity_type)
print("entity_type:{}".format(entity_type))
print("predictions:\n{}".format(predictions))
print("topk_predictions:\n{}".format(topk_predictions))
print("\n\n")
## event extraction
entity = "塔吉克斯坦地震"
relation = "震源深度"
predictions, topk_predictions = hugie.request(text,
entity,
relation=relation)
print("entity:{}, relation:{}".format(entity, relation))
print("predictions:\n{}".format(predictions))
print("topk_predictions:\n{}".format(topk_predictions))
print("\n\n")
## event extraction
entity = "塔吉克斯坦地震"
relation = "震源位置"
predictions, topk_predictions = hugie.request(text,
entity,
relation=relation)
print("entity:{}, relation:{}".format(entity, relation))
print("predictions:\n{}".format(predictions))
print("topk_predictions:\n{}".format(topk_predictions))
print("\n\n")
## event extraction
entity = "塔吉克斯坦地震"
relation = "时间"
predictions, topk_predictions = hugie.request(text,
entity,
relation=relation)
print("entity:{}, relation:{}".format(entity, relation))
print("predictions:\n{}".format(predictions))
print("topk_predictions:\n{}".format(topk_predictions))
print("\n\n")
## event extraction
entity = "塔吉克斯坦地震"
relation = "影响"
predictions, topk_predictions = hugie.request(text,
entity,
relation=relation)
print("entity:{}, relation:{}".format(entity, relation))
print("predictions:\n{}".format(predictions))
print("topk_predictions:\n{}".format(topk_predictions))
print("\n\n")
"""
Output results:
entity_type:国家
predictions:
{0: ["塔吉克斯坦"]}
predictions:
{0: [{"answer": "塔吉克斯坦", "prob": 0.9999997615814209, "pos": [(tensor(57), tensor(62))]}]}
entity:塔吉克斯坦地震, relation:震源深度
predictions:
{0: ["10公里"]}
predictions:
{0: [{"answer": "10公里", "prob": 0.999994158744812, "pos": [(tensor(80), tensor(84))]}]}
entity:塔吉克斯坦地震, relation:震源位置
predictions:
{0: ["10公里", "距我国边境线最近约82公里", "北纬37.98度,东经73.29度", "北纬37.98度,东经73.29度,距我国边境线最近约82公里"]}
predictions:
{0: [{"answer": "10公里", "prob": 0.9895901083946228, "pos": [(tensor(80), tensor(84))]}, {"answer": "距我国边境线最近约82公里", "prob": 0.8584909439086914, "pos": [(tensor(107), tensor(120))]}, {"answer": "北纬37.98度,东经73.29度", "prob": 0.7202121615409851, "pos": [(tensor(89), tensor(106))]}, {"answer": "北纬37.98度,东经73.29度,距我国边境线最近约82公里", "prob": 0.11628123372793198, "pos": [(tensor(89), tensor(120))]}]}
entity:塔吉克斯坦地震, relation:时间
predictions:
{0: ["2月23日8时37分"]}
predictions:
{0: [{"answer": "2月23日8时37分", "prob": 0.9999995231628418, "pos": [(tensor(49), tensor(59))]}]}
entity:塔吉克斯坦地震, relation:影响
predictions:
{0: ["新疆喀什等地震感强烈"]}
predictions:
{0: [{"answer": "新疆喀什等地震感强烈", "prob": 0.9525265693664551, "pos": [(tensor(123), tensor(133))]}]}
"""
|