Spaces:
Runtime error
Runtime error
| # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| This code is refer from: | |
| https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/mmocr/models/textrecog/backbones/table_resnet_extra.py | |
| """ | |
| import paddle | |
| import paddle.nn as nn | |
| import paddle.nn.functional as F | |
| class BasicBlock(nn.Layer): | |
| expansion = 1 | |
| def __init__(self, | |
| inplanes, | |
| planes, | |
| stride=1, | |
| downsample=None, | |
| gcb_config=None): | |
| super(BasicBlock, self).__init__() | |
| self.conv1 = nn.Conv2D( | |
| inplanes, | |
| planes, | |
| kernel_size=3, | |
| stride=stride, | |
| padding=1, | |
| bias_attr=False) | |
| self.bn1 = nn.BatchNorm2D(planes, momentum=0.9) | |
| self.relu = nn.ReLU() | |
| self.conv2 = nn.Conv2D( | |
| planes, planes, kernel_size=3, stride=1, padding=1, bias_attr=False) | |
| self.bn2 = nn.BatchNorm2D(planes, momentum=0.9) | |
| self.downsample = downsample | |
| self.stride = stride | |
| self.gcb_config = gcb_config | |
| if self.gcb_config is not None: | |
| gcb_ratio = gcb_config['ratio'] | |
| gcb_headers = gcb_config['headers'] | |
| att_scale = gcb_config['att_scale'] | |
| fusion_type = gcb_config['fusion_type'] | |
| self.context_block = MultiAspectGCAttention( | |
| inplanes=planes, | |
| ratio=gcb_ratio, | |
| headers=gcb_headers, | |
| att_scale=att_scale, | |
| fusion_type=fusion_type) | |
| def forward(self, x): | |
| residual = x | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.relu(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| if self.gcb_config is not None: | |
| out = self.context_block(out) | |
| if self.downsample is not None: | |
| residual = self.downsample(x) | |
| out += residual | |
| out = self.relu(out) | |
| return out | |
| def get_gcb_config(gcb_config, layer): | |
| if gcb_config is None or not gcb_config['layers'][layer]: | |
| return None | |
| else: | |
| return gcb_config | |
| class TableResNetExtra(nn.Layer): | |
| def __init__(self, layers, in_channels=3, gcb_config=None): | |
| assert len(layers) >= 4 | |
| super(TableResNetExtra, self).__init__() | |
| self.inplanes = 128 | |
| self.conv1 = nn.Conv2D( | |
| in_channels, | |
| 64, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias_attr=False) | |
| self.bn1 = nn.BatchNorm2D(64) | |
| self.relu1 = nn.ReLU() | |
| self.conv2 = nn.Conv2D( | |
| 64, 128, kernel_size=3, stride=1, padding=1, bias_attr=False) | |
| self.bn2 = nn.BatchNorm2D(128) | |
| self.relu2 = nn.ReLU() | |
| self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2) | |
| self.layer1 = self._make_layer( | |
| BasicBlock, | |
| 256, | |
| layers[0], | |
| stride=1, | |
| gcb_config=get_gcb_config(gcb_config, 0)) | |
| self.conv3 = nn.Conv2D( | |
| 256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False) | |
| self.bn3 = nn.BatchNorm2D(256) | |
| self.relu3 = nn.ReLU() | |
| self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2) | |
| self.layer2 = self._make_layer( | |
| BasicBlock, | |
| 256, | |
| layers[1], | |
| stride=1, | |
| gcb_config=get_gcb_config(gcb_config, 1)) | |
| self.conv4 = nn.Conv2D( | |
| 256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False) | |
| self.bn4 = nn.BatchNorm2D(256) | |
| self.relu4 = nn.ReLU() | |
| self.maxpool3 = nn.MaxPool2D(kernel_size=2, stride=2) | |
| self.layer3 = self._make_layer( | |
| BasicBlock, | |
| 512, | |
| layers[2], | |
| stride=1, | |
| gcb_config=get_gcb_config(gcb_config, 2)) | |
| self.conv5 = nn.Conv2D( | |
| 512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False) | |
| self.bn5 = nn.BatchNorm2D(512) | |
| self.relu5 = nn.ReLU() | |
| self.layer4 = self._make_layer( | |
| BasicBlock, | |
| 512, | |
| layers[3], | |
| stride=1, | |
| gcb_config=get_gcb_config(gcb_config, 3)) | |
| self.conv6 = nn.Conv2D( | |
| 512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False) | |
| self.bn6 = nn.BatchNorm2D(512) | |
| self.relu6 = nn.ReLU() | |
| self.out_channels = [256, 256, 512] | |
| def _make_layer(self, block, planes, blocks, stride=1, gcb_config=None): | |
| downsample = None | |
| if stride != 1 or self.inplanes != planes * block.expansion: | |
| downsample = nn.Sequential( | |
| nn.Conv2D( | |
| self.inplanes, | |
| planes * block.expansion, | |
| kernel_size=1, | |
| stride=stride, | |
| bias_attr=False), | |
| nn.BatchNorm2D(planes * block.expansion), ) | |
| layers = [] | |
| layers.append( | |
| block( | |
| self.inplanes, | |
| planes, | |
| stride, | |
| downsample, | |
| gcb_config=gcb_config)) | |
| self.inplanes = planes * block.expansion | |
| for _ in range(1, blocks): | |
| layers.append(block(self.inplanes, planes)) | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| f = [] | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.relu1(x) | |
| x = self.conv2(x) | |
| x = self.bn2(x) | |
| x = self.relu2(x) | |
| x = self.maxpool1(x) | |
| x = self.layer1(x) | |
| x = self.conv3(x) | |
| x = self.bn3(x) | |
| x = self.relu3(x) | |
| f.append(x) | |
| x = self.maxpool2(x) | |
| x = self.layer2(x) | |
| x = self.conv4(x) | |
| x = self.bn4(x) | |
| x = self.relu4(x) | |
| f.append(x) | |
| x = self.maxpool3(x) | |
| x = self.layer3(x) | |
| x = self.conv5(x) | |
| x = self.bn5(x) | |
| x = self.relu5(x) | |
| x = self.layer4(x) | |
| x = self.conv6(x) | |
| x = self.bn6(x) | |
| x = self.relu6(x) | |
| f.append(x) | |
| return f | |
| class MultiAspectGCAttention(nn.Layer): | |
| def __init__(self, | |
| inplanes, | |
| ratio, | |
| headers, | |
| pooling_type='att', | |
| att_scale=False, | |
| fusion_type='channel_add'): | |
| super(MultiAspectGCAttention, self).__init__() | |
| assert pooling_type in ['avg', 'att'] | |
| assert fusion_type in ['channel_add', 'channel_mul', 'channel_concat'] | |
| assert inplanes % headers == 0 and inplanes >= 8 # inplanes must be divided by headers evenly | |
| self.headers = headers | |
| self.inplanes = inplanes | |
| self.ratio = ratio | |
| self.planes = int(inplanes * ratio) | |
| self.pooling_type = pooling_type | |
| self.fusion_type = fusion_type | |
| self.att_scale = False | |
| self.single_header_inplanes = int(inplanes / headers) | |
| if pooling_type == 'att': | |
| self.conv_mask = nn.Conv2D( | |
| self.single_header_inplanes, 1, kernel_size=1) | |
| self.softmax = nn.Softmax(axis=2) | |
| else: | |
| self.avg_pool = nn.AdaptiveAvgPool2D(1) | |
| if fusion_type == 'channel_add': | |
| self.channel_add_conv = nn.Sequential( | |
| nn.Conv2D( | |
| self.inplanes, self.planes, kernel_size=1), | |
| nn.LayerNorm([self.planes, 1, 1]), | |
| nn.ReLU(), | |
| nn.Conv2D( | |
| self.planes, self.inplanes, kernel_size=1)) | |
| elif fusion_type == 'channel_concat': | |
| self.channel_concat_conv = nn.Sequential( | |
| nn.Conv2D( | |
| self.inplanes, self.planes, kernel_size=1), | |
| nn.LayerNorm([self.planes, 1, 1]), | |
| nn.ReLU(), | |
| nn.Conv2D( | |
| self.planes, self.inplanes, kernel_size=1)) | |
| # for concat | |
| self.cat_conv = nn.Conv2D( | |
| 2 * self.inplanes, self.inplanes, kernel_size=1) | |
| elif fusion_type == 'channel_mul': | |
| self.channel_mul_conv = nn.Sequential( | |
| nn.Conv2D( | |
| self.inplanes, self.planes, kernel_size=1), | |
| nn.LayerNorm([self.planes, 1, 1]), | |
| nn.ReLU(), | |
| nn.Conv2D( | |
| self.planes, self.inplanes, kernel_size=1)) | |
| def spatial_pool(self, x): | |
| batch, channel, height, width = x.shape | |
| if self.pooling_type == 'att': | |
| # [N*headers, C', H , W] C = headers * C' | |
| x = x.reshape([ | |
| batch * self.headers, self.single_header_inplanes, height, width | |
| ]) | |
| input_x = x | |
| # [N*headers, C', H * W] C = headers * C' | |
| # input_x = input_x.view(batch, channel, height * width) | |
| input_x = input_x.reshape([ | |
| batch * self.headers, self.single_header_inplanes, | |
| height * width | |
| ]) | |
| # [N*headers, 1, C', H * W] | |
| input_x = input_x.unsqueeze(1) | |
| # [N*headers, 1, H, W] | |
| context_mask = self.conv_mask(x) | |
| # [N*headers, 1, H * W] | |
| context_mask = context_mask.reshape( | |
| [batch * self.headers, 1, height * width]) | |
| # scale variance | |
| if self.att_scale and self.headers > 1: | |
| context_mask = context_mask / paddle.sqrt( | |
| self.single_header_inplanes) | |
| # [N*headers, 1, H * W] | |
| context_mask = self.softmax(context_mask) | |
| # [N*headers, 1, H * W, 1] | |
| context_mask = context_mask.unsqueeze(-1) | |
| # [N*headers, 1, C', 1] = [N*headers, 1, C', H * W] * [N*headers, 1, H * W, 1] | |
| context = paddle.matmul(input_x, context_mask) | |
| # [N, headers * C', 1, 1] | |
| context = context.reshape( | |
| [batch, self.headers * self.single_header_inplanes, 1, 1]) | |
| else: | |
| # [N, C, 1, 1] | |
| context = self.avg_pool(x) | |
| return context | |
| def forward(self, x): | |
| # [N, C, 1, 1] | |
| context = self.spatial_pool(x) | |
| out = x | |
| if self.fusion_type == 'channel_mul': | |
| # [N, C, 1, 1] | |
| channel_mul_term = F.sigmoid(self.channel_mul_conv(context)) | |
| out = out * channel_mul_term | |
| elif self.fusion_type == 'channel_add': | |
| # [N, C, 1, 1] | |
| channel_add_term = self.channel_add_conv(context) | |
| out = out + channel_add_term | |
| else: | |
| # [N, C, 1, 1] | |
| channel_concat_term = self.channel_concat_conv(context) | |
| # use concat | |
| _, C1, _, _ = channel_concat_term.shape | |
| N, C2, H, W = out.shape | |
| out = paddle.concat( | |
| [out, channel_concat_term.expand([-1, -1, H, W])], axis=1) | |
| out = self.cat_conv(out) | |
| out = F.layer_norm(out, [self.inplanes, H, W]) | |
| out = F.relu(out) | |
| return out | |