|
import torch |
|
from kornia.color import rgb_to_grayscale |
|
from kornia.feature import HardNet, LAFDescriptor, laf_from_center_scale_ori |
|
|
|
from .sift import SIFT |
|
|
|
|
|
class DoGHardNet(SIFT): |
|
required_data_keys = ["image"] |
|
|
|
def __init__(self, **conf): |
|
super().__init__(**conf) |
|
self.laf_desc = LAFDescriptor(HardNet(True)).eval() |
|
|
|
def forward(self, data: dict) -> dict: |
|
image = data["image"] |
|
if image.shape[1] == 3: |
|
image = rgb_to_grayscale(image) |
|
device = image.device |
|
self.laf_desc = self.laf_desc.to(device) |
|
self.laf_desc.descriptor = self.laf_desc.descriptor.eval() |
|
pred = [] |
|
if "image_size" in data.keys(): |
|
im_size = data.get("image_size").long() |
|
else: |
|
im_size = None |
|
for k in range(len(image)): |
|
img = image[k] |
|
if im_size is not None: |
|
w, h = data["image_size"][k] |
|
img = img[:, : h.to(torch.int32), : w.to(torch.int32)] |
|
p = self.extract_single_image(img) |
|
lafs = laf_from_center_scale_ori( |
|
p["keypoints"].reshape(1, -1, 2), |
|
6.0 * p["scales"].reshape(1, -1, 1, 1), |
|
torch.rad2deg(p["oris"]).reshape(1, -1, 1), |
|
).to(device) |
|
p["descriptors"] = self.laf_desc(img[None], lafs).reshape(-1, 128) |
|
pred.append(p) |
|
pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]} |
|
return pred |
|
|