# -------------------------------------------------------- # InternImage # Copyright (c) 2022 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- from __future__ import absolute_import from __future__ import print_function from __future__ import division import time import torch import torch.nn as nn import math from torch.autograd import gradcheck from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch H_in, W_in = 8, 8 N, M, D = 2, 4, 16 Kh, Kw = 3, 3 remove_center = False P = Kh * Kw - remove_center offset_scale = 2.0 pad = 1 dilation = 1 stride = 1 H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 torch.manual_seed(3) @torch.no_grad() def check_forward_equal_with_pytorch_double(): input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 mask /= mask.sum(-1, keepdim=True) mask = mask.reshape(N, H_out, W_out, M*P) output_pytorch = dcnv3_core_pytorch( input.double(), offset.double(), mask.double(), Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center).detach().cpu() im2col_step = 2 output_cuda = DCNv3Function.apply( input.double(), offset.double(), mask.double(), Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, im2col_step, remove_center).detach().cpu() fwdok = torch.allclose(output_cuda, output_pytorch) max_abs_err = (output_cuda - output_pytorch).abs().max() max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() print('>>> forward double') print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') @torch.no_grad() def check_forward_equal_with_pytorch_float(): input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 mask /= mask.sum(-1, keepdim=True) mask = mask.reshape(N, H_out, W_out, M*P) output_pytorch = dcnv3_core_pytorch( input, offset, mask, Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center).detach().cpu() im2col_step = 2 output_cuda = DCNv3Function.apply( input, offset, mask, Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, im2col_step, remove_center).detach().cpu() fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) max_abs_err = (output_cuda - output_pytorch).abs().max() max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() print('>>> forward float') print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_offset=True, grad_mask=True): # H_in, W_in = 4, 4 N = 2 M = 2 H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 D = channels input0 = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 offset0 = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 mask0 /= mask0.sum(-1, keepdim=True) mask0 = mask0.reshape(N, H_out, W_out, M*P) input0.requires_grad = grad_input offset0.requires_grad = grad_offset mask0.requires_grad = grad_mask output_pytorch = dcnv3_core_pytorch( input0.double(), offset0.double(), mask0.double(), Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center) output_pytorch.sum().backward() input1 = input0.detach() offset1 = offset0.detach() mask1 = mask0.detach() input1.requires_grad = grad_input offset1.requires_grad = grad_offset mask1.requires_grad = grad_mask im2col_step = 2 output_cuda = DCNv3Function.apply( input1.double(), offset1.double(), mask1.double(), Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, im2col_step, remove_center) output_cuda.sum().backward() print(f'>>> backward double: channels {D}') bwdok = torch.allclose(input0.grad, input1.grad, rtol=1e-2, atol=1e-3) max_abs_err = (input0.grad - input1.grad).abs().max() max_rel_err = ((input0.grad - input1.grad).abs() / input0.grad.abs()).max() print( f'* {bwdok} input_grad check_backward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') bwdok = torch.allclose(offset0.grad, offset1.grad, rtol=1e-2, atol=1e-3) max_abs_err = (offset0.grad - offset1.grad).abs().max() max_rel_err = ((offset0.grad - offset1.grad).abs() / offset0.grad.abs()).max() print( f'* {bwdok} offset_grad check_backward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') bwdok = torch.allclose(mask0.grad, mask1.grad, rtol=1e-2, atol=1e-3) max_abs_err = (mask0.grad - mask1.grad).abs().max() max_rel_err = ((mask0.grad - mask1.grad).abs() / mask0.grad.abs()).max() print( f'* {bwdok} mask_grad check_backward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') def check_backward_equal_with_pytorch_float(channels=4, grad_input=True, grad_offset=True, grad_mask=True): # H_in, W_in = 4, 4 N = 2 M = 2 H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 D = channels input0 = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 offset0 = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 mask0 /= mask0.sum(-1, keepdim=True) mask0 = mask0.reshape(N, H_out, W_out, M*P) input0.requires_grad = grad_input offset0.requires_grad = grad_offset mask0.requires_grad = grad_mask output_pytorch = dcnv3_core_pytorch( input0, offset0, mask0, Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center) output_pytorch.sum().backward() input1 = input0.detach() offset1 = offset0.detach() mask1 = mask0.detach() input1.requires_grad = grad_input offset1.requires_grad = grad_offset mask1.requires_grad = grad_mask im2col_step = 2 output_cuda = DCNv3Function.apply( input1, offset1, mask1, Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, im2col_step, remove_center) output_cuda.sum().backward() print(f'>>> backward float: channels {D}') bwdok = torch.allclose(input0.grad, input1.grad, rtol=1e-2, atol=1e-3) max_abs_err = (input0.grad - input1.grad).abs().max() max_rel_err = ((input0.grad - input1.grad).abs() / input0.grad.abs()).max() print( f'* {bwdok} input_grad check_backward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') bwdok = torch.allclose(offset0.grad, offset1.grad, rtol=1e-2, atol=1e-3) max_abs_err = (offset0.grad - offset1.grad).abs().max() max_rel_err = ((offset0.grad - offset1.grad).abs() / offset0.grad.abs()).max() print( f'* {bwdok} offset_grad check_backward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') bwdok = torch.allclose(mask0.grad, mask1.grad, rtol=1e-2, atol=1e-3) max_abs_err = (mask0.grad - mask1.grad).abs().max() max_rel_err = ((mask0.grad - mask1.grad).abs() / mask0.grad.abs()).max() print( f'* {bwdok} mask_grad check_backward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') @torch.no_grad() def check_time_cost(im2col_step=128): N = 512 H_in, W_in = 64, 64 H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 mask /= mask.sum(-1, keepdim=True) mask = mask.reshape(N, H_out, W_out, M*P) print( f'>>> time cost: im2col_step {im2col_step}; input {input.shape}; points {P} ') repeat = 100 for i in range(repeat): output_cuda = DCNv3Function.apply( input, offset, mask, Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, 1.0, im2col_step, remove_center) torch.cuda.synchronize() start = time.time() for i in range(repeat): output_cuda = DCNv3Function.apply( input, offset, mask, Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, 1.0, im2col_step, remove_center) torch.cuda.synchronize() print(f'foward time cost: {(time.time() - start) / repeat}') if __name__ == '__main__': check_forward_equal_with_pytorch_double() check_forward_equal_with_pytorch_float() for channels in [1, 16, 30, 32, 64, 71, 1025]: check_backward_equal_with_pytorch_double(channels, True, True, True) for channels in [1, 16, 30, 32, 64, 71, 1025]: check_backward_equal_with_pytorch_float(channels, True, True, True) for i in range(3): im2col_step = 128 * (2 ** i) check_time_cost(im2col_step)