svjack commited on
Commit
391a43c
·
1 Parent(s): 324d529

Upload qa_on_context.py

Browse files
Files changed (1) hide show
  1. qa_on_context.py +141 -0
qa_on_context.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### py39_cp_cp
2
+ from zh_mt5_model import *
3
+ from en_t2t_model import *
4
+
5
+ import os
6
+ os.environ['KMP_DUPLICATE_LIB_OK']='True'
7
+
8
+ import spacy
9
+ import pandas as pd
10
+ import numpy as np
11
+
12
+ import re
13
+ from tqdm import tqdm
14
+ from copy import deepcopy
15
+ import pathlib
16
+
17
+ import json
18
+ import pickle as pkl
19
+
20
+ from tqdm import tqdm
21
+ from easynmt import EasyNMT
22
+
23
+ ### https://huggingface.co/svjack/squad_gen_qst_zh_v0
24
+ path = "svjack/squad_gen_qst_zh_v0"
25
+ asker_zh = T5_B(path,
26
+ device = "cpu")
27
+
28
+ zh_nlp = spacy.load("zh_core_web_sm")
29
+ en_nlp = spacy.load("en_core_web_sm")
30
+
31
+ trans_model = EasyNMT('opus-mt')
32
+
33
+ def detect_language(text):
34
+ assert type(text) == type("")
35
+ # detect_list.append(trans_model.language_detection_fasttext(prompt))
36
+ lang = trans_model.language_detection_fasttext(text)
37
+ lang = lang.lower().strip()
38
+ if "zh" not in lang and "en" not in lang:
39
+ lang = "others"
40
+ if "zh" in lang:
41
+ lang = "zh"
42
+ if "en" in lang:
43
+ lang = "en"
44
+ assert lang in ["en", "zh", "others"]
45
+ return lang
46
+
47
+ def drop_duplicates_by_col(df, on_col = "aug_sparql_query"):
48
+ assert hasattr(df, "size")
49
+ assert on_col in df.columns.tolist()
50
+ req = []
51
+ set_ = set([])
52
+ for i, r in df.iterrows():
53
+ if r[on_col] not in set_:
54
+ set_.add(r[on_col])
55
+ req.append(r)
56
+ return pd.DataFrame(req)
57
+
58
+ def sent_with_ents(sent, en_nlp):
59
+ assert type(sent) == type("")
60
+ doc = en_nlp(sent)
61
+ return (sent, pd.Series(doc.ents).map(
62
+ lambda span: (span.text, span.label_)
63
+ ).values.tolist())
64
+
65
+ def gen_ask_by_span_zh(asker ,sent, span):
66
+ if type(span) == type(""):
67
+ span = [span]
68
+ if not span:
69
+ return []
70
+ sent = sent.replace("|", "")
71
+ span = list(map(lambda x: x.replace("|", ""), span))
72
+ x = list(map(lambda x: "{}|{}".format(sent, x), span))
73
+ return list(map(
74
+ lambda y: asker.predict(y)
75
+ , x))
76
+
77
+ #### list return
78
+ def gen_ask_by_span(asker, sent, span, lang):
79
+ assert lang in ["en", "zh"]
80
+ if lang == "zh":
81
+ return gen_ask_by_span_zh(asker ,sent, span)
82
+ else:
83
+ return gen_ask_by_span_en(t2t, sent, span)
84
+
85
+
86
+ def filter_ent_cate(ent_list, maintain_cate_list = [
87
+ "DATE", "FAC", "GPE", "LOC", "PERSON"
88
+ ]):
89
+ if not ent_list:
90
+ return []
91
+ return list(filter(lambda t2: t2[1] in maintain_cate_list, ent_list))
92
+
93
+ def batch_as_list(a, batch_size = int(100000)):
94
+ req = []
95
+ for ele in a:
96
+ if not req:
97
+ req.append([])
98
+ if len(req[-1]) < batch_size:
99
+ req[-1].append(ele)
100
+ else:
101
+ req.append([])
102
+ req[-1].append(ele)
103
+ return req
104
+
105
+ def gen_qst_to_df(paragraph,
106
+ nlp = zh_nlp,
107
+ asker = asker_zh,
108
+ nlp_input = None,
109
+ maintain_cate_list = [
110
+ "DATE", "FAC", "GPE", "LOC", "PERSON"
111
+ ], limit_ents_size = 10, batch_size = 4
112
+ ):
113
+ if limit_ents_size is None:
114
+ limit_ents_size = 10000
115
+ assert type(paragraph) == type("")
116
+ lang = detect_language(paragraph)
117
+ if lang != "zh":
118
+ lang = "en"
119
+ nlp = en_nlp if lang == "en" else zh_nlp
120
+
121
+ if nlp_input is None:
122
+ _, entity_list = sent_with_ents(paragraph, nlp)
123
+ else:
124
+ _, entity_list = deepcopy(nlp_input)
125
+ if maintain_cate_list:
126
+ entity_list = filter_ent_cate(entity_list, maintain_cate_list = maintain_cate_list)
127
+ entity_list = entity_list[:limit_ents_size]
128
+ if not entity_list:
129
+ return None
130
+ l = batch_as_list(entity_list, batch_size)
131
+ for ele in tqdm(l):
132
+ ents = list(map(lambda x: x[0], ele))
133
+ ent_cates = list(map(lambda x: x[1], ele))
134
+ #questions = gen_ask_by_span_zh(asker, paragraph, ents)
135
+ questions = gen_ask_by_span(asker, paragraph, ents, lang)
136
+ assert len(ele) == len(ent_cates) == len(questions)
137
+ #return [ele, ent_cates, questions, ans]
138
+ batch_l = list(map(pd.Series, [ents, ent_cates, questions]))
139
+ batch_df = pd.concat(batch_l, axis = 1)
140
+ batch_df.columns = ["entity", "entity_cate", "question",]
141
+ yield batch_df