|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import |
|
from __future__ import print_function |
|
from __future__ import division |
|
|
|
import time |
|
import torch |
|
import torch.nn as nn |
|
import math |
|
from torch.autograd import gradcheck |
|
|
|
from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch |
|
|
|
H_in, W_in = 8, 8 |
|
N, M, D = 2, 4, 16 |
|
Kh, Kw = 3, 3 |
|
remove_center = False |
|
P = Kh * Kw - remove_center |
|
offset_scale = 2.0 |
|
pad = 1 |
|
dilation = 1 |
|
stride = 1 |
|
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 |
|
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 |
|
|
|
torch.manual_seed(3) |
|
|
|
|
|
@torch.no_grad() |
|
def check_forward_equal_with_pytorch_double(): |
|
input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 |
|
offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 |
|
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 |
|
mask /= mask.sum(-1, keepdim=True) |
|
mask = mask.reshape(N, H_out, W_out, M*P) |
|
|
|
output_pytorch = dcnv3_core_pytorch( |
|
input.double(), |
|
offset.double(), |
|
mask.double(), |
|
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center).detach().cpu() |
|
|
|
im2col_step = 2 |
|
output_cuda = DCNv3Function.apply( |
|
input.double(), |
|
offset.double(), |
|
mask.double(), |
|
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, |
|
im2col_step, remove_center).detach().cpu() |
|
|
|
fwdok = torch.allclose(output_cuda, output_pytorch) |
|
max_abs_err = (output_cuda - output_pytorch).abs().max() |
|
max_rel_err = ((output_cuda - output_pytorch).abs() / |
|
output_pytorch.abs()).max() |
|
print('>>> forward double') |
|
print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') |
|
|
|
|
|
@torch.no_grad() |
|
def check_forward_equal_with_pytorch_float(): |
|
input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 |
|
offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 |
|
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 |
|
mask /= mask.sum(-1, keepdim=True) |
|
mask = mask.reshape(N, H_out, W_out, M*P) |
|
|
|
output_pytorch = dcnv3_core_pytorch( |
|
input, |
|
offset, |
|
mask, |
|
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center).detach().cpu() |
|
|
|
im2col_step = 2 |
|
output_cuda = DCNv3Function.apply( |
|
input, |
|
offset, |
|
mask, |
|
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, |
|
im2col_step, remove_center).detach().cpu() |
|
|
|
fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) |
|
max_abs_err = (output_cuda - output_pytorch).abs().max() |
|
max_rel_err = ((output_cuda - output_pytorch).abs() / |
|
output_pytorch.abs()).max() |
|
print('>>> forward float') |
|
print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') |
|
|
|
|
|
def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_offset=True, grad_mask=True): |
|
|
|
N = 2 |
|
M = 2 |
|
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 |
|
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 |
|
|
|
D = channels |
|
input0 = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 |
|
offset0 = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 |
|
mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 |
|
mask0 /= mask0.sum(-1, keepdim=True) |
|
mask0 = mask0.reshape(N, H_out, W_out, M*P) |
|
input0.requires_grad = grad_input |
|
offset0.requires_grad = grad_offset |
|
mask0.requires_grad = grad_mask |
|
|
|
output_pytorch = dcnv3_core_pytorch( |
|
input0.double(), |
|
offset0.double(), |
|
mask0.double(), |
|
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center) |
|
output_pytorch.sum().backward() |
|
|
|
input1 = input0.detach() |
|
offset1 = offset0.detach() |
|
mask1 = mask0.detach() |
|
input1.requires_grad = grad_input |
|
offset1.requires_grad = grad_offset |
|
mask1.requires_grad = grad_mask |
|
|
|
im2col_step = 2 |
|
output_cuda = DCNv3Function.apply( |
|
input1.double(), |
|
offset1.double(), |
|
mask1.double(), |
|
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, |
|
im2col_step, remove_center) |
|
output_cuda.sum().backward() |
|
|
|
print(f'>>> backward double: channels {D}') |
|
bwdok = torch.allclose(input0.grad, input1.grad, rtol=1e-2, atol=1e-3) |
|
max_abs_err = (input0.grad - input1.grad).abs().max() |
|
max_rel_err = ((input0.grad - input1.grad).abs() / |
|
input0.grad.abs()).max() |
|
print( |
|
f'* {bwdok} input_grad check_backward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') |
|
|
|
bwdok = torch.allclose(offset0.grad, offset1.grad, rtol=1e-2, atol=1e-3) |
|
max_abs_err = (offset0.grad - offset1.grad).abs().max() |
|
max_rel_err = ((offset0.grad - offset1.grad).abs() / |
|
offset0.grad.abs()).max() |
|
print( |
|
f'* {bwdok} offset_grad check_backward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') |
|
|
|
bwdok = torch.allclose(mask0.grad, mask1.grad, rtol=1e-2, atol=1e-3) |
|
max_abs_err = (mask0.grad - mask1.grad).abs().max() |
|
max_rel_err = ((mask0.grad - mask1.grad).abs() / |
|
mask0.grad.abs()).max() |
|
print( |
|
f'* {bwdok} mask_grad check_backward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') |
|
|
|
|
|
def check_backward_equal_with_pytorch_float(channels=4, grad_input=True, grad_offset=True, grad_mask=True): |
|
|
|
N = 2 |
|
M = 2 |
|
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 |
|
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 |
|
|
|
D = channels |
|
input0 = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 |
|
offset0 = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 |
|
mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 |
|
mask0 /= mask0.sum(-1, keepdim=True) |
|
mask0 = mask0.reshape(N, H_out, W_out, M*P) |
|
input0.requires_grad = grad_input |
|
offset0.requires_grad = grad_offset |
|
mask0.requires_grad = grad_mask |
|
|
|
output_pytorch = dcnv3_core_pytorch( |
|
input0, |
|
offset0, |
|
mask0, |
|
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center) |
|
output_pytorch.sum().backward() |
|
|
|
input1 = input0.detach() |
|
offset1 = offset0.detach() |
|
mask1 = mask0.detach() |
|
input1.requires_grad = grad_input |
|
offset1.requires_grad = grad_offset |
|
mask1.requires_grad = grad_mask |
|
|
|
im2col_step = 2 |
|
output_cuda = DCNv3Function.apply( |
|
input1, |
|
offset1, |
|
mask1, |
|
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, |
|
im2col_step, remove_center) |
|
output_cuda.sum().backward() |
|
|
|
print(f'>>> backward float: channels {D}') |
|
bwdok = torch.allclose(input0.grad, input1.grad, rtol=1e-2, atol=1e-3) |
|
max_abs_err = (input0.grad - input1.grad).abs().max() |
|
max_rel_err = ((input0.grad - input1.grad).abs() / |
|
input0.grad.abs()).max() |
|
print( |
|
f'* {bwdok} input_grad check_backward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') |
|
|
|
bwdok = torch.allclose(offset0.grad, offset1.grad, rtol=1e-2, atol=1e-3) |
|
max_abs_err = (offset0.grad - offset1.grad).abs().max() |
|
max_rel_err = ((offset0.grad - offset1.grad).abs() / |
|
offset0.grad.abs()).max() |
|
print( |
|
f'* {bwdok} offset_grad check_backward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') |
|
|
|
bwdok = torch.allclose(mask0.grad, mask1.grad, rtol=1e-2, atol=1e-3) |
|
max_abs_err = (mask0.grad - mask1.grad).abs().max() |
|
max_rel_err = ((mask0.grad - mask1.grad).abs() / |
|
mask0.grad.abs()).max() |
|
print( |
|
f'* {bwdok} mask_grad check_backward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') |
|
|
|
|
|
@torch.no_grad() |
|
def check_time_cost(im2col_step=128): |
|
N = 512 |
|
H_in, W_in = 64, 64 |
|
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 |
|
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 |
|
|
|
input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 |
|
offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 |
|
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 |
|
mask /= mask.sum(-1, keepdim=True) |
|
mask = mask.reshape(N, H_out, W_out, M*P) |
|
print( |
|
f'>>> time cost: im2col_step {im2col_step}; input {input.shape}; points {P} ') |
|
repeat = 100 |
|
for i in range(repeat): |
|
output_cuda = DCNv3Function.apply( |
|
input, |
|
offset, |
|
mask, |
|
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, 1.0, |
|
im2col_step, remove_center) |
|
torch.cuda.synchronize() |
|
start = time.time() |
|
for i in range(repeat): |
|
output_cuda = DCNv3Function.apply( |
|
input, |
|
offset, |
|
mask, |
|
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, 1.0, |
|
im2col_step, remove_center) |
|
torch.cuda.synchronize() |
|
print(f'foward time cost: {(time.time() - start) / repeat}') |
|
|
|
|
|
if __name__ == '__main__': |
|
check_forward_equal_with_pytorch_double() |
|
check_forward_equal_with_pytorch_float() |
|
for channels in [1, 16, 30, 32, 64, 71, 1025]: |
|
check_backward_equal_with_pytorch_double(channels, True, True, True) |
|
for channels in [1, 16, 30, 32, 64, 71, 1025]: |
|
check_backward_equal_with_pytorch_float(channels, True, True, True) |
|
for i in range(3): |
|
im2col_step = 128 * (2 ** i) |
|
check_time_cost(im2col_step) |
|
|