Spaces:
Running
Running
| # Copyright 2019-present NAVER Corp. | |
| # CC BY-NC-SA 3.0 | |
| # Available only for non-commercial use | |
| import pdb | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| """ Different samplers, each specifying how to sample pixels for the AP loss. | |
| """ | |
| class FullSampler(nn.Module): | |
| """all pixels are selected | |
| - feats: keypoint descriptors | |
| - confs: reliability values | |
| """ | |
| def __init__(self): | |
| nn.Module.__init__(self) | |
| self.mode = "bilinear" | |
| self.padding = "zeros" | |
| def _aflow_to_grid(aflow): | |
| H, W = aflow.shape[2:] | |
| grid = aflow.permute(0, 2, 3, 1).clone() | |
| grid[:, :, :, 0] *= 2 / (W - 1) | |
| grid[:, :, :, 1] *= 2 / (H - 1) | |
| grid -= 1 | |
| grid[torch.isnan(grid)] = 9e9 # invalids | |
| return grid | |
| def _warp(self, feats, confs, aflow): | |
| if isinstance(aflow, tuple): | |
| return aflow # result was precomputed | |
| feat1, feat2 = feats | |
| conf1, conf2 = confs if confs else (None, None) | |
| B, two, H, W = aflow.shape | |
| D = feat1.shape[1] | |
| assert feat1.shape == feat2.shape == (B, D, H, W) # D = 128, B = batch | |
| assert conf1.shape == conf2.shape == (B, 1, H, W) if confs else True | |
| # warp img2 to img1 | |
| grid = self._aflow_to_grid(aflow) | |
| ones2 = feat2.new_ones(feat2[:, 0:1].shape) | |
| feat2to1 = F.grid_sample(feat2, grid, mode=self.mode, padding_mode=self.padding) | |
| mask2to1 = F.grid_sample(ones2, grid, mode="nearest", padding_mode="zeros") | |
| conf2to1 = ( | |
| F.grid_sample(conf2, grid, mode=self.mode, padding_mode=self.padding) | |
| if confs | |
| else None | |
| ) | |
| return feat2to1, mask2to1.byte(), conf2to1 | |
| def _warp_positions(self, aflow): | |
| B, two, H, W = aflow.shape | |
| assert two == 2 | |
| Y = torch.arange(H, device=aflow.device) | |
| X = torch.arange(W, device=aflow.device) | |
| XY = torch.stack(torch.meshgrid(Y, X)[::-1], dim=0) | |
| XY = XY[None].expand(B, 2, H, W).float() | |
| grid = self._aflow_to_grid(aflow) | |
| XY2 = F.grid_sample(XY, grid, mode="bilinear", padding_mode="zeros") | |
| return XY, XY2 | |
| class SubSampler(FullSampler): | |
| """pixels are selected in an uniformly spaced grid""" | |
| def __init__(self, border, subq, subd, perimage=False): | |
| FullSampler.__init__(self) | |
| assert subq % subd == 0, "subq must be multiple of subd" | |
| self.sub_q = subq | |
| self.sub_d = subd | |
| self.border = border | |
| self.perimage = perimage | |
| def __repr__(self): | |
| return "SubSampler(border=%d, subq=%d, subd=%d, perimage=%d)" % ( | |
| self.border, | |
| self.sub_q, | |
| self.sub_d, | |
| self.perimage, | |
| ) | |
| def __call__(self, feats, confs, aflow): | |
| feat1, conf1 = feats[0], (confs[0] if confs else None) | |
| # warp with optical flow in img1 coords | |
| feat2, mask2, conf2 = self._warp(feats, confs, aflow) | |
| # subsample img1 | |
| slq = slice(self.border, -self.border or None, self.sub_q) | |
| feat1 = feat1[:, :, slq, slq] | |
| conf1 = conf1[:, :, slq, slq] if confs else None | |
| # subsample img2 | |
| sld = slice(self.border, -self.border or None, self.sub_d) | |
| feat2 = feat2[:, :, sld, sld] | |
| mask2 = mask2[:, :, sld, sld] | |
| conf2 = conf2[:, :, sld, sld] if confs else None | |
| B, D, Hq, Wq = feat1.shape | |
| B, D, Hd, Wd = feat2.shape | |
| # compute gt | |
| if self.perimage or self.sub_q != self.sub_d: | |
| # compute ground-truth by comparing pixel indices | |
| f = feats[0][0:1, 0] if self.perimage else feats[0][:, 0] | |
| idxs = torch.arange(f.numel(), dtype=torch.int64, device=feat1.device).view( | |
| f.shape | |
| ) | |
| idxs1 = idxs[:, slq, slq].reshape(-1, Hq * Wq) | |
| idxs2 = idxs[:, sld, sld].reshape(-1, Hd * Wd) | |
| if self.perimage: | |
| gt = idxs1[0].view(-1, 1) == idxs2[0].view(1, -1) | |
| gt = gt[None, :, :].expand(B, Hq * Wq, Hd * Wd) | |
| else: | |
| gt = idxs1.view(-1, 1) == idxs2.view(1, -1) | |
| else: | |
| gt = torch.eye( | |
| feat1[:, 0].numel(), dtype=torch.uint8, device=feat1.device | |
| ) # always binary for AP loss | |
| # compute all images together | |
| queries = feat1.reshape(B, D, -1) # B x D x (Hq x Wq) | |
| database = feat2.reshape(B, D, -1) # B x D x (Hd x Wd) | |
| if self.perimage: | |
| queries = queries.transpose(1, 2) # B x (Hd x Wd) x D | |
| scores = torch.bmm(queries, database) # B x (Hq x Wq) x (Hd x Wd) | |
| else: | |
| queries = queries.transpose(1, 2).reshape(-1, D) # (B x Hq x Wq) x D | |
| database = database.transpose(1, 0).reshape(D, -1) # D x (B x Hd x Wd) | |
| scores = torch.matmul(queries, database) # (B x Hq x Wq) x (B x Hd x Wd) | |
| # compute reliability | |
| qconf = (conf1 + conf2) / 2 if confs else None | |
| assert gt.shape == scores.shape | |
| return scores, gt, mask2, qconf | |
| class NghSampler(FullSampler): | |
| """all pixels in a small neighborhood""" | |
| def __init__(self, ngh, subq=1, subd=1, ignore=1, border=None): | |
| FullSampler.__init__(self) | |
| assert 0 <= ignore < ngh | |
| self.ngh = ngh | |
| self.ignore = ignore | |
| assert subd <= ngh | |
| self.sub_q = subq | |
| self.sub_d = subd | |
| if border is None: | |
| border = ngh | |
| assert border >= ngh, "border has to be larger than ngh" | |
| self.border = border | |
| def __repr__(self): | |
| return "NghSampler(ngh=%d, subq=%d, subd=%d, ignore=%d, border=%d)" % ( | |
| self.ngh, | |
| self.sub_q, | |
| self.sub_d, | |
| self.ignore, | |
| self.border, | |
| ) | |
| def trans(self, arr, i, j): | |
| s = lambda i: slice(self.border + i, i - self.border or None, self.sub_q) | |
| return arr[:, :, s(j), s(i)] | |
| def __call__(self, feats, confs, aflow): | |
| feat1, conf1 = feats[0], (confs[0] if confs else None) | |
| # warp with optical flow in img1 coords | |
| feat2, mask2, conf2 = self._warp(feats, confs, aflow) | |
| qfeat = self.trans(feat1, 0, 0) | |
| qconf = ( | |
| (self.trans(conf1, 0, 0) + self.trans(conf2, 0, 0)) / 2 if confs else None | |
| ) | |
| mask2 = self.trans(mask2, 0, 0) | |
| scores_at = lambda i, j: (qfeat * self.trans(feat2, i, j)).sum(dim=1) | |
| # compute scores for all neighbors | |
| B, D = feat1.shape[:2] | |
| min_d = self.ignore**2 | |
| max_d = self.ngh**2 | |
| rad = (self.ngh // self.sub_d) * self.ngh # make an integer multiple | |
| negs = [] | |
| offsets = [] | |
| for j in range(-rad, rad + 1, self.sub_d): | |
| for i in range(-rad, rad + 1, self.sub_d): | |
| if not (min_d < i * i + j * j <= max_d): | |
| continue # out of scope | |
| offsets.append((i, j)) # Note: this list is just for debug | |
| negs.append(scores_at(i, j)) | |
| scores = torch.stack([scores_at(0, 0)] + negs, dim=-1) | |
| gt = scores.new_zeros(scores.shape, dtype=torch.uint8) | |
| gt[..., 0] = 1 # only the center point is positive | |
| return scores, gt, mask2, qconf | |
| class FarNearSampler(FullSampler): | |
| """Sample pixels from *both* a small neighborhood *and* far-away pixels. | |
| How it works? | |
| 1) Queries are sampled from img1, | |
| - at least `border` pixels from borders and | |
| - on a grid with step = `subq` | |
| 2) Close database pixels | |
| - from the corresponding image (img2), | |
| - within a `ngh` distance radius | |
| - on a grid with step = `subd_ngh` | |
| - ignored if distance to query is >0 and <=`ignore` | |
| 3) Far-away database pixels from , | |
| - from all batch images in `img2` | |
| - at least `border` pixels from borders | |
| - on a grid with step = `subd_far` | |
| """ | |
| def __init__( | |
| self, subq, ngh, subd_ngh, subd_far, border=None, ignore=1, maxpool_ngh=False | |
| ): | |
| FullSampler.__init__(self) | |
| border = border or ngh | |
| assert ignore < ngh < subd_far, "neighborhood needs to be smaller than far step" | |
| self.close_sampler = NghSampler( | |
| ngh=ngh, subq=subq, subd=subd_ngh, ignore=not (maxpool_ngh), border=border | |
| ) | |
| self.faraway_sampler = SubSampler(border=border, subq=subq, subd=subd_far) | |
| self.maxpool_ngh = maxpool_ngh | |
| def __repr__(self): | |
| c, f = self.close_sampler, self.faraway_sampler | |
| res = "FarNearSampler(subq=%d, ngh=%d" % (c.sub_q, c.ngh) | |
| res += ", subd_ngh=%d, subd_far=%d" % (c.sub_d, f.sub_d) | |
| res += ", border=%d, ign=%d" % (f.border, c.ignore) | |
| res += ", maxpool_ngh=%d" % self.maxpool_ngh | |
| return res + ")" | |
| def __call__(self, feats, confs, aflow): | |
| # warp with optical flow in img1 coords | |
| aflow = self._warp(feats, confs, aflow) | |
| # sample ngh pixels | |
| scores1, gt1, msk1, conf1 = self.close_sampler(feats, confs, aflow) | |
| scores1, gt1 = scores1.view(-1, scores1.shape[-1]), gt1.view(-1, gt1.shape[-1]) | |
| if self.maxpool_ngh: | |
| # we consider all scores from ngh as potential positives | |
| scores1, self._cached_maxpool_ngh = scores1.max(dim=1, keepdim=True) | |
| gt1 = gt1[:, 0:1] | |
| # sample far pixels | |
| scores2, gt2, msk2, conf2 = self.faraway_sampler(feats, confs, aflow) | |
| # assert (msk1 == msk2).all() | |
| # assert (conf1 == conf2).all() | |
| return ( | |
| torch.cat((scores1, scores2), dim=1), | |
| torch.cat((gt1, gt2), dim=1), | |
| msk1, | |
| conf1 if confs else None, | |
| ) | |
| class NghSampler2(nn.Module): | |
| """Similar to NghSampler, but doesnt warp the 2nd image. | |
| Distance to GT => 0 ... pos_d ... neg_d ... ngh | |
| Pixel label => + + + + + + 0 0 - - - - - - - | |
| Subsample on query side: if > 0, regular grid | |
| < 0, random points | |
| In both cases, the number of query points is = W*H/subq**2 | |
| """ | |
| def __init__( | |
| self, | |
| ngh, | |
| subq=1, | |
| subd=1, | |
| pos_d=0, | |
| neg_d=2, | |
| border=None, | |
| maxpool_pos=True, | |
| subd_neg=0, | |
| ): | |
| nn.Module.__init__(self) | |
| assert 0 <= pos_d < neg_d <= (ngh if ngh else 99) | |
| self.ngh = ngh | |
| self.pos_d = pos_d | |
| self.neg_d = neg_d | |
| assert subd <= ngh or ngh == 0 | |
| assert subq != 0 | |
| self.sub_q = subq | |
| self.sub_d = subd | |
| self.sub_d_neg = subd_neg | |
| if border is None: | |
| border = ngh | |
| assert border >= ngh, "border has to be larger than ngh" | |
| self.border = border | |
| self.maxpool_pos = maxpool_pos | |
| self.precompute_offsets() | |
| def precompute_offsets(self): | |
| pos_d2 = self.pos_d**2 | |
| neg_d2 = self.neg_d**2 | |
| rad2 = self.ngh**2 | |
| rad = (self.ngh // self.sub_d) * self.ngh # make an integer multiple | |
| pos = [] | |
| neg = [] | |
| for j in range(-rad, rad + 1, self.sub_d): | |
| for i in range(-rad, rad + 1, self.sub_d): | |
| d2 = i * i + j * j | |
| if d2 <= pos_d2: | |
| pos.append((i, j)) | |
| elif neg_d2 <= d2 <= rad2: | |
| neg.append((i, j)) | |
| self.register_buffer("pos_offsets", torch.LongTensor(pos).view(-1, 2).t()) | |
| self.register_buffer("neg_offsets", torch.LongTensor(neg).view(-1, 2).t()) | |
| def gen_grid(self, step, aflow): | |
| B, two, H, W = aflow.shape | |
| dev = aflow.device | |
| b1 = torch.arange(B, device=dev) | |
| if step > 0: | |
| # regular grid | |
| x1 = torch.arange(self.border, W - self.border, step, device=dev) | |
| y1 = torch.arange(self.border, H - self.border, step, device=dev) | |
| H1, W1 = len(y1), len(x1) | |
| x1 = x1[None, None, :].expand(B, H1, W1).reshape(-1) | |
| y1 = y1[None, :, None].expand(B, H1, W1).reshape(-1) | |
| b1 = b1[:, None, None].expand(B, H1, W1).reshape(-1) | |
| shape = (B, H1, W1) | |
| else: | |
| # randomly spread | |
| n = (H - 2 * self.border) * (W - 2 * self.border) // step**2 | |
| x1 = torch.randint(self.border, W - self.border, (n,), device=dev) | |
| y1 = torch.randint(self.border, H - self.border, (n,), device=dev) | |
| x1 = x1[None, :].expand(B, n).reshape(-1) | |
| y1 = y1[None, :].expand(B, n).reshape(-1) | |
| b1 = b1[:, None].expand(B, n).reshape(-1) | |
| shape = (B, n) | |
| return b1, y1, x1, shape | |
| def forward(self, feats, confs, aflow, **kw): | |
| B, two, H, W = aflow.shape | |
| assert two == 2 | |
| feat1, conf1 = feats[0], (confs[0] if confs else None) | |
| feat2, conf2 = feats[1], (confs[1] if confs else None) | |
| # positions in the first image | |
| b1, y1, x1, shape = self.gen_grid(self.sub_q, aflow) | |
| # sample features from first image | |
| feat1 = feat1[b1, :, y1, x1] | |
| qconf = conf1[b1, :, y1, x1].view(shape) if confs else None | |
| # sample GT from second image | |
| b2 = b1 | |
| xy2 = (aflow[b1, :, y1, x1] + 0.5).long().t() | |
| mask = (0 <= xy2[0]) * (0 <= xy2[1]) * (xy2[0] < W) * (xy2[1] < H) | |
| mask = mask.view(shape) | |
| def clamp(xy): | |
| torch.clamp(xy[0], 0, W - 1, out=xy[0]) | |
| torch.clamp(xy[1], 0, H - 1, out=xy[1]) | |
| return xy | |
| # compute positive scores | |
| xy2p = clamp(xy2[:, None, :] + self.pos_offsets[:, :, None]) | |
| pscores = (feat1[None, :, :] * feat2[b2, :, xy2p[1], xy2p[0]]).sum(dim=-1).t() | |
| # xy1p = clamp(torch.stack((x1,y1))[:,None,:] + self.pos_offsets[:,:,None]) | |
| # grid = FullSampler._aflow_to_grid(aflow) | |
| # feat2p = F.grid_sample(feat2, grid, mode='bilinear', padding_mode='border') | |
| # pscores = (feat1[None,:,:] * feat2p[b1,:,xy1p[1], xy1p[0]]).sum(dim=-1).t() | |
| if self.maxpool_pos: | |
| pscores, pos = pscores.max(dim=1, keepdim=True) | |
| if confs: | |
| sel = clamp(xy2 + self.pos_offsets[:, pos.view(-1)]) | |
| qconf = (qconf + conf2[b2, :, sel[1], sel[0]].view(shape)) / 2 | |
| # compute negative scores | |
| xy2n = clamp(xy2[:, None, :] + self.neg_offsets[:, :, None]) | |
| nscores = (feat1[None, :, :] * feat2[b2, :, xy2n[1], xy2n[0]]).sum(dim=-1).t() | |
| if self.sub_d_neg: | |
| # add distractors from a grid | |
| b3, y3, x3, _ = self.gen_grid(self.sub_d_neg, aflow) | |
| distractors = feat2[b3, :, y3, x3] | |
| dscores = torch.matmul(feat1, distractors.t()) | |
| del distractors | |
| # remove scores that corresponds to positives or nulls | |
| dis2 = (x3 - xy2[0][:, None]) ** 2 + (y3 - xy2[1][:, None]) ** 2 | |
| dis2 += (b3 != b2[:, None]).long() * self.neg_d**2 | |
| dscores[dis2 < self.neg_d**2] = 0 | |
| scores = torch.cat((pscores, nscores, dscores), dim=1) | |
| else: | |
| # concat everything | |
| scores = torch.cat((pscores, nscores), dim=1) | |
| gt = scores.new_zeros(scores.shape, dtype=torch.uint8) | |
| gt[:, : pscores.shape[1]] = 1 | |
| return scores, gt, mask, qconf | |