File size: 15,000 Bytes
a5bbcdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
from utils.finetune import Graph2TextModule
from typing import Dict, List, Tuple, Union, Optional
import torch
import re

if torch.cuda.is_available():
    DEVICE = 'cuda'
else:
    DEVICE = 'cpu'
    print('CUDA NOT AVAILABLE')

CHECKPOINT = 'base/t5-base_13881_val_avg_bleu=68.1000-step_count=5.ckpt'
MAX_LENGTH = 384
SEED = 42


class VerbModule():
    
    def __init__(self, override_args: Dict[str, str] = None): 
        # Model
        if not override_args:
            override_args = {}
        self.g2t_module = Graph2TextModule.load_from_checkpoint(CHECKPOINT, strict=False, **override_args)
        self.tokenizer = self.g2t_module.tokenizer
        # Unk replacer
        self.vocab = self.tokenizer.get_vocab()
        self.convert_some_japanese_characters = True
        self.unk_char_replace_sliding_window_size = 2
        self.unknowns = []

    def __generate_verbalisations_from_inputs(self, inputs: Union[str, List[str]]):
        try:
            inputs_encoding = self.tokenizer.prepare_seq2seq_batch(
                inputs, truncation=True, max_length=MAX_LENGTH, return_tensors='pt'
            )
            inputs_encoding = {k: v.to(DEVICE) for k, v in inputs_encoding.items()}
            
            self.g2t_module.model.eval()
            with torch.no_grad():
                gen_output = self.g2t_module.model.generate(
                    inputs_encoding['input_ids'],
                    attention_mask=inputs_encoding['attention_mask'],
                    use_cache=True,
                    decoder_start_token_id = self.g2t_module.decoder_start_token_id,
                    num_beams= self.g2t_module.eval_beams,
                    max_length= self.g2t_module.eval_max_length,
                    length_penalty=1.0    
                )
        except Exception:
            print(inputs)
            raise

        return gen_output
    
    '''
    We create this function as an alteration from [this one](https://github.com/huggingface/transformers/blob/198c335d219a5eb4d3f124fdd1ce1a9cd9f78a9b/src/transformers/tokenization_utils_fast.py#L537), mainly because the official 'tokenizer.decode' treats all special tokens the same, while we want to drop all special tokens from the decoded sentence EXCEPT for the <unk> token, which we will replace later on.
    '''
    def __decode_ids_to_string_custom(
        self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
    ) -> str:
        filtered_tokens = self.tokenizer.convert_ids_to_tokens(token_ids, skip_special_tokens=False)
        # Do not remove special tokens yet

        # To avoid mixing byte-level and unicode for byte-level BPT
        # we need to build string separatly for added tokens and byte-level tokens
        # cf. https://github.com/huggingface/transformers/issues/1133
        sub_texts = []
        current_sub_text = []
        for token in filtered_tokens:
            if skip_special_tokens and\
                token != self.tokenizer.unk_token and\
                token in self.tokenizer.all_special_tokens:

                continue
            else:
                current_sub_text.append(token)
        if current_sub_text:
            sub_texts.append(self.tokenizer.convert_tokens_to_string(current_sub_text))
        text = " ".join(sub_texts)

        if clean_up_tokenization_spaces:
            clean_text = self.tokenizer.clean_up_tokenization(text)
            return clean_text
        else:
            return text

    def __decode_sentences(self, encoded_sentences: Union[str, List[str]]):
        if type(encoded_sentences) == str:
            encoded_sentences = [encoded_sentences]
        decoded_sentences = [self.__decode_ids_to_string_custom(i, skip_special_tokens=True) for i in encoded_sentences]
        return decoded_sentences
        
    def verbalise_sentence(self, inputs: Union[str, List[str]]):
        if type(inputs) == str:
            inputs = [inputs]
        
        gen_output = self.__generate_verbalisations_from_inputs(inputs)
        
        decoded_sentences = self.__decode_sentences(gen_output)

        if len(decoded_sentences) == 1:
            return decoded_sentences[0]
        else:
            return decoded_sentences

    def verbalise_triples(self, input_triples: Union[Dict[str, str], List[Dict[str, str]], List[List[Dict[str, str]]]]):
        if type(input_triples) == dict:
            input_triples = [input_triples]

        verbalisation_inputs = []
        for triple in input_triples:
            if type(triple) == dict:
                assert 'subject' in triple
                assert 'predicate' in triple
                assert 'object' in triple
                verbalisation_inputs.append(
                    f'translate Graph to English: <H> {triple["subject"]} <R> {triple["predicate"]} <T> {triple["object"]}'
                )
            elif type(triple) == list:
                input_sentence = ['translate Graph to English:']
                for subtriple in triple:
                    assert 'subject' in subtriple
                    assert 'predicate' in subtriple
                    assert 'object' in subtriple
                    input_sentence.append(f'<H> {subtriple["subject"]}')
                    input_sentence.append(f'<R> {subtriple["predicate"]}')
                    input_sentence.append(f'<T> {subtriple["object"]}')
                verbalisation_inputs.append(
                    ' '.join(input_sentence)
                )

        return self.verbalise_sentence(verbalisation_inputs)
        
    def verbalise(self, input: Union[str, List, Dict]):
        try:
            if (type(input) == str) or (type(input) == list and type(input[0]) == str):
                return self.verbalise_sentence(input)
            elif (type(input) == dict) or (type(input) == list and type(input[0]) == dict):
                return self.verbalise_triples(input)
            else:
                return self.verbalise_triples(input)
        except Exception:
            print(f'ERROR VERBALISING {input}')
            raise
                
    def add_label_to_unk_replacer(self, label: str):
        N = self.unk_char_replace_sliding_window_size
        self.unknowns.append({})
        
        # Some pre-processing of labels to normalise some characters
        if self.convert_some_japanese_characters:
            label = label.replace('(','(')
            label = label.replace(')',')')
            label = label.replace('〈','<')
            label = label.replace('/','/')
            label = label.replace('〉','>')        
        
        label_encoded = self.tokenizer.encode(label)
        label_tokens = self.tokenizer.convert_ids_to_tokens(label_encoded)
        
        # Here, we also remove </s> (eos) and <pad> tokens in the replacing key, because:
        # 1) When the whole label is all unk:
        #   label_token_to_string would be '<unk></s>', meaning the replacing key (which is the same) only replaces
        #   the <unk> if it appears at the end of the sentence, which is not the desired effect.
        #   But since this means ANY <unk> will be replaced by this, it would be good to only replace keys that are <unk>
        #   on the last replacing pass.
        # 2) On other cases, then the unk is in the label but not in its entirety, like in the start/end, it might
        #   involve the starting <pad> token or the ending <eos> token on the replacing key, again forcing the replacement
        #   to only happen if the label appears in the end of the sentence.
        label_tokens = [t for t in label_tokens if t not in [
            self.tokenizer.eos_token, self.tokenizer.pad_token
        ]]

        label_token_to_string = self.tokenizer.convert_tokens_to_string(label_tokens)
        unk_token_to_string = self.tokenizer.convert_tokens_to_string([self.tokenizer.unk_token])
                
        #print(label_encoded,label_tokens,label_token_to_string)
        
        match_unks_in_label = re.findall('(?:(?: )*<unk>(?: )*)+', label_token_to_string)
        if len(match_unks_in_label) > 0:
            # If the whole label is made of UNK
            if (match_unks_in_label[0]) == label_token_to_string:
                #print('Label is all unks')                    
                self.unknowns[-1][label_token_to_string.strip()] = label
            # Else, there should be non-UNK characters in the label
            else:
                #print('Label is NOT all unks')
                # Analyse the label with a sliding window of size N (N before, N ahead)
                for idx, token in enumerate(label_tokens):
                    idx_before = max(0,idx-N)
                    idx_ahead = min(len(label_tokens), idx+N+1)
                    
                                       
                    # Found a UNK
                    if token == self.tokenizer.unk_token:
                        
                        # In case multiple UNK, exclude UNKs seen after this one, expand window to other side if possible
                        if len(match_unks_in_label) > 1:
                            #print(idx)
                            #print(label_tokens)
                            #print(label_tokens[idx_before:idx_ahead])
                            #print('HERE!')
                            # Reduce on the right, expanding on the left
                            while self.tokenizer.unk_token in label_tokens[idx+1:idx_ahead]:
                                idx_before = max(0,idx_before-1)
                                idx_ahead = min(idx+2, idx_ahead-1)
                                #print(label_tokens[idx_before:idx_ahead])
                            # Now just reduce on the left
                            while self.tokenizer.unk_token in label_tokens[idx_before:idx]:
                                idx_before = min(idx-1,idx_before+2)
                                #print(label_tokens[idx_before:idx_ahead])

                        span = self.tokenizer.convert_tokens_to_string(label_tokens[idx_before:idx_ahead])        
                        # First token of the label is UNK                        
                        if idx == 1 and label_tokens[0] == '▁':
                            #print('Label begins with unks')
                            to_replace = '^' + re.escape(span).replace(
                                    re.escape(unk_token_to_string),
                                    '.+?'
                                )
                            
                            replaced_span = re.search(
                                to_replace,
                                label
                            )[0]
                            self.unknowns[-1][span.strip()] = replaced_span
                        # Last token of the label is UNK
                        elif idx == len(label_tokens)-2 and label_tokens[-1] == self.tokenizer.eos_token:
                            #print('Label ends with unks')
                            pre_idx = self.tokenizer.convert_tokens_to_string(label_tokens[idx_before:idx])
                            pre_idx_unk_counts = pre_idx.count(unk_token_to_string)
                            to_replace = re.escape(span).replace(
                                    re.escape(unk_token_to_string),
                                    f'[^{re.escape(pre_idx)}]+?'
                                ) + '$'
                            
                            if pre_idx.strip() == '':
                                to_replace = to_replace.replace('[^]', '(?<=\s)[^a-zA-Z0-9]')
                            
                            replaced_span = re.search(
                                to_replace,
                                label
                            )[0]
                            self.unknowns[-1][span.strip()] = replaced_span
                            
                        # A token in-between the label is UNK                            
                        else:
                            #print('Label has unks in the middle')
                            pre_idx = self.tokenizer.convert_tokens_to_string(label_tokens[idx_before:idx])

                            to_replace = re.escape(span).replace(
                                re.escape(unk_token_to_string),
                                f'[^{re.escape(pre_idx)}]+?'
                            )
                            #If there is nothing behind the ??, because it is in the middle but the previous token is also
                            #a ??, then we would end up with to_replace beginning with [^], which we can't have
                            if pre_idx.strip() == '':
                                to_replace = to_replace.replace('[^]', '(?<=\s)[^a-zA-Z0-9]')
        
                            replaced_span = re.search(
                                to_replace,
                                label
                            )
                            
                            if replaced_span:
                                span = re.sub(r'\s([?.!",](?:\s|$))', r'\1', span.strip())
                                self.unknowns[-1][span] = replaced_span[0]  

    def replace_unks_on_sentence(self, sentence: str, loop_n : int = 3, empty_after : bool = False):
        # Loop through in case the labels are repeated, maximum of three times
        while '<unk>' in sentence and loop_n > 0:
            loop_n -= 1
            for unknowns in self.unknowns:
                for k,v in unknowns.items():
                    # Leave to replace all-unk labels at the last pass
                    if k == '<unk>' and loop_n > 0:
                        continue
                    # In case it is because the first letter of the sentence has been uppercased
                    if not k in sentence and k[0] == k[0].lower() and k[0].upper() == sentence[0]:
                        k = k[0].upper() + k[1:]
                        v = v[0].upper() + v[1:]
                    # In case it is because a double space is found where it should not be
                    elif not k in sentence and len(re.findall(r'\s{2,}',k))>0:
                        k = re.sub(r'\s+', ' ', k)
                    #print(k,'/',v,'/',sentence)
                    sentence = sentence.replace(k.strip(),v.strip(),1)
                    #sentence = re.sub(k, v, sentence)
            # Removing final doublespaces
            sentence = re.sub(r'\s+', ' ', sentence).strip()
            # Removing spaces before punctuation
            sentence = re.sub(r'\s([?.!",](?:\s|$))', r'\1', sentence)
        if empty_after:
            self.unknowns = []
        return sentence

if __name__ == '__main__':

    verb_module = VerbModule()
    verbs = verb_module.verbalise('translate Graph to English: <H> World Trade Center <R> height <T> 200 meter <H> World Trade Center <R> is a <T> tower')
    print(verbs)