Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| # MIT License | |
| # Copyright (c) 2022 Intelligent Systems Lab Org | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| # File author: Zhenyu Li | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # from zoedepth.models.layers.swin_layers import G2LFusion | |
| from estimator.models.blocks.swin_layers import G2LFusion | |
| from torchvision.ops import roi_align as torch_roi_align | |
| from estimator.registry import MODELS | |
| class DoubleConvWOBN(nn.Module): | |
| """(convolution => [BN] => ReLU) * 2""" | |
| def __init__(self, in_channels, out_channels, mid_channels=None): | |
| super().__init__() | |
| if not mid_channels: | |
| mid_channels = out_channels | |
| self.double_conv = nn.Sequential( | |
| nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=True), | |
| # nn.BatchNorm2d(mid_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=True), | |
| # nn.BatchNorm2d(mid_channels), | |
| nn.ReLU(inplace=True)) | |
| def forward(self, x): | |
| return self.double_conv(x) | |
| class DoubleConv(nn.Module): | |
| """(convolution => [BN] => ReLU) * 2""" | |
| def __init__(self, in_channels, out_channels, mid_channels=None): | |
| super().__init__() | |
| if not mid_channels: | |
| mid_channels = out_channels | |
| self.double_conv = nn.Sequential( | |
| nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), | |
| nn.BatchNorm2d(mid_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True) | |
| ) | |
| def forward(self, x): | |
| return self.double_conv(x) | |
| class Down(nn.Module): | |
| """Downscaling with maxpool then double conv""" | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.maxpool_conv = nn.Sequential( | |
| nn.MaxPool2d(2), | |
| DoubleConv(in_channels, out_channels) | |
| ) | |
| def forward(self, x): | |
| return self.maxpool_conv(x) | |
| class Upv1(nn.Module): | |
| """Upscaling then double conv""" | |
| def __init__(self, in_channels, out_channels, mid_channels=None): | |
| super().__init__() | |
| # self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
| if mid_channels is not None: | |
| self.conv = DoubleConvWOBN(in_channels, out_channels, mid_channels) | |
| else: | |
| self.conv = DoubleConvWOBN(in_channels, out_channels, in_channels) | |
| def forward(self, x1, x2): | |
| # x1 = self.up(x1) | |
| x1 = F.interpolate(x1, size=x2.shape[-2:], mode='bilinear', align_corners=True) | |
| x = torch.cat([x2, x1], dim=1) | |
| return self.conv(x) | |
| class GuidedFusionPatchFusion(nn.Module): | |
| def __init__( | |
| self, | |
| n_channels, | |
| g2l, | |
| in_channels=[32, 256, 256, 256, 256, 256], | |
| depth=[2, 2, 3, 3, 4, 4], | |
| num_heads=[8, 8, 16, 16, 32, 32], | |
| # num_patches=[12*16, 24*32, 48*64, 96*128, 192*256, 384*512], | |
| num_patches=[384*512, 192*256, 96*128, 48*64, 24*32, 12*16], | |
| patch_process_shape=[384, 512]): | |
| super(GuidedFusionPatchFusion, self).__init__() | |
| self.n_channels = n_channels | |
| self.inc = DoubleConv(n_channels, in_channels[0]) | |
| self.down_conv_list = nn.ModuleList() | |
| for idx in range(len(in_channels) - 1): | |
| lay = Down(in_channels[idx], in_channels[idx+1]) | |
| self.down_conv_list.append(lay) | |
| in_channels_inv = in_channels[::-1] | |
| self.up_conv_list = nn.ModuleList() | |
| for idx in range(1, len(in_channels)): | |
| lay = Upv1(in_channels_inv[idx] + in_channels_inv[idx-1] + in_channels_inv[idx-1], in_channels_inv[idx]) | |
| self.up_conv_list.append(lay) | |
| self.g2l = g2l | |
| if self.g2l: | |
| self.g2l_att = nn.ModuleList() | |
| win = 12 | |
| self.patch_process_shape = patch_process_shape | |
| num_heads_inv = num_heads[::-1] | |
| depth_inv = depth[::-1] | |
| num_patches_inv = num_patches[::-1] | |
| self.g2l_list = nn.ModuleList() | |
| self.convs = nn.ModuleList() | |
| for idx in range(len(in_channels_inv)): | |
| g2l_layer = G2LFusion(input_dim=in_channels_inv[idx], embed_dim=in_channels_inv[idx], window_size=win, num_heads=num_heads_inv[idx], depth=depth_inv[idx], num_patches=num_patches_inv[idx]) | |
| self.g2l_list.append(g2l_layer) | |
| layer = DoubleConvWOBN(in_channels_inv[idx] * 2, in_channels_inv[idx], in_channels_inv[idx]) | |
| self.convs.append(layer) | |
| # self.g2l5 = G2LFusion(input_dim=in_channels[5], embed_dim=crf_dims[5], window_size=win, num_heads=32, depth=4, num_patches=num_patches[0]) | |
| # self.g2l4 = G2LFusion(input_dim=in_channels[4], embed_dim=crf_dims[4], window_size=win, num_heads=32, depth=4, num_patches=num_patches[1]) | |
| # self.g2l3 = G2LFusion(input_dim=in_channels[3], embed_dim=crf_dims[3], window_size=win, num_heads=16, depth=3, num_patches=num_patches[2]) | |
| # self.g2l2 = G2LFusion(input_dim=in_channels[2], embed_dim=crf_dims[2], window_size=win, num_heads=16, depth=3, num_patches=num_patches[3]) | |
| # self.g2l1 = G2LFusion(input_dim=in_channels[1], embed_dim=crf_dims[1], window_size=win, num_heads=8, depth=2, num_patches=num_patches[4]) | |
| # self.g2l0 = G2LFusion(input_dim=in_channels[0], embed_dim=crf_dims[0], window_size=win, num_heads=8, depth=2, num_patches=num_patches[5]) | |
| # self.conv5 = DoubleConvWOBN(in_channels[5] * 2, in_channels[5], in_channels[5]) | |
| # self.conv4 = DoubleConvWOBN(in_channels[4] * 2, in_channels[4], in_channels[4]) | |
| # self.conv3 = DoubleConvWOBN(in_channels[3] * 2, in_channels[3], in_channels[3]) | |
| # self.conv2 = DoubleConvWOBN(in_channels[2] * 2, in_channels[2], in_channels[2]) | |
| # self.conv1 = DoubleConvWOBN(in_channels[1] * 2, in_channels[1], in_channels[1]) | |
| # self.conv0 = DoubleConvWOBN(in_channels[0] * 2, in_channels[0], in_channels[0]) | |
| def forward(self, | |
| input_tensor, | |
| guide_plus, | |
| guide_cat, | |
| bbox=None, | |
| fine_feat_crop=None, | |
| coarse_feat_whole=None, | |
| coarse_feat_whole_hack=None, | |
| coarse_feat_crop=None): | |
| # apply unscaled feat to swin | |
| if coarse_feat_whole_hack is not None: | |
| coarse_feat_whole = coarse_feat_whole_hack | |
| feat_list = [] | |
| x = self.inc(input_tensor) | |
| feat_list.append(x) | |
| for layer in self.down_conv_list: | |
| x = layer(x) | |
| feat_list.append(x) | |
| output = [] | |
| feat_inv_list = feat_list[::-1] | |
| for idx, (feat_enc, feat_c) in enumerate(zip(feat_inv_list, coarse_feat_whole)): | |
| # in case for depth-anything | |
| _, _, h, w = feat_enc.shape | |
| if h != feat_c.shape[-2] or w != feat_c.shape[-1]: | |
| feat_enc = F.interpolate(feat_enc, size=feat_c.shape[-2:], mode='bilinear', align_corners=True) | |
| if idx == 0: | |
| pass | |
| else: | |
| feat_enc = self.up_conv_list[idx-1](torch.cat([temp_feat, guide_cat[idx-1]], dim=1), feat_enc) | |
| _, _, h, w = feat_c.shape | |
| feat_c = self.g2l_list[idx](feat_c, None) | |
| feat_c = torch_roi_align(feat_c, bbox, (h, w), h/self.patch_process_shape[0], aligned=True) | |
| x = self.convs[idx](torch.cat([feat_enc, feat_c], dim=1)) | |
| temp_feat = x | |
| output.append(x) | |
| return output[::-1] |