Spaces:
Running
Running
import torch | |
import pickle | |
import numpy as np | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
class AttrDict(dict): | |
def __init__(self, *args, **kwargs): | |
super(AttrDict, self).__init__(*args, **kwargs) | |
self.__dict__ = self | |
##### https://github.com/githubharald/CTCDecoder/blob/master/src/BeamSearch.py | |
class BeamEntry: | |
"information about one single beam at specific time-step" | |
def __init__(self): | |
self.prTotal = 0 # blank and non-blank | |
self.prNonBlank = 0 # non-blank | |
self.prBlank = 0 # blank | |
self.prText = 1 # LM score | |
self.lmApplied = False # flag if LM was already applied to this beam | |
self.labeling = () # beam-labeling | |
class BeamState: | |
"information about the beams at specific time-step" | |
def __init__(self): | |
self.entries = {} | |
def norm(self): | |
"length-normalise LM score" | |
for (k, _) in self.entries.items(): | |
labelingLen = len(self.entries[k].labeling) | |
self.entries[k].prText = self.entries[k].prText ** (1.0 / (labelingLen if labelingLen else 1.0)) | |
def sort(self): | |
"return beam-labelings, sorted by probability" | |
beams = [v for (_, v) in self.entries.items()] | |
sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText) | |
return [x.labeling for x in sortedBeams] | |
def wordsearch(self, classes, ignore_idx, beamWidth, dict_list): | |
beams = [v for (_, v) in self.entries.items()] | |
sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText)[:beamWidth] | |
for j, candidate in enumerate(sortedBeams): | |
idx_list = candidate.labeling | |
text = '' | |
for i,l in enumerate(idx_list): | |
if l not in ignore_idx and (not (i > 0 and idx_list[i - 1] == idx_list[i])): # removing repeated characters and blank. | |
text += classes[l] | |
if j == 0: best_text = text | |
if text in dict_list: | |
print('found text: ', text) | |
best_text = text | |
break | |
else: | |
print('not in dict: ', text) | |
return best_text | |
def applyLM(parentBeam, childBeam, classes, lm): | |
"calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars" | |
if lm and not childBeam.lmApplied: | |
c1 = classes[parentBeam.labeling[-1] if parentBeam.labeling else classes.index(' ')] # first char | |
c2 = classes[childBeam.labeling[-1]] # second char | |
lmFactor = 0.01 # influence of language model | |
bigramProb = lm.getCharBigram(c1, c2) ** lmFactor # probability of seeing first and second char next to each other | |
childBeam.prText = parentBeam.prText * bigramProb # probability of char sequence | |
childBeam.lmApplied = True # only apply LM once per beam entry | |
def addBeam(beamState, labeling): | |
"add beam if it does not yet exist" | |
if labeling not in beamState.entries: | |
beamState.entries[labeling] = BeamEntry() | |
def ctcBeamSearch(mat, classes, ignore_idx, lm, beamWidth=25, dict_list = []): | |
"beam search as described by the paper of Hwang et al. and the paper of Graves et al." | |
#blankIdx = len(classes) | |
blankIdx = 0 | |
maxT, maxC = mat.shape | |
# initialise beam state | |
last = BeamState() | |
labeling = () | |
last.entries[labeling] = BeamEntry() | |
last.entries[labeling].prBlank = 1 | |
last.entries[labeling].prTotal = 1 | |
# go over all time-steps | |
for t in range(maxT): | |
curr = BeamState() | |
# get beam-labelings of best beams | |
bestLabelings = last.sort()[0:beamWidth] | |
# go over best beams | |
for labeling in bestLabelings: | |
# probability of paths ending with a non-blank | |
prNonBlank = 0 | |
# in case of non-empty beam | |
if labeling: | |
# probability of paths with repeated last char at the end | |
prNonBlank = last.entries[labeling].prNonBlank * mat[t, labeling[-1]] | |
# probability of paths ending with a blank | |
prBlank = (last.entries[labeling].prTotal) * mat[t, blankIdx] | |
# add beam at current time-step if needed | |
addBeam(curr, labeling) | |
# fill in data | |
curr.entries[labeling].labeling = labeling | |
curr.entries[labeling].prNonBlank += prNonBlank | |
curr.entries[labeling].prBlank += prBlank | |
curr.entries[labeling].prTotal += prBlank + prNonBlank | |
curr.entries[labeling].prText = last.entries[labeling].prText # beam-labeling not changed, therefore also LM score unchanged from | |
curr.entries[labeling].lmApplied = True # LM already applied at previous time-step for this beam-labeling | |
# extend current beam-labeling | |
for c in range(maxC - 1): | |
# add new char to current beam-labeling | |
newLabeling = labeling + (c,) | |
# if new labeling contains duplicate char at the end, only consider paths ending with a blank | |
if labeling and labeling[-1] == c: | |
prNonBlank = mat[t, c] * last.entries[labeling].prBlank | |
else: | |
prNonBlank = mat[t, c] * last.entries[labeling].prTotal | |
# add beam at current time-step if needed | |
addBeam(curr, newLabeling) | |
# fill in data | |
curr.entries[newLabeling].labeling = newLabeling | |
curr.entries[newLabeling].prNonBlank += prNonBlank | |
curr.entries[newLabeling].prTotal += prNonBlank | |
# apply LM | |
#applyLM(curr.entries[labeling], curr.entries[newLabeling], classes, lm) | |
# set new beam state | |
last = curr | |
# normalise LM scores according to beam-labeling-length | |
last.norm() | |
# sort by probability | |
#bestLabeling = last.sort()[0] # get most probable labeling | |
# map labels to chars | |
#res = '' | |
#for idx,l in enumerate(bestLabeling): | |
# if l not in ignore_idx and (not (idx > 0 and bestLabeling[idx - 1] == bestLabeling[idx])): # removing repeated characters and blank. | |
# res += classes[l] | |
if dict_list == []: | |
bestLabeling = last.sort()[0] # get most probable labeling | |
res = '' | |
for i,l in enumerate(bestLabeling): | |
if l not in ignore_idx and (not (i > 0 and bestLabeling[i - 1] == bestLabeling[i])): # removing repeated characters and blank. | |
res += classes[l] | |
else: | |
res = last.wordsearch(classes, ignore_idx, beamWidth, dict_list) | |
return res | |
##### | |
def consecutive(data, mode ='first', stepsize=1): | |
group = np.split(data, np.where(np.diff(data) != stepsize)[0]+1) | |
group = [item for item in group if len(item)>0] | |
if mode == 'first': result = [l[0] for l in group] | |
elif mode == 'last': result = [l[-1] for l in group] | |
return result | |
def word_segmentation(mat, separator_idx = {'th': [1,2],'en': [3,4]}, separator_idx_list = [1,2,3,4]): | |
result = [] | |
sep_list = [] | |
start_idx = 0 | |
for sep_idx in separator_idx_list: | |
if sep_idx % 2 == 0: mode ='first' | |
else: mode ='last' | |
a = consecutive( np.argwhere(mat == sep_idx).flatten(), mode) | |
new_sep = [ [item, sep_idx] for item in a] | |
sep_list += new_sep | |
sep_list = sorted(sep_list, key=lambda x: x[0]) | |
for sep in sep_list: | |
for lang in separator_idx.keys(): | |
if sep[1] == separator_idx[lang][0]: # start lang | |
sep_lang = lang | |
sep_start_idx = sep[0] | |
elif sep[1] == separator_idx[lang][1]: # end lang | |
if sep_lang == lang: # check if last entry if the same start lang | |
new_sep_pair = [lang, [sep_start_idx+1, sep[0]-1]] | |
if sep_start_idx > start_idx: | |
result.append( ['', [start_idx, sep_start_idx-1] ] ) | |
start_idx = sep[0]+1 | |
result.append(new_sep_pair) | |
else: # reset | |
sep_lang = '' | |
if start_idx <= len(mat)-1: | |
result.append( ['', [start_idx, len(mat)-1] ] ) | |
return result | |
class CTCLabelConverter(object): | |
""" Convert between text-label and text-index """ | |
#def __init__(self, character, separator = []): | |
def __init__(self, character, separator_list = {}, dict_pathlist = {}): | |
# character (str): set of the possible characters. | |
dict_character = list(character) | |
#special_character = ['\xa2', '\xa3', '\xa4','\xa5'] | |
#self.separator_char = special_character[:len(separator)] | |
self.dict = {} | |
#for i, char in enumerate(self.separator_char + dict_character): | |
for i, char in enumerate(dict_character): | |
# NOTE: 0 is reserved for 'blank' token required by CTCLoss | |
self.dict[char] = i + 1 | |
self.character = ['[blank]'] + dict_character # dummy '[blank]' token for CTCLoss (index 0) | |
#self.character = ['[blank]']+ self.separator_char + dict_character # dummy '[blank]' token for CTCLoss (index 0) | |
self.separator_list = separator_list | |
separator_char = [] | |
for lang, sep in separator_list.items(): | |
separator_char += sep | |
self.ignore_idx = [0] + [i+1 for i,item in enumerate(separator_char)] | |
dict_list = {} | |
for lang, dict_path in dict_pathlist.items(): | |
with open(dict_path, "rb") as input_file: | |
word_count = pickle.load(input_file) | |
dict_list[lang] = word_count | |
self.dict_list = dict_list | |
def encode(self, text, batch_max_length=25): | |
"""convert text-label into text-index. | |
input: | |
text: text labels of each image. [batch_size] | |
output: | |
text: concatenated text index for CTCLoss. | |
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] | |
length: length of each text. [batch_size] | |
""" | |
length = [len(s) for s in text] | |
text = ''.join(text) | |
text = [self.dict[char] for char in text] | |
return (torch.IntTensor(text), torch.IntTensor(length)) | |
def decode_greedy(self, text_index, length): | |
""" convert text-index into text-label. """ | |
texts = [] | |
index = 0 | |
for l in length: | |
t = text_index[index:index + l] | |
char_list = [] | |
for i in range(l): | |
if t[i] not in self.ignore_idx and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank (and separator). | |
#if (t[i] != 0) and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank (and separator). | |
char_list.append(self.character[t[i]]) | |
text = ''.join(char_list) | |
texts.append(text) | |
index += l | |
return texts | |
def decode_beamsearch(self, mat, beamWidth=5): | |
texts = [] | |
for i in range(mat.shape[0]): | |
t = ctcBeamSearch(mat[i], self.character, self.ignore_idx, None, beamWidth=beamWidth) | |
texts.append(t) | |
return texts | |
def decode_wordbeamsearch(self, mat, beamWidth=5): | |
texts = [] | |
argmax = np.argmax(mat, axis = 2) | |
for i in range(mat.shape[0]): | |
words = word_segmentation(argmax[i]) | |
string = '' | |
for word in words: | |
matrix = mat[i, word[1][0]:word[1][1]+1,:] | |
if word[0] == '': dict_list = [] | |
else: dict_list = self.dict_list[word[0]] | |
t = ctcBeamSearch(matrix, self.character, self.ignore_idx, None, beamWidth=beamWidth, dict_list=dict_list) | |
string += t | |
texts.append(string) | |
return texts | |
class AttnLabelConverter(object): | |
""" Convert between text-label and text-index """ | |
def __init__(self, character): | |
# character (str): set of the possible characters. | |
# [GO] for the start token of the attention decoder. [s] for end-of-sentence token. | |
list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]'] | |
list_character = list(character) | |
self.character = list_token + list_character | |
self.dict = {} | |
for i, char in enumerate(self.character): | |
# print(i, char) | |
self.dict[char] = i | |
def encode(self, text, batch_max_length=25): | |
""" convert text-label into text-index. | |
input: | |
text: text labels of each image. [batch_size] | |
batch_max_length: max length of text label in the batch. 25 by default | |
output: | |
text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. | |
text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. | |
length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] | |
""" | |
length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. | |
# batch_max_length = max(length) # this is not allowed for multi-gpu setting | |
batch_max_length += 1 | |
# additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. | |
batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0) | |
for i, t in enumerate(text): | |
text = list(t) | |
text.append('[s]') | |
text = [self.dict[char] for char in text] | |
batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token | |
return (batch_text.to(device), torch.IntTensor(length).to(device)) | |
def decode(self, text_index, length): | |
""" convert text-index into text-label. """ | |
texts = [] | |
for index, l in enumerate(length): | |
text = ''.join([self.character[i] for i in text_index[index, :]]) | |
texts.append(text) | |
return texts | |
class Averager(object): | |
"""Compute average for torch.Tensor, used for loss average.""" | |
def __init__(self): | |
self.reset() | |
def add(self, v): | |
count = v.data.numel() | |
v = v.data.sum() | |
self.n_count += count | |
self.sum += v | |
def reset(self): | |
self.n_count = 0 | |
self.sum = 0 | |
def val(self): | |
res = 0 | |
if self.n_count != 0: | |
res = self.sum / float(self.n_count) | |
return res | |