testapi / manga_translator /ocr /model_32px.py
Sunday01's picture
up
9dce458
import math
from typing import List
from collections import defaultdict
import os
import shutil
import cv2
import numpy as np
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from .common import OfflineOCR
from ..utils import TextBlock, Quadrilateral, chunks
from ..utils.bubble import is_ignore
class Model32pxOCR(OfflineOCR):
_MODEL_MAPPING = {
'model': {
'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/ocr.zip',
'hash': '47405638b96fa2540a5ee841a4cd792f25062c09d9458a973362d40785f95d7a',
'archive': {
'ocr.ckpt': '.',
'alphabet-all-v5.txt': '.',
},
},
}
def __init__(self, *args, **kwargs):
os.makedirs(self.model_dir, exist_ok=True)
if os.path.exists('ocr.ckpt'):
shutil.move('ocr.ckpt', self._get_file_path('ocr.ckpt'))
if os.path.exists('alphabet-all-v5.txt'):
shutil.move('alphabet-all-v5.txt', self._get_file_path('alphabet-all-v5.txt'))
super().__init__(*args, **kwargs)
async def _load(self, device: str):
with open(self._get_file_path('alphabet-all-v5.txt'), 'r', encoding = 'utf-8') as fp:
dictionary = [s[:-1] for s in fp.readlines()]
self.model = OCR(dictionary, 768)
sd = torch.load(self._get_file_path('ocr.ckpt'), map_location = 'cpu')
self.model.load_state_dict(sd['model'] if 'model' in sd else sd)
self.model.eval()
self.device = device
if (device == 'cuda' or device == 'mps'):
self.use_gpu = True
else:
self.use_gpu = False
if self.use_gpu:
self.model = self.model.to(device)
async def _unload(self):
del self.model
async def _infer(self, image: np.ndarray, textlines: List[Quadrilateral], args: dict, verbose: bool = False) -> List[TextBlock]:
text_height = 32
max_chunk_size = 16
ignore_bubble = args.get('ignore_bubble', 0)
quadrilaterals = list(self._generate_text_direction(textlines))
region_imgs = [q.get_transformed_region(image, d, text_height) for q, d in quadrilaterals]
out_regions = []
perm = range(len(region_imgs))
is_quadrilaterals = False
if len(quadrilaterals) > 0 and isinstance(quadrilaterals[0][0], Quadrilateral):
perm = sorted(range(len(region_imgs)), key = lambda x: region_imgs[x].shape[1])
is_quadrilaterals = True
ix = 0
for indices in chunks(perm, max_chunk_size):
N = len(indices)
widths = [region_imgs[i].shape[1] for i in indices]
max_width = 4 * (max(widths) + 7) // 4
region = np.zeros((N, text_height, max_width, 3), dtype = np.uint8)
for i, idx in enumerate(indices):
W = region_imgs[idx].shape[1]
tmp = region_imgs[idx]
# Determine whether to skip the text block, and return True to skip.
if ignore_bubble >=1 and ignore_bubble <=50 and is_ignore(region_imgs[idx],ignore_bubble):
ix+=1
continue
region[i, :, : W, :]=tmp
if verbose:
os.makedirs('result/ocrs/', exist_ok=True)
if quadrilaterals[idx][1] == 'v':
cv2.imwrite(f'result/ocrs/{ix}.png', cv2.rotate(cv2.cvtColor(region[i, :, :, :], cv2.COLOR_RGB2BGR), cv2.ROTATE_90_CLOCKWISE))
else:
cv2.imwrite(f'result/ocrs/{ix}.png', cv2.cvtColor(region[i, :, :, :], cv2.COLOR_RGB2BGR))
ix += 1
image_tensor = (torch.from_numpy(region).float() - 127.5) / 127.5
image_tensor = einops.rearrange(image_tensor, 'N H W C -> N C H W')
if self.use_gpu:
image_tensor = image_tensor.to(self.device)
with torch.no_grad():
ret = self.model.infer_beam_batch(image_tensor, widths, beams_k = 5, max_seq_length = 255)
for i, (pred_chars_index, prob, fr, fg, fb, br, bg, bb) in enumerate(ret):
if prob < 0.7:
continue
fr = (torch.clip(fr.view(-1), 0, 1).mean() * 255).long().item()
fg = (torch.clip(fg.view(-1), 0, 1).mean() * 255).long().item()
fb = (torch.clip(fb.view(-1), 0, 1).mean() * 255).long().item()
br = (torch.clip(br.view(-1), 0, 1).mean() * 255).long().item()
bg = (torch.clip(bg.view(-1), 0, 1).mean() * 255).long().item()
bb = (torch.clip(bb.view(-1), 0, 1).mean() * 255).long().item()
seq = []
for chid in pred_chars_index:
ch = self.model.dictionary[chid]
if ch == '<S>':
continue
if ch == '</S>':
break
if ch == '<SP>':
ch = ' '
seq.append(ch)
txt = ''.join(seq)
self.logger.info(f'prob: {prob} {txt} fg: ({fr}, {fg}, {fb}) bg: ({br}, {bg}, {bb})')
cur_region = quadrilaterals[indices[i]][0]
if isinstance(cur_region, Quadrilateral):
cur_region.text = txt
cur_region.prob = prob
cur_region.fg_r = fr
cur_region.fg_g = fg
cur_region.fg_b = fb
cur_region.bg_r = br
cur_region.bg_g = bg
cur_region.bg_b = bb
else:
cur_region.text.append(txt)
cur_region.update_font_colors(np.array([fr, fg, fb]), np.array([br, bg, bb]))
out_regions.append(cur_region)
if is_quadrilaterals:
return out_regions
return textlines
class ResNet(nn.Module):
def __init__(self, input_channel, output_channel, block, layers):
super(ResNet, self).__init__()
self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
self.inplanes = int(output_channel / 8)
self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 8),
kernel_size=3, stride=1, padding=1, bias=False)
self.bn0_1 = nn.BatchNorm2d(int(output_channel / 8))
self.conv0_2 = nn.Conv2d(int(output_channel / 8), self.inplanes,
kernel_size=3, stride=1, padding=1, bias=False)
self.maxpool1 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[
0], kernel_size=3, stride=1, padding=1, bias=False)
self.maxpool2 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1)
self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[
1], kernel_size=3, stride=1, padding=1, bias=False)
self.maxpool3 = nn.AvgPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1)
self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[
2], kernel_size=3, stride=1, padding=1, bias=False)
self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1)
self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False)
self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
3], kernel_size=2, stride=1, padding=0, bias=False)
self.bn4_3 = nn.BatchNorm2d(self.output_channel_block[3])
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.BatchNorm2d(self.inplanes),
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv0_1(x)
x = self.bn0_1(x)
x = F.relu(x)
x = self.conv0_2(x)
x = self.maxpool1(x)
x = self.layer1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.conv1(x)
x = self.maxpool2(x)
x = self.layer2(x)
x = self.bn2(x)
x = F.relu(x)
x = self.conv2(x)
x = self.maxpool3(x)
x = self.layer3(x)
x = self.bn3(x)
x = F.relu(x)
x = self.conv3(x)
x = self.layer4(x)
x = self.bn4_1(x)
x = F.relu(x)
x = self.conv4_1(x)
x = self.bn4_2(x)
x = F.relu(x)
x = self.conv4_2(x)
x = self.bn4_3(x)
return x
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(inplanes)
self.conv1 = self._conv3x3(inplanes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = self._conv3x3(planes, planes)
self.downsample = downsample
self.stride = stride
def _conv3x3(self, in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
def forward(self, x):
residual = x
out = self.bn1(x)
out = F.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = F.relu(out)
out = self.conv2(out)
if self.downsample is not None:
residual = self.downsample(residual)
return out + residual
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class ResNet_FeatureExtractor(nn.Module):
""" FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """
def __init__(self, input_channel, output_channel=128):
super(ResNet_FeatureExtractor, self).__init__()
self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [3, 6, 7, 5])
def forward(self, input):
return self.ConvNet(input)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x, offset = 0):
x = x + self.pe[offset: offset + x.size(0), :]
return x#self.dropout(x)
def generate_square_subsequent_mask(sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
class AddCoords(nn.Module):
def __init__(self, with_r=False):
super().__init__()
self.with_r = with_r
def forward(self, input_tensor):
"""
Args:
input_tensor: shape(batch, channel, x_dim, y_dim)
"""
batch_size, _, x_dim, y_dim = input_tensor.size()
xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)
xx_channel = xx_channel.float() / (x_dim - 1)
yy_channel = yy_channel.float() / (y_dim - 1)
xx_channel = xx_channel * 2 - 1
yy_channel = yy_channel * 2 - 1
xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
ret = torch.cat([
input_tensor,
xx_channel.type_as(input_tensor),
yy_channel.type_as(input_tensor)], dim=1)
if self.with_r:
rr = torch.sqrt(torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2))
ret = torch.cat([ret, rr], dim=1)
return ret
class Beam:
def __init__(self, char_seq = [], logprobs = []):
# L
if isinstance(char_seq, list):
self.chars = torch.tensor(char_seq, dtype=torch.long)
self.logprobs = torch.tensor(logprobs, dtype=torch.float32)
else:
self.chars = char_seq.clone()
self.logprobs = logprobs.clone()
def avg_logprob(self):
return self.logprobs.mean().item()
def sort_key(self):
return -self.avg_logprob()
def seq_end(self, end_tok):
return self.chars.view(-1)[-1] == end_tok
def extend(self, idx, logprob):
return Beam(
torch.cat([self.chars, idx.unsqueeze(0)], dim = -1),
torch.cat([self.logprobs, logprob.unsqueeze(0)], dim = -1),
)
DECODE_BLOCK_LENGTH = 8
class Hypothesis:
def __init__(self, device, start_tok: int, end_tok: int, padding_tok: int, memory_idx: int, num_layers: int, embd_dim: int):
self.device = device
self.start_tok = start_tok
self.end_tok = end_tok
self.padding_tok = padding_tok
self.memory_idx = memory_idx
self.embd_size = embd_dim
self.num_layers = num_layers
# L, 1, E
self.cached_activations = [torch.zeros(0, 1, self.embd_size).to(self.device)] * (num_layers + 1)
self.out_idx = torch.LongTensor([start_tok]).to(self.device)
self.out_logprobs = torch.FloatTensor([0]).to(self.device)
self.length = 0
def seq_end(self):
return self.out_idx.view(-1)[-1] == self.end_tok
def logprob(self):
return self.out_logprobs.mean().item()
def sort_key(self):
return -self.logprob()
def prob(self):
return self.out_logprobs.mean().exp().item()
def __len__(self):
return self.length
def extend(self, idx, logprob):
ret = Hypothesis(self.device, self.start_tok, self.end_tok, self.padding_tok, self.memory_idx, self.num_layers, self.embd_size)
ret.cached_activations = [item.clone() for item in self.cached_activations]
ret.length = self.length + 1
ret.out_idx = torch.cat([self.out_idx, torch.LongTensor([idx]).to(self.device)], dim = 0)
ret.out_logprobs = torch.cat([self.out_logprobs, torch.FloatTensor([logprob]).to(self.device)], dim = 0)
return ret
def output(self):
return self.cached_activations[-1]
def next_token_batch(
hyps: List[Hypothesis],
memory: torch.Tensor, # S, K, E
memory_mask: torch.BoolTensor,
decoders: nn.TransformerDecoder,
pe: PositionalEncoding,
embd: nn.Embedding
):
layer: nn.TransformerDecoderLayer
N = len(hyps)
# N
last_toks = torch.stack([item.out_idx[-1] for item in hyps], dim = 0)
# 1, N, E
tgt: torch.FloatTensor = pe(embd(last_toks).unsqueeze_(0), offset = len(hyps[0]))
# # L, N
# out_idxs = torch.stack([item.out_idx for item in hyps], dim = 0).permute(1, 0)
# # L, N, E
# tgt2: torch.FloatTensor = pe(embd(out_idxs))
# # 1, N, E
# tgt_v2 = tgt2[-1, :, :].unsqueeze_(0)
# print(((tgt_v1 - tgt_v2) ** 2).sum())
# tgt = tgt_v2
# S, N, E
memory = torch.stack([memory[:, idx, :] for idx in [item.memory_idx for item in hyps]], dim = 1)
for l, layer in enumerate(decoders.layers):
# TODO: keys and values are recomputed every time
# L - 1, N, E
combined_activations = torch.cat([item.cached_activations[l] for item in hyps], dim = 1)
# L, N, E
combined_activations = torch.cat([combined_activations, tgt], dim = 0)
for i in range(N):
hyps[i].cached_activations[l] = combined_activations[:, i: i + 1, :]
tgt2 = layer.self_attn(tgt, combined_activations, combined_activations)[0]
tgt = tgt + layer.dropout1(tgt2)
tgt = layer.norm1(tgt)
tgt2 = layer.multihead_attn(tgt, memory, memory, key_padding_mask = memory_mask)[0]
tgt = tgt + layer.dropout2(tgt2)
tgt = layer.norm2(tgt)
tgt2 = layer.linear2(layer.dropout(layer.activation(layer.linear1(tgt))))
tgt = tgt + layer.dropout3(tgt2)
# 1, N, E
tgt = layer.norm3(tgt)
#print(tgt[0, 0, 0])
for i in range(N):
hyps[i].cached_activations[decoders.num_layers] = torch.cat([hyps[i].cached_activations[decoders.num_layers], tgt[:, i: i + 1, :]], dim = 0)
# N, E
return tgt.squeeze_(0)
class OCR(nn.Module):
def __init__(self, dictionary, max_len):
super(OCR, self).__init__()
self.max_len = max_len
self.dictionary = dictionary
self.dict_size = len(dictionary)
self.backbone = ResNet_FeatureExtractor(3, 320)
encoder = nn.TransformerEncoderLayer(320, 4, dropout = 0.0)
decoder = nn.TransformerDecoderLayer(320, 4, dropout = 0.0)
self.encoders = nn.TransformerEncoder(encoder, 3)
self.decoders = nn.TransformerDecoder(decoder, 2)
self.pe = PositionalEncoding(320, max_len = max_len)
self.embd = nn.Embedding(self.dict_size, 320)
self.pred1 = nn.Sequential(nn.Linear(320, 320), nn.ReLU(), nn.Dropout(0.1))
self.pred = nn.Linear(320, self.dict_size)
self.pred.weight = self.embd.weight
self.color_pred1 = nn.Sequential(nn.Linear(320, 64), nn.ReLU())
self.fg_r_pred = nn.Linear(64, 1)
self.fg_g_pred = nn.Linear(64, 1)
self.fg_b_pred = nn.Linear(64, 1)
self.bg_r_pred = nn.Linear(64, 1)
self.bg_g_pred = nn.Linear(64, 1)
self.bg_b_pred = nn.Linear(64, 1)
def forward(self,
img: torch.FloatTensor,
char_idx: torch.LongTensor,
mask: torch.BoolTensor,
source_mask: torch.BoolTensor
):
feats = self.backbone(img)
feats = torch.einsum('n e h s -> s n e', feats)
feats = self.pe(feats)
memory = self.encoders(feats, src_key_padding_mask = source_mask)
N, L = char_idx.shape
char_embd = self.embd(char_idx)
char_embd = torch.einsum('n t e -> t n e', char_embd)
char_embd = self.pe(char_embd)
casual_mask = generate_square_subsequent_mask(L).to(img.device)
decoded = self.decoders(char_embd, memory, tgt_mask = casual_mask, tgt_key_padding_mask = mask, memory_key_padding_mask = source_mask)
decoded = decoded.permute(1, 0, 2)
pred_char_logits = self.pred(self.pred1(decoded))
color_feats = self.color_pred1(decoded)
return pred_char_logits, \
self.fg_r_pred(color_feats), \
self.fg_g_pred(color_feats), \
self.fg_b_pred(color_feats), \
self.bg_r_pred(color_feats), \
self.bg_g_pred(color_feats), \
self.bg_b_pred(color_feats)
def infer_beam_batch(self, img: torch.FloatTensor, img_widths: List[int], beams_k: int = 5, start_tok = 1, end_tok = 2, pad_tok = 0, max_finished_hypos: int = 2, max_seq_length = 384):
N, C, H, W = img.shape
assert H == 32 and C == 3
feats = self.backbone(img)
feats = torch.einsum('n e h s -> s n e', feats)
valid_feats_length = [(x + 3) // 4 + 2 for x in img_widths]
input_mask = torch.zeros(N, feats.size(0), dtype = torch.bool).to(img.device)
for i, l in enumerate(valid_feats_length):
input_mask[i, l:] = True
feats = self.pe(feats)
memory = self.encoders(feats, src_key_padding_mask = input_mask)
hypos = [Hypothesis(img.device, start_tok, end_tok, pad_tok, i, self.decoders.num_layers, 320) for i in range(N)]
# N, E
decoded = next_token_batch(hypos, memory, input_mask, self.decoders, self.pe, self.embd)
# N, n_chars
pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1)
# N, k
pred_chars_values, pred_chars_index = torch.topk(pred_char_logprob, beams_k, dim = 1)
new_hypos = []
finished_hypos = defaultdict(list)
for i in range(N):
for k in range(beams_k):
new_hypos.append(hypos[i].extend(pred_chars_index[i, k], pred_chars_values[i, k]))
hypos = new_hypos
for _ in range(max_seq_length):
# N * k, E
decoded = next_token_batch(hypos, memory, torch.stack([input_mask[hyp.memory_idx] for hyp in hypos]) , self.decoders, self.pe, self.embd)
# N * k, n_chars
pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1)
# N * k, k
pred_chars_values, pred_chars_index = torch.topk(pred_char_logprob, beams_k, dim = 1)
hypos_per_sample = defaultdict(list)
h: Hypothesis
for i, h in enumerate(hypos):
for k in range(beams_k):
hypos_per_sample[h.memory_idx].append(h.extend(pred_chars_index[i, k], pred_chars_values[i, k]))
hypos = []
# hypos_per_sample now contains N * k^2 hypos
for i in hypos_per_sample.keys():
cur_hypos: List[Hypothesis] = hypos_per_sample[i]
cur_hypos = sorted(cur_hypos, key = lambda a: a.sort_key())[: beams_k + 1]
#print(cur_hypos[0].out_idx[-1])
to_added_hypos = []
sample_done = False
for h in cur_hypos:
if h.seq_end():
finished_hypos[i].append(h)
if len(finished_hypos[i]) >= max_finished_hypos:
sample_done = True
break
else:
if len(to_added_hypos) < beams_k:
to_added_hypos.append(h)
if not sample_done:
hypos.extend(to_added_hypos)
if len(hypos) == 0:
break
# add remaining hypos to finished
for i in range(N):
if i not in finished_hypos:
cur_hypos: List[Hypothesis] = hypos_per_sample[i]
cur_hypo = sorted(cur_hypos, key = lambda a: a.sort_key())[0]
finished_hypos[i].append(cur_hypo)
assert len(finished_hypos) == N
result = []
for i in range(N):
cur_hypos = finished_hypos[i]
cur_hypo = sorted(cur_hypos, key = lambda a: a.sort_key())[0]
decoded = cur_hypo.output()
color_feats = self.color_pred1(decoded)
fg_r, fg_g, fg_b, bg_r, bg_g, bg_b = self.fg_r_pred(color_feats), \
self.fg_g_pred(color_feats), \
self.fg_b_pred(color_feats), \
self.bg_r_pred(color_feats), \
self.bg_g_pred(color_feats), \
self.bg_b_pred(color_feats)
result.append((cur_hypo.out_idx, cur_hypo.prob(), fg_r, fg_g, fg_b, bg_r, bg_g, bg_b))
return result
def infer_beam(self, img: torch.FloatTensor, beams_k: int = 5, start_tok = 1, end_tok = 2, pad_tok = 0, max_seq_length = 384):
N, C, H, W = img.shape
assert H == 32 and N == 1 and C == 3
feats = self.backbone(img)
feats = torch.einsum('n e h s -> s n e', feats)
feats = self.pe(feats)
memory = self.encoders(feats)
def run(tokens, add_start_tok = True, char_only = True):
if add_start_tok:
if isinstance(tokens, list):
# N(=1), L
tokens = torch.tensor([start_tok] + tokens, dtype = torch.long, device = img.device).unsqueeze_(0)
else:
# N, L
tokens = torch.cat([torch.tensor([start_tok], dtype = torch.long, device = img.device), tokens], dim = -1).unsqueeze_(0)
N, L = tokens.shape
embd = self.embd(tokens)
embd = torch.einsum('n t e -> t n e', embd)
embd = self.pe(embd)
casual_mask = generate_square_subsequent_mask(L).to(img.device)
decoded = self.decoders(embd, memory, tgt_mask = casual_mask)
decoded = decoded.permute(1, 0, 2)
pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1)
if char_only:
return pred_char_logprob
else:
color_feats = self.color_pred1(decoded)
return pred_char_logprob, \
self.fg_r_pred(color_feats), \
self.fg_g_pred(color_feats), \
self.fg_b_pred(color_feats), \
self.bg_r_pred(color_feats), \
self.bg_g_pred(color_feats), \
self.bg_b_pred(color_feats)
# N, L, embd_size
initial_char_logprob = run([])
# N, L
initial_pred_chars_values, initial_pred_chars_index = torch.topk(initial_char_logprob, beams_k, dim = 2)
# beams_k, L
initial_pred_chars_values = initial_pred_chars_values.squeeze(0).permute(1, 0)
initial_pred_chars_index = initial_pred_chars_index.squeeze(0).permute(1, 0)
beams = sorted([Beam(tok, logprob) for tok, logprob in zip(initial_pred_chars_index, initial_pred_chars_values)], key = lambda a: a.sort_key())
for _ in range(max_seq_length):
new_beams = []
all_ended = True
for beam in beams:
if not beam.seq_end(end_tok):
logprobs = run(beam.chars)
pred_chars_values, pred_chars_index = torch.topk(logprobs, beams_k, dim = 2)
# beams_k, L
pred_chars_values = pred_chars_values.squeeze(0).permute(1, 0)
pred_chars_index = pred_chars_index.squeeze(0).permute(1, 0)
#print(pred_chars_index.view(-1)[-1])
new_beams.extend([beam.extend(tok[-1], logprob[-1]) for tok, logprob in zip(pred_chars_index, pred_chars_values)])
#new_beams.extend([Beam(tok, logprob) for tok, logprob in zip(pred_chars_index, pred_chars_values)]) # extend other top k
all_ended = False
else:
new_beams.append(beam) # seq ended, add back to queue
beams = sorted(new_beams, key = lambda a: a.sort_key())[: beams_k] # keep top k
#print(beams[0].chars)
if all_ended:
break
final_tokens = beams[0].chars[:-1]
#print(beams[0].logprobs.mean().exp())
return run(final_tokens, char_only = False), beams[0].logprobs.mean().exp().item()
def test():
with open('../SynthText/alphabet-all-v2.txt', 'r') as fp:
dictionary = [s[:-1] for s in fp.readlines()]
img = torch.randn(4, 3, 32, 1224)
idx = torch.zeros(4, 32).long()
mask = torch.zeros(4, 32).bool()
model = ResNet_FeatureExtractor(3, 256)
out = model(img)
def test_inference():
with torch.no_grad():
with open('../SynthText/alphabet-all-v3.txt', 'r') as fp:
dictionary = [s[:-1] for s in fp.readlines()]
img = torch.zeros(1, 3, 32, 128)
model = OCR(dictionary, 32)
m = torch.load("ocr_ar_v2-3-test.ckpt", map_location='cpu')
model.load_state_dict(m['model'])
model.eval()
(char_probs, _, _, _, _, _, _, _), _ = model.infer_beam(img, max_seq_length = 20)
_, pred_chars_index = char_probs.max(2)
pred_chars_index = pred_chars_index.squeeze_(0)
seq = []
for chid in pred_chars_index:
ch = dictionary[chid]
if ch == '<SP>':
ch == ' '
seq.append(ch)
print(''.join(seq))
if __name__ == "__main__":
test()