Spaces:
Sleeping
Sleeping
File size: 2,201 Bytes
899c526 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import torch
import cuda_corr
class CorrLayer(torch.autograd.Function):
@staticmethod
def forward(ctx, fmap1, fmap2, coords, ii, jj, radius, dropout):
""" forward correlation """
ctx.save_for_backward(fmap1, fmap2, coords, ii, jj)
ctx.radius = radius
ctx.dropout = dropout
corr, = cuda_corr.forward(fmap1, fmap2, coords, ii, jj, radius)
return corr
@staticmethod
def backward(ctx, grad):
""" backward correlation """
fmap1, fmap2, coords, ii, jj = ctx.saved_tensors
if ctx.dropout < 1:
perm = torch.rand(len(ii), device="cuda") < ctx.dropout
coords = coords[:,perm]
grad = grad[:,perm]
ii = ii[perm]
jj = jj[perm]
fmap1_grad, fmap2_grad = \
cuda_corr.backward(fmap1, fmap2, coords, ii, jj, grad, ctx.radius)
return fmap1_grad, fmap2_grad, None, None, None, None, None
class PatchLayer(torch.autograd.Function):
@staticmethod
def forward(ctx, net, coords, radius):
""" forward patchify """
ctx.radius = radius
ctx.save_for_backward(net, coords)
patches, = cuda_corr.patchify_forward(net, coords, radius)
return patches
@staticmethod
def backward(ctx, grad):
""" backward patchify """
net, coords = ctx.saved_tensors
grad, = cuda_corr.patchify_backward(net, coords, grad, ctx.radius)
return grad, None, None
def patchify(net, coords, radius, mode='bilinear'):
""" extract patches """
patches = PatchLayer.apply(net, coords, radius)
if mode == 'bilinear':
offset = (coords - coords.floor()).to(net.device)
dx, dy = offset[:,:,None,None,None].unbind(dim=-1)
d = 2 * radius + 1
x00 = (1-dy) * (1-dx) * patches[...,:d,:d]
x01 = (1-dy) * ( dx) * patches[...,:d,1:]
x10 = ( dy) * (1-dx) * patches[...,1:,:d]
x11 = ( dy) * ( dx) * patches[...,1:,1:]
return x00 + x01 + x10 + x11
return patches
def corr(fmap1, fmap2, coords, ii, jj, radius=1, dropout=1):
return CorrLayer.apply(fmap1, fmap2, coords, ii, jj, radius, dropout)
|