wushidiguo2's picture
Upload 55 files
58e7ec3
import torch
import pickle
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
##### https://github.com/githubharald/CTCDecoder/blob/master/src/BeamSearch.py
class BeamEntry:
"information about one single beam at specific time-step"
def __init__(self):
self.prTotal = 0 # blank and non-blank
self.prNonBlank = 0 # non-blank
self.prBlank = 0 # blank
self.prText = 1 # LM score
self.lmApplied = False # flag if LM was already applied to this beam
self.labeling = () # beam-labeling
class BeamState:
"information about the beams at specific time-step"
def __init__(self):
self.entries = {}
def norm(self):
"length-normalise LM score"
for (k, _) in self.entries.items():
labelingLen = len(self.entries[k].labeling)
self.entries[k].prText = self.entries[k].prText ** (1.0 / (labelingLen if labelingLen else 1.0))
def sort(self):
"return beam-labelings, sorted by probability"
beams = [v for (_, v) in self.entries.items()]
sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText)
return [x.labeling for x in sortedBeams]
def wordsearch(self, classes, ignore_idx, beamWidth, dict_list):
beams = [v for (_, v) in self.entries.items()]
sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText)[:beamWidth]
for j, candidate in enumerate(sortedBeams):
idx_list = candidate.labeling
text = ''
for i,l in enumerate(idx_list):
if l not in ignore_idx and (not (i > 0 and idx_list[i - 1] == idx_list[i])): # removing repeated characters and blank.
text += classes[l]
if j == 0: best_text = text
if text in dict_list:
print('found text: ', text)
best_text = text
break
else:
print('not in dict: ', text)
return best_text
def applyLM(parentBeam, childBeam, classes, lm):
"calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars"
if lm and not childBeam.lmApplied:
c1 = classes[parentBeam.labeling[-1] if parentBeam.labeling else classes.index(' ')] # first char
c2 = classes[childBeam.labeling[-1]] # second char
lmFactor = 0.01 # influence of language model
bigramProb = lm.getCharBigram(c1, c2) ** lmFactor # probability of seeing first and second char next to each other
childBeam.prText = parentBeam.prText * bigramProb # probability of char sequence
childBeam.lmApplied = True # only apply LM once per beam entry
def addBeam(beamState, labeling):
"add beam if it does not yet exist"
if labeling not in beamState.entries:
beamState.entries[labeling] = BeamEntry()
def ctcBeamSearch(mat, classes, ignore_idx, lm, beamWidth=25, dict_list = []):
"beam search as described by the paper of Hwang et al. and the paper of Graves et al."
#blankIdx = len(classes)
blankIdx = 0
maxT, maxC = mat.shape
# initialise beam state
last = BeamState()
labeling = ()
last.entries[labeling] = BeamEntry()
last.entries[labeling].prBlank = 1
last.entries[labeling].prTotal = 1
# go over all time-steps
for t in range(maxT):
curr = BeamState()
# get beam-labelings of best beams
bestLabelings = last.sort()[0:beamWidth]
# go over best beams
for labeling in bestLabelings:
# probability of paths ending with a non-blank
prNonBlank = 0
# in case of non-empty beam
if labeling:
# probability of paths with repeated last char at the end
prNonBlank = last.entries[labeling].prNonBlank * mat[t, labeling[-1]]
# probability of paths ending with a blank
prBlank = (last.entries[labeling].prTotal) * mat[t, blankIdx]
# add beam at current time-step if needed
addBeam(curr, labeling)
# fill in data
curr.entries[labeling].labeling = labeling
curr.entries[labeling].prNonBlank += prNonBlank
curr.entries[labeling].prBlank += prBlank
curr.entries[labeling].prTotal += prBlank + prNonBlank
curr.entries[labeling].prText = last.entries[labeling].prText # beam-labeling not changed, therefore also LM score unchanged from
curr.entries[labeling].lmApplied = True # LM already applied at previous time-step for this beam-labeling
# extend current beam-labeling
for c in range(maxC - 1):
# add new char to current beam-labeling
newLabeling = labeling + (c,)
# if new labeling contains duplicate char at the end, only consider paths ending with a blank
if labeling and labeling[-1] == c:
prNonBlank = mat[t, c] * last.entries[labeling].prBlank
else:
prNonBlank = mat[t, c] * last.entries[labeling].prTotal
# add beam at current time-step if needed
addBeam(curr, newLabeling)
# fill in data
curr.entries[newLabeling].labeling = newLabeling
curr.entries[newLabeling].prNonBlank += prNonBlank
curr.entries[newLabeling].prTotal += prNonBlank
# apply LM
#applyLM(curr.entries[labeling], curr.entries[newLabeling], classes, lm)
# set new beam state
last = curr
# normalise LM scores according to beam-labeling-length
last.norm()
# sort by probability
#bestLabeling = last.sort()[0] # get most probable labeling
# map labels to chars
#res = ''
#for idx,l in enumerate(bestLabeling):
# if l not in ignore_idx and (not (idx > 0 and bestLabeling[idx - 1] == bestLabeling[idx])): # removing repeated characters and blank.
# res += classes[l]
if dict_list == []:
bestLabeling = last.sort()[0] # get most probable labeling
res = ''
for i,l in enumerate(bestLabeling):
if l not in ignore_idx and (not (i > 0 and bestLabeling[i - 1] == bestLabeling[i])): # removing repeated characters and blank.
res += classes[l]
else:
res = last.wordsearch(classes, ignore_idx, beamWidth, dict_list)
return res
#####
def consecutive(data, mode ='first', stepsize=1):
group = np.split(data, np.where(np.diff(data) != stepsize)[0]+1)
group = [item for item in group if len(item)>0]
if mode == 'first': result = [l[0] for l in group]
elif mode == 'last': result = [l[-1] for l in group]
return result
def word_segmentation(mat, separator_idx = {'th': [1,2],'en': [3,4]}, separator_idx_list = [1,2,3,4]):
result = []
sep_list = []
start_idx = 0
for sep_idx in separator_idx_list:
if sep_idx % 2 == 0: mode ='first'
else: mode ='last'
a = consecutive( np.argwhere(mat == sep_idx).flatten(), mode)
new_sep = [ [item, sep_idx] for item in a]
sep_list += new_sep
sep_list = sorted(sep_list, key=lambda x: x[0])
for sep in sep_list:
for lang in separator_idx.keys():
if sep[1] == separator_idx[lang][0]: # start lang
sep_lang = lang
sep_start_idx = sep[0]
elif sep[1] == separator_idx[lang][1]: # end lang
if sep_lang == lang: # check if last entry if the same start lang
new_sep_pair = [lang, [sep_start_idx+1, sep[0]-1]]
if sep_start_idx > start_idx:
result.append( ['', [start_idx, sep_start_idx-1] ] )
start_idx = sep[0]+1
result.append(new_sep_pair)
else: # reset
sep_lang = ''
if start_idx <= len(mat)-1:
result.append( ['', [start_idx, len(mat)-1] ] )
return result
class CTCLabelConverter(object):
""" Convert between text-label and text-index """
#def __init__(self, character, separator = []):
def __init__(self, character, separator_list = {}, dict_pathlist = {}):
# character (str): set of the possible characters.
dict_character = list(character)
#special_character = ['\xa2', '\xa3', '\xa4','\xa5']
#self.separator_char = special_character[:len(separator)]
self.dict = {}
#for i, char in enumerate(self.separator_char + dict_character):
for i, char in enumerate(dict_character):
# NOTE: 0 is reserved for 'blank' token required by CTCLoss
self.dict[char] = i + 1
self.character = ['[blank]'] + dict_character # dummy '[blank]' token for CTCLoss (index 0)
#self.character = ['[blank]']+ self.separator_char + dict_character # dummy '[blank]' token for CTCLoss (index 0)
self.separator_list = separator_list
separator_char = []
for lang, sep in separator_list.items():
separator_char += sep
self.ignore_idx = [0] + [i+1 for i,item in enumerate(separator_char)]
dict_list = {}
for lang, dict_path in dict_pathlist.items():
with open(dict_path, "rb") as input_file:
word_count = pickle.load(input_file)
dict_list[lang] = word_count
self.dict_list = dict_list
def encode(self, text, batch_max_length=25):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
length = [len(s) for s in text]
text = ''.join(text)
text = [self.dict[char] for char in text]
return (torch.IntTensor(text), torch.IntTensor(length))
def decode_greedy(self, text_index, length):
""" convert text-index into text-label. """
texts = []
index = 0
for l in length:
t = text_index[index:index + l]
char_list = []
for i in range(l):
if t[i] not in self.ignore_idx and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank (and separator).
#if (t[i] != 0) and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank (and separator).
char_list.append(self.character[t[i]])
text = ''.join(char_list)
texts.append(text)
index += l
return texts
def decode_beamsearch(self, mat, beamWidth=5):
texts = []
for i in range(mat.shape[0]):
t = ctcBeamSearch(mat[i], self.character, self.ignore_idx, None, beamWidth=beamWidth)
texts.append(t)
return texts
def decode_wordbeamsearch(self, mat, beamWidth=5):
texts = []
argmax = np.argmax(mat, axis = 2)
for i in range(mat.shape[0]):
words = word_segmentation(argmax[i])
string = ''
for word in words:
matrix = mat[i, word[1][0]:word[1][1]+1,:]
if word[0] == '': dict_list = []
else: dict_list = self.dict_list[word[0]]
t = ctcBeamSearch(matrix, self.character, self.ignore_idx, None, beamWidth=beamWidth, dict_list=dict_list)
string += t
texts.append(string)
return texts
class AttnLabelConverter(object):
""" Convert between text-label and text-index """
def __init__(self, character):
# character (str): set of the possible characters.
# [GO] for the start token of the attention decoder. [s] for end-of-sentence token.
list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]']
list_character = list(character)
self.character = list_token + list_character
self.dict = {}
for i, char in enumerate(self.character):
# print(i, char)
self.dict[char] = i
def encode(self, text, batch_max_length=25):
""" convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
batch_max_length: max length of text label in the batch. 25 by default
output:
text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token.
text[:, 0] is [GO] token and text is padded with [GO] token after [s] token.
length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size]
"""
length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence.
# batch_max_length = max(length) # this is not allowed for multi-gpu setting
batch_max_length += 1
# additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token.
batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0)
for i, t in enumerate(text):
text = list(t)
text.append('[s]')
text = [self.dict[char] for char in text]
batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token
return (batch_text.to(device), torch.IntTensor(length).to(device))
def decode(self, text_index, length):
""" convert text-index into text-label. """
texts = []
for index, l in enumerate(length):
text = ''.join([self.character[i] for i in text_index[index, :]])
texts.append(text)
return texts
class Averager(object):
"""Compute average for torch.Tensor, used for loss average."""
def __init__(self):
self.reset()
def add(self, v):
count = v.data.numel()
v = v.data.sum()
self.n_count += count
self.sum += v
def reset(self):
self.n_count = 0
self.sum = 0
def val(self):
res = 0
if self.n_count != 0:
res = self.sum / float(self.n_count)
return res