import torch
from .networks import *
class BidirectionalLSTM(nn.Module):
def __init__(self, nIn, nHidden, nOut):
super(BidirectionalLSTM, self).__init__()
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
self.embedding = nn.Linear(nHidden * 2, nOut)
def forward(self, input):
recurrent, _ = self.rnn(input)
T, b, h = recurrent.size()
t_rec = recurrent.view(T * b, h)
output = self.embedding(t_rec) # [T * b, nOut]
output = output.view(T, b, -1)
return output
class CRNN(nn.Module):
def __init__(self, args, leakyRelu=False):
super(CRNN, self).__init__()
self.args = args = 'OCR'
self.add_noise = False
self.noise_fac = torch.distributions.Normal(loc=torch.tensor([0.]), scale=torch.tensor([0.2]))
#assert opt.imgH % 16 == 0, 'imgH has to be a multiple of 16'
ks = [3, 3, 3, 3, 3, 3, 2]
ps = [1, 1, 1, 1, 1, 1, 0]
ss = [1, 1, 1, 1, 1, 1, 1]
nm = [64, 128, 256, 256, 512, 512, 512]
cnn = nn.Sequential()
nh = 256
dealwith_lossnone=False # whether to replace all nan/inf in gradients to zero
def convRelu(i, batchNormalization=False):
nIn = 1 if i == 0 else nm[i - 1]
nOut = nm[i]
nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
if batchNormalization:
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
if leakyRelu:
nn.LeakyReLU(0.2, inplace=True))
cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
convRelu(2, True)
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
convRelu(4, True)
if self.args.resolution==63:
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
convRelu(6, True) # 512x1x16
self.cnn = cnn
self.use_rnn = False
if self.use_rnn:
self.rnn = nn.Sequential(
BidirectionalLSTM(512, nh, nh),
BidirectionalLSTM(nh, nh, ))
self.linear = nn.Linear(512, self.args.vocab_size)
# replace all nan/inf in gradients to zero
if dealwith_lossnone:
self.device = torch.device('cuda:{}'.format(0))
self.init = 'N02'
# Initialize weights
self = init_weights(self, self.init)
def forward(self, input):
# conv features
if self.add_noise:
input = input + self.noise_fac.sample(input.size()).squeeze(-1).to(self.args.device)
conv = self.cnn(input)
b, c, h, w = conv.size()
if h!=1:
assert h == 1, "the height of conv must be 1"
conv = conv.squeeze(2)
conv = conv.permute(2, 0, 1) # [w, b, c]
if self.use_rnn:
# rnn features
output = self.rnn(conv)
output = self.linear(conv)
return output
def backward_hook(self, module, grad_input, grad_output):
for g in grad_input:
g[g != g] = 0 # replace all nan/inf in gradients to zero
class strLabelConverter(object):
"""Convert between str and label.
Insert `blank` to the alphabet for CTC.
alphabet (str): set of the possible characters.
ignore_case (bool, default=True): whether or not to ignore all of the case.
def __init__(self, alphabet, ignore_case=False):
self._ignore_case = ignore_case
if self._ignore_case:
alphabet = alphabet.lower()
self.alphabet = alphabet + '-' # for `-1` index
self.dict = {}
for i, char in enumerate(alphabet):
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
self.dict[char] = i + 1
def encode(self, text):
"""Support batch or single str.
text (str or list of str): texts to convert.
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
length = []
result = []
results = []
for item in text:
if isinstance(item, bytes): item = item.decode('utf-8', 'strict')
for char in item:
index = self.dict[char]
result = []
return torch.nn.utils.rnn.pad_sequence([torch.LongTensor(text) for text in results], batch_first=True), torch.IntTensor(length), None
def decode(self, t, length, raw=False):
"""Decode encoded texts back into strs.
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
AssertionError: when the texts and its length does not match.
text (str or list of str): texts to convert.
if length.numel() == 1:
length = length[0]
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(),
if raw:
return ''.join([self.alphabet[i - 1] for i in t])
char_list = []
for i in range(length):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
char_list.append(self.alphabet[t[i] - 1])
return ''.join(char_list)
# batch mode
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(
t.numel(), length.sum())
texts = []
index = 0
for i in range(length.numel()):
l = length[i]
t[index:index + l], torch.IntTensor([l]), raw=raw))
index += l
return texts