|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import numpy as np |
|
import argparse |
|
import os |
|
import sys |
|
import re |
|
import json |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.parallel |
|
import torch.backends.cudnn as cudnn |
|
import torch.optim as optim |
|
import torch.utils.data |
|
|
|
import encoder |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument('--vocab', type=str, default=None, help='vocab path') |
|
|
|
parser.add_argument('--sample_file', default=None, type=str, help='ft sample file') |
|
parser.add_argument('--input_file', default=None, type=str, help='ft input file') |
|
|
|
parser.add_argument('--output_ref_file', default=None, type=str, help='output reference file') |
|
parser.add_argument('--output_pred_file', default=None, type=str, help='output predicion file') |
|
|
|
parser.add_argument('--ref_unique_file', default=None, type=str, help='reference unique id file') |
|
|
|
parser.add_argument('--ref_type', default='e2e', choices=['e2e', 'webnlg', 'dart'], |
|
help='e2e style reference type; webnlg style reference type.') |
|
parser.add_argument('--ref_num', default=4, type=int, help='number of references.') |
|
|
|
|
|
parser.add_argument('--tokenize', action='store_true', help='') |
|
parser.add_argument('--lower', action='store_true', help='') |
|
|
|
parser.add_argument('--filter', default='all', choices=['all', 'seen', 'unseen'], |
|
help='for webnlg only, filter categories that are seen during training, unseen, or all') |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
def stardard_tokenize(sent): |
|
sent = ' '.join(re.split('(\W)', sent)) |
|
sent = sent.split() |
|
sent = ' '.join(sent) |
|
return sent |
|
|
|
|
|
def post_process(sent, is_tokenize, is_lower): |
|
if is_lower: |
|
sent = sent.lower() |
|
if is_tokenize: |
|
sent = stardard_tokenize(sent) |
|
|
|
return sent |
|
|
|
|
|
if __name__ == "__main__": |
|
enc = encoder.get_encoder(args.vocab) |
|
|
|
ref_unique = None |
|
|
|
if args.ref_unique_file is not None: |
|
print('reading ref_unique_file.') |
|
ref_unique = [] |
|
uniques = {} |
|
with open(args.ref_unique_file, 'r') as ref_unique_reader: |
|
for line in ref_unique_reader: |
|
_id = int(line.strip()) |
|
ref_unique.append(_id) |
|
uniques[_id] = 1 |
|
print('len refer dict', len(ref_unique), 'unique', len(uniques)) |
|
|
|
with open(args.sample_file, 'r') as sample_reader, \ |
|
open(args.input_file, 'r', encoding='utf8') as input_reader, \ |
|
open(args.output_pred_file, 'w', encoding='utf8') as pred_writer: |
|
|
|
refer_dict = {} |
|
context_list = [] |
|
line_id = 0 |
|
for line in input_reader: |
|
items = json.loads(line.strip()) |
|
context = items['context'] |
|
completion = items['completion'] |
|
|
|
context_list.append(context) |
|
|
|
keep = False |
|
|
|
if args.filter == 'all': |
|
keep = True |
|
if args.filter == 'seen' and items['cate']: |
|
keep = True |
|
if args.filter == 'unseen' and not items['cate']: |
|
keep = True |
|
|
|
if ref_unique is None: |
|
_key = context |
|
else: |
|
_key = ref_unique[line_id] |
|
|
|
if keep: |
|
if not _key in refer_dict: |
|
refer_dict[_key] = {} |
|
refer_dict[_key]['references'] = [] |
|
refer_dict[_key]['references'].append(completion.split('<|endoftext|>')[0].split('\n\n')[0].strip()) |
|
|
|
line_id += 1 |
|
if line_id==1000: |
|
break |
|
|
|
print('unique refer dict', len(refer_dict)) |
|
|
|
for line in sample_reader: |
|
items = json.loads(line.strip()) |
|
_id = items['id'] |
|
_pred_tokens = items['predict'] |
|
|
|
if ref_unique is None: |
|
_key = context_list[_id] |
|
else: |
|
_key = ref_unique[_id] |
|
|
|
|
|
|
|
if not _key in refer_dict: |
|
refer_dict[_key] = {} |
|
refer_dict[_key]['sample'] = [] |
|
refer_dict[_key]['sample'] = enc.decode(_pred_tokens).split('<|endoftext|>')[0].split('\n\n')[0].strip() |
|
|
|
references = [refer_dict[s]['references'] for s in refer_dict] |
|
hypothesis = [refer_dict[s]['sample'] for s in refer_dict] |
|
|
|
if args.ref_type == 'e2e': |
|
with open(args.output_ref_file, 'w', encoding='utf8') as ref_writer: |
|
for ref, hyp in zip(references, hypothesis): |
|
for r in ref: |
|
ref_writer.write(post_process(r, args.tokenize, args.lower) + '\n') |
|
ref_writer.write('\n') |
|
pred_writer.write(post_process(hyp, args.tokenize, args.lower) + '\n') |
|
|
|
elif args.ref_type in ['webnlg', 'dart']: |
|
if not os.path.exists(args.output_ref_file): |
|
os.makedirs(args.output_ref_file) |
|
|
|
reference_writers = [ |
|
open(os.path.join(args.output_ref_file, f'reference{fid}'), 'w', encoding='utf8') |
|
for fid in range(0, args.ref_num) |
|
] |
|
|
|
for ref, hyp in zip(references, hypothesis): |
|
for fid in range(0, args.ref_num): |
|
if len(ref) > fid: |
|
reference_writers[fid].write(post_process(ref[fid], args.tokenize, args.lower) + '\n') |
|
else: |
|
reference_writers[fid].write(post_process(ref[0], args.tokenize, args.lower) + '\n') |
|
pred_writer.write(post_process(hyp, args.tokenize, args.lower) + '\n') |
|
|
|
for writer in reference_writers: |
|
writer.close() |
|
|