fclong's picture
Upload 396 files
8ebda9e
import json
from tqdm import tqdm
import os
from sklearn.utils import shuffle
import re
import argparse
def cut_sent(para):
para = re.sub('([。,,!?\?])([^”’])', r"\1\n\2", para) # 单字符断句符
para = re.sub('(\.{6})([^”’])', r"\1\n\2", para) # 英文省略号
para = re.sub('(\…{2})([^”’])', r"\1\n\2", para) # 中文省略号
para = re.sub('([。!?\?][”’])([^,。!?\?])', r'\1\n\2', para)
# 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号
para = para.rstrip() # 段尾如果有多余的\n就去掉它
# 很多规则中会考虑分号;,但是这里
return para.split("\n")
def search(pattern, sequence):
n = len(pattern)
res=[]
for i in range(len(sequence)):
if sequence[i:i + n] == pattern:
res.append([i,i + n-1])
return res
max_length=512
stride=128
def stride_split(question, context, answer, start):
end = start + len(answer) -1
results, n = [], 0
max_c_len = max_length - len(question) - 3
while True:
left, right = n * stride, n * stride + max_c_len
if left <= start < end <= right:
results.append((question, context[left:right], answer, start - left, end - left))
elif right < start or end < right:
results.append((question, context[left:right], '', -1, -1))
if right >= len(context):
return results
n += 1
def load_data(file_path,is_training=False):
task_type='抽取任务'
subtask_type='抽取式阅读理解'
with open(file_path, 'r', encoding='utf8') as f:
lines = json.loads(''.join(f.readlines()))
result=[]
lines = lines['data']
for line in tqdm(lines):
if line['paragraphs']==[]:
continue
data = line['paragraphs'][0]
context=data['context'].strip()
for qa in data['qas']:
question=qa['question'].strip()
rcv=[]
for a in qa['answers']:
if a not in rcv:
rcv.append(a)
split=stride_split(question, context, a['text'], a['answer_start'])
for sp in split:
choices = []
choice = {}
choice['id']=qa['id']
choice['entity_type'] = qa['question']
choice['label']=0
entity_list=[]
if sp[3]>=0 and sp[4]>=0:
entity_list.append({'entity_name':sp[2],'entity_type':'','entity_idx':[[sp[3],sp[4]]]})
choice['entity_list']=entity_list
choices.append(choice)
if choices==[]:
print(data)
continue
result.append({ 'task_type':task_type,
'subtask_type':subtask_type,
'text':sp[1],
'choices':choices,
'id':0})
return result
def save_data(data,file_path):
with open(file_path, 'w', encoding='utf8') as f:
for line in data:
json_data=json.dumps(line,ensure_ascii=False)
f.write(json_data+'\n')
if __name__=="__main__":
parser = argparse.ArgumentParser(description="train")
parser.add_argument("--data_path", type=str,default="")
parser.add_argument("--save_path", type=str,default="")
args = parser.parse_args()
data_path = args.data_path
save_path = args.save_path
if not os.path.exists(save_path):
os.makedirs(save_path)
file_list=['dev','train','trial','test']
train_data = []
dev_data = []
for file in file_list:
file_path = os.path.join(data_path,file+'.json')
data=load_data(file_path=file_path)
if 'train' in file or 'trial' in file:
train_data.extend(data)
else:
output_path = os.path.join(save_path,file+'.json')
save_data(data,output_path)
output_path = os.path.join(save_path,'train.json')
save_data(train_data,output_path)