Spaces:
Running
Running
| # A simplified version of the original code - https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition | |
| import math | |
| import torch | |
| import torchvision.transforms as T | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| class NormalizePAD(object): | |
| def __init__(self, max_size, PAD_type='right'): | |
| self.toTensor = T.ToTensor() | |
| self.max_size = max_size | |
| self.max_width_half = math.floor(max_size[2] / 2) | |
| self.PAD_type = PAD_type | |
| def __call__(self, img): | |
| img = self.toTensor(img) | |
| img.sub_(0.5).div_(0.5) | |
| c, h, w = img.size() | |
| Pad_img = torch.FloatTensor(*self.max_size).fill_(0) | |
| Pad_img[:, :, :w] = img # right pad | |
| if self.max_size[2] != w: # add border Pad | |
| Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w) | |
| return Pad_img | |
| class CTCLabelConverter(object): | |
| """ Convert between text-label and text-index """ | |
| def __init__(self, character): | |
| # character (str): set of the possible characters. | |
| dict_character = list(character) | |
| self.dict = {} | |
| for i, char in enumerate(dict_character): | |
| # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss | |
| self.dict[char] = i + 1 | |
| self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) | |
| def encode(self, text, batch_max_length=25): | |
| """convert text-label into text-index. | |
| input: | |
| text: text labels of each image. [batch_size] | |
| batch_max_length: max length of text label in the batch. 25 by default | |
| output: | |
| text: text index for CTCLoss. [batch_size, batch_max_length] | |
| length: length of each text. [batch_size] | |
| """ | |
| length = [len(s) for s in text] | |
| # The index used for padding (=0) would not affect the CTC loss calculation. | |
| batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0) | |
| for i, t in enumerate(text): | |
| text = list(t) | |
| text = [self.dict[char] for char in text] | |
| batch_text[i][:len(text)] = torch.LongTensor(text) | |
| return (batch_text, torch.IntTensor(length)) | |
| def decode(self, text_index, length): | |
| """ convert text-index into text-label. """ | |
| texts = [] | |
| for index, l in enumerate(length): | |
| t = text_index[index, :] | |
| char_list = [] | |
| for i in range(l): | |
| if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. | |
| char_list.append(self.character[t[i]]) | |
| text = ''.join(char_list) | |
| texts.append(text) | |
| return texts |