File size: 3,961 Bytes
95a3ca6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import re
import os
import sys
from tqdm import tqdm


def remove_bpe(line, bpe_symbol="@@ "):
    line = line.replace("\n", '')
    line = (line + ' ').replace(bpe_symbol, '').rstrip()
    return line + ("\n")


def remove_bpe_fn(i=sys.stdin, o=sys.stdout, bpe="@@ "):
    lines = tqdm(i)
    lines = map(lambda x: remove_bpe(x, bpe), lines)
    # _write_lines(lines, f=o)
    for line in lines:
        o.write(line)


def reprocess(fle):
    # takes in a file of generate.py translation generate_output
    # returns a source dict and hypothesis dict, where keys are the ID num (as a string)
    # and values and the corresponding source and translation. There may be several translations
    # per source, so the values for hypothesis_dict are lists.
    # parses output of generate.py
    
    with open(fle, 'r') as f:
        txt = f.read()
    
    """reprocess generate.py output"""
    p = re.compile(r"[STHP][-]\d+\s*")
    hp = re.compile(r"(\s*[-]?\d+[.]?\d+(e[+-]?\d+)?\s*)|(\s*(-inf)\s*)")
    source_dict = {}
    hypothesis_dict = {}
    score_dict = {}
    target_dict = {}
    pos_score_dict = {}
    lines = txt.split("\n")
    
    for line in lines:
        line += "\n"
        prefix = re.search(p, line)
        if prefix is not None:
            assert len(prefix.group()) > 2, "prefix id not found"
            _, j = prefix.span()
            id_num = prefix.group()[2:]
            id_num = int(id_num)
            line_type = prefix.group()[0]
            if line_type == "H":
                h_txt = line[j:]
                hypo = re.search(hp, h_txt)
                assert hypo is not None, ("regular expression failed to find the hypothesis scoring")
                _, i = hypo.span()
                score = hypo.group()
                hypo_str = h_txt[i:]
                # if r2l:  # todo: reverse score as well
                #     hypo_str = " ".join(reversed(hypo_str.strip().split(" "))) + "\n"
                if id_num in hypothesis_dict:
                    hypothesis_dict[id_num].append(hypo_str)
                    score_dict[id_num].append(float(score))
                else:
                    hypothesis_dict[id_num] = [hypo_str]
                    score_dict[id_num] = [float(score)]
            
            elif line_type == "S":
                source_dict[id_num] = (line[j:])
            elif line_type == "T":
                # target_dict[id_num] = (line[j:])
                continue
            elif line_type == "P":
                pos_scores = (line[j:]).split()
                pos_scores = [float(x) for x in pos_scores]
                if id_num in pos_score_dict:
                    pos_score_dict[id_num].append(pos_scores)
                else:
                    pos_score_dict[id_num] = [pos_scores]
    
    return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict


def get_hypo_and_ref(fle, hyp_file, ref_input, ref_file, rank=0):
    with open(ref_input, 'r') as f:
        refs = f.readlines()
    _, hypo_dict, _, _, _ = reprocess(fle)
    assert rank < len(hypo_dict[0])
    maxkey = max(hypo_dict, key=int)
    f_hyp = open(hyp_file, "w")
    f_ref = open(ref_file, "w")
    for idx in range(maxkey + 1):
        if idx not in hypo_dict:
            continue
        f_hyp.write(hypo_dict[idx][rank])
        f_ref.write(refs[idx])
    f_hyp.close()
    f_ref.close()


def recover_bpe(hyp_file):
    f_hyp = open(hyp_file, "r")
    f_hyp_out = open(hyp_file + ".nobpe", "w")
    for _s in ["hyp"]:
        f = eval("f_{}".format(_s))
        fout = eval("f_{}_out".format(_s))
        remove_bpe_fn(i=f, o=fout)
    f_hyp.close()
    f_hyp_out.close()


if __name__ == "__main__":
    filename = sys.argv[1]
    ref_in = sys.argv[2]
    hypo_file = os.path.join(os.path.dirname(filename), "hypo.out")
    ref_out = os.path.join(os.path.dirname(filename), "ref.out")
    get_hypo_and_ref(filename, hypo_file, ref_in, ref_out)
    recover_bpe(hypo_file)