import torch import pickle import numpy as np device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class AttrDict(dict): def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs) self.__dict__ = self ##### https://github.com/githubharald/CTCDecoder/blob/master/src/BeamSearch.py class BeamEntry: "information about one single beam at specific time-step" def __init__(self): self.prTotal = 0 # blank and non-blank self.prNonBlank = 0 # non-blank self.prBlank = 0 # blank self.prText = 1 # LM score self.lmApplied = False # flag if LM was already applied to this beam self.labeling = () # beam-labeling class BeamState: "information about the beams at specific time-step" def __init__(self): self.entries = {} def norm(self): "length-normalise LM score" for (k, _) in self.entries.items(): labelingLen = len(self.entries[k].labeling) self.entries[k].prText = self.entries[k].prText ** (1.0 / (labelingLen if labelingLen else 1.0)) def sort(self): "return beam-labelings, sorted by probability" beams = [v for (_, v) in self.entries.items()] sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText) return [x.labeling for x in sortedBeams] def wordsearch(self, classes, ignore_idx, beamWidth, dict_list): beams = [v for (_, v) in self.entries.items()] sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText)[:beamWidth] for j, candidate in enumerate(sortedBeams): idx_list = candidate.labeling text = '' for i,l in enumerate(idx_list): if l not in ignore_idx and (not (i > 0 and idx_list[i - 1] == idx_list[i])): # removing repeated characters and blank. text += classes[l] if j == 0: best_text = text if text in dict_list: print('found text: ', text) best_text = text break else: print('not in dict: ', text) return best_text def applyLM(parentBeam, childBeam, classes, lm): "calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars" if lm and not childBeam.lmApplied: c1 = classes[parentBeam.labeling[-1] if parentBeam.labeling else classes.index(' ')] # first char c2 = classes[childBeam.labeling[-1]] # second char lmFactor = 0.01 # influence of language model bigramProb = lm.getCharBigram(c1, c2) ** lmFactor # probability of seeing first and second char next to each other childBeam.prText = parentBeam.prText * bigramProb # probability of char sequence childBeam.lmApplied = True # only apply LM once per beam entry def addBeam(beamState, labeling): "add beam if it does not yet exist" if labeling not in beamState.entries: beamState.entries[labeling] = BeamEntry() def ctcBeamSearch(mat, classes, ignore_idx, lm, beamWidth=25, dict_list = []): "beam search as described by the paper of Hwang et al. and the paper of Graves et al." #blankIdx = len(classes) blankIdx = 0 maxT, maxC = mat.shape # initialise beam state last = BeamState() labeling = () last.entries[labeling] = BeamEntry() last.entries[labeling].prBlank = 1 last.entries[labeling].prTotal = 1 # go over all time-steps for t in range(maxT): curr = BeamState() # get beam-labelings of best beams bestLabelings = last.sort()[0:beamWidth] # go over best beams for labeling in bestLabelings: # probability of paths ending with a non-blank prNonBlank = 0 # in case of non-empty beam if labeling: # probability of paths with repeated last char at the end prNonBlank = last.entries[labeling].prNonBlank * mat[t, labeling[-1]] # probability of paths ending with a blank prBlank = (last.entries[labeling].prTotal) * mat[t, blankIdx] # add beam at current time-step if needed addBeam(curr, labeling) # fill in data curr.entries[labeling].labeling = labeling curr.entries[labeling].prNonBlank += prNonBlank curr.entries[labeling].prBlank += prBlank curr.entries[labeling].prTotal += prBlank + prNonBlank curr.entries[labeling].prText = last.entries[labeling].prText # beam-labeling not changed, therefore also LM score unchanged from curr.entries[labeling].lmApplied = True # LM already applied at previous time-step for this beam-labeling # extend current beam-labeling for c in range(maxC - 1): # add new char to current beam-labeling newLabeling = labeling + (c,) # if new labeling contains duplicate char at the end, only consider paths ending with a blank if labeling and labeling[-1] == c: prNonBlank = mat[t, c] * last.entries[labeling].prBlank else: prNonBlank = mat[t, c] * last.entries[labeling].prTotal # add beam at current time-step if needed addBeam(curr, newLabeling) # fill in data curr.entries[newLabeling].labeling = newLabeling curr.entries[newLabeling].prNonBlank += prNonBlank curr.entries[newLabeling].prTotal += prNonBlank # apply LM #applyLM(curr.entries[labeling], curr.entries[newLabeling], classes, lm) # set new beam state last = curr # normalise LM scores according to beam-labeling-length last.norm() # sort by probability #bestLabeling = last.sort()[0] # get most probable labeling # map labels to chars #res = '' #for idx,l in enumerate(bestLabeling): # if l not in ignore_idx and (not (idx > 0 and bestLabeling[idx - 1] == bestLabeling[idx])): # removing repeated characters and blank. # res += classes[l] if dict_list == []: bestLabeling = last.sort()[0] # get most probable labeling res = '' for i,l in enumerate(bestLabeling): if l not in ignore_idx and (not (i > 0 and bestLabeling[i - 1] == bestLabeling[i])): # removing repeated characters and blank. res += classes[l] else: res = last.wordsearch(classes, ignore_idx, beamWidth, dict_list) return res ##### def consecutive(data, mode ='first', stepsize=1): group = np.split(data, np.where(np.diff(data) != stepsize)[0]+1) group = [item for item in group if len(item)>0] if mode == 'first': result = [l[0] for l in group] elif mode == 'last': result = [l[-1] for l in group] return result def word_segmentation(mat, separator_idx = {'th': [1,2],'en': [3,4]}, separator_idx_list = [1,2,3,4]): result = [] sep_list = [] start_idx = 0 for sep_idx in separator_idx_list: if sep_idx % 2 == 0: mode ='first' else: mode ='last' a = consecutive( np.argwhere(mat == sep_idx).flatten(), mode) new_sep = [ [item, sep_idx] for item in a] sep_list += new_sep sep_list = sorted(sep_list, key=lambda x: x[0]) for sep in sep_list: for lang in separator_idx.keys(): if sep[1] == separator_idx[lang][0]: # start lang sep_lang = lang sep_start_idx = sep[0] elif sep[1] == separator_idx[lang][1]: # end lang if sep_lang == lang: # check if last entry if the same start lang new_sep_pair = [lang, [sep_start_idx+1, sep[0]-1]] if sep_start_idx > start_idx: result.append( ['', [start_idx, sep_start_idx-1] ] ) start_idx = sep[0]+1 result.append(new_sep_pair) else: # reset sep_lang = '' if start_idx <= len(mat)-1: result.append( ['', [start_idx, len(mat)-1] ] ) return result class CTCLabelConverter(object): """ Convert between text-label and text-index """ #def __init__(self, character, separator = []): def __init__(self, character, separator_list = {}, dict_pathlist = {}): # character (str): set of the possible characters. dict_character = list(character) #special_character = ['\xa2', '\xa3', '\xa4','\xa5'] #self.separator_char = special_character[:len(separator)] self.dict = {} #for i, char in enumerate(self.separator_char + dict_character): for i, char in enumerate(dict_character): # NOTE: 0 is reserved for 'blank' token required by CTCLoss self.dict[char] = i + 1 self.character = ['[blank]'] + dict_character # dummy '[blank]' token for CTCLoss (index 0) #self.character = ['[blank]']+ self.separator_char + dict_character # dummy '[blank]' token for CTCLoss (index 0) self.separator_list = separator_list separator_char = [] for lang, sep in separator_list.items(): separator_char += sep self.ignore_idx = [0] + [i+1 for i,item in enumerate(separator_char)] dict_list = {} for lang, dict_path in dict_pathlist.items(): with open(dict_path, "rb") as input_file: word_count = pickle.load(input_file) dict_list[lang] = word_count self.dict_list = dict_list def encode(self, text, batch_max_length=25): """convert text-label into text-index. input: text: text labels of each image. [batch_size] output: text: concatenated text index for CTCLoss. [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] length: length of each text. [batch_size] """ length = [len(s) for s in text] text = ''.join(text) text = [self.dict[char] for char in text] return (torch.IntTensor(text), torch.IntTensor(length)) def decode_greedy(self, text_index, length): """ convert text-index into text-label. """ texts = [] index = 0 for l in length: t = text_index[index:index + l] char_list = [] for i in range(l): if t[i] not in self.ignore_idx and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank (and separator). #if (t[i] != 0) and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank (and separator). char_list.append(self.character[t[i]]) text = ''.join(char_list) texts.append(text) index += l return texts def decode_beamsearch(self, mat, beamWidth=5): texts = [] for i in range(mat.shape[0]): t = ctcBeamSearch(mat[i], self.character, self.ignore_idx, None, beamWidth=beamWidth) texts.append(t) return texts def decode_wordbeamsearch(self, mat, beamWidth=5): texts = [] argmax = np.argmax(mat, axis = 2) for i in range(mat.shape[0]): words = word_segmentation(argmax[i]) string = '' for word in words: matrix = mat[i, word[1][0]:word[1][1]+1,:] if word[0] == '': dict_list = [] else: dict_list = self.dict_list[word[0]] t = ctcBeamSearch(matrix, self.character, self.ignore_idx, None, beamWidth=beamWidth, dict_list=dict_list) string += t texts.append(string) return texts class AttnLabelConverter(object): """ Convert between text-label and text-index """ def __init__(self, character): # character (str): set of the possible characters. # [GO] for the start token of the attention decoder. [s] for end-of-sentence token. list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]'] list_character = list(character) self.character = list_token + list_character self.dict = {} for i, char in enumerate(self.character): # print(i, char) self.dict[char] = i def encode(self, text, batch_max_length=25): """ convert text-label into text-index. input: text: text labels of each image. [batch_size] batch_max_length: max length of text label in the batch. 25 by default output: text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] """ length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. # batch_max_length = max(length) # this is not allowed for multi-gpu setting batch_max_length += 1 # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0) for i, t in enumerate(text): text = list(t) text.append('[s]') text = [self.dict[char] for char in text] batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token return (batch_text.to(device), torch.IntTensor(length).to(device)) def decode(self, text_index, length): """ convert text-index into text-label. """ texts = [] for index, l in enumerate(length): text = ''.join([self.character[i] for i in text_index[index, :]]) texts.append(text) return texts class Averager(object): """Compute average for torch.Tensor, used for loss average.""" def __init__(self): self.reset() def add(self, v): count = v.data.numel() v = v.data.sum() self.n_count += count self.sum += v def reset(self): self.n_count = 0 self.sum = 0 def val(self): res = 0 if self.n_count != 0: res = self.sum / float(self.n_count) return res