diff --git a/app.py b/app.py
index fa9289672a72747a831206e782646df4e8fa37f9..96c6dac4ed89a6bd1783c81b6c145fb4315aa4c1 100644
--- a/app.py
+++ b/app.py
@@ -9,9 +9,10 @@ from extra_utils.utils import (
     match_features,
     get_model,
     get_feature_model,
-    display_matches
+    display_matches,
 )
 
+
 def run_matching(
     match_threshold, extract_max_keypoints, keypoint_threshold, key, image0, image1
 ):
@@ -277,7 +278,7 @@ def run(config):
                 matcher_info,
             ]
             button_reset.click(fn=ui_reset_state, inputs=inputs, outputs=reset_outputs)
-        
+
     app.launch(share=False)
 
 
diff --git a/third_party/ALIKE/alike.py b/third_party/ALIKE/alike.py
index 303616d52581efce0ae0eb86af70f5ea8984909d..b975f806f3e0f593a3564ae52d9d08187f514b34 100644
--- a/third_party/ALIKE/alike.py
+++ b/third_party/ALIKE/alike.py
@@ -12,46 +12,89 @@ from soft_detect import DKD
 import time
 
 configs = {
-    'alike-t': {'c1': 8, 'c2': 16, 'c3': 32, 'c4': 64, 'dim': 64, 'single_head': True, 'radius': 2,
-                'model_path': os.path.join(os.path.split(__file__)[0], 'models', 'alike-t.pth')},
-    'alike-s': {'c1': 8, 'c2': 16, 'c3': 48, 'c4': 96, 'dim': 96, 'single_head': True, 'radius': 2,
-                'model_path': os.path.join(os.path.split(__file__)[0], 'models', 'alike-s.pth')},
-    'alike-n': {'c1': 16, 'c2': 32, 'c3': 64, 'c4': 128, 'dim': 128, 'single_head': True, 'radius': 2,
-                'model_path': os.path.join(os.path.split(__file__)[0], 'models', 'alike-n.pth')},
-    'alike-l': {'c1': 32, 'c2': 64, 'c3': 128, 'c4': 128, 'dim': 128, 'single_head': False, 'radius': 2,
-                'model_path': os.path.join(os.path.split(__file__)[0], 'models', 'alike-l.pth')},
+    "alike-t": {
+        "c1": 8,
+        "c2": 16,
+        "c3": 32,
+        "c4": 64,
+        "dim": 64,
+        "single_head": True,
+        "radius": 2,
+        "model_path": os.path.join(os.path.split(__file__)[0], "models", "alike-t.pth"),
+    },
+    "alike-s": {
+        "c1": 8,
+        "c2": 16,
+        "c3": 48,
+        "c4": 96,
+        "dim": 96,
+        "single_head": True,
+        "radius": 2,
+        "model_path": os.path.join(os.path.split(__file__)[0], "models", "alike-s.pth"),
+    },
+    "alike-n": {
+        "c1": 16,
+        "c2": 32,
+        "c3": 64,
+        "c4": 128,
+        "dim": 128,
+        "single_head": True,
+        "radius": 2,
+        "model_path": os.path.join(os.path.split(__file__)[0], "models", "alike-n.pth"),
+    },
+    "alike-l": {
+        "c1": 32,
+        "c2": 64,
+        "c3": 128,
+        "c4": 128,
+        "dim": 128,
+        "single_head": False,
+        "radius": 2,
+        "model_path": os.path.join(os.path.split(__file__)[0], "models", "alike-l.pth"),
+    },
 }
 
 
 class ALike(ALNet):
-    def __init__(self,
-                 # ================================== feature encoder
-                 c1: int = 32, c2: int = 64, c3: int = 128, c4: int = 128, dim: int = 128,
-                 single_head: bool = False,
-                 # ================================== detect parameters
-                 radius: int = 2,
-                 top_k: int = 500, scores_th: float = 0.5,
-                 n_limit: int = 5000,
-                 device: str = 'cpu',
-                 model_path: str = ''
-                 ):
+    def __init__(
+        self,
+        # ================================== feature encoder
+        c1: int = 32,
+        c2: int = 64,
+        c3: int = 128,
+        c4: int = 128,
+        dim: int = 128,
+        single_head: bool = False,
+        # ================================== detect parameters
+        radius: int = 2,
+        top_k: int = 500,
+        scores_th: float = 0.5,
+        n_limit: int = 5000,
+        device: str = "cpu",
+        model_path: str = "",
+    ):
         super().__init__(c1, c2, c3, c4, dim, single_head)
         self.radius = radius
         self.top_k = top_k
         self.n_limit = n_limit
         self.scores_th = scores_th
-        self.dkd = DKD(radius=self.radius, top_k=self.top_k,
-                       scores_th=self.scores_th, n_limit=self.n_limit)
+        self.dkd = DKD(
+            radius=self.radius,
+            top_k=self.top_k,
+            scores_th=self.scores_th,
+            n_limit=self.n_limit,
+        )
         self.device = device
 
-        if model_path != '':
+        if model_path != "":
             state_dict = torch.load(model_path, self.device)
             self.load_state_dict(state_dict)
             self.to(self.device)
             self.eval()
-            logging.info(f'Loaded model parameters from {model_path}')
+            logging.info(f"Loaded model parameters from {model_path}")
             logging.info(
-                f"Number of model parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e3}KB")
+                f"Number of model parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e3}KB"
+            )
 
     def extract_dense_map(self, image, ret_dict=False):
         # ====================================================
@@ -81,7 +124,10 @@ class ALike(ALNet):
         descriptor_map = torch.nn.functional.normalize(descriptor_map, p=2, dim=1)
 
         if ret_dict:
-            return {'descriptor_map': descriptor_map, 'scores_map': scores_map, }
+            return {
+                "descriptor_map": descriptor_map,
+                "scores_map": scores_map,
+            }
         else:
             return descriptor_map, scores_map
 
@@ -104,15 +150,22 @@ class ALike(ALNet):
             image = cv2.resize(image, dsize=None, fx=ratio, fy=ratio)
 
         # ==================== convert image to tensor
-        image = torch.from_numpy(image).to(self.device).to(torch.float32).permute(2, 0, 1)[None] / 255.0
+        image = (
+            torch.from_numpy(image)
+            .to(self.device)
+            .to(torch.float32)
+            .permute(2, 0, 1)[None]
+            / 255.0
+        )
 
         # ==================== extract keypoints
         start = time.time()
 
         with torch.no_grad():
             descriptor_map, scores_map = self.extract_dense_map(image)
-            keypoints, descriptors, scores, _ = self.dkd(scores_map, descriptor_map,
-                                                         sub_pixel=sub_pixel)
+            keypoints, descriptors, scores, _ = self.dkd(
+                scores_map, descriptor_map, sub_pixel=sub_pixel
+            )
             keypoints, descriptors, scores = keypoints[0], descriptors[0], scores[0]
             keypoints = (keypoints + 1) / 2 * keypoints.new_tensor([[W - 1, H - 1]])
 
@@ -124,14 +177,16 @@ class ALike(ALNet):
 
         end = time.time()
 
-        return {'keypoints': keypoints.cpu().numpy(),
-                'descriptors': descriptors.cpu().numpy(),
-                'scores': scores.cpu().numpy(),
-                'scores_map': scores_map.cpu().numpy(),
-                'time': end - start, }
+        return {
+            "keypoints": keypoints.cpu().numpy(),
+            "descriptors": descriptors.cpu().numpy(),
+            "scores": scores.cpu().numpy(),
+            "scores_map": scores_map.cpu().numpy(),
+            "time": end - start,
+        }
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     import numpy as np
     from thop import profile
 
@@ -139,5 +194,5 @@ if __name__ == '__main__':
 
     image = np.random.random((640, 480, 3)).astype(np.float32)
     flops, params = profile(net, inputs=(image, 9999, False), verbose=False)
-    print('{:<30}  {:<8} GFLops'.format('Computational complexity: ', flops / 1e9))
-    print('{:<30}  {:<8} KB'.format('Number of parameters: ', params / 1e3))
+    print("{:<30}  {:<8} GFLops".format("Computational complexity: ", flops / 1e9))
+    print("{:<30}  {:<8} KB".format("Number of parameters: ", params / 1e3))
diff --git a/third_party/ALIKE/alnet.py b/third_party/ALIKE/alnet.py
index 53127063233660c7b96aa15e89aa4a8a1a340dd1..91cb7ee55e502895e7b0037f2add1a35a613cd40 100644
--- a/third_party/ALIKE/alnet.py
+++ b/third_party/ALIKE/alnet.py
@@ -5,9 +5,13 @@ from typing import Optional, Callable
 
 
 class ConvBlock(nn.Module):
-    def __init__(self, in_channels, out_channels,
-                 gate: Optional[Callable[..., nn.Module]] = None,
-                 norm_layer: Optional[Callable[..., nn.Module]] = None):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        gate: Optional[Callable[..., nn.Module]] = None,
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+    ):
         super().__init__()
         if gate is None:
             self.gate = nn.ReLU(inplace=True)
@@ -31,16 +35,16 @@ class ResBlock(nn.Module):
     expansion: int = 1
 
     def __init__(
-            self,
-            inplanes: int,
-            planes: int,
-            stride: int = 1,
-            downsample: Optional[nn.Module] = None,
-            groups: int = 1,
-            base_width: int = 64,
-            dilation: int = 1,
-            gate: Optional[Callable[..., nn.Module]] = None,
-            norm_layer: Optional[Callable[..., nn.Module]] = None
+        self,
+        inplanes: int,
+        planes: int,
+        stride: int = 1,
+        downsample: Optional[nn.Module] = None,
+        groups: int = 1,
+        base_width: int = 64,
+        dilation: int = 1,
+        gate: Optional[Callable[..., nn.Module]] = None,
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
     ) -> None:
         super(ResBlock, self).__init__()
         if gate is None:
@@ -50,7 +54,7 @@ class ResBlock(nn.Module):
         if norm_layer is None:
             norm_layer = nn.BatchNorm2d
         if groups != 1 or base_width != 64:
-            raise ValueError('ResBlock only supports groups=1 and base_width=64')
+            raise ValueError("ResBlock only supports groups=1 and base_width=64")
         if dilation > 1:
             raise NotImplementedError("Dilation > 1 not supported in ResBlock")
         # Both self.conv1 and self.downsample layers downsample the input when stride != 1
@@ -81,9 +85,15 @@ class ResBlock(nn.Module):
 
 
 class ALNet(nn.Module):
-    def __init__(self, c1: int = 32, c2: int = 64, c3: int = 128, c4: int = 128, dim: int = 128,
-                 single_head: bool = True,
-                 ):
+    def __init__(
+        self,
+        c1: int = 32,
+        c2: int = 64,
+        c3: int = 128,
+        c4: int = 128,
+        dim: int = 128,
+        single_head: bool = True,
+    ):
         super().__init__()
 
         self.gate = nn.ReLU(inplace=True)
@@ -93,28 +103,48 @@ class ALNet(nn.Module):
 
         self.block1 = ConvBlock(3, c1, self.gate, nn.BatchNorm2d)
 
-        self.block2 = ResBlock(inplanes=c1, planes=c2, stride=1,
-                               downsample=nn.Conv2d(c1, c2, 1),
-                               gate=self.gate,
-                               norm_layer=nn.BatchNorm2d)
-        self.block3 = ResBlock(inplanes=c2, planes=c3, stride=1,
-                               downsample=nn.Conv2d(c2, c3, 1),
-                               gate=self.gate,
-                               norm_layer=nn.BatchNorm2d)
-        self.block4 = ResBlock(inplanes=c3, planes=c4, stride=1,
-                               downsample=nn.Conv2d(c3, c4, 1),
-                               gate=self.gate,
-                               norm_layer=nn.BatchNorm2d)
+        self.block2 = ResBlock(
+            inplanes=c1,
+            planes=c2,
+            stride=1,
+            downsample=nn.Conv2d(c1, c2, 1),
+            gate=self.gate,
+            norm_layer=nn.BatchNorm2d,
+        )
+        self.block3 = ResBlock(
+            inplanes=c2,
+            planes=c3,
+            stride=1,
+            downsample=nn.Conv2d(c2, c3, 1),
+            gate=self.gate,
+            norm_layer=nn.BatchNorm2d,
+        )
+        self.block4 = ResBlock(
+            inplanes=c3,
+            planes=c4,
+            stride=1,
+            downsample=nn.Conv2d(c3, c4, 1),
+            gate=self.gate,
+            norm_layer=nn.BatchNorm2d,
+        )
 
         # ================================== feature aggregation
         self.conv1 = resnet.conv1x1(c1, dim // 4)
         self.conv2 = resnet.conv1x1(c2, dim // 4)
         self.conv3 = resnet.conv1x1(c3, dim // 4)
         self.conv4 = resnet.conv1x1(dim, dim // 4)
-        self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
-        self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
-        self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
-        self.upsample32 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=True)
+        self.upsample2 = nn.Upsample(
+            scale_factor=2, mode="bilinear", align_corners=True
+        )
+        self.upsample4 = nn.Upsample(
+            scale_factor=4, mode="bilinear", align_corners=True
+        )
+        self.upsample8 = nn.Upsample(
+            scale_factor=8, mode="bilinear", align_corners=True
+        )
+        self.upsample32 = nn.Upsample(
+            scale_factor=32, mode="bilinear", align_corners=True
+        )
 
         # ================================== detector and descriptor head
         self.single_head = single_head
@@ -153,12 +183,12 @@ class ALNet(nn.Module):
         return scores_map, descriptor_map
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     from thop import profile
 
     net = ALNet(c1=16, c2=32, c3=64, c4=128, dim=128, single_head=True)
 
     image = torch.randn(1, 3, 640, 480)
     flops, params = profile(net, inputs=(image,), verbose=False)
-    print('{:<30}  {:<8} GFLops'.format('Computational complexity: ', flops / 1e9))
-    print('{:<30}  {:<8} KB'.format('Number of parameters: ', params / 1e3))
+    print("{:<30}  {:<8} GFLops".format("Computational complexity: ", flops / 1e9))
+    print("{:<30}  {:<8} KB".format("Number of parameters: ", params / 1e3))
diff --git a/third_party/ALIKE/demo.py b/third_party/ALIKE/demo.py
index 9bfbefdd26cfeceefc75f90d1c44a7f922c624a5..a3f5130eea283404412b374c678ba3a1ae6d1c04 100644
--- a/third_party/ALIKE/demo.py
+++ b/third_party/ALIKE/demo.py
@@ -12,13 +12,13 @@ from alike import ALike, configs
 class ImageLoader(object):
     def __init__(self, filepath: str):
         self.N = 3000
-        if filepath.startswith('camera'):
+        if filepath.startswith("camera"):
             camera = int(filepath[6:])
             self.cap = cv2.VideoCapture(camera)
             if not self.cap.isOpened():
                 raise IOError(f"Can't open camera {camera}!")
-            logging.info(f'Opened camera {camera}')
-            self.mode = 'camera'
+            logging.info(f"Opened camera {camera}")
+            self.mode = "camera"
         elif os.path.exists(filepath):
             if os.path.isfile(filepath):
                 self.cap = cv2.VideoCapture(filepath)
@@ -27,34 +27,38 @@ class ImageLoader(object):
                 rate = self.cap.get(cv2.CAP_PROP_FPS)
                 self.N = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1
                 duration = self.N / rate
-                logging.info(f'Opened video {filepath}')
-                logging.info(f'Frames: {self.N}, FPS: {rate}, Duration: {duration}s')
-                self.mode = 'video'
+                logging.info(f"Opened video {filepath}")
+                logging.info(f"Frames: {self.N}, FPS: {rate}, Duration: {duration}s")
+                self.mode = "video"
             else:
-                self.images = glob.glob(os.path.join(filepath, '*.png')) + \
-                              glob.glob(os.path.join(filepath, '*.jpg')) + \
-                              glob.glob(os.path.join(filepath, '*.ppm'))
+                self.images = (
+                    glob.glob(os.path.join(filepath, "*.png"))
+                    + glob.glob(os.path.join(filepath, "*.jpg"))
+                    + glob.glob(os.path.join(filepath, "*.ppm"))
+                )
                 self.images.sort()
                 self.N = len(self.images)
-                logging.info(f'Loading {self.N} images')
-                self.mode = 'images'
+                logging.info(f"Loading {self.N} images")
+                self.mode = "images"
         else:
-            raise IOError('Error filepath (camerax/path of images/path of videos): ', filepath)
+            raise IOError(
+                "Error filepath (camerax/path of images/path of videos): ", filepath
+            )
 
     def __getitem__(self, item):
-        if self.mode == 'camera' or self.mode == 'video':
+        if self.mode == "camera" or self.mode == "video":
             if item > self.N:
                 return None
             ret, img = self.cap.read()
             if not ret:
                 raise "Can't read image from camera"
-            if self.mode == 'video':
+            if self.mode == "video":
                 self.cap.set(cv2.CAP_PROP_POS_FRAMES, item)
-        elif self.mode == 'images':
+        elif self.mode == "images":
             filename = self.images[item]
             img = cv2.imread(filename)
             if img is None:
-                raise Exception('Error reading image %s' % filename)        
+                raise Exception("Error reading image %s" % filename)
         return img
 
     def __len__(self):
@@ -99,38 +103,68 @@ class SimpleTracker(object):
         nn12 = np.argmax(sim, axis=1)
         nn21 = np.argmax(sim, axis=0)
         ids1 = np.arange(0, sim.shape[0])
-        mask = (ids1 == nn21[nn12])
+        mask = ids1 == nn21[nn12]
         matches = np.stack([ids1[mask], nn12[mask]])
         return matches.transpose()
 
 
-if __name__ == '__main__':
-    parser = argparse.ArgumentParser(description='ALike Demo.')
-    parser.add_argument('input', type=str, default='',
-                        help='Image directory or movie file or "camera0" (for webcam0).')
-    parser.add_argument('--model', choices=['alike-t', 'alike-s', 'alike-n', 'alike-l'], default="alike-t",
-                        help="The model configuration")
-    parser.add_argument('--device', type=str, default='cuda', help="Running device (default: cuda).")
-    parser.add_argument('--top_k', type=int, default=-1,
-                        help='Detect top K keypoints. -1 for threshold based mode, >0 for top K mode. (default: -1)')
-    parser.add_argument('--scores_th', type=float, default=0.2,
-                        help='Detector score threshold (default: 0.2).')
-    parser.add_argument('--n_limit', type=int, default=5000,
-                        help='Maximum number of keypoints to be detected (default: 5000).')
-    parser.add_argument('--no_display', action='store_true',
-                        help='Do not display images to screen. Useful if running remotely (default: False).')
-    parser.add_argument('--no_sub_pixel', action='store_true',
-                        help='Do not detect sub-pixel keypoints (default: False).')
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="ALike Demo.")
+    parser.add_argument(
+        "input",
+        type=str,
+        default="",
+        help='Image directory or movie file or "camera0" (for webcam0).',
+    )
+    parser.add_argument(
+        "--model",
+        choices=["alike-t", "alike-s", "alike-n", "alike-l"],
+        default="alike-t",
+        help="The model configuration",
+    )
+    parser.add_argument(
+        "--device", type=str, default="cuda", help="Running device (default: cuda)."
+    )
+    parser.add_argument(
+        "--top_k",
+        type=int,
+        default=-1,
+        help="Detect top K keypoints. -1 for threshold based mode, >0 for top K mode. (default: -1)",
+    )
+    parser.add_argument(
+        "--scores_th",
+        type=float,
+        default=0.2,
+        help="Detector score threshold (default: 0.2).",
+    )
+    parser.add_argument(
+        "--n_limit",
+        type=int,
+        default=5000,
+        help="Maximum number of keypoints to be detected (default: 5000).",
+    )
+    parser.add_argument(
+        "--no_display",
+        action="store_true",
+        help="Do not display images to screen. Useful if running remotely (default: False).",
+    )
+    parser.add_argument(
+        "--no_sub_pixel",
+        action="store_true",
+        help="Do not detect sub-pixel keypoints (default: False).",
+    )
     args = parser.parse_args()
 
     logging.basicConfig(level=logging.INFO)
 
     image_loader = ImageLoader(args.input)
-    model = ALike(**configs[args.model],
-                  device=args.device,
-                  top_k=args.top_k,
-                  scores_th=args.scores_th,
-                  n_limit=args.n_limit)
+    model = ALike(
+        **configs[args.model],
+        device=args.device,
+        top_k=args.top_k,
+        scores_th=args.scores_th,
+        n_limit=args.n_limit,
+    )
     tracker = SimpleTracker()
 
     if not args.no_display:
@@ -142,26 +176,26 @@ if __name__ == '__main__':
     for img in progress_bar:
         if img is None:
             break
-        
+
         img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
         pred = model(img_rgb, sub_pixel=not args.no_sub_pixel)
-        kpts = pred['keypoints']
-        desc = pred['descriptors']
-        runtime.append(pred['time'])
+        kpts = pred["keypoints"]
+        desc = pred["descriptors"]
+        runtime.append(pred["time"])
 
         out, N_matches = tracker.update(img, kpts, desc)
 
-        ave_fps = (1. / np.stack(runtime)).mean()
+        ave_fps = (1.0 / np.stack(runtime)).mean()
         status = f"Fps:{ave_fps:.1f}, Keypoints/Matches: {len(kpts)}/{N_matches}"
         progress_bar.set_description(status)
 
         if not args.no_display:
-            cv2.setWindowTitle(args.model, args.model + ': ' + status)
+            cv2.setWindowTitle(args.model, args.model + ": " + status)
             cv2.imshow(args.model, out)
-            if cv2.waitKey(1) == ord('q'):
+            if cv2.waitKey(1) == ord("q"):
                 break
 
-    logging.info('Finished!')
+    logging.info("Finished!")
     if not args.no_display:
-        logging.info('Press any key to exit!')
+        logging.info("Press any key to exit!")
         cv2.waitKey()
diff --git a/third_party/ALIKE/hseq/eval.py b/third_party/ALIKE/hseq/eval.py
index abca625044013a0cd34a518223c32d3ec8abb8a3..1d91398740e5dee9d2968fb418fcb45febd015ba 100644
--- a/third_party/ALIKE/hseq/eval.py
+++ b/third_party/ALIKE/hseq/eval.py
@@ -6,29 +6,53 @@ import numpy as np
 from extract import extract_method
 
 use_cuda = torch.cuda.is_available()
-device = torch.device('cuda' if use_cuda else 'cpu')
-
-methods = ['d2', 'lfnet', 'superpoint', 'r2d2', 'aslfeat', 'disk',
-           'alike-n', 'alike-l', 'alike-n-ms', 'alike-l-ms']
-names = ['D2-Net(MS)', 'LF-Net(MS)', 'SuperPoint', 'R2D2(MS)', 'ASLFeat(MS)', 'DISK',
-         'ALike-N', 'ALike-L', 'ALike-N(MS)', 'ALike-L(MS)']
+device = torch.device("cuda" if use_cuda else "cpu")
+
+methods = [
+    "d2",
+    "lfnet",
+    "superpoint",
+    "r2d2",
+    "aslfeat",
+    "disk",
+    "alike-n",
+    "alike-l",
+    "alike-n-ms",
+    "alike-l-ms",
+]
+names = [
+    "D2-Net(MS)",
+    "LF-Net(MS)",
+    "SuperPoint",
+    "R2D2(MS)",
+    "ASLFeat(MS)",
+    "DISK",
+    "ALike-N",
+    "ALike-L",
+    "ALike-N(MS)",
+    "ALike-L(MS)",
+]
 
 top_k = None
 n_i = 52
 n_v = 56
-cache_dir = 'hseq/cache'
-dataset_path = 'hseq/hpatches-sequences-release'
+cache_dir = "hseq/cache"
+dataset_path = "hseq/hpatches-sequences-release"
 
 
-def generate_read_function(method, extension='ppm'):
+def generate_read_function(method, extension="ppm"):
     def read_function(seq_name, im_idx):
-        aux = np.load(os.path.join(dataset_path, seq_name, '%d.%s.%s' % (im_idx, extension, method)))
+        aux = np.load(
+            os.path.join(
+                dataset_path, seq_name, "%d.%s.%s" % (im_idx, extension, method)
+            )
+        )
         if top_k is None:
-            return aux['keypoints'], aux['descriptors']
+            return aux["keypoints"], aux["descriptors"]
         else:
-            assert ('scores' in aux)
-            ids = np.argsort(aux['scores'])[-top_k:]
-            return aux['keypoints'][ids, :], aux['descriptors'][ids, :]
+            assert "scores" in aux
+            ids = np.argsort(aux["scores"])[-top_k:]
+            return aux["keypoints"][ids, :], aux["descriptors"][ids, :]
 
     return read_function
 
@@ -39,7 +63,7 @@ def mnn_matcher(descriptors_a, descriptors_b):
     nn12 = torch.max(sim, dim=1)[1]
     nn21 = torch.max(sim, dim=0)[1]
     ids1 = torch.arange(0, sim.shape[0], device=device)
-    mask = (ids1 == nn21[nn12])
+    mask = ids1 == nn21[nn12]
     matches = torch.stack([ids1[mask], nn12[mask]])
     return matches.t().data.cpu().numpy()
 
@@ -73,7 +97,7 @@ def benchmark_features(read_feats):
         n_feats.append(keypoints_a.shape[0])
 
         # =========== compute homography
-        ref_img = cv2.imread(os.path.join(dataset_path, seq_name, '1.ppm'))
+        ref_img = cv2.imread(os.path.join(dataset_path, seq_name, "1.ppm"))
         ref_img_shape = ref_img.shape
 
         for im_idx in range(2, 7):
@@ -82,17 +106,19 @@ def benchmark_features(read_feats):
 
             matches = mnn_matcher(
                 torch.from_numpy(descriptors_a).to(device=device),
-                torch.from_numpy(descriptors_b).to(device=device)
+                torch.from_numpy(descriptors_b).to(device=device),
             )
 
-            homography = np.loadtxt(os.path.join(dataset_path, seq_name, "H_1_" + str(im_idx)))
+            homography = np.loadtxt(
+                os.path.join(dataset_path, seq_name, "H_1_" + str(im_idx))
+            )
 
-            pos_a = keypoints_a[matches[:, 0], : 2]
+            pos_a = keypoints_a[matches[:, 0], :2]
             pos_a_h = np.concatenate([pos_a, np.ones([matches.shape[0], 1])], axis=1)
             pos_b_proj_h = np.transpose(np.dot(homography, np.transpose(pos_a_h)))
-            pos_b_proj = pos_b_proj_h[:, : 2] / pos_b_proj_h[:, 2:]
+            pos_b_proj = pos_b_proj_h[:, :2] / pos_b_proj_h[:, 2:]
 
-            pos_b = keypoints_b[matches[:, 1], : 2]
+            pos_b = keypoints_b[matches[:, 1], :2]
 
             dist = np.sqrt(np.sum((pos_b - pos_b_proj) ** 2, axis=1))
 
@@ -103,28 +129,37 @@ def benchmark_features(read_feats):
                 dist = np.array([float("inf")])
 
             for thr in rng:
-                if seq_name[0] == 'i':
+                if seq_name[0] == "i":
                     i_err[thr] += np.mean(dist <= thr)
                 else:
                     v_err[thr] += np.mean(dist <= thr)
 
             # =========== compute homography
             gt_homo = homography
-            pred_homo, _ = cv2.findHomography(keypoints_a[matches[:, 0], : 2], keypoints_b[matches[:, 1], : 2],
-                                              cv2.RANSAC)
+            pred_homo, _ = cv2.findHomography(
+                keypoints_a[matches[:, 0], :2],
+                keypoints_b[matches[:, 1], :2],
+                cv2.RANSAC,
+            )
             if pred_homo is None:
                 homo_dist = np.array([float("inf")])
             else:
-                corners = np.array([[0, 0],
-                                    [ref_img_shape[1] - 1, 0],
-                                    [0, ref_img_shape[0] - 1],
-                                    [ref_img_shape[1] - 1, ref_img_shape[0] - 1]])
+                corners = np.array(
+                    [
+                        [0, 0],
+                        [ref_img_shape[1] - 1, 0],
+                        [0, ref_img_shape[0] - 1],
+                        [ref_img_shape[1] - 1, ref_img_shape[0] - 1],
+                    ]
+                )
                 real_warped_corners = homo_trans(corners, gt_homo)
                 warped_corners = homo_trans(corners, pred_homo)
-                homo_dist = np.mean(np.linalg.norm(real_warped_corners - warped_corners, axis=1))
+                homo_dist = np.mean(
+                    np.linalg.norm(real_warped_corners - warped_corners, axis=1)
+                )
 
             for thr in rng:
-                if seq_name[0] == 'i':
+                if seq_name[0] == "i":
                     i_err_homo[thr] += np.mean(homo_dist <= thr)
                 else:
                     v_err_homo[thr] += np.mean(homo_dist <= thr)
@@ -136,10 +171,10 @@ def benchmark_features(read_feats):
     return i_err, v_err, i_err_homo, v_err_homo, [seq_type, n_feats, n_matches]
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     errors = {}
     for method in methods:
-        output_file = os.path.join(cache_dir, method + '.npy')
+        output_file = os.path.join(cache_dir, method + ".npy")
         read_function = generate_read_function(method)
         if os.path.exists(output_file):
             errors[method] = np.load(output_file, allow_pickle=True)
@@ -152,11 +187,11 @@ if __name__ == '__main__':
         i_err, v_err, i_err_hom, v_err_hom, _ = errors[method]
 
         print(f"====={name}=====")
-        print(f"MMA@1 MMA@2 MMA@3 MHA@1 MHA@2 MHA@3: ", end='')
+        print(f"MMA@1 MMA@2 MMA@3 MHA@1 MHA@2 MHA@3: ", end="")
         for thr in range(1, 4):
             err = (i_err[thr] + v_err[thr]) / ((n_i + n_v) * 5)
-            print(f"{err * 100:.2f}%", end=' ')
+            print(f"{err * 100:.2f}%", end=" ")
         for thr in range(1, 4):
             err_hom = (i_err_hom[thr] + v_err_hom[thr]) / ((n_i + n_v) * 5)
-            print(f"{err_hom * 100:.2f}%", end=' ')
-        print('')
+            print(f"{err_hom * 100:.2f}%", end=" ")
+        print("")
diff --git a/third_party/ALIKE/hseq/extract.py b/third_party/ALIKE/hseq/extract.py
index 1342e40dd2d0e1d1986e90f995c95b17972ec4e1..df16ae246bf360b529f0640cab5ae79f495e4f61 100644
--- a/third_party/ALIKE/hseq/extract.py
+++ b/third_party/ALIKE/hseq/extract.py
@@ -9,23 +9,23 @@ from tqdm import tqdm
 from copy import deepcopy
 from torchvision.transforms import ToTensor
 
-sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
+sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
 from alike import ALike, configs
 
-dataset_root = 'hseq/hpatches-sequences-release'
+dataset_root = "hseq/hpatches-sequences-release"
 use_cuda = torch.cuda.is_available()
-device = 'cuda' if use_cuda else 'cpu'
-methods = ['alike-n', 'alike-l', 'alike-n-ms', 'alike-l-ms']
+device = "cuda" if use_cuda else "cpu"
+methods = ["alike-n", "alike-l", "alike-n-ms", "alike-l-ms"]
 
 
 class HPatchesDataset(data.Dataset):
-    def __init__(self, root: str = dataset_root, alteration: str = 'all'):
+    def __init__(self, root: str = dataset_root, alteration: str = "all"):
         """
         Args:
             root: dataset root path
             alteration: # 'all', 'i' for illumination or 'v' for viewpoint
         """
-        assert (Path(root).exists()), f"Dataset root path {root} dose not exist!"
+        assert Path(root).exists(), f"Dataset root path {root} dose not exist!"
         self.root = root
 
         # get all image file name
@@ -35,15 +35,15 @@ class HPatchesDataset(data.Dataset):
         folders = [x for x in Path(self.root).iterdir() if x.is_dir()]
         self.seqs = []
         for folder in folders:
-            if alteration == 'i' and folder.stem[0] != 'i':
+            if alteration == "i" and folder.stem[0] != "i":
                 continue
-            if alteration == 'v' and folder.stem[0] != 'v':
+            if alteration == "v" and folder.stem[0] != "v":
                 continue
 
             self.seqs.append(folder)
 
         self.len = len(self.seqs)
-        assert (self.len > 0), f'Can not find PatchDataset in path {self.root}'
+        assert self.len > 0, f"Can not find PatchDataset in path {self.root}"
 
     def __getitem__(self, item):
         folder = self.seqs[item]
@@ -51,12 +51,12 @@ class HPatchesDataset(data.Dataset):
         imgs = []
         homos = []
         for i in range(1, 7):
-            img = cv2.imread(str(folder / f'{i}.ppm'), cv2.IMREAD_COLOR)
+            img = cv2.imread(str(folder / f"{i}.ppm"), cv2.IMREAD_COLOR)
             img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # HxWxC
             imgs.append(img)
 
             if i != 1:
-                homo = np.loadtxt(str(folder / f'H_1_{i}')).astype('float32')
+                homo = np.loadtxt(str(folder / f"H_1_{i}")).astype("float32")
                 homos.append(homo)
 
         return imgs, homos, folder.stem
@@ -68,11 +68,18 @@ class HPatchesDataset(data.Dataset):
         return self.__class__
 
 
-def extract_multiscale(model, img, scale_f=2 ** 0.5,
-                       min_scale=1., max_scale=1.,
-                       min_size=0., max_size=99999.,
-                       image_size_max=99999,
-                       n_k=0, sort=False):
+def extract_multiscale(
+    model,
+    img,
+    scale_f=2**0.5,
+    min_scale=1.0,
+    max_scale=1.0,
+    min_size=0.0,
+    max_size=99999.0,
+    image_size_max=99999,
+    n_k=0,
+    sort=False,
+):
     H_, W_, three = img.shape
     assert three == 3, "input image shape should be [HxWx3]"
 
@@ -100,7 +107,9 @@ def extract_multiscale(model, img, scale_f=2 ** 0.5,
             # extract descriptors
             with torch.no_grad():
                 descriptor_map, scores_map = model.extract_dense_map(image)
-                keypoints_, descriptors_, scores_, _ = model.dkd(scores_map, descriptor_map)
+                keypoints_, descriptors_, scores_, _ = model.dkd(
+                    scores_map, descriptor_map
+                )
 
             keypoints.append(keypoints_[0])
             descriptors.append(descriptors_[0])
@@ -110,7 +119,9 @@ def extract_multiscale(model, img, scale_f=2 ** 0.5,
 
         # down-scale the image for next iteration
         nh, nw = round(H * s), round(W * s)
-        image = torch.nn.functional.interpolate(image, (nh, nw), mode='bilinear', align_corners=False)
+        image = torch.nn.functional.interpolate(
+            image, (nh, nw), mode="bilinear", align_corners=False
+        )
 
     # restore value
     torch.backends.cudnn.benchmark = old_bm
@@ -131,29 +142,34 @@ def extract_multiscale(model, img, scale_f=2 ** 0.5,
         descriptors = descriptors[0:n_k]
         scores = scores[0:n_k]
 
-    return {'keypoints': keypoints, 'descriptors': descriptors, 'scores': scores}
+    return {"keypoints": keypoints, "descriptors": descriptors, "scores": scores}
 
 
 def extract_method(m):
-    hpatches = HPatchesDataset(root=dataset_root, alteration='all')
+    hpatches = HPatchesDataset(root=dataset_root, alteration="all")
     model = m[:7]
-    min_scale = 0.3 if m[8:] == 'ms' else 1.0
+    min_scale = 0.3 if m[8:] == "ms" else 1.0
 
     model = ALike(**configs[model], device=device, top_k=0, scores_th=0.2, n_limit=5000)
 
-    progbar = tqdm(hpatches, desc='Extracting for {}'.format(m))
+    progbar = tqdm(hpatches, desc="Extracting for {}".format(m))
     for imgs, homos, seq_name in progbar:
         for i in range(1, 7):
             img = imgs[i - 1]
-            pred = extract_multiscale(model, img, min_scale=min_scale, max_scale=1, sort=False, n_k=5000)
-            kpts, descs, scores = pred['keypoints'], pred['descriptors'], pred['scores']
+            pred = extract_multiscale(
+                model, img, min_scale=min_scale, max_scale=1, sort=False, n_k=5000
+            )
+            kpts, descs, scores = pred["keypoints"], pred["descriptors"], pred["scores"]
 
-            with open(os.path.join(dataset_root, seq_name, f'{i}.ppm.{m}'), 'wb') as f:
-                np.savez(f, keypoints=kpts.cpu().numpy(),
-                         scores=scores.cpu().numpy(),
-                         descriptors=descs.cpu().numpy())
+            with open(os.path.join(dataset_root, seq_name, f"{i}.ppm.{m}"), "wb") as f:
+                np.savez(
+                    f,
+                    keypoints=kpts.cpu().numpy(),
+                    scores=scores.cpu().numpy(),
+                    descriptors=descs.cpu().numpy(),
+                )
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     for method in methods:
         extract_method(method)
diff --git a/third_party/ALIKE/soft_detect.py b/third_party/ALIKE/soft_detect.py
index 2d23cd13b8a7db9b0398fdc1b235564222d30c90..636ba11d0584c513631fffce31ba2d71be3e6c74 100644
--- a/third_party/ALIKE/soft_detect.py
+++ b/third_party/ALIKE/soft_detect.py
@@ -17,13 +17,15 @@ import torch.nn.functional as F
 #  v
 # [ y: range=-1.0~1.0; h: range=0~H ]
 
+
 def simple_nms(scores, nms_radius: int):
-    """ Fast Non-maximum suppression to remove nearby points """
-    assert (nms_radius >= 0)
+    """Fast Non-maximum suppression to remove nearby points"""
+    assert nms_radius >= 0
 
     def max_pool(x):
         return torch.nn.functional.max_pool2d(
-            x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius)
+            x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
+        )
 
     zeros = torch.zeros_like(scores)
     max_mask = scores == max_pool(scores)
@@ -50,8 +52,14 @@ def sample_descriptor(descriptor_map, kpts, bilinear_interp=False):
         kptsi = kpts[index]  # Nx2,(x,y)
 
         if bilinear_interp:
-            descriptors_ = torch.nn.functional.grid_sample(descriptor_map[index].unsqueeze(0), kptsi.view(1, 1, -1, 2),
-                                                           mode='bilinear', align_corners=True)[0, :, 0, :]  # CxN
+            descriptors_ = torch.nn.functional.grid_sample(
+                descriptor_map[index].unsqueeze(0),
+                kptsi.view(1, 1, -1, 2),
+                mode="bilinear",
+                align_corners=True,
+            )[
+                0, :, 0, :
+            ]  # CxN
         else:
             kptsi = (kptsi + 1) / 2 * kptsi.new_tensor([[width - 1, height - 1]])
             kptsi = kptsi.long()
@@ -94,10 +102,10 @@ class DKD(nn.Module):
         nms_scores = simple_nms(scores_nograd, 2)
 
         # remove border
-        nms_scores[:, :, :self.radius + 1, :] = 0
-        nms_scores[:, :, :, :self.radius + 1] = 0
-        nms_scores[:, :, h - self.radius:, :] = 0
-        nms_scores[:, :, :, w - self.radius:] = 0
+        nms_scores[:, :, : self.radius + 1, :] = 0
+        nms_scores[:, :, :, : self.radius + 1] = 0
+        nms_scores[:, :, h - self.radius :, :] = 0
+        nms_scores[:, :, :, w - self.radius :] = 0
 
         # detect keypoints without grad
         if self.top_k > 0:
@@ -121,7 +129,7 @@ class DKD(nn.Module):
                 if len(indices) > self.n_limit:
                     kpts_sc = scores[indices]
                     sort_idx = kpts_sc.sort(descending=True)[1]
-                    sel_idx = sort_idx[:self.n_limit]
+                    sel_idx = sort_idx[: self.n_limit]
                     indices = indices[sel_idx]
                 indices_keypoints.append(indices)
 
@@ -134,42 +142,73 @@ class DKD(nn.Module):
             self.hw_grid = self.hw_grid.to(patches)  # to device
             for b_idx in range(b):
                 patch = patches[b_idx].t()  # (H*W) x (kernel**2)
-                indices_kpt = indices_keypoints[b_idx]  # one dimension vector, say its size is M
+                indices_kpt = indices_keypoints[
+                    b_idx
+                ]  # one dimension vector, say its size is M
                 patch_scores = patch[indices_kpt]  # M x (kernel**2)
 
                 # max is detached to prevent undesired backprop loops in the graph
                 max_v = patch_scores.max(dim=1).values.detach()[:, None]
-                x_exp = ((patch_scores - max_v) / self.temperature).exp()  # M * (kernel**2), in [0, 1]
+                x_exp = (
+                    (patch_scores - max_v) / self.temperature
+                ).exp()  # M * (kernel**2), in [0, 1]
 
                 # \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
-                xy_residual = x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None]  # Soft-argmax, Mx2
-
-                hw_grid_dist2 = torch.norm((self.hw_grid[None, :, :] - xy_residual[:, None, :]) / self.radius,
-                                           dim=-1) ** 2
+                xy_residual = (
+                    x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None]
+                )  # Soft-argmax, Mx2
+
+                hw_grid_dist2 = (
+                    torch.norm(
+                        (self.hw_grid[None, :, :] - xy_residual[:, None, :])
+                        / self.radius,
+                        dim=-1,
+                    )
+                    ** 2
+                )
                 scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)
 
                 # compute result keypoints
-                keypoints_xy_nms = torch.stack([indices_kpt % w, indices_kpt // w], dim=1)  # Mx2
+                keypoints_xy_nms = torch.stack(
+                    [indices_kpt % w, indices_kpt // w], dim=1
+                )  # Mx2
                 keypoints_xy = keypoints_xy_nms + xy_residual
-                keypoints_xy = keypoints_xy / keypoints_xy.new_tensor(
-                    [w - 1, h - 1]) * 2 - 1  # (w,h) -> (-1~1,-1~1)
-
-                kptscore = torch.nn.functional.grid_sample(scores_map[b_idx].unsqueeze(0),
-                                                           keypoints_xy.view(1, 1, -1, 2),
-                                                           mode='bilinear', align_corners=True)[0, 0, 0, :]  # CxN
+                keypoints_xy = (
+                    keypoints_xy / keypoints_xy.new_tensor([w - 1, h - 1]) * 2 - 1
+                )  # (w,h) -> (-1~1,-1~1)
+
+                kptscore = torch.nn.functional.grid_sample(
+                    scores_map[b_idx].unsqueeze(0),
+                    keypoints_xy.view(1, 1, -1, 2),
+                    mode="bilinear",
+                    align_corners=True,
+                )[
+                    0, 0, 0, :
+                ]  # CxN
 
                 keypoints.append(keypoints_xy)
                 scoredispersitys.append(scoredispersity)
                 kptscores.append(kptscore)
         else:
             for b_idx in range(b):
-                indices_kpt = indices_keypoints[b_idx]  # one dimension vector, say its size is M
-                keypoints_xy_nms = torch.stack([indices_kpt % w, indices_kpt // w], dim=1)  # Mx2
-                keypoints_xy = keypoints_xy_nms / keypoints_xy_nms.new_tensor(
-                    [w - 1, h - 1]) * 2 - 1  # (w,h) -> (-1~1,-1~1)
-                kptscore = torch.nn.functional.grid_sample(scores_map[b_idx].unsqueeze(0),
-                                                           keypoints_xy.view(1, 1, -1, 2),
-                                                           mode='bilinear', align_corners=True)[0, 0, 0, :]  # CxN
+                indices_kpt = indices_keypoints[
+                    b_idx
+                ]  # one dimension vector, say its size is M
+                keypoints_xy_nms = torch.stack(
+                    [indices_kpt % w, indices_kpt // w], dim=1
+                )  # Mx2
+                keypoints_xy = (
+                    keypoints_xy_nms / keypoints_xy_nms.new_tensor([w - 1, h - 1]) * 2
+                    - 1
+                )  # (w,h) -> (-1~1,-1~1)
+                kptscore = torch.nn.functional.grid_sample(
+                    scores_map[b_idx].unsqueeze(0),
+                    keypoints_xy.view(1, 1, -1, 2),
+                    mode="bilinear",
+                    align_corners=True,
+                )[
+                    0, 0, 0, :
+                ]  # CxN
                 keypoints.append(keypoints_xy)
                 scoredispersitys.append(None)
                 kptscores.append(kptscore)
@@ -183,8 +222,9 @@ class DKD(nn.Module):
         :param sub_pixel: whether to use sub-pixel keypoint detection
         :return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1.0 ~ 1.0
         """
-        keypoints, scoredispersitys, kptscores = self.detect_keypoints(scores_map,
-                                                                       sub_pixel)
+        keypoints, scoredispersitys, kptscores = self.detect_keypoints(
+            scores_map, sub_pixel
+        )
 
         descriptors = sample_descriptor(descriptor_map, keypoints, sub_pixel)
 
diff --git a/third_party/ASpanFormer/configs/aspan/indoor/aspan_test.py b/third_party/ASpanFormer/configs/aspan/indoor/aspan_test.py
index fc2b44807696ec280672c8f40650fd04fa4d8a36..00ea16cd35dc4362d0d9a294ad8a1762427bc382 100644
--- a/third_party/ASpanFormer/configs/aspan/indoor/aspan_test.py
+++ b/third_party/ASpanFormer/configs/aspan/indoor/aspan_test.py
@@ -1,10 +1,11 @@
 import sys
 from pathlib import Path
-sys.path.append(str(Path(__file__).parent / '../../../'))
+
+sys.path.append(str(Path(__file__).parent / "../../../"))
 from src.config.default import _CN as cfg
 
-cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
+cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = "dual_softmax"
 
 cfg.ASPAN.MATCH_COARSE.BORDER_RM = 0
-cfg.ASPAN.COARSE.COARSEST_LEVEL= [15,20]
-cfg.ASPAN.COARSE.TRAIN_RES = [480,640]
+cfg.ASPAN.COARSE.COARSEST_LEVEL = [15, 20]
+cfg.ASPAN.COARSE.TRAIN_RES = [480, 640]
diff --git a/third_party/ASpanFormer/configs/aspan/indoor/aspan_train.py b/third_party/ASpanFormer/configs/aspan/indoor/aspan_train.py
index 886d10d8f55533c8021bcca8395b5a2897fb8734..854132e8c8af3b3c9c85fa797a79a149aff545ef 100644
--- a/third_party/ASpanFormer/configs/aspan/indoor/aspan_train.py
+++ b/third_party/ASpanFormer/configs/aspan/indoor/aspan_train.py
@@ -1,10 +1,11 @@
 import sys
 from pathlib import Path
-sys.path.append(str(Path(__file__).parent / '../../../'))
+
+sys.path.append(str(Path(__file__).parent / "../../../"))
 from src.config.default import _CN as cfg
 
-cfg.ASPAN.COARSE.COARSEST_LEVEL= [15,20]
-cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
+cfg.ASPAN.COARSE.COARSEST_LEVEL = [15, 20]
+cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = "dual_softmax"
 
 cfg.ASPAN.MATCH_COARSE.SPARSE_SPVS = False
 cfg.ASPAN.MATCH_COARSE.BORDER_RM = 0
diff --git a/third_party/ASpanFormer/configs/aspan/outdoor/aspan_test.py b/third_party/ASpanFormer/configs/aspan/outdoor/aspan_test.py
index f0b9c04cbf3f466e413b345272afe7d7fe4274ea..e2ff53d7a1943f4149c43cdb6f2547c2290651aa 100644
--- a/third_party/ASpanFormer/configs/aspan/outdoor/aspan_test.py
+++ b/third_party/ASpanFormer/configs/aspan/outdoor/aspan_test.py
@@ -1,12 +1,13 @@
 import sys
 from pathlib import Path
-sys.path.append(str(Path(__file__).parent / '../../../'))
+
+sys.path.append(str(Path(__file__).parent / "../../../"))
 from src.config.default import _CN as cfg
 
-cfg.ASPAN.COARSE.COARSEST_LEVEL= [36,36]
-cfg.ASPAN.COARSE.TRAIN_RES = [832,832]
-cfg.ASPAN.COARSE.TEST_RES = [1152,1152]
-cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
+cfg.ASPAN.COARSE.COARSEST_LEVEL = [36, 36]
+cfg.ASPAN.COARSE.TRAIN_RES = [832, 832]
+cfg.ASPAN.COARSE.TEST_RES = [1152, 1152]
+cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = "dual_softmax"
 
 cfg.TRAINER.CANONICAL_LR = 8e-3
 cfg.TRAINER.WARMUP_STEP = 1875  # 3 epochs
diff --git a/third_party/ASpanFormer/configs/aspan/outdoor/aspan_train.py b/third_party/ASpanFormer/configs/aspan/outdoor/aspan_train.py
index 1202080b234562d8cc65d924d7cccf0336b9f7c0..b226243478579ba2f1d4f45d8c90c02fb347d7ff 100644
--- a/third_party/ASpanFormer/configs/aspan/outdoor/aspan_train.py
+++ b/third_party/ASpanFormer/configs/aspan/outdoor/aspan_train.py
@@ -1,10 +1,11 @@
 import sys
 from pathlib import Path
-sys.path.append(str(Path(__file__).parent / '../../../'))
+
+sys.path.append(str(Path(__file__).parent / "../../../"))
 from src.config.default import _CN as cfg
 
-cfg.ASPAN.COARSE.COARSEST_LEVEL= [26,26]
-cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
+cfg.ASPAN.COARSE.COARSEST_LEVEL = [26, 26]
+cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = "dual_softmax"
 cfg.ASPAN.MATCH_COARSE.SPARSE_SPVS = False
 
 cfg.TRAINER.CANONICAL_LR = 8e-3
diff --git a/third_party/ASpanFormer/configs/data/base.py b/third_party/ASpanFormer/configs/data/base.py
index 03aab160fa4137ccc04380f94854a56fbb549074..2621621cd3caf2edb11b41a96b11aa6a63afba92 100644
--- a/third_party/ASpanFormer/configs/data/base.py
+++ b/third_party/ASpanFormer/configs/data/base.py
@@ -4,6 +4,7 @@ Setups in data configs will override all existed setups!
 """
 
 from yacs.config import CfgNode as CN
+
 _CN = CN()
 _CN.DATASET = CN()
 _CN.TRAINER = CN()
diff --git a/third_party/ASpanFormer/configs/data/megadepth_test_1500.py b/third_party/ASpanFormer/configs/data/megadepth_test_1500.py
index 9616432f52a693ed84f3f12b9b85470b23410eee..a8d07aafd1944188cec525043c775d268b01be1f 100644
--- a/third_party/ASpanFormer/configs/data/megadepth_test_1500.py
+++ b/third_party/ASpanFormer/configs/data/megadepth_test_1500.py
@@ -8,6 +8,6 @@ cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}"
 cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/megadepth_test_1500.txt"
 
 cfg.DATASET.MGDPT_IMG_RESIZE = 1152
-cfg.DATASET.MGDPT_IMG_PAD=True
-cfg.DATASET.MGDPT_DF =8
-cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
\ No newline at end of file
+cfg.DATASET.MGDPT_IMG_PAD = True
+cfg.DATASET.MGDPT_DF = 8
+cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
diff --git a/third_party/ASpanFormer/configs/data/megadepth_trainval_832.py b/third_party/ASpanFormer/configs/data/megadepth_trainval_832.py
index 8f9b01fdaed254e10b3d55980499b88a00060f04..48b9bd095d64c681d0e64ee9416fb63fbd1f27b5 100644
--- a/third_party/ASpanFormer/configs/data/megadepth_trainval_832.py
+++ b/third_party/ASpanFormer/configs/data/megadepth_trainval_832.py
@@ -11,9 +11,13 @@ cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0
 TEST_BASE_PATH = "data/megadepth/index"
 cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
 cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test"
-cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}/scene_info_val_1500"
-cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt"
-cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0   # for both test and val
+cfg.DATASET.VAL_NPZ_ROOT = (
+    cfg.DATASET.TEST_NPZ_ROOT
+) = f"{TEST_BASE_PATH}/scene_info_val_1500"
+cfg.DATASET.VAL_LIST_PATH = (
+    cfg.DATASET.TEST_LIST_PATH
+) = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt"
+cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0  # for both test and val
 
 # 368 scenes in total for MegaDepth
 # (with difficulty balanced (further split each scene to 3 sub-scenes))
diff --git a/third_party/ASpanFormer/configs/data/scannet_trainval.py b/third_party/ASpanFormer/configs/data/scannet_trainval.py
index c38d6440e2b4ec349e5f168909c7f8c367408813..a9a5b8a332e012a2891bbf7ec8842523b67e7599 100644
--- a/third_party/ASpanFormer/configs/data/scannet_trainval.py
+++ b/third_party/ASpanFormer/configs/data/scannet_trainval.py
@@ -12,6 +12,10 @@ TEST_BASE_PATH = "assets/scannet_test_1500"
 cfg.DATASET.TEST_DATA_SOURCE = "ScanNet"
 cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/scannet/test"
 cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = TEST_BASE_PATH
-cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/scannet_test.txt"
-cfg.DATASET.VAL_INTRINSIC_PATH = cfg.DATASET.TEST_INTRINSIC_PATH = f"{TEST_BASE_PATH}/intrinsics.npz"
-cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0   # for both test and val
+cfg.DATASET.VAL_LIST_PATH = (
+    cfg.DATASET.TEST_LIST_PATH
+) = f"{TEST_BASE_PATH}/scannet_test.txt"
+cfg.DATASET.VAL_INTRINSIC_PATH = (
+    cfg.DATASET.TEST_INTRINSIC_PATH
+) = f"{TEST_BASE_PATH}/intrinsics.npz"
+cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0  # for both test and val
diff --git a/third_party/ASpanFormer/demo/demo.py b/third_party/ASpanFormer/demo/demo.py
index f3d95b10dc3166c18ad8493be7a3d36a25d8fc3b..dceb13523faec756063b40fd586bcd81f483e274 100644
--- a/third_party/ASpanFormer/demo/demo.py
+++ b/third_party/ASpanFormer/demo/demo.py
@@ -1,63 +1,91 @@
 import os
 import sys
+
 ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
 sys.path.insert(0, ROOT_DIR)
 
-from src.ASpanFormer.aspanformer import ASpanFormer 
+from src.ASpanFormer.aspanformer import ASpanFormer
 from src.config.default import get_cfg_defaults
 from src.utils.misc import lower_config
-import demo_utils 
+import demo_utils
 
 import cv2
 import torch
 import numpy as np
 
 import argparse
+
 parser = argparse.ArgumentParser()
-parser.add_argument('--config_path', type=str, default='../configs/aspan/outdoor/aspan_test.py',
-  help='path for config file.')
-parser.add_argument('--img0_path', type=str, default='../assets/phototourism_sample_images/piazza_san_marco_06795901_3725050516.jpg',
-  help='path for image0.')
-parser.add_argument('--img1_path', type=str, default='../assets/phototourism_sample_images/piazza_san_marco_15148634_5228701572.jpg',
-  help='path for image1.')
-parser.add_argument('--weights_path', type=str, default='../weights/outdoor.ckpt',
-  help='path for model weights.')
-parser.add_argument('--long_dim0', type=int, default=1024,
-  help='resize for longest dim of image0.')
-parser.add_argument('--long_dim1', type=int, default=1024,
-  help='resize for longest dim of image1.')
+parser.add_argument(
+    "--config_path",
+    type=str,
+    default="../configs/aspan/outdoor/aspan_test.py",
+    help="path for config file.",
+)
+parser.add_argument(
+    "--img0_path",
+    type=str,
+    default="../assets/phototourism_sample_images/piazza_san_marco_06795901_3725050516.jpg",
+    help="path for image0.",
+)
+parser.add_argument(
+    "--img1_path",
+    type=str,
+    default="../assets/phototourism_sample_images/piazza_san_marco_15148634_5228701572.jpg",
+    help="path for image1.",
+)
+parser.add_argument(
+    "--weights_path",
+    type=str,
+    default="../weights/outdoor.ckpt",
+    help="path for model weights.",
+)
+parser.add_argument(
+    "--long_dim0", type=int, default=1024, help="resize for longest dim of image0."
+)
+parser.add_argument(
+    "--long_dim1", type=int, default=1024, help="resize for longest dim of image1."
+)
 
 args = parser.parse_args()
 
 
-if __name__=='__main__':
+if __name__ == "__main__":
     config = get_cfg_defaults()
     config.merge_from_file(args.config_path)
     _config = lower_config(config)
-    matcher = ASpanFormer(config=_config['aspan'])
-    state_dict = torch.load(args.weights_path, map_location='cpu')['state_dict']
-    matcher.load_state_dict(state_dict,strict=False)
-    matcher.cuda(),matcher.eval()
-
-    img0,img1=cv2.imread(args.img0_path),cv2.imread(args.img1_path)
-    img0_g,img1_g=cv2.imread(args.img0_path,0),cv2.imread(args.img1_path,0)
-    img0,img1=demo_utils.resize(img0,args.long_dim0),demo_utils.resize(img1,args.long_dim1)
-    img0_g,img1_g=demo_utils.resize(img0_g,args.long_dim0),demo_utils.resize(img1_g,args.long_dim1)
-    data={'image0':torch.from_numpy(img0_g/255.)[None,None].cuda().float(),
-          'image1':torch.from_numpy(img1_g/255.)[None,None].cuda().float()} 
-    with torch.no_grad():   
-      matcher(data,online_resize=True)
-      corr0,corr1=data['mkpts0_f'].cpu().numpy(),data['mkpts1_f'].cpu().numpy()
-
-    F_hat,mask_F=cv2.findFundamentalMat(corr0,corr1,method=cv2.FM_RANSAC,ransacReprojThreshold=1)
+    matcher = ASpanFormer(config=_config["aspan"])
+    state_dict = torch.load(args.weights_path, map_location="cpu")["state_dict"]
+    matcher.load_state_dict(state_dict, strict=False)
+    matcher.cuda(), matcher.eval()
+
+    img0, img1 = cv2.imread(args.img0_path), cv2.imread(args.img1_path)
+    img0_g, img1_g = cv2.imread(args.img0_path, 0), cv2.imread(args.img1_path, 0)
+    img0, img1 = demo_utils.resize(img0, args.long_dim0), demo_utils.resize(
+        img1, args.long_dim1
+    )
+    img0_g, img1_g = demo_utils.resize(img0_g, args.long_dim0), demo_utils.resize(
+        img1_g, args.long_dim1
+    )
+    data = {
+        "image0": torch.from_numpy(img0_g / 255.0)[None, None].cuda().float(),
+        "image1": torch.from_numpy(img1_g / 255.0)[None, None].cuda().float(),
+    }
+    with torch.no_grad():
+        matcher(data, online_resize=True)
+        corr0, corr1 = data["mkpts0_f"].cpu().numpy(), data["mkpts1_f"].cpu().numpy()
+
+    F_hat, mask_F = cv2.findFundamentalMat(
+        corr0, corr1, method=cv2.FM_RANSAC, ransacReprojThreshold=1
+    )
     if mask_F is not None:
-      mask_F=mask_F[:,0].astype(bool) 
+        mask_F = mask_F[:, 0].astype(bool)
     else:
-      mask_F=np.zeros_like(corr0[:,0]).astype(bool)
-
-    #visualize match
-    display=demo_utils.draw_match(img0,img1,corr0,corr1)
-    display_ransac=demo_utils.draw_match(img0,img1,corr0[mask_F],corr1[mask_F])
-    cv2.imwrite('match.png',display)
-    cv2.imwrite('match_ransac.png',display_ransac)
-    print(len(corr1),len(corr1[mask_F]))
\ No newline at end of file
+        mask_F = np.zeros_like(corr0[:, 0]).astype(bool)
+
+    # visualize match
+    display = demo_utils.draw_match(img0, img1, corr0, corr1)
+    display_ransac = demo_utils.draw_match(img0, img1, corr0[mask_F], corr1[mask_F])
+    cv2.imwrite("match.png", display)
+    cv2.imwrite("match_ransac.png", display_ransac)
+    print(len(corr1), len(corr1[mask_F]))
diff --git a/third_party/ASpanFormer/demo/demo_utils.py b/third_party/ASpanFormer/demo/demo_utils.py
index a104e25d3f5ee8b7efb6cc5fa0dc27378e22c83f..fcc8f71e02406fef4ac97fef2d0fec7c9196ad57 100644
--- a/third_party/ASpanFormer/demo/demo_utils.py
+++ b/third_party/ASpanFormer/demo/demo_utils.py
@@ -1,44 +1,88 @@
 import cv2
 import numpy as np
 
-def resize(image,long_dim):
-    h,w=image.shape[0],image.shape[1]
-    image=cv2.resize(image,(int(w*long_dim/max(h,w)),int(h*long_dim/max(h,w))))
+
+def resize(image, long_dim):
+    h, w = image.shape[0], image.shape[1]
+    image = cv2.resize(
+        image, (int(w * long_dim / max(h, w)), int(h * long_dim / max(h, w)))
+    )
     return image
 
-def draw_points(img,points,color=(0,255,0),radius=3):
+
+def draw_points(img, points, color=(0, 255, 0), radius=3):
     dp = [(int(points[i, 0]), int(points[i, 1])) for i in range(points.shape[0])]
     for i in range(points.shape[0]):
-        cv2.circle(img, dp[i],radius=radius,color=color)
+        cv2.circle(img, dp[i], radius=radius, color=color)
     return img
-    
 
-def draw_match(img1, img2, corr1, corr2,inlier=[True],color=None,radius1=1,radius2=1,resize=None):
+
+def draw_match(
+    img1,
+    img2,
+    corr1,
+    corr2,
+    inlier=[True],
+    color=None,
+    radius1=1,
+    radius2=1,
+    resize=None,
+):
     if resize is not None:
-        scale1,scale2=[img1.shape[1]/resize[0],img1.shape[0]/resize[1]],[img2.shape[1]/resize[0],img2.shape[0]/resize[1]]
-        img1,img2=cv2.resize(img1, resize, interpolation=cv2.INTER_AREA),cv2.resize(img2, resize, interpolation=cv2.INTER_AREA) 
-        corr1,corr2=corr1/np.asarray(scale1)[np.newaxis],corr2/np.asarray(scale2)[np.newaxis]
-    corr1_key = [cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0])]
-    corr2_key = [cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0])]
+        scale1, scale2 = [img1.shape[1] / resize[0], img1.shape[0] / resize[1]], [
+            img2.shape[1] / resize[0],
+            img2.shape[0] / resize[1],
+        ]
+        img1, img2 = cv2.resize(img1, resize, interpolation=cv2.INTER_AREA), cv2.resize(
+            img2, resize, interpolation=cv2.INTER_AREA
+        )
+        corr1, corr2 = (
+            corr1 / np.asarray(scale1)[np.newaxis],
+            corr2 / np.asarray(scale2)[np.newaxis],
+        )
+    corr1_key = [
+        cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0])
+    ]
+    corr2_key = [
+        cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0])
+    ]
 
     assert len(corr1) == len(corr2)
 
     draw_matches = [cv2.DMatch(i, i, 0) for i in range(len(corr1))]
     if color is None:
-        color = [(0, 255, 0) if cur_inlier else (0,0,255) for cur_inlier in inlier]
-    if len(color)==1:
-        display = cv2.drawMatches(img1, corr1_key, img2, corr2_key, draw_matches, None,
-                              matchColor=color[0],
-                              singlePointColor=color[0],
-                              flags=4
-                              )
+        color = [(0, 255, 0) if cur_inlier else (0, 0, 255) for cur_inlier in inlier]
+    if len(color) == 1:
+        display = cv2.drawMatches(
+            img1,
+            corr1_key,
+            img2,
+            corr2_key,
+            draw_matches,
+            None,
+            matchColor=color[0],
+            singlePointColor=color[0],
+            flags=4,
+        )
     else:
-        height,width=max(img1.shape[0],img2.shape[0]),img1.shape[1]+img2.shape[1]
-        display=np.zeros([height,width,3],np.uint8)
-        display[:img1.shape[0],:img1.shape[1]]=img1
-        display[:img2.shape[0],img1.shape[1]:]=img2
+        height, width = max(img1.shape[0], img2.shape[0]), img1.shape[1] + img2.shape[1]
+        display = np.zeros([height, width, 3], np.uint8)
+        display[: img1.shape[0], : img1.shape[1]] = img1
+        display[: img2.shape[0], img1.shape[1] :] = img2
         for i in range(len(corr1)):
-            left_x,left_y,right_x,right_y=int(corr1[i][0]),int(corr1[i][1]),int(corr2[i][0]+img1.shape[1]),int(corr2[i][1])
-            cur_color=(int(color[i][0]),int(color[i][1]),int(color[i][2]))
-            cv2.line(display, (left_x,left_y), (right_x,right_y),cur_color,1,lineType=cv2.LINE_AA)
-    return display
\ No newline at end of file
+            left_x, left_y, right_x, right_y = (
+                int(corr1[i][0]),
+                int(corr1[i][1]),
+                int(corr2[i][0] + img1.shape[1]),
+                int(corr2[i][1]),
+            )
+            cur_color = (int(color[i][0]), int(color[i][1]), int(color[i][2]))
+            cv2.line(
+                display,
+                (left_x, left_y),
+                (right_x, right_y),
+                cur_color,
+                1,
+                lineType=cv2.LINE_AA,
+            )
+    return display
diff --git a/third_party/ASpanFormer/src/ASpanFormer/aspan_module/__init__.py b/third_party/ASpanFormer/src/ASpanFormer/aspan_module/__init__.py
index dff6704976cbe9e916c6de6af9e3b755dfbd20bf..0603d4088cd41dc4669ff60368fd1547000c161f 100644
--- a/third_party/ASpanFormer/src/ASpanFormer/aspan_module/__init__.py
+++ b/third_party/ASpanFormer/src/ASpanFormer/aspan_module/__init__.py
@@ -1,3 +1,3 @@
 from .transformer import LocalFeatureTransformer_Flow
-from .loftr import LocalFeatureTransformer 
+from .loftr import LocalFeatureTransformer
 from .fine_preprocess import FinePreprocess
diff --git a/third_party/ASpanFormer/src/ASpanFormer/aspan_module/attention.py b/third_party/ASpanFormer/src/ASpanFormer/aspan_module/attention.py
index 632dd22077806d2b53f66a09d0567925a30d1523..984b0df8b6bc8783b6ade4e9dbdf39b8a5673850 100644
--- a/third_party/ASpanFormer/src/ASpanFormer/aspan_module/attention.py
+++ b/third_party/ASpanFormer/src/ASpanFormer/aspan_module/attention.py
@@ -4,39 +4,59 @@ import torch.nn as nn
 from itertools import product
 from torch.nn import functional as F
 
+
 class layernorm2d(nn.Module):
-     
-     def __init__(self,dim) :
-         super().__init__()
-         self.dim=dim
-         self.affine=nn.parameter.Parameter(torch.ones(dim), requires_grad=True)
-         self.bias=nn.parameter.Parameter(torch.zeros(dim), requires_grad=True) 
-    
-     def forward(self,x):
-        #x: B*C*H*W
-        mean,std=x.mean(dim=1,keepdim=True),x.std(dim=1,keepdim=True)
-        return self.affine[None,:,None,None]*(x-mean)/(std+1e-6)+self.bias[None,:,None,None]
+    def __init__(self, dim):
+        super().__init__()
+        self.dim = dim
+        self.affine = nn.parameter.Parameter(torch.ones(dim), requires_grad=True)
+        self.bias = nn.parameter.Parameter(torch.zeros(dim), requires_grad=True)
+
+    def forward(self, x):
+        # x: B*C*H*W
+        mean, std = x.mean(dim=1, keepdim=True), x.std(dim=1, keepdim=True)
+        return (
+            self.affine[None, :, None, None] * (x - mean) / (std + 1e-6)
+            + self.bias[None, :, None, None]
+        )
 
 
 class HierachicalAttention(Module):
-    def __init__(self,d_model,nhead,nsample,radius_scale,nlevel=3):
+    def __init__(self, d_model, nhead, nsample, radius_scale, nlevel=3):
         super().__init__()
-        self.d_model=d_model
-        self.nhead=nhead
-        self.nsample=nsample
-        self.nlevel=nlevel
-        self.radius_scale=radius_scale
+        self.d_model = d_model
+        self.nhead = nhead
+        self.nsample = nsample
+        self.nlevel = nlevel
+        self.radius_scale = radius_scale
         self.merge_head = nn.Sequential(
-            nn.Conv1d(d_model*3, d_model, kernel_size=1,bias=False),
+            nn.Conv1d(d_model * 3, d_model, kernel_size=1, bias=False),
             nn.ReLU(True),
-            nn.Conv1d(d_model, d_model, kernel_size=1,bias=False),
+            nn.Conv1d(d_model, d_model, kernel_size=1, bias=False),
         )
-        self.fullattention=FullAttention(d_model,nhead)
-        self.temp=nn.parameter.Parameter(torch.tensor(1.),requires_grad=True) 
-        sample_offset=torch.tensor([[pos[0]-nsample[1]/2+0.5, pos[1]-nsample[1]/2+0.5] for pos in product(range(nsample[1]), range(nsample[1]))]) #r^2*2
-        self.sample_offset=nn.parameter.Parameter(sample_offset,requires_grad=False)
+        self.fullattention = FullAttention(d_model, nhead)
+        self.temp = nn.parameter.Parameter(torch.tensor(1.0), requires_grad=True)
+        sample_offset = torch.tensor(
+            [
+                [pos[0] - nsample[1] / 2 + 0.5, pos[1] - nsample[1] / 2 + 0.5]
+                for pos in product(range(nsample[1]), range(nsample[1]))
+            ]
+        )  # r^2*2
+        self.sample_offset = nn.parameter.Parameter(sample_offset, requires_grad=False)
 
-    def forward(self,query,key,value,flow,size_q,size_kv,mask0=None, mask1=None,ds0=[4,4],ds1=[4,4]):
+    def forward(
+        self,
+        query,
+        key,
+        value,
+        flow,
+        size_q,
+        size_kv,
+        mask0=None,
+        mask1=None,
+        ds0=[4, 4],
+        ds1=[4, 4],
+    ):
         """
         Args:
             q,k,v (torch.Tensor): [B, C, L]
@@ -45,123 +65,217 @@ class HierachicalAttention(Module):
         Return:
             all_message (torch.Tensor): [B, C, H, W]
         """
-        
-        variance=flow[:,:,:,2:]
-        offset=flow[:,:,:,:2]  #B*H*W*2
-        bs=query.shape[0]
-        h0,w0=size_q[0],size_q[1]
-        h1,w1=size_kv[0],size_kv[1]
-        variance=torch.exp(0.5*variance)*self.radius_scale #b*h*w*2(pixel scale)
-        span_scale=torch.clamp((variance*2/self.nsample[1]),min=1) #b*h*w*2
-
-        sub_sample0,sub_sample1=[ds0,2,1],[ds1,2,1]
-        q_list=[F.avg_pool2d(query.view(bs,-1,h0,w0),kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample0]
-        k_list=[F.avg_pool2d(key.view(bs,-1,h1,w1),kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample1]
-        v_list=[F.avg_pool2d(value.view(bs,-1,h1,w1),kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample1] #n_level
-        
-        offset_list=[F.avg_pool2d(offset.permute(0,3,1,2),kernel_size=sub_size*self.nsample[0],stride=sub_size*self.nsample[0]).permute(0,2,3,1)/sub_size for sub_size in sub_sample0[1:]] #n_level-1
-        span_list=[F.avg_pool2d(span_scale.permute(0,3,1,2),kernel_size=sub_size*self.nsample[0],stride=sub_size*self.nsample[0]).permute(0,2,3,1) for sub_size in sub_sample0[1:]] #n_level-1
+
+        variance = flow[:, :, :, 2:]
+        offset = flow[:, :, :, :2]  # B*H*W*2
+        bs = query.shape[0]
+        h0, w0 = size_q[0], size_q[1]
+        h1, w1 = size_kv[0], size_kv[1]
+        variance = torch.exp(0.5 * variance) * self.radius_scale  # b*h*w*2(pixel scale)
+        span_scale = torch.clamp((variance * 2 / self.nsample[1]), min=1)  # b*h*w*2
+
+        sub_sample0, sub_sample1 = [ds0, 2, 1], [ds1, 2, 1]
+        q_list = [
+            F.avg_pool2d(
+                query.view(bs, -1, h0, w0), kernel_size=sub_size, stride=sub_size
+            )
+            for sub_size in sub_sample0
+        ]
+        k_list = [
+            F.avg_pool2d(
+                key.view(bs, -1, h1, w1), kernel_size=sub_size, stride=sub_size
+            )
+            for sub_size in sub_sample1
+        ]
+        v_list = [
+            F.avg_pool2d(
+                value.view(bs, -1, h1, w1), kernel_size=sub_size, stride=sub_size
+            )
+            for sub_size in sub_sample1
+        ]  # n_level
+
+        offset_list = [
+            F.avg_pool2d(
+                offset.permute(0, 3, 1, 2),
+                kernel_size=sub_size * self.nsample[0],
+                stride=sub_size * self.nsample[0],
+            ).permute(0, 2, 3, 1)
+            / sub_size
+            for sub_size in sub_sample0[1:]
+        ]  # n_level-1
+        span_list = [
+            F.avg_pool2d(
+                span_scale.permute(0, 3, 1, 2),
+                kernel_size=sub_size * self.nsample[0],
+                stride=sub_size * self.nsample[0],
+            ).permute(0, 2, 3, 1)
+            for sub_size in sub_sample0[1:]
+        ]  # n_level-1
 
         if mask0 is not None:
-            mask0,mask1=mask0.view(bs,1,h0,w0),mask1.view(bs,1,h1,w1)
-            mask0_list=[-F.max_pool2d(-mask0,kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample0]
-            mask1_list=[-F.max_pool2d(-mask1,kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample1]
+            mask0, mask1 = mask0.view(bs, 1, h0, w0), mask1.view(bs, 1, h1, w1)
+            mask0_list = [
+                -F.max_pool2d(-mask0, kernel_size=sub_size, stride=sub_size)
+                for sub_size in sub_sample0
+            ]
+            mask1_list = [
+                -F.max_pool2d(-mask1, kernel_size=sub_size, stride=sub_size)
+                for sub_size in sub_sample1
+            ]
         else:
-            mask0_list=mask1_list=[None,None,None]
-
-        message_list=[]
-        #full attention at coarse scale
-        mask0_flatten=mask0_list[0].view(bs,-1) if mask0 is not None else None
-        mask1_flatten=mask1_list[0].view(bs,-1) if mask1 is not None else None
-        message_list.append(self.fullattention(q_list[0],k_list[0],v_list[0],mask0_flatten,mask1_flatten,self.temp).view(bs,self.d_model,h0//ds0[0],w0//ds0[1]))
-
-        for index in range(1,self.nlevel):
-            q,k,v=q_list[index],k_list[index],v_list[index]
-            mask0,mask1=mask0_list[index],mask1_list[index]
-            s,o=span_list[index-1],offset_list[index-1] #B*h*w(*2)
-            q,k,v,sample_pixel,mask_sample=self.partition_token(q,k,v,o,s,mask0) #B*Head*D*G*N(G*N=H*W for q)
-            message_list.append(self.group_attention(q,k,v,1,mask_sample).view(bs,self.d_model,h0//sub_sample0[index],w0//sub_sample0[index]))
-        #fuse
-        all_message=torch.cat([F.upsample(message_list[idx],scale_factor=sub_sample0[idx],mode='nearest') \
-                    for idx in range(self.nlevel)],dim=1).view(bs,-1,h0*w0) #b*3d*H*W
-        
-        all_message=self.merge_head(all_message).view(bs,-1,h0,w0) #b*d*H*W
+            mask0_list = mask1_list = [None, None, None]
+
+        message_list = []
+        # full attention at coarse scale
+        mask0_flatten = mask0_list[0].view(bs, -1) if mask0 is not None else None
+        mask1_flatten = mask1_list[0].view(bs, -1) if mask1 is not None else None
+        message_list.append(
+            self.fullattention(
+                q_list[0], k_list[0], v_list[0], mask0_flatten, mask1_flatten, self.temp
+            ).view(bs, self.d_model, h0 // ds0[0], w0 // ds0[1])
+        )
+
+        for index in range(1, self.nlevel):
+            q, k, v = q_list[index], k_list[index], v_list[index]
+            mask0, mask1 = mask0_list[index], mask1_list[index]
+            s, o = span_list[index - 1], offset_list[index - 1]  # B*h*w(*2)
+            q, k, v, sample_pixel, mask_sample = self.partition_token(
+                q, k, v, o, s, mask0
+            )  # B*Head*D*G*N(G*N=H*W for q)
+            message_list.append(
+                self.group_attention(q, k, v, 1, mask_sample).view(
+                    bs, self.d_model, h0 // sub_sample0[index], w0 // sub_sample0[index]
+                )
+            )
+        # fuse
+        all_message = torch.cat(
+            [
+                F.upsample(
+                    message_list[idx], scale_factor=sub_sample0[idx], mode="nearest"
+                )
+                for idx in range(self.nlevel)
+            ],
+            dim=1,
+        ).view(
+            bs, -1, h0 * w0
+        )  # b*3d*H*W
+
+        all_message = self.merge_head(all_message).view(bs, -1, h0, w0)  # b*d*H*W
         return all_message
-      
-    def partition_token(self,q,k,v,offset,span_scale,maskv):
-        #q,k,v: B*C*H*W
-        #o: B*H/2*W/2*2
-        #span_scale:B*H*W
-        bs=q.shape[0]
-        h,w=q.shape[2],q.shape[3]
-        hk,wk=k.shape[2],k.shape[3]
-        offset=offset.view(bs,-1,2)
-        span_scale=span_scale.view(bs,-1,1,2)
-        #B*G*2
-        offset_sample=self.sample_offset[None,None]*span_scale
-        sample_pixel=offset[:,:,None]+offset_sample#B*G*r^2*2
-        sample_norm=sample_pixel/torch.tensor([wk/2,hk/2]).cuda()[None,None,None]-1
-        
-        q = q.view(bs, -1 , h // self.nsample[0], self.nsample[0], w // self.nsample[0], self.nsample[0]).\
-                permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, self.nhead,self.d_model//self.nhead, -1,self.nsample[0]**2)#B*head*D*G*N(G*N=H*W for q)
-        #sample token
-        k=F.grid_sample(k, grid=sample_norm).view(bs, self.nhead,self.d_model//self.nhead,-1, self.nsample[1]**2) #B*head*D*G*r^2
-        v=F.grid_sample(v, grid=sample_norm).view(bs, self.nhead,self.d_model//self.nhead,-1, self.nsample[1]**2) #B*head*D*G*r^2
-        #import pdb;pdb.set_trace()
+
+    def partition_token(self, q, k, v, offset, span_scale, maskv):
+        # q,k,v: B*C*H*W
+        # o: B*H/2*W/2*2
+        # span_scale:B*H*W
+        bs = q.shape[0]
+        h, w = q.shape[2], q.shape[3]
+        hk, wk = k.shape[2], k.shape[3]
+        offset = offset.view(bs, -1, 2)
+        span_scale = span_scale.view(bs, -1, 1, 2)
+        # B*G*2
+        offset_sample = self.sample_offset[None, None] * span_scale
+        sample_pixel = offset[:, :, None] + offset_sample  # B*G*r^2*2
+        sample_norm = (
+            sample_pixel / torch.tensor([wk / 2, hk / 2]).cuda()[None, None, None] - 1
+        )
+
+        q = (
+            q.view(
+                bs,
+                -1,
+                h // self.nsample[0],
+                self.nsample[0],
+                w // self.nsample[0],
+                self.nsample[0],
+            )
+            .permute(0, 1, 2, 4, 3, 5)
+            .contiguous()
+            .view(bs, self.nhead, self.d_model // self.nhead, -1, self.nsample[0] ** 2)
+        )  # B*head*D*G*N(G*N=H*W for q)
+        # sample token
+        k = F.grid_sample(k, grid=sample_norm).view(
+            bs, self.nhead, self.d_model // self.nhead, -1, self.nsample[1] ** 2
+        )  # B*head*D*G*r^2
+        v = F.grid_sample(v, grid=sample_norm).view(
+            bs, self.nhead, self.d_model // self.nhead, -1, self.nsample[1] ** 2
+        )  # B*head*D*G*r^2
+        # import pdb;pdb.set_trace()
         if maskv is not None:
-            mask_sample=F.grid_sample(maskv.view(bs,-1,h,w).float(),grid=sample_norm,mode='nearest')==1 #B*1*G*r^2
+            mask_sample = (
+                F.grid_sample(
+                    maskv.view(bs, -1, h, w).float(), grid=sample_norm, mode="nearest"
+                )
+                == 1
+            )  # B*1*G*r^2
         else:
-            mask_sample=None
-        return q,k,v,sample_pixel,mask_sample
-
+            mask_sample = None
+        return q, k, v, sample_pixel, mask_sample
 
-    def group_attention(self,query,key,value,temp,mask_sample=None):
-        #q,k,v: B*Head*D*G*N(G*N=H*W for q)
-        bs=query.shape[0]
-        #import pdb;pdb.set_trace()
+    def group_attention(self, query, key, value, temp, mask_sample=None):
+        # q,k,v: B*Head*D*G*N(G*N=H*W for q)
+        bs = query.shape[0]
+        # import pdb;pdb.set_trace()
         QK = torch.einsum("bhdgn,bhdgm->bhgnm", query, key)
         if mask_sample is not None:
-            num_head,number_n=QK.shape[1],QK.shape[3]
-            QK.masked_fill_(~(mask_sample[:,:,:,None]).expand(-1,num_head,-1,number_n,-1).bool(), float(-1e8))
+            num_head, number_n = QK.shape[1], QK.shape[3]
+            QK.masked_fill_(
+                ~(mask_sample[:, :, :, None])
+                .expand(-1, num_head, -1, number_n, -1)
+                .bool(),
+                float(-1e8),
+            )
         # Compute the attention and the weighted average
-        softmax_temp = temp / query.size(2)**.5  # sqrt(D)
+        softmax_temp = temp / query.size(2) ** 0.5  # sqrt(D)
         A = torch.softmax(softmax_temp * QK, dim=-1)
-        queried_values = torch.einsum("bhgnm,bhdgm->bhdgn", A, value).contiguous().view(bs,self.d_model,-1)
+        queried_values = (
+            torch.einsum("bhgnm,bhdgm->bhdgn", A, value)
+            .contiguous()
+            .view(bs, self.d_model, -1)
+        )
         return queried_values
 
-    
 
 class FullAttention(Module):
-    def __init__(self,d_model,nhead):
+    def __init__(self, d_model, nhead):
         super().__init__()
-        self.d_model=d_model
-        self.nhead=nhead
+        self.d_model = d_model
+        self.nhead = nhead
 
-    def forward(self, q, k,v , mask0=None, mask1=None, temp=1):
-        """ Multi-head scaled dot-product attention, a.k.a full attention.
+    def forward(self, q, k, v, mask0=None, mask1=None, temp=1):
+        """Multi-head scaled dot-product attention, a.k.a full attention.
         Args:
             q,k,v: [N, D, L]
             mask: [N, L]
         Returns:
             msg: [N,L]
         """
-        bs=q.shape[0]
-        q,k,v=q.view(bs,self.nhead,self.d_model//self.nhead,-1),k.view(bs,self.nhead,self.d_model//self.nhead,-1),v.view(bs,self.nhead,self.d_model//self.nhead,-1)
+        bs = q.shape[0]
+        q, k, v = (
+            q.view(bs, self.nhead, self.d_model // self.nhead, -1),
+            k.view(bs, self.nhead, self.d_model // self.nhead, -1),
+            v.view(bs, self.nhead, self.d_model // self.nhead, -1),
+        )
         # Compute the unnormalized attention and apply the masks
         QK = torch.einsum("nhdl,nhds->nhls", q, k)
         if mask0 is not None:
-            QK.masked_fill_(~(mask0[:,None, :, None] * mask1[:, None, None]).bool(), float(-1e8))
+            QK.masked_fill_(
+                ~(mask0[:, None, :, None] * mask1[:, None, None]).bool(), float(-1e8)
+            )
         # Compute the attention and the weighted average
-        softmax_temp = temp / q.size(2)**.5  # sqrt(D)
+        softmax_temp = temp / q.size(2) ** 0.5  # sqrt(D)
         A = torch.softmax(softmax_temp * QK, dim=-1)
-        queried_values = torch.einsum("nhls,nhds->nhdl", A, v).contiguous().view(bs,self.d_model,-1)
+        queried_values = (
+            torch.einsum("nhls,nhds->nhdl", A, v)
+            .contiguous()
+            .view(bs, self.d_model, -1)
+        )
         return queried_values
- 
-    
+
 
 def elu_feature_map(x):
     return F.elu(x) + 1
 
+
 class LinearAttention(Module):
     def __init__(self, eps=1e-6):
         super().__init__()
@@ -169,7 +283,7 @@ class LinearAttention(Module):
         self.eps = eps
 
     def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
-        """ Multi-Head linear attention proposed in "Transformers are RNNs"
+        """Multi-Head linear attention proposed in "Transformers are RNNs"
         Args:
             queries: [N, L, H, D]
             keys: [N, S, H, D]
@@ -195,4 +309,4 @@ class LinearAttention(Module):
         Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
         queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
 
-        return queried_values.contiguous()
\ No newline at end of file
+        return queried_values.contiguous()
diff --git a/third_party/ASpanFormer/src/ASpanFormer/aspan_module/fine_preprocess.py b/third_party/ASpanFormer/src/ASpanFormer/aspan_module/fine_preprocess.py
index 5bb8eefd362240a9901a335f0e6e07770ff04567..6c37f76c3d5735508f950bb1239f5e93039b27ff 100644
--- a/third_party/ASpanFormer/src/ASpanFormer/aspan_module/fine_preprocess.py
+++ b/third_party/ASpanFormer/src/ASpanFormer/aspan_module/fine_preprocess.py
@@ -9,15 +9,15 @@ class FinePreprocess(nn.Module):
         super().__init__()
 
         self.config = config
-        self.cat_c_feat = config['fine_concat_coarse_feat']
-        self.W = self.config['fine_window_size']
+        self.cat_c_feat = config["fine_concat_coarse_feat"]
+        self.W = self.config["fine_window_size"]
 
-        d_model_c = self.config['coarse']['d_model']
-        d_model_f = self.config['fine']['d_model']
+        d_model_c = self.config["coarse"]["d_model"]
+        d_model_f = self.config["fine"]["d_model"]
         self.d_model_f = d_model_f
         if self.cat_c_feat:
             self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
-            self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True)
+            self.merge_feat = nn.Linear(2 * d_model_f, d_model_f, bias=True)
 
         self._reset_parameters()
 
@@ -28,32 +28,48 @@ class FinePreprocess(nn.Module):
 
     def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data):
         W = self.W
-        stride = data['hw0_f'][0] // data['hw0_c'][0]
+        stride = data["hw0_f"][0] // data["hw0_c"][0]
 
-        data.update({'W': W})
-        if data['b_ids'].shape[0] == 0:
+        data.update({"W": W})
+        if data["b_ids"].shape[0] == 0:
             feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
             feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
             return feat0, feat1
 
         # 1. unfold(crop) all local windows
-        feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2)
-        feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
-        feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2)
-        feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
+        feat_f0_unfold = F.unfold(
+            feat_f0, kernel_size=(W, W), stride=stride, padding=W // 2
+        )
+        feat_f0_unfold = rearrange(feat_f0_unfold, "n (c ww) l -> n l ww c", ww=W**2)
+        feat_f1_unfold = F.unfold(
+            feat_f1, kernel_size=(W, W), stride=stride, padding=W // 2
+        )
+        feat_f1_unfold = rearrange(feat_f1_unfold, "n (c ww) l -> n l ww c", ww=W**2)
 
         # 2. select only the predicted matches
-        feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']]  # [n, ww, cf]
-        feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
+        feat_f0_unfold = feat_f0_unfold[data["b_ids"], data["i_ids"]]  # [n, ww, cf]
+        feat_f1_unfold = feat_f1_unfold[data["b_ids"], data["j_ids"]]
 
         # option: use coarse-level loftr feature as context: concat and linear
         if self.cat_c_feat:
-            feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']],
-                                                   feat_c1[data['b_ids'], data['j_ids']]], 0))  # [2n, c]
-            feat_cf_win = self.merge_feat(torch.cat([
-                torch.cat([feat_f0_unfold, feat_f1_unfold], 0),  # [2n, ww, cf]
-                repeat(feat_c_win, 'n c -> n ww c', ww=W**2),  # [2n, ww, cf]
-            ], -1))
+            feat_c_win = self.down_proj(
+                torch.cat(
+                    [
+                        feat_c0[data["b_ids"], data["i_ids"]],
+                        feat_c1[data["b_ids"], data["j_ids"]],
+                    ],
+                    0,
+                )
+            )  # [2n, c]
+            feat_cf_win = self.merge_feat(
+                torch.cat(
+                    [
+                        torch.cat([feat_f0_unfold, feat_f1_unfold], 0),  # [2n, ww, cf]
+                        repeat(feat_c_win, "n c -> n ww c", ww=W**2),  # [2n, ww, cf]
+                    ],
+                    -1,
+                )
+            )
             feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
 
         return feat_f0_unfold, feat_f1_unfold
diff --git a/third_party/ASpanFormer/src/ASpanFormer/aspan_module/loftr.py b/third_party/ASpanFormer/src/ASpanFormer/aspan_module/loftr.py
index 7dcebaa7beee978b9b8abcec8bb1bd2cc6b60870..eaad9fdac1fbfc7a77f2db7c98c67bc41e335945 100644
--- a/third_party/ASpanFormer/src/ASpanFormer/aspan_module/loftr.py
+++ b/third_party/ASpanFormer/src/ASpanFormer/aspan_module/loftr.py
@@ -3,11 +3,9 @@ import torch
 import torch.nn as nn
 from .attention import LinearAttention
 
+
 class LoFTREncoderLayer(nn.Module):
-    def __init__(self,
-                 d_model,
-                 nhead,
-                 attention='linear'):
+    def __init__(self, d_model, nhead, attention="linear"):
         super(LoFTREncoderLayer, self).__init__()
 
         self.dim = d_model // nhead
@@ -22,9 +20,9 @@ class LoFTREncoderLayer(nn.Module):
 
         # feed-forward network
         self.mlp = nn.Sequential(
-            nn.Linear(d_model*2, d_model*2, bias=False),
+            nn.Linear(d_model * 2, d_model * 2, bias=False),
             nn.ReLU(True),
-            nn.Linear(d_model*2, d_model, bias=False),
+            nn.Linear(d_model * 2, d_model, bias=False),
         )
 
         # norm and dropout
@@ -43,16 +41,14 @@ class LoFTREncoderLayer(nn.Module):
         query, key, value = x, source, source
 
         # multi-head attention
-        query = self.q_proj(query).view(
-            bs, -1, self.nhead, self.dim)  # [N, L, (H, D)]
-        key = self.k_proj(key).view(bs, -1, self.nhead,
-                                    self.dim)  # [N, S, (H, D)]
+        query = self.q_proj(query).view(bs, -1, self.nhead, self.dim)  # [N, L, (H, D)]
+        key = self.k_proj(key).view(bs, -1, self.nhead, self.dim)  # [N, S, (H, D)]
         value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
 
         message = self.attention(
-            query, key, value, q_mask=x_mask, kv_mask=source_mask)  # [N, L, (H, D)]
-        message = self.merge(message.view(
-            bs, -1, self.nhead*self.dim))  # [N, L, C]
+            query, key, value, q_mask=x_mask, kv_mask=source_mask
+        )  # [N, L, (H, D)]
+        message = self.merge(message.view(bs, -1, self.nhead * self.dim))  # [N, L, C]
         message = self.norm1(message)
 
         # feed-forward network
@@ -69,13 +65,15 @@ class LocalFeatureTransformer(nn.Module):
         super(LocalFeatureTransformer, self).__init__()
 
         self.config = config
-        self.d_model = config['d_model']
-        self.nhead = config['nhead']
-        self.layer_names = config['layer_names']
+        self.d_model = config["d_model"]
+        self.nhead = config["nhead"]
+        self.layer_names = config["layer_names"]
         encoder_layer = LoFTREncoderLayer(
-            config['d_model'], config['nhead'], config['attention'])
+            config["d_model"], config["nhead"], config["attention"]
+        )
         self.layers = nn.ModuleList(
-            [copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
+            [copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]
+        )
         self._reset_parameters()
 
     def _reset_parameters(self):
@@ -93,20 +91,18 @@ class LocalFeatureTransformer(nn.Module):
         """
 
         assert self.d_model == feat0.size(
-            2), "the feature number of src and transformer must be equal"
+            2
+        ), "the feature number of src and transformer must be equal"
 
         index = 0
         for layer, name in zip(self.layers, self.layer_names):
-            if name == 'self':
-                feat0 = layer(feat0, feat0, mask0, mask0,
-                              type='self', index=index)
+            if name == "self":
+                feat0 = layer(feat0, feat0, mask0, mask0, type="self", index=index)
                 feat1 = layer(feat1, feat1, mask1, mask1)
-            elif name == 'cross':
+            elif name == "cross":
                 feat0 = layer(feat0, feat1, mask0, mask1)
-                feat1 = layer(feat1, feat0, mask1, mask0,
-                              type='cross', index=index)
+                feat1 = layer(feat1, feat0, mask1, mask0, type="cross", index=index)
                 index += 1
             else:
                 raise KeyError
         return feat0, feat1
-
diff --git a/third_party/ASpanFormer/src/ASpanFormer/aspan_module/transformer.py b/third_party/ASpanFormer/src/ASpanFormer/aspan_module/transformer.py
index c398f770833bf2066cda60a7ff546ec29640d433..125f555f93874af74c6e2595a360939f2f3bbce2 100644
--- a/third_party/ASpanFormer/src/ASpanFormer/aspan_module/transformer.py
+++ b/third_party/ASpanFormer/src/ASpanFormer/aspan_module/transformer.py
@@ -2,44 +2,42 @@ import copy
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-from .attention import FullAttention, HierachicalAttention ,layernorm2d
+from .attention import FullAttention, HierachicalAttention, layernorm2d
 
 
 class messageLayer_ini(nn.Module):
-
-    def __init__(self, d_model, d_flow,d_value, nhead):
+    def __init__(self, d_model, d_flow, d_value, nhead):
         super().__init__()
         super(messageLayer_ini, self).__init__()
 
         self.d_model = d_model
         self.d_flow = d_flow
-        self.d_value=d_value
+        self.d_value = d_value
         self.nhead = nhead
-        self.attention = FullAttention(d_model,nhead)
+        self.attention = FullAttention(d_model, nhead)
 
-        self.q_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False)
-        self.k_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False)
-        self.v_proj = nn.Conv1d(d_value, d_model, kernel_size=1,bias=False)
-        self.merge_head=nn.Conv1d(d_model,d_model,kernel_size=1,bias=False)
+        self.q_proj = nn.Conv1d(d_model, d_model, kernel_size=1, bias=False)
+        self.k_proj = nn.Conv1d(d_model, d_model, kernel_size=1, bias=False)
+        self.v_proj = nn.Conv1d(d_value, d_model, kernel_size=1, bias=False)
+        self.merge_head = nn.Conv1d(d_model, d_model, kernel_size=1, bias=False)
 
-        self.merge_f= self.merge_f = nn.Sequential(
-            nn.Conv2d(d_model*2, d_model*2, kernel_size=1, bias=False),
+        self.merge_f = self.merge_f = nn.Sequential(
+            nn.Conv2d(d_model * 2, d_model * 2, kernel_size=1, bias=False),
             nn.ReLU(True),
-            nn.Conv2d(d_model*2, d_model, kernel_size=1, bias=False),
+            nn.Conv2d(d_model * 2, d_model, kernel_size=1, bias=False),
         )
 
         self.norm1 = layernorm2d(d_model)
         self.norm2 = layernorm2d(d_model)
 
+    def forward(self, x0, x1, pos0, pos1, mask0=None, mask1=None):
+        # x1,x2: b*d*L
+        x0, x1 = self.update(x0, x1, pos1, mask0, mask1), self.update(
+            x1, x0, pos0, mask1, mask0
+        )
+        return x0, x1
 
-    def forward(self, x0, x1,pos0,pos1,mask0=None,mask1=None):
-        #x1,x2: b*d*L
-        x0,x1=self.update(x0,x1,pos1,mask0,mask1),\
-                self.update(x1,x0,pos0,mask1,mask0)
-        return x0,x1
-
-
-    def update(self,f0,f1,pos1,mask0,mask1):
+    def update(self, f0, f1, pos1, mask0, mask1):
         """
         Args:
             f0: [N, D, H, W]
@@ -47,53 +45,77 @@ class messageLayer_ini(nn.Module):
         Returns:
             f0_new: (N, d, h, w)
         """
-        bs,h,w=f0.shape[0],f0.shape[2],f0.shape[3]
+        bs, h, w = f0.shape[0], f0.shape[2], f0.shape[3]
 
-        f0_flatten,f1_flatten=f0.view(bs,self.d_model,-1),f1.view(bs,self.d_model,-1)
-        pos1_flatten=pos1.view(bs,self.d_value-self.d_model,-1)
-        f1_flatten_v=torch.cat([f1_flatten,pos1_flatten],dim=1)
+        f0_flatten, f1_flatten = f0.view(bs, self.d_model, -1), f1.view(
+            bs, self.d_model, -1
+        )
+        pos1_flatten = pos1.view(bs, self.d_value - self.d_model, -1)
+        f1_flatten_v = torch.cat([f1_flatten, pos1_flatten], dim=1)
 
-        queries,keys=self.q_proj(f0_flatten),self.k_proj(f1_flatten)
-        values=self.v_proj(f1_flatten_v).view(bs,self.nhead,self.d_model//self.nhead,-1)
-        
-        queried_values=self.attention(queries,keys,values,mask0,mask1)
-        msg=self.merge_head(queried_values).view(bs,-1,h,w)
-        msg=self.norm2(self.merge_f(torch.cat([f0,self.norm1(msg)],dim=1)))
-        return f0+msg
+        queries, keys = self.q_proj(f0_flatten), self.k_proj(f1_flatten)
+        values = self.v_proj(f1_flatten_v).view(
+            bs, self.nhead, self.d_model // self.nhead, -1
+        )
 
+        queried_values = self.attention(queries, keys, values, mask0, mask1)
+        msg = self.merge_head(queried_values).view(bs, -1, h, w)
+        msg = self.norm2(self.merge_f(torch.cat([f0, self.norm1(msg)], dim=1)))
+        return f0 + msg
 
 
 class messageLayer_gla(nn.Module):
-
-    def __init__(self,d_model,d_flow,d_value,
-                    nhead,radius_scale,nsample,update_flow=True):
+    def __init__(
+        self, d_model, d_flow, d_value, nhead, radius_scale, nsample, update_flow=True
+    ):
         super().__init__()
         self.d_model = d_model
-        self.d_flow=d_flow
-        self.d_value=d_value
+        self.d_flow = d_flow
+        self.d_value = d_value
         self.nhead = nhead
-        self.radius_scale=radius_scale
-        self.update_flow=update_flow
-        self.flow_decoder=nn.Sequential(
-                    nn.Conv1d(d_flow, d_flow//2, kernel_size=1, bias=False),
-                    nn.ReLU(True),
-                    nn.Conv1d(d_flow//2, 4, kernel_size=1, bias=False))
-        self.attention=HierachicalAttention(d_model,nhead,nsample,radius_scale)
-
-        self.q_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False)
-        self.k_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False)
-        self.v_proj = nn.Conv1d(d_value, d_model, kernel_size=1,bias=False)
-
-        d_extra=d_flow if update_flow else 0
-        self.merge_f=nn.Sequential(
-                     nn.Conv2d(d_model*2+d_extra, d_model+d_flow, kernel_size=1, bias=False),
-                     nn.ReLU(True),
-                     nn.Conv2d(d_model+d_flow, d_model+d_extra, kernel_size=3,padding=1, bias=False),
-                )
-        self.norm1 = layernorm2d(d_model)
-        self.norm2 = layernorm2d(d_model+d_extra)
+        self.radius_scale = radius_scale
+        self.update_flow = update_flow
+        self.flow_decoder = nn.Sequential(
+            nn.Conv1d(d_flow, d_flow // 2, kernel_size=1, bias=False),
+            nn.ReLU(True),
+            nn.Conv1d(d_flow // 2, 4, kernel_size=1, bias=False),
+        )
+        self.attention = HierachicalAttention(d_model, nhead, nsample, radius_scale)
 
-    def forward(self, x0, x1, flow_feature0,flow_feature1,pos0,pos1,mask0=None,mask1=None,ds0=[4,4],ds1=[4,4]):
+        self.q_proj = nn.Conv1d(d_model, d_model, kernel_size=1, bias=False)
+        self.k_proj = nn.Conv1d(d_model, d_model, kernel_size=1, bias=False)
+        self.v_proj = nn.Conv1d(d_value, d_model, kernel_size=1, bias=False)
+
+        d_extra = d_flow if update_flow else 0
+        self.merge_f = nn.Sequential(
+            nn.Conv2d(
+                d_model * 2 + d_extra, d_model + d_flow, kernel_size=1, bias=False
+            ),
+            nn.ReLU(True),
+            nn.Conv2d(
+                d_model + d_flow,
+                d_model + d_extra,
+                kernel_size=3,
+                padding=1,
+                bias=False,
+            ),
+        )
+        self.norm1 = layernorm2d(d_model)
+        self.norm2 = layernorm2d(d_model + d_extra)
+
+    def forward(
+        self,
+        x0,
+        x1,
+        flow_feature0,
+        flow_feature1,
+        pos0,
+        pos1,
+        mask0=None,
+        mask1=None,
+        ds0=[4, 4],
+        ds1=[4, 4],
+    ):
         """
         Args:
             x0 (torch.Tensor): [B, C, H, W]
@@ -101,88 +123,135 @@ class messageLayer_gla(nn.Module):
             flow_feature0 (torch.Tensor): [B, C', H, W]
             flow_feature1 (torch.Tensor): [B, C', H, W]
         """
-        flow0,flow1=self.decode_flow(flow_feature0,flow_feature1.shape[2:]),self.decode_flow(flow_feature1,flow_feature0.shape[2:])
-        x0_new,flow_feature0_new=self.update(x0,x1,flow0.detach(),flow_feature0,pos1,mask0,mask1,ds0,ds1)
-        x1_new,flow_feature1_new=self.update(x1,x0,flow1.detach(),flow_feature1,pos0,mask1,mask0,ds1,ds0)
-        return x0_new,x1_new,flow_feature0_new,flow_feature1_new,flow0,flow1
-
-    def update(self,x0,x1,flow0,flow_feature0,pos1,mask0,mask1,ds0,ds1):
-        bs=x0.shape[0]
-        queries,keys=self.q_proj(x0.view(bs,self.d_model,-1)),self.k_proj(x1.view(bs,self.d_model,-1))
-        x1_pos=torch.cat([x1,pos1],dim=1)
-        values=self.v_proj(x1_pos.view(bs,self.d_value,-1))
-        msg=self.attention(queries,keys,values,flow0,x0.shape[2:],x1.shape[2:],mask0,mask1,ds0,ds1)
+        flow0, flow1 = self.decode_flow(
+            flow_feature0, flow_feature1.shape[2:]
+        ), self.decode_flow(flow_feature1, flow_feature0.shape[2:])
+        x0_new, flow_feature0_new = self.update(
+            x0, x1, flow0.detach(), flow_feature0, pos1, mask0, mask1, ds0, ds1
+        )
+        x1_new, flow_feature1_new = self.update(
+            x1, x0, flow1.detach(), flow_feature1, pos0, mask1, mask0, ds1, ds0
+        )
+        return x0_new, x1_new, flow_feature0_new, flow_feature1_new, flow0, flow1
+
+    def update(self, x0, x1, flow0, flow_feature0, pos1, mask0, mask1, ds0, ds1):
+        bs = x0.shape[0]
+        queries, keys = self.q_proj(x0.view(bs, self.d_model, -1)), self.k_proj(
+            x1.view(bs, self.d_model, -1)
+        )
+        x1_pos = torch.cat([x1, pos1], dim=1)
+        values = self.v_proj(x1_pos.view(bs, self.d_value, -1))
+        msg = self.attention(
+            queries,
+            keys,
+            values,
+            flow0,
+            x0.shape[2:],
+            x1.shape[2:],
+            mask0,
+            mask1,
+            ds0,
+            ds1,
+        )
 
         if self.update_flow:
-            update_feature=torch.cat([x0,flow_feature0],dim=1)
+            update_feature = torch.cat([x0, flow_feature0], dim=1)
         else:
-            update_feature=x0
-        msg=self.norm2(self.merge_f(torch.cat([update_feature,self.norm1(msg)],dim=1)))
-        update_feature=update_feature+msg
-
-        x0_new,flow_feature0_new=update_feature[:,:self.d_model],update_feature[:,self.d_model:]
-        return x0_new,flow_feature0_new
-
-    def decode_flow(self,flow_feature,kshape):
-        bs,h,w=flow_feature.shape[0],flow_feature.shape[2],flow_feature.shape[3]
-        scale_factor=torch.tensor([kshape[1],kshape[0]]).cuda()[None,None,None]
-        flow=self.flow_decoder(flow_feature.view(bs,-1,h*w)).permute(0,2,1).view(bs,h,w,4)
-        flow_coordinates=torch.sigmoid(flow[:,:,:,:2])*scale_factor
-        flow_var=flow[:,:,:,2:]
-        flow=torch.cat([flow_coordinates,flow_var],dim=-1) #B*H*W*4
+            update_feature = x0
+        msg = self.norm2(
+            self.merge_f(torch.cat([update_feature, self.norm1(msg)], dim=1))
+        )
+        update_feature = update_feature + msg
+
+        x0_new, flow_feature0_new = (
+            update_feature[:, : self.d_model],
+            update_feature[:, self.d_model :],
+        )
+        return x0_new, flow_feature0_new
+
+    def decode_flow(self, flow_feature, kshape):
+        bs, h, w = flow_feature.shape[0], flow_feature.shape[2], flow_feature.shape[3]
+        scale_factor = torch.tensor([kshape[1], kshape[0]]).cuda()[None, None, None]
+        flow = (
+            self.flow_decoder(flow_feature.view(bs, -1, h * w))
+            .permute(0, 2, 1)
+            .view(bs, h, w, 4)
+        )
+        flow_coordinates = torch.sigmoid(flow[:, :, :, :2]) * scale_factor
+        flow_var = flow[:, :, :, 2:]
+        flow = torch.cat([flow_coordinates, flow_var], dim=-1)  # B*H*W*4
         return flow
 
 
 class flow_initializer(nn.Module):
-
     def __init__(self, dim, dim_flow, nhead, layer_num):
         super().__init__()
-        self.layer_num= layer_num
+        self.layer_num = layer_num
         self.dim = dim
         self.dim_flow = dim_flow
 
-        encoder_layer = messageLayer_ini(
-            dim ,dim_flow,dim+dim_flow , nhead)
+        encoder_layer = messageLayer_ini(dim, dim_flow, dim + dim_flow, nhead)
         self.layers_coarse = nn.ModuleList(
-            [copy.deepcopy(encoder_layer) for _ in range(layer_num)])
-        self.decoupler = nn.Conv2d(
-                self.dim, self.dim+self.dim_flow, kernel_size=1)
-        self.up_merge = nn.Conv2d(2*dim, dim, kernel_size=1)
+            [copy.deepcopy(encoder_layer) for _ in range(layer_num)]
+        )
+        self.decoupler = nn.Conv2d(self.dim, self.dim + self.dim_flow, kernel_size=1)
+        self.up_merge = nn.Conv2d(2 * dim, dim, kernel_size=1)
 
-    def forward(self, feat0, feat1,pos0,pos1,mask0=None,mask1=None,ds0=[4,4],ds1=[4,4]):
+    def forward(
+        self, feat0, feat1, pos0, pos1, mask0=None, mask1=None, ds0=[4, 4], ds1=[4, 4]
+    ):
         # feat0: [B, C, H0, W0]
         # feat1: [B, C, H1, W1]
         # use low-res MHA to initialize flow feature
         bs = feat0.size(0)
-        h0,w0,h1,w1=feat0.shape[2],feat0.shape[3],feat1.shape[2],feat1.shape[3]
+        h0, w0, h1, w1 = feat0.shape[2], feat0.shape[3], feat1.shape[2], feat1.shape[3]
 
         # coarse level
-        sub_feat0, sub_feat1 = F.avg_pool2d(feat0, ds0, stride=ds0), \
-                            F.avg_pool2d(feat1, ds1, stride=ds1)
+        sub_feat0, sub_feat1 = F.avg_pool2d(feat0, ds0, stride=ds0), F.avg_pool2d(
+            feat1, ds1, stride=ds1
+        )
+
+        sub_pos0, sub_pos1 = F.avg_pool2d(pos0, ds0, stride=ds0), F.avg_pool2d(
+            pos1, ds1, stride=ds1
+        )
 
-        sub_pos0,sub_pos1=F.avg_pool2d(pos0, ds0, stride=ds0), \
-                            F.avg_pool2d(pos1, ds1, stride=ds1)
-    
         if mask0 is not None:
-            mask0,mask1=-F.max_pool2d(-mask0.view(bs,1,h0,w0),ds0,stride=ds0).view(bs,-1),\
-                        -F.max_pool2d(-mask1.view(bs,1,h1,w1),ds1,stride=ds1).view(bs,-1)
-        
+            mask0, mask1 = -F.max_pool2d(
+                -mask0.view(bs, 1, h0, w0), ds0, stride=ds0
+            ).view(bs, -1), -F.max_pool2d(
+                -mask1.view(bs, 1, h1, w1), ds1, stride=ds1
+            ).view(
+                bs, -1
+            )
+
         for layer in self.layers_coarse:
-            sub_feat0, sub_feat1 = layer(sub_feat0, sub_feat1,sub_pos0,sub_pos1,mask0,mask1)
+            sub_feat0, sub_feat1 = layer(
+                sub_feat0, sub_feat1, sub_pos0, sub_pos1, mask0, mask1
+            )
         # decouple flow and visual features
-        decoupled_feature0, decoupled_feature1 = self.decoupler(sub_feat0),self.decoupler(sub_feat1) 
+        decoupled_feature0, decoupled_feature1 = self.decoupler(
+            sub_feat0
+        ), self.decoupler(sub_feat1)
+
+        sub_feat0, sub_flow_feature0 = (
+            decoupled_feature0[:, : self.dim],
+            decoupled_feature0[:, self.dim :],
+        )
+        sub_feat1, sub_flow_feature1 = (
+            decoupled_feature1[:, : self.dim],
+            decoupled_feature1[:, self.dim :],
+        )
+        update_feat0, flow_feature0 = F.upsample(
+            sub_feat0, scale_factor=ds0, mode="bilinear"
+        ), F.upsample(sub_flow_feature0, scale_factor=ds0, mode="bilinear")
+        update_feat1, flow_feature1 = F.upsample(
+            sub_feat1, scale_factor=ds1, mode="bilinear"
+        ), F.upsample(sub_flow_feature1, scale_factor=ds1, mode="bilinear")
 
-        sub_feat0, sub_flow_feature0 = decoupled_feature0[:,:self.dim], decoupled_feature0[:, self.dim:]
-        sub_feat1, sub_flow_feature1 = decoupled_feature1[:,:self.dim], decoupled_feature1[:, self.dim:]
-        update_feat0, flow_feature0 = F.upsample(sub_feat0, scale_factor=ds0, mode='bilinear'),\
-                                        F.upsample(sub_flow_feature0, scale_factor=ds0, mode='bilinear')
-        update_feat1, flow_feature1 = F.upsample(sub_feat1, scale_factor=ds1, mode='bilinear'),\
-                                        F.upsample(sub_flow_feature1, scale_factor=ds1, mode='bilinear')
-        
-        feat0 = feat0+self.up_merge(torch.cat([feat0, update_feat0], dim=1))
-        feat1 = feat1+self.up_merge(torch.cat([feat1, update_feat1], dim=1))
-    
-        return feat0,feat1,flow_feature0,flow_feature1 #b*c*h*w
+        feat0 = feat0 + self.up_merge(torch.cat([feat0, update_feat0], dim=1))
+        feat1 = feat1 + self.up_merge(torch.cat([feat1, update_feat1], dim=1))
+
+        return feat0, feat1, flow_feature0, flow_feature1  # b*c*h*w
 
 
 class LocalFeatureTransformer_Flow(nn.Module):
@@ -192,27 +261,49 @@ class LocalFeatureTransformer_Flow(nn.Module):
         super(LocalFeatureTransformer_Flow, self).__init__()
 
         self.config = config
-        self.d_model = config['d_model']
-        self.nhead = config['nhead']
+        self.d_model = config["d_model"]
+        self.nhead = config["nhead"]
+
+        self.pos_transform = nn.Conv2d(
+            config["d_model"], config["d_flow"], kernel_size=1, bias=False
+        )
+        self.ini_layer = flow_initializer(
+            self.d_model, config["d_flow"], config["nhead"], config["ini_layer_num"]
+        )
 
-        self.pos_transform=nn.Conv2d(config['d_model'],config['d_flow'],kernel_size=1,bias=False)
-        self.ini_layer = flow_initializer(self.d_model, config['d_flow'], config['nhead'],config['ini_layer_num'])
-        
         encoder_layer = messageLayer_gla(
-            config['d_model'], config['d_flow'], config['d_flow']+config['d_model'], config['nhead'],config['radius_scale'],config['nsample'])
-        encoder_layer_last=messageLayer_gla(
-            config['d_model'], config['d_flow'], config['d_flow']+config['d_model'], config['nhead'],config['radius_scale'],config['nsample'],update_flow=False)
-        self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(config['layer_num']-1)]+[encoder_layer_last])
+            config["d_model"],
+            config["d_flow"],
+            config["d_flow"] + config["d_model"],
+            config["nhead"],
+            config["radius_scale"],
+            config["nsample"],
+        )
+        encoder_layer_last = messageLayer_gla(
+            config["d_model"],
+            config["d_flow"],
+            config["d_flow"] + config["d_model"],
+            config["nhead"],
+            config["radius_scale"],
+            config["nsample"],
+            update_flow=False,
+        )
+        self.layers = nn.ModuleList(
+            [copy.deepcopy(encoder_layer) for _ in range(config["layer_num"] - 1)]
+            + [encoder_layer_last]
+        )
         self._reset_parameters()
-        
+
     def _reset_parameters(self):
-        for name,p in self.named_parameters():
-            if 'temp' in name or 'sample_offset' in name:
+        for name, p in self.named_parameters():
+            if "temp" in name or "sample_offset" in name:
                 continue
             if p.dim() > 1:
                 nn.init.xavier_uniform_(p)
 
-    def forward(self, feat0, feat1,pos0,pos1,mask0=None,mask1=None,ds0=[4,4],ds1=[4,4]):
+    def forward(
+        self, feat0, feat1, pos0, pos1, mask0=None, mask1=None, ds0=[4, 4], ds1=[4, 4]
+    ):
         """
         Args:
             feat0 (torch.Tensor): [N, C, H, W]
@@ -224,21 +315,37 @@ class LocalFeatureTransformer_Flow(nn.Module):
             flow_list: [L,N,H,W,4]*1(2)
         """
         bs = feat0.size(0)
-        
-        pos0,pos1=self.pos_transform(pos0),self.pos_transform(pos1)
-        pos0,pos1=pos0.expand(bs,-1,-1,-1),pos1.expand(bs,-1,-1,-1)
+
+        pos0, pos1 = self.pos_transform(pos0), self.pos_transform(pos1)
+        pos0, pos1 = pos0.expand(bs, -1, -1, -1), pos1.expand(bs, -1, -1, -1)
         assert self.d_model == feat0.size(
-            1), "the feature number of src and transformer must be equal"
-       
-        flow_list=[[],[]]# [px,py,sx,sy] 
+            1
+        ), "the feature number of src and transformer must be equal"
+
+        flow_list = [[], []]  # [px,py,sx,sy]
         if mask0 is not None:
-            mask0,mask1=mask0[:,None].float(),mask1[:,None].float()
-        feat0,feat1, flow_feature0, flow_feature1 = self.ini_layer(feat0, feat1,pos0,pos1,mask0,mask1,ds0,ds1)
+            mask0, mask1 = mask0[:, None].float(), mask1[:, None].float()
+        feat0, feat1, flow_feature0, flow_feature1 = self.ini_layer(
+            feat0, feat1, pos0, pos1, mask0, mask1, ds0, ds1
+        )
         for layer in self.layers:
-            feat0,feat1,flow_feature0,flow_feature1,flow0,flow1=layer(feat0,feat1,flow_feature0,flow_feature1,pos0,pos1,mask0,mask1,ds0,ds1)
+            feat0, feat1, flow_feature0, flow_feature1, flow0, flow1 = layer(
+                feat0,
+                feat1,
+                flow_feature0,
+                flow_feature1,
+                pos0,
+                pos1,
+                mask0,
+                mask1,
+                ds0,
+                ds1,
+            )
             flow_list[0].append(flow0)
             flow_list[1].append(flow1)
-        flow_list[0]=torch.stack(flow_list[0],dim=0)
-        flow_list[1]=torch.stack(flow_list[1],dim=0)
-        feat0, feat1 = feat0.permute(0, 2, 3, 1).view(bs, -1, self.d_model), feat1.permute(0, 2, 3, 1).view(bs, -1, self.d_model)
-        return feat0, feat1, flow_list
\ No newline at end of file
+        flow_list[0] = torch.stack(flow_list[0], dim=0)
+        flow_list[1] = torch.stack(flow_list[1], dim=0)
+        feat0, feat1 = feat0.permute(0, 2, 3, 1).view(
+            bs, -1, self.d_model
+        ), feat1.permute(0, 2, 3, 1).view(bs, -1, self.d_model)
+        return feat0, feat1, flow_list
diff --git a/third_party/ASpanFormer/src/ASpanFormer/aspanformer.py b/third_party/ASpanFormer/src/ASpanFormer/aspanformer.py
index 01b797a420cf5ccea5b53fee3ceda8b5e157573f..113e912bf219ff6fcbc7a1642454ac08b455fd0d 100644
--- a/third_party/ASpanFormer/src/ASpanFormer/aspanformer.py
+++ b/third_party/ASpanFormer/src/ASpanFormer/aspanformer.py
@@ -5,7 +5,11 @@ from einops.einops import rearrange
 
 from .backbone import build_backbone
 from .utils.position_encoding import PositionEncodingSine
-from .aspan_module import LocalFeatureTransformer_Flow, LocalFeatureTransformer, FinePreprocess
+from .aspan_module import (
+    LocalFeatureTransformer_Flow,
+    LocalFeatureTransformer,
+    FinePreprocess,
+)
 from .utils.coarse_matching import CoarseMatching
 from .utils.fine_matching import FineMatching
 
@@ -19,16 +23,18 @@ class ASpanFormer(nn.Module):
         # Modules
         self.backbone = build_backbone(config)
         self.pos_encoding = PositionEncodingSine(
-            config['coarse']['d_model'],pre_scaling=[config['coarse']['train_res'],config['coarse']['test_res']])
-        self.loftr_coarse = LocalFeatureTransformer_Flow(config['coarse'])
-        self.coarse_matching = CoarseMatching(config['match_coarse'])
+            config["coarse"]["d_model"],
+            pre_scaling=[config["coarse"]["train_res"], config["coarse"]["test_res"]],
+        )
+        self.loftr_coarse = LocalFeatureTransformer_Flow(config["coarse"])
+        self.coarse_matching = CoarseMatching(config["match_coarse"])
         self.fine_preprocess = FinePreprocess(config)
         self.loftr_fine = LocalFeatureTransformer(config["fine"])
         self.fine_matching = FineMatching()
-        self.coarsest_level=config['coarse']['coarsest_level']
+        self.coarsest_level = config["coarse"]["coarsest_level"]
 
     def forward(self, data, online_resize=False):
-        """ 
+        """
         Update:
             data (dict): {
                 'image0': (torch.Tensor): (N, 1, H, W)
@@ -38,96 +44,135 @@ class ASpanFormer(nn.Module):
             }
         """
         if online_resize:
-            assert data['image0'].shape[0]==1 and data['image1'].shape[1]==1
-            self.resize_input(data,self.config['coarse']['train_res'])
+            assert data["image0"].shape[0] == 1 and data["image1"].shape[1] == 1
+            self.resize_input(data, self.config["coarse"]["train_res"])
         else:
-            data['pos_scale0'],data['pos_scale1']=None,None
+            data["pos_scale0"], data["pos_scale1"] = None, None
 
         # 1. Local Feature CNN
-        data.update({
-            'bs': data['image0'].size(0),
-            'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:]
-        })
-        
-        if data['hw0_i'] == data['hw1_i']:  # faster & better BN convergence
+        data.update(
+            {
+                "bs": data["image0"].size(0),
+                "hw0_i": data["image0"].shape[2:],
+                "hw1_i": data["image1"].shape[2:],
+            }
+        )
+
+        if data["hw0_i"] == data["hw1_i"]:  # faster & better BN convergence
             feats_c, feats_f = self.backbone(
-                torch.cat([data['image0'], data['image1']], dim=0))
+                torch.cat([data["image0"], data["image1"]], dim=0)
+            )
             (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(
-                data['bs']), feats_f.split(data['bs'])
+                data["bs"]
+            ), feats_f.split(data["bs"])
         else:  # handle different input shapes
             (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(
-                data['image0']), self.backbone(data['image1'])
+                data["image0"]
+            ), self.backbone(data["image1"])
 
-        data.update({
-            'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
-            'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:]
-        })
+        data.update(
+            {
+                "hw0_c": feat_c0.shape[2:],
+                "hw1_c": feat_c1.shape[2:],
+                "hw0_f": feat_f0.shape[2:],
+                "hw1_f": feat_f1.shape[2:],
+            }
+        )
 
         # 2. coarse-level loftr module
         # add featmap with positional encoding, then flatten it to sequence [N, HW, C]
-        [feat_c0, pos_encoding0], [feat_c1, pos_encoding1] = self.pos_encoding(feat_c0,data['pos_scale0']), self.pos_encoding(feat_c1,data['pos_scale1'])
-        feat_c0 = rearrange(feat_c0, 'n c h w -> n c h w ')
-        feat_c1 = rearrange(feat_c1, 'n c h w -> n c h w ')
+        [feat_c0, pos_encoding0], [feat_c1, pos_encoding1] = self.pos_encoding(
+            feat_c0, data["pos_scale0"]
+        ), self.pos_encoding(feat_c1, data["pos_scale1"])
+        feat_c0 = rearrange(feat_c0, "n c h w -> n c h w ")
+        feat_c1 = rearrange(feat_c1, "n c h w -> n c h w ")
 
-        #TODO:adjust ds 
-        ds0=[int(data['hw0_c'][0]/self.coarsest_level[0]),int(data['hw0_c'][1]/self.coarsest_level[1])]
-        ds1=[int(data['hw1_c'][0]/self.coarsest_level[0]),int(data['hw1_c'][1]/self.coarsest_level[1])]
+        # TODO:adjust ds
+        ds0 = [
+            int(data["hw0_c"][0] / self.coarsest_level[0]),
+            int(data["hw0_c"][1] / self.coarsest_level[1]),
+        ]
+        ds1 = [
+            int(data["hw1_c"][0] / self.coarsest_level[0]),
+            int(data["hw1_c"][1] / self.coarsest_level[1]),
+        ]
         if online_resize:
-            ds0,ds1=[4,4],[4,4]
+            ds0, ds1 = [4, 4], [4, 4]
 
         mask_c0 = mask_c1 = None  # mask is useful in training
-        if 'mask0' in data:
-            mask_c0, mask_c1 = data['mask0'].flatten(
-                -2), data['mask1'].flatten(-2)
+        if "mask0" in data:
+            mask_c0, mask_c1 = data["mask0"].flatten(-2), data["mask1"].flatten(-2)
         feat_c0, feat_c1, flow_list = self.loftr_coarse(
-            feat_c0, feat_c1,pos_encoding0,pos_encoding1,mask_c0,mask_c1,ds0,ds1)
+            feat_c0, feat_c1, pos_encoding0, pos_encoding1, mask_c0, mask_c1, ds0, ds1
+        )
 
         # 3. match coarse-level and register predicted offset
-        self.coarse_matching(feat_c0, feat_c1, flow_list,data,
-                             mask_c0=mask_c0, mask_c1=mask_c1)
+        self.coarse_matching(
+            feat_c0, feat_c1, flow_list, data, mask_c0=mask_c0, mask_c1=mask_c1
+        )
 
         # 4. fine-level refinement
         feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(
-            feat_f0, feat_f1, feat_c0, feat_c1, data)
+            feat_f0, feat_f1, feat_c0, feat_c1, data
+        )
         if feat_f0_unfold.size(0) != 0:  # at least one coarse level predicted
             feat_f0_unfold, feat_f1_unfold = self.loftr_fine(
-                feat_f0_unfold, feat_f1_unfold)
+                feat_f0_unfold, feat_f1_unfold
+            )
 
         # 5. match fine-level
         self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
 
         # 6. resize match coordinates back to input resolution
         if online_resize:
-            data['mkpts0_f']*=data['online_resize_scale0']
-            data['mkpts1_f']*=data['online_resize_scale1']
-        
+            data["mkpts0_f"] *= data["online_resize_scale0"]
+            data["mkpts1_f"] *= data["online_resize_scale1"]
+
     def load_state_dict(self, state_dict, *args, **kwargs):
         for k in list(state_dict.keys()):
-            if k.startswith('matcher.'):
-                if 'sample_offset' in k:
+            if k.startswith("matcher."):
+                if "sample_offset" in k:
                     state_dict.pop(k)
                 else:
-                    state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k)
+                    state_dict[k.replace("matcher.", "", 1)] = state_dict.pop(k)
         return super().load_state_dict(state_dict, *args, **kwargs)
-    
-    def resize_input(self,data,train_res,df=32):
-        h0,w0,h1,w1=data['image0'].shape[2],data['image0'].shape[3],data['image1'].shape[2],data['image1'].shape[3]
-        data['image0'],data['image1']=self.resize_df(data['image0'],df),self.resize_df(data['image1'],df)
-        
-        if len(train_res)==1:
-            train_res_h=train_res_w=train_res
+
+    def resize_input(self, data, train_res, df=32):
+        h0, w0, h1, w1 = (
+            data["image0"].shape[2],
+            data["image0"].shape[3],
+            data["image1"].shape[2],
+            data["image1"].shape[3],
+        )
+        data["image0"], data["image1"] = self.resize_df(
+            data["image0"], df
+        ), self.resize_df(data["image1"], df)
+
+        if len(train_res) == 1:
+            train_res_h = train_res_w = train_res
         else:
-            train_res_h,train_res_w=train_res[0],train_res[1]
-        data['pos_scale0'],data['pos_scale1']=[train_res_h/data['image0'].shape[2],train_res_w/data['image0'].shape[3]],\
-                                  [train_res_h/data['image1'].shape[2],train_res_w/data['image1'].shape[3]] 
-        data['online_resize_scale0'],data['online_resize_scale1']=torch.tensor([w0/data['image0'].shape[3],h0/data['image0'].shape[2]])[None].cuda(),\
-                                                                    torch.tensor([w1/data['image1'].shape[3],h1/data['image1'].shape[2]])[None].cuda()
-
-    def resize_df(self,image,df=32):
-        h,w=image.shape[2],image.shape[3]
-        h_new,w_new=h//df*df,w//df*df
-        if h!=h_new or w!=w_new:
-            img_new=transforms.Resize([h_new,w_new]).forward(image)
+            train_res_h, train_res_w = train_res[0], train_res[1]
+        data["pos_scale0"], data["pos_scale1"] = [
+            train_res_h / data["image0"].shape[2],
+            train_res_w / data["image0"].shape[3],
+        ], [
+            train_res_h / data["image1"].shape[2],
+            train_res_w / data["image1"].shape[3],
+        ]
+        data["online_resize_scale0"], data["online_resize_scale1"] = (
+            torch.tensor([w0 / data["image0"].shape[3], h0 / data["image0"].shape[2]])[
+                None
+            ].cuda(),
+            torch.tensor([w1 / data["image1"].shape[3], h1 / data["image1"].shape[2]])[
+                None
+            ].cuda(),
+        )
+
+    def resize_df(self, image, df=32):
+        h, w = image.shape[2], image.shape[3]
+        h_new, w_new = h // df * df, w // df * df
+        if h != h_new or w != w_new:
+            img_new = transforms.Resize([h_new, w_new]).forward(image)
         else:
-            img_new=image
+            img_new = image
         return img_new
diff --git a/third_party/ASpanFormer/src/ASpanFormer/backbone/__init__.py b/third_party/ASpanFormer/src/ASpanFormer/backbone/__init__.py
index b6e731b3f53ab367c89ef0ea8e1cbffb0d990775..ae8593230b281e960ece68c04dcf214769e50f08 100644
--- a/third_party/ASpanFormer/src/ASpanFormer/backbone/__init__.py
+++ b/third_party/ASpanFormer/src/ASpanFormer/backbone/__init__.py
@@ -2,10 +2,12 @@ from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4
 
 
 def build_backbone(config):
-    if config['backbone_type'] == 'ResNetFPN':
-        if config['resolution'] == (8, 2):
-            return ResNetFPN_8_2(config['resnetfpn'])
-        elif config['resolution'] == (16, 4):
-            return ResNetFPN_16_4(config['resnetfpn'])
+    if config["backbone_type"] == "ResNetFPN":
+        if config["resolution"] == (8, 2):
+            return ResNetFPN_8_2(config["resnetfpn"])
+        elif config["resolution"] == (16, 4):
+            return ResNetFPN_16_4(config["resnetfpn"])
     else:
-        raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
+        raise ValueError(
+            f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported."
+        )
diff --git a/third_party/ASpanFormer/src/ASpanFormer/backbone/resnet_fpn.py b/third_party/ASpanFormer/src/ASpanFormer/backbone/resnet_fpn.py
index 985e5b3f273a51e51447a8025ca3aadbe46752eb..948c72940ab00e5741e2788eea841d124333c8ed 100644
--- a/third_party/ASpanFormer/src/ASpanFormer/backbone/resnet_fpn.py
+++ b/third_party/ASpanFormer/src/ASpanFormer/backbone/resnet_fpn.py
@@ -4,12 +4,16 @@ import torch.nn.functional as F
 
 def conv1x1(in_planes, out_planes, stride=1):
     """1x1 convolution without padding"""
-    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
+    return nn.Conv2d(
+        in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False
+    )
 
 
 def conv3x3(in_planes, out_planes, stride=1):
     """3x3 convolution with padding"""
-    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
+    return nn.Conv2d(
+        in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
+    )
 
 
 class BasicBlock(nn.Module):
@@ -25,8 +29,7 @@ class BasicBlock(nn.Module):
             self.downsample = None
         else:
             self.downsample = nn.Sequential(
-                conv1x1(in_planes, planes, stride=stride),
-                nn.BatchNorm2d(planes)
+                conv1x1(in_planes, planes, stride=stride), nn.BatchNorm2d(planes)
             )
 
     def forward(self, x):
@@ -37,7 +40,7 @@ class BasicBlock(nn.Module):
         if self.downsample is not None:
             x = self.downsample(x)
 
-        return self.relu(x+y)
+        return self.relu(x + y)
 
 
 class ResNetFPN_8_2(nn.Module):
@@ -50,14 +53,16 @@ class ResNetFPN_8_2(nn.Module):
         super().__init__()
         # Config
         block = BasicBlock
-        initial_dim = config['initial_dim']
-        block_dims = config['block_dims']
+        initial_dim = config["initial_dim"]
+        block_dims = config["block_dims"]
 
         # Class Variable
         self.in_planes = initial_dim
 
         # Networks
-        self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
+        self.conv1 = nn.Conv2d(
+            1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False
+        )
         self.bn1 = nn.BatchNorm2d(initial_dim)
         self.relu = nn.ReLU(inplace=True)
 
@@ -84,7 +89,7 @@ class ResNetFPN_8_2(nn.Module):
 
         for m in self.modules():
             if isinstance(m, nn.Conv2d):
-                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
             elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                 nn.init.constant_(m.weight, 1)
                 nn.init.constant_(m.bias, 0)
@@ -107,13 +112,17 @@ class ResNetFPN_8_2(nn.Module):
         # FPN
         x3_out = self.layer3_outconv(x3)
 
-        x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
+        x3_out_2x = F.interpolate(
+            x3_out, scale_factor=2.0, mode="bilinear", align_corners=True
+        )
         x2_out = self.layer2_outconv(x2)
-        x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
+        x2_out = self.layer2_outconv2(x2_out + x3_out_2x)
 
-        x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
+        x2_out_2x = F.interpolate(
+            x2_out, scale_factor=2.0, mode="bilinear", align_corners=True
+        )
         x1_out = self.layer1_outconv(x1)
-        x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
+        x1_out = self.layer1_outconv2(x1_out + x2_out_2x)
 
         return [x3_out, x1_out]
 
@@ -128,14 +137,16 @@ class ResNetFPN_16_4(nn.Module):
         super().__init__()
         # Config
         block = BasicBlock
-        initial_dim = config['initial_dim']
-        block_dims = config['block_dims']
+        initial_dim = config["initial_dim"]
+        block_dims = config["block_dims"]
 
         # Class Variable
         self.in_planes = initial_dim
 
         # Networks
-        self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
+        self.conv1 = nn.Conv2d(
+            1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False
+        )
         self.bn1 = nn.BatchNorm2d(initial_dim)
         self.relu = nn.ReLU(inplace=True)
 
@@ -164,7 +175,7 @@ class ResNetFPN_16_4(nn.Module):
 
         for m in self.modules():
             if isinstance(m, nn.Conv2d):
-                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
             elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                 nn.init.constant_(m.weight, 1)
                 nn.init.constant_(m.bias, 0)
@@ -188,12 +199,16 @@ class ResNetFPN_16_4(nn.Module):
         # FPN
         x4_out = self.layer4_outconv(x4)
 
-        x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True)
+        x4_out_2x = F.interpolate(
+            x4_out, scale_factor=2.0, mode="bilinear", align_corners=True
+        )
         x3_out = self.layer3_outconv(x3)
-        x3_out = self.layer3_outconv2(x3_out+x4_out_2x)
+        x3_out = self.layer3_outconv2(x3_out + x4_out_2x)
 
-        x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
+        x3_out_2x = F.interpolate(
+            x3_out, scale_factor=2.0, mode="bilinear", align_corners=True
+        )
         x2_out = self.layer2_outconv(x2)
-        x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
+        x2_out = self.layer2_outconv2(x2_out + x3_out_2x)
 
         return [x4_out, x2_out]
diff --git a/third_party/ASpanFormer/src/ASpanFormer/utils/coarse_matching.py b/third_party/ASpanFormer/src/ASpanFormer/utils/coarse_matching.py
index 953ee55a09144a4ce0099e709f3a992d021aa0ab..c506479a978c3ebb20c6736ed30f0ef0a351d4b9 100644
--- a/third_party/ASpanFormer/src/ASpanFormer/utils/coarse_matching.py
+++ b/third_party/ASpanFormer/src/ASpanFormer/utils/coarse_matching.py
@@ -7,8 +7,9 @@ from time import time
 
 INF = 1e9
 
+
 def mask_border(m, b: int, v):
-    """ Mask borders with value
+    """Mask borders with value
     Args:
         m (torch.Tensor): [N, H0, W0, H1, W1]
         b (int)
@@ -39,22 +40,21 @@ def mask_border_with_padding(m, bd, v, p_m0, p_m1):
     h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
     h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
     for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
-        m[b_idx, h0 - bd:] = v
-        m[b_idx, :, w0 - bd:] = v
-        m[b_idx, :, :, h1 - bd:] = v
-        m[b_idx, :, :, :, w1 - bd:] = v
+        m[b_idx, h0 - bd :] = v
+        m[b_idx, :, w0 - bd :] = v
+        m[b_idx, :, :, h1 - bd :] = v
+        m[b_idx, :, :, :, w1 - bd :] = v
 
 
 def compute_max_candidates(p_m0, p_m1):
     """Compute the max candidates of all pairs within a batch
-    
+
     Args:
         p_m0, p_m1 (torch.Tensor): padded masks
     """
     h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
     h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
-    max_cand = torch.sum(
-        torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
+    max_cand = torch.sum(torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
     return max_cand
 
 
@@ -63,29 +63,32 @@ class CoarseMatching(nn.Module):
         super().__init__()
         self.config = config
         # general config
-        self.thr = config['thr']
-        self.border_rm = config['border_rm']
+        self.thr = config["thr"]
+        self.border_rm = config["border_rm"]
         # -- # for trainig fine-level LoFTR
-        self.train_coarse_percent = config['train_coarse_percent']
-        self.train_pad_num_gt_min = config['train_pad_num_gt_min']
-        
+        self.train_coarse_percent = config["train_coarse_percent"]
+        self.train_pad_num_gt_min = config["train_pad_num_gt_min"]
+
         # we provide 2 options for differentiable matching
-        self.match_type = config['match_type']
-        if self.match_type == 'dual_softmax':
-            self.temperature=nn.parameter.Parameter(torch.tensor(10.), requires_grad=True)
-        elif self.match_type == 'sinkhorn':
+        self.match_type = config["match_type"]
+        if self.match_type == "dual_softmax":
+            self.temperature = nn.parameter.Parameter(
+                torch.tensor(10.0), requires_grad=True
+            )
+        elif self.match_type == "sinkhorn":
             try:
                 from .superglue import log_optimal_transport
             except ImportError:
                 raise ImportError("download superglue.py first!")
             self.log_optimal_transport = log_optimal_transport
             self.bin_score = nn.Parameter(
-                torch.tensor(config['skh_init_bin_score'], requires_grad=True))
-            self.skh_iters = config['skh_iters']
-            self.skh_prefilter = config['skh_prefilter']
+                torch.tensor(config["skh_init_bin_score"], requires_grad=True)
+            )
+            self.skh_iters = config["skh_iters"]
+            self.skh_prefilter = config["skh_prefilter"]
         else:
             raise NotImplementedError()
-     
+
     def forward(self, feat_c0, feat_c1, flow_list, data, mask_c0=None, mask_c1=None):
         """
         Args:
@@ -108,29 +111,32 @@ class CoarseMatching(nn.Module):
         """
         N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2)
         # normalize
-        feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
-                               [feat_c0, feat_c1])
-
-        if self.match_type == 'dual_softmax':
-            sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0,
-                                      feat_c1) * self.temperature
+        feat_c0, feat_c1 = map(
+            lambda feat: feat / feat.shape[-1] ** 0.5, [feat_c0, feat_c1]
+        )
+
+        if self.match_type == "dual_softmax":
+            sim_matrix = (
+                torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) * self.temperature
+            )
             if mask_c0 is not None:
                 sim_matrix.masked_fill_(
-                    ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
-                    -INF)
+                    ~(mask_c0[..., None] * mask_c1[:, None]).bool(), -INF
+                )
             conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
-            
-        elif self.match_type == 'sinkhorn':
+
+        elif self.match_type == "sinkhorn":
             # sinkhorn, dustbin included
             sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1)
             if mask_c0 is not None:
                 sim_matrix[:, :L, :S].masked_fill_(
-                    ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
-                    -INF)
+                    ~(mask_c0[..., None] * mask_c1[:, None]).bool(), -INF
+                )
 
             # build uniform prior & use sinkhorn
             log_assign_matrix = self.log_optimal_transport(
-                sim_matrix, self.bin_score, self.skh_iters)
+                sim_matrix, self.bin_score, self.skh_iters
+            )
             assign_matrix = log_assign_matrix.exp()
             conf_matrix = assign_matrix[:, :-1, :-1]
 
@@ -141,18 +147,21 @@ class CoarseMatching(nn.Module):
                 conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0
                 conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0
 
-            if self.config['sparse_spvs']:
-                data.update({'conf_matrix_with_bin': assign_matrix.clone()})
+            if self.config["sparse_spvs"]:
+                data.update({"conf_matrix_with_bin": assign_matrix.clone()})
 
-        data.update({'conf_matrix': conf_matrix})
+        data.update({"conf_matrix": conf_matrix})
         # predict coarse matches from conf_matrix
         data.update(**self.get_coarse_match(conf_matrix, data))
 
-        #update predicted offset
-        if flow_list[0].shape[2]==flow_list[1].shape[2] and flow_list[0].shape[3]==flow_list[1].shape[3]:
-            flow_list=torch.stack(flow_list,dim=0)
-        data.update({'predict_flow':flow_list}) #[2*L*B*H*W*4]
-        self.get_offset_match(flow_list,data,mask_c0,mask_c1)
+        # update predicted offset
+        if (
+            flow_list[0].shape[2] == flow_list[1].shape[2]
+            and flow_list[0].shape[3] == flow_list[1].shape[3]
+        ):
+            flow_list = torch.stack(flow_list, dim=0)
+        data.update({"predict_flow": flow_list})  # [2*L*B*H*W*4]
+        self.get_offset_match(flow_list, data, mask_c0, mask_c1)
 
     @torch.no_grad()
     def get_coarse_match(self, conf_matrix, data):
@@ -172,28 +181,33 @@ class CoarseMatching(nn.Module):
                 'mconf' (torch.Tensor): [M]}
         """
         axes_lengths = {
-            'h0c': data['hw0_c'][0],
-            'w0c': data['hw0_c'][1],
-            'h1c': data['hw1_c'][0],
-            'w1c': data['hw1_c'][1]
+            "h0c": data["hw0_c"][0],
+            "w0c": data["hw0_c"][1],
+            "h1c": data["hw1_c"][0],
+            "w1c": data["hw1_c"][1],
         }
         _device = conf_matrix.device
         # 1. confidence thresholding
         mask = conf_matrix > self.thr
-        mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
-                         **axes_lengths)
-        if 'mask0' not in data:
+        mask = rearrange(
+            mask, "b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c", **axes_lengths
+        )
+        if "mask0" not in data:
             mask_border(mask, self.border_rm, False)
         else:
-            mask_border_with_padding(mask, self.border_rm, False,
-                                     data['mask0'], data['mask1'])
-        mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',
-                         **axes_lengths)
+            mask_border_with_padding(
+                mask, self.border_rm, False, data["mask0"], data["mask1"]
+            )
+        mask = rearrange(
+            mask, "b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)", **axes_lengths
+        )
 
         # 2. mutual nearest
-        mask = mask \
-            * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \
+        mask = (
+            mask
+            * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0])
             * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
+        )
 
         # 3. find all valid coarse matches
         # this only works when at most one `True` in each row
@@ -208,67 +222,79 @@ class CoarseMatching(nn.Module):
             # NOTE:
             # The sampling is performed across all pairs in a batch without manually balancing
             # #samples for fine-level increases w.r.t. batch_size
-            if 'mask0' not in data:
-                num_candidates_max = mask.size(0) * max(
-                    mask.size(1), mask.size(2))
+            if "mask0" not in data:
+                num_candidates_max = mask.size(0) * max(mask.size(1), mask.size(2))
             else:
                 num_candidates_max = compute_max_candidates(
-                    data['mask0'], data['mask1'])
-            num_matches_train = int(num_candidates_max *
-                                    self.train_coarse_percent)
+                    data["mask0"], data["mask1"]
+                )
+            num_matches_train = int(num_candidates_max * self.train_coarse_percent)
             num_matches_pred = len(b_ids)
-            assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches"
-            
+            assert (
+                self.train_pad_num_gt_min < num_matches_train
+            ), "min-num-gt-pad should be less than num-train-matches"
+
             # pred_indices is to select from prediction
             if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
                 pred_indices = torch.arange(num_matches_pred, device=_device)
             else:
                 pred_indices = torch.randint(
                     num_matches_pred,
-                    (num_matches_train - self.train_pad_num_gt_min, ),
-                    device=_device)
+                    (num_matches_train - self.train_pad_num_gt_min,),
+                    device=_device,
+                )
 
             # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
             gt_pad_indices = torch.randint(
-                    len(data['spv_b_ids']),
-                    (max(num_matches_train - num_matches_pred,
-                        self.train_pad_num_gt_min), ),
-                    device=_device)
-            mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device)  # set conf of gt paddings to all zero
+                len(data["spv_b_ids"]),
+                (max(num_matches_train - num_matches_pred, self.train_pad_num_gt_min),),
+                device=_device,
+            )
+            mconf_gt = torch.zeros(
+                len(data["spv_b_ids"]), device=_device
+            )  # set conf of gt paddings to all zero
 
             b_ids, i_ids, j_ids, mconf = map(
-                lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
-                                       dim=0),
-                *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']],
-                     [j_ids, data['spv_j_ids']], [mconf, mconf_gt]))
+                lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], dim=0),
+                *zip(
+                    [b_ids, data["spv_b_ids"]],
+                    [i_ids, data["spv_i_ids"]],
+                    [j_ids, data["spv_j_ids"]],
+                    [mconf, mconf_gt],
+                )
+            )
 
         # These matches select patches that feed into fine-level network
-        coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}
+        coarse_matches = {"b_ids": b_ids, "i_ids": i_ids, "j_ids": j_ids}
 
         # 4. Update with matches in original image resolution
-        scale = data['hw0_i'][0] / data['hw0_c'][0]
-        scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
-        scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
-        mkpts0_c = torch.stack(
-            [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]],
-            dim=1) * scale0
-        mkpts1_c = torch.stack(
-            [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]],
-            dim=1) * scale1
+        scale = data["hw0_i"][0] / data["hw0_c"][0]
+        scale0 = scale * data["scale0"][b_ids] if "scale0" in data else scale
+        scale1 = scale * data["scale1"][b_ids] if "scale1" in data else scale
+        mkpts0_c = (
+            torch.stack([i_ids % data["hw0_c"][1], i_ids // data["hw0_c"][1]], dim=1)
+            * scale0
+        )
+        mkpts1_c = (
+            torch.stack([j_ids % data["hw1_c"][1], j_ids // data["hw1_c"][1]], dim=1)
+            * scale1
+        )
 
         # These matches is the current prediction (for visualization)
-        coarse_matches.update({
-            'gt_mask': mconf == 0,
-            'm_bids': b_ids[mconf != 0],  # mconf == 0 => gt matches
-            'mkpts0_c': mkpts0_c[mconf != 0],
-            'mkpts1_c': mkpts1_c[mconf != 0],
-            'mconf': mconf[mconf != 0]
-        })
+        coarse_matches.update(
+            {
+                "gt_mask": mconf == 0,
+                "m_bids": b_ids[mconf != 0],  # mconf == 0 => gt matches
+                "mkpts0_c": mkpts0_c[mconf != 0],
+                "mkpts1_c": mkpts1_c[mconf != 0],
+                "mconf": mconf[mconf != 0],
+            }
+        )
 
         return coarse_matches
 
     @torch.no_grad()
-    def get_offset_match(self, flow_list, data,mask1,mask2):
+    def get_offset_match(self, flow_list, data, mask1, mask2):
         """
         Args:
             offset (torch.Tensor): [L, B, H, W, 2]
@@ -280,52 +306,62 @@ class CoarseMatching(nn.Module):
                 'mkpts1_c' (torch.Tensor): [M, 2],
                 'mconf' (torch.Tensor): [M]}
         """
-        offset1=flow_list[0]
-        bs,layer_num=offset1.shape[1],offset1.shape[0]
-        
-        #left side
-        offset1=offset1.view(layer_num,bs,-1,4)
-        conf1=offset1[:,:,:,2:].mean(dim=-1)
+        offset1 = flow_list[0]
+        bs, layer_num = offset1.shape[1], offset1.shape[0]
+
+        # left side
+        offset1 = offset1.view(layer_num, bs, -1, 4)
+        conf1 = offset1[:, :, :, 2:].mean(dim=-1)
         if mask1 is not None:
-            conf1.masked_fill_(~mask1.bool()[None].expand(layer_num,-1,-1),100)
-        offset1=offset1[:,:,:,:2]
-        self.get_offset_match_work(offset1,conf1,data,'left')
-
-        #rihgt side
-        if len(flow_list)==2:
-            offset2=flow_list[1].view(layer_num,bs,-1,4)
-            conf2=offset2[:,:,:,2:].mean(dim=-1)
+            conf1.masked_fill_(~mask1.bool()[None].expand(layer_num, -1, -1), 100)
+        offset1 = offset1[:, :, :, :2]
+        self.get_offset_match_work(offset1, conf1, data, "left")
+
+        # rihgt side
+        if len(flow_list) == 2:
+            offset2 = flow_list[1].view(layer_num, bs, -1, 4)
+            conf2 = offset2[:, :, :, 2:].mean(dim=-1)
             if mask2 is not None:
-                conf2.masked_fill_(~mask2.bool()[None].expand(layer_num,-1,-1),100)
-            offset2=offset2[:,:,:,:2]
-            self.get_offset_match_work(offset2,conf2,data,'right')
-
+                conf2.masked_fill_(~mask2.bool()[None].expand(layer_num, -1, -1), 100)
+            offset2 = offset2[:, :, :, :2]
+            self.get_offset_match_work(offset2, conf2, data, "right")
 
     @torch.no_grad()
-    def get_offset_match_work(self, offset,conf, data,side):
-        bs,layer_num=offset.shape[1],offset.shape[0]
+    def get_offset_match_work(self, offset, conf, data, side):
+        bs, layer_num = offset.shape[1], offset.shape[0]
         # 1. confidence thresholding
-        mask_conf= conf<2
+        mask_conf = conf < 2
         for index in range(bs):
-            mask_conf[:,index,0]=True #safe guard in case that no match survives
+            mask_conf[:, index, 0] = True  # safe guard in case that no match survives
         # 3. find offset matches
-        scale = data['hw0_i'][0] / data['hw0_c'][0]
-        l_ids,b_ids,i_ids = torch.where(mask_conf)
-        j_coor=offset[l_ids,b_ids,i_ids,:2] *scale#[N,2]
-        i_coor=torch.stack([i_ids%data['hw0_c'][1],i_ids//data['hw0_c'][1]],dim=1)*scale
-        #i_coor=torch.as_tensor([[index%data['hw0_c'][1],index//data['hw0_c'][1]] for index in i_ids]).cuda().float()*scale #[N,2]
+        scale = data["hw0_i"][0] / data["hw0_c"][0]
+        l_ids, b_ids, i_ids = torch.where(mask_conf)
+        j_coor = offset[l_ids, b_ids, i_ids, :2] * scale  # [N,2]
+        i_coor = (
+            torch.stack([i_ids % data["hw0_c"][1], i_ids // data["hw0_c"][1]], dim=1)
+            * scale
+        )
+        # i_coor=torch.as_tensor([[index%data['hw0_c'][1],index//data['hw0_c'][1]] for index in i_ids]).cuda().float()*scale #[N,2]
         # These matches is the current prediction (for visualization)
-        data.update({
-            'offset_bids_'+side: b_ids,  # mconf == 0 => gt matches
-            'offset_lids_'+side: l_ids,
-            'conf'+side: conf[mask_conf]
-        })
-        
-        if side=='right':
-            data.update({'offset_kpts0_f_'+side: j_coor.detach(),
-            'offset_kpts1_f_'+side: i_coor})
+        data.update(
+            {
+                "offset_bids_" + side: b_ids,  # mconf == 0 => gt matches
+                "offset_lids_" + side: l_ids,
+                "conf" + side: conf[mask_conf],
+            }
+        )
+
+        if side == "right":
+            data.update(
+                {
+                    "offset_kpts0_f_" + side: j_coor.detach(),
+                    "offset_kpts1_f_" + side: i_coor,
+                }
+            )
         else:
-            data.update({'offset_kpts0_f_'+side: i_coor,
-            'offset_kpts1_f_'+side: j_coor.detach()})
-
-    
+            data.update(
+                {
+                    "offset_kpts0_f_" + side: i_coor,
+                    "offset_kpts1_f_" + side: j_coor.detach(),
+                }
+            )
diff --git a/third_party/ASpanFormer/src/ASpanFormer/utils/cvpr_ds_config.py b/third_party/ASpanFormer/src/ASpanFormer/utils/cvpr_ds_config.py
index fdc57e84936c805cb387b6239ca4a5ff6154e22e..1ffe9c067b1fb95a75dd102c5947c82d03dbea89 100644
--- a/third_party/ASpanFormer/src/ASpanFormer/utils/cvpr_ds_config.py
+++ b/third_party/ASpanFormer/src/ASpanFormer/utils/cvpr_ds_config.py
@@ -8,7 +8,7 @@ def lower_config(yacs_cfg):
 
 
 _CN = CN()
-_CN.BACKBONE_TYPE = 'ResNetFPN'
+_CN.BACKBONE_TYPE = "ResNetFPN"
 _CN.RESOLUTION = (8, 2)  # options: [(8, 2), (16, 4)]
 _CN.FINE_WINDOW_SIZE = 5  # window_size in fine_level, must be odd
 _CN.FINE_CONCAT_COARSE_FEAT = True
@@ -23,15 +23,15 @@ _CN.COARSE = CN()
 _CN.COARSE.D_MODEL = 256
 _CN.COARSE.D_FFN = 256
 _CN.COARSE.NHEAD = 8
-_CN.COARSE.LAYER_NAMES = ['self', 'cross'] * 4
-_CN.COARSE.ATTENTION = 'linear'  # options: ['linear', 'full']
+_CN.COARSE.LAYER_NAMES = ["self", "cross"] * 4
+_CN.COARSE.ATTENTION = "linear"  # options: ['linear', 'full']
 _CN.COARSE.TEMP_BUG_FIX = False
 
 # 3. Coarse-Matching config
 _CN.MATCH_COARSE = CN()
 _CN.MATCH_COARSE.THR = 0.1
 _CN.MATCH_COARSE.BORDER_RM = 2
-_CN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'  # options: ['dual_softmax, 'sinkhorn']
+_CN.MATCH_COARSE.MATCH_TYPE = "dual_softmax"  # options: ['dual_softmax, 'sinkhorn']
 _CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
 _CN.MATCH_COARSE.SKH_ITERS = 3
 _CN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
@@ -44,7 +44,7 @@ _CN.FINE = CN()
 _CN.FINE.D_MODEL = 128
 _CN.FINE.D_FFN = 128
 _CN.FINE.NHEAD = 8
-_CN.FINE.LAYER_NAMES = ['self', 'cross'] * 1
-_CN.FINE.ATTENTION = 'linear'
+_CN.FINE.LAYER_NAMES = ["self", "cross"] * 1
+_CN.FINE.ATTENTION = "linear"
 
 default_cfg = lower_config(_CN)
diff --git a/third_party/ASpanFormer/src/ASpanFormer/utils/fine_matching.py b/third_party/ASpanFormer/src/ASpanFormer/utils/fine_matching.py
index 6e77aded52e1eb5c01e22c2738104f3b09d6922a..3f41b1db96016efb58888381284f86d448839ff0 100644
--- a/third_party/ASpanFormer/src/ASpanFormer/utils/fine_matching.py
+++ b/third_party/ASpanFormer/src/ASpanFormer/utils/fine_matching.py
@@ -26,35 +26,46 @@ class FineMatching(nn.Module):
         """
         M, WW, C = feat_f0.shape
         W = int(math.sqrt(WW))
-        scale = data['hw0_i'][0] / data['hw0_f'][0]
+        scale = data["hw0_i"][0] / data["hw0_f"][0]
         self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale
 
         # corner case: if no coarse matches found
         if M == 0:
-            assert self.training == False, "M is always >0, when training, see coarse_matching.py"
+            assert (
+                self.training == False
+            ), "M is always >0, when training, see coarse_matching.py"
             # logger.warning('No matches found in coarse-level.')
-            data.update({
-                'expec_f': torch.empty(0, 3, device=feat_f0.device),
-                'mkpts0_f': data['mkpts0_c'],
-                'mkpts1_f': data['mkpts1_c'],
-            })
+            data.update(
+                {
+                    "expec_f": torch.empty(0, 3, device=feat_f0.device),
+                    "mkpts0_f": data["mkpts0_c"],
+                    "mkpts1_f": data["mkpts1_c"],
+                }
+            )
             return
 
-        feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :]
-        sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1)
-        softmax_temp = 1. / C**.5
+        feat_f0_picked = feat_f0_picked = feat_f0[:, WW // 2, :]
+        sim_matrix = torch.einsum("mc,mrc->mr", feat_f0_picked, feat_f1)
+        softmax_temp = 1.0 / C**0.5
         heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W)
 
         # compute coordinates from heatmap
         coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0]  # [M, 2]
-        grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2)  # [1, WW, 2]
+        grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(
+            1, -1, 2
+        )  # [1, WW, 2]
 
         # compute std over <x, y>
-        var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2  # [M, 2]
-        std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1)  # [M]  clamp needed for numerical stability
-        
+        var = (
+            torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1)
+            - coords_normalized**2
+        )  # [M, 2]
+        std = torch.sum(
+            torch.sqrt(torch.clamp(var, min=1e-10)), -1
+        )  # [M]  clamp needed for numerical stability
+
         # for fine-level supervision
-        data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
+        data.update({"expec_f": torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
 
         # compute absolute kpt coords
         self.get_fine_match(coords_normalized, data)
@@ -64,11 +75,10 @@ class FineMatching(nn.Module):
         W, WW, C, scale = self.W, self.WW, self.C, self.scale
 
         # mkpts0_f and mkpts1_f
-        mkpts0_f = data['mkpts0_c']
-        scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
-        mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])]
+        mkpts0_f = data["mkpts0_c"]
+        scale1 = scale * data["scale1"][data["b_ids"]] if "scale0" in data else scale
+        mkpts1_f = (
+            data["mkpts1_c"] + (coords_normed * (W // 2) * scale1)[: len(data["mconf"])]
+        )
 
-        data.update({
-            "mkpts0_f": mkpts0_f,
-            "mkpts1_f": mkpts1_f
-        })
+        data.update({"mkpts0_f": mkpts0_f, "mkpts1_f": mkpts1_f})
diff --git a/third_party/ASpanFormer/src/ASpanFormer/utils/geometry.py b/third_party/ASpanFormer/src/ASpanFormer/utils/geometry.py
index f95cdb65b48324c4f4ceb20231b1bed992b41116..6101f738f2b2b7ee014fcb53a4032391939ed8cd 100644
--- a/third_party/ASpanFormer/src/ASpanFormer/utils/geometry.py
+++ b/third_party/ASpanFormer/src/ASpanFormer/utils/geometry.py
@@ -3,10 +3,10 @@ import torch
 
 @torch.no_grad()
 def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
-    """ Warp kpts0 from I0 to I1 with depth, K and Rt
+    """Warp kpts0 from I0 to I1 with depth, K and Rt
     Also check covisibility and depth consistency.
     Depth is consistent if relative error < 0.2 (hard-coded).
-    
+
     Args:
         kpts0 (torch.Tensor): [N, L, 2] - <x, y>,
         depth0 (torch.Tensor): [N, H, W],
@@ -22,33 +22,52 @@ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
 
     # Sample depth, get calculable_mask on depth != 0
     kpts0_depth = torch.stack(
-        [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
+        [
+            depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]]
+            for i in range(kpts0.shape[0])
+        ],
+        dim=0,
     )  # (N, L)
     nonzero_mask = kpts0_depth != 0
 
     # Unproject
-    kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None]  # (N, L, 3)
+    kpts0_h = (
+        torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
+        * kpts0_depth[..., None]
+    )  # (N, L, 3)
     kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1)  # (N, 3, L)
 
     # Rigid Transform
-    w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]]    # (N, 3, L)
+    w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]]  # (N, 3, L)
     w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
 
     # Project
     w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1)  # (N, L, 3)
-    w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4)  # (N, L, 2), +1e-4 to avoid zero depth
+    w_kpts0 = w_kpts0_h[:, :, :2] / (
+        w_kpts0_h[:, :, [2]] + 1e-4
+    )  # (N, L, 2), +1e-4 to avoid zero depth
 
     # Covisible Check
     h, w = depth1.shape[1:3]
-    covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
-        (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
+    covisible_mask = (
+        (w_kpts0[:, :, 0] > 0)
+        * (w_kpts0[:, :, 0] < w - 1)
+        * (w_kpts0[:, :, 1] > 0)
+        * (w_kpts0[:, :, 1] < h - 1)
+    )
     w_kpts0_long = w_kpts0.long()
     w_kpts0_long[~covisible_mask, :] = 0
 
     w_kpts0_depth = torch.stack(
-        [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
+        [
+            depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]]
+            for i in range(w_kpts0_long.shape[0])
+        ],
+        dim=0,
     )  # (N, L)
-    consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
+    consistent_mask = (
+        (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
+    ).abs() < 0.2
     valid_mask = nonzero_mask * covisible_mask * consistent_mask
 
     return valid_mask, w_kpts0
diff --git a/third_party/ASpanFormer/src/ASpanFormer/utils/position_encoding.py b/third_party/ASpanFormer/src/ASpanFormer/utils/position_encoding.py
index 07d384ae18370acb99ef00a788f628c967249ace..1da77ecef628e3e263b56fb501b6a6313f05c060 100644
--- a/third_party/ASpanFormer/src/ASpanFormer/utils/position_encoding.py
+++ b/third_party/ASpanFormer/src/ASpanFormer/utils/position_encoding.py
@@ -8,7 +8,7 @@ class PositionEncodingSine(nn.Module):
     This is a sinusoidal position encoding that generalized to 2-dimensional images
     """
 
-    def __init__(self, d_model, max_shape=(256, 256),pre_scaling=None):
+    def __init__(self, d_model, max_shape=(256, 256), pre_scaling=None):
         """
         Args:
             max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
@@ -18,44 +18,63 @@ class PositionEncodingSine(nn.Module):
                 We will remove the buggy impl after re-training all variants of our released models.
         """
         super().__init__()
-        self.d_model=d_model
-        self.max_shape=max_shape
-        self.pre_scaling=pre_scaling
+        self.d_model = d_model
+        self.max_shape = max_shape
+        self.pre_scaling = pre_scaling
 
         pe = torch.zeros((d_model, *max_shape))
         y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
         x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
 
         if pre_scaling[0] is not None and pre_scaling[1] is not None:
-            train_res,test_res=pre_scaling[0],pre_scaling[1]
-            x_position,y_position=x_position*train_res[1]/test_res[1],y_position*train_res[0]/test_res[0]
+            train_res, test_res = pre_scaling[0], pre_scaling[1]
+            x_position, y_position = (
+                x_position * train_res[1] / test_res[1],
+                y_position * train_res[0] / test_res[0],
+            )
 
-        div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
+        div_term = torch.exp(
+            torch.arange(0, d_model // 2, 2).float()
+            * (-math.log(10000.0) / (d_model // 2))
+        )
         div_term = div_term[:, None, None]  # [C//4, 1, 1]
         pe[0::4, :, :] = torch.sin(x_position * div_term)
         pe[1::4, :, :] = torch.cos(x_position * div_term)
         pe[2::4, :, :] = torch.sin(y_position * div_term)
         pe[3::4, :, :] = torch.cos(y_position * div_term)
 
-        self.register_buffer('pe', pe.unsqueeze(0), persistent=False)  # [1, C, H, W]
+        self.register_buffer("pe", pe.unsqueeze(0), persistent=False)  # [1, C, H, W]
 
-    def forward(self, x,scaling=None):
+    def forward(self, x, scaling=None):
         """
         Args:
             x: [N, C, H, W]
         """
-        if scaling is None: #onliner scaling overwrites pre_scaling
-            return x + self.pe[:, :, :x.size(2), :x.size(3)],self.pe[:, :, :x.size(2), :x.size(3)]
+        if scaling is None:  # onliner scaling overwrites pre_scaling
+            return (
+                x + self.pe[:, :, : x.size(2), : x.size(3)],
+                self.pe[:, :, : x.size(2), : x.size(3)],
+            )
         else:
             pe = torch.zeros((self.d_model, *self.max_shape))
-            y_position = torch.ones(self.max_shape).cumsum(0).float().unsqueeze(0)*scaling[0]
-            x_position = torch.ones(self.max_shape).cumsum(1).float().unsqueeze(0)*scaling[1]
-            
-            div_term = torch.exp(torch.arange(0, self.d_model//2, 2).float() * (-math.log(10000.0) / (self.d_model//2)))
+            y_position = (
+                torch.ones(self.max_shape).cumsum(0).float().unsqueeze(0) * scaling[0]
+            )
+            x_position = (
+                torch.ones(self.max_shape).cumsum(1).float().unsqueeze(0) * scaling[1]
+            )
+
+            div_term = torch.exp(
+                torch.arange(0, self.d_model // 2, 2).float()
+                * (-math.log(10000.0) / (self.d_model // 2))
+            )
             div_term = div_term[:, None, None]  # [C//4, 1, 1]
             pe[0::4, :, :] = torch.sin(x_position * div_term)
             pe[1::4, :, :] = torch.cos(x_position * div_term)
             pe[2::4, :, :] = torch.sin(y_position * div_term)
             pe[3::4, :, :] = torch.cos(y_position * div_term)
-            pe=pe.unsqueeze(0).to(x.device)
-            return x + pe[:, :, :x.size(2), :x.size(3)],pe[:, :, :x.size(2), :x.size(3)]
\ No newline at end of file
+            pe = pe.unsqueeze(0).to(x.device)
+            return (
+                x + pe[:, :, : x.size(2), : x.size(3)],
+                pe[:, :, : x.size(2), : x.size(3)],
+            )
diff --git a/third_party/ASpanFormer/src/ASpanFormer/utils/supervision.py b/third_party/ASpanFormer/src/ASpanFormer/utils/supervision.py
index 5cef3a7968413136f6dc9f52b6a1ec87192b006b..16c468d8ee1425be0d4518477263f377bd09873a 100644
--- a/third_party/ASpanFormer/src/ASpanFormer/utils/supervision.py
+++ b/third_party/ASpanFormer/src/ASpanFormer/utils/supervision.py
@@ -13,7 +13,7 @@ from .geometry import warp_kpts
 @torch.no_grad()
 def mask_pts_at_padded_regions(grid_pt, mask):
     """For megadepth dataset, zero-padding exists in images"""
-    mask = repeat(mask, 'n h w -> n (h w) c', c=2)
+    mask = repeat(mask, "n h w -> n (h w) c", c=2)
     grid_pt[~mask.bool()] = 0
     return grid_pt
 
@@ -30,37 +30,55 @@ def spvs_coarse(data, config):
             'spv_w_pt0_i': [N, hw0, 2], in original image resolution
             'spv_pt1_i': [N, hw1, 2], in original image resolution
         }
-        
+
     NOTE:
         - for scannet dataset, there're 3 kinds of resolution {i, c, f}
         - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
     """
     # 1. misc
-    device = data['image0'].device
-    N, _, H0, W0 = data['image0'].shape
-    _, _, H1, W1 = data['image1'].shape
-    scale = config['ASPAN']['RESOLUTION'][0]
-    scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
-    scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale
+    device = data["image0"].device
+    N, _, H0, W0 = data["image0"].shape
+    _, _, H1, W1 = data["image1"].shape
+    scale = config["ASPAN"]["RESOLUTION"][0]
+    scale0 = scale * data["scale0"][:, None] if "scale0" in data else scale
+    scale1 = scale * data["scale1"][:, None] if "scale0" in data else scale
     h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
 
     # 2. warp grids
     # create kpts in meshgrid and resize them to image resolution
-    grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1)    # [N, hw, 2]
+    grid_pt0_c = (
+        create_meshgrid(h0, w0, False, device).reshape(1, h0 * w0, 2).repeat(N, 1, 1)
+    )  # [N, hw, 2]
     grid_pt0_i = scale0 * grid_pt0_c
-    grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1)
+    grid_pt1_c = (
+        create_meshgrid(h1, w1, False, device).reshape(1, h1 * w1, 2).repeat(N, 1, 1)
+    )
     grid_pt1_i = scale1 * grid_pt1_c
 
     # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt
-    if 'mask0' in data:
-        grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0'])
-        grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1'])
+    if "mask0" in data:
+        grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data["mask0"])
+        grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data["mask1"])
 
     # warp kpts bi-directionally and resize them to coarse-level resolution
     # (no depth consistency check, since it leads to worse results experimentally)
     # (unhandled edge case: points with 0-depth will be warped to the left-up corner)
-    _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'])
-    _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
+    _, w_pt0_i = warp_kpts(
+        grid_pt0_i,
+        data["depth0"],
+        data["depth1"],
+        data["T_0to1"],
+        data["K0"],
+        data["K1"],
+    )
+    _, w_pt1_i = warp_kpts(
+        grid_pt1_i,
+        data["depth1"],
+        data["depth0"],
+        data["T_1to0"],
+        data["K1"],
+        data["K0"],
+    )
     w_pt0_c = w_pt0_i / scale1
     w_pt1_c = w_pt1_i / scale0
 
@@ -72,21 +90,26 @@ def spvs_coarse(data, config):
 
     # corner case: out of boundary
     def out_bound_mask(pt, w, h):
-        return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
+        return (
+            (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
+        )
+
     nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0
     nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0
 
-    loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0)
-    correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1)
+    loop_back = torch.stack(
+        [nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0
+    )
+    correct_0to1 = loop_back == torch.arange(h0 * w0, device=device)[None].repeat(N, 1)
     correct_0to1[:, 0] = False  # ignore the top-left corner
 
     # 4. construct a gt conf_matrix
-    conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device)
+    conf_matrix_gt = torch.zeros(N, h0 * w0, h1 * w1, device=device)
     b_ids, i_ids = torch.where(correct_0to1 != 0)
     j_ids = nearest_index1[b_ids, i_ids]
 
     conf_matrix_gt[b_ids, i_ids, j_ids] = 1
-    data.update({'conf_matrix_gt': conf_matrix_gt})
+    data.update({"conf_matrix_gt": conf_matrix_gt})
 
     # 5. save coarse matches(gt) for training fine level
     if len(b_ids) == 0:
@@ -96,30 +119,26 @@ def spvs_coarse(data, config):
         i_ids = torch.tensor([0], device=device)
         j_ids = torch.tensor([0], device=device)
 
-    data.update({
-        'spv_b_ids': b_ids,
-        'spv_i_ids': i_ids,
-        'spv_j_ids': j_ids
-    })
+    data.update({"spv_b_ids": b_ids, "spv_i_ids": i_ids, "spv_j_ids": j_ids})
 
     # 6. save intermediate results (for fast fine-level computation)
-    data.update({
-        'spv_w_pt0_i': w_pt0_i,
-        'spv_pt1_i': grid_pt1_i
-    })
+    data.update({"spv_w_pt0_i": w_pt0_i, "spv_pt1_i": grid_pt1_i})
 
 
 def compute_supervision_coarse(data, config):
-    assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!"
-    data_source = data['dataset_name'][0]
-    if data_source.lower() in ['scannet', 'megadepth']:
+    assert (
+        len(set(data["dataset_name"])) == 1
+    ), "Do not support mixed datasets training!"
+    data_source = data["dataset_name"][0]
+    if data_source.lower() in ["scannet", "megadepth"]:
         spvs_coarse(data, config)
     else:
-        raise ValueError(f'Unknown data source: {data_source}')
+        raise ValueError(f"Unknown data source: {data_source}")
 
 
 ##############  ↓  Fine-Level supervision  ↓  ##############
 
+
 @torch.no_grad()
 def spvs_fine(data, config):
     """
@@ -129,23 +148,25 @@ def spvs_fine(data, config):
     """
     # 1. misc
     # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i')
-    w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i']
-    scale = config['ASPAN']['RESOLUTION'][1]
-    radius = config['ASPAN']['FINE_WINDOW_SIZE'] // 2
+    w_pt0_i, pt1_i = data["spv_w_pt0_i"], data["spv_pt1_i"]
+    scale = config["ASPAN"]["RESOLUTION"][1]
+    radius = config["ASPAN"]["FINE_WINDOW_SIZE"] // 2
 
     # 2. get coarse prediction
-    b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids']
+    b_ids, i_ids, j_ids = data["b_ids"], data["i_ids"], data["j_ids"]
 
     # 3. compute gt
-    scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale
+    scale = scale * data["scale1"][b_ids] if "scale0" in data else scale
     # `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later
-    expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius  # [M, 2]
+    expec_f_gt = (
+        (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius
+    )  # [M, 2]
     data.update({"expec_f_gt": expec_f_gt})
 
 
 def compute_supervision_fine(data, config):
-    data_source = data['dataset_name'][0]
-    if data_source.lower() in ['scannet', 'megadepth']:
+    data_source = data["dataset_name"][0]
+    if data_source.lower() in ["scannet", "megadepth"]:
         spvs_fine(data, config)
     else:
         raise NotImplementedError
diff --git a/third_party/ASpanFormer/src/config/default.py b/third_party/ASpanFormer/src/config/default.py
index 40abd51c3f28ea6dee3c4e9fcee6efac5c080a2f..2850199cfb4d403fe4ec7aa5d61a7de524e4183c 100644
--- a/third_party/ASpanFormer/src/config/default.py
+++ b/third_party/ASpanFormer/src/config/default.py
@@ -1,9 +1,10 @@
 from yacs.config import CfgNode as CN
+
 _CN = CN()
 
 ##############  ↓  ASPAN Pipeline  ↓  ##############
 _CN.ASPAN = CN()
-_CN.ASPAN.BACKBONE_TYPE = 'ResNetFPN'
+_CN.ASPAN.BACKBONE_TYPE = "ResNetFPN"
 _CN.ASPAN.RESOLUTION = (8, 2)  # options: [(8, 2), (16, 4)]
 _CN.ASPAN.FINE_WINDOW_SIZE = 5  # window_size in fine_level, must be odd
 _CN.ASPAN.FINE_CONCAT_COARSE_FEAT = True
@@ -17,14 +18,14 @@ _CN.ASPAN.RESNETFPN.BLOCK_DIMS = [128, 196, 256]  # s1, s2, s3
 _CN.ASPAN.COARSE = CN()
 _CN.ASPAN.COARSE.D_MODEL = 256
 _CN.ASPAN.COARSE.D_FFN = 256
-_CN.ASPAN.COARSE.D_FLOW= 128
+_CN.ASPAN.COARSE.D_FLOW = 128
 _CN.ASPAN.COARSE.NHEAD = 8
-_CN.ASPAN.COARSE.NLEVEL= 3
-_CN.ASPAN.COARSE.INI_LAYER_NUM =  2
-_CN.ASPAN.COARSE.LAYER_NUM =  4
-_CN.ASPAN.COARSE.NSAMPLE = [2,8]
-_CN.ASPAN.COARSE.RADIUS_SCALE= 5
-_CN.ASPAN.COARSE.COARSEST_LEVEL= [26,26]
+_CN.ASPAN.COARSE.NLEVEL = 3
+_CN.ASPAN.COARSE.INI_LAYER_NUM = 2
+_CN.ASPAN.COARSE.LAYER_NUM = 4
+_CN.ASPAN.COARSE.NSAMPLE = [2, 8]
+_CN.ASPAN.COARSE.RADIUS_SCALE = 5
+_CN.ASPAN.COARSE.COARSEST_LEVEL = [26, 26]
 _CN.ASPAN.COARSE.TRAIN_RES = None
 _CN.ASPAN.COARSE.TEST_RES = None
 
@@ -32,7 +33,9 @@ _CN.ASPAN.COARSE.TEST_RES = None
 _CN.ASPAN.MATCH_COARSE = CN()
 _CN.ASPAN.MATCH_COARSE.THR = 0.2
 _CN.ASPAN.MATCH_COARSE.BORDER_RM = 2
-_CN.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'  # options: ['dual_softmax, 'sinkhorn']
+_CN.ASPAN.MATCH_COARSE.MATCH_TYPE = (
+    "dual_softmax"  # options: ['dual_softmax, 'sinkhorn']
+)
 _CN.ASPAN.MATCH_COARSE.SKH_ITERS = 3
 _CN.ASPAN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
 _CN.ASPAN.MATCH_COARSE.SKH_PREFILTER = False
@@ -46,13 +49,13 @@ _CN.ASPAN.FINE = CN()
 _CN.ASPAN.FINE.D_MODEL = 128
 _CN.ASPAN.FINE.D_FFN = 128
 _CN.ASPAN.FINE.NHEAD = 8
-_CN.ASPAN.FINE.LAYER_NAMES = ['self', 'cross'] * 1
-_CN.ASPAN.FINE.ATTENTION = 'linear'
+_CN.ASPAN.FINE.LAYER_NAMES = ["self", "cross"] * 1
+_CN.ASPAN.FINE.ATTENTION = "linear"
 
 # 5. ASPAN Losses
 # -- # coarse-level
 _CN.ASPAN.LOSS = CN()
-_CN.ASPAN.LOSS.COARSE_TYPE = 'focal'  # ['focal', 'cross_entropy']
+_CN.ASPAN.LOSS.COARSE_TYPE = "focal"  # ['focal', 'cross_entropy']
 _CN.ASPAN.LOSS.COARSE_WEIGHT = 1.0
 # _CN.ASPAN.LOSS.SPARSE_SPVS = False
 # -- - -- # focal loss (coarse)
@@ -64,7 +67,7 @@ _CN.ASPAN.LOSS.NEG_WEIGHT = 1.0
 # use `_CN.ASPAN.MATCH_COARSE.MATCH_TYPE`
 
 # -- # fine-level
-_CN.ASPAN.LOSS.FINE_TYPE = 'l2_with_std'  # ['l2_with_std', 'l2']
+_CN.ASPAN.LOSS.FINE_TYPE = "l2_with_std"  # ['l2_with_std', 'l2']
 _CN.ASPAN.LOSS.FINE_WEIGHT = 1.0
 _CN.ASPAN.LOSS.FINE_CORRECT_THR = 1.0  # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window)
 
@@ -85,24 +88,32 @@ _CN.DATASET.TRAIN_INTRINSIC_PATH = None
 _CN.DATASET.VAL_DATA_ROOT = None
 _CN.DATASET.VAL_POSE_ROOT = None  # (optional directory for poses)
 _CN.DATASET.VAL_NPZ_ROOT = None
-_CN.DATASET.VAL_LIST_PATH = None    # None if val data from all scenes are bundled into a single npz file
+_CN.DATASET.VAL_LIST_PATH = (
+    None  # None if val data from all scenes are bundled into a single npz file
+)
 _CN.DATASET.VAL_INTRINSIC_PATH = None
 # testing
 _CN.DATASET.TEST_DATA_SOURCE = None
 _CN.DATASET.TEST_DATA_ROOT = None
 _CN.DATASET.TEST_POSE_ROOT = None  # (optional directory for poses)
 _CN.DATASET.TEST_NPZ_ROOT = None
-_CN.DATASET.TEST_LIST_PATH = None   # None if test data from all scenes are bundled into a single npz file
+_CN.DATASET.TEST_LIST_PATH = (
+    None  # None if test data from all scenes are bundled into a single npz file
+)
 _CN.DATASET.TEST_INTRINSIC_PATH = None
 
 # 2. dataset config
 # general options
-_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4  # discard data with overlap_score < min_overlap_score
+_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = (
+    0.4  # discard data with overlap_score < min_overlap_score
+)
 _CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
 _CN.DATASET.AUGMENTATION_TYPE = None  # options: [None, 'dark', 'mobile']
 
 # MegaDepth options
-_CN.DATASET.MGDPT_IMG_RESIZE = 640  # resize the longer side, zero-pad bottom-right to square.
+_CN.DATASET.MGDPT_IMG_RESIZE = (
+    640  # resize the longer side, zero-pad bottom-right to square.
+)
 _CN.DATASET.MGDPT_IMG_PAD = True  # pad img to square with size = MGDPT_IMG_RESIZE
 _CN.DATASET.MGDPT_DEPTH_PAD = True  # pad depthmap to square with size = 2000
 _CN.DATASET.MGDPT_DF = 8
@@ -118,17 +129,17 @@ _CN.TRAINER.FIND_LR = False  # use learning rate finder from pytorch-lightning
 # optimizer
 _CN.TRAINER.OPTIMIZER = "adamw"  # [adam, adamw]
 _CN.TRAINER.TRUE_LR = None  # this will be calculated automatically at runtime
-_CN.TRAINER.ADAM_DECAY = 0.  # ADAM: for adam
+_CN.TRAINER.ADAM_DECAY = 0.0  # ADAM: for adam
 _CN.TRAINER.ADAMW_DECAY = 0.1
 
 # step-based warm-up
-_CN.TRAINER.WARMUP_TYPE = 'linear'  # [linear, constant]
-_CN.TRAINER.WARMUP_RATIO = 0.
+_CN.TRAINER.WARMUP_TYPE = "linear"  # [linear, constant]
+_CN.TRAINER.WARMUP_RATIO = 0.0
 _CN.TRAINER.WARMUP_STEP = 4800
 
 # learning rate scheduler
-_CN.TRAINER.SCHEDULER = 'MultiStepLR'  # [MultiStepLR, CosineAnnealing, ExponentialLR]
-_CN.TRAINER.SCHEDULER_INTERVAL = 'epoch'    # [epoch, step]
+_CN.TRAINER.SCHEDULER = "MultiStepLR"  # [MultiStepLR, CosineAnnealing, ExponentialLR]
+_CN.TRAINER.SCHEDULER_INTERVAL = "epoch"  # [epoch, step]
 _CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12]  # MSLR: MultiStepLR
 _CN.TRAINER.MSLR_GAMMA = 0.5
 _CN.TRAINER.COSA_TMAX = 30  # COSA: CosineAnnealing
@@ -136,25 +147,33 @@ _CN.TRAINER.ELR_GAMMA = 0.999992  # ELR: ExponentialLR, this value for 'step' in
 
 # plotting related
 _CN.TRAINER.ENABLE_PLOTTING = True
-_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32     # number of val/test paris for plotting
-_CN.TRAINER.PLOT_MODE = 'evaluation'  # ['evaluation', 'confidence']
-_CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic'
+_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32  # number of val/test paris for plotting
+_CN.TRAINER.PLOT_MODE = "evaluation"  # ['evaluation', 'confidence']
+_CN.TRAINER.PLOT_MATCHES_ALPHA = "dynamic"
 
 # geometric metrics and pose solver
-_CN.TRAINER.EPI_ERR_THR = 5e-4  # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
-_CN.TRAINER.POSE_GEO_MODEL = 'E'  # ['E', 'F', 'H']
-_CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC'  # [RANSAC, DEGENSAC, MAGSAC]
+_CN.TRAINER.EPI_ERR_THR = (
+    5e-4  # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
+)
+_CN.TRAINER.POSE_GEO_MODEL = "E"  # ['E', 'F', 'H']
+_CN.TRAINER.POSE_ESTIMATION_METHOD = "RANSAC"  # [RANSAC, DEGENSAC, MAGSAC]
 _CN.TRAINER.RANSAC_PIXEL_THR = 0.5
 _CN.TRAINER.RANSAC_CONF = 0.99999
 _CN.TRAINER.RANSAC_MAX_ITERS = 10000
 _CN.TRAINER.USE_MAGSACPP = False
 
 # data sampler for train_dataloader
-_CN.TRAINER.DATA_SAMPLER = 'scene_balance'  # options: ['scene_balance', 'random', 'normal']
+_CN.TRAINER.DATA_SAMPLER = (
+    "scene_balance"  # options: ['scene_balance', 'random', 'normal']
+)
 # 'scene_balance' config
 _CN.TRAINER.N_SAMPLES_PER_SUBSET = 200
-_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True  # whether sample each scene with replacement or not
-_CN.TRAINER.SB_SUBSET_SHUFFLE = True  # after sampling from scenes, whether shuffle within the epoch or not
+_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = (
+    True  # whether sample each scene with replacement or not
+)
+_CN.TRAINER.SB_SUBSET_SHUFFLE = (
+    True  # after sampling from scenes, whether shuffle within the epoch or not
+)
 _CN.TRAINER.SB_REPEAT = 1  # repeat N times for training the sampled data
 # 'random' config
 _CN.TRAINER.RDM_REPLACEMENT = True
diff --git a/third_party/ASpanFormer/src/datasets/__init__.py b/third_party/ASpanFormer/src/datasets/__init__.py
index 1860e3ae060a26e4625925861cecdc355f2b08b7..4feb648440e6c8db60de3aa475cd82ce460dcc1c 100644
--- a/third_party/ASpanFormer/src/datasets/__init__.py
+++ b/third_party/ASpanFormer/src/datasets/__init__.py
@@ -1,3 +1,2 @@
 from .scannet import ScanNetDataset
 from .megadepth import MegaDepthDataset
-
diff --git a/third_party/ASpanFormer/src/datasets/megadepth.py b/third_party/ASpanFormer/src/datasets/megadepth.py
index a70ac715a3f807e37bc5b87ae9446ddd2aa4fc86..7cbf95962df705c14d11483838f13bfd5e036166 100644
--- a/third_party/ASpanFormer/src/datasets/megadepth.py
+++ b/third_party/ASpanFormer/src/datasets/megadepth.py
@@ -9,20 +9,22 @@ from src.utils.dataset import read_megadepth_gray, read_megadepth_depth
 
 
 class MegaDepthDataset(Dataset):
-    def __init__(self,
-                 root_dir,
-                 npz_path,
-                 mode='train',
-                 min_overlap_score=0.4,
-                 img_resize=None,
-                 df=None,
-                 img_padding=False,
-                 depth_padding=False,
-                 augment_fn=None,
-                 **kwargs):
+    def __init__(
+        self,
+        root_dir,
+        npz_path,
+        mode="train",
+        min_overlap_score=0.4,
+        img_resize=None,
+        df=None,
+        img_padding=False,
+        depth_padding=False,
+        augment_fn=None,
+        **kwargs
+    ):
         """
         Manage one scene(npz_path) of MegaDepth dataset.
-        
+
         Args:
             root_dir (str): megadepth root directory that has `phoenix`.
             npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
@@ -38,28 +40,36 @@ class MegaDepthDataset(Dataset):
         super().__init__()
         self.root_dir = root_dir
         self.mode = mode
-        self.scene_id = npz_path.split('.')[0]
+        self.scene_id = npz_path.split(".")[0]
 
         # prepare scene_info and pair_info
-        if mode == 'test' and min_overlap_score != 0:
-            logger.warning("You are using `min_overlap_score`!=0 in test mode. Set to 0.")
+        if mode == "test" and min_overlap_score != 0:
+            logger.warning(
+                "You are using `min_overlap_score`!=0 in test mode. Set to 0."
+            )
             min_overlap_score = 0
         self.scene_info = np.load(npz_path, allow_pickle=True)
-        self.pair_infos = self.scene_info['pair_infos'].copy()
-        del self.scene_info['pair_infos']
-        self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score]
+        self.pair_infos = self.scene_info["pair_infos"].copy()
+        del self.scene_info["pair_infos"]
+        self.pair_infos = [
+            pair_info
+            for pair_info in self.pair_infos
+            if pair_info[1] > min_overlap_score
+        ]
 
         # parameters for image resizing, padding and depthmap padding
-        if mode == 'train':
+        if mode == "train":
             assert img_resize is not None and img_padding and depth_padding
         self.img_resize = img_resize
         self.df = df
         self.img_padding = img_padding
-        self.depth_max_size = 2000 if depth_padding else None  # the upperbound of depthmaps size in megadepth.
+        self.depth_max_size = (
+            2000 if depth_padding else None
+        )  # the upperbound of depthmaps size in megadepth.
 
         # for training LoFTR
-        self.augment_fn = augment_fn if mode == 'train' else None
-        self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125)
+        self.augment_fn = augment_fn if mode == "train" else None
+        self.coarse_scale = getattr(kwargs, "coarse_scale", 0.125)
 
     def __len__(self):
         return len(self.pair_infos)
@@ -68,60 +78,77 @@ class MegaDepthDataset(Dataset):
         (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx]
 
         # read grayscale image and mask. (1, h, w) and (h, w)
-        img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0])
-        img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1])
-        
+        img_name0 = osp.join(self.root_dir, self.scene_info["image_paths"][idx0])
+        img_name1 = osp.join(self.root_dir, self.scene_info["image_paths"][idx1])
+
         # TODO: Support augmentation & handle seeds for each worker correctly.
         image0, mask0, scale0 = read_megadepth_gray(
-            img_name0, self.img_resize, self.df, self.img_padding, None)
-            # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+            img_name0, self.img_resize, self.df, self.img_padding, None
+        )
+        # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
         image1, mask1, scale1 = read_megadepth_gray(
-            img_name1, self.img_resize, self.df, self.img_padding, None)
-            # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+            img_name1, self.img_resize, self.df, self.img_padding, None
+        )
+        # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
 
         # read depth. shape: (h, w)
-        if self.mode in ['train', 'val']:
+        if self.mode in ["train", "val"]:
             depth0 = read_megadepth_depth(
-                osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size)
+                osp.join(self.root_dir, self.scene_info["depth_paths"][idx0]),
+                pad_to=self.depth_max_size,
+            )
             depth1 = read_megadepth_depth(
-                osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size)
+                osp.join(self.root_dir, self.scene_info["depth_paths"][idx1]),
+                pad_to=self.depth_max_size,
+            )
         else:
             depth0 = depth1 = torch.tensor([])
 
         # read intrinsics of original size
-        K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3)
-        K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3)
+        K_0 = torch.tensor(
+            self.scene_info["intrinsics"][idx0].copy(), dtype=torch.float
+        ).reshape(3, 3)
+        K_1 = torch.tensor(
+            self.scene_info["intrinsics"][idx1].copy(), dtype=torch.float
+        ).reshape(3, 3)
 
         # read and compute relative poses
-        T0 = self.scene_info['poses'][idx0]
-        T1 = self.scene_info['poses'][idx1]
-        T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4]  # (4, 4)
+        T0 = self.scene_info["poses"][idx0]
+        T1 = self.scene_info["poses"][idx1]
+        T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[
+            :4, :4
+        ]  # (4, 4)
         T_1to0 = T_0to1.inverse()
 
         data = {
-            'image0': image0,  # (1, h, w)
-            'depth0': depth0,  # (h, w)
-            'image1': image1,
-            'depth1': depth1,
-            'T_0to1': T_0to1,  # (4, 4)
-            'T_1to0': T_1to0,
-            'K0': K_0,  # (3, 3)
-            'K1': K_1,
-            'scale0': scale0,  # [scale_w, scale_h]
-            'scale1': scale1,
-            'dataset_name': 'MegaDepth',
-            'scene_id': self.scene_id,
-            'pair_id': idx,
-            'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]),
+            "image0": image0,  # (1, h, w)
+            "depth0": depth0,  # (h, w)
+            "image1": image1,
+            "depth1": depth1,
+            "T_0to1": T_0to1,  # (4, 4)
+            "T_1to0": T_1to0,
+            "K0": K_0,  # (3, 3)
+            "K1": K_1,
+            "scale0": scale0,  # [scale_w, scale_h]
+            "scale1": scale1,
+            "dataset_name": "MegaDepth",
+            "scene_id": self.scene_id,
+            "pair_id": idx,
+            "pair_names": (
+                self.scene_info["image_paths"][idx0],
+                self.scene_info["image_paths"][idx1],
+            ),
         }
 
         # for LoFTR training
         if mask0 is not None:  # img_padding is True
             if self.coarse_scale:
-                [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
-                                                       scale_factor=self.coarse_scale,
-                                                       mode='nearest',
-                                                       recompute_scale_factor=False)[0].bool()
-            data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1})
+                [ts_mask_0, ts_mask_1] = F.interpolate(
+                    torch.stack([mask0, mask1], dim=0)[None].float(),
+                    scale_factor=self.coarse_scale,
+                    mode="nearest",
+                    recompute_scale_factor=False,
+                )[0].bool()
+            data.update({"mask0": ts_mask_0, "mask1": ts_mask_1})
 
         return data
diff --git a/third_party/ASpanFormer/src/datasets/sampler.py b/third_party/ASpanFormer/src/datasets/sampler.py
index 81b6f435645632a013476f9a665a0861ab7fcb61..131111c4cf69cd8770058dfac2be717aa183978e 100644
--- a/third_party/ASpanFormer/src/datasets/sampler.py
+++ b/third_party/ASpanFormer/src/datasets/sampler.py
@@ -3,10 +3,10 @@ from torch.utils.data import Sampler, ConcatDataset
 
 
 class RandomConcatSampler(Sampler):
-    """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
+    """Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
     in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement.
     However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase.
-    
+
     For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not.
     Args:
         shuffle (bool): shuffle the random sampled indices across all sub-datsets.
@@ -18,16 +18,19 @@ class RandomConcatSampler(Sampler):
     TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs.
           ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373
     """
-    def __init__(self,
-                 data_source: ConcatDataset,
-                 n_samples_per_subset: int,
-                 subset_replacement: bool=True,
-                 shuffle: bool=True,
-                 repeat: int=1,
-                 seed: int=None):
+
+    def __init__(
+        self,
+        data_source: ConcatDataset,
+        n_samples_per_subset: int,
+        subset_replacement: bool = True,
+        shuffle: bool = True,
+        repeat: int = 1,
+        seed: int = None,
+    ):
         if not isinstance(data_source, ConcatDataset):
             raise TypeError("data_source should be torch.utils.data.ConcatDataset")
-        
+
         self.data_source = data_source
         self.n_subset = len(self.data_source.datasets)
         self.n_samples_per_subset = n_samples_per_subset
@@ -37,27 +40,37 @@ class RandomConcatSampler(Sampler):
         self.shuffle = shuffle
         self.generator = torch.manual_seed(seed)
         assert self.repeat >= 1
-        
+
     def __len__(self):
         return self.n_samples
-    
+
     def __iter__(self):
         indices = []
         # sample from each sub-dataset
         for d_idx in range(self.n_subset):
-            low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1]
+            low = 0 if d_idx == 0 else self.data_source.cumulative_sizes[d_idx - 1]
             high = self.data_source.cumulative_sizes[d_idx]
             if self.subset_replacement:
-                rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ),
-                                            generator=self.generator, dtype=torch.int64)
+                rand_tensor = torch.randint(
+                    low,
+                    high,
+                    (self.n_samples_per_subset,),
+                    generator=self.generator,
+                    dtype=torch.int64,
+                )
             else:  # sample without replacement
                 len_subset = len(self.data_source.datasets[d_idx])
                 rand_tensor = torch.randperm(len_subset, generator=self.generator) + low
                 if len_subset >= self.n_samples_per_subset:
-                    rand_tensor = rand_tensor[:self.n_samples_per_subset]
-                else: # padding with replacement
-                    rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ),
-                                                            generator=self.generator, dtype=torch.int64)
+                    rand_tensor = rand_tensor[: self.n_samples_per_subset]
+                else:  # padding with replacement
+                    rand_tensor_replacement = torch.randint(
+                        low,
+                        high,
+                        (self.n_samples_per_subset - len_subset,),
+                        generator=self.generator,
+                        dtype=torch.int64,
+                    )
                     rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement])
             indices.append(rand_tensor)
         indices = torch.cat(indices)
@@ -72,6 +85,6 @@ class RandomConcatSampler(Sampler):
                 _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)]
                 repeat_indices = map(_choice, repeat_indices)
             indices = torch.cat([indices, *repeat_indices], 0)
-        
+
         assert indices.shape[0] == self.n_samples
         return iter(indices.tolist())
diff --git a/third_party/ASpanFormer/src/datasets/scannet.py b/third_party/ASpanFormer/src/datasets/scannet.py
index 3520d34c0f08a784ddbf923846a7cb2a847b1787..615e98409b92713ab241aa8658c74cf7b2f8baae 100644
--- a/third_party/ASpanFormer/src/datasets/scannet.py
+++ b/third_party/ASpanFormer/src/datasets/scannet.py
@@ -10,20 +10,22 @@ from src.utils.dataset import (
     read_scannet_gray,
     read_scannet_depth,
     read_scannet_pose,
-    read_scannet_intrinsic
+    read_scannet_intrinsic,
 )
 
 
 class ScanNetDataset(utils.data.Dataset):
-    def __init__(self,
-                 root_dir,
-                 npz_path,
-                 intrinsic_path,
-                 mode='train',
-                 min_overlap_score=0.4,
-                 augment_fn=None,
-                 pose_dir=None,
-                 **kwargs):
+    def __init__(
+        self,
+        root_dir,
+        npz_path,
+        intrinsic_path,
+        mode="train",
+        min_overlap_score=0.4,
+        augment_fn=None,
+        pose_dir=None,
+        **kwargs,
+    ):
         """Manage one scene of ScanNet Dataset.
         Args:
             root_dir (str): ScanNet root directory that contains scene folders.
@@ -41,73 +43,81 @@ class ScanNetDataset(utils.data.Dataset):
 
         # prepare data_names, intrinsics and extrinsics(T)
         with np.load(npz_path) as data:
-            self.data_names = data['name']
-            if 'score' in data.keys() and mode not in ['val' or 'test']:
-                kept_mask = data['score'] > min_overlap_score
+            self.data_names = data["name"]
+            if "score" in data.keys() and mode not in ["val" or "test"]:
+                kept_mask = data["score"] > min_overlap_score
                 self.data_names = self.data_names[kept_mask]
         self.intrinsics = dict(np.load(intrinsic_path))
 
         # for training LoFTR
-        self.augment_fn = augment_fn if mode == 'train' else None
+        self.augment_fn = augment_fn if mode == "train" else None
 
     def __len__(self):
         return len(self.data_names)
 
     def _read_abs_pose(self, scene_name, name):
-        pth = osp.join(self.pose_dir,
-                       scene_name,
-                       'pose', f'{name}.txt')
+        pth = osp.join(self.pose_dir, scene_name, "pose", f"{name}.txt")
         return read_scannet_pose(pth)
 
     def _compute_rel_pose(self, scene_name, name0, name1):
         pose0 = self._read_abs_pose(scene_name, name0)
         pose1 = self._read_abs_pose(scene_name, name1)
-        
+
         return np.matmul(pose1, inv(pose0))  # (4, 4)
 
     def __getitem__(self, idx):
         data_name = self.data_names[idx]
         scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name
-        scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
+        scene_name = f"scene{scene_name:04d}_{scene_sub_name:02d}"
 
         # read the grayscale image which will be resized to (1, 480, 640)
-        img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg')
-        img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg')
+        img_name0 = osp.join(self.root_dir, scene_name, "color", f"{stem_name_0}.jpg")
+        img_name1 = osp.join(self.root_dir, scene_name, "color", f"{stem_name_1}.jpg")
         # TODO: Support augmentation & handle seeds for each worker correctly.
         image0 = read_scannet_gray(img_name0, resize=(640, 480), augment_fn=None)
-                                #    augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+        #    augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
         image1 = read_scannet_gray(img_name1, resize=(640, 480), augment_fn=None)
-                                #    augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+        #    augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
 
         # read the depthmap which is stored as (480, 640)
-        if self.mode in ['train', 'val']:
-            depth0 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png'))
-            depth1 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png'))
+        if self.mode in ["train", "val"]:
+            depth0 = read_scannet_depth(
+                osp.join(self.root_dir, scene_name, "depth", f"{stem_name_0}.png")
+            )
+            depth1 = read_scannet_depth(
+                osp.join(self.root_dir, scene_name, "depth", f"{stem_name_1}.png")
+            )
         else:
             depth0 = depth1 = torch.tensor([])
 
         # read the intrinsic of depthmap
-        K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3)
+        K_0 = K_1 = torch.tensor(
+            self.intrinsics[scene_name].copy(), dtype=torch.float
+        ).reshape(3, 3)
 
         # read and compute relative poses
-        T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1),
-                              dtype=torch.float32)
+        T_0to1 = torch.tensor(
+            self._compute_rel_pose(scene_name, stem_name_0, stem_name_1),
+            dtype=torch.float32,
+        )
         T_1to0 = T_0to1.inverse()
 
         data = {
-            'image0': image0,   # (1, h, w)
-            'depth0': depth0,   # (h, w)
-            'image1': image1,
-            'depth1': depth1,
-            'T_0to1': T_0to1,   # (4, 4)
-            'T_1to0': T_1to0,
-            'K0': K_0,  # (3, 3)
-            'K1': K_1,
-            'dataset_name': 'ScanNet',
-            'scene_id': scene_name,
-            'pair_id': idx,
-            'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'),
-                           osp.join(scene_name, 'color', f'{stem_name_1}.jpg'))
+            "image0": image0,  # (1, h, w)
+            "depth0": depth0,  # (h, w)
+            "image1": image1,
+            "depth1": depth1,
+            "T_0to1": T_0to1,  # (4, 4)
+            "T_1to0": T_1to0,
+            "K0": K_0,  # (3, 3)
+            "K1": K_1,
+            "dataset_name": "ScanNet",
+            "scene_id": scene_name,
+            "pair_id": idx,
+            "pair_names": (
+                osp.join(scene_name, "color", f"{stem_name_0}.jpg"),
+                osp.join(scene_name, "color", f"{stem_name_1}.jpg"),
+            ),
         }
 
         return data
diff --git a/third_party/ASpanFormer/src/lightning/data.py b/third_party/ASpanFormer/src/lightning/data.py
index 73db514b8924d647814e6c5def919c23393d3ccf..9877df5980c73e9bfb5a1e6ec301e1a84a97ca56 100644
--- a/third_party/ASpanFormer/src/lightning/data.py
+++ b/third_party/ASpanFormer/src/lightning/data.py
@@ -16,7 +16,7 @@ from torch.utils.data import (
     ConcatDataset,
     DistributedSampler,
     RandomSampler,
-    dataloader
+    dataloader,
 )
 
 from src.utils.augment import build_augmentor
@@ -29,10 +29,11 @@ from src.datasets.sampler import RandomConcatSampler
 
 
 class MultiSceneDataModule(pl.LightningDataModule):
-    """ 
+    """
     For distributed training, each training process is assgined
     only a part of the training scenes to reduce memory overhead.
     """
+
     def __init__(self, args, config):
         super().__init__()
 
@@ -60,47 +61,51 @@ class MultiSceneDataModule(pl.LightningDataModule):
 
         # 2. dataset config
         # general options
-        self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST  # 0.4, omit data with overlap_score < min_overlap_score
+        self.min_overlap_score_test = (
+            config.DATASET.MIN_OVERLAP_SCORE_TEST
+        )  # 0.4, omit data with overlap_score < min_overlap_score
         self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN
-        self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE)  # None, options: [None, 'dark', 'mobile']
+        self.augment_fn = build_augmentor(
+            config.DATASET.AUGMENTATION_TYPE
+        )  # None, options: [None, 'dark', 'mobile']
 
         # MegaDepth options
         self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE  # 840
-        self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD   # True
-        self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD   # True
+        self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD  # True
+        self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD  # True
         self.mgdpt_df = config.DATASET.MGDPT_DF  # 8
         self.coarse_scale = 1 / config.ASPAN.RESOLUTION[0]  # 0.125. for training loftr.
 
         # 3.loader parameters
         self.train_loader_params = {
-            'batch_size': args.batch_size,
-            'num_workers': args.num_workers,
-            'pin_memory': getattr(args, 'pin_memory', True)
+            "batch_size": args.batch_size,
+            "num_workers": args.num_workers,
+            "pin_memory": getattr(args, "pin_memory", True),
         }
         self.val_loader_params = {
-            'batch_size': 1,
-            'shuffle': False,
-            'num_workers': args.num_workers,
-            'pin_memory': getattr(args, 'pin_memory', True)
+            "batch_size": 1,
+            "shuffle": False,
+            "num_workers": args.num_workers,
+            "pin_memory": getattr(args, "pin_memory", True),
         }
         self.test_loader_params = {
-            'batch_size': 1,
-            'shuffle': False,
-            'num_workers': args.num_workers,
-            'pin_memory': True
+            "batch_size": 1,
+            "shuffle": False,
+            "num_workers": args.num_workers,
+            "pin_memory": True,
         }
-        
+
         # 4. sampler
         self.data_sampler = config.TRAINER.DATA_SAMPLER
         self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET
         self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT
         self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE
         self.repeat = config.TRAINER.SB_REPEAT
-        
+
         # (optional) RandomSampler for debugging
 
         # misc configurations
-        self.parallel_load_data = getattr(args, 'parallel_load_data', False)
+        self.parallel_load_data = getattr(args, "parallel_load_data", False)
         self.seed = config.TRAINER.SEED  # 66
 
     def setup(self, stage=None):
@@ -110,7 +115,7 @@ class MultiSceneDataModule(pl.LightningDataModule):
             stage (str): 'fit' in training phase, and 'test' in testing phase.
         """
 
-        assert stage in ['fit', 'test'], "stage must be either fit or test"
+        assert stage in ["fit", "test"], "stage must be either fit or test"
 
         try:
             self.world_size = dist.get_world_size()
@@ -121,73 +126,94 @@ class MultiSceneDataModule(pl.LightningDataModule):
             self.rank = 0
             logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)")
 
-        if stage == 'fit':
+        if stage == "fit":
             self.train_dataset = self._setup_dataset(
                 self.train_data_root,
                 self.train_npz_root,
                 self.train_list_path,
                 self.train_intrinsic_path,
-                mode='train',
+                mode="train",
                 min_overlap_score=self.min_overlap_score_train,
-                pose_dir=self.train_pose_root)
+                pose_dir=self.train_pose_root,
+            )
             # setup multiple (optional) validation subsets
             if isinstance(self.val_list_path, (list, tuple)):
                 self.val_dataset = []
                 if not isinstance(self.val_npz_root, (list, tuple)):
-                    self.val_npz_root = [self.val_npz_root for _ in range(len(self.val_list_path))]
+                    self.val_npz_root = [
+                        self.val_npz_root for _ in range(len(self.val_list_path))
+                    ]
                 for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root):
-                    self.val_dataset.append(self._setup_dataset(
-                        self.val_data_root,
-                        npz_root,
-                        npz_list,
-                        self.val_intrinsic_path,
-                        mode='val',
-                        min_overlap_score=self.min_overlap_score_test,
-                        pose_dir=self.val_pose_root))
+                    self.val_dataset.append(
+                        self._setup_dataset(
+                            self.val_data_root,
+                            npz_root,
+                            npz_list,
+                            self.val_intrinsic_path,
+                            mode="val",
+                            min_overlap_score=self.min_overlap_score_test,
+                            pose_dir=self.val_pose_root,
+                        )
+                    )
             else:
                 self.val_dataset = self._setup_dataset(
                     self.val_data_root,
                     self.val_npz_root,
                     self.val_list_path,
                     self.val_intrinsic_path,
-                    mode='val',
+                    mode="val",
                     min_overlap_score=self.min_overlap_score_test,
-                    pose_dir=self.val_pose_root)
-            logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!')
+                    pose_dir=self.val_pose_root,
+                )
+            logger.info(f"[rank:{self.rank}] Train & Val Dataset loaded!")
         else:  # stage == 'test
             self.test_dataset = self._setup_dataset(
                 self.test_data_root,
                 self.test_npz_root,
                 self.test_list_path,
                 self.test_intrinsic_path,
-                mode='test',
+                mode="test",
                 min_overlap_score=self.min_overlap_score_test,
-                pose_dir=self.test_pose_root)
-            logger.info(f'[rank:{self.rank}]: Test Dataset loaded!')
+                pose_dir=self.test_pose_root,
+            )
+            logger.info(f"[rank:{self.rank}]: Test Dataset loaded!")
 
-    def _setup_dataset(self,
-                       data_root,
-                       split_npz_root,
-                       scene_list_path,
-                       intri_path,
-                       mode='train',
-                       min_overlap_score=0.,
-                       pose_dir=None):
-        """ Setup train / val / test set"""
-        with open(scene_list_path, 'r') as f:
+    def _setup_dataset(
+        self,
+        data_root,
+        split_npz_root,
+        scene_list_path,
+        intri_path,
+        mode="train",
+        min_overlap_score=0.0,
+        pose_dir=None,
+    ):
+        """Setup train / val / test set"""
+        with open(scene_list_path, "r") as f:
             npz_names = [name.split()[0] for name in f.readlines()]
 
-        if mode == 'train':
-            local_npz_names = get_local_split(npz_names, self.world_size, self.rank, self.seed)
+        if mode == "train":
+            local_npz_names = get_local_split(
+                npz_names, self.world_size, self.rank, self.seed
+            )
         else:
             local_npz_names = npz_names
-        logger.info(f'[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.')
-        
-        dataset_builder = self._build_concat_dataset_parallel \
-                            if self.parallel_load_data \
-                            else self._build_concat_dataset
-        return dataset_builder(data_root, local_npz_names, split_npz_root, intri_path,
-                                mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir)
+        logger.info(f"[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.")
+
+        dataset_builder = (
+            self._build_concat_dataset_parallel
+            if self.parallel_load_data
+            else self._build_concat_dataset
+        )
+        return dataset_builder(
+            data_root,
+            local_npz_names,
+            split_npz_root,
+            intri_path,
+            mode=mode,
+            min_overlap_score=min_overlap_score,
+            pose_dir=pose_dir,
+        )
 
     def _build_concat_dataset(
         self,
@@ -196,49 +222,61 @@ class MultiSceneDataModule(pl.LightningDataModule):
         npz_dir,
         intrinsic_path,
         mode,
-        min_overlap_score=0.,
-        pose_dir=None
+        min_overlap_score=0.0,
+        pose_dir=None,
     ):
         datasets = []
-        augment_fn = self.augment_fn if mode == 'train' else None
-        data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source
-        if data_source=='GL3D' and mode=='val':
-            data_source='MegaDepth'
-        if str(data_source).lower() == 'megadepth':
-            npz_names = [f'{n}.npz' for n in npz_names]
-        if str(data_source).lower() == 'gl3d':
-            npz_names = [f'{n}.txt' for n in npz_names] 
-        #npz_names=npz_names[:8]
-        for npz_name in tqdm(npz_names,
-                             desc=f'[rank:{self.rank}] loading {mode} datasets',
-                             disable=int(self.rank) != 0):
+        augment_fn = self.augment_fn if mode == "train" else None
+        data_source = (
+            self.trainval_data_source
+            if mode in ["train", "val"]
+            else self.test_data_source
+        )
+        if data_source == "GL3D" and mode == "val":
+            data_source = "MegaDepth"
+        if str(data_source).lower() == "megadepth":
+            npz_names = [f"{n}.npz" for n in npz_names]
+        if str(data_source).lower() == "gl3d":
+            npz_names = [f"{n}.txt" for n in npz_names]
+        # npz_names=npz_names[:8]
+        for npz_name in tqdm(
+            npz_names,
+            desc=f"[rank:{self.rank}] loading {mode} datasets",
+            disable=int(self.rank) != 0,
+        ):
             # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time.
             npz_path = osp.join(npz_dir, npz_name)
-            if data_source == 'ScanNet':
+            if data_source == "ScanNet":
                 datasets.append(
-                    ScanNetDataset(data_root,
-                                   npz_path,
-                                   intrinsic_path,
-                                   mode=mode,
-                                   min_overlap_score=min_overlap_score,
-                                   augment_fn=augment_fn,
-                                   pose_dir=pose_dir))
-            elif data_source == 'MegaDepth':
+                    ScanNetDataset(
+                        data_root,
+                        npz_path,
+                        intrinsic_path,
+                        mode=mode,
+                        min_overlap_score=min_overlap_score,
+                        augment_fn=augment_fn,
+                        pose_dir=pose_dir,
+                    )
+                )
+            elif data_source == "MegaDepth":
                 datasets.append(
-                    MegaDepthDataset(data_root,
-                                     npz_path,
-                                     mode=mode,
-                                     min_overlap_score=min_overlap_score,
-                                     img_resize=self.mgdpt_img_resize,
-                                     df=self.mgdpt_df,
-                                     img_padding=self.mgdpt_img_pad,
-                                     depth_padding=self.mgdpt_depth_pad,
-                                     augment_fn=augment_fn,
-                                     coarse_scale=self.coarse_scale))
+                    MegaDepthDataset(
+                        data_root,
+                        npz_path,
+                        mode=mode,
+                        min_overlap_score=min_overlap_score,
+                        img_resize=self.mgdpt_img_resize,
+                        df=self.mgdpt_df,
+                        img_padding=self.mgdpt_img_pad,
+                        depth_padding=self.mgdpt_depth_pad,
+                        augment_fn=augment_fn,
+                        coarse_scale=self.coarse_scale,
+                    )
+                )
             else:
                 raise NotImplementedError()
         return ConcatDataset(datasets)
-    
+
     def _build_concat_dataset_parallel(
         self,
         data_root,
@@ -246,78 +284,119 @@ class MultiSceneDataModule(pl.LightningDataModule):
         npz_dir,
         intrinsic_path,
         mode,
-        min_overlap_score=0.,
+        min_overlap_score=0.0,
         pose_dir=None,
     ):
-        augment_fn = self.augment_fn if mode == 'train' else None
-        data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source
-        if str(data_source).lower() == 'megadepth':
-            npz_names = [f'{n}.npz' for n in npz_names]
-        #npz_names=npz_names[:8]
-        with tqdm_joblib(tqdm(desc=f'[rank:{self.rank}] loading {mode} datasets',
-                              total=len(npz_names), disable=int(self.rank) != 0)):
-            if data_source == 'ScanNet':
-                datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))(
-                    delayed(lambda x: _build_dataset(
-                        ScanNetDataset,
-                        data_root,
-                        osp.join(npz_dir, x),
-                        intrinsic_path,
-                        mode=mode,
-                        min_overlap_score=min_overlap_score,
-                        augment_fn=augment_fn,
-                        pose_dir=pose_dir))(name)
-                    for name in npz_names)
-            elif data_source == 'MegaDepth':
+        augment_fn = self.augment_fn if mode == "train" else None
+        data_source = (
+            self.trainval_data_source
+            if mode in ["train", "val"]
+            else self.test_data_source
+        )
+        if str(data_source).lower() == "megadepth":
+            npz_names = [f"{n}.npz" for n in npz_names]
+        # npz_names=npz_names[:8]
+        with tqdm_joblib(
+            tqdm(
+                desc=f"[rank:{self.rank}] loading {mode} datasets",
+                total=len(npz_names),
+                disable=int(self.rank) != 0,
+            )
+        ):
+            if data_source == "ScanNet":
+                datasets = Parallel(
+                    n_jobs=math.floor(
+                        len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()
+                    )
+                )(
+                    delayed(
+                        lambda x: _build_dataset(
+                            ScanNetDataset,
+                            data_root,
+                            osp.join(npz_dir, x),
+                            intrinsic_path,
+                            mode=mode,
+                            min_overlap_score=min_overlap_score,
+                            augment_fn=augment_fn,
+                            pose_dir=pose_dir,
+                        )
+                    )(name)
+                    for name in npz_names
+                )
+            elif data_source == "MegaDepth":
                 # TODO: _pickle.PicklingError: Could not pickle the task to send it to the workers.
                 raise NotImplementedError()
-                datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))(
-                    delayed(lambda x: _build_dataset(
-                        MegaDepthDataset,
-                        data_root,
-                        osp.join(npz_dir, x),
-                        mode=mode,
-                        min_overlap_score=min_overlap_score,
-                        img_resize=self.mgdpt_img_resize,
-                        df=self.mgdpt_df,
-                        img_padding=self.mgdpt_img_pad,
-                        depth_padding=self.mgdpt_depth_pad,
-                        augment_fn=augment_fn,
-                        coarse_scale=self.coarse_scale))(name)
-                    for name in npz_names)
+                datasets = Parallel(
+                    n_jobs=math.floor(
+                        len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()
+                    )
+                )(
+                    delayed(
+                        lambda x: _build_dataset(
+                            MegaDepthDataset,
+                            data_root,
+                            osp.join(npz_dir, x),
+                            mode=mode,
+                            min_overlap_score=min_overlap_score,
+                            img_resize=self.mgdpt_img_resize,
+                            df=self.mgdpt_df,
+                            img_padding=self.mgdpt_img_pad,
+                            depth_padding=self.mgdpt_depth_pad,
+                            augment_fn=augment_fn,
+                            coarse_scale=self.coarse_scale,
+                        )
+                    )(name)
+                    for name in npz_names
+                )
             else:
-                raise ValueError(f'Unknown dataset: {data_source}')
+                raise ValueError(f"Unknown dataset: {data_source}")
         return ConcatDataset(datasets)
 
     def train_dataloader(self):
-        """ Build training dataloader for ScanNet / MegaDepth. """
-        assert self.data_sampler in ['scene_balance']
-        logger.info(f'[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!).')
-        if self.data_sampler == 'scene_balance':
-            sampler = RandomConcatSampler(self.train_dataset,
-                                          self.n_samples_per_subset,
-                                          self.subset_replacement,
-                                          self.shuffle, self.repeat, self.seed)
+        """Build training dataloader for ScanNet / MegaDepth."""
+        assert self.data_sampler in ["scene_balance"]
+        logger.info(
+            f"[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!)."
+        )
+        if self.data_sampler == "scene_balance":
+            sampler = RandomConcatSampler(
+                self.train_dataset,
+                self.n_samples_per_subset,
+                self.subset_replacement,
+                self.shuffle,
+                self.repeat,
+                self.seed,
+            )
         else:
             sampler = None
-        dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params)
+        dataloader = DataLoader(
+            self.train_dataset, sampler=sampler, **self.train_loader_params
+        )
         return dataloader
-    
+
     def val_dataloader(self):
-        """ Build validation dataloader for ScanNet / MegaDepth. """
-        logger.info(f'[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init.')
+        """Build validation dataloader for ScanNet / MegaDepth."""
+        logger.info(
+            f"[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init."
+        )
         if not isinstance(self.val_dataset, abc.Sequence):
             sampler = DistributedSampler(self.val_dataset, shuffle=False)
-            return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params)
+            return DataLoader(
+                self.val_dataset, sampler=sampler, **self.val_loader_params
+            )
         else:
             dataloaders = []
             for dataset in self.val_dataset:
                 sampler = DistributedSampler(dataset, shuffle=False)
-                dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params))
+                dataloaders.append(
+                    DataLoader(dataset, sampler=sampler, **self.val_loader_params)
+                )
             return dataloaders
 
     def test_dataloader(self, *args, **kwargs):
-        logger.info(f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.')
+        logger.info(
+            f"[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init."
+        )
         sampler = DistributedSampler(self.test_dataset, shuffle=False)
         return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params)
 
diff --git a/third_party/ASpanFormer/src/lightning/lightning_aspanformer.py b/third_party/ASpanFormer/src/lightning/lightning_aspanformer.py
index ee20cbec4628b73c08358ebf1e1906fb2c0ac13c..9b34b7b7485d4419390614e3fe0174ccc53ac7a9 100644
--- a/third_party/ASpanFormer/src/lightning/lightning_aspanformer.py
+++ b/third_party/ASpanFormer/src/lightning/lightning_aspanformer.py
@@ -1,4 +1,3 @@
-
 from collections import defaultdict
 import pprint
 from loguru import logger
@@ -10,15 +9,19 @@ import pytorch_lightning as pl
 from matplotlib import pyplot as plt
 
 from src.ASpanFormer.aspanformer import ASpanFormer
-from src.ASpanFormer.utils.supervision import compute_supervision_coarse, compute_supervision_fine
+from src.ASpanFormer.utils.supervision import (
+    compute_supervision_coarse,
+    compute_supervision_fine,
+)
 from src.losses.aspan_loss import ASpanLoss
 from src.optimizers import build_optimizer, build_scheduler
 from src.utils.metrics import (
-    compute_symmetrical_epipolar_errors,compute_symmetrical_epipolar_errors_offset_bidirectional,
+    compute_symmetrical_epipolar_errors,
+    compute_symmetrical_epipolar_errors_offset_bidirectional,
     compute_pose_errors,
-    aggregate_metrics
+    aggregate_metrics,
 )
-from src.utils.plotting import make_matching_figures,make_matching_figures_offset
+from src.utils.plotting import make_matching_figures, make_matching_figures_offset
 from src.utils.comm import gather, all_gather
 from src.utils.misc import lower_config, flattenList
 from src.utils.profiler import PassThroughProfiler
@@ -34,200 +37,288 @@ class PL_ASpanFormer(pl.LightningModule):
         # Misc
         self.config = config  # full config
         _config = lower_config(self.config)
-        self.loftr_cfg = lower_config(_config['aspan'])
+        self.loftr_cfg = lower_config(_config["aspan"])
         self.profiler = profiler or PassThroughProfiler()
-        self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1)
+        self.n_vals_plot = max(
+            config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1
+        )
 
         # Matcher: LoFTR
-        self.matcher = ASpanFormer(config=_config['aspan'])
+        self.matcher = ASpanFormer(config=_config["aspan"])
         self.loss = ASpanLoss(_config)
 
         # Pretrained weights
         print(pretrained_ckpt)
         if pretrained_ckpt:
-            print('load')
-            state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict']
-            msg=self.matcher.load_state_dict(state_dict, strict=False)
+            print("load")
+            state_dict = torch.load(pretrained_ckpt, map_location="cpu")["state_dict"]
+            msg = self.matcher.load_state_dict(state_dict, strict=False)
             print(msg)
-            logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint")
-        
+            logger.info(f"Load '{pretrained_ckpt}' as pretrained checkpoint")
+
         # Testing
         self.dump_dir = dump_dir
-        
+
     def configure_optimizers(self):
         # FIXME: The scheduler did not work properly when `--resume_from_checkpoint`
         optimizer = build_optimizer(self, self.config)
         scheduler = build_scheduler(self.config, optimizer)
         return [optimizer], [scheduler]
-    
+
     def optimizer_step(
-            self, epoch, batch_idx, optimizer, optimizer_idx,
-            optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
+        self,
+        epoch,
+        batch_idx,
+        optimizer,
+        optimizer_idx,
+        optimizer_closure,
+        on_tpu,
+        using_native_amp,
+        using_lbfgs,
+    ):
         # learning rate warm up
         warmup_step = self.config.TRAINER.WARMUP_STEP
         if self.trainer.global_step < warmup_step:
-            if self.config.TRAINER.WARMUP_TYPE == 'linear':
+            if self.config.TRAINER.WARMUP_TYPE == "linear":
                 base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
-                lr = base_lr + \
-                    (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \
-                    abs(self.config.TRAINER.TRUE_LR - base_lr)
+                lr = base_lr + (
+                    self.trainer.global_step / self.config.TRAINER.WARMUP_STEP
+                ) * abs(self.config.TRAINER.TRUE_LR - base_lr)
                 for pg in optimizer.param_groups:
-                    pg['lr'] = lr
-            elif self.config.TRAINER.WARMUP_TYPE == 'constant':
+                    pg["lr"] = lr
+            elif self.config.TRAINER.WARMUP_TYPE == "constant":
                 pass
             else:
-                raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}')
+                raise ValueError(
+                    f"Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}"
+                )
 
         # update params
         optimizer.step(closure=optimizer_closure)
         optimizer.zero_grad()
-    
+
     def _trainval_inference(self, batch):
         with self.profiler.profile("Compute coarse supervision"):
-            compute_supervision_coarse(batch, self.config) 
-        
+            compute_supervision_coarse(batch, self.config)
+
         with self.profiler.profile("LoFTR"):
-            self.matcher(batch) 
-        
+            self.matcher(batch)
+
         with self.profiler.profile("Compute fine supervision"):
-            compute_supervision_fine(batch, self.config) 
-        
+            compute_supervision_fine(batch, self.config)
+
         with self.profiler.profile("Compute losses"):
-            self.loss(batch) 
-    
+            self.loss(batch)
+
     def _compute_metrics(self, batch):
         with self.profiler.profile("Copmute metrics"):
-            compute_symmetrical_epipolar_errors(batch)  # compute epi_errs for each match
-            compute_symmetrical_epipolar_errors_offset_bidirectional(batch) # compute epi_errs for offset match
-            compute_pose_errors(batch, self.config)  # compute R_errs, t_errs, pose_errs for each pair
+            compute_symmetrical_epipolar_errors(
+                batch
+            )  # compute epi_errs for each match
+            compute_symmetrical_epipolar_errors_offset_bidirectional(
+                batch
+            )  # compute epi_errs for offset match
+            compute_pose_errors(
+                batch, self.config
+            )  # compute R_errs, t_errs, pose_errs for each pair
 
-            rel_pair_names = list(zip(*batch['pair_names']))
-            bs = batch['image0'].size(0)
+            rel_pair_names = list(zip(*batch["pair_names"]))
+            bs = batch["image0"].size(0)
             metrics = {
                 # to filter duplicate pairs caused by DistributedSampler
-                'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
-                'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)],
-                'epi_errs_offset': [batch['epi_errs_offset_left'][batch['offset_bids_left'] == b].cpu().numpy() for b in range(bs)], #only consider left side
-                'R_errs': batch['R_errs'],
-                't_errs': batch['t_errs'],
-                'inliers': batch['inliers']}
-            ret_dict = {'metrics': metrics}
+                "identifiers": ["#".join(rel_pair_names[b]) for b in range(bs)],
+                "epi_errs": [
+                    batch["epi_errs"][batch["m_bids"] == b].cpu().numpy()
+                    for b in range(bs)
+                ],
+                "epi_errs_offset": [
+                    batch["epi_errs_offset_left"][batch["offset_bids_left"] == b]
+                    .cpu()
+                    .numpy()
+                    for b in range(bs)
+                ],  # only consider left side
+                "R_errs": batch["R_errs"],
+                "t_errs": batch["t_errs"],
+                "inliers": batch["inliers"],
+            }
+            ret_dict = {"metrics": metrics}
         return ret_dict, rel_pair_names
-    
-   
+
     def training_step(self, batch, batch_idx):
         self._trainval_inference(batch)
-        
+
         # logging
-        if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0:
+        if (
+            self.trainer.global_rank == 0
+            and self.global_step % self.trainer.log_every_n_steps == 0
+        ):
             # scalars
-            for k, v in batch['loss_scalars'].items():
-                if not k.startswith('loss_flow') and not k.startswith('conf_'):
-                    self.logger.experiment.add_scalar(f'train/{k}', v, self.global_step)
-            
-            #log offset_loss and conf for each layer and level
-            layer_num=self.loftr_cfg['coarse']['layer_num']
+            for k, v in batch["loss_scalars"].items():
+                if not k.startswith("loss_flow") and not k.startswith("conf_"):
+                    self.logger.experiment.add_scalar(f"train/{k}", v, self.global_step)
+
+            # log offset_loss and conf for each layer and level
+            layer_num = self.loftr_cfg["coarse"]["layer_num"]
             for layer_index in range(layer_num):
-                log_title='layer_'+str(layer_index)
-                self.logger.experiment.add_scalar(log_title+'/offset_loss', batch['loss_scalars']['loss_flow_'+str(layer_index)], self.global_step)
-                self.logger.experiment.add_scalar(log_title+'/conf_', batch['loss_scalars']['conf_'+str(layer_index)],self.global_step)
-            
+                log_title = "layer_" + str(layer_index)
+                self.logger.experiment.add_scalar(
+                    log_title + "/offset_loss",
+                    batch["loss_scalars"]["loss_flow_" + str(layer_index)],
+                    self.global_step,
+                )
+                self.logger.experiment.add_scalar(
+                    log_title + "/conf_",
+                    batch["loss_scalars"]["conf_" + str(layer_index)],
+                    self.global_step,
+                )
+
             # net-params
-            if self.config.ASPAN.MATCH_COARSE.MATCH_TYPE == 'sinkhorn':
+            if self.config.ASPAN.MATCH_COARSE.MATCH_TYPE == "sinkhorn":
                 self.logger.experiment.add_scalar(
-                    f'skh_bin_score', self.matcher.coarse_matching.bin_score.clone().detach().cpu().data, self.global_step)
+                    f"skh_bin_score",
+                    self.matcher.coarse_matching.bin_score.clone().detach().cpu().data,
+                    self.global_step,
+                )
 
             # figures
             if self.config.TRAINER.ENABLE_PLOTTING:
-                compute_symmetrical_epipolar_errors(batch)  # compute epi_errs for each match
-                figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE)
+                compute_symmetrical_epipolar_errors(
+                    batch
+                )  # compute epi_errs for each match
+                figures = make_matching_figures(
+                    batch, self.config, self.config.TRAINER.PLOT_MODE
+                )
                 for k, v in figures.items():
-                    self.logger.experiment.add_figure(f'train_match/{k}', v, self.global_step)
+                    self.logger.experiment.add_figure(
+                        f"train_match/{k}", v, self.global_step
+                    )
 
-                #plot offset 
-                if self.global_step%200==0:
+                # plot offset
+                if self.global_step % 200 == 0:
                     compute_symmetrical_epipolar_errors_offset_bidirectional(batch)
-                    figures_left = make_matching_figures_offset(batch, self.config, self.config.TRAINER.PLOT_MODE,side='_left')
-                    figures_right = make_matching_figures_offset(batch, self.config, self.config.TRAINER.PLOT_MODE,side='_right')
+                    figures_left = make_matching_figures_offset(
+                        batch, self.config, self.config.TRAINER.PLOT_MODE, side="_left"
+                    )
+                    figures_right = make_matching_figures_offset(
+                        batch, self.config, self.config.TRAINER.PLOT_MODE, side="_right"
+                    )
                     for k, v in figures_left.items():
-                        self.logger.experiment.add_figure(f'train_offset/{k}'+'_left', v, self.global_step)
-                    figures = make_matching_figures_offset(batch, self.config, self.config.TRAINER.PLOT_MODE,side='_right')
+                        self.logger.experiment.add_figure(
+                            f"train_offset/{k}" + "_left", v, self.global_step
+                        )
+                    figures = make_matching_figures_offset(
+                        batch, self.config, self.config.TRAINER.PLOT_MODE, side="_right"
+                    )
                     for k, v in figures_right.items():
-                        self.logger.experiment.add_figure(f'train_offset/{k}'+'_right', v, self.global_step)
-                
-        return {'loss': batch['loss']}
+                        self.logger.experiment.add_figure(
+                            f"train_offset/{k}" + "_right", v, self.global_step
+                        )
+
+        return {"loss": batch["loss"]}
 
     def training_epoch_end(self, outputs):
-        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
+        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
         if self.trainer.global_rank == 0:
             self.logger.experiment.add_scalar(
-                'train/avg_loss_on_epoch', avg_loss,
-                global_step=self.current_epoch)
-    
+                "train/avg_loss_on_epoch", avg_loss, global_step=self.current_epoch
+            )
+
     def validation_step(self, batch, batch_idx):
         self._trainval_inference(batch)
-         
-        ret_dict, _ = self._compute_metrics(batch) #this func also compute the epi_errors
-        
+
+        ret_dict, _ = self._compute_metrics(
+            batch
+        )  # this func also compute the epi_errors
+
         val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1)
         figures = {self.config.TRAINER.PLOT_MODE: []}
         figures_offset = {self.config.TRAINER.PLOT_MODE: []}
         if batch_idx % val_plot_interval == 0:
-            figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE)
-            figures_offset=make_matching_figures_offset(batch, self.config, self.config.TRAINER.PLOT_MODE,'_left')
+            figures = make_matching_figures(
+                batch, self.config, mode=self.config.TRAINER.PLOT_MODE
+            )
+            figures_offset = make_matching_figures_offset(
+                batch, self.config, self.config.TRAINER.PLOT_MODE, "_left"
+            )
         return {
             **ret_dict,
-            'loss_scalars': batch['loss_scalars'],
-            'figures': figures,
-            'figures_offset_left':figures_offset
+            "loss_scalars": batch["loss_scalars"],
+            "figures": figures,
+            "figures_offset_left": figures_offset,
         }
-        
+
     def validation_epoch_end(self, outputs):
         # handle multiple validation sets
-        multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
+        multi_outputs = (
+            [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
+        )
         multi_val_metrics = defaultdict(list)
-        
+
         for valset_idx, outputs in enumerate(multi_outputs):
             # since pl performs sanity_check at the very begining of the training
             cur_epoch = self.trainer.current_epoch
-            if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check:
+            if (
+                not self.trainer.resume_from_checkpoint
+                and self.trainer.running_sanity_check
+            ):
                 cur_epoch = -1
 
             # 1. loss_scalars: dict of list, on cpu
-            _loss_scalars = [o['loss_scalars'] for o in outputs]
-            loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]}
+            _loss_scalars = [o["loss_scalars"] for o in outputs]
+            loss_scalars = {
+                k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars]))
+                for k in _loss_scalars[0]
+            }
 
             # 2. val metrics: dict of list, numpy
-            _metrics = [o['metrics'] for o in outputs]
-            metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
-            # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0 
-            val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
+            _metrics = [o["metrics"] for o in outputs]
+            metrics = {
+                k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics])))
+                for k in _metrics[0]
+            }
+            # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
+            val_metrics_4tb = aggregate_metrics(
+                metrics, self.config.TRAINER.EPI_ERR_THR
+            )
             for thr in [5, 10, 20]:
-                multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}'])
-            
+                multi_val_metrics[f"auc@{thr}"].append(val_metrics_4tb[f"auc@{thr}"])
+
             # 3. figures
-            _figures = [o['figures'] for o in outputs]
-            figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]}
+            _figures = [o["figures"] for o in outputs]
+            figures = {
+                k: flattenList(gather(flattenList([_me[k] for _me in _figures])))
+                for k in _figures[0]
+            }
 
             # tensorboard records only on rank 0
             if self.trainer.global_rank == 0:
                 for k, v in loss_scalars.items():
                     mean_v = torch.stack(v).mean()
-                    self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch)
+                    self.logger.experiment.add_scalar(
+                        f"val_{valset_idx}/avg_{k}", mean_v, global_step=cur_epoch
+                    )
 
                 for k, v in val_metrics_4tb.items():
-                    self.logger.experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch)
-                
+                    self.logger.experiment.add_scalar(
+                        f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch
+                    )
+
                 for k, v in figures.items():
                     if self.trainer.global_rank == 0:
                         for plot_idx, fig in enumerate(v):
                             self.logger.experiment.add_figure(
-                                f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True)
-            plt.close('all')
+                                f"val_match_{valset_idx}/{k}/pair-{plot_idx}",
+                                fig,
+                                cur_epoch,
+                                close=True,
+                            )
+            plt.close("all")
 
         for thr in [5, 10, 20]:
             # log on all ranks for ModelCheckpoint callback to work properly
-            self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}'])))  # ckpt monitors on this
+            self.log(
+                f"auc@{thr}", torch.tensor(np.mean(multi_val_metrics[f"auc@{thr}"]))
+            )  # ckpt monitors on this
 
     def test_step(self, batch, batch_idx):
         with self.profiler.profile("LoFTR"):
@@ -238,39 +329,46 @@ class PL_ASpanFormer(pl.LightningModule):
         with self.profiler.profile("dump_results"):
             if self.dump_dir is not None:
                 # dump results for further analysis
-                keys_to_save = {'mkpts0_f', 'mkpts1_f', 'mconf', 'epi_errs'}
-                pair_names = list(zip(*batch['pair_names']))
-                bs = batch['image0'].shape[0]
+                keys_to_save = {"mkpts0_f", "mkpts1_f", "mconf", "epi_errs"}
+                pair_names = list(zip(*batch["pair_names"]))
+                bs = batch["image0"].shape[0]
                 dumps = []
                 for b_id in range(bs):
                     item = {}
-                    mask = batch['m_bids'] == b_id
-                    item['pair_names'] = pair_names[b_id]
-                    item['identifier'] = '#'.join(rel_pair_names[b_id])
+                    mask = batch["m_bids"] == b_id
+                    item["pair_names"] = pair_names[b_id]
+                    item["identifier"] = "#".join(rel_pair_names[b_id])
                     for key in keys_to_save:
                         item[key] = batch[key][mask].cpu().numpy()
-                    for key in ['R_errs', 't_errs', 'inliers']:
+                    for key in ["R_errs", "t_errs", "inliers"]:
                         item[key] = batch[key][b_id]
                     dumps.append(item)
-                ret_dict['dumps'] = dumps
+                ret_dict["dumps"] = dumps
 
         return ret_dict
 
     def test_epoch_end(self, outputs):
         # metrics: dict of list, numpy
-        _metrics = [o['metrics'] for o in outputs]
-        metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
+        _metrics = [o["metrics"] for o in outputs]
+        metrics = {
+            k: flattenList(gather(flattenList([_me[k] for _me in _metrics])))
+            for k in _metrics[0]
+        }
 
         # [{key: [{...}, *#bs]}, *#batch]
         if self.dump_dir is not None:
             Path(self.dump_dir).mkdir(parents=True, exist_ok=True)
-            _dumps = flattenList([o['dumps'] for o in outputs])  # [{...}, #bs*#batch]
+            _dumps = flattenList([o["dumps"] for o in outputs])  # [{...}, #bs*#batch]
             dumps = flattenList(gather(_dumps))  # [{...}, #proc*#bs*#batch]
-            logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}')
+            logger.info(
+                f"Prediction and evaluation results will be saved to: {self.dump_dir}"
+            )
 
         if self.trainer.global_rank == 0:
             print(self.profiler.summary())
-            val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
-            logger.info('\n' + pprint.pformat(val_metrics_4tb))
+            val_metrics_4tb = aggregate_metrics(
+                metrics, self.config.TRAINER.EPI_ERR_THR
+            )
+            logger.info("\n" + pprint.pformat(val_metrics_4tb))
             if self.dump_dir is not None:
-                np.save(Path(self.dump_dir) / 'LoFTR_pred_eval', dumps)
+                np.save(Path(self.dump_dir) / "LoFTR_pred_eval", dumps)
diff --git a/third_party/ASpanFormer/src/losses/aspan_loss.py b/third_party/ASpanFormer/src/losses/aspan_loss.py
index 0cca52b36fc997415937969f26caba8c41ac2b8e..dc0f33391b95b6f4f39f673ebc07f6991a00491f 100644
--- a/third_party/ASpanFormer/src/losses/aspan_loss.py
+++ b/third_party/ASpanFormer/src/losses/aspan_loss.py
@@ -3,48 +3,55 @@ from loguru import logger
 import torch
 import torch.nn as nn
 
+
 class ASpanLoss(nn.Module):
     def __init__(self, config):
         super().__init__()
         self.config = config  # config under the global namespace
-        self.loss_config = config['aspan']['loss']
-        self.match_type = self.config['aspan']['match_coarse']['match_type']
-        self.sparse_spvs = self.config['aspan']['match_coarse']['sparse_spvs']
-        self.flow_weight=self.config['aspan']['loss']['flow_weight']
+        self.loss_config = config["aspan"]["loss"]
+        self.match_type = self.config["aspan"]["match_coarse"]["match_type"]
+        self.sparse_spvs = self.config["aspan"]["match_coarse"]["sparse_spvs"]
+        self.flow_weight = self.config["aspan"]["loss"]["flow_weight"]
 
         # coarse-level
-        self.correct_thr = self.loss_config['fine_correct_thr']
-        self.c_pos_w = self.loss_config['pos_weight']
-        self.c_neg_w = self.loss_config['neg_weight']
+        self.correct_thr = self.loss_config["fine_correct_thr"]
+        self.c_pos_w = self.loss_config["pos_weight"]
+        self.c_neg_w = self.loss_config["neg_weight"]
         # fine-level
-        self.fine_type = self.loss_config['fine_type']
-
-    def compute_flow_loss(self,coarse_corr_gt,flow_list,h0,w0,h1,w1):
-        #coarse_corr_gt:[[batch_indices],[left_indices],[right_indices]]
-        #flow_list: [L,B,H,W,4]
-        loss1=self.flow_loss_worker(flow_list[0],coarse_corr_gt[0],coarse_corr_gt[1],coarse_corr_gt[2],w1)
-        loss2=self.flow_loss_worker(flow_list[1],coarse_corr_gt[0],coarse_corr_gt[2],coarse_corr_gt[1],w0)
-        total_loss=(loss1+loss2)/2
+        self.fine_type = self.loss_config["fine_type"]
+
+    def compute_flow_loss(self, coarse_corr_gt, flow_list, h0, w0, h1, w1):
+        # coarse_corr_gt:[[batch_indices],[left_indices],[right_indices]]
+        # flow_list: [L,B,H,W,4]
+        loss1 = self.flow_loss_worker(
+            flow_list[0], coarse_corr_gt[0], coarse_corr_gt[1], coarse_corr_gt[2], w1
+        )
+        loss2 = self.flow_loss_worker(
+            flow_list[1], coarse_corr_gt[0], coarse_corr_gt[2], coarse_corr_gt[1], w0
+        )
+        total_loss = (loss1 + loss2) / 2
         return total_loss
 
-    def flow_loss_worker(self,flow,batch_indicies,self_indicies,cross_indicies,w):
-        bs,layer_num=flow.shape[1],flow.shape[0]
-        flow=flow.view(layer_num,bs,-1,4)
-        gt_flow=torch.stack([cross_indicies%w,cross_indicies//w],dim=1)
+    def flow_loss_worker(self, flow, batch_indicies, self_indicies, cross_indicies, w):
+        bs, layer_num = flow.shape[1], flow.shape[0]
+        flow = flow.view(layer_num, bs, -1, 4)
+        gt_flow = torch.stack([cross_indicies % w, cross_indicies // w], dim=1)
 
-        total_loss_list=[]
+        total_loss_list = []
         for layer_index in range(layer_num):
-            cur_flow_list=flow[layer_index]
-            spv_flow=cur_flow_list[batch_indicies,self_indicies][:,:2]
-            spv_conf=cur_flow_list[batch_indicies,self_indicies][:,2:]#[#coarse,2]
-            l2_flow_dis=((gt_flow-spv_flow)**2) #[#coarse,2]
-            total_loss=(spv_conf+torch.exp(-spv_conf)*l2_flow_dis) #[#coarse,2]
+            cur_flow_list = flow[layer_index]
+            spv_flow = cur_flow_list[batch_indicies, self_indicies][:, :2]
+            spv_conf = cur_flow_list[batch_indicies, self_indicies][
+                :, 2:
+            ]  # [#coarse,2]
+            l2_flow_dis = (gt_flow - spv_flow) ** 2  # [#coarse,2]
+            total_loss = spv_conf + torch.exp(-spv_conf) * l2_flow_dis  # [#coarse,2]
             total_loss_list.append(total_loss.mean())
-        total_loss=torch.stack(total_loss_list,dim=-1)*self.flow_weight
+        total_loss = torch.stack(total_loss_list, dim=-1) * self.flow_weight
         return total_loss
-        
+
     def compute_coarse_loss(self, conf, conf_gt, weight=None):
-        """ Point-wise CE / Focal Loss with 0 / 1 confidence as gt.
+        """Point-wise CE / Focal Loss with 0 / 1 confidence as gt.
         Args:
             conf (torch.Tensor): (N, HW0, HW1) / (N, HW0+1, HW1+1)
             conf_gt (torch.Tensor): (N, HW0, HW1)
@@ -56,38 +63,44 @@ class ASpanLoss(nn.Module):
         if not pos_mask.any():  # assign a wrong gt
             pos_mask[0, 0, 0] = True
             if weight is not None:
-                weight[0, 0, 0] = 0.
-            c_pos_w = 0.
+                weight[0, 0, 0] = 0.0
+            c_pos_w = 0.0
         if not neg_mask.any():
             neg_mask[0, 0, 0] = True
             if weight is not None:
-                weight[0, 0, 0] = 0.
-            c_neg_w = 0.
-
-        if self.loss_config['coarse_type'] == 'cross_entropy':
-            assert not self.sparse_spvs, 'Sparse Supervision for cross-entropy not implemented!'
-            conf = torch.clamp(conf, 1e-6, 1-1e-6)
-            loss_pos = - torch.log(conf[pos_mask])
-            loss_neg = - torch.log(1 - conf[neg_mask])
+                weight[0, 0, 0] = 0.0
+            c_neg_w = 0.0
+
+        if self.loss_config["coarse_type"] == "cross_entropy":
+            assert (
+                not self.sparse_spvs
+            ), "Sparse Supervision for cross-entropy not implemented!"
+            conf = torch.clamp(conf, 1e-6, 1 - 1e-6)
+            loss_pos = -torch.log(conf[pos_mask])
+            loss_neg = -torch.log(1 - conf[neg_mask])
             if weight is not None:
                 loss_pos = loss_pos * weight[pos_mask]
                 loss_neg = loss_neg * weight[neg_mask]
             return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean()
-        elif self.loss_config['coarse_type'] == 'focal':
-            conf = torch.clamp(conf, 1e-6, 1-1e-6)
-            alpha = self.loss_config['focal_alpha']
-            gamma = self.loss_config['focal_gamma']
-            
+        elif self.loss_config["coarse_type"] == "focal":
+            conf = torch.clamp(conf, 1e-6, 1 - 1e-6)
+            alpha = self.loss_config["focal_alpha"]
+            gamma = self.loss_config["focal_gamma"]
+
             if self.sparse_spvs:
-                pos_conf = conf[:, :-1, :-1][pos_mask] \
-                            if self.match_type == 'sinkhorn' \
-                            else conf[pos_mask]
-                loss_pos = - alpha * torch.pow(1 - pos_conf, gamma) * pos_conf.log()
+                pos_conf = (
+                    conf[:, :-1, :-1][pos_mask]
+                    if self.match_type == "sinkhorn"
+                    else conf[pos_mask]
+                )
+                loss_pos = -alpha * torch.pow(1 - pos_conf, gamma) * pos_conf.log()
                 # calculate losses for negative samples
-                if self.match_type == 'sinkhorn':
+                if self.match_type == "sinkhorn":
                     neg0, neg1 = conf_gt.sum(-1) == 0, conf_gt.sum(1) == 0
-                    neg_conf = torch.cat([conf[:, :-1, -1][neg0], conf[:, -1, :-1][neg1]], 0)
-                    loss_neg = - alpha * torch.pow(1 - neg_conf, gamma) * neg_conf.log()
+                    neg_conf = torch.cat(
+                        [conf[:, :-1, -1][neg0], conf[:, -1, :-1][neg1]], 0
+                    )
+                    loss_neg = -alpha * torch.pow(1 - neg_conf, gamma) * neg_conf.log()
                 else:
                     # These is no dustbin for dual_softmax, so we left unmatchable patches without supervision.
                     # we could also add 'pseudo negtive-samples'
@@ -97,32 +110,46 @@ class ASpanLoss(nn.Module):
                     # Different from dense-spvs, the loss w.r.t. padded regions aren't directly zeroed out,
                     # but only through manually setting corresponding regions in sim_matrix to '-inf'.
                     loss_pos = loss_pos * weight[pos_mask]
-                    if self.match_type == 'sinkhorn':
+                    if self.match_type == "sinkhorn":
                         neg_w0 = (weight.sum(-1) != 0)[neg0]
                         neg_w1 = (weight.sum(1) != 0)[neg1]
                         neg_mask = torch.cat([neg_w0, neg_w1], 0)
                         loss_neg = loss_neg[neg_mask]
-                
-                loss =  c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean() \
-                            if self.match_type == 'sinkhorn' \
-                            else c_pos_w * loss_pos.mean()
+
+                loss = (
+                    c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean()
+                    if self.match_type == "sinkhorn"
+                    else c_pos_w * loss_pos.mean()
+                )
                 return loss
                 # positive and negative elements occupy similar propotions. => more balanced loss weights needed
             else:  # dense supervision (in the case of match_type=='sinkhorn', the dustbin is not supervised.)
-                loss_pos = - alpha * torch.pow(1 - conf[pos_mask], gamma) * (conf[pos_mask]).log()
-                loss_neg = - alpha * torch.pow(conf[neg_mask], gamma) * (1 - conf[neg_mask]).log()
+                loss_pos = (
+                    -alpha
+                    * torch.pow(1 - conf[pos_mask], gamma)
+                    * (conf[pos_mask]).log()
+                )
+                loss_neg = (
+                    -alpha
+                    * torch.pow(conf[neg_mask], gamma)
+                    * (1 - conf[neg_mask]).log()
+                )
                 if weight is not None:
                     loss_pos = loss_pos * weight[pos_mask]
                     loss_neg = loss_neg * weight[neg_mask]
                 return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean()
                 # each negative element occupy a smaller propotion than positive elements. => higher negative loss weight needed
         else:
-            raise ValueError('Unknown coarse loss: {type}'.format(type=self.loss_config['coarse_type']))
-        
+            raise ValueError(
+                "Unknown coarse loss: {type}".format(
+                    type=self.loss_config["coarse_type"]
+                )
+            )
+
     def compute_fine_loss(self, expec_f, expec_f_gt):
-        if self.fine_type == 'l2_with_std':
+        if self.fine_type == "l2_with_std":
             return self._compute_fine_loss_l2_std(expec_f, expec_f_gt)
-        elif self.fine_type == 'l2':
+        elif self.fine_type == "l2":
             return self._compute_fine_loss_l2(expec_f, expec_f_gt)
         else:
             raise NotImplementedError()
@@ -133,9 +160,13 @@ class ASpanLoss(nn.Module):
             expec_f (torch.Tensor): [M, 2] <x, y>
             expec_f_gt (torch.Tensor): [M, 2] <x, y>
         """
-        correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr
+        correct_mask = (
+            torch.linalg.norm(expec_f_gt, ord=float("inf"), dim=1) < self.correct_thr
+        )
         if correct_mask.sum() == 0:
-            if self.training:  # this seldomly happen when training, since we pad prediction with gt
+            if (
+                self.training
+            ):  # this seldomly happen when training, since we pad prediction with gt
                 logger.warning("assign a false supervision to avoid ddp deadlock")
                 correct_mask[0] = True
             else:
@@ -150,20 +181,26 @@ class ASpanLoss(nn.Module):
             expec_f_gt (torch.Tensor): [M, 2] <x, y>
         """
         # correct_mask tells you which pair to compute fine-loss
-        correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr
+        correct_mask = (
+            torch.linalg.norm(expec_f_gt, ord=float("inf"), dim=1) < self.correct_thr
+        )
 
         # use std as weight that measures uncertainty
         std = expec_f[:, 2]
-        inverse_std = 1. / torch.clamp(std, min=1e-10)
-        weight = (inverse_std / torch.mean(inverse_std)).detach()  # avoid minizing loss through increase std
+        inverse_std = 1.0 / torch.clamp(std, min=1e-10)
+        weight = (
+            inverse_std / torch.mean(inverse_std)
+        ).detach()  # avoid minizing loss through increase std
 
         # corner case: no correct coarse match found
         if not correct_mask.any():
-            if self.training:  # this seldomly happen during training, since we pad prediction with gt
-                               # sometimes there is not coarse-level gt at all.
+            if (
+                self.training
+            ):  # this seldomly happen during training, since we pad prediction with gt
+                # sometimes there is not coarse-level gt at all.
                 logger.warning("assign a false supervision to avoid ddp deadlock")
                 correct_mask[0] = True
-                weight[0] = 0.
+                weight[0] = 0.0
             else:
                 return None
 
@@ -172,12 +209,15 @@ class ASpanLoss(nn.Module):
         loss = (flow_l2 * weight[correct_mask]).mean()
 
         return loss
-    
+
     @torch.no_grad()
     def compute_c_weight(self, data):
-        """ compute element-wise weights for computing coarse-level loss. """
-        if 'mask0' in data:
-            c_weight = (data['mask0'].flatten(-2)[..., None] * data['mask1'].flatten(-2)[:, None]).float()
+        """compute element-wise weights for computing coarse-level loss."""
+        if "mask0" in data:
+            c_weight = (
+                data["mask0"].flatten(-2)[..., None]
+                * data["mask1"].flatten(-2)[:, None]
+            ).float()
         else:
             c_weight = None
         return c_weight
@@ -196,36 +236,54 @@ class ASpanLoss(nn.Module):
 
         # 1. coarse-level loss
         loss_c = self.compute_coarse_loss(
-            data['conf_matrix_with_bin'] if self.sparse_spvs and self.match_type == 'sinkhorn' \
-                else data['conf_matrix'],
-            data['conf_matrix_gt'],
-            weight=c_weight)
-        loss = loss_c * self.loss_config['coarse_weight']
+            data["conf_matrix_with_bin"]
+            if self.sparse_spvs and self.match_type == "sinkhorn"
+            else data["conf_matrix"],
+            data["conf_matrix_gt"],
+            weight=c_weight,
+        )
+        loss = loss_c * self.loss_config["coarse_weight"]
         loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()})
 
         # 2. fine-level loss
-        loss_f = self.compute_fine_loss(data['expec_f'], data['expec_f_gt'])
+        loss_f = self.compute_fine_loss(data["expec_f"], data["expec_f_gt"])
         if loss_f is not None:
-            loss += loss_f * self.loss_config['fine_weight']
-            loss_scalars.update({"loss_f":  loss_f.clone().detach().cpu()})
+            loss += loss_f * self.loss_config["fine_weight"]
+            loss_scalars.update({"loss_f": loss_f.clone().detach().cpu()})
         else:
             assert self.training is False
-            loss_scalars.update({'loss_f': torch.tensor(1.)})  # 1 is the upper bound
-        
+            loss_scalars.update({"loss_f": torch.tensor(1.0)})  # 1 is the upper bound
+
         # 3. flow loss
-        coarse_corr=[data['spv_b_ids'],data['spv_i_ids'],data['spv_j_ids']]
-        loss_flow = self.compute_flow_loss(coarse_corr,data['predict_flow'],\
-                                            data['hw0_c'][0],data['hw0_c'][1],data['hw1_c'][0],data['hw1_c'][1])
-        loss_flow=loss_flow*self.flow_weight
-        for index,loss_off in enumerate(loss_flow):
-            loss_scalars.update({'loss_flow_'+str(index): loss_off.clone().detach().cpu()})  # 1 is the upper bound
-            conf=data['predict_flow'][0][:,:,:,:,2:]
-            layer_num=conf.shape[0]
+        coarse_corr = [data["spv_b_ids"], data["spv_i_ids"], data["spv_j_ids"]]
+        loss_flow = self.compute_flow_loss(
+            coarse_corr,
+            data["predict_flow"],
+            data["hw0_c"][0],
+            data["hw0_c"][1],
+            data["hw1_c"][0],
+            data["hw1_c"][1],
+        )
+        loss_flow = loss_flow * self.flow_weight
+        for index, loss_off in enumerate(loss_flow):
+            loss_scalars.update(
+                {"loss_flow_" + str(index): loss_off.clone().detach().cpu()}
+            )  # 1 is the upper bound
+            conf = data["predict_flow"][0][:, :, :, :, 2:]
+            layer_num = conf.shape[0]
             for layer_index in range(layer_num):
-                loss_scalars.update({'conf_'+str(layer_index): conf[layer_index].mean().clone().detach().cpu()})  # 1 is the upper bound
-        
-        
-        loss+=loss_flow.sum()
-        #print((loss_c * self.loss_config['coarse_weight']).data,loss_flow.data)
-        loss_scalars.update({'loss': loss.clone().detach().cpu()})
+                loss_scalars.update(
+                    {
+                        "conf_"
+                        + str(layer_index): conf[layer_index]
+                        .mean()
+                        .clone()
+                        .detach()
+                        .cpu()
+                    }
+                )  # 1 is the upper bound
+
+        loss += loss_flow.sum()
+        # print((loss_c * self.loss_config['coarse_weight']).data,loss_flow.data)
+        loss_scalars.update({"loss": loss.clone().detach().cpu()})
         data.update({"loss": loss, "loss_scalars": loss_scalars})
diff --git a/third_party/ASpanFormer/src/optimizers/__init__.py b/third_party/ASpanFormer/src/optimizers/__init__.py
index e1db2285352586c250912bdd2c4ae5029620ab5f..e4e36c22e00217deccacd589f8924b2f74589456 100644
--- a/third_party/ASpanFormer/src/optimizers/__init__.py
+++ b/third_party/ASpanFormer/src/optimizers/__init__.py
@@ -7,9 +7,13 @@ def build_optimizer(model, config):
     lr = config.TRAINER.TRUE_LR
 
     if name == "adam":
-        return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY)
+        return torch.optim.Adam(
+            model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY
+        )
     elif name == "adamw":
-        return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY)
+        return torch.optim.AdamW(
+            model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY
+        )
     else:
         raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
 
@@ -24,18 +28,27 @@ def build_scheduler(config, optimizer):
             'frequency': x, (optional)
         }
     """
-    scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL}
+    scheduler = {"interval": config.TRAINER.SCHEDULER_INTERVAL}
     name = config.TRAINER.SCHEDULER
 
-    if name == 'MultiStepLR':
+    if name == "MultiStepLR":
         scheduler.update(
-            {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)})
-    elif name == 'CosineAnnealing':
+            {
+                "scheduler": MultiStepLR(
+                    optimizer,
+                    config.TRAINER.MSLR_MILESTONES,
+                    gamma=config.TRAINER.MSLR_GAMMA,
+                )
+            }
+        )
+    elif name == "CosineAnnealing":
         scheduler.update(
-            {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)})
-    elif name == 'ExponentialLR':
+            {"scheduler": CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)}
+        )
+    elif name == "ExponentialLR":
         scheduler.update(
-            {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)})
+            {"scheduler": ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)}
+        )
     else:
         raise NotImplementedError()
 
diff --git a/third_party/ASpanFormer/src/utils/augment.py b/third_party/ASpanFormer/src/utils/augment.py
index d7c5d3e11b6fe083aaeff7555bb7ce3a4bfb755d..068751c6c07091bbaed76debd43a73155f61b9bd 100644
--- a/third_party/ASpanFormer/src/utils/augment.py
+++ b/third_party/ASpanFormer/src/utils/augment.py
@@ -7,16 +7,21 @@ class DarkAug(object):
     """
 
     def __init__(self) -> None:
-        self.augmentor = A.Compose([
-            A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)),
-            A.Blur(p=0.1, blur_limit=(3, 9)),
-            A.MotionBlur(p=0.2, blur_limit=(3, 25)),
-            A.RandomGamma(p=0.1, gamma_limit=(15, 65)),
-            A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40))
-        ], p=0.75)
+        self.augmentor = A.Compose(
+            [
+                A.RandomBrightnessContrast(
+                    p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)
+                ),
+                A.Blur(p=0.1, blur_limit=(3, 9)),
+                A.MotionBlur(p=0.2, blur_limit=(3, 25)),
+                A.RandomGamma(p=0.1, gamma_limit=(15, 65)),
+                A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40)),
+            ],
+            p=0.75,
+        )
 
     def __call__(self, x):
-        return self.augmentor(image=x)['image']
+        return self.augmentor(image=x)["image"]
 
 
 class MobileAug(object):
@@ -25,31 +30,36 @@ class MobileAug(object):
     """
 
     def __init__(self):
-        self.augmentor = A.Compose([
-            A.MotionBlur(p=0.25),
-            A.ColorJitter(p=0.5),
-            A.RandomRain(p=0.1),  # random occlusion
-            A.RandomSunFlare(p=0.1),
-            A.JpegCompression(p=0.25),
-            A.ISONoise(p=0.25)
-        ], p=1.0)
+        self.augmentor = A.Compose(
+            [
+                A.MotionBlur(p=0.25),
+                A.ColorJitter(p=0.5),
+                A.RandomRain(p=0.1),  # random occlusion
+                A.RandomSunFlare(p=0.1),
+                A.JpegCompression(p=0.25),
+                A.ISONoise(p=0.25),
+            ],
+            p=1.0,
+        )
 
     def __call__(self, x):
-        return self.augmentor(image=x)['image']
+        return self.augmentor(image=x)["image"]
 
 
 def build_augmentor(method=None, **kwargs):
     if method is not None:
-        raise NotImplementedError('Using of augmentation functions are not supported yet!')
-    if method == 'dark':
+        raise NotImplementedError(
+            "Using of augmentation functions are not supported yet!"
+        )
+    if method == "dark":
         return DarkAug()
-    elif method == 'mobile':
+    elif method == "mobile":
         return MobileAug()
     elif method is None:
         return None
     else:
-        raise ValueError(f'Invalid augmentation method: {method}')
+        raise ValueError(f"Invalid augmentation method: {method}")
 
 
-if __name__ == '__main__':
-    augmentor = build_augmentor('FDA')
+if __name__ == "__main__":
+    augmentor = build_augmentor("FDA")
diff --git a/third_party/ASpanFormer/src/utils/comm.py b/third_party/ASpanFormer/src/utils/comm.py
index 26ec9517cc47e224430106d8ae9aa99a3fe49167..9f578cda8933cc358934c645fcf413c63ab4d79d 100644
--- a/third_party/ASpanFormer/src/utils/comm.py
+++ b/third_party/ASpanFormer/src/utils/comm.py
@@ -98,11 +98,11 @@ def _serialize_to_tensor(data, group):
     device = torch.device("cpu" if backend == "gloo" else "cuda")
 
     buffer = pickle.dumps(data)
-    if len(buffer) > 1024 ** 3:
+    if len(buffer) > 1024**3:
         logger = logging.getLogger(__name__)
         logger.warning(
             "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
-                get_rank(), len(buffer) / (1024 ** 3), device
+                get_rank(), len(buffer) / (1024**3), device
             )
         )
     storage = torch.ByteStorage.from_buffer(buffer)
@@ -122,7 +122,8 @@ def _pad_to_largest_tensor(tensor, group):
     ), "comm.gather/all_gather must be called from ranks within the given group!"
     local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
     size_list = [
-        torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
+        torch.zeros([1], dtype=torch.int64, device=tensor.device)
+        for _ in range(world_size)
     ]
     dist.all_gather(size_list, local_size, group=group)
 
@@ -133,7 +134,9 @@ def _pad_to_largest_tensor(tensor, group):
     # we pad the tensor because torch all_gather does not support
     # gathering tensors of different shapes
     if local_size != max_size:
-        padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
+        padding = torch.zeros(
+            (max_size - local_size,), dtype=torch.uint8, device=tensor.device
+        )
         tensor = torch.cat((tensor, padding), dim=0)
     return size_list, tensor
 
@@ -164,7 +167,8 @@ def all_gather(data, group=None):
 
     # receiving Tensor from all ranks
     tensor_list = [
-        torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+        torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
+        for _ in size_list
     ]
     dist.all_gather(tensor_list, tensor, group=group)
 
@@ -205,7 +209,8 @@ def gather(data, dst=0, group=None):
     if rank == dst:
         max_size = max(size_list)
         tensor_list = [
-            torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+            torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
+            for _ in size_list
         ]
         dist.gather(tensor, tensor_list, dst=dst, group=group)
 
@@ -228,7 +233,7 @@ def shared_random_seed():
 
     All workers must call this function, otherwise it will deadlock.
     """
-    ints = np.random.randint(2 ** 31)
+    ints = np.random.randint(2**31)
     all_ints = all_gather(ints)
     return all_ints[0]
 
diff --git a/third_party/ASpanFormer/src/utils/dataloader.py b/third_party/ASpanFormer/src/utils/dataloader.py
index 6da37b880a290c2bb3ebb028d0c8dab592acc5c1..b980dfd344714870ecdacd9e7a9742f51c3ee14d 100644
--- a/third_party/ASpanFormer/src/utils/dataloader.py
+++ b/third_party/ASpanFormer/src/utils/dataloader.py
@@ -3,21 +3,22 @@ import numpy as np
 
 # --- PL-DATAMODULE ---
 
+
 def get_local_split(items: list, world_size: int, rank: int, seed: int):
-    """ The local rank only loads a split of the dataset. """
+    """The local rank only loads a split of the dataset."""
     n_items = len(items)
     items_permute = np.random.RandomState(seed).permutation(items)
     if n_items % world_size == 0:
         padded_items = items_permute
     else:
         padding = np.random.RandomState(seed).choice(
-            items,
-            world_size - (n_items % world_size),
-            replace=True)
+            items, world_size - (n_items % world_size), replace=True
+        )
         padded_items = np.concatenate([items_permute, padding])
-        assert len(padded_items) % world_size == 0, \
-            f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}'
+        assert (
+            len(padded_items) % world_size == 0
+        ), f"len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}"
     n_per_rank = len(padded_items) // world_size
-    local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)]
+    local_items = padded_items[n_per_rank * rank : n_per_rank * (rank + 1)]
 
     return local_items
diff --git a/third_party/ASpanFormer/src/utils/dataset.py b/third_party/ASpanFormer/src/utils/dataset.py
index 209bf554acc20e33ea89eb9e7024ba68d0b3a30b..1881446fd69aedb520ae669100cd2a3c2d143a18 100644
--- a/third_party/ASpanFormer/src/utils/dataset.py
+++ b/third_party/ASpanFormer/src/utils/dataset.py
@@ -15,8 +15,11 @@ except Exception:
 
 # --- DATA IO ---
 
+
 def load_array_from_s3(
-    path, client, cv_type,
+    path,
+    client,
+    cv_type,
     use_h5py=False,
 ):
     byte_str = client.Get(path)
@@ -26,7 +29,7 @@ def load_array_from_s3(
             data = cv2.imdecode(raw_array, cv_type)
         else:
             f = io.BytesIO(byte_str)
-            data = np.array(h5py.File(f, 'r')['/depth'])
+            data = np.array(h5py.File(f, "r")["/depth"])
     except Exception as ex:
         print(f"==> Data loading failure: {path}")
         raise ex
@@ -36,9 +39,8 @@ def load_array_from_s3(
 
 
 def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT):
-    cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \
-                else cv2.IMREAD_COLOR
-    if str(path).startswith('s3://'):
+    cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None else cv2.IMREAD_COLOR
+    if str(path).startswith("s3://"):
         image = load_array_from_s3(str(path), client, cv_type)
     else:
         image = cv2.imread(str(path), cv_type)
@@ -54,7 +56,7 @@ def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT):
 def get_resized_wh(w, h, resize=None):
     if resize is not None:  # resize the longer edge
         scale = resize / max(h, w)
-        w_new, h_new = int(round(w*scale)), int(round(h*scale))
+        w_new, h_new = int(round(w * scale)), int(round(h * scale))
     else:
         w_new, h_new = w, h
     return w_new, h_new
@@ -69,20 +71,22 @@ def get_divisible_wh(w, h, df=None):
 
 
 def pad_bottom_right(inp, pad_size, ret_mask=False):
-    assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
+    assert isinstance(pad_size, int) and pad_size >= max(
+        inp.shape[-2:]
+    ), f"{pad_size} < {max(inp.shape[-2:])}"
     mask = None
     if inp.ndim == 2:
         padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
-        padded[:inp.shape[0], :inp.shape[1]] = inp
+        padded[: inp.shape[0], : inp.shape[1]] = inp
         if ret_mask:
             mask = np.zeros((pad_size, pad_size), dtype=bool)
-            mask[:inp.shape[0], :inp.shape[1]] = True
+            mask[: inp.shape[0], : inp.shape[1]] = True
     elif inp.ndim == 3:
         padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
-        padded[:, :inp.shape[1], :inp.shape[2]] = inp
+        padded[:, : inp.shape[1], : inp.shape[2]] = inp
         if ret_mask:
             mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
-            mask[:, :inp.shape[1], :inp.shape[2]] = True
+            mask[:, : inp.shape[1], : inp.shape[2]] = True
     else:
         raise NotImplementedError()
     return padded, mask
@@ -90,6 +94,7 @@ def pad_bottom_right(inp, pad_size, ret_mask=False):
 
 # --- MEGADEPTH ---
 
+
 def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None):
     """
     Args:
@@ -99,7 +104,7 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No
     Returns:
         image (torch.tensor): (1, h, w)
         mask (torch.tensor): (h, w)
-        scale (torch.tensor): [w/w_new, h/h_new]        
+        scale (torch.tensor): [w/w_new, h/h_new]
     """
     # read image
     image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT)
@@ -110,7 +115,7 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No
     w_new, h_new = get_divisible_wh(w_new, h_new, df)
 
     image = cv2.resize(image, (w_new, h_new))
-    scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
+    scale = torch.tensor([w / w_new, h / h_new], dtype=torch.float)
 
     if padding:  # padding
         pad_to = max(h_new, w_new)
@@ -118,7 +123,9 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No
     else:
         mask = None
 
-    image = torch.from_numpy(image).float()[None] / 255  # (h, w) -> (1, h, w) and normalized
+    image = (
+        torch.from_numpy(image).float()[None] / 255
+    )  # (h, w) -> (1, h, w) and normalized
     if mask is not None:
         mask = torch.from_numpy(mask)
 
@@ -126,10 +133,10 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No
 
 
 def read_megadepth_depth(path, pad_to=None):
-    if str(path).startswith('s3://'):
+    if str(path).startswith("s3://"):
         depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True)
     else:
-        depth = np.array(h5py.File(path, 'r')['depth'])
+        depth = np.array(h5py.File(path, "r")["depth"])
     if pad_to is not None:
         depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False)
     depth = torch.from_numpy(depth).float()  # (h, w)
@@ -138,6 +145,7 @@ def read_megadepth_depth(path, pad_to=None):
 
 # --- ScanNet ---
 
+
 def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
     """
     Args:
@@ -146,7 +154,7 @@ def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
     Returns:
         image (torch.tensor): (1, h, w)
         mask (torch.tensor): (h, w)
-        scale (torch.tensor): [w/w_new, h/h_new]        
+        scale (torch.tensor): [w/w_new, h/h_new]
     """
     # read and resize image
     image = imread_gray(path, augment_fn)
@@ -158,7 +166,7 @@ def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
 
 
 def read_scannet_depth(path):
-    if str(path).startswith('s3://'):
+    if str(path).startswith("s3://"):
         depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED)
     else:
         depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
@@ -168,55 +176,57 @@ def read_scannet_depth(path):
 
 
 def read_scannet_pose(path):
-    """ Read ScanNet's Camera2World pose and transform it to World2Camera.
-    
+    """Read ScanNet's Camera2World pose and transform it to World2Camera.
+
     Returns:
         pose_w2c (np.ndarray): (4, 4)
     """
-    cam2world = np.loadtxt(path, delimiter=' ')
+    cam2world = np.loadtxt(path, delimiter=" ")
     world2cam = inv(cam2world)
     return world2cam
 
 
 def read_scannet_intrinsic(path):
-    """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
-    """
-    intrinsic = np.loadtxt(path, delimiter=' ')
+    """Read ScanNet's intrinsic matrix and return the 3x3 matrix."""
+    intrinsic = np.loadtxt(path, delimiter=" ")
     return intrinsic[:-1, :-1]
 
 
-def read_gl3d_gray(path,resize):
-    img=cv2.resize(cv2.imread(path,cv2.IMREAD_GRAYSCALE),(int(resize),int(resize)))
-    img = torch.from_numpy(img).float()[None] / 255  # (h, w) -> (1, h, w) and normalized
+def read_gl3d_gray(path, resize):
+    img = cv2.resize(cv2.imread(path, cv2.IMREAD_GRAYSCALE), (int(resize), int(resize)))
+    img = (
+        torch.from_numpy(img).float()[None] / 255
+    )  # (h, w) -> (1, h, w) and normalized
     return img
 
+
 def read_gl3d_depth(file_path):
-    with open(file_path, 'rb') as fin:
+    with open(file_path, "rb") as fin:
         color = None
         width = None
         height = None
         scale = None
         data_type = None
-        header = str(fin.readline().decode('UTF-8')).rstrip()
-        if header == 'PF':
+        header = str(fin.readline().decode("UTF-8")).rstrip()
+        if header == "PF":
             color = True
-        elif header == 'Pf':
+        elif header == "Pf":
             color = False
         else:
-            raise Exception('Not a PFM file.')
-        dim_match = re.match(r'^(\d+)\s(\d+)\s$', fin.readline().decode('UTF-8'))
+            raise Exception("Not a PFM file.")
+        dim_match = re.match(r"^(\d+)\s(\d+)\s$", fin.readline().decode("UTF-8"))
         if dim_match:
             width, height = map(int, dim_match.groups())
         else:
-            raise Exception('Malformed PFM header.')
-        scale = float((fin.readline().decode('UTF-8')).rstrip())
+            raise Exception("Malformed PFM header.")
+        scale = float((fin.readline().decode("UTF-8")).rstrip())
         if scale < 0:  # little-endian
-            data_type = '<f'
+            data_type = "<f"
         else:
-            data_type = '>f'  # big-endian
+            data_type = ">f"  # big-endian
         data_string = fin.read()
         data = np.fromstring(data_string, data_type)
         shape = (height, width, 3) if color else (height, width)
         data = np.reshape(data, shape)
         data = np.flip(data, 0)
-    return torch.from_numpy(data.copy()).float()
\ No newline at end of file
+    return torch.from_numpy(data.copy()).float()
diff --git a/third_party/ASpanFormer/src/utils/metrics.py b/third_party/ASpanFormer/src/utils/metrics.py
index 6fd6faa5afea3b3c9e2a7b1980a7d7c132def102..fd2c34886f5824c34d9ca19c0419204f5a7e9d2c 100644
--- a/third_party/ASpanFormer/src/utils/metrics.py
+++ b/third_party/ASpanFormer/src/utils/metrics.py
@@ -9,6 +9,7 @@ from kornia.geometry.conversions import convert_points_to_homogeneous
 
 # --- METRICS ---
 
+
 def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
     # angle error between 2 vectors
     t_gt = T_0to1[:3, 3]
@@ -21,7 +22,7 @@ def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
     # angle error between 2 rotation matrices
     R_gt = T_0to1[:3, :3]
     cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2
-    cos = np.clip(cos, -1., 1.)  # handle numercial errors
+    cos = np.clip(cos, -1.0, 1.0)  # handle numercial errors
     R_err = np.rad2deg(np.abs(np.arccos(cos)))
 
     return t_err, R_err
@@ -43,93 +44,108 @@ def symmetric_epipolar_distance(pts0, pts1, E, K0, K1):
     p1Ep0 = torch.sum(pts1 * Ep0, -1)  # [N,]
     Etp1 = pts1 @ E  # [N, 3]
 
-    d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2))  # N
+    d = p1Ep0**2 * (
+        1.0 / (Ep0[:, 0] ** 2 + Ep0[:, 1] ** 2)
+        + 1.0 / (Etp1[:, 0] ** 2 + Etp1[:, 1] ** 2)
+    )  # N
     return d
 
 
 def compute_symmetrical_epipolar_errors(data):
-    """ 
+    """
     Update:
         data (dict):{"epi_errs": [M]}
     """
-    Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3])
-    E_mat = Tx @ data['T_0to1'][:, :3, :3]
+    Tx = numeric.cross_product_matrix(data["T_0to1"][:, :3, 3])
+    E_mat = Tx @ data["T_0to1"][:, :3, :3]
 
-    m_bids = data['m_bids']
-    pts0 = data['mkpts0_f']
-    pts1 = data['mkpts1_f']
+    m_bids = data["m_bids"]
+    pts0 = data["mkpts0_f"]
+    pts1 = data["mkpts1_f"]
 
     epi_errs = []
     for bs in range(Tx.size(0)):
         mask = m_bids == bs
         epi_errs.append(
-            symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs]))
+            symmetric_epipolar_distance(
+                pts0[mask], pts1[mask], E_mat[bs], data["K0"][bs], data["K1"][bs]
+            )
+        )
     epi_errs = torch.cat(epi_errs, dim=0)
 
-    data.update({'epi_errs': epi_errs})
+    data.update({"epi_errs": epi_errs})
+
 
 def compute_symmetrical_epipolar_errors_offset(data):
-    """ 
+    """
     Update:
         data (dict):{"epi_errs": [M]}
     """
-    Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3])
-    E_mat = Tx @ data['T_0to1'][:, :3, :3]
+    Tx = numeric.cross_product_matrix(data["T_0to1"][:, :3, 3])
+    E_mat = Tx @ data["T_0to1"][:, :3, :3]
 
-    m_bids = data['offset_bids']
-    l_ids=data['offset_lids']
-    pts0 = data['offset_kpts0_f']
-    pts1 = data['offset_kpts1_f']
+    m_bids = data["offset_bids"]
+    l_ids = data["offset_lids"]
+    pts0 = data["offset_kpts0_f"]
+    pts1 = data["offset_kpts1_f"]
 
     epi_errs = []
-    layer_num=data['predict_flow'][0].shape[0]
-  
+    layer_num = data["predict_flow"][0].shape[0]
+
     for bs in range(Tx.size(0)):
         for ls in range(layer_num):
             mask_b = m_bids == bs
             mask_l = l_ids == ls
-            mask=mask_b&mask_l
+            mask = mask_b & mask_l
             epi_errs.append(
-                symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs]))
+                symmetric_epipolar_distance(
+                    pts0[mask], pts1[mask], E_mat[bs], data["K0"][bs], data["K1"][bs]
+                )
+            )
     epi_errs = torch.cat(epi_errs, dim=0)
 
-    data.update({'epi_errs_offset': epi_errs}) #[b*l*n]
+    data.update({"epi_errs_offset": epi_errs})  # [b*l*n]
+
 
 def compute_symmetrical_epipolar_errors_offset_bidirectional(data):
-    """ 
+    """
     Update
         data (dict):{"epi_errs": [M]}
     """
-    _compute_symmetrical_epipolar_errors_offset(data,'left')
-    _compute_symmetrical_epipolar_errors_offset(data,'right')
+    _compute_symmetrical_epipolar_errors_offset(data, "left")
+    _compute_symmetrical_epipolar_errors_offset(data, "right")
 
 
-def _compute_symmetrical_epipolar_errors_offset(data,side):
-    """ 
+def _compute_symmetrical_epipolar_errors_offset(data, side):
+    """
     Update
         data (dict):{"epi_errs": [M]}
     """
-    assert side=='left' or side=='right', 'invalid side'
+    assert side == "left" or side == "right", "invalid side"
 
-    Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3])
-    E_mat = Tx @ data['T_0to1'][:, :3, :3]
+    Tx = numeric.cross_product_matrix(data["T_0to1"][:, :3, 3])
+    E_mat = Tx @ data["T_0to1"][:, :3, :3]
 
-    m_bids = data['offset_bids_'+side]
-    l_ids=data['offset_lids_'+side]
-    pts0 = data['offset_kpts0_f_'+side]
-    pts1 = data['offset_kpts1_f_'+side]
+    m_bids = data["offset_bids_" + side]
+    l_ids = data["offset_lids_" + side]
+    pts0 = data["offset_kpts0_f_" + side]
+    pts1 = data["offset_kpts1_f_" + side]
 
     epi_errs = []
-    layer_num=data['predict_flow'][0].shape[0]
+    layer_num = data["predict_flow"][0].shape[0]
     for bs in range(Tx.size(0)):
         for ls in range(layer_num):
             mask_b = m_bids == bs
             mask_l = l_ids == ls
-            mask=mask_b&mask_l
+            mask = mask_b & mask_l
             epi_errs.append(
-                symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs]))
+                symmetric_epipolar_distance(
+                    pts0[mask], pts1[mask], E_mat[bs], data["K0"][bs], data["K1"][bs]
+                )
+            )
     epi_errs = torch.cat(epi_errs, dim=0)
-    data.update({'epi_errs_offset_'+side: epi_errs}) #[b*l*n]
+    data.update({"epi_errs_offset_" + side: epi_errs})  # [b*l*n]
+
 
 def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
     if len(kpts0) < 5:
@@ -143,7 +159,8 @@ def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
 
     # compute pose with cv2
     E, mask = cv2.findEssentialMat(
-        kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC)
+        kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC
+    )
     if E is None:
         print("\nE is None while trying to recover pose.\n")
         return None
@@ -161,7 +178,7 @@ def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
 
 
 def compute_pose_errors(data, config):
-    """ 
+    """
     Update:
         data (dict):{
             "R_errs" List[float]: [N]
@@ -171,33 +188,36 @@ def compute_pose_errors(data, config):
     """
     pixel_thr = config.TRAINER.RANSAC_PIXEL_THR  # 0.5
     conf = config.TRAINER.RANSAC_CONF  # 0.99999
-    data.update({'R_errs': [], 't_errs': [], 'inliers': []})
+    data.update({"R_errs": [], "t_errs": [], "inliers": []})
 
-    m_bids = data['m_bids'].cpu().numpy()
-    pts0 = data['mkpts0_f'].cpu().numpy()
-    pts1 = data['mkpts1_f'].cpu().numpy()
-    K0 = data['K0'].cpu().numpy()
-    K1 = data['K1'].cpu().numpy()
-    T_0to1 = data['T_0to1'].cpu().numpy()
+    m_bids = data["m_bids"].cpu().numpy()
+    pts0 = data["mkpts0_f"].cpu().numpy()
+    pts1 = data["mkpts1_f"].cpu().numpy()
+    K0 = data["K0"].cpu().numpy()
+    K1 = data["K1"].cpu().numpy()
+    T_0to1 = data["T_0to1"].cpu().numpy()
 
     for bs in range(K0.shape[0]):
         mask = m_bids == bs
-        ret = estimate_pose(pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf)
+        ret = estimate_pose(
+            pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf
+        )
 
         if ret is None:
-            data['R_errs'].append(np.inf)
-            data['t_errs'].append(np.inf)
-            data['inliers'].append(np.array([]).astype(np.bool))
+            data["R_errs"].append(np.inf)
+            data["t_errs"].append(np.inf)
+            data["inliers"].append(np.array([]).astype(np.bool))
         else:
             R, t, inliers = ret
             t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0)
-            data['R_errs'].append(R_err)
-            data['t_errs'].append(t_err)
-            data['inliers'].append(inliers)
+            data["R_errs"].append(R_err)
+            data["t_errs"].append(t_err)
+            data["inliers"].append(inliers)
 
 
 # --- METRIC AGGREGATION ---
 
+
 def error_auc(errors, thresholds):
     """
     Args:
@@ -211,14 +231,14 @@ def error_auc(errors, thresholds):
     thresholds = [5, 10, 20]
     for thr in thresholds:
         last_index = np.searchsorted(errors, thr)
-        y = recall[:last_index] + [recall[last_index-1]]
+        y = recall[:last_index] + [recall[last_index - 1]]
         x = errors[:last_index] + [thr]
         aucs.append(np.trapz(y, x) / thr)
 
-    return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)}
+    return {f"auc@{t}": auc for t, auc in zip(thresholds, aucs)}
 
 
-def epidist_prec(errors, thresholds, ret_dict=False,offset=False):
+def epidist_prec(errors, thresholds, ret_dict=False, offset=False):
     precs = []
     for thr in thresholds:
         prec_ = []
@@ -227,34 +247,47 @@ def epidist_prec(errors, thresholds, ret_dict=False,offset=False):
             prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0)
         precs.append(np.mean(prec_) if len(prec_) > 0 else 0)
     if ret_dict:
-        return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)} if not offset else {f'prec_flow@{t:.0e}': prec for t, prec in zip(thresholds, precs)} 
+        return (
+            {f"prec@{t:.0e}": prec for t, prec in zip(thresholds, precs)}
+            if not offset
+            else {f"prec_flow@{t:.0e}": prec for t, prec in zip(thresholds, precs)}
+        )
     else:
         return precs
 
 
 def aggregate_metrics(metrics, epi_err_thr=5e-4):
-    """ Aggregate metrics for the whole dataset:
+    """Aggregate metrics for the whole dataset:
     (This method should be called once per dataset)
     1. AUC of the pose error (angular) at the threshold [5, 10, 20]
     2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth)
     """
     # filter duplicates
-    unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics['identifiers']))
+    unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics["identifiers"]))
     unq_ids = list(unq_ids.values())
-    logger.info(f'Aggregating metrics over {len(unq_ids)} unique items...')
+    logger.info(f"Aggregating metrics over {len(unq_ids)} unique items...")
 
     # pose auc
     angular_thresholds = [5, 10, 20]
-    pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids]
+    pose_errors = np.max(np.stack([metrics["R_errs"], metrics["t_errs"]]), axis=0)[
+        unq_ids
+    ]
     aucs = error_auc(pose_errors, angular_thresholds)  # (auc@5, auc@10, auc@20)
 
     # matching precision
     dist_thresholds = [epi_err_thr]
-    precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True)  # (prec@err_thr)
-    
-    #offset precision
+    precs = epidist_prec(
+        np.array(metrics["epi_errs"], dtype=object)[unq_ids], dist_thresholds, True
+    )  # (prec@err_thr)
+
+    # offset precision
     try:
-        precs_offset = epidist_prec(np.array(metrics['epi_errs_offset'], dtype=object)[unq_ids], [2e-3], True,offset=True) 
-        return {**aucs, **precs,**precs_offset}
+        precs_offset = epidist_prec(
+            np.array(metrics["epi_errs_offset"], dtype=object)[unq_ids],
+            [2e-3],
+            True,
+            offset=True,
+        )
+        return {**aucs, **precs, **precs_offset}
     except:
         return {**aucs, **precs}
diff --git a/third_party/ASpanFormer/src/utils/misc.py b/third_party/ASpanFormer/src/utils/misc.py
index 25e4433f5ffa41adc4c0435cfe2b5696e43b58b3..d9b6a4a5f5920cde89bdecbf2a444aaea8ff51f3 100644
--- a/third_party/ASpanFormer/src/utils/misc.py
+++ b/third_party/ASpanFormer/src/utils/misc.py
@@ -11,6 +11,7 @@ from pytorch_lightning.utilities import rank_zero_only
 import cv2
 import numpy as np
 
+
 def lower_config(yacs_cfg):
     if not isinstance(yacs_cfg, CN):
         return yacs_cfg
@@ -25,7 +26,7 @@ def upper_config(dict_cfg):
 
 def log_on(condition, message, level):
     if condition:
-        assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL']
+        assert level in ["INFO", "DEBUG", "WARNING", "ERROR", "CRITICAL"]
         logger.log(level, message)
 
 
@@ -35,32 +36,35 @@ def get_rank_zero_only_logger(logger: _Logger):
     else:
         for _level in logger._core.levels.keys():
             level = _level.lower()
-            setattr(logger, level,
-                    lambda x: None)
+            setattr(logger, level, lambda x: None)
         logger._log = lambda x: None
     return logger
 
 
 def setup_gpus(gpus: Union[str, int]) -> int:
-    """ A temporary fix for pytorch-lighting 1.3.x """
+    """A temporary fix for pytorch-lighting 1.3.x"""
     gpus = str(gpus)
     gpu_ids = []
-    
-    if ',' not in gpus:
+
+    if "," not in gpus:
         n_gpus = int(gpus)
         return n_gpus if n_gpus != -1 else torch.cuda.device_count()
     else:
-        gpu_ids = [i.strip() for i in gpus.split(',') if i != '']
-    
+        gpu_ids = [i.strip() for i in gpus.split(",") if i != ""]
+
     # setup environment variables
-    visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
+    visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
     if visible_devices is None:
         os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
-        os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids)
-        visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
-        logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}')
+        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpu_ids)
+        visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
+        logger.warning(
+            f"[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}"
+        )
     else:
-        logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.')
+        logger.warning(
+            "[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process."
+        )
     return len(gpu_ids)
 
 
@@ -71,11 +75,11 @@ def flattenList(x):
 @contextlib.contextmanager
 def tqdm_joblib(tqdm_object):
     """Context manager to patch joblib to report into tqdm progress bar given as argument
-    
+
     Usage:
         with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:
             Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))
-            
+
     When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing)
         ret_vals = Parallel(n_jobs=args.world_size)(
                     delayed(lambda x: _compute_cov_score(pid, *x))(param)
@@ -84,6 +88,7 @@ def tqdm_joblib(tqdm_object):
                                           total=len(image_ids)*(len(image_ids)-1)/2))
     Src: https://stackoverflow.com/a/58936697
     """
+
     class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
         def __init__(self, *args, **kwargs):
             super().__init__(*args, **kwargs)
@@ -101,39 +106,79 @@ def tqdm_joblib(tqdm_object):
         tqdm_object.close()
 
 
-def draw_points(img,points,color=(0,255,0),radius=3):
+def draw_points(img, points, color=(0, 255, 0), radius=3):
     dp = [(int(points[i, 0]), int(points[i, 1])) for i in range(points.shape[0])]
     for i in range(points.shape[0]):
-        cv2.circle(img, dp[i],radius=radius,color=color)
+        cv2.circle(img, dp[i], radius=radius, color=color)
     return img
-    
 
-def draw_match(img1, img2, corr1, corr2,inlier=[True],color=None,radius1=1,radius2=1,resize=None):
+
+def draw_match(
+    img1,
+    img2,
+    corr1,
+    corr2,
+    inlier=[True],
+    color=None,
+    radius1=1,
+    radius2=1,
+    resize=None,
+):
     if resize is not None:
-        scale1,scale2=[img1.shape[1]/resize[0],img1.shape[0]/resize[1]],[img2.shape[1]/resize[0],img2.shape[0]/resize[1]]
-        img1,img2=cv2.resize(img1, resize, interpolation=cv2.INTER_AREA),cv2.resize(img2, resize, interpolation=cv2.INTER_AREA) 
-        corr1,corr2=corr1/np.asarray(scale1)[np.newaxis],corr2/np.asarray(scale2)[np.newaxis]
-    corr1_key = [cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0])]
-    corr2_key = [cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0])]
+        scale1, scale2 = [img1.shape[1] / resize[0], img1.shape[0] / resize[1]], [
+            img2.shape[1] / resize[0],
+            img2.shape[0] / resize[1],
+        ]
+        img1, img2 = cv2.resize(img1, resize, interpolation=cv2.INTER_AREA), cv2.resize(
+            img2, resize, interpolation=cv2.INTER_AREA
+        )
+        corr1, corr2 = (
+            corr1 / np.asarray(scale1)[np.newaxis],
+            corr2 / np.asarray(scale2)[np.newaxis],
+        )
+    corr1_key = [
+        cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0])
+    ]
+    corr2_key = [
+        cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0])
+    ]
 
     assert len(corr1) == len(corr2)
 
     draw_matches = [cv2.DMatch(i, i, 0) for i in range(len(corr1))]
     if color is None:
-        color = [(0, 255, 0) if cur_inlier else (0,0,255) for cur_inlier in inlier]
-    if len(color)==1:
-        display = cv2.drawMatches(img1, corr1_key, img2, corr2_key, draw_matches, None,
-                              matchColor=color[0],
-                              singlePointColor=color[0],
-                              flags=4
-                              )
+        color = [(0, 255, 0) if cur_inlier else (0, 0, 255) for cur_inlier in inlier]
+    if len(color) == 1:
+        display = cv2.drawMatches(
+            img1,
+            corr1_key,
+            img2,
+            corr2_key,
+            draw_matches,
+            None,
+            matchColor=color[0],
+            singlePointColor=color[0],
+            flags=4,
+        )
     else:
-        height,width=max(img1.shape[0],img2.shape[0]),img1.shape[1]+img2.shape[1]
-        display=np.zeros([height,width,3],np.uint8)
-        display[:img1.shape[0],:img1.shape[1]]=img1
-        display[:img2.shape[0],img1.shape[1]:]=img2
+        height, width = max(img1.shape[0], img2.shape[0]), img1.shape[1] + img2.shape[1]
+        display = np.zeros([height, width, 3], np.uint8)
+        display[: img1.shape[0], : img1.shape[1]] = img1
+        display[: img2.shape[0], img1.shape[1] :] = img2
         for i in range(len(corr1)):
-            left_x,left_y,right_x,right_y=int(corr1[i][0]),int(corr1[i][1]),int(corr2[i][0]+img1.shape[1]),int(corr2[i][1])
-            cur_color=(int(color[i][0]),int(color[i][1]),int(color[i][2]))
-            cv2.line(display, (left_x,left_y), (right_x,right_y),cur_color,1,lineType=cv2.LINE_AA)
+            left_x, left_y, right_x, right_y = (
+                int(corr1[i][0]),
+                int(corr1[i][1]),
+                int(corr2[i][0] + img1.shape[1]),
+                int(corr2[i][1]),
+            )
+            cur_color = (int(color[i][0]), int(color[i][1]), int(color[i][2]))
+            cv2.line(
+                display,
+                (left_x, left_y),
+                (right_x, right_y),
+                cur_color,
+                1,
+                lineType=cv2.LINE_AA,
+            )
     return display
diff --git a/third_party/ASpanFormer/src/utils/plotting.py b/third_party/ASpanFormer/src/utils/plotting.py
index 8696880237b6ad9fe48d3c1fc44ed13b691a6c4d..0ca3ef0a336a652e7ca910a5584227da043ac019 100644
--- a/third_party/ASpanFormer/src/utils/plotting.py
+++ b/third_party/ASpanFormer/src/utils/plotting.py
@@ -4,38 +4,51 @@ import matplotlib.pyplot as plt
 import matplotlib
 from copy import deepcopy
 
+
 def _compute_conf_thresh(data):
-    dataset_name = data['dataset_name'][0].lower()
-    if dataset_name == 'scannet':
+    dataset_name = data["dataset_name"][0].lower()
+    if dataset_name == "scannet":
         thr = 5e-4
-    elif dataset_name == 'megadepth' or dataset_name=='gl3d':
+    elif dataset_name == "megadepth" or dataset_name == "gl3d":
         thr = 1e-4
     else:
-        raise ValueError(f'Unknown dataset: {dataset_name}')
+        raise ValueError(f"Unknown dataset: {dataset_name}")
     return thr
 
 
 # --- VISUALIZATION --- #
 
+
 def make_matching_figure(
-        img0, img1, mkpts0, mkpts1, color,
-        kpts0=None, kpts1=None, text=[], dpi=75, path=None):
+    img0,
+    img1,
+    mkpts0,
+    mkpts1,
+    color,
+    kpts0=None,
+    kpts1=None,
+    text=[],
+    dpi=75,
+    path=None,
+):
     # draw image pair
-    assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
+    assert (
+        mkpts0.shape[0] == mkpts1.shape[0]
+    ), f"mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}"
     fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
-    axes[0].imshow(img0, cmap='gray')
-    axes[1].imshow(img1, cmap='gray')
-    for i in range(2):   # clear all frames
+    axes[0].imshow(img0, cmap="gray")
+    axes[1].imshow(img1, cmap="gray")
+    for i in range(2):  # clear all frames
         axes[i].get_yaxis().set_ticks([])
         axes[i].get_xaxis().set_ticks([])
         for spine in axes[i].spines.values():
             spine.set_visible(False)
     plt.tight_layout(pad=1)
-    
+
     if kpts0 is not None:
         assert kpts1 is not None
-        axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
-        axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2)
+        axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c="w", s=2)
+        axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=2)
 
     # draw matches
     if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
@@ -43,164 +56,181 @@ def make_matching_figure(
         transFigure = fig.transFigure.inverted()
         fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
         fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
-        fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
-                                            (fkpts0[i, 1], fkpts1[i, 1]),
-                                            transform=fig.transFigure, c=color[i], linewidth=1)
-                                        for i in range(len(mkpts0))]
-        
+        fig.lines = [
+            matplotlib.lines.Line2D(
+                (fkpts0[i, 0], fkpts1[i, 0]),
+                (fkpts0[i, 1], fkpts1[i, 1]),
+                transform=fig.transFigure,
+                c=color[i],
+                linewidth=1,
+            )
+            for i in range(len(mkpts0))
+        ]
+
         axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
         axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
 
     # put txts
-    txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
+    txt_color = "k" if img0[:100, :200].mean() > 200 else "w"
     fig.text(
-        0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
-        fontsize=15, va='top', ha='left', color=txt_color)
+        0.01,
+        0.99,
+        "\n".join(text),
+        transform=fig.axes[0].transAxes,
+        fontsize=15,
+        va="top",
+        ha="left",
+        color=txt_color,
+    )
 
     # save or return figure
     if path:
-        plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
+        plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
         plt.close()
     else:
         return fig
 
 
-def _make_evaluation_figure(data, b_id, alpha='dynamic'):
-    b_mask = data['m_bids'] == b_id
+def _make_evaluation_figure(data, b_id, alpha="dynamic"):
+    b_mask = data["m_bids"] == b_id
     conf_thr = _compute_conf_thresh(data)
-    
-    img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
-    img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
-    kpts0 = data['mkpts0_f'][b_mask].cpu().numpy()
-    kpts1 = data['mkpts1_f'][b_mask].cpu().numpy()
-    
+
+    img0 = (data["image0"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+    img1 = (data["image1"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+    kpts0 = data["mkpts0_f"][b_mask].cpu().numpy()
+    kpts1 = data["mkpts1_f"][b_mask].cpu().numpy()
+
     # for megadepth, we visualize matches on the resized image
-    if 'scale0' in data:
-        kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]]
-        kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]]
-    epi_errs = data['epi_errs'][b_mask].cpu().numpy()
+    if "scale0" in data:
+        kpts0 = kpts0 / data["scale0"][b_id].cpu().numpy()[[1, 0]]
+        kpts1 = kpts1 / data["scale1"][b_id].cpu().numpy()[[1, 0]]
+    epi_errs = data["epi_errs"][b_mask].cpu().numpy()
     correct_mask = epi_errs < conf_thr
     precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
     n_correct = np.sum(correct_mask)
-    n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu())
+    n_gt_matches = int(data["conf_matrix_gt"][b_id].sum().cpu())
     recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
     # recall might be larger than 1, since the calculation of conf_matrix_gt
     # uses groundtruth depths and camera poses, but epipolar distance is used here.
 
     # matching info
-    if alpha == 'dynamic':
+    if alpha == "dynamic":
         alpha = dynamic_alpha(len(correct_mask))
     color = error_colormap(epi_errs, conf_thr, alpha=alpha)
-    
+
     text = [
-        f'#Matches {len(kpts0)}',
-        f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
-        f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}'
+        f"#Matches {len(kpts0)}",
+        f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}",
+        f"Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}",
     ]
-    
+
     # make the figure
-    figure = make_matching_figure(img0, img1, kpts0, kpts1,
-                                  color, text=text)
+    figure = make_matching_figure(img0, img1, kpts0, kpts1, color, text=text)
     return figure
 
-def _make_evaluation_figure_offset(data, b_id, alpha='dynamic',side=''):
-    layer_num=data['predict_flow'][0].shape[0]
 
-    b_mask = data['offset_bids'+side] == b_id
-    conf_thr = 2e-3 #hardcode for scannet(coarse level)
-    img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
-    img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
-    
-    figure_list=[]
-    #draw offset matches in different layers
+def _make_evaluation_figure_offset(data, b_id, alpha="dynamic", side=""):
+    layer_num = data["predict_flow"][0].shape[0]
+
+    b_mask = data["offset_bids" + side] == b_id
+    conf_thr = 2e-3  # hardcode for scannet(coarse level)
+    img0 = (data["image0"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+    img1 = (data["image1"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+
+    figure_list = []
+    # draw offset matches in different layers
     for layer_index in range(layer_num):
-        l_mask=data['offset_lids'+side]==layer_index
-        mask=l_mask&b_mask
-        kpts0 = data['offset_kpts0_f'+side][mask].cpu().numpy()
-        kpts1 = data['offset_kpts1_f'+side][mask].cpu().numpy()
-        
-        epi_errs = data['epi_errs_offset'+side][mask].cpu().numpy()
+        l_mask = data["offset_lids" + side] == layer_index
+        mask = l_mask & b_mask
+        kpts0 = data["offset_kpts0_f" + side][mask].cpu().numpy()
+        kpts1 = data["offset_kpts1_f" + side][mask].cpu().numpy()
+
+        epi_errs = data["epi_errs_offset" + side][mask].cpu().numpy()
         correct_mask = epi_errs < conf_thr
-        
+
         precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
         n_correct = np.sum(correct_mask)
-        n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu())
+        n_gt_matches = int(data["conf_matrix_gt"][b_id].sum().cpu())
         recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
         # recall might be larger than 1, since the calculation of conf_matrix_gt
         # uses groundtruth depths and camera poses, but epipolar distance is used here.
 
         # matching info
-        if alpha == 'dynamic':
+        if alpha == "dynamic":
             alpha = dynamic_alpha(len(correct_mask))
         color = error_colormap(epi_errs, conf_thr, alpha=alpha)
-        
+
         text = [
-            f'#Matches {len(kpts0)}',
-            f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
-            f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}'
+            f"#Matches {len(kpts0)}",
+            f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}",
+            f"Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}",
         ]
-        
+
         # make the figure
-        #import pdb;pdb.set_trace()
-        figure = make_matching_figure(deepcopy(img0), deepcopy(img1) , kpts0, kpts1,
-                                    color, text=text)
+        # import pdb;pdb.set_trace()
+        figure = make_matching_figure(
+            deepcopy(img0), deepcopy(img1), kpts0, kpts1, color, text=text
+        )
         figure_list.append(figure)
     return figure
 
+
 def _make_confidence_figure(data, b_id):
     # TODO: Implement confidence figure
     raise NotImplementedError()
 
 
-def make_matching_figures(data, config, mode='evaluation'):
-    """ Make matching figures for a batch.
-    
+def make_matching_figures(data, config, mode="evaluation"):
+    """Make matching figures for a batch.
+
     Args:
         data (Dict): a batch updated by PL_LoFTR.
         config (Dict): matcher config
     Returns:
         figures (Dict[str, List[plt.figure]]
     """
-    assert mode in ['evaluation', 'confidence']  # 'confidence'
+    assert mode in ["evaluation", "confidence"]  # 'confidence'
     figures = {mode: []}
-    for b_id in range(data['image0'].size(0)):
-        if mode == 'evaluation':
+    for b_id in range(data["image0"].size(0)):
+        if mode == "evaluation":
             fig = _make_evaluation_figure(
-                data, b_id,
-                alpha=config.TRAINER.PLOT_MATCHES_ALPHA)
-        elif mode == 'confidence':
+                data, b_id, alpha=config.TRAINER.PLOT_MATCHES_ALPHA
+            )
+        elif mode == "confidence":
             fig = _make_confidence_figure(data, b_id)
         else:
-            raise ValueError(f'Unknown plot mode: {mode}')
+            raise ValueError(f"Unknown plot mode: {mode}")
     figures[mode].append(fig)
     return figures
 
-def make_matching_figures_offset(data, config, mode='evaluation',side=''):
-    """ Make matching figures for a batch.
-    
+
+def make_matching_figures_offset(data, config, mode="evaluation", side=""):
+    """Make matching figures for a batch.
+
     Args:
         data (Dict): a batch updated by PL_LoFTR.
         config (Dict): matcher config
     Returns:
         figures (Dict[str, List[plt.figure]]
     """
-    assert mode in ['evaluation', 'confidence']  # 'confidence'
+    assert mode in ["evaluation", "confidence"]  # 'confidence'
     figures = {mode: []}
-    for b_id in range(data['image0'].size(0)):
-        if mode == 'evaluation':
+    for b_id in range(data["image0"].size(0)):
+        if mode == "evaluation":
             fig = _make_evaluation_figure_offset(
-                data, b_id,
-                alpha=config.TRAINER.PLOT_MATCHES_ALPHA,side=side)
-        elif mode == 'confidence':
+                data, b_id, alpha=config.TRAINER.PLOT_MATCHES_ALPHA, side=side
+            )
+        elif mode == "confidence":
             fig = _make_evaluation_figure_offset(data, b_id)
         else:
-            raise ValueError(f'Unknown plot mode: {mode}')
+            raise ValueError(f"Unknown plot mode: {mode}")
         figures[mode].append(fig)
     return figures
 
-def dynamic_alpha(n_matches,
-                  milestones=[0, 300, 1000, 2000],
-                  alphas=[1.0, 0.8, 0.4, 0.2]):
+
+def dynamic_alpha(
+    n_matches, milestones=[0, 300, 1000, 2000], alphas=[1.0, 0.8, 0.4, 0.2]
+):
     if n_matches == 0:
         return 1.0
     ranges = list(zip(alphas, alphas[1:] + [None]))
@@ -209,11 +239,15 @@ def dynamic_alpha(n_matches,
     if _range[1] is None:
         return _range[0]
     return _range[1] + (milestones[loc + 1] - n_matches) / (
-        milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1])
+        milestones[loc + 1] - milestones[loc]
+    ) * (_range[0] - _range[1])
 
 
 def error_colormap(err, thr, alpha=1.0):
     assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
     x = 1 - np.clip(err / (thr * 2), 0, 1)
     return np.clip(
-        np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
+        np.stack([2 - x * 2, x * 2, np.zeros_like(x), np.ones_like(x) * alpha], -1),
+        0,
+        1,
+    )
diff --git a/third_party/ASpanFormer/src/utils/profiler.py b/third_party/ASpanFormer/src/utils/profiler.py
index 6d21ed79fb506ef09c75483355402c48a195aaa9..0275ea34e3eb9cceb4ed809bebeda209749f5bc5 100644
--- a/third_party/ASpanFormer/src/utils/profiler.py
+++ b/third_party/ASpanFormer/src/utils/profiler.py
@@ -7,7 +7,7 @@ from pytorch_lightning.utilities import rank_zero_only
 class InferenceProfiler(SimpleProfiler):
     """
     This profiler records duration of actions with cuda.synchronize()
-    Use this in test time. 
+    Use this in test time.
     """
 
     def __init__(self):
@@ -28,12 +28,13 @@ class InferenceProfiler(SimpleProfiler):
 
 
 def build_profiler(name):
-    if name == 'inference':
+    if name == "inference":
         return InferenceProfiler()
-    elif name == 'pytorch':
+    elif name == "pytorch":
         from pytorch_lightning.profiler import PyTorchProfiler
+
         return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100)
     elif name is None:
         return PassThroughProfiler()
     else:
-        raise ValueError(f'Invalid profiler: {name}')
+        raise ValueError(f"Invalid profiler: {name}")
diff --git a/third_party/ASpanFormer/test.py b/third_party/ASpanFormer/test.py
index 541ce84662ab4888c6fece30403c5c9983118637..bed3060d931d2f9e5d60ef3b0eb6a9016322fa0f 100644
--- a/third_party/ASpanFormer/test.py
+++ b/third_party/ASpanFormer/test.py
@@ -10,33 +10,52 @@ from src.lightning.data import MultiSceneDataModule
 from src.lightning.lightning_aspanformer import PL_ASpanFormer
 import torch
 
+
 def parse_args():
     # init a costum parser which will be added into pl.Trainer parser
     # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
-    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
-    parser.add_argument(
-        'data_cfg_path', type=str, help='data config path')
-    parser.add_argument(
-        'main_cfg_path', type=str, help='main config path')
-    parser.add_argument(
-        '--ckpt_path', type=str, default="weights/indoor_ds.ckpt", help='path to the checkpoint')
-    parser.add_argument(
-        '--dump_dir', type=str, default=None, help="if set, the matching results will be dump to dump_dir")
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+    parser.add_argument("data_cfg_path", type=str, help="data config path")
+    parser.add_argument("main_cfg_path", type=str, help="main config path")
     parser.add_argument(
-        '--profiler_name', type=str, default=None, help='options: [inference, pytorch], or leave it unset')
+        "--ckpt_path",
+        type=str,
+        default="weights/indoor_ds.ckpt",
+        help="path to the checkpoint",
+    )
     parser.add_argument(
-        '--batch_size', type=int, default=1, help='batch_size per gpu')
+        "--dump_dir",
+        type=str,
+        default=None,
+        help="if set, the matching results will be dump to dump_dir",
+    )
     parser.add_argument(
-        '--num_workers', type=int, default=2)
+        "--profiler_name",
+        type=str,
+        default=None,
+        help="options: [inference, pytorch], or leave it unset",
+    )
+    parser.add_argument("--batch_size", type=int, default=1, help="batch_size per gpu")
+    parser.add_argument("--num_workers", type=int, default=2)
     parser.add_argument(
-        '--thr', type=float, default=None, help='modify the coarse-level matching threshold.')
+        "--thr",
+        type=float,
+        default=None,
+        help="modify the coarse-level matching threshold.",
+    )
     parser.add_argument(
-        '--mode', type=str, default='vanilla', help='modify the coarse-level matching threshold.')
+        "--mode",
+        type=str,
+        default="vanilla",
+        help="modify the coarse-level matching threshold.",
+    )
     parser = pl.Trainer.add_argparse_args(parser)
     return parser.parse_args()
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     # parse arguments
     args = parse_args()
     pprint.pprint(vars(args))
@@ -55,7 +74,12 @@ if __name__ == '__main__':
 
     # lightning module
     profiler = build_profiler(args.profiler_name)
-    model = PL_ASpanFormer(config, pretrained_ckpt=args.ckpt_path, profiler=profiler, dump_dir=args.dump_dir)
+    model = PL_ASpanFormer(
+        config,
+        pretrained_ckpt=args.ckpt_path,
+        profiler=profiler,
+        dump_dir=args.dump_dir,
+    )
     loguru_logger.info(f"ASpanFormer-lightning initialized!")
 
     # lightning data
@@ -63,7 +87,9 @@ if __name__ == '__main__':
     loguru_logger.info(f"DataModule initialized!")
 
     # lightning trainer
-    trainer = pl.Trainer.from_argparse_args(args, replace_sampler_ddp=False, logger=False)
+    trainer = pl.Trainer.from_argparse_args(
+        args, replace_sampler_ddp=False, logger=False
+    )
 
     loguru_logger.info(f"Start testing!")
     trainer.test(model, datamodule=data_module, verbose=False)
diff --git a/third_party/ASpanFormer/tools/extract.py b/third_party/ASpanFormer/tools/extract.py
index 12f55e2f94120d5765f124f8eec867f1d82e0aa7..b3dea56a14f6c100b2c53978678bab69a656cdeb 100644
--- a/third_party/ASpanFormer/tools/extract.py
+++ b/third_party/ASpanFormer/tools/extract.py
@@ -5,43 +5,77 @@ from tqdm import tqdm
 from multiprocessing import Pool
 from functools import partial
 
-scannet_dir='/root/data/ScanNet-v2-1.0.0/data/raw'
-dump_dir='/root/data/scannet_dump'
-num_process=32
-
-def extract(seq,scannet_dir,split,dump_dir):
-    assert split=='train' or split=='test'
-    if not os.path.exists(os.path.join(dump_dir,split,seq)):
-            os.mkdir(os.path.join(dump_dir,split,seq))
-    cmd='python reader.py --filename '+os.path.join(scannet_dir,'scans' if split=='train' else 'scans_test',seq,seq+'.sens')+' --output_path '+os.path.join(dump_dir,split,seq)+\
-            ' --export_depth_images --export_color_images --export_poses --export_intrinsics'
+scannet_dir = "/root/data/ScanNet-v2-1.0.0/data/raw"
+dump_dir = "/root/data/scannet_dump"
+num_process = 32
+
+
+def extract(seq, scannet_dir, split, dump_dir):
+    assert split == "train" or split == "test"
+    if not os.path.exists(os.path.join(dump_dir, split, seq)):
+        os.mkdir(os.path.join(dump_dir, split, seq))
+    cmd = (
+        "python reader.py --filename "
+        + os.path.join(
+            scannet_dir,
+            "scans" if split == "train" else "scans_test",
+            seq,
+            seq + ".sens",
+        )
+        + " --output_path "
+        + os.path.join(dump_dir, split, seq)
+        + " --export_depth_images --export_color_images --export_poses --export_intrinsics"
+    )
     os.system(cmd)
 
-if __name__=='__main__':
+
+if __name__ == "__main__":
     if not os.path.exists(dump_dir):
         os.mkdir(dump_dir)
-        os.mkdir(os.path.join(dump_dir,'train'))
-        os.mkdir(os.path.join(dump_dir,'test'))
+        os.mkdir(os.path.join(dump_dir, "train"))
+        os.mkdir(os.path.join(dump_dir, "test"))
 
-    train_seq_list=[seq.split('/')[-1] for seq in glob.glob(os.path.join(scannet_dir,'scans','scene*'))]
-    test_seq_list=[seq.split('/')[-1] for seq in glob.glob(os.path.join(scannet_dir,'scans_test','scene*'))]
+    train_seq_list = [
+        seq.split("/")[-1]
+        for seq in glob.glob(os.path.join(scannet_dir, "scans", "scene*"))
+    ]
+    test_seq_list = [
+        seq.split("/")[-1]
+        for seq in glob.glob(os.path.join(scannet_dir, "scans_test", "scene*"))
+    ]
 
-    extract_train=partial(extract,scannet_dir=scannet_dir,split='train',dump_dir=dump_dir)
-    extract_test=partial(extract,scannet_dir=scannet_dir,split='test',dump_dir=dump_dir)
+    extract_train = partial(
+        extract, scannet_dir=scannet_dir, split="train", dump_dir=dump_dir
+    )
+    extract_test = partial(
+        extract, scannet_dir=scannet_dir, split="test", dump_dir=dump_dir
+    )
 
-    num_train_iter=len(train_seq_list)//num_process if len(train_seq_list)%num_process==0 else len(train_seq_list)//num_process+1
-    num_test_iter=len(test_seq_list)//num_process if len(test_seq_list)%num_process==0 else len(test_seq_list)//num_process+1
+    num_train_iter = (
+        len(train_seq_list) // num_process
+        if len(train_seq_list) % num_process == 0
+        else len(train_seq_list) // num_process + 1
+    )
+    num_test_iter = (
+        len(test_seq_list) // num_process
+        if len(test_seq_list) % num_process == 0
+        else len(test_seq_list) // num_process + 1
+    )
 
     pool = Pool(num_process)
     for index in tqdm(range(num_train_iter)):
-        seq_list=train_seq_list[index*num_process:min((index+1)*num_process,len(train_seq_list))]
-        pool.map(extract_train,seq_list)
+        seq_list = train_seq_list[
+            index * num_process : min((index + 1) * num_process, len(train_seq_list))
+        ]
+        pool.map(extract_train, seq_list)
     pool.close()
     pool.join()
-    
+
     pool = Pool(num_process)
     for index in tqdm(range(num_test_iter)):
-        seq_list=test_seq_list[index*num_process:min((index+1)*num_process,len(test_seq_list))]
-        pool.map(extract_test,seq_list)
+        seq_list = test_seq_list[
+            index * num_process : min((index + 1) * num_process, len(test_seq_list))
+        ]
+        pool.map(extract_test, seq_list)
     pool.close()
-    pool.join()
\ No newline at end of file
+    pool.join()
diff --git a/third_party/ASpanFormer/tools/preprocess_scene.py b/third_party/ASpanFormer/tools/preprocess_scene.py
index d20c0d070243519d67bbd25668ff5eb1657474be..5364058829b7e45eabd61a32a591711645fc1ded 100644
--- a/third_party/ASpanFormer/tools/preprocess_scene.py
+++ b/third_party/ASpanFormer/tools/preprocess_scene.py
@@ -6,78 +6,63 @@ import numpy as np
 
 import os
 
-parser = argparse.ArgumentParser(description='MegaDepth preprocessing script')
+parser = argparse.ArgumentParser(description="MegaDepth preprocessing script")
 
-parser.add_argument(
-    '--base_path', type=str, required=True,
-    help='path to MegaDepth'
-)
-parser.add_argument(
-    '--scene_id', type=str, required=True,
-    help='scene ID'
-)
+parser.add_argument("--base_path", type=str, required=True, help="path to MegaDepth")
+parser.add_argument("--scene_id", type=str, required=True, help="scene ID")
 
 parser.add_argument(
-    '--output_path', type=str, required=True,
-    help='path to the output directory'
+    "--output_path", type=str, required=True, help="path to the output directory"
 )
 
 args = parser.parse_args()
 
 base_path = args.base_path
 # Remove the trailing / if need be.
-if base_path[-1] in ['/', '\\']:
-    base_path = base_path[: - 1]
+if base_path[-1] in ["/", "\\"]:
+    base_path = base_path[:-1]
 scene_id = args.scene_id
 
-base_depth_path = os.path.join(
-    base_path, 'phoenix/S6/zl548/MegaDepth_v1'
-)
-base_undistorted_sfm_path = os.path.join(
-    base_path, 'Undistorted_SfM'
-)
+base_depth_path = os.path.join(base_path, "phoenix/S6/zl548/MegaDepth_v1")
+base_undistorted_sfm_path = os.path.join(base_path, "Undistorted_SfM")
 
 undistorted_sparse_path = os.path.join(
-    base_undistorted_sfm_path, scene_id, 'sparse-txt'
+    base_undistorted_sfm_path, scene_id, "sparse-txt"
 )
 if not os.path.exists(undistorted_sparse_path):
     exit()
 
-depths_path = os.path.join(
-    base_depth_path, scene_id, 'dense0', 'depths'
-)
+depths_path = os.path.join(base_depth_path, scene_id, "dense0", "depths")
 if not os.path.exists(depths_path):
     exit()
 
-images_path = os.path.join(
-    base_undistorted_sfm_path, scene_id, 'images'
-)
+images_path = os.path.join(base_undistorted_sfm_path, scene_id, "images")
 if not os.path.exists(images_path):
     exit()
 
 # Process cameras.txt
-with open(os.path.join(undistorted_sparse_path, 'cameras.txt'), 'r') as f:
-    raw = f.readlines()[3 :]  # skip the header
+with open(os.path.join(undistorted_sparse_path, "cameras.txt"), "r") as f:
+    raw = f.readlines()[3:]  # skip the header
 
 camera_intrinsics = {}
 for camera in raw:
-    camera = camera.split(' ')
-    camera_intrinsics[int(camera[0])] = [float(elem) for elem in camera[2 :]]
+    camera = camera.split(" ")
+    camera_intrinsics[int(camera[0])] = [float(elem) for elem in camera[2:]]
 
 # Process points3D.txt
-with open(os.path.join(undistorted_sparse_path, 'points3D.txt'), 'r') as f:
-    raw = f.readlines()[3 :]  # skip the header
+with open(os.path.join(undistorted_sparse_path, "points3D.txt"), "r") as f:
+    raw = f.readlines()[3:]  # skip the header
 
 points3D = {}
 for point3D in raw:
-    point3D = point3D.split(' ')
-    points3D[int(point3D[0])] = np.array([
-        float(point3D[1]), float(point3D[2]), float(point3D[3])
-    ])
-    
+    point3D = point3D.split(" ")
+    points3D[int(point3D[0])] = np.array(
+        [float(point3D[1]), float(point3D[2]), float(point3D[3])]
+    )
+
 # Process images.txt
-with open(os.path.join(undistorted_sparse_path, 'images.txt'), 'r') as f:
-    raw = f.readlines()[4 :]  # skip the header
+with open(os.path.join(undistorted_sparse_path, "images.txt"), "r") as f:
+    raw = f.readlines()[4:]  # skip the header
 
 image_id_to_idx = {}
 image_names = []
@@ -85,19 +70,19 @@ raw_pose = []
 camera = []
 points3D_id_to_2D = []
 n_points3D = []
-for idx, (image, points) in enumerate(zip(raw[:: 2], raw[1 :: 2])):
-    image = image.split(' ')
-    points = points.split(' ')
+for idx, (image, points) in enumerate(zip(raw[::2], raw[1::2])):
+    image = image.split(" ")
+    points = points.split(" ")
 
     image_id_to_idx[int(image[0])] = idx
 
-    image_name = image[-1].strip('\n')
+    image_name = image[-1].strip("\n")
     image_names.append(image_name)
 
-    raw_pose.append([float(elem) for elem in image[1 : -2]])
+    raw_pose.append([float(elem) for elem in image[1:-2]])
     camera.append(int(image[-2]))
     current_points3D_id_to_2D = {}
-    for x, y, point3D_id in zip(points[:: 3], points[1 :: 3], points[2 :: 3]):
+    for x, y, point3D_id in zip(points[::3], points[1::3], points[2::3]):
         if int(point3D_id) == -1:
             continue
         current_points3D_id_to_2D[int(point3D_id)] = [float(x), float(y)]
@@ -110,12 +95,10 @@ image_paths = []
 depth_paths = []
 for image_name in image_names:
     image_path = os.path.join(images_path, image_name)
-   
+
     # Path to the depth file
-    depth_path = os.path.join(
-        depths_path, '%s.h5' % os.path.splitext(image_name)[0]
-    )
-    
+    depth_path = os.path.join(depths_path, "%s.h5" % os.path.splitext(image_name)[0])
+
     if os.path.exists(depth_path):
         # Check if depth map or background / foreground mask
         file_size = os.stat(depth_path).st_size
@@ -152,32 +135,22 @@ for idx, image_name in enumerate(image_names):
     intrinsics.append(K)
 
     image_pose = raw_pose[idx]
-    qvec = image_pose[: 4]
+    qvec = image_pose[:4]
     qvec = qvec / np.linalg.norm(qvec)
     w, x, y, z = qvec
-    R = np.array([
-        [
-            1 - 2 * y * y - 2 * z * z,
-            2 * x * y - 2 * z * w,
-            2 * x * z + 2 * y * w
-        ],
+    R = np.array(
         [
-            2 * x * y + 2 * z * w,
-            1 - 2 * x * x - 2 * z * z,
-            2 * y * z - 2 * x * w
-        ],
-        [
-            2 * x * z - 2 * y * w,
-            2 * y * z + 2 * x * w,
-            1 - 2 * x * x - 2 * y * y
+            [1 - 2 * y * y - 2 * z * z, 2 * x * y - 2 * z * w, 2 * x * z + 2 * y * w],
+            [2 * x * y + 2 * z * w, 1 - 2 * x * x - 2 * z * z, 2 * y * z - 2 * x * w],
+            [2 * x * z - 2 * y * w, 2 * y * z + 2 * x * w, 1 - 2 * x * x - 2 * y * y],
         ]
-    ])
+    )
     principal_axis.append(R[2, :])
-    t = image_pose[4 : 7]
+    t = image_pose[4:7]
     # World-to-Camera pose
     current_pose = np.zeros([4, 4])
-    current_pose[: 3, : 3] = R
-    current_pose[: 3, 3] = t
+    current_pose[:3, :3] = R
+    current_pose[:3, 3] = t
     current_pose[3, 3] = 1
     # Camera-to-World pose
     # pose = np.zeros([4, 4])
@@ -185,38 +158,38 @@ for idx, image_name in enumerate(image_names):
     # pose[: 3, 3] = -np.matmul(np.transpose(R), t)
     # pose[3, 3] = 1
     poses.append(current_pose)
-    
+
     current_points3D_id_to_ndepth = {}
     for point3D_id in points3D_id_to_2D[idx].keys():
         p3d = points3D[point3D_id]
-        current_points3D_id_to_ndepth[point3D_id] = (np.dot(R[2, :], p3d) + t[2]) / (.5 * (K[0, 0] + K[1, 1])) 
+        current_points3D_id_to_ndepth[point3D_id] = (np.dot(R[2, :], p3d) + t[2]) / (
+            0.5 * (K[0, 0] + K[1, 1])
+        )
     points3D_id_to_ndepth.append(current_points3D_id_to_ndepth)
 principal_axis = np.array(principal_axis)
-angles = np.rad2deg(np.arccos(
-    np.clip(
-        np.dot(principal_axis, np.transpose(principal_axis)),
-        -1, 1
-    )
-))
+angles = np.rad2deg(
+    np.arccos(np.clip(np.dot(principal_axis, np.transpose(principal_axis)), -1, 1))
+)
 
 # Compute overlap score
-overlap_matrix = np.full([n_images, n_images], -1.)
-scale_ratio_matrix = np.full([n_images, n_images], -1.)
+overlap_matrix = np.full([n_images, n_images], -1.0)
+scale_ratio_matrix = np.full([n_images, n_images], -1.0)
 for idx1 in range(n_images):
     if image_paths[idx1] is None or depth_paths[idx1] is None:
         continue
     for idx2 in range(idx1 + 1, n_images):
         if image_paths[idx2] is None or depth_paths[idx2] is None:
             continue
-        matches = (
-            points3D_id_to_2D[idx1].keys() &
-            points3D_id_to_2D[idx2].keys()
-        )
+        matches = points3D_id_to_2D[idx1].keys() & points3D_id_to_2D[idx2].keys()
         min_num_points3D = min(
             len(points3D_id_to_2D[idx1]), len(points3D_id_to_2D[idx2])
         )
-        overlap_matrix[idx1, idx2] = len(matches) / len(points3D_id_to_2D[idx1])  # min_num_points3D
-        overlap_matrix[idx2, idx1] = len(matches) / len(points3D_id_to_2D[idx2])  # min_num_points3D
+        overlap_matrix[idx1, idx2] = len(matches) / len(
+            points3D_id_to_2D[idx1]
+        )  # min_num_points3D
+        overlap_matrix[idx2, idx1] = len(matches) / len(
+            points3D_id_to_2D[idx2]
+        )  # min_num_points3D
         if len(matches) == 0:
             continue
         points3D_id_to_ndepth1 = points3D_id_to_ndepth[idx1]
@@ -228,7 +201,7 @@ for idx1 in range(n_images):
         scale_ratio_matrix[idx2, idx1] = min_scale_ratio
 
 np.savez(
-    os.path.join(args.output_path, '%s.npz' % scene_id),
+    os.path.join(args.output_path, "%s.npz" % scene_id),
     image_paths=image_paths,
     depth_paths=depth_paths,
     intrinsics=intrinsics,
@@ -238,5 +211,5 @@ np.savez(
     angles=angles,
     n_points3D=n_points3D,
     points3D_id_to_2D=points3D_id_to_2D,
-    points3D_id_to_ndepth=points3D_id_to_ndepth
-)
\ No newline at end of file
+    points3D_id_to_ndepth=points3D_id_to_ndepth,
+)
diff --git a/third_party/ASpanFormer/tools/reader.py b/third_party/ASpanFormer/tools/reader.py
index f419fbaa8a099fcfede1cea51fcf95a2c1589160..2734a7796ef8235bdbc1be317b6618f3d3185319 100644
--- a/third_party/ASpanFormer/tools/reader.py
+++ b/third_party/ASpanFormer/tools/reader.py
@@ -6,34 +6,45 @@ from SensorData import SensorData
 # params
 parser = argparse.ArgumentParser()
 # data paths
-parser.add_argument('--filename', required=True, help='path to sens file to read')
-parser.add_argument('--output_path', required=True, help='path to output folder')
-parser.add_argument('--export_depth_images', dest='export_depth_images', action='store_true')
-parser.add_argument('--export_color_images', dest='export_color_images', action='store_true')
-parser.add_argument('--export_poses', dest='export_poses', action='store_true')
-parser.add_argument('--export_intrinsics', dest='export_intrinsics', action='store_true')
-parser.set_defaults(export_depth_images=False, export_color_images=False, export_poses=False, export_intrinsics=False)
+parser.add_argument("--filename", required=True, help="path to sens file to read")
+parser.add_argument("--output_path", required=True, help="path to output folder")
+parser.add_argument(
+    "--export_depth_images", dest="export_depth_images", action="store_true"
+)
+parser.add_argument(
+    "--export_color_images", dest="export_color_images", action="store_true"
+)
+parser.add_argument("--export_poses", dest="export_poses", action="store_true")
+parser.add_argument(
+    "--export_intrinsics", dest="export_intrinsics", action="store_true"
+)
+parser.set_defaults(
+    export_depth_images=False,
+    export_color_images=False,
+    export_poses=False,
+    export_intrinsics=False,
+)
 
 opt = parser.parse_args()
 print(opt)
 
 
 def main():
-  if not os.path.exists(opt.output_path):
-    os.makedirs(opt.output_path)
-  # load the data
-  sys.stdout.write('loading %s...' % opt.filename)
-  sd = SensorData(opt.filename)
-  sys.stdout.write('loaded!\n')
-  if opt.export_depth_images:
-    sd.export_depth_images(os.path.join(opt.output_path, 'depth'))
-  if opt.export_color_images:
-    sd.export_color_images(os.path.join(opt.output_path, 'color'))
-  if opt.export_poses:
-    sd.export_poses(os.path.join(opt.output_path, 'pose'))
-  if opt.export_intrinsics:
-    sd.export_intrinsics(os.path.join(opt.output_path, 'intrinsic'))
+    if not os.path.exists(opt.output_path):
+        os.makedirs(opt.output_path)
+    # load the data
+    sys.stdout.write("loading %s..." % opt.filename)
+    sd = SensorData(opt.filename)
+    sys.stdout.write("loaded!\n")
+    if opt.export_depth_images:
+        sd.export_depth_images(os.path.join(opt.output_path, "depth"))
+    if opt.export_color_images:
+        sd.export_color_images(os.path.join(opt.output_path, "color"))
+    if opt.export_poses:
+        sd.export_poses(os.path.join(opt.output_path, "pose"))
+    if opt.export_intrinsics:
+        sd.export_intrinsics(os.path.join(opt.output_path, "intrinsic"))
 
 
-if __name__ == '__main__':
-    main()
\ No newline at end of file
+if __name__ == "__main__":
+    main()
diff --git a/third_party/ASpanFormer/tools/undistort_mega.py b/third_party/ASpanFormer/tools/undistort_mega.py
index 68798ff30e6afa37a0f98571ecfd3f05751868c8..fcd5ff2d77cd45dc9e5cebc48d7a173e31e68caf 100644
--- a/third_party/ASpanFormer/tools/undistort_mega.py
+++ b/third_party/ASpanFormer/tools/undistort_mega.py
@@ -6,28 +6,20 @@ import os
 
 import subprocess
 
-parser = argparse.ArgumentParser(description='MegaDepth Undistortion')
+parser = argparse.ArgumentParser(description="MegaDepth Undistortion")
 
 parser.add_argument(
-    '--colmap_path', type=str,default='/usr/bin/',
-    help='path to colmap executable'
+    "--colmap_path", type=str, default="/usr/bin/", help="path to colmap executable"
 )
 parser.add_argument(
-    '--base_path', type=str,default='/root/MegaDepth',
-    help='path to MegaDepth'
+    "--base_path", type=str, default="/root/MegaDepth", help="path to MegaDepth"
 )
 
 args = parser.parse_args()
 
-sfm_path = os.path.join(
-    args.base_path, 'MegaDepth_v1_SfM'
-)
-base_depth_path = os.path.join(
-    args.base_path, 'phoenix/S6/zl548/MegaDepth_v1'
-)
-output_path = os.path.join(
-    args.base_path, 'Undistorted_SfM'
-)
+sfm_path = os.path.join(args.base_path, "MegaDepth_v1_SfM")
+base_depth_path = os.path.join(args.base_path, "phoenix/S6/zl548/MegaDepth_v1")
+output_path = os.path.join(args.base_path, "Undistorted_SfM")
 
 os.mkdir(output_path)
 
@@ -35,35 +27,45 @@ for scene_name in os.listdir(base_depth_path):
     current_output_path = os.path.join(output_path, scene_name)
     os.mkdir(current_output_path)
 
-    image_path = os.path.join(
-        base_depth_path, scene_name, 'dense0', 'imgs'
-    )
+    image_path = os.path.join(base_depth_path, scene_name, "dense0", "imgs")
     if not os.path.exists(image_path):
         continue
-    
+
     # Find the maximum image size in scene.
     max_image_size = 0
     for image_name in os.listdir(image_path):
         max_image_size = max(
-            max_image_size,
-            max(imagesize.get(os.path.join(image_path, image_name)))
+            max_image_size, max(imagesize.get(os.path.join(image_path, image_name)))
         )
 
     # Undistort the images and update the reconstruction.
-    subprocess.call([
-        os.path.join(args.colmap_path, 'colmap'), 'image_undistorter', 
-        '--image_path', os.path.join(sfm_path, scene_name, 'images'),
-        '--input_path', os.path.join(sfm_path, scene_name, 'sparse', 'manhattan', '0'),
-        '--output_path',  current_output_path,
-        '--max_image_size', str(max_image_size)
-    ])
+    subprocess.call(
+        [
+            os.path.join(args.colmap_path, "colmap"),
+            "image_undistorter",
+            "--image_path",
+            os.path.join(sfm_path, scene_name, "images"),
+            "--input_path",
+            os.path.join(sfm_path, scene_name, "sparse", "manhattan", "0"),
+            "--output_path",
+            current_output_path,
+            "--max_image_size",
+            str(max_image_size),
+        ]
+    )
 
     # Transform the reconstruction to raw text format.
-    sparse_txt_path = os.path.join(current_output_path, 'sparse-txt')
+    sparse_txt_path = os.path.join(current_output_path, "sparse-txt")
     os.mkdir(sparse_txt_path)
-    subprocess.call([
-        os.path.join(args.colmap_path, 'colmap'), 'model_converter',
-        '--input_path', os.path.join(current_output_path, 'sparse'),
-        '--output_path', sparse_txt_path, 
-        '--output_type', 'TXT'
-    ])
\ No newline at end of file
+    subprocess.call(
+        [
+            os.path.join(args.colmap_path, "colmap"),
+            "model_converter",
+            "--input_path",
+            os.path.join(current_output_path, "sparse"),
+            "--output_path",
+            sparse_txt_path,
+            "--output_type",
+            "TXT",
+        ]
+    )
diff --git a/third_party/ASpanFormer/train.py b/third_party/ASpanFormer/train.py
index 21f644763711481e84863ed5d861ec57d95f2d5c..f1aeb79f630932b539500544d4249b1237d06605 100644
--- a/third_party/ASpanFormer/train.py
+++ b/third_party/ASpanFormer/train.py
@@ -23,41 +23,58 @@ loguru_logger = get_rank_zero_only_logger(loguru_logger)
 def parse_args():
     def str2bool(v):
         return v.lower() in ("true", "1")
+
     # init a costum parser which will be added into pl.Trainer parser
     # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
     parser = argparse.ArgumentParser(
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
-    parser.add_argument(
-        'data_cfg_path', type=str, help='data config path')
-    parser.add_argument(
-        'main_cfg_path', type=str, help='main config path')
-    parser.add_argument(
-        '--exp_name', type=str, default='default_exp_name')
-    parser.add_argument(
-        '--batch_size', type=int, default=4, help='batch_size per gpu')
-    parser.add_argument(
-        '--num_workers', type=int, default=4)
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+    parser.add_argument("data_cfg_path", type=str, help="data config path")
+    parser.add_argument("main_cfg_path", type=str, help="main config path")
+    parser.add_argument("--exp_name", type=str, default="default_exp_name")
+    parser.add_argument("--batch_size", type=int, default=4, help="batch_size per gpu")
+    parser.add_argument("--num_workers", type=int, default=4)
     parser.add_argument(
-        '--pin_memory', type=lambda x: bool(strtobool(x)),
-        nargs='?', default=True, help='whether loading data to pinned memory or not')
+        "--pin_memory",
+        type=lambda x: bool(strtobool(x)),
+        nargs="?",
+        default=True,
+        help="whether loading data to pinned memory or not",
+    )
     parser.add_argument(
-        '--ckpt_path', type=str, default=None,
-        help='pretrained checkpoint path, helpful for using a pre-trained coarse-only ASpanFormer')
+        "--ckpt_path",
+        type=str,
+        default=None,
+        help="pretrained checkpoint path, helpful for using a pre-trained coarse-only ASpanFormer",
+    )
     parser.add_argument(
-        '--disable_ckpt', action='store_true',
-        help='disable checkpoint saving (useful for debugging).')
+        "--disable_ckpt",
+        action="store_true",
+        help="disable checkpoint saving (useful for debugging).",
+    )
     parser.add_argument(
-        '--profiler_name', type=str, default=None,
-        help='options: [inference, pytorch], or leave it unset')
+        "--profiler_name",
+        type=str,
+        default=None,
+        help="options: [inference, pytorch], or leave it unset",
+    )
     parser.add_argument(
-        '--parallel_load_data', action='store_true',
-        help='load datasets in with multiple processes.')
+        "--parallel_load_data",
+        action="store_true",
+        help="load datasets in with multiple processes.",
+    )
     parser.add_argument(
-        '--mode', type=str, default='vanilla',
-        help='pretrained checkpoint path, helpful for using a pre-trained coarse-only ASpanFormer')
+        "--mode",
+        type=str,
+        default="vanilla",
+        help="pretrained checkpoint path, helpful for using a pre-trained coarse-only ASpanFormer",
+    )
     parser.add_argument(
-        '--ini', type=str2bool, default=False,
-        help='pretrained checkpoint path, helpful for using a pre-trained coarse-only ASpanFormer')
+        "--ini",
+        type=str2bool,
+        default=False,
+        help="pretrained checkpoint path, helpful for using a pre-trained coarse-only ASpanFormer",
+    )
 
     parser = pl.Trainer.add_argparse_args(parser)
     return parser.parse_args()
@@ -83,8 +100,7 @@ def main():
     _scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS
     config.TRAINER.SCALING = _scaling
     config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling
-    config.TRAINER.WARMUP_STEP = math.floor(
-        config.TRAINER.WARMUP_STEP / _scaling)
+    config.TRAINER.WARMUP_STEP = math.floor(config.TRAINER.WARMUP_STEP / _scaling)
 
     # lightning module
     profiler = build_profiler(args.profiler_name)
@@ -97,16 +113,22 @@ def main():
 
     # TensorBoard Logger
     logger = TensorBoardLogger(
-        save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False)
-    ckpt_dir = Path(logger.log_dir) / 'checkpoints'
+        save_dir="logs/tb_logs", name=args.exp_name, default_hp_metric=False
+    )
+    ckpt_dir = Path(logger.log_dir) / "checkpoints"
 
     # Callbacks
     # TODO: update ModelCheckpoint to monitor multiple metrics
-    ckpt_callback = ModelCheckpoint(monitor='auc@10', verbose=True, save_top_k=5, mode='max',
-                                    save_last=True,
-                                    dirpath=str(ckpt_dir),
-                                    filename='{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}')
-    lr_monitor = LearningRateMonitor(logging_interval='step')
+    ckpt_callback = ModelCheckpoint(
+        monitor="auc@10",
+        verbose=True,
+        save_top_k=5,
+        mode="max",
+        save_last=True,
+        dirpath=str(ckpt_dir),
+        filename="{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}",
+    )
+    lr_monitor = LearningRateMonitor(logging_interval="step")
     callbacks = [lr_monitor]
     if not args.disable_ckpt:
         callbacks.append(ckpt_callback)
@@ -114,21 +136,24 @@ def main():
     # Lightning Trainer
     trainer = pl.Trainer.from_argparse_args(
         args,
-        plugins=DDPPlugin(find_unused_parameters=False,
-                          num_nodes=args.num_nodes,
-                          sync_batchnorm=config.TRAINER.WORLD_SIZE > 0),
+        plugins=DDPPlugin(
+            find_unused_parameters=False,
+            num_nodes=args.num_nodes,
+            sync_batchnorm=config.TRAINER.WORLD_SIZE > 0,
+        ),
         gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING,
         callbacks=callbacks,
         logger=logger,
         sync_batchnorm=config.TRAINER.WORLD_SIZE > 0,
         replace_sampler_ddp=False,  # use custom sampler
         reload_dataloaders_every_epoch=False,  # avoid repeated samples!
-        weights_summary='full',
-        profiler=profiler)
+        weights_summary="full",
+        profiler=profiler,
+    )
     loguru_logger.info(f"Trainer initialized!")
     loguru_logger.info(f"Start training!")
     trainer.fit(model, datamodule=data_module)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
diff --git a/third_party/DKM/demo/demo_fundamental.py b/third_party/DKM/demo/demo_fundamental.py
index e19766d5d3ce1abf0d18483cbbce71b2696983be..643ae3d62d3d4a09d1eb6f7b351ea23f2095b725 100644
--- a/third_party/DKM/demo/demo_fundamental.py
+++ b/third_party/DKM/demo/demo_fundamental.py
@@ -6,11 +6,12 @@ from dkm.utils.utils import tensor_to_pil
 import cv2
 from dkm import DKMv3_outdoor
 
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 if __name__ == "__main__":
     from argparse import ArgumentParser
+
     parser = ArgumentParser()
     parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
     parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
@@ -22,7 +23,6 @@ if __name__ == "__main__":
     # Create model
     dkm_model = DKMv3_outdoor(device=device)
 
-
     W_A, H_A = Image.open(im1_path).size
     W_B, H_B = Image.open(im2_path).size
 
@@ -30,8 +30,13 @@ if __name__ == "__main__":
     warp, certainty = dkm_model.match(im1_path, im2_path, device=device)
     # Sample matches for estimation
     matches, certainty = dkm_model.sample(warp, certainty)
-    kpts1, kpts2 = dkm_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)    
+    kpts1, kpts2 = dkm_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
     F, mask = cv2.findFundamentalMat(
-        kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
+        kpts1.cpu().numpy(),
+        kpts2.cpu().numpy(),
+        ransacReprojThreshold=0.2,
+        method=cv2.USAC_MAGSAC,
+        confidence=0.999999,
+        maxIters=10000,
     )
-    # TODO: some better visualization    
\ No newline at end of file
+    # TODO: some better visualization
diff --git a/third_party/DKM/demo/demo_match.py b/third_party/DKM/demo/demo_match.py
index fb901894d8654a884819162d3b9bb8094529e034..aef324e1b19a76498dc0476714149534546e0218 100644
--- a/third_party/DKM/demo/demo_match.py
+++ b/third_party/DKM/demo/demo_match.py
@@ -6,15 +6,18 @@ from dkm.utils.utils import tensor_to_pil
 
 from dkm import DKMv3_outdoor
 
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 if __name__ == "__main__":
     from argparse import ArgumentParser
+
     parser = ArgumentParser()
     parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
     parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
-    parser.add_argument("--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str)
+    parser.add_argument(
+        "--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str
+    )
 
     args, _ = parser.parse_known_args()
     im1_path = args.im_A_path
@@ -37,12 +40,12 @@ if __name__ == "__main__":
     x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
 
     im2_transfer_rgb = F.grid_sample(
-    x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
+        x2[None], warp[:, :W, 2:][None], mode="bilinear", align_corners=False
     )[0]
     im1_transfer_rgb = F.grid_sample(
-    x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
+        x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
     )[0]
-    warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
-    white_im = torch.ones((H,2*W),device=device)
+    warp_im = torch.cat((im2_transfer_rgb, im1_transfer_rgb), dim=2)
+    white_im = torch.ones((H, 2 * W), device=device)
     vis_im = certainty * warp_im + (1 - certainty) * white_im
     tensor_to_pil(vis_im, unnormalize=False).save(save_path)
diff --git a/third_party/DKM/dkm/__init__.py b/third_party/DKM/dkm/__init__.py
index a9b47632780acc7762bcccc348e2025fe99f3726..27099047d713e61a103bd0f439f292245ad720a3 100644
--- a/third_party/DKM/dkm/__init__.py
+++ b/third_party/DKM/dkm/__init__.py
@@ -1,4 +1,4 @@
 from .models import (
     DKMv3_outdoor,
     DKMv3_indoor,
-    )
+)
diff --git a/third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py b/third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py
index 9c3febe5ca9e3a683bc7122cec635c4f54b66f7c..719e298726528754c3f826d6d2f2fe2ce9b3b903 100644
--- a/third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py
+++ b/third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py
@@ -53,7 +53,7 @@ class HpatchesHomogBenchmark:
         )
         return query_coords, query_to_support
 
-    def benchmark(self, model, model_name = None):
+    def benchmark(self, model, model_name=None):
         n_matches = []
         homog_dists = []
         for seq_idx, seq_name in tqdm(
@@ -71,9 +71,7 @@ class HpatchesHomogBenchmark:
                 H = np.loadtxt(
                     os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
                 )
-                dense_matches, dense_certainty = model.match(
-                    im1_path, im2_path
-                )
+                dense_matches, dense_certainty = model.match(im1_path, im2_path)
                 good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
                 pos_a, pos_b = self.convert_coordinates(
                     good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
@@ -82,9 +80,9 @@ class HpatchesHomogBenchmark:
                     H_pred, inliers = cv2.findHomography(
                         pos_a,
                         pos_b,
-                        method = cv2.RANSAC,
-                        confidence = 0.99999,
-                        ransacReprojThreshold = 3 * min(w2, h2) / 480,
+                        method=cv2.RANSAC,
+                        confidence=0.99999,
+                        ransacReprojThreshold=3 * min(w2, h2) / 480,
                     )
                 except:
                     H_pred = None
diff --git a/third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py b/third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py
index 6b1193745ff18d239165aeb3376642fb17033874..d9499f1e92fd4df3ad6fe59c37b6c881d5322a51 100644
--- a/third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py
+++ b/third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py
@@ -5,8 +5,9 @@ from PIL import Image
 from tqdm import tqdm
 import torch.nn.functional as F
 
+
 class Megadepth1500Benchmark:
-    def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
+    def __init__(self, data_root="data/megadepth", scene_names=None) -> None:
         if scene_names is None:
             self.scene_names = [
                 "0015_0.1_0.3.npz",
@@ -56,28 +57,24 @@ class Megadepth1500Benchmark:
                     K1[:2] = K1[:2] * scale1
                     K2[:2] = K2[:2] * scale2
                     dense_matches, dense_certainty = model.match(im1_path, im2_path)
-                    sparse_matches,_ = model.sample(
+                    sparse_matches, _ = model.sample(
                         dense_matches, dense_certainty, 5000
                     )
                     kpts1 = sparse_matches[:, :2]
-                    kpts1 = (
-                        torch.stack(
-                            (
-                                w1 * (kpts1[:, 0] + 1) / 2,
-                                h1 * (kpts1[:, 1] + 1) / 2,
-                            ),
-                            axis=-1,
-                        )
+                    kpts1 = torch.stack(
+                        (
+                            w1 * (kpts1[:, 0] + 1) / 2,
+                            h1 * (kpts1[:, 1] + 1) / 2,
+                        ),
+                        axis=-1,
                     )
                     kpts2 = sparse_matches[:, 2:]
-                    kpts2 = (
-                        torch.stack(
-                            (
-                                w2 * (kpts2[:, 0] + 1) / 2,
-                                h2 * (kpts2[:, 1] + 1) / 2,
-                            ),
-                            axis=-1,
-                        )
+                    kpts2 = torch.stack(
+                        (
+                            w2 * (kpts2[:, 0] + 1) / 2,
+                            h2 * (kpts2[:, 1] + 1) / 2,
+                        ),
+                        axis=-1,
                     )
                     for _ in range(5):
                         shuffling = np.random.permutation(np.arange(len(kpts1)))
@@ -85,7 +82,9 @@ class Megadepth1500Benchmark:
                         kpts2 = kpts2[shuffling]
                         try:
                             norm_threshold = 0.5 / (
-                            np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
+                                np.mean(np.abs(K1[:2, :2]))
+                                + np.mean(np.abs(K2[:2, :2]))
+                            )
                             R_est, t_est, mask = estimate_pose(
                                 kpts1.cpu().numpy(),
                                 kpts2.cpu().numpy(),
diff --git a/third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py b/third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py
index 0b370644497efd62563105e68e692e10ff339669..5e8d597760a82349d043055f5ca867f1f79fc55a 100644
--- a/third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py
+++ b/third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py
@@ -7,14 +7,16 @@ from torch.utils.data import ConcatDataset
 
 
 class MegadepthDenseBenchmark:
-    def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000, device=None) -> None:
+    def __init__(
+        self, data_root="data/megadepth", h=384, w=512, num_samples=2000, device=None
+    ) -> None:
         mega = MegadepthBuilder(data_root=data_root)
         self.dataset = ConcatDataset(
             mega.build_scenes(split="test_loftr", ht=h, wt=w)
         )  # fixed resolution of 384,512
         self.num_samples = num_samples
         if device is None:
-            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         self.device = device
 
     def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches):
@@ -54,7 +56,9 @@ class MegadepthDenseBenchmark:
             pck_3_tot = 0.0
             pck_5_tot = 0.0
             sampler = torch.utils.data.WeightedRandomSampler(
-                torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples
+                torch.ones(len(self.dataset)),
+                replacement=False,
+                num_samples=self.num_samples,
             )
             dataloader = torch.utils.data.DataLoader(
                 self.dataset, batch_size=8, num_workers=batch_size, sampler=sampler
diff --git a/third_party/DKM/dkm/benchmarks/scannet_benchmark.py b/third_party/DKM/dkm/benchmarks/scannet_benchmark.py
index ca938cb462c351845ce035f8be0714cf81214452..1ad659f887d3863812a368dcb210fbd7bbadb04e 100644
--- a/third_party/DKM/dkm/benchmarks/scannet_benchmark.py
+++ b/third_party/DKM/dkm/benchmarks/scannet_benchmark.py
@@ -10,7 +10,7 @@ class ScanNetBenchmark:
     def __init__(self, data_root="data/scannet") -> None:
         self.data_root = data_root
 
-    def benchmark(self, model, model_name = None):
+    def benchmark(self, model, model_name=None):
         model.train(False)
         with torch.no_grad():
             data_root = self.data_root
@@ -24,20 +24,20 @@ class ScanNetBenchmark:
                 scene = pairs[pairind]
                 scene_name = f"scene0{scene[0]}_00"
                 im1_path = osp.join(
-                        self.data_root,
-                        "scans_test",
-                        scene_name,
-                        "color",
-                        f"{scene[2]}.jpg",
-                    )
+                    self.data_root,
+                    "scans_test",
+                    scene_name,
+                    "color",
+                    f"{scene[2]}.jpg",
+                )
                 im1 = Image.open(im1_path)
                 im2_path = osp.join(
-                        self.data_root,
-                        "scans_test",
-                        scene_name,
-                        "color",
-                        f"{scene[3]}.jpg",
-                    )
+                    self.data_root,
+                    "scans_test",
+                    scene_name,
+                    "color",
+                    f"{scene[3]}.jpg",
+                )
                 im2 = Image.open(im2_path)
                 T_gt = rel_pose[pairind].reshape(3, 4)
                 R, t = T_gt[:3, :3], T_gt[:3, 3]
@@ -76,24 +76,20 @@ class ScanNetBenchmark:
 
                 offset = 0.5
                 kpts1 = sparse_matches[:, :2]
-                kpts1 = (
-                    np.stack(
-                        (
-                            w1 * (kpts1[:, 0] + 1) / 2 - offset,
-                            h1 * (kpts1[:, 1] + 1) / 2 - offset,
-                        ),
-                        axis=-1,
-                    )
+                kpts1 = np.stack(
+                    (
+                        w1 * (kpts1[:, 0] + 1) / 2 - offset,
+                        h1 * (kpts1[:, 1] + 1) / 2 - offset,
+                    ),
+                    axis=-1,
                 )
                 kpts2 = sparse_matches[:, 2:]
-                kpts2 = (
-                    np.stack(
-                        (
-                            w2 * (kpts2[:, 0] + 1) / 2 - offset,
-                            h2 * (kpts2[:, 1] + 1) / 2 - offset,
-                        ),
-                        axis=-1,
-                    )
+                kpts2 = np.stack(
+                    (
+                        w2 * (kpts2[:, 0] + 1) / 2 - offset,
+                        h2 * (kpts2[:, 1] + 1) / 2 - offset,
+                    ),
+                    axis=-1,
                 )
                 for _ in range(5):
                     shuffling = np.random.permutation(np.arange(len(kpts1)))
@@ -101,7 +97,8 @@ class ScanNetBenchmark:
                     kpts2 = kpts2[shuffling]
                     try:
                         norm_threshold = 0.5 / (
-                        np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
+                            np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))
+                        )
                         R_est, t_est, mask = estimate_pose(
                             kpts1,
                             kpts2,
diff --git a/third_party/DKM/dkm/datasets/scannet.py b/third_party/DKM/dkm/datasets/scannet.py
index 6ac39b41480f7585c4755cc30e0677ef74ed5e0c..fc24263c771f5fbb5d1e676257e9ad484a03ae31 100644
--- a/third_party/DKM/dkm/datasets/scannet.py
+++ b/third_party/DKM/dkm/datasets/scannet.py
@@ -5,10 +5,7 @@ import cv2
 import h5py
 import numpy as np
 import torch
-from torch.utils.data import (
-    Dataset,
-    DataLoader,
-    ConcatDataset)
+from torch.utils.data import Dataset, DataLoader, ConcatDataset
 
 import torchvision.transforms.functional as tvf
 import kornia.augmentation as K
@@ -19,21 +16,35 @@ from dkm.utils.transforms import GeometricSequential
 
 from tqdm import tqdm
 
+
 class ScanNetScene:
-    def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.) -> None:
-        self.scene_root = osp.join(data_root,"scans","scans_train")
-        self.data_names = scene_info['name']
-        self.overlaps = scene_info['score']
+    def __init__(
+        self,
+        data_root,
+        scene_info,
+        ht=384,
+        wt=512,
+        min_overlap=0.0,
+        shake_t=0,
+        rot_prob=0.0,
+    ) -> None:
+        self.scene_root = osp.join(data_root, "scans", "scans_train")
+        self.data_names = scene_info["name"]
+        self.overlaps = scene_info["score"]
         # Only sample 10s
-        valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0
+        valid = (self.data_names[:, -2:] % 10).sum(axis=-1) == 0
         self.overlaps = self.overlaps[valid]
         self.data_names = self.data_names[valid]
         if len(self.data_names) > 10000:
-            pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False)
+            pairinds = np.random.choice(
+                np.arange(0, len(self.data_names)), 10000, replace=False
+            )
             self.data_names = self.data_names[pairinds]
             self.overlaps = self.overlaps[pairinds]
         self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True)
-        self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False)
+        self.depth_transform_ops = get_depth_tuple_transform_ops(
+            resize=(ht, wt), normalize=False
+        )
         self.wt, self.ht = wt, ht
         self.shake_t = shake_t
         self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
@@ -41,7 +52,7 @@ class ScanNetScene:
     def load_im(self, im_ref, crop=None):
         im = Image.open(im_ref)
         return im
-    
+
     def load_depth(self, depth_ref, crop=None):
         depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED)
         depth = depth / 1000
@@ -50,55 +61,61 @@ class ScanNetScene:
 
     def __len__(self):
         return len(self.data_names)
-    
+
     def scale_intrinsic(self, K, wi, hi):
-        sx, sy = self.wt / wi, self.ht /  hi
-        sK = torch.tensor([[sx, 0, 0],
-                        [0, sy, 0],
-                        [0, 0, 1]])
-        return sK@K
-
-    def read_scannet_pose(self,path):
-        """ Read ScanNet's Camera2World pose and transform it to World2Camera.
-        
+        sx, sy = self.wt / wi, self.ht / hi
+        sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
+        return sK @ K
+
+    def read_scannet_pose(self, path):
+        """Read ScanNet's Camera2World pose and transform it to World2Camera.
+
         Returns:
             pose_w2c (np.ndarray): (4, 4)
         """
-        cam2world = np.loadtxt(path, delimiter=' ')
+        cam2world = np.loadtxt(path, delimiter=" ")
         world2cam = np.linalg.inv(cam2world)
         return world2cam
 
-
-    def read_scannet_intrinsic(self,path):
-        """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
-        """
-        intrinsic = np.loadtxt(path, delimiter=' ')
+    def read_scannet_intrinsic(self, path):
+        """Read ScanNet's intrinsic matrix and return the 3x3 matrix."""
+        intrinsic = np.loadtxt(path, delimiter=" ")
         return intrinsic[:-1, :-1]
 
     def __getitem__(self, pair_idx):
         # read intrinsics of original size
         data_name = self.data_names[pair_idx]
         scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name
-        scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
-        
+        scene_name = f"scene{scene_name:04d}_{scene_sub_name:02d}"
+
         # read the intrinsic of depthmap
-        K1 = K2 =  self.read_scannet_intrinsic(osp.join(self.scene_root,
-                       scene_name,
-                       'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter
+        K1 = K2 = self.read_scannet_intrinsic(
+            osp.join(self.scene_root, scene_name, "intrinsic", "intrinsic_color.txt")
+        )  # the depth K is not the same, but doesnt really matter
         # read and compute relative poses
-        T1 =  self.read_scannet_pose(osp.join(self.scene_root,
-                       scene_name,
-                       'pose', f'{stem_name_1}.txt'))
-        T2 =  self.read_scannet_pose(osp.join(self.scene_root,
-                       scene_name,
-                       'pose', f'{stem_name_2}.txt'))
-        T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4]  # (4, 4)
+        T1 = self.read_scannet_pose(
+            osp.join(self.scene_root, scene_name, "pose", f"{stem_name_1}.txt")
+        )
+        T2 = self.read_scannet_pose(
+            osp.join(self.scene_root, scene_name, "pose", f"{stem_name_2}.txt")
+        )
+        T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[
+            :4, :4
+        ]  # (4, 4)
 
         # Load positive pair data
-        im_src_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg')
-        im_pos_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg')
-        depth_src_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png')
-        depth_pos_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png')
+        im_src_ref = os.path.join(
+            self.scene_root, scene_name, "color", f"{stem_name_1}.jpg"
+        )
+        im_pos_ref = os.path.join(
+            self.scene_root, scene_name, "color", f"{stem_name_2}.jpg"
+        )
+        depth_src_ref = os.path.join(
+            self.scene_root, scene_name, "depth", f"{stem_name_1}.png"
+        )
+        depth_pos_ref = os.path.join(
+            self.scene_root, scene_name, "depth", f"{stem_name_2}.png"
+        )
 
         im_src = self.load_im(im_src_ref)
         im_pos = self.load_im(im_pos_ref)
@@ -110,42 +127,53 @@ class ScanNetScene:
         K2 = self.scale_intrinsic(K2, im_pos.width, im_pos.height)
         # Process images
         im_src, im_pos = self.im_transform_ops((im_src, im_pos))
-        depth_src, depth_pos = self.depth_transform_ops((depth_src[None,None], depth_pos[None,None]))
-
-        data_dict = {'query': im_src,
-                    'support': im_pos,
-                    'query_depth': depth_src[0,0],
-                    'support_depth': depth_pos[0,0],
-                    'K1': K1,
-                    'K2': K2,
-                    'T_1to2':T_1to2,
-                    }
+        depth_src, depth_pos = self.depth_transform_ops(
+            (depth_src[None, None], depth_pos[None, None])
+        )
+
+        data_dict = {
+            "query": im_src,
+            "support": im_pos,
+            "query_depth": depth_src[0, 0],
+            "support_depth": depth_pos[0, 0],
+            "K1": K1,
+            "K2": K2,
+            "T_1to2": T_1to2,
+        }
         return data_dict
 
 
 class ScanNetBuilder:
-    def __init__(self, data_root = 'data/scannet') -> None:
+    def __init__(self, data_root="data/scannet") -> None:
         self.data_root = data_root
-        self.scene_info_root = os.path.join(data_root,'scannet_indices')
+        self.scene_info_root = os.path.join(data_root, "scannet_indices")
         self.all_scenes = os.listdir(self.scene_info_root)
-        
-    def build_scenes(self, split = 'train', min_overlap=0., **kwargs):
+
+    def build_scenes(self, split="train", min_overlap=0.0, **kwargs):
         # Note: split doesn't matter here as we always use same scannet_train scenes
         scene_names = self.all_scenes
         scenes = []
         for scene_name in tqdm(scene_names):
-            scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True)
-            scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs))
+            scene_info = np.load(
+                os.path.join(self.scene_info_root, scene_name), allow_pickle=True
+            )
+            scenes.append(
+                ScanNetScene(
+                    self.data_root, scene_info, min_overlap=min_overlap, **kwargs
+                )
+            )
         return scenes
-    
-    def weight_scenes(self, concat_dataset, alpha=.5):
+
+    def weight_scenes(self, concat_dataset, alpha=0.5):
         ns = []
         for d in concat_dataset.datasets:
             ns.append(len(d))
-        ws = torch.cat([torch.ones(n)/n**alpha for n in ns])
+        ws = torch.cat([torch.ones(n) / n**alpha for n in ns])
         return ws
 
 
 if __name__ == "__main__":
-    mega_test = ConcatDataset(ScanNetBuilder("data/scannet").build_scenes(split='train'))
-    mega_test[0]
\ No newline at end of file
+    mega_test = ConcatDataset(
+        ScanNetBuilder("data/scannet").build_scenes(split="train")
+    )
+    mega_test[0]
diff --git a/third_party/DKM/dkm/models/deprecated/build_model.py b/third_party/DKM/dkm/models/deprecated/build_model.py
index dd28335f3e348ab6c90b26ba91b95e864b0bbbb9..6b4f6608296c21387b19242681e6e49160c0887e 100644
--- a/third_party/DKM/dkm/models/deprecated/build_model.py
+++ b/third_party/DKM/dkm/models/deprecated/build_model.py
@@ -10,16 +10,16 @@ dkm_pretrained_urls = {
         "mega_synthetic": "https://github.com/Parskatt/storage/releases/download/dkm_mega_synthetic/dkm_mega_synthetic.pth",
         "mega": "https://github.com/Parskatt/storage/releases/download/dkm_mega/dkm_mega.pth",
     },
-    "DKMv2":{
+    "DKMv2": {
         "outdoor": "https://github.com/Parskatt/storage/releases/download/dkmv2/dkm_v2_outdoor.pth",
         "indoor": "https://github.com/Parskatt/storage/releases/download/dkmv2/dkm_v2_indoor.pth",
-    }
+    },
 }
 
 
 def DKM(pretrained=True, version="mega_synthetic", device=None):
     if device is None:
-        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     gp_dim = 256
     dfn_dim = 384
     feat_dim = 256
@@ -150,7 +150,8 @@ def DKM(pretrained=True, version="mega_synthetic", device=None):
         matcher.load_state_dict(weights)
     return matcher
 
-def DKMv2(pretrained=True, version="outdoor", resolution = "low", **kwargs):
+
+def DKMv2(pretrained=True, version="outdoor", resolution="low", **kwargs):
     gp_dim = 256
     dfn_dim = 384
     feat_dim = 256
@@ -200,8 +201,8 @@ def DKMv2(pretrained=True, version="outdoor", resolution = "low", **kwargs):
     conv_refiner = nn.ModuleDict(
         {
             "16": ConvRefiner(
-                2 * 512+128,
-                1024+128,
+                2 * 512 + 128,
+                1024 + 128,
                 3,
                 kernel_size=kernel_size,
                 dw=dw,
@@ -210,8 +211,8 @@ def DKMv2(pretrained=True, version="outdoor", resolution = "low", **kwargs):
                 displacement_emb_dim=128,
             ),
             "8": ConvRefiner(
-                2 * 512+64,
-                1024+64,
+                2 * 512 + 64,
+                1024 + 64,
                 3,
                 kernel_size=kernel_size,
                 dw=dw,
@@ -220,8 +221,8 @@ def DKMv2(pretrained=True, version="outdoor", resolution = "low", **kwargs):
                 displacement_emb_dim=64,
             ),
             "4": ConvRefiner(
-                2 * 256+32,
-                512+32,
+                2 * 256 + 32,
+                512 + 32,
                 3,
                 kernel_size=kernel_size,
                 dw=dw,
@@ -230,8 +231,8 @@ def DKMv2(pretrained=True, version="outdoor", resolution = "low", **kwargs):
                 displacement_emb_dim=32,
             ),
             "2": ConvRefiner(
-                2 * 64+16,
-                128+16,
+                2 * 64 + 16,
+                128 + 16,
                 3,
                 kernel_size=kernel_size,
                 dw=dw,
@@ -240,7 +241,7 @@ def DKMv2(pretrained=True, version="outdoor", resolution = "low", **kwargs):
                 displacement_emb_dim=16,
             ),
             "1": ConvRefiner(
-                2 * 3+6,
+                2 * 3 + 6,
                 24,
                 3,
                 kernel_size=kernel_size,
@@ -287,16 +288,14 @@ def DKMv2(pretrained=True, version="outdoor", resolution = "low", **kwargs):
     encoder = Encoder(
         tv_resnet.resnet50(pretrained=not pretrained),
     )  # only load pretrained weights if not loading a pretrained matcher ;)
-    matcher = RegressionMatcher(encoder, decoder, h=h, w=w,**kwargs).to(device)
+    matcher = RegressionMatcher(encoder, decoder, h=h, w=w, **kwargs).to(device)
     if pretrained:
         try:
             weights = torch.hub.load_state_dict_from_url(
                 dkm_pretrained_urls["DKMv2"][version]
             )
         except:
-            weights = torch.load(
-                dkm_pretrained_urls["DKMv2"][version]
-            )
+            weights = torch.load(dkm_pretrained_urls["DKMv2"][version])
         matcher.load_state_dict(weights)
     return matcher
 
diff --git a/third_party/DKM/dkm/models/deprecated/local_corr.py b/third_party/DKM/dkm/models/deprecated/local_corr.py
index 681fe4c0079561fa7a4c44e82a8879a4a27273a1..227d73b00be7efd7f64c32936b3dcdd7e5b4d123 100644
--- a/third_party/DKM/dkm/models/deprecated/local_corr.py
+++ b/third_party/DKM/dkm/models/deprecated/local_corr.py
@@ -10,8 +10,8 @@ from ..dkm import ConvRefiner
 
 
 class Stream:
-    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-    if device == 'cuda':
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    if device == "cuda":
         stream = torch.cuda.current_stream(device=device).cuda_stream
     else:
         stream = None
@@ -622,7 +622,7 @@ class LocalCorr(ConvRefiner):
 
 
 if __name__ == "__main__":
-    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     x = torch.randn(2, 128, 32, 32).to(device)
     y = torch.randn(2, 128, 32, 32).to(device)
     local_corr = LocalCorr(in_dim=81, hidden_dim=81 * 4)
diff --git a/third_party/DKM/dkm/models/dkm.py b/third_party/DKM/dkm/models/dkm.py
index 27c3f6d59ad3a8e976e3d719868908ddf443883e..58462e5d14cf9cac6e1fa551298f9fc82f93fcab 100644
--- a/third_party/DKM/dkm/models/dkm.py
+++ b/third_party/DKM/dkm/models/dkm.py
@@ -19,11 +19,11 @@ class ConvRefiner(nn.Module):
         dw=False,
         kernel_size=5,
         hidden_blocks=3,
-        displacement_emb = None,
-        displacement_emb_dim = None,
-        local_corr_radius = None,
-        corr_in_other = None,
-        no_support_fm = False,
+        displacement_emb=None,
+        displacement_emb_dim=None,
+        local_corr_radius=None,
+        corr_in_other=None,
+        no_support_fm=False,
     ):
         super().__init__()
         self.block1 = self.create_block(
@@ -43,12 +43,13 @@ class ConvRefiner(nn.Module):
         self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
         if displacement_emb:
             self.has_displacement_emb = True
-            self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0)
+            self.disp_emb = nn.Conv2d(2, displacement_emb_dim, 1, 1, 0)
         else:
             self.has_displacement_emb = False
         self.local_corr_radius = local_corr_radius
         self.corr_in_other = corr_in_other
         self.no_support_fm = no_support_fm
+
     def create_block(
         self,
         in_dim,
@@ -86,29 +87,35 @@ class ConvRefiner(nn.Module):
             [type]: [description]
         """
         device = x.device
-        b,c,hs,ws = x.shape
+        b, c, hs, ws = x.shape
         with torch.no_grad():
             x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False)
         if self.has_displacement_emb:
             query_coords = torch.meshgrid(
-            (
-                torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
-                torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
-            )
+                (
+                    torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
+                    torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
+                )
             )
             query_coords = torch.stack((query_coords[1], query_coords[0]))
             query_coords = query_coords[None].expand(b, 2, hs, ws)
-            in_displacement = flow-query_coords
+            in_displacement = flow - query_coords
             emb_in_displacement = self.disp_emb(in_displacement)
             if self.local_corr_radius:
-                #TODO: should corr have gradient?
+                # TODO: should corr have gradient?
                 if self.corr_in_other:
                     # Corr in other means take a kxk grid around the predicted coordinate in other image
-                    local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow)
+                    local_corr = local_correlation(
+                        x, y, local_radius=self.local_corr_radius, flow=flow
+                    )
                 else:
                     # Otherwise we use the warp to sample in the first image
                     # This is actually different operations, especially for large viewpoint changes
-                    local_corr = local_correlation(x, x_hat, local_radius=self.local_corr_radius,)
+                    local_corr = local_correlation(
+                        x,
+                        x_hat,
+                        local_radius=self.local_corr_radius,
+                    )
                 if self.no_support_fm:
                     x_hat = torch.zeros_like(x)
                 d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
@@ -269,7 +276,7 @@ class GP(nn.Module):
         only_nearest_neighbour=False,
         sigma_noise=0.1,
         no_cov=False,
-        predict_features = False,
+        predict_features=False,
     ):
         super().__init__()
         self.K = kernel(T=T, learn_temperature=learn_temperature)
@@ -344,9 +351,9 @@ class GP(nn.Module):
         b, c, h2, w2 = y.shape
         f = self.get_pos_enc(y)
         if self.predict_features:
-            f = f + y[:,:self.dim] # Stupid way to predict features
+            f = f + y[:, : self.dim]  # Stupid way to predict features
         b, d, h2, w2 = f.shape
-        #assert x.shape == y.shape
+        # assert x.shape == y.shape
         x, y, f = self.reshape(x), self.reshape(y), self.reshape(f)
         K_xx = self.K(x, x)
         K_yy = self.K(y, y)
@@ -355,7 +362,12 @@ class GP(nn.Module):
         sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :]
         # Due to https://github.com/pytorch/pytorch/issues/16963 annoying warnings, remove batch if N large
         if len(K_yy[0]) > 2000:
-            K_yy_inv = torch.cat([torch.linalg.inv(K_yy[k:k+1] + sigma_noise[k:k+1]) for k in range(b)])
+            K_yy_inv = torch.cat(
+                [
+                    torch.linalg.inv(K_yy[k : k + 1] + sigma_noise[k : k + 1])
+                    for k in range(b)
+                ]
+            )
         else:
             K_yy_inv = torch.linalg.inv(K_yy + sigma_noise)
 
@@ -363,7 +375,9 @@ class GP(nn.Module):
         mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
         if not self.no_cov:
             cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
-            cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1)
+            cov_x = rearrange(
+                cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1
+            )
             local_cov_x = self.get_local_cov(cov_x)
             local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
             gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
@@ -376,6 +390,7 @@ class Encoder(nn.Module):
     def __init__(self, resnet):
         super().__init__()
         self.resnet = resnet
+
     def forward(self, x):
         x0 = x
         b, c, h, w = x.shape
@@ -404,7 +419,15 @@ class Encoder(nn.Module):
 
 class Decoder(nn.Module):
     def __init__(
-        self, embedding_decoder, gps, proj, conv_refiner, transformers = None, detach=False, scales="all", pos_embeddings = None,
+        self,
+        embedding_decoder,
+        gps,
+        proj,
+        conv_refiner,
+        transformers=None,
+        detach=False,
+        scales="all",
+        pos_embeddings=None,
     ):
         super().__init__()
         self.embedding_decoder = embedding_decoder
@@ -424,17 +447,15 @@ class Decoder(nn.Module):
         certainty = F.interpolate(
             certainty, size=(h, w), align_corners=False, mode="bilinear"
         )
-        flow = F.interpolate(
-            flow, size=(h, w), align_corners=False, mode="bilinear"
-        )
+        flow = F.interpolate(flow, size=(h, w), align_corners=False, mode="bilinear")
         delta_certainty, delta_flow = self.conv_refiner["1"](query, support, flow)
         flow = torch.stack(
-                (
-                    flow[:, 0] + delta_flow[:, 0] / (4 * w),
-                    flow[:, 1] + delta_flow[:, 1] / (4 * h),
-                ),
-                dim=1,
-            )
+            (
+                flow[:, 0] + delta_flow[:, 0] / (4 * w),
+                flow[:, 1] + delta_flow[:, 1] / (4 * h),
+            ),
+            dim=1,
+        )
         flow = flow.permute(0, 2, 3, 1)
         certainty = certainty + delta_certainty
         return flow, certainty
@@ -452,8 +473,7 @@ class Decoder(nn.Module):
         coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
         return coarse_coords
 
-
-    def forward(self, f1, f2, upsample = False, dense_flow = None, dense_certainty = None):
+    def forward(self, f1, f2, upsample=False, dense_flow=None, dense_certainty=None):
         coarse_scales = self.embedding_decoder.scales()
         all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
         sizes = {scale: f1[scale].shape[-2:] for scale in f1}
@@ -462,7 +482,10 @@ class Decoder(nn.Module):
         device = f1[1].device
         coarsest_scale = int(all_scales[0])
         old_stuff = torch.zeros(
-            b, self.embedding_decoder.internal_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device
+            b,
+            self.embedding_decoder.internal_dim,
+            *sizes[coarsest_scale],
+            device=f1[coarsest_scale].device
         )
         dense_corresps = {}
         if not upsample:
@@ -470,17 +493,17 @@ class Decoder(nn.Module):
             dense_certainty = 0.0
         else:
             dense_flow = F.interpolate(
-                    dense_flow,
-                    size=sizes[coarsest_scale],
-                    align_corners=False,
-                    mode="bilinear",
-                )
+                dense_flow,
+                size=sizes[coarsest_scale],
+                align_corners=False,
+                mode="bilinear",
+            )
             dense_certainty = F.interpolate(
-                    dense_certainty,
-                    size=sizes[coarsest_scale],
-                    align_corners=False,
-                    mode="bilinear",
-                )
+                dense_certainty,
+                size=sizes[coarsest_scale],
+                align_corners=False,
+                mode="bilinear",
+            )
         for new_scale in all_scales:
             ins = int(new_scale)
             f1_s, f2_s = f1[ins], f2[ins]
@@ -543,14 +566,14 @@ class RegressionMatcher(nn.Module):
         decoder,
         h=384,
         w=512,
-        use_contrastive_loss = False,
-        alpha = 1,
-        beta = 0,
-        sample_mode = "threshold",
-        upsample_preds = False,
-        symmetric = False,
-        name = None,
-        use_soft_mutual_nearest_neighbours = False,
+        use_contrastive_loss=False,
+        alpha=1,
+        beta=0,
+        sample_mode="threshold",
+        upsample_preds=False,
+        symmetric=False,
+        name=None,
+        use_soft_mutual_nearest_neighbours=False,
     ):
         super().__init__()
         self.encoder = encoder
@@ -566,13 +589,13 @@ class RegressionMatcher(nn.Module):
         self.symmetric = symmetric
         self.name = name
         self.sample_thresh = 0.05
-        self.upsample_res = (864,1152)
+        self.upsample_res = (864, 1152)
         if use_soft_mutual_nearest_neighbours:
             assert symmetric, "MNS requires symmetric inference"
         self.use_soft_mutual_nearest_neighbours = use_soft_mutual_nearest_neighbours
-        
-    def extract_backbone_features(self, batch, batched = True, upsample = True):
-        #TODO: only extract stride [1,2,4,8] for upsample = True
+
+    def extract_backbone_features(self, batch, batched=True, upsample=True):
+        # TODO: only extract stride [1,2,4,8] for upsample = True
         x_q = batch["query"]
         x_s = batch["support"]
         if batched:
@@ -593,7 +616,7 @@ class RegressionMatcher(nn.Module):
             dense_certainty = dense_certainty.clone()
             dense_certainty[dense_certainty > upper_thresh] = 1
         elif "pow" in self.sample_mode:
-            dense_certainty = dense_certainty**(1/3)
+            dense_certainty = dense_certainty ** (1 / 3)
         elif "naive" in self.sample_mode:
             dense_certainty = torch.ones_like(dense_certainty)
         matches, certainty = (
@@ -601,23 +624,28 @@ class RegressionMatcher(nn.Module):
             dense_certainty.reshape(-1),
         )
         expansion_factor = 4 if "balanced" in self.sample_mode else 1
-        good_samples = torch.multinomial(certainty, 
-                          num_samples = min(expansion_factor*num, len(certainty)), 
-                          replacement=False)
+        good_samples = torch.multinomial(
+            certainty,
+            num_samples=min(expansion_factor * num, len(certainty)),
+            replacement=False,
+        )
         good_matches, good_certainty = matches[good_samples], certainty[good_samples]
         if "balanced" not in self.sample_mode:
             return good_matches, good_certainty
 
         from ..utils.kde import kde
+
         density = kde(good_matches, std=0.1)
-        p = 1 / (density+1)
-        p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
-        balanced_samples = torch.multinomial(p, 
-                          num_samples = min(num,len(good_certainty)), 
-                          replacement=False)
+        p = 1 / (density + 1)
+        p[
+            density < 10
+        ] = 1e-7  # Basically should have at least 10 perfect neighbours, or around 100 ok ones
+        balanced_samples = torch.multinomial(
+            p, num_samples=min(num, len(good_certainty)), replacement=False
+        )
         return good_matches[balanced_samples], good_certainty[balanced_samples]
 
-    def forward(self, batch, batched = True):
+    def forward(self, batch, batched=True):
         feature_pyramid = self.extract_backbone_features(batch, batched=batched)
         if batched:
             f_q_pyramid = {
@@ -634,37 +662,43 @@ class RegressionMatcher(nn.Module):
         else:
             return dense_corresps
 
-    def forward_symmetric(self, batch, upsample = False, batched = True):
-        feature_pyramid = self.extract_backbone_features(batch, upsample = upsample, batched = batched)
+    def forward_symmetric(self, batch, upsample=False, batched=True):
+        feature_pyramid = self.extract_backbone_features(
+            batch, upsample=upsample, batched=batched
+        )
         f_q_pyramid = feature_pyramid
         f_s_pyramid = {
             scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]))
             for scale, f_scale in feature_pyramid.items()
         }
-        dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid, upsample = upsample, **(batch["corresps"] if "corresps" in batch else {}))
+        dense_corresps = self.decoder(
+            f_q_pyramid,
+            f_s_pyramid,
+            upsample=upsample,
+            **(batch["corresps"] if "corresps" in batch else {})
+        )
         return dense_corresps
-    
+
     def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
-        kpts_A, kpts_B = matches[...,:2], matches[...,2:]
-        kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1)
-        kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
+        kpts_A, kpts_B = matches[..., :2], matches[..., 2:]
+        kpts_A = torch.stack(
+            (W_A / 2 * (kpts_A[..., 0] + 1), H_A / 2 * (kpts_A[..., 1] + 1)), axis=-1
+        )
+        kpts_B = torch.stack(
+            (W_B / 2 * (kpts_B[..., 0] + 1), H_B / 2 * (kpts_B[..., 1] + 1)), axis=-1
+        )
         return kpts_A, kpts_B
-    
-    def match(
-        self,
-        im1_path,
-        im2_path,
-        *args,
-        batched=False,
-        device = None
-    ):
-        assert not (batched and self.upsample_preds), "Cannot upsample preds if in batchmode (as we don't have access to high res images). You can turn off upsample_preds by model.upsample_preds = False "
+
+    def match(self, im1_path, im2_path, *args, batched=False, device=None):
+        assert not (
+            batched and self.upsample_preds
+        ), "Cannot upsample preds if in batchmode (as we don't have access to high res images). You can turn off upsample_preds by model.upsample_preds = False "
         if isinstance(im1_path, (str, os.PathLike)):
             im1, im2 = Image.open(im1_path), Image.open(im2_path)
-        else: # assume it is a PIL Image
+        else:  # assume it is a PIL Image
             im1, im2 = im1_path, im2_path
         if device is None:
-            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         symmetric = self.symmetric
         self.train(False)
         with torch.no_grad():
@@ -680,7 +714,10 @@ class RegressionMatcher(nn.Module):
                     resize=(hs, ws), normalize=True
                 )
                 query, support = test_transform((im1, im2))
-                batch = {"query": query[None].to(device), "support": support[None].to(device)}
+                batch = {
+                    "query": query[None].to(device),
+                    "support": support[None].to(device),
+                }
             else:
                 b, c, h, w = im1.shape
                 b, c, h2, w2 = im2.shape
@@ -690,38 +727,47 @@ class RegressionMatcher(nn.Module):
             finest_scale = 1
             # Run matcher
             if symmetric:
-                dense_corresps  = self.forward_symmetric(batch, batched = True)
+                dense_corresps = self.forward_symmetric(batch, batched=True)
             else:
-                dense_corresps = self.forward(batch, batched = True)
-            
+                dense_corresps = self.forward(batch, batched=True)
+
             if self.upsample_preds:
                 hs, ws = self.upsample_res
             low_res_certainty = F.interpolate(
-            dense_corresps[16]["dense_certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
+                dense_corresps[16]["dense_certainty"],
+                size=(hs, ws),
+                align_corners=False,
+                mode="bilinear",
             )
             cert_clamp = 0
             factor = 0.5
-            low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
+            low_res_certainty = (
+                factor * low_res_certainty * (low_res_certainty < cert_clamp)
+            )
 
-            if self.upsample_preds: 
+            if self.upsample_preds:
                 test_transform = get_tuple_transform_ops(
                     resize=(hs, ws), normalize=True
                 )
                 query, support = test_transform((im1, im2))
                 query, support = query[None].to(device), support[None].to(device)
-                batch = {"query": query, "support": support, "corresps": dense_corresps[finest_scale]}
+                batch = {
+                    "query": query,
+                    "support": support,
+                    "corresps": dense_corresps[finest_scale],
+                }
                 if symmetric:
-                    dense_corresps = self.forward_symmetric(batch, upsample = True, batched=True)
+                    dense_corresps = self.forward_symmetric(
+                        batch, upsample=True, batched=True
+                    )
                 else:
-                    dense_corresps = self.forward(batch, batched = True, upsample=True)
+                    dense_corresps = self.forward(batch, batched=True, upsample=True)
             query_to_support = dense_corresps[finest_scale]["dense_flow"]
             dense_certainty = dense_corresps[finest_scale]["dense_certainty"]
-            
+
             # Get certainty interpolation
             dense_certainty = dense_certainty - low_res_certainty
-            query_to_support = query_to_support.permute(
-                0, 2, 3, 1
-                )
+            query_to_support = query_to_support.permute(0, 2, 3, 1)
             # Create im1 meshgrid
             query_coords = torch.meshgrid(
                 (
@@ -735,23 +781,20 @@ class RegressionMatcher(nn.Module):
             query_coords = query_coords.permute(0, 2, 3, 1)
             if (query_to_support.abs() > 1).any() and True:
                 wrong = (query_to_support.abs() > 1).sum(dim=-1) > 0
-                dense_certainty[wrong[:,None]] = 0
-                
+                dense_certainty[wrong[:, None]] = 0
+
             query_to_support = torch.clamp(query_to_support, -1, 1)
             if symmetric:
                 support_coords = query_coords
-                qts, stq = query_to_support.chunk(2)                    
+                qts, stq = query_to_support.chunk(2)
                 q_warp = torch.cat((query_coords, qts), dim=-1)
                 s_warp = torch.cat((stq, support_coords), dim=-1)
-                warp = torch.cat((q_warp, s_warp),dim=2)
-                dense_certainty = torch.cat(dense_certainty.chunk(2), dim=3)[:,0]
+                warp = torch.cat((q_warp, s_warp), dim=2)
+                dense_certainty = torch.cat(dense_certainty.chunk(2), dim=3)[:, 0]
             else:
                 warp = torch.cat((query_coords, query_to_support), dim=-1)
             if batched:
-                return (
-                    warp,
-                    dense_certainty
-                )
+                return (warp, dense_certainty)
             else:
                 return (
                     warp[0],
diff --git a/third_party/DKM/dkm/models/encoders.py b/third_party/DKM/dkm/models/encoders.py
index 29077e1797196611e9b59a753130a5b153e0aa05..29fe93443933cf7bbf5c542d8732aabc8c771604 100644
--- a/third_party/DKM/dkm/models/encoders.py
+++ b/third_party/DKM/dkm/models/encoders.py
@@ -3,10 +3,12 @@ import torch.nn as nn
 import torch.nn.functional as F
 import torchvision.models as tvm
 
+
 class ResNet18(nn.Module):
     def __init__(self, pretrained=False) -> None:
         super().__init__()
         self.net = tvm.resnet18(pretrained=pretrained)
+
     def forward(self, x):
         self = self.net
         x1 = x
@@ -18,7 +20,7 @@ class ResNet18(nn.Module):
         x8 = self.layer2(x4)
         x16 = self.layer3(x8)
         x32 = self.layer4(x16)
-        return {32:x32,16:x16,8:x8,4:x4,2:x2,1:x1}
+        return {32: x32, 16: x16, 8: x8, 4: x4, 2: x2, 1: x1}
 
     def train(self, mode=True):
         super().train(mode)
@@ -27,33 +29,47 @@ class ResNet18(nn.Module):
                 m.eval()
             pass
 
+
 class ResNet50(nn.Module):
-    def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False) -> None:
+    def __init__(
+        self,
+        pretrained=False,
+        high_res=False,
+        weights=None,
+        dilation=None,
+        freeze_bn=True,
+        anti_aliased=False,
+    ) -> None:
         super().__init__()
         if dilation is None:
-            dilation = [False,False,False]
+            dilation = [False, False, False]
         if anti_aliased:
             pass
         else:
             if weights is not None:
-                self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation)
+                self.net = tvm.resnet50(
+                    weights=weights, replace_stride_with_dilation=dilation
+                )
             else:
-                self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation)
-            
+                self.net = tvm.resnet50(
+                    pretrained=pretrained, replace_stride_with_dilation=dilation
+                )
+
         self.high_res = high_res
         self.freeze_bn = freeze_bn
+
     def forward(self, x):
         net = self.net
-        feats = {1:x}
+        feats = {1: x}
         x = net.conv1(x)
         x = net.bn1(x)
         x = net.relu(x)
-        feats[2] = x 
+        feats[2] = x
         x = net.maxpool(x)
         x = net.layer1(x)
-        feats[4] = x 
+        feats[4] = x
         x = net.layer2(x)
-        feats[8] = x  
+        feats[8] = x
         x = net.layer3(x)
         feats[16] = x
         x = net.layer4(x)
@@ -69,36 +85,65 @@ class ResNet50(nn.Module):
                 pass
 
 
-
-
 class ResNet101(nn.Module):
-    def __init__(self, pretrained=False, high_res = False, weights = None) -> None:
+    def __init__(self, pretrained=False, high_res=False, weights=None) -> None:
         super().__init__()
         if weights is not None:
-            self.net = tvm.resnet101(weights = weights)
+            self.net = tvm.resnet101(weights=weights)
         else:
             self.net = tvm.resnet101(pretrained=pretrained)
         self.high_res = high_res
         self.scale_factor = 1 if not high_res else 1.5
+
     def forward(self, x):
         net = self.net
-        feats = {1:x}
+        feats = {1: x}
         sf = self.scale_factor
         if self.high_res:
             x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic")
         x = net.conv1(x)
         x = net.bn1(x)
         x = net.relu(x)
-        feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")    
+        feats[2] = (
+            x
+            if not self.high_res
+            else F.interpolate(
+                x, scale_factor=1 / sf, align_corners=False, mode="bilinear"
+            )
+        )
         x = net.maxpool(x)
         x = net.layer1(x)
-        feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")    
+        feats[4] = (
+            x
+            if not self.high_res
+            else F.interpolate(
+                x, scale_factor=1 / sf, align_corners=False, mode="bilinear"
+            )
+        )
         x = net.layer2(x)
-        feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")    
+        feats[8] = (
+            x
+            if not self.high_res
+            else F.interpolate(
+                x, scale_factor=1 / sf, align_corners=False, mode="bilinear"
+            )
+        )
         x = net.layer3(x)
-        feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")    
+        feats[16] = (
+            x
+            if not self.high_res
+            else F.interpolate(
+                x, scale_factor=1 / sf, align_corners=False, mode="bilinear"
+            )
+        )
         x = net.layer4(x)
-        feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")    
+        feats[32] = (
+            x
+            if not self.high_res
+            else F.interpolate(
+                x, scale_factor=1 / sf, align_corners=False, mode="bilinear"
+            )
+        )
         return feats
 
     def train(self, mode=True):
@@ -110,33 +155,64 @@ class ResNet101(nn.Module):
 
 
 class WideResNet50(nn.Module):
-    def __init__(self, pretrained=False, high_res = False, weights = None) -> None:
+    def __init__(self, pretrained=False, high_res=False, weights=None) -> None:
         super().__init__()
         if weights is not None:
-            self.net = tvm.wide_resnet50_2(weights = weights)
+            self.net = tvm.wide_resnet50_2(weights=weights)
         else:
             self.net = tvm.wide_resnet50_2(pretrained=pretrained)
         self.high_res = high_res
         self.scale_factor = 1 if not high_res else 1.5
+
     def forward(self, x):
         net = self.net
-        feats = {1:x}
+        feats = {1: x}
         sf = self.scale_factor
         if self.high_res:
             x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic")
         x = net.conv1(x)
         x = net.bn1(x)
         x = net.relu(x)
-        feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")    
+        feats[2] = (
+            x
+            if not self.high_res
+            else F.interpolate(
+                x, scale_factor=1 / sf, align_corners=False, mode="bilinear"
+            )
+        )
         x = net.maxpool(x)
         x = net.layer1(x)
-        feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")    
+        feats[4] = (
+            x
+            if not self.high_res
+            else F.interpolate(
+                x, scale_factor=1 / sf, align_corners=False, mode="bilinear"
+            )
+        )
         x = net.layer2(x)
-        feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")    
+        feats[8] = (
+            x
+            if not self.high_res
+            else F.interpolate(
+                x, scale_factor=1 / sf, align_corners=False, mode="bilinear"
+            )
+        )
         x = net.layer3(x)
-        feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")    
+        feats[16] = (
+            x
+            if not self.high_res
+            else F.interpolate(
+                x, scale_factor=1 / sf, align_corners=False, mode="bilinear"
+            )
+        )
         x = net.layer4(x)
-        feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")    
+        feats[32] = (
+            x
+            if not self.high_res
+            else F.interpolate(
+                x, scale_factor=1 / sf, align_corners=False, mode="bilinear"
+            )
+        )
         return feats
 
     def train(self, mode=True):
@@ -144,4 +220,4 @@ class WideResNet50(nn.Module):
         for m in self.modules():
             if isinstance(m, nn.BatchNorm2d):
                 m.eval()
-            pass
\ No newline at end of file
+            pass
diff --git a/third_party/DKM/dkm/models/model_zoo/DKMv3.py b/third_party/DKM/dkm/models/model_zoo/DKMv3.py
index 6f4c9ede3863d778f679a033d8d2287b8776e894..fe41ab8b6400a4e57b8b08aab556bcba535e384a 100644
--- a/third_party/DKM/dkm/models/model_zoo/DKMv3.py
+++ b/third_party/DKM/dkm/models/model_zoo/DKMv3.py
@@ -5,9 +5,17 @@ from ..dkm import *
 from ..encoders import *
 
 
-def DKMv3(weights, h, w, symmetric = True, sample_mode= "threshold_balanced", device = None, **kwargs):
+def DKMv3(
+    weights,
+    h,
+    w,
+    symmetric=True,
+    sample_mode="threshold_balanced",
+    device=None,
+    **kwargs
+):
     if device is None:
-        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     gp_dim = 256
     dfn_dim = 384
     feat_dim = 256
@@ -57,44 +65,44 @@ def DKMv3(weights, h, w, symmetric = True, sample_mode= "threshold_balanced", de
     conv_refiner = nn.ModuleDict(
         {
             "16": ConvRefiner(
-                2 * 512+128+(2*7+1)**2,
-                2 * 512+128+(2*7+1)**2,
+                2 * 512 + 128 + (2 * 7 + 1) ** 2,
+                2 * 512 + 128 + (2 * 7 + 1) ** 2,
                 3,
                 kernel_size=kernel_size,
                 dw=dw,
                 hidden_blocks=hidden_blocks,
                 displacement_emb=displacement_emb,
                 displacement_emb_dim=128,
-                local_corr_radius = 7,
-                corr_in_other = True,
+                local_corr_radius=7,
+                corr_in_other=True,
             ),
             "8": ConvRefiner(
-                2 * 512+64+(2*3+1)**2,
-                2 * 512+64+(2*3+1)**2,
+                2 * 512 + 64 + (2 * 3 + 1) ** 2,
+                2 * 512 + 64 + (2 * 3 + 1) ** 2,
                 3,
                 kernel_size=kernel_size,
                 dw=dw,
                 hidden_blocks=hidden_blocks,
                 displacement_emb=displacement_emb,
                 displacement_emb_dim=64,
-                local_corr_radius = 3,
-                corr_in_other = True,
+                local_corr_radius=3,
+                corr_in_other=True,
             ),
             "4": ConvRefiner(
-                2 * 256+32+(2*2+1)**2,
-                2 * 256+32+(2*2+1)**2,
+                2 * 256 + 32 + (2 * 2 + 1) ** 2,
+                2 * 256 + 32 + (2 * 2 + 1) ** 2,
                 3,
                 kernel_size=kernel_size,
                 dw=dw,
                 hidden_blocks=hidden_blocks,
                 displacement_emb=displacement_emb,
                 displacement_emb_dim=32,
-                local_corr_radius = 2,
-                corr_in_other = True,
+                local_corr_radius=2,
+                corr_in_other=True,
             ),
             "2": ConvRefiner(
-                2 * 64+16,
-                128+16,
+                2 * 64 + 16,
+                128 + 16,
                 3,
                 kernel_size=kernel_size,
                 dw=dw,
@@ -103,7 +111,7 @@ def DKMv3(weights, h, w, symmetric = True, sample_mode= "threshold_balanced", de
                 displacement_emb_dim=16,
             ),
             "1": ConvRefiner(
-                2 * 3+6,
+                2 * 3 + 6,
                 24,
                 3,
                 kernel_size=kernel_size,
@@ -144,7 +152,16 @@ def DKMv3(weights, h, w, symmetric = True, sample_mode= "threshold_balanced", de
     )
     decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
 
-    encoder = ResNet50(pretrained = False, high_res = False, freeze_bn=False)
-    matcher = RegressionMatcher(encoder, decoder, h=h, w=w, name = "DKMv3", sample_mode=sample_mode, symmetric = symmetric, **kwargs).to(device)
+    encoder = ResNet50(pretrained=False, high_res=False, freeze_bn=False)
+    matcher = RegressionMatcher(
+        encoder,
+        decoder,
+        h=h,
+        w=w,
+        name="DKMv3",
+        sample_mode=sample_mode,
+        symmetric=symmetric,
+        **kwargs
+    ).to(device)
     res = matcher.load_state_dict(weights)
     return matcher
diff --git a/third_party/DKM/dkm/models/model_zoo/__init__.py b/third_party/DKM/dkm/models/model_zoo/__init__.py
index c85da2920c1acfac140ada2d87623203607d42ca..78901ad4f67e152933af8bb56c5478e3d561f30d 100644
--- a/third_party/DKM/dkm/models/model_zoo/__init__.py
+++ b/third_party/DKM/dkm/models/model_zoo/__init__.py
@@ -8,7 +8,7 @@ import torch
 from .DKMv3 import DKMv3
 
 
-def DKMv3_outdoor(path_to_weights = None, device=None):
+def DKMv3_outdoor(path_to_weights=None, device=None):
     """
     Loads DKMv3 outdoor weights, uses internal resolution of (540, 720) by default
     resolution can be changed by setting model.h_resized, model.w_resized later.
@@ -16,24 +16,27 @@ def DKMv3_outdoor(path_to_weights = None, device=None):
     can be turned off by model.upsample_preds = False
     """
     if device is None:
-        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     if path_to_weights is not None:
-        weights = torch.load(path_to_weights, map_location='cpu')
+        weights = torch.load(path_to_weights, map_location="cpu")
     else:
-        weights = torch.hub.load_state_dict_from_url(weight_urls["DKMv3"]["outdoor"],
-                                                     map_location='cpu')
-    return DKMv3(weights, 540, 720, upsample_preds = True, device=device)
+        weights = torch.hub.load_state_dict_from_url(
+            weight_urls["DKMv3"]["outdoor"], map_location="cpu"
+        )
+    return DKMv3(weights, 540, 720, upsample_preds=True, device=device)
 
-def DKMv3_indoor(path_to_weights = None, device=None):
+
+def DKMv3_indoor(path_to_weights=None, device=None):
     """
     Loads DKMv3 indoor weights, uses internal resolution of (480, 640) by default
     Resolution can be changed by setting model.h_resized, model.w_resized later.
     """
     if device is None:
-        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     if path_to_weights is not None:
         weights = torch.load(path_to_weights, map_location=device)
     else:
-        weights = torch.hub.load_state_dict_from_url(weight_urls["DKMv3"]["indoor"],
-                                                     map_location=device)
-    return DKMv3(weights, 480, 640, upsample_preds = False, device=device)
+        weights = torch.hub.load_state_dict_from_url(
+            weight_urls["DKMv3"]["indoor"], map_location=device
+        )
+    return DKMv3(weights, 480, 640, upsample_preds=False, device=device)
diff --git a/third_party/DKM/dkm/utils/kde.py b/third_party/DKM/dkm/utils/kde.py
index fa392455e70fda4c9c77c28bda76bcb7ef9045b0..286a531cede3fe1b46fbb8915bb8ad140b2cb79a 100644
--- a/third_party/DKM/dkm/utils/kde.py
+++ b/third_party/DKM/dkm/utils/kde.py
@@ -2,25 +2,28 @@ import torch
 import torch.nn.functional as F
 import numpy as np
 
-def fast_kde(x, std = 0.1, kernel_size = 9, dilation = 3, padding = 9//2, stride = 1):
+
+def fast_kde(x, std=0.1, kernel_size=9, dilation=3, padding=9 // 2, stride=1):
     raise NotImplementedError("WIP, use at your own risk.")
     # Note: when doing symmetric matching this might not be very exact, since we only check neighbours on the grid
-    x = x.permute(0,3,1,2)
-    B,C,H,W = x.shape
-    K = kernel_size ** 2
-    unfolded_x = F.unfold(x,kernel_size=kernel_size, dilation = dilation, padding = padding, stride = stride).reshape(B, C, K, H, W)
-    scores = (-(unfolded_x - x[:,:,None]).sum(dim=1)**2/(2*std**2)).exp()
+    x = x.permute(0, 3, 1, 2)
+    B, C, H, W = x.shape
+    K = kernel_size**2
+    unfolded_x = F.unfold(
+        x, kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride
+    ).reshape(B, C, K, H, W)
+    scores = (-(unfolded_x - x[:, :, None]).sum(dim=1) ** 2 / (2 * std**2)).exp()
     density = scores.sum(dim=1)
     return density
 
 
-def kde(x, std = 0.1, device=None):
+def kde(x, std=0.1, device=None):
     if device is None:
-        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     if isinstance(x, np.ndarray):
         x = torch.from_numpy(x)
     # use a gaussian kernel to estimate density
     x = x.to(device)
-    scores = (-torch.cdist(x,x)**2/(2*std**2)).exp()
+    scores = (-torch.cdist(x, x) ** 2 / (2 * std**2)).exp()
     density = scores.sum(dim=-1)
     return density
diff --git a/third_party/DKM/dkm/utils/local_correlation.py b/third_party/DKM/dkm/utils/local_correlation.py
index c0c1c06291d0b760376a2b2162bcf49d6eb1303c..08f7f04881bb9610edf3bd8bdcbda4e32d6e4c54 100644
--- a/third_party/DKM/dkm/utils/local_correlation.py
+++ b/third_party/DKM/dkm/utils/local_correlation.py
@@ -3,38 +3,42 @@ import torch.nn.functional as F
 
 
 def local_correlation(
-    feature0,
-    feature1,
-    local_radius,
-    padding_mode="zeros",
-    flow = None
+    feature0, feature1, local_radius, padding_mode="zeros", flow=None
 ):
     device = feature0.device
     b, c, h, w = feature0.size()
     if flow is None:
         # If flow is None, assume feature0 and feature1 are aligned
         coords = torch.meshgrid(
-                (
-                    torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
-                    torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
-                ))
-        coords = torch.stack((coords[1], coords[0]), dim=-1)[
-            None
-        ].expand(b, h, w, 2)
+            (
+                torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
+                torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
+            )
+        )
+        coords = torch.stack((coords[1], coords[0]), dim=-1)[None].expand(b, h, w, 2)
     else:
-        coords = flow.permute(0,2,3,1) # If using flow, sample around flow target.
+        coords = flow.permute(0, 2, 3, 1)  # If using flow, sample around flow target.
     r = local_radius
     local_window = torch.meshgrid(
-                (
-                    torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device=device),
-                    torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device=device),
-                ))
-    local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[
-            None
-        ].expand(b, 2*r+1, 2*r+1, 2).reshape(b, (2*r+1)**2, 2)
-    coords = (coords[:,:,:,None]+local_window[:,None,None]).reshape(b,h,w*(2*r+1)**2,2)
+        (
+            torch.linspace(
+                -2 * local_radius / h, 2 * local_radius / h, 2 * r + 1, device=device
+            ),
+            torch.linspace(
+                -2 * local_radius / w, 2 * local_radius / w, 2 * r + 1, device=device
+            ),
+        )
+    )
+    local_window = (
+        torch.stack((local_window[1], local_window[0]), dim=-1)[None]
+        .expand(b, 2 * r + 1, 2 * r + 1, 2)
+        .reshape(b, (2 * r + 1) ** 2, 2)
+    )
+    coords = (coords[:, :, :, None] + local_window[:, None, None]).reshape(
+        b, h, w * (2 * r + 1) ** 2, 2
+    )
     window_feature = F.grid_sample(
         feature1, coords, padding_mode=padding_mode, align_corners=False
-    )[...,None].reshape(b,c,h,w,(2*r+1)**2)
-    corr = torch.einsum("bchw, bchwk -> bkhw", feature0, window_feature)/(c**.5)
+    )[..., None].reshape(b, c, h, w, (2 * r + 1) ** 2)
+    corr = torch.einsum("bchw, bchwk -> bkhw", feature0, window_feature) / (c**0.5)
     return corr
diff --git a/third_party/DKM/dkm/utils/utils.py b/third_party/DKM/dkm/utils/utils.py
index 46bbe60260930aed184c6fa5907c837c0177b304..ca5ca11da35d2c201d3351d33798a04cd7781b4f 100644
--- a/third_party/DKM/dkm/utils/utils.py
+++ b/third_party/DKM/dkm/utils/utils.py
@@ -6,18 +6,18 @@ from torchvision.transforms.functional import InterpolationMode
 import torch.nn.functional as F
 from PIL import Image
 
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 # Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py
 # --- GEOMETRY ---
 def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
     if len(kpts0) < 5:
         return None
-    K0inv = np.linalg.inv(K0[:2,:2])
-    K1inv = np.linalg.inv(K1[:2,:2])
+    K0inv = np.linalg.inv(K0[:2, :2])
+    K1inv = np.linalg.inv(K1[:2, :2])
 
-    kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T 
-    kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T
+    kpts0 = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T
+    kpts1 = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T
 
     E, mask = cv2.findEssentialMat(
         kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf, method=cv2.RANSAC
diff --git a/third_party/DarkFeat/darkfeat.py b/third_party/DarkFeat/darkfeat.py
index e78ad2604aafb759a6241365ac93fd1ef38f76f3..710962a2a8853689b5b0b764ce817d23aa0537ac 100644
--- a/third_party/DarkFeat/darkfeat.py
+++ b/third_party/DarkFeat/darkfeat.py
@@ -16,11 +16,11 @@ def gather_nd(params, indices):
         out_shape = orig_shape[:-1] + list(params.shape)[m:]
     else:
         raise ValueError(
-            f'the last dimension of indices must less or equal to the rank of params. Got indices:{indices.shape}, params:{params.shape}. {m} > {n}'
+            f"the last dimension of indices must less or equal to the rank of params. Got indices:{indices.shape}, params:{params.shape}. {m} > {n}"
         )
 
     indices = indices.reshape((num_samples, m)).transpose(0, 1).tolist()
-    output = params[indices]    # (num_samples, ...)
+    output = params[indices]  # (num_samples, ...)
     return output.reshape(out_shape).contiguous()
 
 
@@ -59,11 +59,13 @@ def interpolate(pos, inputs, nd=True):
         w_bottom_right = w_bottom_right[..., None]
 
     interpolated_val = (
-        w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) +
-        w_top_right * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) +
-        w_bottom_left * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) +
-        w_bottom_right *
-        gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1))
+        w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1))
+        + w_top_right
+        * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1))
+        + w_bottom_left
+        * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1))
+        + w_bottom_right
+        * gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1))
     )
 
     return interpolated_val
@@ -73,24 +75,29 @@ def edge_mask(inputs, n_channel, dilation=1, edge_thld=5):
     b, c, h, w = inputs.size()
     device = inputs.device
 
-    dii_filter = torch.tensor(
-        [[0, 1., 0], [0, -2., 0], [0, 1., 0]]
-    ).view(1, 1, 3, 3)
+    dii_filter = torch.tensor([[0, 1.0, 0], [0, -2.0, 0], [0, 1.0, 0]]).view(1, 1, 3, 3)
     dij_filter = 0.25 * torch.tensor(
-        [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]]
-    ).view(1, 1, 3, 3)
-    djj_filter = torch.tensor(
-        [[0, 0, 0], [1., -2., 1.], [0, 0, 0]]
+        [[1.0, 0, -1.0], [0, 0.0, 0], [-1.0, 0, 1.0]]
     ).view(1, 1, 3, 3)
+    djj_filter = torch.tensor([[0, 0, 0], [1.0, -2.0, 1.0], [0, 0, 0]]).view(1, 1, 3, 3)
 
     dii = F.conv2d(
-        inputs.view(-1, 1, h, w), dii_filter.to(device), padding=dilation, dilation=dilation
+        inputs.view(-1, 1, h, w),
+        dii_filter.to(device),
+        padding=dilation,
+        dilation=dilation,
     ).view(b, c, h, w)
     dij = F.conv2d(
-        inputs.view(-1, 1, h, w), dij_filter.to(device), padding=dilation, dilation=dilation
+        inputs.view(-1, 1, h, w),
+        dij_filter.to(device),
+        padding=dilation,
+        dilation=dilation,
     ).view(b, c, h, w)
     djj = F.conv2d(
-        inputs.view(-1, 1, h, w), djj_filter.to(device), padding=dilation, dilation=dilation
+        inputs.view(-1, 1, h, w),
+        djj_filter.to(device),
+        padding=dilation,
+        dilation=dilation,
     ).view(b, c, h, w)
 
     det = dii * djj - dij * dij
@@ -111,11 +118,17 @@ def extract_kpts(score_map, k=256, score_thld=0, edge_thld=0, nms_size=3, eof_si
 
     mask = score_map > score_thld
     if nms_size > 0:
-        nms_mask = F.max_pool2d(score_map, kernel_size=nms_size, stride=1, padding=nms_size//2)
+        nms_mask = F.max_pool2d(
+            score_map, kernel_size=nms_size, stride=1, padding=nms_size // 2
+        )
         nms_mask = torch.eq(score_map, nms_mask)
         mask = torch.logical_and(nms_mask, mask)
     if eof_size > 0:
-        eof_mask = torch.ones((1, 1, h - 2 * eof_size, w - 2 * eof_size), dtype=torch.float32, device=score_map.device)
+        eof_mask = torch.ones(
+            (1, 1, h - 2 * eof_size, w - 2 * eof_size),
+            dtype=torch.float32,
+            device=score_map.device,
+        )
         eof_mask = F.pad(eof_mask, [eof_size] * 4, value=0)
         eof_mask = eof_mask.bool()
         mask = torch.logical_and(eof_mask, mask)
@@ -157,23 +170,20 @@ def extract_kpts(score_map, k=256, score_thld=0, edge_thld=0, nms_size=3, eof_si
 # output: [batch_size, C, H, W], [batch_size, C, H, W]
 def peakiness_score(inputs, moving_instance_max, ksize=3, dilation=1):
     inputs = inputs / moving_instance_max
-    
+
     batch_size, C, H, W = inputs.shape
 
     pad_size = ksize // 2 + (dilation - 1)
     kernel = torch.ones([C, 1, ksize, ksize], device=inputs.device) / (ksize * ksize)
-    
-    pad_inputs = F.pad(inputs, [pad_size] * 4, mode='reflect')
+
+    pad_inputs = F.pad(inputs, [pad_size] * 4, mode="reflect")
 
     avg_spatial_inputs = F.conv2d(
-        pad_inputs,
-        kernel,
-        stride=1,
-        dilation=dilation,
-        padding=0,
-        groups=C
+        pad_inputs, kernel, stride=1, dilation=dilation, padding=0, groups=C
     )
-    avg_channel_inputs = torch.mean(inputs, axis=1, keepdim=True) # channel dimension is 1
+    avg_channel_inputs = torch.mean(
+        inputs, axis=1, keepdim=True
+    )  # channel dimension is 1
     # print(avg_spatial_inputs.shape)
 
     alpha = F.softplus(inputs - avg_spatial_inputs)
@@ -184,23 +194,36 @@ def peakiness_score(inputs, moving_instance_max, ksize=3, dilation=1):
 
 class DarkFeat(nn.Module):
     default_config = {
-        'model_path': '',
-        'input_type': 'raw-demosaic',
-        'kpt_n': 5000,
-        'kpt_refinement': True,
-        'score_thld': 0.5,
-        'edge_thld': 10,
-        'multi_scale': False,
-        'multi_level': True,
-        'nms_size': 3,
-        'eof_size': 5,
-        'need_norm': True,
-        'use_peakiness': True
+        "model_path": "",
+        "input_type": "raw-demosaic",
+        "kpt_n": 5000,
+        "kpt_refinement": True,
+        "score_thld": 0.5,
+        "edge_thld": 10,
+        "multi_scale": False,
+        "multi_level": True,
+        "nms_size": 3,
+        "eof_size": 5,
+        "need_norm": True,
+        "use_peakiness": True,
     }
 
-    def __init__(self, model_path='', inchan=3, dilated=True, dilation=1, bn=True, bn_affine=False):
+    def __init__(
+        self,
+        model_path="",
+        inchan=3,
+        dilated=True,
+        dilation=1,
+        bn=True,
+        bn_affine=False,
+    ):
         super(DarkFeat, self).__init__()
-        inchan = 3 if self.default_config['input_type'] == 'rgb' or self.default_config['input_type'] == 'raw-demosaic' else 1
+        inchan = (
+            3
+            if self.default_config["input_type"] == "rgb"
+            or self.default_config["input_type"] == "raw-demosaic"
+            else 1
+        )
         self.config = {**self.default_config}
 
         self.inchan = inchan
@@ -209,60 +232,81 @@ class DarkFeat(nn.Module):
         self.dilation = dilation
         self.bn = bn
         self.bn_affine = bn_affine
-        self.config['model_path'] = model_path
+        self.config["model_path"] = model_path
 
         dim = 128
         mchan = 4
 
-        self.conv0 = self._add_conv(  8*mchan)
-        self.conv1 = self._add_conv(  8*mchan, bn=False)
-        self.bn1 = self._make_bn(8*mchan)
-        self.conv2 = self._add_conv( 16*mchan, stride=2)
-        self.conv3 = self._add_conv( 16*mchan, bn=False)
-        self.bn3 = self._make_bn(16*mchan)
-        self.conv4 = self._add_conv( 32*mchan, stride=2)
-        self.conv5 = self._add_conv( 32*mchan)
+        self.conv0 = self._add_conv(8 * mchan)
+        self.conv1 = self._add_conv(8 * mchan, bn=False)
+        self.bn1 = self._make_bn(8 * mchan)
+        self.conv2 = self._add_conv(16 * mchan, stride=2)
+        self.conv3 = self._add_conv(16 * mchan, bn=False)
+        self.bn3 = self._make_bn(16 * mchan)
+        self.conv4 = self._add_conv(32 * mchan, stride=2)
+        self.conv5 = self._add_conv(32 * mchan)
         # replace last 8x8 convolution with 3 3x3 convolutions
-        self.conv6_0 = self._add_conv( 32*mchan)
-        self.conv6_1 = self._add_conv( 32*mchan)
+        self.conv6_0 = self._add_conv(32 * mchan)
+        self.conv6_1 = self._add_conv(32 * mchan)
         self.conv6_2 = self._add_conv(dim, bn=False, relu=False)
         self.out_dim = dim
 
-        self.moving_avg_params = nn.ParameterList([
-            Parameter(torch.tensor(1.), requires_grad=False),
-            Parameter(torch.tensor(1.), requires_grad=False),
-            Parameter(torch.tensor(1.), requires_grad=False)
-        ])
+        self.moving_avg_params = nn.ParameterList(
+            [
+                Parameter(torch.tensor(1.0), requires_grad=False),
+                Parameter(torch.tensor(1.0), requires_grad=False),
+                Parameter(torch.tensor(1.0), requires_grad=False),
+            ]
+        )
         self.clf = nn.Conv2d(128, 2, kernel_size=1)
 
         state_dict = torch.load(self.config["model_path"])
         new_state_dict = {}
-        
+
         for key in state_dict:
-            if 'running_mean' not in key and 'running_var' not in key and 'num_batches_tracked' not in key:
+            if (
+                "running_mean" not in key
+                and "running_var" not in key
+                and "num_batches_tracked" not in key
+            ):
                 new_state_dict[key] = state_dict[key]
 
         self.load_state_dict(new_state_dict)
-        print('Loaded DarkFeat model')
-        
+        print("Loaded DarkFeat model")
+
     def _make_bn(self, outd):
         return nn.BatchNorm2d(outd, affine=self.bn_affine, track_running_stats=False)
 
-    def _add_conv(self, outd, k=3, stride=1, dilation=1, bn=True, relu=True, k_pool = 1, pool_type='max', bias=False):
+    def _add_conv(
+        self,
+        outd,
+        k=3,
+        stride=1,
+        dilation=1,
+        bn=True,
+        relu=True,
+        k_pool=1,
+        pool_type="max",
+        bias=False,
+    ):
         d = self.dilation * dilation
-        conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=stride, bias=bias)
+        conv_params = dict(
+            padding=((k - 1) * d) // 2, dilation=d, stride=stride, bias=bias
+        )
 
         ops = nn.ModuleList([])
 
-        ops.append( nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params) )
-        if bn and self.bn: ops.append( self._make_bn(outd) )
-        if relu: ops.append( nn.ReLU(inplace=True) )
+        ops.append(nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params))
+        if bn and self.bn:
+            ops.append(self._make_bn(outd))
+        if relu:
+            ops.append(nn.ReLU(inplace=True))
         self.curchan = outd
-        
+
         if k_pool > 1:
-            if pool_type == 'avg':
+            if pool_type == "avg":
                 ops.append(torch.nn.AvgPool2d(kernel_size=k_pool))
-            elif pool_type == 'max':
+            elif pool_type == "max":
                 ops.append(torch.nn.MaxPool2d(kernel_size=k_pool))
             else:
                 print(f"Error, unknown pooling type {pool_type}...")
@@ -270,32 +314,32 @@ class DarkFeat(nn.Module):
         return nn.Sequential(*ops)
 
     def forward(self, input):
-        """ Compute keypoints, scores, descriptors for image """
-        data = input['image']
+        """Compute keypoints, scores, descriptors for image"""
+        data = input["image"]
         H, W = data.shape[2:]
 
-        if self.config['input_type'] == 'rgb':
+        if self.config["input_type"] == "rgb":
             # 3-channel rgb
             RGB_mean = [0.485, 0.456, 0.406]
-            RGB_std  = [0.229, 0.224, 0.225]
+            RGB_std = [0.229, 0.224, 0.225]
             norm_RGB = tvf.Normalize(mean=RGB_mean, std=RGB_std)
             data = norm_RGB(data)
 
-        elif self.config['input_type'] == 'gray':
+        elif self.config["input_type"] == "gray":
             # 1-channel
             data = torch.mean(data, dim=1, keepdim=True)
             norm_gray0 = tvf.Normalize(mean=data.mean(), std=data.std())
             data = norm_gray0(data)
 
-        elif self.config['input_type'] == 'raw':
+        elif self.config["input_type"] == "raw":
             # 4-channel
             pass
-        elif self.config['input_type'] == 'raw-demosaic':
+        elif self.config["input_type"] == "raw-demosaic":
             # 3-channel
             pass
         else:
             raise NotImplementedError()
-        
+
         # x: [N, C, H, W]
         x0 = self.conv0(data)
         x1 = self.conv1(x0)
@@ -309,16 +353,20 @@ class DarkFeat(nn.Module):
         x6_1 = self.conv6_1(x6_0)
         x6_2 = self.conv6_2(x6_1)
 
-        comb_weights = torch.tensor([1., 2., 3.], device=data.device)
+        comb_weights = torch.tensor([1.0, 2.0, 3.0], device=data.device)
         comb_weights /= torch.sum(comb_weights)
         ksize = [3, 2, 1]
         det_score_maps = []
 
         for idx, xx in enumerate([x1, x3, x6_2]):
-            alpha, beta = peakiness_score(xx, self.moving_avg_params[idx].detach(), ksize=3, dilation=ksize[idx])
+            alpha, beta = peakiness_score(
+                xx, self.moving_avg_params[idx].detach(), ksize=3, dilation=ksize[idx]
+            )
             score_vol = alpha * beta
             det_score_map = torch.max(score_vol, dim=1, keepdim=True)[0]
-            det_score_map = F.interpolate(det_score_map, size=data.shape[2:], mode='bilinear', align_corners=True)
+            det_score_map = F.interpolate(
+                det_score_map, size=data.shape[2:], mode="bilinear", align_corners=True
+            )
             det_score_map = comb_weights[idx] * det_score_map
             det_score_maps.append(det_score_map)
 
@@ -326,34 +374,42 @@ class DarkFeat(nn.Module):
 
         desc = x6_2
         score_map = det_score_map
-        conf = F.softmax(self.clf((desc)**2), dim=1)[:,1:2]
-        score_map = score_map * F.interpolate(conf, size=score_map.shape[2:], mode='bilinear', align_corners=True)
+        conf = F.softmax(self.clf((desc) ** 2), dim=1)[:, 1:2]
+        score_map = score_map * F.interpolate(
+            conf, size=score_map.shape[2:], mode="bilinear", align_corners=True
+        )
 
         kpt_inds, kpt_score = extract_kpts(
             score_map,
-            k=self.config['kpt_n'],
-            score_thld=self.config['score_thld'],
-            nms_size=self.config['nms_size'],
-            eof_size=self.config['eof_size'],
-            edge_thld=self.config['edge_thld']
+            k=self.config["kpt_n"],
+            score_thld=self.config["score_thld"],
+            nms_size=self.config["nms_size"],
+            eof_size=self.config["eof_size"],
+            edge_thld=self.config["edge_thld"],
         )
 
-        descs = F.normalize(
-                    interpolate(kpt_inds.squeeze(0) / 4, desc.squeeze(0).permute(1, 2, 0)),
-                    p=2,
-                    dim=-1
-                ).detach().cpu().numpy(),
-        kpts = np.squeeze(torch.stack([kpt_inds[:, :, 1], kpt_inds[:, :, 0]], dim=-1).cpu(), axis=0) \
-                * np.array([W / data.shape[3], H / data.shape[2]], dtype=np.float32)
+        descs = (
+            F.normalize(
+                interpolate(kpt_inds.squeeze(0) / 4, desc.squeeze(0).permute(1, 2, 0)),
+                p=2,
+                dim=-1,
+            )
+            .detach()
+            .cpu()
+            .numpy(),
+        )
+        kpts = np.squeeze(
+            torch.stack([kpt_inds[:, :, 1], kpt_inds[:, :, 0]], dim=-1).cpu(), axis=0
+        ) * np.array([W / data.shape[3], H / data.shape[2]], dtype=np.float32)
         scores = np.squeeze(kpt_score.detach().cpu().numpy(), axis=0)
 
-        idxs = np.negative(scores).argsort()[0:self.config['kpt_n']]
+        idxs = np.negative(scores).argsort()[0 : self.config["kpt_n"]]
         descs = descs[0][idxs]
         kpts = kpts[idxs]
         scores = scores[idxs]
 
         return {
-            'keypoints': kpts,
-            'scores': torch.from_numpy(scores),
-            'descriptors': torch.from_numpy(descs.T),
+            "keypoints": kpts,
+            "scores": torch.from_numpy(scores),
+            "descriptors": torch.from_numpy(descs.T),
         }
diff --git a/third_party/DarkFeat/datasets/InvISP/cal_metrics.py b/third_party/DarkFeat/datasets/InvISP/cal_metrics.py
index cc3e501664487de4c08ab8c89328dd266fba2868..28811368c5be5a362e8907ec4963a1de7aaa260b 100644
--- a/third_party/DarkFeat/datasets/InvISP/cal_metrics.py
+++ b/third_party/DarkFeat/datasets/InvISP/cal_metrics.py
@@ -1,8 +1,9 @@
 import cv2
 import numpy as np
 import math
+
 # from skimage.metrics import structural_similarity as ssim
-from skimage.measure import compare_ssim 
+from skimage.measure import compare_ssim
 from scipy.misc import imread
 from glob import glob
 
@@ -14,30 +15,34 @@ parser.add_argument("--path", type=str, help="Path to evaluate images.")
 
 args = parser.parse_args()
 
+
 def psnr(img1, img2):
-   mse = np.mean( (img1/255. - img2/255.) ** 2 )
-   if mse < 1.0e-10:
-      return 100
-   PIXEL_MAX = 1
-   return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
+    mse = np.mean((img1 / 255.0 - img2 / 255.0) ** 2)
+    if mse < 1.0e-10:
+        return 100
+    PIXEL_MAX = 1
+    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
+
 
 def psnr_raw(img1, img2):
-   mse = np.mean( (img1 - img2) ** 2 )
-   if mse < 1.0e-10:
-      return 100
-   PIXEL_MAX = 1
-   return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
+    mse = np.mean((img1 - img2) ** 2)
+    if mse < 1.0e-10:
+        return 100
+    PIXEL_MAX = 1
+    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
 
 
 def my_ssim(img1, img2):
-    return compare_ssim(img1, img2, data_range=img1.max() - img1.min(), multichannel=True)
+    return compare_ssim(
+        img1, img2, data_range=img1.max() - img1.min(), multichannel=True
+    )
 
 
 def quan_eval(path, suffix="jpg"):
     # path: /disk2/yazhou/projects/IISP/exps/test_final_unet_globalEDV2/
     # ours
-    gt_imgs = sorted(glob(path+"tar*.%s"%suffix))
-    pred_imgs = sorted(glob(path+"pred*.%s"%suffix))
+    gt_imgs = sorted(glob(path + "tar*.%s" % suffix))
+    pred_imgs = sorted(glob(path + "pred*.%s" % suffix))
 
     # with open(split_path + "test_gt.txt", 'r') as f_gt, open(split_path+"test_rgb.txt","r") as f_rgb:
     #     gt_imgs = [line.rstrip() for line in f_gt.readlines()]
@@ -45,8 +50,8 @@ def quan_eval(path, suffix="jpg"):
 
     assert len(gt_imgs) == len(pred_imgs)
 
-    psnr_avg = 0.
-    ssim_avg = 0.
+    psnr_avg = 0.0
+    ssim_avg = 0.0
     for i in range(len(gt_imgs)):
         gt = imread(gt_imgs[i])
         pred = imread(pred_imgs[i])
@@ -66,21 +71,23 @@ def quan_eval(path, suffix="jpg"):
 
     return psnr_avg, ssim_avg
 
+
 def mse(gt, pred):
-    return np.mean((gt-pred)**2)
+    return np.mean((gt - pred) ** 2)
+
 
 def mse_raw(path, suffix="npy"):
-    gt_imgs = sorted(glob(path+"raw_tar*.%s"%suffix))
-    pred_imgs = sorted(glob(path+"raw_pred*.%s"%suffix))
+    gt_imgs = sorted(glob(path + "raw_tar*.%s" % suffix))
+    pred_imgs = sorted(glob(path + "raw_pred*.%s" % suffix))
 
     # with open(split_path + "test_gt.txt", 'r') as f_gt, open(split_path+"test_rgb.txt","r") as f_rgb:
     #     gt_imgs = [line.rstrip() for line in f_gt.readlines()]
     #     pred_imgs = [line.rstrip() for line in f_rgb.readlines()]
-    
+
     assert len(gt_imgs) == len(pred_imgs)
 
-    mse_avg = 0.
-    psnr_avg = 0.
+    mse_avg = 0.0
+    psnr_avg = 0.0
     for i in range(len(gt_imgs)):
         gt = np.load(gt_imgs[i])
         pred = np.load(pred_imgs[i])
@@ -100,6 +107,7 @@ def mse_raw(path, suffix="npy"):
 
     return mse_avg, psnr_avg
 
+
 test_full = False
 
 # if test_full:
@@ -107,8 +115,10 @@ test_full = False
 #     mse_avg, psnr_avg_raw = mse_raw(ROOT_PATH+"%s/vis_%s_full/"%(args.task, args.ckpt))
 # else:
 psnr_avg, ssim_avg = quan_eval(args.path, "jpg")
-mse_avg, psnr_avg_raw = mse_raw(args.path)    
-
-print("pnsr: {}, ssim: {}, mse: {}, psnr raw: {}".format(psnr_avg, ssim_avg, mse_avg, psnr_avg_raw))
-
+mse_avg, psnr_avg_raw = mse_raw(args.path)
 
+print(
+    "pnsr: {}, ssim: {}, mse: {}, psnr raw: {}".format(
+        psnr_avg, ssim_avg, mse_avg, psnr_avg_raw
+    )
+)
diff --git a/third_party/DarkFeat/datasets/InvISP/config/config.py b/third_party/DarkFeat/datasets/InvISP/config/config.py
index dc42182ecf7464cc85ed5c77b7aeb9ee4e3ecd74..d0b041cd724db5d8edf629fd56dfba10b83ea6c0 100644
--- a/third_party/DarkFeat/datasets/InvISP/config/config.py
+++ b/third_party/DarkFeat/datasets/InvISP/config/config.py
@@ -5,17 +5,37 @@ BATCH_SIZE = 1
 DATA_PATH = "./data/"
 
 
-
 def get_arguments():
     parser = argparse.ArgumentParser(description="training codes")
-    
+
     parser.add_argument("--task", type=str, help="Name of this training")
-    parser.add_argument("--data_path", type=str, default=DATA_PATH, help="Dataset root path.")
-    parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="Batch size for training. ")       
-    parser.add_argument("--debug_mode", dest='debug_mode', action='store_true',  help="If debug mode, load less data.")    
-    parser.add_argument("--gamma", dest='gamma', action='store_true', help="Use gamma compression for raw data.")     
-    parser.add_argument("--camera", type=str, default="NIKON_D700", choices=["NIKON_D700", "Canon_EOS_5D"], help="Choose which camera to use. ")    
-    parser.add_argument("--rgb_weight", type=float, default=1, help="Weight for rgb loss. ")                 
-    
-    
+    parser.add_argument(
+        "--data_path", type=str, default=DATA_PATH, help="Dataset root path."
+    )
+    parser.add_argument(
+        "--batch_size", type=int, default=BATCH_SIZE, help="Batch size for training. "
+    )
+    parser.add_argument(
+        "--debug_mode",
+        dest="debug_mode",
+        action="store_true",
+        help="If debug mode, load less data.",
+    )
+    parser.add_argument(
+        "--gamma",
+        dest="gamma",
+        action="store_true",
+        help="Use gamma compression for raw data.",
+    )
+    parser.add_argument(
+        "--camera",
+        type=str,
+        default="NIKON_D700",
+        choices=["NIKON_D700", "Canon_EOS_5D"],
+        help="Choose which camera to use. ",
+    )
+    parser.add_argument(
+        "--rgb_weight", type=float, default=1, help="Weight for rgb loss. "
+    )
+
     return parser
diff --git a/third_party/DarkFeat/datasets/InvISP/data/data_preprocess.py b/third_party/DarkFeat/datasets/InvISP/data/data_preprocess.py
index 62271771a17a4863b730136d49f2a23aed0e49b2..3445a409b756b5f2ae6f0f4d1da2c589268635e1 100644
--- a/third_party/DarkFeat/datasets/InvISP/data/data_preprocess.py
+++ b/third_party/DarkFeat/datasets/InvISP/data/data_preprocess.py
@@ -10,22 +10,27 @@ import scipy.io as scio
 parser = argparse.ArgumentParser(description="data preprocess")
 
 parser.add_argument("--camera", type=str, default="NIKON_D700", help="Camera Name")
-parser.add_argument("--Bayer_Pattern", type=str, default="RGGB", help="Bayer Pattern of RAW")
-parser.add_argument("--JPEG_Quality", type=int, default=90, help="Jpeg Quality of the ground truth.")
+parser.add_argument(
+    "--Bayer_Pattern", type=str, default="RGGB", help="Bayer Pattern of RAW"
+)
+parser.add_argument(
+    "--JPEG_Quality", type=int, default=90, help="Jpeg Quality of the ground truth."
+)
 
 args = parser.parse_args()
 camera_name = args.camera
 Bayer_Pattern = args.Bayer_Pattern
 JPEG_Quality = args.JPEG_Quality
 
-dng_path = sorted(glob.glob('/mnt/nvme2n1/hyz/data/' + camera_name + '/DNG/*.cr2'))
-rgb_target_path = '/mnt/nvme2n1/hyz/data/'+ camera_name + '/RGB/'
-raw_input_path = '/mnt/nvme2n1/hyz/data/' + camera_name + '/RAW/'
+dng_path = sorted(glob.glob("/mnt/nvme2n1/hyz/data/" + camera_name + "/DNG/*.cr2"))
+rgb_target_path = "/mnt/nvme2n1/hyz/data/" + camera_name + "/RGB/"
+raw_input_path = "/mnt/nvme2n1/hyz/data/" + camera_name + "/RAW/"
 if not os.path.isdir(rgb_target_path):
     os.mkdir(rgb_target_path)
 if not os.path.isdir(raw_input_path):
     os.mkdir(raw_input_path)
-    
+
+
 def flip(raw_img, flip):
     if flip == 3:
         raw_img = np.rot90(raw_img, k=2)
@@ -38,19 +43,19 @@ def flip(raw_img, flip):
     return raw_img
 
 
-
 for path in dng_path:
     print("Start Processing %s" % os.path.basename(path))
     raw = rawpy.imread(path)
-    file_name = path.split('/')[-1].split('.')[0]
-    im = raw.postprocess(use_camera_wb=True,no_auto_bright=True)
+    file_name = path.split("/")[-1].split(".")[0]
+    im = raw.postprocess(use_camera_wb=True, no_auto_bright=True)
     flip_val = raw.sizes.flip
     cwb = raw.camera_whitebalance
     raw_img = raw.raw_image_visible
-    if camera_name == 'Canon_EOS_5D':
+    if camera_name == "Canon_EOS_5D":
         raw_img = np.maximum(raw_img - 127.0, 0)
     de_raw = colour_demosaicing.demosaicing_CFA_Bayer_bilinear(raw_img, Bayer_Pattern)
     de_raw = flip(de_raw, flip_val)
-    rgb_img = PILImage.fromarray(im).save(rgb_target_path + file_name + '.jpg', quality = JPEG_Quality, subsampling = 1)
-    np.savez(raw_input_path + file_name + '.npz', raw=de_raw, wb=cwb)
-    
+    rgb_img = PILImage.fromarray(im).save(
+        rgb_target_path + file_name + ".jpg", quality=JPEG_Quality, subsampling=1
+    )
+    np.savez(raw_input_path + file_name + ".npz", raw=de_raw, wb=cwb)
diff --git a/third_party/DarkFeat/datasets/InvISP/dataset/FiveK_dataset.py b/third_party/DarkFeat/datasets/InvISP/dataset/FiveK_dataset.py
index 4c71bd3b4162bd21761983deef6b94fa46a364f6..9f0106b9f5175c8cd003cbdcab21f6c9c71e262d 100644
--- a/third_party/DarkFeat/datasets/InvISP/dataset/FiveK_dataset.py
+++ b/third_party/DarkFeat/datasets/InvISP/dataset/FiveK_dataset.py
@@ -14,119 +14,147 @@ from .base_dataset import BaseDataset
 
 class FiveKDatasetTrain(BaseDataset):
     def __init__(self, opt):
-        super().__init__(opt=opt) 
+        super().__init__(opt=opt)
         self.patch_size = 256
         input_RAWs_WBs, target_RGBs = self.load(is_train=True)
-        assert len(input_RAWs_WBs) == len(target_RGBs)        
-        self.data = {'input_RAWs_WBs':input_RAWs_WBs, 'target_RGBs':target_RGBs} 
+        assert len(input_RAWs_WBs) == len(target_RGBs)
+        self.data = {"input_RAWs_WBs": input_RAWs_WBs, "target_RGBs": target_RGBs}
 
     def random_flip(self, input_raw, target_rgb):
         idx = np.random.randint(2)
-        input_raw = np.flip(input_raw,axis=idx).copy()
-        target_rgb = np.flip(target_rgb,axis=idx).copy()
-        
+        input_raw = np.flip(input_raw, axis=idx).copy()
+        target_rgb = np.flip(target_rgb, axis=idx).copy()
+
         return input_raw, target_rgb
 
     def random_rotate(self, input_raw, target_rgb):
         idx = np.random.randint(4)
-        input_raw = np.rot90(input_raw,k=idx)
-        target_rgb = np.rot90(target_rgb,k=idx)
+        input_raw = np.rot90(input_raw, k=idx)
+        target_rgb = np.rot90(target_rgb, k=idx)
 
         return input_raw, target_rgb
 
-    def random_crop(self, patch_size, input_raw, target_rgb,flow=False,demos=False):
+    def random_crop(self, patch_size, input_raw, target_rgb, flow=False, demos=False):
         H, W, _ = input_raw.shape
         rnd_h = random.randint(0, max(0, H - patch_size))
         rnd_w = random.randint(0, max(0, W - patch_size))
 
-        patch_input_raw = input_raw[rnd_h:rnd_h + patch_size, rnd_w:rnd_w + patch_size, :]
+        patch_input_raw = input_raw[
+            rnd_h : rnd_h + patch_size, rnd_w : rnd_w + patch_size, :
+        ]
         if flow or demos:
-            patch_target_rgb = target_rgb[rnd_h:rnd_h + patch_size, rnd_w:rnd_w + patch_size, :]
+            patch_target_rgb = target_rgb[
+                rnd_h : rnd_h + patch_size, rnd_w : rnd_w + patch_size, :
+            ]
         else:
-            patch_target_rgb = target_rgb[rnd_h*2:rnd_h*2 + patch_size*2, rnd_w*2:rnd_w*2 + patch_size*2, :]
+            patch_target_rgb = target_rgb[
+                rnd_h * 2 : rnd_h * 2 + patch_size * 2,
+                rnd_w * 2 : rnd_w * 2 + patch_size * 2,
+                :,
+            ]
 
         return patch_input_raw, patch_target_rgb
-        
+
     def aug(self, patch_size, input_raw, target_rgb, flow=False, demos=False):
-        input_raw, target_rgb = self.random_crop(patch_size, input_raw,target_rgb,flow=flow, demos=demos)
-        input_raw, target_rgb = self.random_rotate(input_raw,target_rgb)
-        input_raw, target_rgb = self.random_flip(input_raw,target_rgb)
-        
+        input_raw, target_rgb = self.random_crop(
+            patch_size, input_raw, target_rgb, flow=flow, demos=demos
+        )
+        input_raw, target_rgb = self.random_rotate(input_raw, target_rgb)
+        input_raw, target_rgb = self.random_flip(input_raw, target_rgb)
+
         return input_raw, target_rgb
 
     def __len__(self):
-        return len(self.data['input_RAWs_WBs'])
+        return len(self.data["input_RAWs_WBs"])
+
+    def __getitem__(self, idx):
+        input_raw_wb_path = self.data["input_RAWs_WBs"][idx]
+        target_rgb_path = self.data["target_RGBs"][idx]
 
-    def __getitem__(self, idx):    
-        input_raw_wb_path = self.data['input_RAWs_WBs'][idx]
-        target_rgb_path = self.data['target_RGBs'][idx]
-        
         target_rgb_img = imread(target_rgb_path)
         input_raw_wb = np.load(input_raw_wb_path)
-        input_raw_img = input_raw_wb['raw']
-        wb = input_raw_wb['wb']
-        wb = wb / wb.max() 
-        input_raw_img = input_raw_img * wb[:-1]   
+        input_raw_img = input_raw_wb["raw"]
+        wb = input_raw_wb["wb"]
+        wb = wb / wb.max()
+        input_raw_img = input_raw_img * wb[:-1]
 
         self.patch_size = 256
-        input_raw_img, target_rgb_img = self.aug(self.patch_size, input_raw_img, target_rgb_img, flow=True, demos=True)  
-
-        if self.gamma:            
-            norm_value = np.power(4095, 1/2.2) if self.camera_name=='Canon_EOS_5D' else np.power(16383, 1/2.2)            
-            input_raw_img = np.power(input_raw_img, 1/2.2)             
+        input_raw_img, target_rgb_img = self.aug(
+            self.patch_size, input_raw_img, target_rgb_img, flow=True, demos=True
+        )
+
+        if self.gamma:
+            norm_value = (
+                np.power(4095, 1 / 2.2)
+                if self.camera_name == "Canon_EOS_5D"
+                else np.power(16383, 1 / 2.2)
+            )
+            input_raw_img = np.power(input_raw_img, 1 / 2.2)
         else:
-            norm_value = 4095 if self.camera_name=='Canon_EOS_5D' else 16383
+            norm_value = 4095 if self.camera_name == "Canon_EOS_5D" else 16383
 
         target_rgb_img = self.norm_img(target_rgb_img, max_value=255)
-        input_raw_img = self.norm_img(input_raw_img, max_value=norm_value)   
+        input_raw_img = self.norm_img(input_raw_img, max_value=norm_value)
         target_raw_img = input_raw_img.copy()
 
         input_raw_img = self.np2tensor(input_raw_img).float()
         target_rgb_img = self.np2tensor(target_rgb_img).float()
         target_raw_img = self.np2tensor(target_raw_img).float()
-        
-        sample = {'input_raw':input_raw_img, 'target_rgb':target_rgb_img, 'target_raw':target_raw_img,
-                    'file_name':input_raw_wb_path.split("/")[-1].split(".")[0]}
+
+        sample = {
+            "input_raw": input_raw_img,
+            "target_rgb": target_rgb_img,
+            "target_raw": target_raw_img,
+            "file_name": input_raw_wb_path.split("/")[-1].split(".")[0],
+        }
         return sample
 
+
 class FiveKDatasetTest(BaseDataset):
     def __init__(self, opt):
         super().__init__(opt=opt)
         self.patch_size = 256
-        
+
         input_RAWs_WBs, target_RGBs = self.load(is_train=False)
-        assert len(input_RAWs_WBs) == len(target_RGBs)        
-        self.data = {'input_RAWs_WBs':input_RAWs_WBs, 'target_RGBs':target_RGBs} 
+        assert len(input_RAWs_WBs) == len(target_RGBs)
+        self.data = {"input_RAWs_WBs": input_RAWs_WBs, "target_RGBs": target_RGBs}
 
     def __len__(self):
-        return len(self.data['input_RAWs_WBs'])
+        return len(self.data["input_RAWs_WBs"])
+
+    def __getitem__(self, idx):
+        input_raw_wb_path = self.data["input_RAWs_WBs"][idx]
+        target_rgb_path = self.data["target_RGBs"][idx]
 
-    def __getitem__(self, idx):    
-        input_raw_wb_path = self.data['input_RAWs_WBs'][idx]
-        target_rgb_path = self.data['target_RGBs'][idx]
-        
         target_rgb_img = imread(target_rgb_path)
         input_raw_wb = np.load(input_raw_wb_path)
-        input_raw_img = input_raw_wb['raw']
-        wb = input_raw_wb['wb']
-        wb = wb / wb.max() 
-        input_raw_img = input_raw_img * wb[:-1]   
-
-        if self.gamma:            
-            norm_value = np.power(4095, 1/2.2) if self.camera_name=='Canon_EOS_5D' else np.power(16383, 1/2.2)            
-            input_raw_img = np.power(input_raw_img, 1/2.2)             
+        input_raw_img = input_raw_wb["raw"]
+        wb = input_raw_wb["wb"]
+        wb = wb / wb.max()
+        input_raw_img = input_raw_img * wb[:-1]
+
+        if self.gamma:
+            norm_value = (
+                np.power(4095, 1 / 2.2)
+                if self.camera_name == "Canon_EOS_5D"
+                else np.power(16383, 1 / 2.2)
+            )
+            input_raw_img = np.power(input_raw_img, 1 / 2.2)
         else:
-            norm_value = 4095 if self.camera_name=='Canon_EOS_5D' else 16383
+            norm_value = 4095 if self.camera_name == "Canon_EOS_5D" else 16383
 
         target_rgb_img = self.norm_img(target_rgb_img, max_value=255)
-        input_raw_img = self.norm_img(input_raw_img, max_value=norm_value)   
+        input_raw_img = self.norm_img(input_raw_img, max_value=norm_value)
         target_raw_img = input_raw_img.copy()
 
         input_raw_img = self.np2tensor(input_raw_img).float()
         target_rgb_img = self.np2tensor(target_rgb_img).float()
         target_raw_img = self.np2tensor(target_raw_img).float()
-        
-        sample = {'input_raw':input_raw_img, 'target_rgb':target_rgb_img, 'target_raw':target_raw_img,
-                    'file_name':input_raw_wb_path.split("/")[-1].split(".")[0]}
-        return sample
 
+        sample = {
+            "input_raw": input_raw_img,
+            "target_rgb": target_rgb_img,
+            "target_raw": target_raw_img,
+            "file_name": input_raw_wb_path.split("/")[-1].split(".")[0],
+        }
+        return sample
diff --git a/third_party/DarkFeat/datasets/InvISP/dataset/base_dataset.py b/third_party/DarkFeat/datasets/InvISP/dataset/base_dataset.py
index 34c5de9f75dbfb5323c2cdad532cb0a42c09df22..1ec55b4edd7663c8323a9b197e938083c6ed2497 100644
--- a/third_party/DarkFeat/datasets/InvISP/dataset/base_dataset.py
+++ b/third_party/DarkFeat/datasets/InvISP/dataset/base_dataset.py
@@ -3,16 +3,17 @@ import numpy as np
 from torch.utils.data import Dataset
 import torch
 
+
 class BaseDataset(Dataset):
     def __init__(self, opt):
         self.crop_size = 512
         self.debug_mode = opt.debug_mode
-        self.data_path = opt.data_path # dataset path. e.g., ./data/
-        self.camera_name = opt.camera 
+        self.data_path = opt.data_path  # dataset path. e.g., ./data/
+        self.camera_name = opt.camera
         self.gamma = opt.gamma
 
     def norm_img(self, img, max_value):
-        img = img / float(max_value)        
+        img = img / float(max_value)
         return img
 
     def pack_raw(self, raw):
@@ -20,15 +21,20 @@ class BaseDataset(Dataset):
         im = np.expand_dims(raw, axis=2)
         H, W = raw.shape[0], raw.shape[1]
         # RGBG
-        out = np.concatenate((im[0:H:2, 0:W:2, :],
-                              im[0:H:2, 1:W:2, :],
-                              im[1:H:2, 1:W:2, :],
-                              im[1:H:2, 0:W:2, :]), axis=2)
+        out = np.concatenate(
+            (
+                im[0:H:2, 0:W:2, :],
+                im[0:H:2, 1:W:2, :],
+                im[1:H:2, 1:W:2, :],
+                im[1:H:2, 0:W:2, :],
+            ),
+            axis=2,
+        )
         return out
-    
+
     def np2tensor(self, array):
-        return torch.Tensor(array).permute(2,0,1)
-    
+        return torch.Tensor(array).permute(2, 0, 1)
+
     def center_crop(self, img, crop_size=None):
         H = img.shape[0]
         W = img.shape[1]
@@ -37,44 +43,43 @@ class BaseDataset(Dataset):
             th, tw = crop_size[0], crop_size[1]
         else:
             th, tw = self.crop_size, self.crop_size
-        x1_img = int(round((W - tw) / 2.))
-        y1_img = int(round((H - th) / 2.))
+        x1_img = int(round((W - tw) / 2.0))
+        y1_img = int(round((H - th) / 2.0))
         if img.ndim == 3:
-            input_patch = img[y1_img:y1_img + th, x1_img:x1_img + tw, :]
+            input_patch = img[y1_img : y1_img + th, x1_img : x1_img + tw, :]
         else:
-            input_patch = img[y1_img:y1_img + th, x1_img:x1_img + tw]
+            input_patch = img[y1_img : y1_img + th, x1_img : x1_img + tw]
 
         return input_patch
 
     def load(self, is_train=True):
         # ./data
-        # ./data/NIKON D700/RAW, ./data/NIKON D700/RGB 
-        # ./data/Canon EOS 5D/RAW,  ./data/Canon EOS 5D/RGB 
-        # ./data/NIKON D700_train.txt, ./data/NIKON D700_test.txt 
-        # ./data/NIKON D700_train.txt: a0016, ... 
-        input_RAWs_WBs = [] 
-        target_RGBs = []        
-        
-        data_path = self.data_path # ./data/ 
+        # ./data/NIKON D700/RAW, ./data/NIKON D700/RGB
+        # ./data/Canon EOS 5D/RAW,  ./data/Canon EOS 5D/RGB
+        # ./data/NIKON D700_train.txt, ./data/NIKON D700_test.txt
+        # ./data/NIKON D700_train.txt: a0016, ...
+        input_RAWs_WBs = []
+        target_RGBs = []
+
+        data_path = self.data_path  # ./data/
         if is_train:
             txt_path = data_path + self.camera_name + "_train.txt"
         else:
             txt_path = data_path + self.camera_name + "_test.txt"
 
         with open(txt_path, "r") as f_read:
-            # valid_camera_list = [os.path.basename(line.strip()).split('.')[0] for line in f_read.readlines()] 
-            valid_camera_list = [line.strip() for line in f_read.readlines()] 
-        
+            # valid_camera_list = [os.path.basename(line.strip()).split('.')[0] for line in f_read.readlines()]
+            valid_camera_list = [line.strip() for line in f_read.readlines()]
+
         if self.debug_mode:
             valid_camera_list = valid_camera_list[:10]
-        
-        for i,name in enumerate(valid_camera_list): 
-            full_name = data_path + self.camera_name 
-            input_RAWs_WBs.append(full_name + "/RAW/" + name + ".npz") 
-            target_RGBs.append(full_name + "/RGB/" + name + ".jpg") 
-            
-        return input_RAWs_WBs, target_RGBs 
 
+        for i, name in enumerate(valid_camera_list):
+            full_name = data_path + self.camera_name
+            input_RAWs_WBs.append(full_name + "/RAW/" + name + ".npz")
+            target_RGBs.append(full_name + "/RGB/" + name + ".jpg")
+
+        return input_RAWs_WBs, target_RGBs
 
     def __len__(self):
         return 0
diff --git a/third_party/DarkFeat/datasets/InvISP/model/loss.py b/third_party/DarkFeat/datasets/InvISP/model/loss.py
index abe8b599d5402c367bb7c84b7e370964d8273518..62a028ec26a8d7f8ef857e0582ac74800dac212e 100644
--- a/third_party/DarkFeat/datasets/InvISP/model/loss.py
+++ b/third_party/DarkFeat/datasets/InvISP/model/loss.py
@@ -2,14 +2,15 @@ import torch.nn.functional as F
 import torch
 
 
-def l1_loss(output, target_rgb, target_raw, weight=1.):
-    raw_loss = F.l1_loss(output['reconstruct_raw'], target_raw)
-    rgb_loss = F.l1_loss(output['reconstruct_rgb'], target_rgb)
+def l1_loss(output, target_rgb, target_raw, weight=1.0):
+    raw_loss = F.l1_loss(output["reconstruct_raw"], target_raw)
+    rgb_loss = F.l1_loss(output["reconstruct_rgb"], target_rgb)
     total_loss = raw_loss + weight * rgb_loss
     return total_loss, raw_loss, rgb_loss
 
-def l2_loss(output, target_rgb, target_raw, weight=1.):
-    raw_loss = F.mse_loss(output['reconstruct_raw'], target_raw)
-    rgb_loss = F.mse_loss(output['reconstruct_rgb'], target_rgb)
+
+def l2_loss(output, target_rgb, target_raw, weight=1.0):
+    raw_loss = F.mse_loss(output["reconstruct_raw"], target_raw)
+    rgb_loss = F.mse_loss(output["reconstruct_rgb"], target_rgb)
     total_loss = raw_loss + weight * rgb_loss
-    return total_loss, raw_loss, rgb_loss
\ No newline at end of file
+    return total_loss, raw_loss, rgb_loss
diff --git a/third_party/DarkFeat/datasets/InvISP/model/model.py b/third_party/DarkFeat/datasets/InvISP/model/model.py
index 9dd0e33cee8ebb26d621ece84622bd2611b33a60..52938290b7ca895a7c71173d40f90df5cd51b0d0 100644
--- a/third_party/DarkFeat/datasets/InvISP/model/model.py
+++ b/third_party/DarkFeat/datasets/InvISP/model/model.py
@@ -14,12 +14,12 @@ def initialize_weights(net_l, scale=1):
     for net in net_l:
         for m in net.modules():
             if isinstance(m, nn.Conv2d):
-                init.kaiming_normal_(m.weight, a=0, mode='fan_in')
+                init.kaiming_normal_(m.weight, a=0, mode="fan_in")
                 m.weight.data *= scale  # for residual block
                 if m.bias is not None:
                     m.bias.data.zero_()
             elif isinstance(m, nn.Linear):
-                init.kaiming_normal_(m.weight, a=0, mode='fan_in')
+                init.kaiming_normal_(m.weight, a=0, mode="fan_in")
                 m.weight.data *= scale
                 if m.bias is not None:
                     m.bias.data.zero_()
@@ -49,7 +49,7 @@ def initialize_weights_xavier(net_l, scale=1):
 
 
 class DenseBlock(nn.Module):
-    def __init__(self, channel_in, channel_out, init='xavier', gc=32, bias=True):
+    def __init__(self, channel_in, channel_out, init="xavier", gc=32, bias=True):
         super(DenseBlock, self).__init__()
         self.conv1 = nn.Conv2d(channel_in, gc, 3, 1, 1, bias=bias)
         self.conv2 = nn.Conv2d(channel_in + gc, gc, 3, 1, 1, bias=bias)
@@ -58,12 +58,14 @@ class DenseBlock(nn.Module):
         self.conv5 = nn.Conv2d(channel_in + 4 * gc, channel_out, 3, 1, 1, bias=bias)
         self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
 
-        if init == 'xavier':
-            initialize_weights_xavier([self.conv1, self.conv2, self.conv3, self.conv4], 0.1)
+        if init == "xavier":
+            initialize_weights_xavier(
+                [self.conv1, self.conv2, self.conv3, self.conv4], 0.1
+            )
         else:
             initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4], 0.1)
         initialize_weights(self.conv5, 0)
-    
+
     def forward(self, x):
         x1 = self.lrelu(self.conv1(x))
         x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
@@ -73,10 +75,11 @@ class DenseBlock(nn.Module):
 
         return x5
 
-def subnet(net_structure, init='xavier'):
+
+def subnet(net_structure, init="xavier"):
     def constructor(channel_in, channel_out):
-        if net_structure == 'DBNet':
-            if init == 'xavier':
+        if net_structure == "DBNet":
+            if init == "xavier":
                 return DenseBlock(channel_in, channel_out, init)
             else:
                 return DenseBlock(channel_in, channel_out)
@@ -93,8 +96,8 @@ class InvBlock(nn.Module):
         # channel_num: 3
         # channel_split_num: 1
 
-        self.split_len1 = channel_split_num # 1
-        self.split_len2 = channel_num - channel_split_num # 2 
+        self.split_len1 = channel_split_num  # 1
+        self.split_len2 = channel_num - channel_split_num  # 2
 
         self.clamp = clamp
 
@@ -102,38 +105,51 @@ class InvBlock(nn.Module):
         self.G = subnet_constructor(self.split_len1, self.split_len2)
         self.H = subnet_constructor(self.split_len1, self.split_len2)
 
-        in_channels = 3        
+        in_channels = 3
         self.invconv = InvertibleConv1x1(in_channels, LU_decomposed=True)
         self.flow_permutation = lambda z, logdet, rev: self.invconv(z, logdet, rev)
-        
+
     def forward(self, x, rev=False):
-        if not rev:            
-            # invert1x1conv 
-            x, logdet = self.flow_permutation(x, logdet=0, rev=False) 
-            
-            # split to 1 channel and 2 channel. 
-            x1, x2 = (x.narrow(1, 0, self.split_len1), x.narrow(1, self.split_len1, self.split_len2)) 
-
-            y1 = x1 + self.F(x2) # 1 channel 
+        if not rev:
+            # invert1x1conv
+            x, logdet = self.flow_permutation(x, logdet=0, rev=False)
+
+            # split to 1 channel and 2 channel.
+            x1, x2 = (
+                x.narrow(1, 0, self.split_len1),
+                x.narrow(1, self.split_len1, self.split_len2),
+            )
+
+            y1 = x1 + self.F(x2)  # 1 channel
             self.s = self.clamp * (torch.sigmoid(self.H(y1)) * 2 - 1)
-            y2 = x2.mul(torch.exp(self.s)) + self.G(y1) # 2 channel 
+            y2 = x2.mul(torch.exp(self.s)) + self.G(y1)  # 2 channel
             out = torch.cat((y1, y2), 1)
         else:
-            # split. 
-            x1, x2 = (x.narrow(1, 0, self.split_len1), x.narrow(1, self.split_len1, self.split_len2)) 
+            # split.
+            x1, x2 = (
+                x.narrow(1, 0, self.split_len1),
+                x.narrow(1, self.split_len1, self.split_len2),
+            )
             self.s = self.clamp * (torch.sigmoid(self.H(x1)) * 2 - 1)
-            y2 = (x2 - self.G(x1)).div(torch.exp(self.s)) 
-            y1 = x1 - self.F(y2) 
+            y2 = (x2 - self.G(x1)).div(torch.exp(self.s))
+            y1 = x1 - self.F(y2)
 
-            x = torch.cat((y1, y2), 1)            
+            x = torch.cat((y1, y2), 1)
 
-            # inv permutation 
+            # inv permutation
             out, logdet = self.flow_permutation(x, logdet=0, rev=True)
 
         return out
 
+
 class InvISPNet(nn.Module):
-    def __init__(self, channel_in=3, channel_out=3, subnet_constructor=subnet('DBNet'), block_num=8):
+    def __init__(
+        self,
+        channel_in=3,
+        channel_out=3,
+        subnet_constructor=subnet("DBNet"),
+        block_num=8,
+    ):
         super(InvISPNet, self).__init__()
         operations = []
 
@@ -141,10 +157,12 @@ class InvISPNet(nn.Module):
         channel_num = channel_in
         channel_split_num = 1
 
-        for j in range(block_num): 
-            b = InvBlock(subnet_constructor, channel_num, channel_split_num) # one block is one flow step. 
+        for j in range(block_num):
+            b = InvBlock(
+                subnet_constructor, channel_num, channel_split_num
+            )  # one block is one flow step.
             operations.append(b)
-        
+
         self.operations = nn.ModuleList(operations)
 
         self.initialize()
@@ -153,27 +171,26 @@ class InvISPNet(nn.Module):
         for m in self.modules():
             if isinstance(m, nn.Conv2d):
                 init.xavier_normal_(m.weight)
-                m.weight.data *= 1.  # for residual block
+                m.weight.data *= 1.0  # for residual block
                 if m.bias is not None:
-                    m.bias.data.zero_() 
+                    m.bias.data.zero_()
             elif isinstance(m, nn.Linear):
                 init.xavier_normal_(m.weight)
-                m.weight.data *= 1.
+                m.weight.data *= 1.0
                 if m.bias is not None:
                     m.bias.data.zero_()
             elif isinstance(m, nn.BatchNorm2d):
                 init.constant_(m.weight, 1)
                 init.constant_(m.bias.data, 0.0)
-    
+
     def forward(self, x, rev=False):
-        out = x # x: [N,3,H,W] 
-        
-        if not rev: 
+        out = x  # x: [N,3,H,W]
+
+        if not rev:
             for op in self.operations:
                 out = op.forward(out, rev)
         else:
             for op in reversed(self.operations):
                 out = op.forward(out, rev)
-        
-        return out
 
+        return out
diff --git a/third_party/DarkFeat/datasets/InvISP/model/modules.py b/third_party/DarkFeat/datasets/InvISP/model/modules.py
index 88244c0b211860d97be78ba4f60f4743228171a7..b32c312d13284bc5a4837df756ed58c505b60768 100644
--- a/third_party/DarkFeat/datasets/InvISP/model/modules.py
+++ b/third_party/DarkFeat/datasets/InvISP/model/modules.py
@@ -47,7 +47,7 @@ def unsqueeze2d(input, factor):
     if factor == 1:
         return input
 
-    factor2 = factor ** 2
+    factor2 = factor**2
 
     B, C, H, W = input.size()
 
diff --git a/third_party/DarkFeat/datasets/InvISP/model/utils.py b/third_party/DarkFeat/datasets/InvISP/model/utils.py
index d1bef31afd7d61d4c942ffd895c818b90571b4b7..a1ab33bf1ba26ee027e1c051f63b0a29fefe6706 100644
--- a/third_party/DarkFeat/datasets/InvISP/model/utils.py
+++ b/third_party/DarkFeat/datasets/InvISP/model/utils.py
@@ -27,7 +27,7 @@ def uniform_binning_correction(x, n_bits=8):
         objective: Equivalent to -q(x)*log(q(x)).
     """
     b, c, h, w = x.size()
-    n_bins = 2 ** n_bits
+    n_bins = 2**n_bits
     chw = c * h * w
     x += torch.zeros_like(x).uniform_(0, 1.0 / n_bins)
 
@@ -42,11 +42,7 @@ def split_feature(tensor, type="split"):
     C = tensor.size(1)
     if type == "split":
         # return tensor[:, : C // 2, ...], tensor[:, C // 2 :, ...]
-        return tensor[:, :1, ...], tensor[:,1:, ...]
+        return tensor[:, :1, ...], tensor[:, 1:, ...]
     elif type == "cross":
         # return tensor[:, 0::2, ...], tensor[:, 1::2, ...]
-        return tensor[:, 0::2, ...], tensor[:, 1::2, ...] 
-
-
-
-
+        return tensor[:, 0::2, ...], tensor[:, 1::2, ...]
diff --git a/third_party/DarkFeat/datasets/InvISP/test_raw.py b/third_party/DarkFeat/datasets/InvISP/test_raw.py
index 37610f8268e4586864e0275236c5bb1932f894df..8c3c30faf6662b04fe34f63de0d729ebcec86517 100644
--- a/third_party/DarkFeat/datasets/InvISP/test_raw.py
+++ b/third_party/DarkFeat/datasets/InvISP/test_raw.py
@@ -18,101 +18,145 @@ from utils.JPEG import DiffJPEG
 from utils.commons import denorm, preprocess_test_patch
 
 
-os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')
-os.environ['CUDA_VISIBLE_DEVICES'] = str(np.argmax([int(x.split()[2]) for x in open('tmp', 'r').readlines()]))
+os.system("nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp")
+os.environ["CUDA_VISIBLE_DEVICES"] = str(
+    np.argmax([int(x.split()[2]) for x in open("tmp", "r").readlines()])
+)
 # os.environ['CUDA_VISIBLE_DEVICES'] = '7'
-os.system('rm tmp')
+os.system("rm tmp")
 
 DiffJPEG = DiffJPEG(differentiable=True, quality=90).cuda()
 
 parser = get_arguments()
-parser.add_argument("--ckpt", type=str, help="Checkpoint path.") 
-parser.add_argument("--out_path", type=str, default="./exps/", help="Path to save checkpoint. ")
-parser.add_argument("--split_to_patch", dest='split_to_patch', action='store_true', help="Test on patch. ")
+parser.add_argument("--ckpt", type=str, help="Checkpoint path.")
+parser.add_argument(
+    "--out_path", type=str, default="./exps/", help="Path to save checkpoint. "
+)
+parser.add_argument(
+    "--split_to_patch",
+    dest="split_to_patch",
+    action="store_true",
+    help="Test on patch. ",
+)
 args = parser.parse_args()
 print("Parsed arguments: {}".format(args))
 
 
 ckpt_name = args.ckpt.split("/")[-1].split(".")[0]
 if args.split_to_patch:
-    os.makedirs(args.out_path+"%s/results_metric_%s/"%(args.task, ckpt_name), exist_ok=True)
-    out_path = args.out_path+"%s/results_metric_%s/"%(args.task, ckpt_name)
+    os.makedirs(
+        args.out_path + "%s/results_metric_%s/" % (args.task, ckpt_name), exist_ok=True
+    )
+    out_path = args.out_path + "%s/results_metric_%s/" % (args.task, ckpt_name)
 else:
-    os.makedirs(args.out_path+"%s/results_%s/"%(args.task, ckpt_name), exist_ok=True)
-    out_path = args.out_path+"%s/results_%s/"%(args.task, ckpt_name)
+    os.makedirs(
+        args.out_path + "%s/results_%s/" % (args.task, ckpt_name), exist_ok=True
+    )
+    out_path = args.out_path + "%s/results_%s/" % (args.task, ckpt_name)
 
 
 def main(args):
     # ======================================define the model============================================
     net = InvISPNet(channel_in=3, channel_out=3, block_num=8)
     device = torch.device("cuda:0")
-    
+
     net.to(device)
     net.eval()
     # load the pretrained weight if there exists one
     if os.path.isfile(args.ckpt):
         net.load_state_dict(torch.load(args.ckpt), strict=False)
         print("[INFO] Loaded checkpoint: {}".format(args.ckpt))
-    
-    print("[INFO] Start data load and preprocessing") 
-    RAWDataset = FiveKDatasetTest(opt=args) 
-    dataloader = DataLoader(RAWDataset, batch_size=1, shuffle=False, num_workers=0, drop_last=True) 
-    
-    input_RGBs = sorted(glob(out_path+"pred*jpg"))
-    input_RGBs_names = [path.split("/")[-1].split(".")[0][5:] for path in input_RGBs]    
-
-    print("[INFO] Start test...") 
+
+    print("[INFO] Start data load and preprocessing")
+    RAWDataset = FiveKDatasetTest(opt=args)
+    dataloader = DataLoader(
+        RAWDataset, batch_size=1, shuffle=False, num_workers=0, drop_last=True
+    )
+
+    input_RGBs = sorted(glob(out_path + "pred*jpg"))
+    input_RGBs_names = [path.split("/")[-1].split(".")[0][5:] for path in input_RGBs]
+
+    print("[INFO] Start test...")
     for i_batch, sample_batched in enumerate(tqdm(dataloader)):
         step_time = time.time()
-        
-        input, target_rgb, target_raw = sample_batched['input_raw'].to(device), sample_batched['target_rgb'].to(device), \
-                            sample_batched['target_raw'].to(device)
-        file_name = sample_batched['file_name'][0]
+
+        input, target_rgb, target_raw = (
+            sample_batched["input_raw"].to(device),
+            sample_batched["target_rgb"].to(device),
+            sample_batched["target_raw"].to(device),
+        )
+        file_name = sample_batched["file_name"][0]
 
         if args.split_to_patch:
-            input_list, target_rgb_list, target_raw_list = preprocess_test_patch(input, target_rgb, target_raw)
+            input_list, target_rgb_list, target_raw_list = preprocess_test_patch(
+                input, target_rgb, target_raw
+            )
         else:
-            # remove [:,:,::2,::2] if you have enough GPU memory to test the full resolution 
-            input_list, target_rgb_list, target_raw_list = [input[:,:,::2,::2]], [target_rgb[:,:,::2,::2]], [target_raw[:,:,::2,::2]]
-        
+            # remove [:,:,::2,::2] if you have enough GPU memory to test the full resolution
+            input_list, target_rgb_list, target_raw_list = (
+                [input[:, :, ::2, ::2]],
+                [target_rgb[:, :, ::2, ::2]],
+                [target_raw[:, :, ::2, ::2]],
+            )
+
         for i_patch in range(len(input_list)):
-            file_name_patch = file_name + "_%05d"%i_patch
+            file_name_patch = file_name + "_%05d" % i_patch
             idx = input_RGBs_names.index(file_name_patch)
             input_RGB_path = input_RGBs[idx]
-            input_RGB = torch.from_numpy(np.array(PILImage.open(input_RGB_path))/255.0).unsqueeze(0).permute(0,3,1,2).float().to(device)
-            
-            target_raw_patch = target_raw_list[i_patch] 
-            
+            input_RGB = (
+                torch.from_numpy(np.array(PILImage.open(input_RGB_path)) / 255.0)
+                .unsqueeze(0)
+                .permute(0, 3, 1, 2)
+                .float()
+                .to(device)
+            )
+
+            target_raw_patch = target_raw_list[i_patch]
+
             with torch.no_grad():
                 reconstruct_raw = net(input_RGB, rev=True)
-            
-            pred_raw = reconstruct_raw.detach().permute(0,2,3,1)
+
+            pred_raw = reconstruct_raw.detach().permute(0, 2, 3, 1)
             pred_raw = torch.clamp(pred_raw, 0, 1)
-            
-            target_raw_patch = target_raw_patch.permute(0,2,3,1)
+
+            target_raw_patch = target_raw_patch.permute(0, 2, 3, 1)
             pred_raw = denorm(pred_raw, 255)
             target_raw_patch = denorm(target_raw_patch, 255)
 
             pred_raw = pred_raw.cpu().numpy()
             target_raw_patch = target_raw_patch.cpu().numpy().astype(np.float32)
 
-            raw_pred = PILImage.fromarray(np.uint8(pred_raw[0,:,:,0]))
-            raw_tar_pred = PILImage.fromarray(np.hstack((np.uint8(target_raw_patch[0,:,:,0]), np.uint8(pred_raw[0,:,:,0]))))
-            
-            raw_tar = PILImage.fromarray(np.uint8(target_raw_patch[0,:,:,0]))
+            raw_pred = PILImage.fromarray(np.uint8(pred_raw[0, :, :, 0]))
+            raw_tar_pred = PILImage.fromarray(
+                np.hstack(
+                    (
+                        np.uint8(target_raw_patch[0, :, :, 0]),
+                        np.uint8(pred_raw[0, :, :, 0]),
+                    )
+                )
+            )
 
-            raw_pred.save(out_path+"raw_pred_%s_%05d.jpg"%(file_name, i_patch))
-            raw_tar.save(out_path+"raw_tar_%s_%05d.jpg"%(file_name, i_patch))
-            raw_tar_pred.save(out_path+"raw_gt_pred_%s_%05d.jpg"%(file_name, i_patch))
-            
-            np.save(out_path+"raw_pred_%s_%05d.npy"%(file_name, i_patch), pred_raw[0,:,:,:]/255.0)
-            np.save(out_path+"raw_tar_%s_%05d.npy"%(file_name, i_patch), target_raw_patch[0,:,:,:]/255.0)
+            raw_tar = PILImage.fromarray(np.uint8(target_raw_patch[0, :, :, 0]))
 
-            del reconstruct_raw            
+            raw_pred.save(out_path + "raw_pred_%s_%05d.jpg" % (file_name, i_patch))
+            raw_tar.save(out_path + "raw_tar_%s_%05d.jpg" % (file_name, i_patch))
+            raw_tar_pred.save(
+                out_path + "raw_gt_pred_%s_%05d.jpg" % (file_name, i_patch)
+            )
 
+            np.save(
+                out_path + "raw_pred_%s_%05d.npy" % (file_name, i_patch),
+                pred_raw[0, :, :, :] / 255.0,
+            )
+            np.save(
+                out_path + "raw_tar_%s_%05d.npy" % (file_name, i_patch),
+                target_raw_patch[0, :, :, :] / 255.0,
+            )
 
-if __name__ == '__main__':
+            del reconstruct_raw
+
+
+if __name__ == "__main__":
 
     torch.set_num_threads(4)
     main(args)
-
diff --git a/third_party/DarkFeat/datasets/InvISP/test_rgb.py b/third_party/DarkFeat/datasets/InvISP/test_rgb.py
index d1e054b899d9142609e3f90f4a12d367a45aeac0..5c1c9f1839acd58e71b4dc244b0ce3132d09b8c7 100644
--- a/third_party/DarkFeat/datasets/InvISP/test_rgb.py
+++ b/third_party/DarkFeat/datasets/InvISP/test_rgb.py
@@ -16,90 +16,133 @@ from utils.JPEG import DiffJPEG
 from utils.commons import denorm, preprocess_test_patch
 from tqdm import tqdm
 
-os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')
-os.environ['CUDA_VISIBLE_DEVICES'] = str(np.argmax([int(x.split()[2]) for x in open('tmp', 'r').readlines()]))
+os.system("nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp")
+os.environ["CUDA_VISIBLE_DEVICES"] = str(
+    np.argmax([int(x.split()[2]) for x in open("tmp", "r").readlines()])
+)
 # os.environ['CUDA_VISIBLE_DEVICES'] = '7'
-os.system('rm tmp')
+os.system("rm tmp")
 
 DiffJPEG = DiffJPEG(differentiable=True, quality=90).cuda()
 
 parser = get_arguments()
-parser.add_argument("--ckpt", type=str, help="Checkpoint path.") 
-parser.add_argument("--out_path", type=str, default="./exps/", help="Path to save results. ")
-parser.add_argument("--split_to_patch", dest='split_to_patch', action='store_true', help="Test on patch. ")
+parser.add_argument("--ckpt", type=str, help="Checkpoint path.")
+parser.add_argument(
+    "--out_path", type=str, default="./exps/", help="Path to save results. "
+)
+parser.add_argument(
+    "--split_to_patch",
+    dest="split_to_patch",
+    action="store_true",
+    help="Test on patch. ",
+)
 args = parser.parse_args()
 print("Parsed arguments: {}".format(args))
 
 
 ckpt_name = args.ckpt.split("/")[-1].split(".")[0]
 if args.split_to_patch:
-    os.makedirs(args.out_path+"%s/results_metric_%s/"%(args.task, ckpt_name), exist_ok=True)
-    out_path = args.out_path+"%s/results_metric_%s/"%(args.task, ckpt_name)
+    os.makedirs(
+        args.out_path + "%s/results_metric_%s/" % (args.task, ckpt_name), exist_ok=True
+    )
+    out_path = args.out_path + "%s/results_metric_%s/" % (args.task, ckpt_name)
 else:
-    os.makedirs(args.out_path+"%s/results_%s/"%(args.task, ckpt_name), exist_ok=True)
-    out_path = args.out_path+"%s/results_%s/"%(args.task, ckpt_name)
+    os.makedirs(
+        args.out_path + "%s/results_%s/" % (args.task, ckpt_name), exist_ok=True
+    )
+    out_path = args.out_path + "%s/results_%s/" % (args.task, ckpt_name)
 
 
 def main(args):
     # ======================================define the model============================================
     net = InvISPNet(channel_in=3, channel_out=3, block_num=8)
     device = torch.device("cuda:0")
-    
+
     net.to(device)
     net.eval()
     # load the pretrained weight if there exists one
     if os.path.isfile(args.ckpt):
         net.load_state_dict(torch.load(args.ckpt), strict=False)
         print("[INFO] Loaded checkpoint: {}".format(args.ckpt))
-    
-    print("[INFO] Start data load and preprocessing") 
-    RAWDataset = FiveKDatasetTest(opt=args) 
-    dataloader = DataLoader(RAWDataset, batch_size=1, shuffle=False, num_workers=0, drop_last=True) 
-    
-    print("[INFO] Start test...") 
+
+    print("[INFO] Start data load and preprocessing")
+    RAWDataset = FiveKDatasetTest(opt=args)
+    dataloader = DataLoader(
+        RAWDataset, batch_size=1, shuffle=False, num_workers=0, drop_last=True
+    )
+
+    print("[INFO] Start test...")
     for i_batch, sample_batched in enumerate(tqdm(dataloader)):
-        step_time = time.time() 
-        
-        input, target_rgb, target_raw = sample_batched['input_raw'].to(device), sample_batched['target_rgb'].to(device), \
-                            sample_batched['target_raw'].to(device)
-        file_name = sample_batched['file_name'][0]
-        
+        step_time = time.time()
+
+        input, target_rgb, target_raw = (
+            sample_batched["input_raw"].to(device),
+            sample_batched["target_rgb"].to(device),
+            sample_batched["target_raw"].to(device),
+        )
+        file_name = sample_batched["file_name"][0]
+
         if args.split_to_patch:
-            input_list, target_rgb_list, target_raw_list = preprocess_test_patch(input, target_rgb, target_raw)
+            input_list, target_rgb_list, target_raw_list = preprocess_test_patch(
+                input, target_rgb, target_raw
+            )
         else:
-            # remove [:,:,::2,::2] if you have enough GPU memory to test the full resolution 
-            input_list, target_rgb_list, target_raw_list = [input[:,:,::2,::2]], [target_rgb[:,:,::2,::2]], [target_raw[:,:,::2,::2]]
-        
+            # remove [:,:,::2,::2] if you have enough GPU memory to test the full resolution
+            input_list, target_rgb_list, target_raw_list = (
+                [input[:, :, ::2, ::2]],
+                [target_rgb[:, :, ::2, ::2]],
+                [target_raw[:, :, ::2, ::2]],
+            )
+
         for i_patch in range(len(input_list)):
             input_patch = input_list[i_patch]
             target_rgb_patch = target_rgb_list[i_patch]
-            target_raw_patch = target_raw_list[i_patch] 
-            
+            target_raw_patch = target_raw_list[i_patch]
+
             with torch.no_grad():
                 reconstruct_rgb = net(input_patch)
                 reconstruct_rgb = torch.clamp(reconstruct_rgb, 0, 1)
-            
-            pred_rgb = reconstruct_rgb.detach().permute(0,2,3,1)
-            target_rgb_patch = target_rgb_patch.permute(0,2,3,1)
-            
+
+            pred_rgb = reconstruct_rgb.detach().permute(0, 2, 3, 1)
+            target_rgb_patch = target_rgb_patch.permute(0, 2, 3, 1)
+
             pred_rgb = denorm(pred_rgb, 255)
             target_rgb_patch = denorm(target_rgb_patch, 255)
             pred_rgb = pred_rgb.cpu().numpy()
             target_rgb_patch = target_rgb_patch.cpu().numpy().astype(np.float32)
-            
+
             # print(type(pred_rgb))
-            pred = PILImage.fromarray(np.uint8(pred_rgb[0,:,:,:]))
-            tar_pred = PILImage.fromarray(np.hstack((np.uint8(target_rgb_patch[0,:,:,:]), np.uint8(pred_rgb[0,:,:,:]))))
-            
-            tar = PILImage.fromarray(np.uint8(target_rgb_patch[0,:,:,:]))
-            
-            pred.save(out_path+"pred_%s_%05d.jpg"%(file_name, i_patch), quality=90, subsampling=1)
-            tar.save(out_path+"tar_%s_%05d.jpg"%(file_name, i_patch), quality=90, subsampling=1)
-            tar_pred.save(out_path+"gt_pred_%s_%05d.jpg"%(file_name, i_patch), quality=90, subsampling=1)
-            
+            pred = PILImage.fromarray(np.uint8(pred_rgb[0, :, :, :]))
+            tar_pred = PILImage.fromarray(
+                np.hstack(
+                    (
+                        np.uint8(target_rgb_patch[0, :, :, :]),
+                        np.uint8(pred_rgb[0, :, :, :]),
+                    )
+                )
+            )
+
+            tar = PILImage.fromarray(np.uint8(target_rgb_patch[0, :, :, :]))
+
+            pred.save(
+                out_path + "pred_%s_%05d.jpg" % (file_name, i_patch),
+                quality=90,
+                subsampling=1,
+            )
+            tar.save(
+                out_path + "tar_%s_%05d.jpg" % (file_name, i_patch),
+                quality=90,
+                subsampling=1,
+            )
+            tar_pred.save(
+                out_path + "gt_pred_%s_%05d.jpg" % (file_name, i_patch),
+                quality=90,
+                subsampling=1,
+            )
+
             del reconstruct_rgb
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     torch.set_num_threads(4)
     main(args)
-
diff --git a/third_party/DarkFeat/datasets/InvISP/train.py b/third_party/DarkFeat/datasets/InvISP/train.py
index 16186cb38d825ac1299e5c4164799d35bfa79907..4022c4a8f523b97ffeb928263b14a79bd8b54a20 100644
--- a/third_party/DarkFeat/datasets/InvISP/train.py
+++ b/third_party/DarkFeat/datasets/InvISP/train.py
@@ -14,85 +14,130 @@ from config.config import get_arguments
 
 from utils.JPEG import DiffJPEG
 
-os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')
-os.environ['CUDA_VISIBLE_DEVICES'] = str(np.argmax([int(x.split()[2]) for x in open('tmp', 'r').readlines()]))
+os.system("nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp")
+os.environ["CUDA_VISIBLE_DEVICES"] = str(
+    np.argmax([int(x.split()[2]) for x in open("tmp", "r").readlines()])
+)
 # os.environ['CUDA_VISIBLE_DEVICES'] = "1"
-os.system('rm tmp')
+os.system("rm tmp")
 
 DiffJPEG = DiffJPEG(differentiable=True, quality=90).cuda()
 
 parser = get_arguments()
-parser.add_argument("--out_path", type=str, default="./exps/", help="Path to save checkpoint. ")
-parser.add_argument("--resume", dest='resume', action='store_true',  help="Resume training. ")
-parser.add_argument("--loss", type=str, default="L1", choices=["L1", "L2"], help="Choose which loss function to use. ")
+parser.add_argument(
+    "--out_path", type=str, default="./exps/", help="Path to save checkpoint. "
+)
+parser.add_argument(
+    "--resume", dest="resume", action="store_true", help="Resume training. "
+)
+parser.add_argument(
+    "--loss",
+    type=str,
+    default="L1",
+    choices=["L1", "L2"],
+    help="Choose which loss function to use. ",
+)
 parser.add_argument("--lr", type=float, default=0.0001, help="Learning rate")
-parser.add_argument("--aug", dest='aug', action='store_true', help="Use data augmentation.")
+parser.add_argument(
+    "--aug", dest="aug", action="store_true", help="Use data augmentation."
+)
 args = parser.parse_args()
 print("Parsed arguments: {}".format(args))
 
 os.makedirs(args.out_path, exist_ok=True)
-os.makedirs(args.out_path+"%s"%args.task, exist_ok=True)
-os.makedirs(args.out_path+"%s/checkpoint"%args.task, exist_ok=True)
+os.makedirs(args.out_path + "%s" % args.task, exist_ok=True)
+os.makedirs(args.out_path + "%s/checkpoint" % args.task, exist_ok=True)
 
-with open(args.out_path+"%s/commandline_args.yaml"%args.task , 'w') as f:
+with open(args.out_path + "%s/commandline_args.yaml" % args.task, "w") as f:
     json.dump(args.__dict__, f, indent=2)
 
+
 def main(args):
     # ======================================define the model======================================
     net = InvISPNet(channel_in=3, channel_out=3, block_num=8)
     net.cuda()
     # load the pretrained weight if there exists one
     if args.resume:
-        net.load_state_dict(torch.load(args.out_path+"%s/checkpoint/latest.pth"%args.task))
-        print("[INFO] loaded " + args.out_path+"%s/checkpoint/latest.pth"%args.task)
+        net.load_state_dict(
+            torch.load(args.out_path + "%s/checkpoint/latest.pth" % args.task)
+        )
+        print("[INFO] loaded " + args.out_path + "%s/checkpoint/latest.pth" % args.task)
 
     optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
-    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[50, 80], gamma=0.5)    
-    
+    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[50, 80], gamma=0.5)
+
     print("[INFO] Start data loading and preprocessing")
-    RAWDataset = FiveKDatasetTrain(opt=args)        
-    dataloader = DataLoader(RAWDataset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
+    RAWDataset = FiveKDatasetTrain(opt=args)
+    dataloader = DataLoader(
+        RAWDataset,
+        batch_size=args.batch_size,
+        shuffle=True,
+        num_workers=0,
+        drop_last=True,
+    )
 
     print("[INFO] Start to train")
     step = 0
     for epoch in range(0, 300):
-        epoch_time = time.time()             
-        
+        epoch_time = time.time()
+
         for i_batch, sample_batched in enumerate(dataloader):
-            step_time = time.time() 
+            step_time = time.time()
 
-            input, target_rgb, target_raw = sample_batched['input_raw'].cuda(), sample_batched['target_rgb'].cuda(), \
-                                        sample_batched['target_raw'].cuda()
-            
-            reconstruct_rgb = net(input) 
+            input, target_rgb, target_raw = (
+                sample_batched["input_raw"].cuda(),
+                sample_batched["target_rgb"].cuda(),
+                sample_batched["target_raw"].cuda(),
+            )
+
+            reconstruct_rgb = net(input)
             reconstruct_rgb = torch.clamp(reconstruct_rgb, 0, 1)
             rgb_loss = F.l1_loss(reconstruct_rgb, target_rgb)
             reconstruct_rgb = DiffJPEG(reconstruct_rgb)
             reconstruct_raw = net(reconstruct_rgb, rev=True)
             raw_loss = F.l1_loss(reconstruct_raw, target_raw)
-            
+
             loss = args.rgb_weight * rgb_loss + raw_loss
-            
+
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
-            
-            print("task: %s Epoch: %d Step: %d || loss: %.5f raw_loss: %.5f rgb_loss: %.5f || lr: %f time: %f"%(
-                args.task, epoch, step, loss.detach().cpu().numpy(), raw_loss.detach().cpu().numpy(), 
-                rgb_loss.detach().cpu().numpy(), optimizer.param_groups[0]['lr'], time.time()-step_time
-            )) 
-            step += 1 
-        
-        torch.save(net.state_dict(), args.out_path+"%s/checkpoint/latest.pth"%args.task)
-        if (epoch+1) % 10 == 0:
+
+            print(
+                "task: %s Epoch: %d Step: %d || loss: %.5f raw_loss: %.5f rgb_loss: %.5f || lr: %f time: %f"
+                % (
+                    args.task,
+                    epoch,
+                    step,
+                    loss.detach().cpu().numpy(),
+                    raw_loss.detach().cpu().numpy(),
+                    rgb_loss.detach().cpu().numpy(),
+                    optimizer.param_groups[0]["lr"],
+                    time.time() - step_time,
+                )
+            )
+            step += 1
+
+        torch.save(
+            net.state_dict(), args.out_path + "%s/checkpoint/latest.pth" % args.task
+        )
+        if (epoch + 1) % 10 == 0:
             # os.makedirs(args.out_path+"%s/checkpoint/%04d"%(args.task,epoch), exist_ok=True)
-            torch.save(net.state_dict(), args.out_path+"%s/checkpoint/%04d.pth"%(args.task,epoch))
-            print("[INFO] Successfully saved "+args.out_path+"%s/checkpoint/%04d.pth"%(args.task,epoch))
-        scheduler.step()   
-        
-        print("[INFO] Epoch time: ", time.time()-epoch_time, "task: ", args.task)    
+            torch.save(
+                net.state_dict(),
+                args.out_path + "%s/checkpoint/%04d.pth" % (args.task, epoch),
+            )
+            print(
+                "[INFO] Successfully saved "
+                + args.out_path
+                + "%s/checkpoint/%04d.pth" % (args.task, epoch)
+            )
+        scheduler.step()
+
+        print("[INFO] Epoch time: ", time.time() - epoch_time, "task: ", args.task)
+
 
-if __name__ == '__main__':
+if __name__ == "__main__":
 
     torch.set_num_threads(4)
     main(args)
diff --git a/third_party/DarkFeat/datasets/InvISP/utils/JPEG.py b/third_party/DarkFeat/datasets/InvISP/utils/JPEG.py
index 8997ee98a41668b4737a9b2acc2341032f173bd3..7cdd7fa91ee424250f241ecc7de63d868795aaa7 100644
--- a/third_party/DarkFeat/datasets/InvISP/utils/JPEG.py
+++ b/third_party/DarkFeat/datasets/InvISP/utils/JPEG.py
@@ -1,5 +1,3 @@
-
-
 import torch
 import torch.nn as nn
 
@@ -8,16 +6,16 @@ from .compression import compress_jpeg
 from .decompression import decompress_jpeg
 
 
-class DiffJPEG(nn.Module):    
+class DiffJPEG(nn.Module):
     def __init__(self, differentiable=True, quality=75):
-        ''' Initialize the DiffJPEG layer
+        """Initialize the DiffJPEG layer
         Inputs:
             height(int): Original image height
             width(int): Original image width
             differentiable(bool): If true uses custom differentiable
                 rounding function, if false uses standrard torch.round
-            quality(float): Quality factor for jpeg compression scheme. 
-        '''
+            quality(float): Quality factor for jpeg compression scheme.
+        """
         super(DiffJPEG, self).__init__()
         if differentiable:
             rounding = diff_round
@@ -31,13 +29,10 @@ class DiffJPEG(nn.Module):
         self.decompress = decompress_jpeg(rounding=rounding, factor=factor)
 
     def forward(self, x):
-        '''
-        '''
+        """ """
         org_height = x.shape[2]
         org_width = x.shape[3]
         y, cb, cr = self.compress(x)
 
         recovered = self.decompress(y, cb, cr, org_height, org_width)
         return recovered
-
-
diff --git a/third_party/DarkFeat/datasets/InvISP/utils/JPEG_utils.py b/third_party/DarkFeat/datasets/InvISP/utils/JPEG_utils.py
index e2ebd9bdc184e869ade58eea1c6763baa1d9fc91..4ef225505d21728f63d34cec55e5335a50130e17 100644
--- a/third_party/DarkFeat/datasets/InvISP/utils/JPEG_utils.py
+++ b/third_party/DarkFeat/datasets/InvISP/utils/JPEG_utils.py
@@ -1,58 +1,65 @@
 # Standard libraries
 import numpy as np
+
 # PyTorch
 import torch
 import torch.nn as nn
 import math
 
 y_table = np.array(
-    [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60,
-                                        55], [14, 13, 16, 24, 40, 57, 69, 56],
-     [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103,
-                                        77], [24, 35, 55, 64, 81, 104, 113, 92],
-     [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
-    dtype=np.float32).T
+    [
+        [16, 11, 10, 16, 24, 40, 51, 61],
+        [12, 12, 14, 19, 26, 58, 60, 55],
+        [14, 13, 16, 24, 40, 57, 69, 56],
+        [14, 17, 22, 29, 51, 87, 80, 62],
+        [18, 22, 37, 56, 68, 109, 103, 77],
+        [24, 35, 55, 64, 81, 104, 113, 92],
+        [49, 64, 78, 87, 103, 121, 120, 101],
+        [72, 92, 95, 98, 112, 100, 103, 99],
+    ],
+    dtype=np.float32,
+).T
 
 y_table = nn.Parameter(torch.from_numpy(y_table))
 #
 c_table = np.empty((8, 8), dtype=np.float32)
 c_table.fill(99)
-c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66],
-                            [24, 26, 56, 99], [47, 66, 99, 99]]).T
+c_table[:4, :4] = np.array(
+    [[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]
+).T
 c_table = nn.Parameter(torch.from_numpy(c_table))
 
 
 def diff_round_back(x):
-    """ Differentiable rounding function
+    """Differentiable rounding function
     Input:
         x(tensor)
     Output:
         x(tensor)
     """
-    return torch.round(x) + (x - torch.round(x))**3
-
+    return torch.round(x) + (x - torch.round(x)) ** 3
 
 
 def diff_round(input_tensor):
     test = 0
     for n in range(1, 10):
-        test += math.pow(-1, n+1) / n * torch.sin(2 * math.pi * n * input_tensor)
+        test += math.pow(-1, n + 1) / n * torch.sin(2 * math.pi * n * input_tensor)
     final_tensor = input_tensor - 1 / math.pi * test
     return final_tensor
 
 
 class Quant(torch.autograd.Function):
-
     @staticmethod
     def forward(ctx, input):
         input = torch.clamp(input, 0, 1)
-        output = (input * 255.).round() / 255.
+        output = (input * 255.0).round() / 255.0
         return output
 
     @staticmethod
     def backward(ctx, grad_output):
         return grad_output
 
+
 class Quantization(nn.Module):
     def __init__(self):
         super(Quantization, self).__init__()
@@ -62,14 +69,14 @@ class Quantization(nn.Module):
 
 
 def quality_to_factor(quality):
-    """ Calculate factor corresponding to quality
+    """Calculate factor corresponding to quality
     Input:
         quality(float): Quality for jpeg compression
     Output:
         factor(float): Compression factor
     """
     if quality < 50:
-        quality = 5000. / quality
+        quality = 5000.0 / quality
     else:
-        quality = 200. - quality*2
-    return quality / 100.
\ No newline at end of file
+        quality = 200.0 - quality * 2
+    return quality / 100.0
diff --git a/third_party/DarkFeat/datasets/InvISP/utils/commons.py b/third_party/DarkFeat/datasets/InvISP/utils/commons.py
index e594e0597bac601edc2015d9cae670799f981495..ea546a3fa517304e97652f00c5cc65a8a2b512d6 100644
--- a/third_party/DarkFeat/datasets/InvISP/utils/commons.py
+++ b/third_party/DarkFeat/datasets/InvISP/utils/commons.py
@@ -5,6 +5,7 @@ def denorm(img, max_value):
     img = img * float(max_value)
     return img
 
+
 def preprocess_test_patch(input_image, target_image, gt_image):
     input_patch_list = []
     target_patch_list = []
@@ -13,11 +14,26 @@ def preprocess_test_patch(input_image, target_image, gt_image):
     W = input_image.shape[3]
     for i in range(3):
         for j in range(3):
-            input_patch = input_image[:,:,int(i * H / 3):int((i+1) * H / 3),int(j * W / 3):int((j+1) * W / 3)]
-            target_patch = target_image[:,:,int(i * H / 3):int((i+1) * H / 3),int(j * W / 3):int((j+1) * W / 3)]
-            gt_patch = gt_image[:,:,int(i * H / 3):int((i+1) * H / 3),int(j * W / 3):int((j+1) * W / 3)]
+            input_patch = input_image[
+                :,
+                :,
+                int(i * H / 3) : int((i + 1) * H / 3),
+                int(j * W / 3) : int((j + 1) * W / 3),
+            ]
+            target_patch = target_image[
+                :,
+                :,
+                int(i * H / 3) : int((i + 1) * H / 3),
+                int(j * W / 3) : int((j + 1) * W / 3),
+            ]
+            gt_patch = gt_image[
+                :,
+                :,
+                int(i * H / 3) : int((i + 1) * H / 3),
+                int(j * W / 3) : int((j + 1) * W / 3),
+            ]
             input_patch_list.append(input_patch)
             target_patch_list.append(target_patch)
             gt_patch_list.append(gt_patch)
-            
+
     return input_patch_list, target_patch_list, gt_patch_list
diff --git a/third_party/DarkFeat/datasets/InvISP/utils/compression.py b/third_party/DarkFeat/datasets/InvISP/utils/compression.py
index 3ae22f8839517bfd7e3c774528943e8fff59dce7..9519bb99cedd1cf64efc3dacc07d59603d9e7508 100644
--- a/third_party/DarkFeat/datasets/InvISP/utils/compression.py
+++ b/third_party/DarkFeat/datasets/InvISP/utils/compression.py
@@ -1,40 +1,47 @@
 # Standard libraries
 import itertools
 import numpy as np
+
 # PyTorch
 import torch
 import torch.nn as nn
+
 # Local
 from . import JPEG_utils
 
 
 class rgb_to_ycbcr_jpeg(nn.Module):
-    """ Converts RGB image to YCbCr
+    """Converts RGB image to YCbCr
     Input:
         image(tensor): batch x 3 x height x width
     Outpput:
         result(tensor): batch x height x width x 3
     """
+
     def __init__(self):
         super(rgb_to_ycbcr_jpeg, self).__init__()
         matrix = np.array(
-            [[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5],
-             [0.5, -0.418688, -0.081312]], dtype=np.float32).T
-        self.shift = nn.Parameter(torch.tensor([0., 128., 128.]))
+            [
+                [0.299, 0.587, 0.114],
+                [-0.168736, -0.331264, 0.5],
+                [0.5, -0.418688, -0.081312],
+            ],
+            dtype=np.float32,
+        ).T
+        self.shift = nn.Parameter(torch.tensor([0.0, 128.0, 128.0]))
         #
         self.matrix = nn.Parameter(torch.from_numpy(matrix))
 
     def forward(self, image):
         image = image.permute(0, 2, 3, 1)
         result = torch.tensordot(image, self.matrix, dims=1) + self.shift
-    #    result = torch.from_numpy(result)
+        #    result = torch.from_numpy(result)
         result.view(image.shape)
         return result
 
 
-
 class chroma_subsampling(nn.Module):
-    """ Chroma subsampling on CbCv channels
+    """Chroma subsampling on CbCv channels
     Input:
         image(tensor): batch x height x width x 3
     Output:
@@ -42,27 +49,28 @@ class chroma_subsampling(nn.Module):
         cb(tensor): batch x height/2 x width/2
         cr(tensor): batch x height/2 x width/2
     """
+
     def __init__(self):
         super(chroma_subsampling, self).__init__()
 
     def forward(self, image):
         image_2 = image.permute(0, 3, 1, 2).clone()
-        avg_pool = nn.AvgPool2d(kernel_size=2, stride=(2, 2),
-                                count_include_pad=False)
+        avg_pool = nn.AvgPool2d(kernel_size=2, stride=(2, 2), count_include_pad=False)
         cb = avg_pool(image_2[:, 1, :, :].unsqueeze(1))
         cr = avg_pool(image_2[:, 2, :, :].unsqueeze(1))
         cb = cb.permute(0, 2, 3, 1)
         cr = cr.permute(0, 2, 3, 1)
         return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3)
-        
+
 
 class block_splitting(nn.Module):
-    """ Splitting image into patches
+    """Splitting image into patches
     Input:
         image(tensor): batch x height x width
-    Output: 
+    Output:
         patch(tensor):  batch x h*w/64 x h x w
     """
+
     def __init__(self):
         super(block_splitting, self).__init__()
         self.k = 8
@@ -75,26 +83,30 @@ class block_splitting(nn.Module):
         image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k)
         image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
         return image_transposed.contiguous().view(batch_size, -1, self.k, self.k)
-    
+
 
 class dct_8x8(nn.Module):
-    """ Discrete Cosine Transformation
+    """Discrete Cosine Transformation
     Input:
         image(tensor): batch x height x width
     Output:
         dcp(tensor): batch x height x width
     """
+
     def __init__(self):
         super(dct_8x8, self).__init__()
         tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
         for x, y, u, v in itertools.product(range(8), repeat=4):
             tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos(
-                (2 * y + 1) * v * np.pi / 16)
-        alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
+                (2 * y + 1) * v * np.pi / 16
+            )
+        alpha = np.array([1.0 / np.sqrt(2)] + [1] * 7)
         #
-        self.tensor =  nn.Parameter(torch.from_numpy(tensor).float())
-        self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float() )
-        
+        self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
+        self.scale = nn.Parameter(
+            torch.from_numpy(np.outer(alpha, alpha) * 0.25).float()
+        )
+
     def forward(self, image):
         image = image - 128
         result = self.scale * torch.tensordot(image, self.tensor, dims=2)
@@ -103,7 +115,7 @@ class dct_8x8(nn.Module):
 
 
 class y_quantize(nn.Module):
-    """ JPEG Quantization for Y channel
+    """JPEG Quantization for Y channel
     Input:
         image(tensor): batch x height x width
         rounding(function): rounding function to use
@@ -111,6 +123,7 @@ class y_quantize(nn.Module):
     Output:
         image(tensor): batch x height x width
     """
+
     def __init__(self, rounding, factor=1):
         super(y_quantize, self).__init__()
         self.rounding = rounding
@@ -124,7 +137,7 @@ class y_quantize(nn.Module):
 
 
 class c_quantize(nn.Module):
-    """ JPEG Quantization for CrCb channels
+    """JPEG Quantization for CrCb channels
     Input:
         image(tensor): batch x height x width
         rounding(function): rounding function to use
@@ -132,6 +145,7 @@ class c_quantize(nn.Module):
     Output:
         image(tensor): batch x height x width
     """
+
     def __init__(self, rounding, factor=1):
         super(c_quantize, self).__init__()
         self.rounding = rounding
@@ -145,41 +159,39 @@ class c_quantize(nn.Module):
 
 
 class compress_jpeg(nn.Module):
-    """ Full JPEG compression algortihm
+    """Full JPEG compression algortihm
     Input:
-        imgs(tensor): batch x 3 x height x width 
+        imgs(tensor): batch x 3 x height x width
         rounding(function): rounding function to use
         factor(float): Compression factor
     Ouput:
         compressed(dict(tensor)): batch x h*w/64 x 8 x 8
     """
+
     def __init__(self, rounding=torch.round, factor=1):
         super(compress_jpeg, self).__init__()
         self.l1 = nn.Sequential(
             rgb_to_ycbcr_jpeg(),
-            # comment this line if no subsampling 
-            chroma_subsampling()
-        )
-        self.l2 = nn.Sequential(
-            block_splitting(),
-            dct_8x8()
+            # comment this line if no subsampling
+            chroma_subsampling(),
         )
+        self.l2 = nn.Sequential(block_splitting(), dct_8x8())
         self.c_quantize = c_quantize(rounding=rounding, factor=factor)
         self.y_quantize = y_quantize(rounding=rounding, factor=factor)
-        
+
     def forward(self, image):
-        y, cb, cr = self.l1(image*255) # modify 
+        y, cb, cr = self.l1(image * 255)  # modify
 
         # y, cb, cr = result[:,:,:,0], result[:,:,:,1], result[:,:,:,2]
-        components = {'y': y, 'cb': cb, 'cr': cr}
+        components = {"y": y, "cb": cb, "cr": cr}
         for k in components.keys():
             comp = self.l2(components[k])
             # print(comp.shape)
-            if k in ('cb', 'cr'):
+            if k in ("cb", "cr"):
                 comp = self.c_quantize(comp)
             else:
                 comp = self.y_quantize(comp)
 
             components[k] = comp
 
-        return components['y'], components['cb'], components['cr']
\ No newline at end of file
+        return components["y"], components["cb"], components["cr"]
diff --git a/third_party/DarkFeat/datasets/InvISP/utils/decompression.py b/third_party/DarkFeat/datasets/InvISP/utils/decompression.py
index b73ff96d5f6818e1d0464b9c4133f559a3b23fba..8a006442522b8b39261c78be85fcf16e7400fe7e 100644
--- a/third_party/DarkFeat/datasets/InvISP/utils/decompression.py
+++ b/third_party/DarkFeat/datasets/InvISP/utils/decompression.py
@@ -1,21 +1,24 @@
 # Standard libraries
 import itertools
 import numpy as np
+
 # PyTorch
 import torch
 import torch.nn as nn
+
 # Local
 from . import JPEG_utils as utils
 
 
 class y_dequantize(nn.Module):
-    """ Dequantize Y channel
+    """Dequantize Y channel
     Inputs:
         image(tensor): batch x height x width
         factor(float): compression factor
     Outputs:
         image(tensor): batch x height x width
     """
+
     def __init__(self, factor=1):
         super(y_dequantize, self).__init__()
         self.y_table = utils.y_table
@@ -26,13 +29,14 @@ class y_dequantize(nn.Module):
 
 
 class c_dequantize(nn.Module):
-    """ Dequantize CbCr channel
+    """Dequantize CbCr channel
     Inputs:
         image(tensor): batch x height x width
         factor(float): compression factor
     Outputs:
         image(tensor): batch x height x width
     """
+
     def __init__(self, factor=1):
         super(c_dequantize, self).__init__()
         self.factor = factor
@@ -43,24 +47,26 @@ class c_dequantize(nn.Module):
 
 
 class idct_8x8(nn.Module):
-    """ Inverse discrete Cosine Transformation
+    """Inverse discrete Cosine Transformation
     Input:
         dcp(tensor): batch x height x width
     Output:
         image(tensor): batch x height x width
     """
+
     def __init__(self):
         super(idct_8x8, self).__init__()
-        alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
+        alpha = np.array([1.0 / np.sqrt(2)] + [1] * 7)
         self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
         tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
         for x, y, u, v in itertools.product(range(8), repeat=4):
             tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos(
-                (2 * v + 1) * y * np.pi / 16)
+                (2 * v + 1) * y * np.pi / 16
+            )
         self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
 
     def forward(self, image):
-        
+
         image = image * self.alpha
         result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
         result.view(image.shape)
@@ -68,7 +74,7 @@ class idct_8x8(nn.Module):
 
 
 class block_merging(nn.Module):
-    """ Merge pathces into image
+    """Merge pathces into image
     Inputs:
         patches(tensor) batch x height*width/64, height x width
         height(int)
@@ -76,30 +82,32 @@ class block_merging(nn.Module):
     Output:
         image(tensor): batch x height x width
     """
+
     def __init__(self):
         super(block_merging, self).__init__()
-        
+
     def forward(self, patches, height, width):
         k = 8
         batch_size = patches.shape[0]
-        # print(patches.shape) # (1,1024,8,8) 
-        image_reshaped = patches.view(batch_size, height//k, width//k, k, k)
+        # print(patches.shape) # (1,1024,8,8)
+        image_reshaped = patches.view(batch_size, height // k, width // k, k, k)
         image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
         return image_transposed.contiguous().view(batch_size, height, width)
 
 
 class chroma_upsampling(nn.Module):
-    """ Upsample chroma layers
-    Input: 
+    """Upsample chroma layers
+    Input:
         y(tensor): y channel image
         cb(tensor): cb channel
         cr(tensor): cr channel
     Ouput:
         image(tensor): batch x height x width x 3
     """
+
     def __init__(self):
         super(chroma_upsampling, self).__init__()
-    
+
     def forward(self, y, cb, cr):
         def repeat(x, k=2):
             height, width = x.shape[1:3]
@@ -110,35 +118,37 @@ class chroma_upsampling(nn.Module):
 
         cb = repeat(cb)
         cr = repeat(cr)
-        
+
         return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)
 
 
 class ycbcr_to_rgb_jpeg(nn.Module):
-    """ Converts YCbCr image to RGB JPEG
+    """Converts YCbCr image to RGB JPEG
     Input:
         image(tensor): batch x height x width x 3
     Outpput:
         result(tensor): batch x 3 x height x width
     """
+
     def __init__(self):
         super(ycbcr_to_rgb_jpeg, self).__init__()
 
         matrix = np.array(
-            [[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]],
-            dtype=np.float32).T
-        self.shift = nn.Parameter(torch.tensor([0, -128., -128.]))
+            [[1.0, 0.0, 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]],
+            dtype=np.float32,
+        ).T
+        self.shift = nn.Parameter(torch.tensor([0, -128.0, -128.0]))
         self.matrix = nn.Parameter(torch.from_numpy(matrix))
 
     def forward(self, image):
         result = torch.tensordot(image + self.shift, self.matrix, dims=1)
-        #result = torch.from_numpy(result)
+        # result = torch.from_numpy(result)
         result.view(image.shape)
         return result.permute(0, 3, 1, 2)
 
 
 class decompress_jpeg(nn.Module):
-    """ Full JPEG decompression algortihm
+    """Full JPEG decompression algortihm
     Input:
         compressed(dict(tensor)): batch x h*w/64 x 8 x 8
         rounding(function): rounding function to use
@@ -146,6 +156,7 @@ class decompress_jpeg(nn.Module):
     Ouput:
         image(tensor): batch x 3 x height x width
     """
+
     # def __init__(self, height, width, rounding=torch.round, factor=1):
     def __init__(self, rounding=torch.round, factor=1):
         super(decompress_jpeg, self).__init__()
@@ -156,35 +167,35 @@ class decompress_jpeg(nn.Module):
         # comment this line if no subsampling
         self.chroma = chroma_upsampling()
         self.colors = ycbcr_to_rgb_jpeg()
-        
+
         # self.height, self.width = height, width
-        
+
     def forward(self, y, cb, cr, height, width):
-        components = {'y': y, 'cb': cb, 'cr': cr}
+        components = {"y": y, "cb": cb, "cr": cr}
         # height = y.shape[0]
         # width = y.shape[1]
         self.height = height
         self.width = width
         for k in components.keys():
-            if k in ('cb', 'cr'):
+            if k in ("cb", "cr"):
                 comp = self.c_dequantize(components[k])
                 # comment this line if no subsampling
-                height, width = int(self.height/2), int(self.width/2)
+                height, width = int(self.height / 2), int(self.width / 2)
                 # height, width = int(self.height), int(self.width)
-                
+
             else:
-                comp = self.y_dequantize(components[k]) 
-                # comment this line if no subsampling 
-                height, width = self.height, self.width 
-            comp = self.idct(comp) 
-            components[k] = self.merging(comp, height, width) 
-            # 
-        # comment this line if no subsampling 
-        image = self.chroma(components['y'], components['cb'], components['cr']) 
-        # image = torch.cat([components['y'].unsqueeze(3), components['cb'].unsqueeze(3), components['cr'].unsqueeze(3)], dim=3) 
+                comp = self.y_dequantize(components[k])
+                # comment this line if no subsampling
+                height, width = self.height, self.width
+            comp = self.idct(comp)
+            components[k] = self.merging(comp, height, width)
+            #
+        # comment this line if no subsampling
+        image = self.chroma(components["y"], components["cb"], components["cr"])
+        # image = torch.cat([components['y'].unsqueeze(3), components['cb'].unsqueeze(3), components['cr'].unsqueeze(3)], dim=3)
         image = self.colors(image)
 
-        image = torch.min(255*torch.ones_like(image),
-                          torch.max(torch.zeros_like(image), image))
-        return image/255
-
+        image = torch.min(
+            255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image)
+        )
+        return image / 255
diff --git a/third_party/DarkFeat/datasets/gl3d/io.py b/third_party/DarkFeat/datasets/gl3d/io.py
index 9e5b4b0459d6814ef6af17a0a322b59202037d4f..9b48a2be61ba799d567b7df45c9b9b011cbef4be 100644
--- a/third_party/DarkFeat/datasets/gl3d/io.py
+++ b/third_party/DarkFeat/datasets/gl3d/io.py
@@ -5,42 +5,42 @@ import numpy as np
 
 from ..utils.common import Notify
 
+
 def read_list(list_path):
     """Read list."""
     if list_path is None or not os.path.exists(list_path):
-        print(Notify.FAIL, 'Not exist', list_path, Notify.ENDC)
+        print(Notify.FAIL, "Not exist", list_path, Notify.ENDC)
         exit(-1)
     content = open(list_path).read().splitlines()
     return content
 
 
 def load_pfm(pfm_path):
-    with open(pfm_path, 'rb') as fin:
+    with open(pfm_path, "rb") as fin:
         color = None
         width = None
         height = None
         scale = None
         data_type = None
-        header = str(fin.readline().decode('UTF-8')).rstrip()
+        header = str(fin.readline().decode("UTF-8")).rstrip()
 
-        if header == 'PF':
+        if header == "PF":
             color = True
-        elif header == 'Pf':
+        elif header == "Pf":
             color = False
         else:
-            raise Exception('Not a PFM file.')
+            raise Exception("Not a PFM file.")
 
-        dim_match = re.match(r'^(\d+)\s(\d+)\s$',
-                             fin.readline().decode('UTF-8'))
+        dim_match = re.match(r"^(\d+)\s(\d+)\s$", fin.readline().decode("UTF-8"))
         if dim_match:
             width, height = map(int, dim_match.groups())
         else:
-            raise Exception('Malformed PFM header.')
-        scale = float((fin.readline().decode('UTF-8')).rstrip())
+            raise Exception("Malformed PFM header.")
+        scale = float((fin.readline().decode("UTF-8")).rstrip())
         if scale < 0:  # little-endian
-            data_type = '<f'
+            data_type = "<f"
         else:
-            data_type = '>f'  # big-endian
+            data_type = ">f"  # big-endian
         data_string = fin.read()
         data = np.fromstring(data_string, data_type)
         shape = (height, width, 3) if color else (height, width)
@@ -52,25 +52,24 @@ def load_pfm(pfm_path):
 def _parse_img(img_paths, idx, config):
     img_path = img_paths[idx]
     img = cv2.imread(img_path)[:, :, ::-1]
-    if config['resize'] > 0:
-        img = cv2.resize(
-            img, (config['resize'], config['resize']))
+    if config["resize"] > 0:
+        img = cv2.resize(img, (config["resize"], config["resize"]))
     return img
 
 
 def _parse_depth(depth_paths, idx, config):
     depth = load_pfm(depth_paths[idx])
 
-    if config['resize'] > 0:
-        target_size = config['resize']
-    if config['input_type'] == 'raw':
-        depth = cv2.resize(depth, (int(target_size/2), int(target_size/2)))
+    if config["resize"] > 0:
+        target_size = config["resize"]
+    if config["input_type"] == "raw":
+        depth = cv2.resize(depth, (int(target_size / 2), int(target_size / 2)))
     else:
         depth = cv2.resize(depth, (target_size, target_size))
     return depth
 
 
 def _parse_kpts(kpts_paths, idx, config):
-    kpts = np.load(kpts_paths[idx])['pts']
+    kpts = np.load(kpts_paths[idx])["pts"]
     # output: [N, 2] (W first H last)
     return kpts
diff --git a/third_party/DarkFeat/datasets/gl3d_dataset.py b/third_party/DarkFeat/datasets/gl3d_dataset.py
index db3d2db646ae7fce81424f5f72cdff7e6e34ba60..0dd9ea77f44bcc065a895c05a66cdc843632ddee 100644
--- a/third_party/DarkFeat/datasets/gl3d_dataset.py
+++ b/third_party/DarkFeat/datasets/gl3d_dataset.py
@@ -15,17 +15,18 @@ class GL3DDataset(Dataset):
         self.config = config
         self.is_training = is_training
         self.data_split = data_split
-        
-        self.match_set_list, self.global_img_list, \
-            self.global_depth_list = self.prepare_match_sets()
 
-        pass
+        (
+            self.match_set_list,
+            self.global_img_list,
+            self.global_depth_list,
+        ) = self.prepare_match_sets()
 
+        pass
 
     def __len__(self):
         return len(self.match_set_list)
 
-
     def __getitem__(self, idx):
         match_set_path = self.match_set_list[idx]
         decoded = np.fromfile(match_set_path, dtype=np.float32)
@@ -50,26 +51,24 @@ class GL3DDataset(Dataset):
         img1 = photaug(img1)
 
         return {
-            'img0': img0 / 255.,
-            'img1': img1 / 255.,
-            'depth0': depth0,
-            'depth1': depth1,
-            'ori_img_size0': ori_img_size0,
-            'ori_img_size1': ori_img_size1,
-            'K0': K0,
-            'K1': K1,
-            'rel_pose': rel_pose,
-            'inlier_num': inlier_num
+            "img0": img0 / 255.0,
+            "img1": img1 / 255.0,
+            "depth0": depth0,
+            "depth1": depth1,
+            "ori_img_size0": ori_img_size0,
+            "ori_img_size1": ori_img_size1,
+            "K0": K0,
+            "K1": K1,
+            "rel_pose": rel_pose,
+            "inlier_num": inlier_num,
         }
 
-
     def points_to_2D(self, pnts, H, W):
         labels = np.zeros((H, W))
         pnts = pnts.astype(int)
         labels[pnts[:, 1], pnts[:, 0]] = 1
         return labels
 
-
     def prepare_match_sets(self, q_diff_thld=3, rot_diff_thld=60):
         """Get match sets.
         Args:
@@ -81,20 +80,29 @@ class GL3DDataset(Dataset):
             global_context_feat_list:
         """
         # get necessary lists.
-        gl3d_list_folder = os.path.join(self.dataset_dir, 'list', self.data_split)
-        global_info = read_list(os.path.join(
-            gl3d_list_folder, 'image_index_offset.txt'))
-        global_img_list = [os.path.join(self.dataset_dir, i) for i in read_list(
-            os.path.join(gl3d_list_folder, 'image_list.txt'))]
-        global_depth_list = [os.path.join(self.dataset_dir, i) for i in read_list(
-            os.path.join(gl3d_list_folder, 'depth_list.txt'))]
-
-        imageset_list_name = 'imageset_train.txt' if self.is_training else 'imageset_test.txt'
-        match_set_list = self.get_match_set_list(os.path.join(
-            gl3d_list_folder, imageset_list_name), q_diff_thld, rot_diff_thld)
+        gl3d_list_folder = os.path.join(self.dataset_dir, "list", self.data_split)
+        global_info = read_list(
+            os.path.join(gl3d_list_folder, "image_index_offset.txt")
+        )
+        global_img_list = [
+            os.path.join(self.dataset_dir, i)
+            for i in read_list(os.path.join(gl3d_list_folder, "image_list.txt"))
+        ]
+        global_depth_list = [
+            os.path.join(self.dataset_dir, i)
+            for i in read_list(os.path.join(gl3d_list_folder, "depth_list.txt"))
+        ]
+
+        imageset_list_name = (
+            "imageset_train.txt" if self.is_training else "imageset_test.txt"
+        )
+        match_set_list = self.get_match_set_list(
+            os.path.join(gl3d_list_folder, imageset_list_name),
+            q_diff_thld,
+            rot_diff_thld,
+        )
         return match_set_list, global_img_list, global_depth_list
 
-
     def get_match_set_list(self, imageset_list_path, q_diff_thld, rot_diff_thld):
         """Get the path list of match sets.
         Args:
@@ -103,25 +111,25 @@ class GL3DDataset(Dataset):
         Returns:
             match_set_list: List of match set path.
         """
-        imageset_list = [os.path.join(self.dataset_dir, 'data', i)
-                        for i in read_list(imageset_list_path)]
-        print(Notify.INFO, 'Use # imageset', len(imageset_list), Notify.ENDC)
+        imageset_list = [
+            os.path.join(self.dataset_dir, "data", i)
+            for i in read_list(imageset_list_path)
+        ]
+        print(Notify.INFO, "Use # imageset", len(imageset_list), Notify.ENDC)
         match_set_list = []
         # discard image pairs whose image simiarity is beyond the threshold.
         for i in imageset_list:
-            match_set_folder = os.path.join(i, 'match_sets')
+            match_set_folder = os.path.join(i, "match_sets")
             if os.path.exists(match_set_folder):
                 match_set_files = os.listdir(match_set_folder)
                 for val in match_set_files:
                     name, ext = os.path.splitext(val)
-                    if ext == '.match_set':
-                        splits = name.split('_')
+                    if ext == ".match_set":
+                        splits = name.split("_")
                         q_diff = int(splits[2])
                         rot_diff = int(splits[3])
                         if q_diff >= q_diff_thld and rot_diff <= rot_diff_thld:
-                            match_set_list.append(
-                                os.path.join(match_set_folder, val))
+                            match_set_list.append(os.path.join(match_set_folder, val))
 
-        print(Notify.INFO, 'Get # match sets', len(match_set_list), Notify.ENDC)
+        print(Notify.INFO, "Get # match sets", len(match_set_list), Notify.ENDC)
         return match_set_list
-        
diff --git a/third_party/DarkFeat/datasets/noise.py b/third_party/DarkFeat/datasets/noise.py
index aa68c98183186e9e9185e78e1a3e7335ac8d5bb1..a44c6a902c653f6c829a2536a49e5a3c9790e5de 100644
--- a/third_party/DarkFeat/datasets/noise.py
+++ b/third_party/DarkFeat/datasets/noise.py
@@ -3,31 +3,49 @@ import random
 from scipy.stats import tukeylambda
 
 camera_params = {
-    'Kmin': 0.2181895124454343,
-    'Kmax': 3.0,
-    'G_shape': np.array([0.15714286, 0.14285714, 0.08571429, 0.08571429, 0.2       ,
-                         0.2       , 0.1       , 0.08571429, 0.05714286, 0.07142857,
-                         0.02857143, 0.02857143, 0.01428571, 0.02857143, 0.08571429,
-                         0.07142857, 0.11428571, 0.11428571]),
-    'Profile-1': {
-        'R_scale': {
-            'slope': 0.4712797750747537,
-            'bias': -0.8078958947116487,
-            'sigma': 0.2436176299944695
+    "Kmin": 0.2181895124454343,
+    "Kmax": 3.0,
+    "G_shape": np.array(
+        [
+            0.15714286,
+            0.14285714,
+            0.08571429,
+            0.08571429,
+            0.2,
+            0.2,
+            0.1,
+            0.08571429,
+            0.05714286,
+            0.07142857,
+            0.02857143,
+            0.02857143,
+            0.01428571,
+            0.02857143,
+            0.08571429,
+            0.07142857,
+            0.11428571,
+            0.11428571,
+        ]
+    ),
+    "Profile-1": {
+        "R_scale": {
+            "slope": 0.4712797750747537,
+            "bias": -0.8078958947116487,
+            "sigma": 0.2436176299944695,
         },
-        'g_scale': {
-            'slope': 0.6771267783987617,
-            'bias': 1.5121876510805845,
-            'sigma': 0.24641096601611254
+        "g_scale": {
+            "slope": 0.6771267783987617,
+            "bias": 1.5121876510805845,
+            "sigma": 0.24641096601611254,
+        },
+        "G_scale": {
+            "slope": 0.6558756156508007,
+            "bias": 1.09268679594838,
+            "sigma": 0.28604721742277756,
         },
-        'G_scale': {
-            'slope': 0.6558756156508007,
-            'bias': 1.09268679594838,
-            'sigma': 0.28604721742277756
-        }
     },
-    'black_level': 2048,
-    'max_value': 16383
+    "black_level": 2048,
+    "max_value": 16383,
 }
 
 
@@ -46,15 +64,18 @@ def addGStarNoise(img, K, G_shape, G_scale_param):
 
     rand_num = random.uniform(0, 1)
     idx = np.sum(np.cumsum(a) < rand_num)
-    lam = random.uniform(b[idx], b[idx+1])
+    lam = random.uniform(b[idx], b[idx + 1])
 
     # calculate scale parameter [G_scale]
     log_K = np.log(K)
-    log_G_scale = np.random.standard_normal() * G_scale_param['sigma'] * 1 +\
-             G_scale_param['slope'] * log_K + G_scale_param['bias']
+    log_G_scale = (
+        np.random.standard_normal() * G_scale_param["sigma"] * 1
+        + G_scale_param["slope"] * log_K
+        + G_scale_param["bias"]
+    )
     G_scale = np.exp(log_G_scale)
     # print(f'G_scale: {G_scale}')
-    
+
     return img + tukeylambda.rvs(lam, scale=G_scale, size=img.shape).astype(np.float32)
 
 
@@ -63,11 +84,14 @@ def addGStarNoise(img, K, G_shape, G_scale_param):
 def addRowNoise(img, K, R_scale_param):
     # calculate scale parameter [R_scale]
     log_K = np.log(K)
-    log_R_scale = np.random.standard_normal() * R_scale_param['sigma'] * 1 +\
-             R_scale_param['slope'] * log_K + R_scale_param['bias']
+    log_R_scale = (
+        np.random.standard_normal() * R_scale_param["sigma"] * 1
+        + R_scale_param["slope"] * log_K
+        + R_scale_param["bias"]
+    )
     R_scale = np.exp(log_R_scale)
     # print(f'R_scale: {R_scale}')
-    
+
     row_noise = np.random.randn(img.shape[0], 1).astype(np.float32) * R_scale
     return img + np.tile(row_noise, (1, img.shape[1]))
 
@@ -75,7 +99,7 @@ def addRowNoise(img, K, R_scale_param):
 # quantization noise
 # uniform distribution
 def addQuantNoise(img, q):
-    return img + np.random.uniform(low=-0.5*q, high=0.5*q, size=img.shape)
+    return img + np.random.uniform(low=-0.5 * q, high=0.5 * q, size=img.shape)
 
 
 def sampleK(Kmin, Kmax):
diff --git a/third_party/DarkFeat/datasets/noise_simulator.py b/third_party/DarkFeat/datasets/noise_simulator.py
index 17e21d3b3443aaa3585ae8460709f60b05835a84..8d7ff4ad00583b1a0879160d725a5de4dade4892 100644
--- a/third_party/DarkFeat/datasets/noise_simulator.py
+++ b/third_party/DarkFeat/datasets/noise_simulator.py
@@ -14,17 +14,28 @@ import colour_demosaicing
 
 from .InvISP.model.model import InvISPNet
 from .utils.common import Notify
-from datasets.noise import camera_params, addGStarNoise, addPStarNoise, addQuantNoise, addRowNoise, sampleK
+from datasets.noise import (
+    camera_params,
+    addGStarNoise,
+    addPStarNoise,
+    addQuantNoise,
+    addRowNoise,
+    sampleK,
+)
 
 
 class NoiseSimulator:
-    def __init__(self, device, ckpt_path='./datasets/InvISP/pretrained/canon.pth'):
+    def __init__(self, device, ckpt_path="./datasets/InvISP/pretrained/canon.pth"):
         self.device = device
 
         # load Invertible ISP Network
-        self.net = InvISPNet(channel_in=3, channel_out=3, block_num=8).to(self.device).eval()
+        self.net = (
+            InvISPNet(channel_in=3, channel_out=3, block_num=8).to(self.device).eval()
+        )
         self.net.load_state_dict(torch.load(ckpt_path), strict=False)
-        print(Notify.INFO, "Loaded ISPNet checkpoint: {}".format(ckpt_path), Notify.ENDC)
+        print(
+            Notify.INFO, "Loaded ISPNet checkpoint: {}".format(ckpt_path), Notify.ENDC
+        )
 
         # white balance parameters
         self.wb = np.array([2020.0, 1024.0, 1458.0, 1024.0])
@@ -75,11 +86,11 @@ class NoiseSimulator:
     # input: [H, W]
     # output: [H, W, 3]
     def demosaic(self, img):
-        return colour_demosaicing.demosaicing_CFA_Bayer_bilinear(img, 'RGGB')
+        return colour_demosaicing.demosaicing_CFA_Bayer_bilinear(img, "RGGB")
 
     # load rgb image
     def path2rgb(self, path):
-        return torch.from_numpy(np.array(PILImage.open(path))/255.0)
+        return torch.from_numpy(np.array(PILImage.open(path)) / 255.0)
 
     # InvISP
     # input: rgb image [H, W, 3]
@@ -89,21 +100,21 @@ class NoiseSimulator:
         if not batched:
             rgb = rgb.unsqueeze(0)
 
-        rgb = rgb.permute(0,3,1,2).float().to(self.device)
+        rgb = rgb.permute(0, 3, 1, 2).float().to(self.device)
         with torch.no_grad():
             reconstruct_raw = self.net(rgb, rev=True)
 
-        pred_raw = reconstruct_raw.detach().permute(0,2,3,1)
+        pred_raw = reconstruct_raw.detach().permute(0, 2, 3, 1)
         pred_raw = torch.clamp(pred_raw, 0, 1)
 
         if not batched:
             pred_raw = pred_raw[0, ...]
-            
+
         pred_raw = pred_raw.cpu().numpy()
 
         # 2. -> inv gamma
-        norm_value = np.power(16383, 1/2.2)
-        pred_raw *= norm_value          
+        norm_value = np.power(16383, 1 / 2.2)
+        pred_raw *= norm_value
         pred_raw = np.power(pred_raw, 2.2)
 
         # 3. -> inv white balance
@@ -111,7 +122,7 @@ class NoiseSimulator:
         pred_raw = pred_raw / wb[:-1]
 
         # 4. -> add black level
-        pred_raw += self.camera_params['black_level']
+        pred_raw += self.camera_params["black_level"]
 
         # 5. -> inv demosaic
         if not batched:
@@ -124,18 +135,24 @@ class NoiseSimulator:
 
         return pred_raw
 
-    
     def raw2noisyRaw(self, raw, ratio_dec=1, batched=False):
         if not batched:
             ratio = (random.uniform(self.ratio_min, self.ratio_max) - 1) * ratio_dec + 1
             raw = raw.copy() / ratio
 
-            K = sampleK(self.camera_params['Kmin'], self.camera_params['Kmax'])
-            q = 1 / (self.camera_params['max_value'] - self.camera_params['black_level'])
+            K = sampleK(self.camera_params["Kmin"], self.camera_params["Kmax"])
+            q = 1 / (
+                self.camera_params["max_value"] - self.camera_params["black_level"]
+            )
 
             raw = addPStarNoise(raw, K)
-            raw = addGStarNoise(raw, K, self.camera_params['G_shape'], self.camera_params['Profile-1']['G_scale'])
-            raw = addRowNoise(raw, K, self.camera_params['Profile-1']['R_scale'])
+            raw = addGStarNoise(
+                raw,
+                K,
+                self.camera_params["G_shape"],
+                self.camera_params["Profile-1"]["G_scale"],
+            )
+            raw = addRowNoise(raw, K, self.camera_params["Profile-1"]["R_scale"])
             raw = addQuantNoise(raw, q)
             raw *= ratio
             return raw
@@ -146,12 +163,21 @@ class NoiseSimulator:
                 ratio = random.uniform(self.ratio_min, self.ratio_max)
                 raw[i] /= ratio
 
-                K = sampleK(self.camera_params['Kmin'], self.camera_params['Kmax'])
-                q = 1 / (self.camera_params['max_value'] - self.camera_params['black_level'])
+                K = sampleK(self.camera_params["Kmin"], self.camera_params["Kmax"])
+                q = 1 / (
+                    self.camera_params["max_value"] - self.camera_params["black_level"]
+                )
 
                 raw[i] = addPStarNoise(raw[i], K)
-                raw[i] = addGStarNoise(raw[i], K, self.camera_params['G_shape'], self.camera_params['Profile-1']['G_scale'])
-                raw[i] = addRowNoise(raw[i], K, self.camera_params['Profile-1']['R_scale'])
+                raw[i] = addGStarNoise(
+                    raw[i],
+                    K,
+                    self.camera_params["G_shape"],
+                    self.camera_params["Profile-1"]["G_scale"],
+                )
+                raw[i] = addRowNoise(
+                    raw[i], K, self.camera_params["Profile-1"]["R_scale"]
+                )
                 raw[i] = addQuantNoise(raw[i], q)
                 raw[i] *= ratio
             return raw
@@ -167,29 +193,38 @@ class NoiseSimulator:
             raw = np.stack(raws, axis=0)
 
         # 2. -> substract black level
-        raw -= self.camera_params['black_level']
-        raw = np.clip(raw, 0, self.camera_params['max_value'] - self.camera_params['black_level'])
+        raw -= self.camera_params["black_level"]
+        raw = np.clip(
+            raw, 0, self.camera_params["max_value"] - self.camera_params["black_level"]
+        )
 
         # 3. -> white balance
         wb = self.wb / self.wb.max()
         raw = raw * wb[:-1]
 
         # 4. -> gamma
-        norm_value = np.power(16383, 1/2.2)            
-        raw = np.power(raw, 1/2.2)
+        norm_value = np.power(16383, 1 / 2.2)
+        raw = np.power(raw, 1 / 2.2)
         raw /= norm_value
 
         # 5. -> ispnet
         if not batched:
-            input_raw_img = torch.Tensor(raw).permute(2,0,1).float().to(self.device)[np.newaxis, ...]
+            input_raw_img = (
+                torch.Tensor(raw)
+                .permute(2, 0, 1)
+                .float()
+                .to(self.device)[np.newaxis, ...]
+            )
         else:
-            input_raw_img = torch.Tensor(raw).permute(0,3,1,2).float().to(self.device)
+            input_raw_img = (
+                torch.Tensor(raw).permute(0, 3, 1, 2).float().to(self.device)
+            )
 
         with torch.no_grad():
             reconstruct_rgb = self.net(input_raw_img)
             reconstruct_rgb = torch.clamp(reconstruct_rgb, 0, 1)
 
-        pred_rgb = reconstruct_rgb.detach().permute(0,2,3,1)
+        pred_rgb = reconstruct_rgb.detach().permute(0, 2, 3, 1)
 
         if not batched:
             pred_rgb = pred_rgb[0, ...]
@@ -197,12 +232,13 @@ class NoiseSimulator:
 
         return pred_rgb
 
-
     def raw2packedRaw(self, raw, batched=False):
         # 1. -> substract black level
-        raw -= self.camera_params['black_level']
-        raw = np.clip(raw, 0, self.camera_params['max_value'] - self.camera_params['black_level'])
-        raw /= self.camera_params['max_value']
+        raw -= self.camera_params["black_level"]
+        raw = np.clip(
+            raw, 0, self.camera_params["max_value"] - self.camera_params["black_level"]
+        )
+        raw /= self.camera_params["max_value"]
 
         # 2. pack
         if not batched:
@@ -211,20 +247,30 @@ class NoiseSimulator:
             H = img_shape[0]
             W = img_shape[1]
 
-            out = np.concatenate((im[0:H:2, 0:W:2, :],
-                                  im[0:H:2, 1:W:2, :],
-                                  im[1:H:2, 1:W:2, :],
-                                  im[1:H:2, 0:W:2, :]), axis=2)
+            out = np.concatenate(
+                (
+                    im[0:H:2, 0:W:2, :],
+                    im[0:H:2, 1:W:2, :],
+                    im[1:H:2, 1:W:2, :],
+                    im[1:H:2, 0:W:2, :],
+                ),
+                axis=2,
+            )
         else:
             im = np.expand_dims(raw, axis=3)
             img_shape = im.shape
             H = img_shape[1]
             W = img_shape[2]
 
-            out = np.concatenate((im[:, 0:H:2, 0:W:2, :],
-                                  im[:, 0:H:2, 1:W:2, :],
-                                  im[:, 1:H:2, 1:W:2, :],
-                                  im[:, 1:H:2, 0:W:2, :]), axis=3)
+            out = np.concatenate(
+                (
+                    im[:, 0:H:2, 0:W:2, :],
+                    im[:, 0:H:2, 1:W:2, :],
+                    im[:, 1:H:2, 1:W:2, :],
+                    im[:, 1:H:2, 0:W:2, :],
+                ),
+                axis=3,
+            )
         return out
 
     def raw2demosaicRaw(self, raw, batched=False):
@@ -238,7 +284,9 @@ class NoiseSimulator:
             raw = np.stack(raws, axis=0)
 
         # 2. -> substract black level
-        raw -= self.camera_params['black_level']
-        raw = np.clip(raw, 0, self.camera_params['max_value'] - self.camera_params['black_level'])
-        raw /= self.camera_params['max_value']
+        raw -= self.camera_params["black_level"]
+        raw = np.clip(
+            raw, 0, self.camera_params["max_value"] - self.camera_params["black_level"]
+        )
+        raw /= self.camera_params["max_value"]
         return raw
diff --git a/third_party/DarkFeat/datasets/utils/common.py b/third_party/DarkFeat/datasets/utils/common.py
index 6433408a39e53fcedb634901268754ed1ba971b3..aa2007b0b31df0325c51f4112a259ab1e1d7f1aa 100644
--- a/third_party/DarkFeat/datasets/utils/common.py
+++ b/third_party/DarkFeat/datasets/utils/common.py
@@ -28,31 +28,30 @@ class Notify(object):
 
     @ClassProperty
     def HEADER(cls):
-        return str(datetime.now()) + ': \033[95m'
+        return str(datetime.now()) + ": \033[95m"
 
     @ClassProperty
     def INFO(cls):
-        return str(datetime.now()) + ': \033[92mI'
+        return str(datetime.now()) + ": \033[92mI"
 
     @ClassProperty
     def OKBLUE(cls):
-        return str(datetime.now()) + ': \033[94m'
+        return str(datetime.now()) + ": \033[94m"
 
     @ClassProperty
     def WARNING(cls):
-        return str(datetime.now()) + ': \033[93mW'
+        return str(datetime.now()) + ": \033[93mW"
 
     @ClassProperty
     def FAIL(cls):
-        return str(datetime.now()) + ': \033[91mF'
+        return str(datetime.now()) + ": \033[91mF"
 
     @ClassProperty
     def BOLD(cls):
-        return str(datetime.now()) + ': \033[1mB'
+        return str(datetime.now()) + ": \033[1mB"
 
     @ClassProperty
     def UNDERLINE(cls):
-        return str(datetime.now()) + ': \033[4mU'
-    ENDC = '\033[0m'
-
+        return str(datetime.now()) + ": \033[4mU"
 
+    ENDC = "\033[0m"
diff --git a/third_party/DarkFeat/datasets/utils/photaug.py b/third_party/DarkFeat/datasets/utils/photaug.py
index 41f2278c720355470f00a881a1516cf1b71d2c4a..29b9130871f8cb96d714228fe22d8c6f4b6526e3 100644
--- a/third_party/DarkFeat/datasets/utils/photaug.py
+++ b/third_party/DarkFeat/datasets/utils/photaug.py
@@ -7,41 +7,45 @@ def random_brightness_np(image, max_abs_change=50):
     delta = random.uniform(-max_abs_change, max_abs_change)
     return np.clip(image + delta, 0, 255)
 
+
 def random_contrast_np(image, strength_range=[0.3, 1.5]):
     delta = random.uniform(*strength_range)
     mean = image.mean()
     return np.clip((image - mean) * delta + mean, 0, 255)
 
+
 def motion_blur_np(img, max_kernel_size=3):
     # Either vertial, hozirontal or diagonal blur
-    mode = np.random.choice(['h', 'v', 'diag_down', 'diag_up'])
-    ksize = np.random.randint(
-        0, (max_kernel_size+1)/2)*2 + 1  # make sure is odd
-    center = int((ksize-1)/2)
+    mode = np.random.choice(["h", "v", "diag_down", "diag_up"])
+    ksize = np.random.randint(0, (max_kernel_size + 1) / 2) * 2 + 1  # make sure is odd
+    center = int((ksize - 1) / 2)
     kernel = np.zeros((ksize, ksize))
-    if mode == 'h':
-        kernel[center, :] = 1.
-    elif mode == 'v':
-        kernel[:, center] = 1.
-    elif mode == 'diag_down':
+    if mode == "h":
+        kernel[center, :] = 1.0
+    elif mode == "v":
+        kernel[:, center] = 1.0
+    elif mode == "diag_down":
         kernel = np.eye(ksize)
-    elif mode == 'diag_up':
+    elif mode == "diag_up":
         kernel = np.flip(np.eye(ksize), 0)
-    var = ksize * ksize / 16.
+    var = ksize * ksize / 16.0
     grid = np.repeat(np.arange(ksize)[:, np.newaxis], ksize, axis=-1)
-    gaussian = np.exp(-(np.square(grid-center) +
-                        np.square(grid.T-center))/(2.*var))
+    gaussian = np.exp(
+        -(np.square(grid - center) + np.square(grid.T - center)) / (2.0 * var)
+    )
     kernel *= gaussian
     kernel /= np.sum(kernel)
     img = cv2.filter2D(img, -1, kernel)
     return np.clip(img, 0, 255)
 
+
 def additive_gaussian_noise(image, stddev_range=[5, 95]):
     stddev = random.uniform(*stddev_range)
     noise = np.random.normal(size=image.shape, scale=stddev)
     noisy_image = np.clip(image + noise, 0, 255)
     return noisy_image
 
+
 def photaug(img):
     img = random_brightness_np(img)
     img = random_contrast_np(img)
diff --git a/third_party/DarkFeat/demo_darkfeat.py b/third_party/DarkFeat/demo_darkfeat.py
index ca50ae5b892e7a90e75da7197c33bc0c06e699bf..be9a25c92f7e77da57ca111311dd96d426ba0c36 100644
--- a/third_party/DarkFeat/demo_darkfeat.py
+++ b/third_party/DarkFeat/demo_darkfeat.py
@@ -5,82 +5,106 @@ import matplotlib.cm as cm
 import torch
 import numpy as np
 from utils.nnmatching import NNMatching
-from utils.misc import (AverageTimer, VideoStreamer, make_matching_plot_fast, frame2tensor)
+from utils.misc import (
+    AverageTimer,
+    VideoStreamer,
+    make_matching_plot_fast,
+    frame2tensor,
+)
 
 torch.set_grad_enabled(False)
 
 
 def compute_essential(matched_kp1, matched_kp2, K):
-    pts1 = cv2.undistortPoints(matched_kp1,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0))
-    pts2 = cv2.undistortPoints(matched_kp2,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0))
+    pts1 = cv2.undistortPoints(
+        matched_kp1,
+        cameraMatrix=K,
+        distCoeffs=(-0.117918271740560, 0.075246403574314, 0, 0),
+    )
+    pts2 = cv2.undistortPoints(
+        matched_kp2,
+        cameraMatrix=K,
+        distCoeffs=(-0.117918271740560, 0.075246403574314, 0, 0),
+    )
     K_1 = np.eye(3)
     # Estimate the homography between the matches using RANSAC
-    ransac_model, ransac_inliers = cv2.findEssentialMat(pts1, pts2, K_1, method=cv2.RANSAC, prob=0.999, threshold=0.001, maxIters=10000)
-    if ransac_inliers is None or ransac_model.shape != (3,3):
+    ransac_model, ransac_inliers = cv2.findEssentialMat(
+        pts1, pts2, K_1, method=cv2.RANSAC, prob=0.999, threshold=0.001, maxIters=10000
+    )
+    if ransac_inliers is None or ransac_model.shape != (3, 3):
         ransac_inliers = np.array([])
         ransac_model = None
     return ransac_model, ransac_inliers, pts1, pts2
 
 
 sizer = (960, 640)
-focallength_x = 4.504986436499113e+03/(6744/sizer[0])
-focallength_y = 4.513311442889859e+03/(4502/sizer[1])
+focallength_x = 4.504986436499113e03 / (6744 / sizer[0])
+focallength_y = 4.513311442889859e03 / (4502 / sizer[1])
 K = np.eye(3)
-K[0,0] = focallength_x
-K[1,1] = focallength_y
-K[0,2] = 3.363322177533149e+03/(6744/sizer[0])# * 0.5
-K[1,2] = 2.291824660547715e+03/(4502/sizer[1])# * 0.5
+K[0, 0] = focallength_x
+K[1, 1] = focallength_y
+K[0, 2] = 3.363322177533149e03 / (6744 / sizer[0])  # * 0.5
+K[1, 2] = 2.291824660547715e03 / (4502 / sizer[1])  # * 0.5
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     parser = argparse.ArgumentParser(
-        description='DarkFeat demo',
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+        description="DarkFeat demo",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument("--input", type=str, help="path to an image directory")
     parser.add_argument(
-        '--input', type=str,
-        help='path to an image directory')
-    parser.add_argument(
-        '--output_dir', type=str, default=None,
-        help='Directory where to write output frames (If None, no output)')
+        "--output_dir",
+        type=str,
+        default=None,
+        help="Directory where to write output frames (If None, no output)",
+    )
 
     parser.add_argument(
-        '--image_glob', type=str, nargs='+', default=['*.ARW'],
-        help='Glob if a directory of images is specified')
+        "--image_glob",
+        type=str,
+        nargs="+",
+        default=["*.ARW"],
+        help="Glob if a directory of images is specified",
+    )
     parser.add_argument(
-        '--resize', type=int, nargs='+', default=[640, 480],
-        help='Resize the input image before running inference. If two numbers, '
-             'resize to the exact dimensions, if one number, resize the max '
-             'dimension, if -1, do not resize')
+        "--resize",
+        type=int,
+        nargs="+",
+        default=[640, 480],
+        help="Resize the input image before running inference. If two numbers, "
+        "resize to the exact dimensions, if one number, resize the max "
+        "dimension, if -1, do not resize",
+    )
     parser.add_argument(
-        '--force_cpu', action='store_true',
-        help='Force pytorch to run in CPU mode.')
-    parser.add_argument('--model_path', type=str,
-                        help='Path to the pretrained model')
+        "--force_cpu", action="store_true", help="Force pytorch to run in CPU mode."
+    )
+    parser.add_argument("--model_path", type=str, help="Path to the pretrained model")
 
     opt = parser.parse_args()
     print(opt)
 
     assert len(opt.resize) == 2
-    print('Will resize to {}x{} (WxH)'.format(opt.resize[0], opt.resize[1]))
+    print("Will resize to {}x{} (WxH)".format(opt.resize[0], opt.resize[1]))
 
-    device = 'cuda' if torch.cuda.is_available() and not opt.force_cpu else 'cpu'
-    print('Running inference on device \"{}\"'.format(device))
+    device = "cuda" if torch.cuda.is_available() and not opt.force_cpu else "cpu"
+    print('Running inference on device "{}"'.format(device))
     matching = NNMatching(opt.model_path).eval().to(device)
-    keys = ['keypoints', 'scores', 'descriptors']
+    keys = ["keypoints", "scores", "descriptors"]
 
     vs = VideoStreamer(opt.input, opt.resize, opt.image_glob)
     frame, ret = vs.next_frame()
-    assert ret, 'Error when reading the first frame (try different --input?)'
+    assert ret, "Error when reading the first frame (try different --input?)"
 
     frame_tensor = frame2tensor(frame, device)
-    last_data = matching.darkfeat({'image': frame_tensor})
-    last_data = {k+'0': [last_data[k]] for k in keys}
-    last_data['image0'] = frame_tensor
+    last_data = matching.darkfeat({"image": frame_tensor})
+    last_data = {k + "0": [last_data[k]] for k in keys}
+    last_data["image0"] = frame_tensor
     last_frame = frame
     last_image_id = 0
 
     if opt.output_dir is not None:
-        print('==> Will write outputs to {}'.format(opt.output_dir))
+        print("==> Will write outputs to {}".format(opt.output_dir))
         Path(opt.output_dir).mkdir(exist_ok=True)
 
     timer = AverageTimer()
@@ -88,37 +112,43 @@ if __name__ == '__main__':
     while True:
         frame, ret = vs.next_frame()
         if not ret:
-            print('Finished demo_darkfeat.py')
+            print("Finished demo_darkfeat.py")
             break
-        timer.update('data')
+        timer.update("data")
         stem0, stem1 = last_image_id, vs.i - 1
 
         frame_tensor = frame2tensor(frame, device)
-        pred = matching({**last_data, 'image1': frame_tensor})
-        kpts0 = last_data['keypoints0'][0].cpu().numpy()
-        kpts1 = pred['keypoints1'][0].cpu().numpy()
-        matches = pred['matches0'][0].cpu().numpy()
-        confidence = pred['matching_scores0'][0].cpu().numpy()
-        timer.update('forward')
+        pred = matching({**last_data, "image1": frame_tensor})
+        kpts0 = last_data["keypoints0"][0].cpu().numpy()
+        kpts1 = pred["keypoints1"][0].cpu().numpy()
+        matches = pred["matches0"][0].cpu().numpy()
+        confidence = pred["matching_scores0"][0].cpu().numpy()
+        timer.update("forward")
 
         valid = matches > -1
         mkpts0 = kpts0[valid]
         mkpts1 = kpts1[matches[valid]]
 
         E, inliers, pts1, pts2 = compute_essential(mkpts0, mkpts1, K)
-        color = cm.jet(np.clip(confidence[valid][inliers[:, 0].astype('bool')] * 2 - 1, -1, 1))
+        color = cm.jet(
+            np.clip(confidence[valid][inliers[:, 0].astype("bool")] * 2 - 1, -1, 1)
+        )
 
-        text = [
-            'DarkFeat',
-            'Matches: {}'.format(inliers.sum())
-        ]
+        text = ["DarkFeat", "Matches: {}".format(inliers.sum())]
 
         out = make_matching_plot_fast(
-            last_frame, frame, mkpts0[inliers[:, 0].astype('bool')], mkpts1[inliers[:, 0].astype('bool')], color, text,
-            path=None, small_text=' ')
+            last_frame,
+            frame,
+            mkpts0[inliers[:, 0].astype("bool")],
+            mkpts1[inliers[:, 0].astype("bool")],
+            color,
+            text,
+            path=None,
+            small_text=" ",
+        )
 
         if opt.output_dir is not None:
-            stem = 'matches_{:06}_{:06}'.format(stem0, stem1)
-            out_file = str(Path(opt.output_dir, stem + '.png'))
-            print('Writing image to {}'.format(out_file))
+            stem = "matches_{:06}_{:06}".format(stem0, stem1)
+            out_file = str(Path(opt.output_dir, stem + ".png"))
+            print("Writing image to {}".format(out_file))
             cv2.imwrite(out_file, out)
diff --git a/third_party/DarkFeat/export_features.py b/third_party/DarkFeat/export_features.py
index c7caea5e57890948728f84cbb7e68e59d455e171..da54e3dc0a1fed98e832b9cc5d6961e713087b8b 100644
--- a/third_party/DarkFeat/export_features.py
+++ b/third_party/DarkFeat/export_features.py
@@ -11,6 +11,7 @@ import cv2
 from darkfeat import DarkFeat
 from utils import matching
 
+
 def darkfeat_pre(img, cuda):
     H, W = img.shape[0], img.shape[1]
     inp = img.copy()
@@ -21,24 +22,25 @@ def darkfeat_pre(img, cuda):
         inp = inp.cuda()
     return inp
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     # Parse command line arguments.
     parser = argparse.ArgumentParser()
-    parser.add_argument('--H', type=int, default=int(640))
-    parser.add_argument('--W', type=int, default=int(960))
-    parser.add_argument('--histeq', action='store_true')
-    parser.add_argument('--model_path', type=str)
-    parser.add_argument('--dataset_dir', type=str, default='/data/hyz/MID/')
+    parser.add_argument("--H", type=int, default=int(640))
+    parser.add_argument("--W", type=int, default=int(960))
+    parser.add_argument("--histeq", action="store_true")
+    parser.add_argument("--model_path", type=str)
+    parser.add_argument("--dataset_dir", type=str, default="/data/hyz/MID/")
     opt = parser.parse_args()
 
     sizer = (opt.W, opt.H)
-    focallength_x = 4.504986436499113e+03/(6744/sizer[0])
-    focallength_y = 4.513311442889859e+03/(4502/sizer[1])
+    focallength_x = 4.504986436499113e03 / (6744 / sizer[0])
+    focallength_y = 4.513311442889859e03 / (4502 / sizer[1])
     K = np.eye(3)
-    K[0,0] = focallength_x
-    K[1,1] = focallength_y
-    K[0,2] = 3.363322177533149e+03/(6744/sizer[0])# * 0.5
-    K[1,2] = 2.291824660547715e+03/(4502/sizer[1])# * 0.5
+    K[0, 0] = focallength_x
+    K[1, 1] = focallength_y
+    K[0, 2] = 3.363322177533149e03 / (6744 / sizer[0])  # * 0.5
+    K[1, 2] = 2.291824660547715e03 / (4502 / sizer[1])  # * 0.5
     Kinv = np.linalg.inv(K)
     Kinvt = np.transpose(Kinv)
 
@@ -46,83 +48,111 @@ if __name__ == '__main__':
     if cuda:
         darkfeat = DarkFeat(opt.model_path).cuda().eval()
 
-    for scene in ['Indoor', 'Outdoor']:
-        base_save = './result/' + scene + '/'
-        dir_base = opt.dataset_dir + '/' + scene + '/'
+    for scene in ["Indoor", "Outdoor"]:
+        base_save = "./result/" + scene + "/"
+        dir_base = opt.dataset_dir + "/" + scene + "/"
         pair_list = sorted(os.listdir(dir_base))
 
         for pair in tqdm.tqdm(pair_list):
             opention = 1
-            if scene == 'Outdoor':
+            if scene == "Outdoor":
                 pass
             else:
                 if int(pair[4::]) <= 17:
                     opention = 0
                 else:
                     pass
-            name=[]
-            files = sorted(os.listdir(dir_base+pair))
+            name = []
+            files = sorted(os.listdir(dir_base + pair))
             for file_ in files:
-                if file_.endswith('.cr2'):
+                if file_.endswith(".cr2"):
                     name.append(file_[0:9])
-            ISO = ['00100', '00200', '00400', '00800', '01600', '03200', '06400', '12800']
+            ISO = [
+                "00100",
+                "00200",
+                "00400",
+                "00800",
+                "01600",
+                "03200",
+                "06400",
+                "12800",
+            ]
             if opention == 1:
-                Shutter_speed = ['0.005','0.01','0.025','0.05','0.17','0.5']
+                Shutter_speed = ["0.005", "0.01", "0.025", "0.05", "0.17", "0.5"]
             else:
-                Shutter_speed = ['0.01','0.02','0.05','0.1','0.3','1']
+                Shutter_speed = ["0.01", "0.02", "0.05", "0.1", "0.3", "1"]
 
-            E_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'E_estimated.npy')
-            F_GT = np.dot(np.dot(Kinvt,E_GT),Kinv)
-            R_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'R_GT.npy')
-            t_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'T_GT.npy')
+            E_GT = np.load(dir_base + pair + "/GT_Correspondence/" + "E_estimated.npy")
+            F_GT = np.dot(np.dot(Kinvt, E_GT), Kinv)
+            R_GT = np.load(dir_base + pair + "/GT_Correspondence/" + "R_GT.npy")
+            t_GT = np.load(dir_base + pair + "/GT_Correspondence/" + "T_GT.npy")
 
-            id0, id1 = sorted([ int(i.split('/')[-1]) for i in glob.glob(f'{dir_base+pair}/?????') ])
+            id0, id1 = sorted(
+                [int(i.split("/")[-1]) for i in glob.glob(f"{dir_base+pair}/?????")]
+            )
 
             cnt = 0
 
             for iso in ISO:
                 for ex in Shutter_speed:
-                    dark_name1 = name[0] + iso+'_'+ex+'_'+scene+'.npy'
-                    dark_name2 = name[1] + iso+'_'+ex+'_'+scene+'.npy'
+                    dark_name1 = name[0] + iso + "_" + ex + "_" + scene + ".npy"
+                    dark_name2 = name[1] + iso + "_" + ex + "_" + scene + ".npy"
 
                     if not opt.histeq:
-                        dst_T1_None = f'{dir_base}{pair}/{id0:05d}-npy-nohisteq/{dark_name1}'
-                        dst_T2_None = f'{dir_base}{pair}/{id1:05d}-npy-nohisteq/{dark_name2}'
+                        dst_T1_None = (
+                            f"{dir_base}{pair}/{id0:05d}-npy-nohisteq/{dark_name1}"
+                        )
+                        dst_T2_None = (
+                            f"{dir_base}{pair}/{id1:05d}-npy-nohisteq/{dark_name2}"
+                        )
 
                         img1_orig_None = np.load(dst_T1_None)
                         img2_orig_None = np.load(dst_T2_None)
 
-                        dir_save = base_save + pair + '/None/'
+                        dir_save = base_save + pair + "/None/"
 
-                        img_input1 = darkfeat_pre(img1_orig_None.astype('float32')/255.0, cuda)
-                        img_input2 = darkfeat_pre(img2_orig_None.astype('float32')/255.0, cuda)
+                        img_input1 = darkfeat_pre(
+                            img1_orig_None.astype("float32") / 255.0, cuda
+                        )
+                        img_input2 = darkfeat_pre(
+                            img2_orig_None.astype("float32") / 255.0, cuda
+                        )
 
                     else:
-                        dst_T1_histeq = f'{dir_base}{pair}/{id0:05d}-npy/{dark_name1}'
-                        dst_T2_histeq = f'{dir_base}{pair}/{id1:05d}-npy/{dark_name2}'
+                        dst_T1_histeq = f"{dir_base}{pair}/{id0:05d}-npy/{dark_name1}"
+                        dst_T2_histeq = f"{dir_base}{pair}/{id1:05d}-npy/{dark_name2}"
 
                         img1_orig_histeq = np.load(dst_T1_histeq)
                         img2_orig_histeq = np.load(dst_T2_histeq)
 
-                        dir_save = base_save + pair + '/HistEQ/'
+                        dir_save = base_save + pair + "/HistEQ/"
 
-                        img_input1 = darkfeat_pre(img1_orig_histeq.astype('float32')/255.0, cuda)
-                        img_input2 = darkfeat_pre(img2_orig_histeq.astype('float32')/255.0, cuda)
+                        img_input1 = darkfeat_pre(
+                            img1_orig_histeq.astype("float32") / 255.0, cuda
+                        )
+                        img_input2 = darkfeat_pre(
+                            img2_orig_histeq.astype("float32") / 255.0, cuda
+                        )
 
-                    result1 = darkfeat({'image': img_input1})
-                    result2 = darkfeat({'image': img_input2})
+                    result1 = darkfeat({"image": img_input1})
+                    result2 = darkfeat({"image": img_input2})
 
                     mkpts0, mkpts1, _ = matching.match_descriptors(
-                        cv2.KeyPoint_convert(result1['keypoints'].detach().cpu().float().numpy()), result1['descriptors'].detach().cpu().numpy(),
-                        cv2.KeyPoint_convert(result2['keypoints'].detach().cpu().float().numpy()), result2['descriptors'].detach().cpu().numpy(),
-                        ORB=False
+                        cv2.KeyPoint_convert(
+                            result1["keypoints"].detach().cpu().float().numpy()
+                        ),
+                        result1["descriptors"].detach().cpu().numpy(),
+                        cv2.KeyPoint_convert(
+                            result2["keypoints"].detach().cpu().float().numpy()
+                        ),
+                        result2["descriptors"].detach().cpu().numpy(),
+                        ORB=False,
                     )
 
-                    POINT_1_dir = dir_save+f'DarkFeat/POINT_1/'
-                    POINT_2_dir = dir_save+f'DarkFeat/POINT_2/'
-
-                    subprocess.check_output(['mkdir', '-p', POINT_1_dir])
-                    subprocess.check_output(['mkdir', '-p', POINT_2_dir])
-                    np.save(POINT_1_dir+dark_name1[0:-3]+'npy',mkpts0)
-                    np.save(POINT_2_dir+dark_name2[0:-3]+'npy',mkpts1)
+                    POINT_1_dir = dir_save + f"DarkFeat/POINT_1/"
+                    POINT_2_dir = dir_save + f"DarkFeat/POINT_2/"
 
+                    subprocess.check_output(["mkdir", "-p", POINT_1_dir])
+                    subprocess.check_output(["mkdir", "-p", POINT_2_dir])
+                    np.save(POINT_1_dir + dark_name1[0:-3] + "npy", mkpts0)
+                    np.save(POINT_2_dir + dark_name2[0:-3] + "npy", mkpts1)
diff --git a/third_party/DarkFeat/nets/geom.py b/third_party/DarkFeat/nets/geom.py
index 043ca6e8f5917c56defd6aa17c1ff236a431f8c0..d711ffdbf57aa023caa048adb3e7c8519aef7a3f 100644
--- a/third_party/DarkFeat/nets/geom.py
+++ b/third_party/DarkFeat/nets/geom.py
@@ -14,23 +14,25 @@ def rnd_sample(inputs, n_sample):
 def _grid_positions(h, w, bs):
     x_rng = torch.arange(0, w.int())
     y_rng = torch.arange(0, h.int())
-    xv, yv = torch.meshgrid(x_rng, y_rng, indexing='xy')
-    return torch.reshape(
-        torch.stack((yv, xv), axis=-1),
-        (1, -1, 2)
-    ).repeat(bs, 1, 1).float()
+    xv, yv = torch.meshgrid(x_rng, y_rng, indexing="xy")
+    return (
+        torch.reshape(torch.stack((yv, xv), axis=-1), (1, -1, 2))
+        .repeat(bs, 1, 1)
+        .float()
+    )
 
 
 def getK(ori_img_size, cur_feat_size, K):
     # WARNING: cur_feat_size's order is [h, w]
     r = ori_img_size / cur_feat_size[[1, 0]]
-    r_K0 = torch.stack([K[:, 0] / r[:, 0][..., None], K[:, 1] /
-                        r[:, 1][..., None], K[:, 2]], axis=1)
+    r_K0 = torch.stack(
+        [K[:, 0] / r[:, 0][..., None], K[:, 1] / r[:, 1][..., None], K[:, 2]], axis=1
+    )
     return r_K0
 
 
 def gather_nd(params, indices):
-    """ The same as tf.gather_nd but batched gather is not supported yet.
+    """The same as tf.gather_nd but batched gather is not supported yet.
     indices is an k-dimensional integer tensor, best thought of as a (k-1)-dimensional tensor of indices into params, where each element defines a slice of params:
 
     output[\\(i_0, ..., i_{k-2}\\)] = params[indices[\\(i_0, ..., i_{k-2}\\)]]
@@ -40,7 +42,7 @@ def gather_nd(params, indices):
         indices (Tensor): "k" dimensions. shape: [y_0,y_2,...,y_{k-2}, m]. m <= n.
 
     Returns: gathered Tensor.
-        shape [y_0,y_2,...y_{k-2}] + params.shape[m:] 
+        shape [y_0,y_2,...y_{k-2}] + params.shape[m:]
 
     """
     orig_shape = list(indices.shape)
@@ -52,13 +54,14 @@ def gather_nd(params, indices):
         out_shape = orig_shape[:-1] + list(params.shape)[m:]
     else:
         raise ValueError(
-            f'the last dimension of indices must less or equal to the rank of params. Got indices:{indices.shape}, params:{params.shape}. {m} > {n}'
+            f"the last dimension of indices must less or equal to the rank of params. Got indices:{indices.shape}, params:{params.shape}. {m} > {n}"
         )
 
     indices = indices.reshape((num_samples, m)).transpose(0, 1).tolist()
-    output = params[indices]    # (num_samples, ...)
+    output = params[indices]  # (num_samples, ...)
     return output.reshape(out_shape).contiguous()
 
+
 # input: pos [kpt_n, 2]; inputs [H, W, 128] / [H, W]
 # output: [kpt_n, 128] / [kpt_n]
 def interpolate(pos, inputs, nd=True):
@@ -94,17 +97,21 @@ def interpolate(pos, inputs, nd=True):
         w_bottom_right = w_bottom_right[..., None]
 
     interpolated_val = (
-        w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) +
-        w_top_right * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) +
-        w_bottom_left * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) +
-        w_bottom_right *
-        gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1))
+        w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1))
+        + w_top_right
+        * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1))
+        + w_bottom_left
+        * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1))
+        + w_bottom_right
+        * gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1))
     )
 
     return interpolated_val
 
 
-def validate_and_interpolate(pos, inputs, validate_corner=True, validate_val=None, nd=False):
+def validate_and_interpolate(
+    pos, inputs, validate_corner=True, validate_val=None, nd=False
+):
     if nd:
         h, w, c = inputs.shape
     else:
@@ -135,7 +142,7 @@ def validate_and_interpolate(pos, inputs, validate_corner=True, validate_val=Non
 
         valid_corner = torch.logical_and(
             torch.logical_and(valid_top_left, valid_top_right),
-            torch.logical_and(valid_bottom_left, valid_bottom_right)
+            torch.logical_and(valid_bottom_left, valid_bottom_right),
         )
 
         i_top_left = i_top_left[valid_corner]
@@ -157,12 +164,16 @@ def validate_and_interpolate(pos, inputs, validate_corner=True, validate_val=Non
         valid_depth = torch.logical_and(
             torch.logical_and(
                 gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) > 0,
-                gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) > 0
+                gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) > 0,
             ),
             torch.logical_and(
-                gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) > 0,
-                gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1)) > 0
-            )
+                gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1))
+                > 0,
+                gather_nd(
+                    inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1)
+                )
+                > 0,
+            ),
         )
 
         i_top_left = i_top_left[valid_depth]
@@ -196,10 +207,13 @@ def validate_and_interpolate(pos, inputs, validate_corner=True, validate_val=Non
         w_bottom_right = w_bottom_right[..., None]
 
     interpolated_val = (
-        w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) +
-        w_top_right * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) +
-        w_bottom_left * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) +
-        w_bottom_right * gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1))
+        w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1))
+        + w_top_right
+        * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1))
+        + w_bottom_left
+        * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1))
+        + w_bottom_right
+        * gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1))
     )
 
     pos = torch.stack([i, j], axis=1)
@@ -218,10 +232,21 @@ def getWarp(pos0, rel_pose, depth0, K0, depth1, K1, bs):
     for i in range(bs):
         z0, new_pos0, ids = validate_and_interpolate(pos0[i], depth0[i], validate_val=0)
 
-        uv0_homo = torch.cat([swap_axis(new_pos0), torch.ones((new_pos0.shape[0], 1)).to(new_pos0.device)], axis=-1)
+        uv0_homo = torch.cat(
+            [
+                swap_axis(new_pos0),
+                torch.ones((new_pos0.shape[0], 1)).to(new_pos0.device),
+            ],
+            axis=-1,
+        )
         xy0_homo = torch.matmul(torch.linalg.inv(K0[i]), uv0_homo.t())
-        xyz0_homo = torch.cat([torch.unsqueeze(z0, 0) * xy0_homo,
-                               torch.ones((1, new_pos0.shape[0])).to(z0.device)], axis=0)
+        xyz0_homo = torch.cat(
+            [
+                torch.unsqueeze(z0, 0) * xy0_homo,
+                torch.ones((1, new_pos0.shape[0])).to(z0.device),
+            ],
+            axis=0,
+        )
 
         xyz1 = torch.matmul(rel_pose[i], xyz0_homo)
         xy1_homo = xyz1 / torch.unsqueeze(xyz1[-1, :], axis=0)
@@ -229,7 +254,8 @@ def getWarp(pos0, rel_pose, depth0, K0, depth1, K1, bs):
 
         new_pos1 = swap_axis(uv1)
         annotated_depth, new_pos1, new_ids = validate_and_interpolate(
-            new_pos1, depth1[i], validate_val=0)
+            new_pos1, depth1[i], validate_val=0
+        )
 
         ids = ids[new_ids]
         new_pos0 = new_pos0[new_ids]
@@ -256,10 +282,21 @@ def getWarpNoValidate(pos0, rel_pose, depth0, K0, depth1, K1, bs):
     for i in range(bs):
         z0, new_pos0, ids = validate_and_interpolate(pos0[i], depth0[i], validate_val=0)
 
-        uv0_homo = torch.cat([swap_axis(new_pos0), torch.ones((new_pos0.shape[0], 1)).to(new_pos0.device)], axis=-1)
+        uv0_homo = torch.cat(
+            [
+                swap_axis(new_pos0),
+                torch.ones((new_pos0.shape[0], 1)).to(new_pos0.device),
+            ],
+            axis=-1,
+        )
         xy0_homo = torch.matmul(torch.linalg.inv(K0[i]), uv0_homo.t())
-        xyz0_homo = torch.cat([torch.unsqueeze(z0, 0) * xy0_homo,
-                               torch.ones((1, new_pos0.shape[0])).to(z0.device)], axis=0)
+        xyz0_homo = torch.cat(
+            [
+                torch.unsqueeze(z0, 0) * xy0_homo,
+                torch.ones((1, new_pos0.shape[0])).to(z0.device),
+            ],
+            axis=0,
+        )
 
         xyz1 = torch.matmul(rel_pose[i], xyz0_homo)
         xy1_homo = xyz1 / torch.unsqueeze(xyz1[-1, :], axis=0)
@@ -267,7 +304,8 @@ def getWarpNoValidate(pos0, rel_pose, depth0, K0, depth1, K1, bs):
 
         new_pos1 = swap_axis(uv1)
         _, new_pos1, new_ids = validate_and_interpolate(
-            new_pos1, depth1[i], validate_val=0)
+            new_pos1, depth1[i], validate_val=0
+        )
 
         ids = ids[new_ids]
         new_pos0 = new_pos0[new_ids]
@@ -287,10 +325,17 @@ def getWarpNoValidate2(pos0, rel_pose, depth0, K0, depth1, K1):
 
     z0 = interpolate(pos0, depth0, nd=False)
 
-    uv0_homo = torch.cat([swap_axis(pos0), torch.ones((pos0.shape[0], 1)).to(pos0.device)], axis=-1)
+    uv0_homo = torch.cat(
+        [swap_axis(pos0), torch.ones((pos0.shape[0], 1)).to(pos0.device)], axis=-1
+    )
     xy0_homo = torch.matmul(torch.linalg.inv(K0), uv0_homo.t())
-    xyz0_homo = torch.cat([torch.unsqueeze(z0, 0) * xy0_homo,
-                            torch.ones((1, pos0.shape[0])).to(z0.device)], axis=0)
+    xyz0_homo = torch.cat(
+        [
+            torch.unsqueeze(z0, 0) * xy0_homo,
+            torch.ones((1, pos0.shape[0])).to(z0.device),
+        ],
+        axis=0,
+    )
 
     xyz1 = torch.matmul(rel_pose, xyz0_homo)
     xy1_homo = xyz1 / torch.unsqueeze(xyz1[-1, :], axis=0)
@@ -301,22 +346,18 @@ def getWarpNoValidate2(pos0, rel_pose, depth0, K0, depth1, K1):
     return new_pos1
 
 
-
 def get_dist_mat(feat1, feat2, dist_type):
     eps = 1e-6
     cos_dist_mat = torch.matmul(feat1, feat2.t())
-    if dist_type == 'cosine_dist':
+    if dist_type == "cosine_dist":
         dist_mat = torch.clamp(cos_dist_mat, -1, 1)
-    elif dist_type == 'euclidean_dist':
+    elif dist_type == "euclidean_dist":
         dist_mat = torch.sqrt(torch.clamp(2 - 2 * cos_dist_mat, min=eps))
-    elif dist_type == 'euclidean_dist_no_norm':
+    elif dist_type == "euclidean_dist_no_norm":
         norm1 = torch.sum(feat1 * feat1, axis=-1, keepdims=True)
         norm2 = torch.sum(feat2 * feat2, axis=-1, keepdims=True)
         dist_mat = torch.sqrt(
-            torch.clamp(
-                norm1 - 2 * cos_dist_mat + norm2.t(),
-                min=0.
-            ) + eps
+            torch.clamp(norm1 - 2 * cos_dist_mat + norm2.t(), min=0.0) + eps
         )
     else:
         raise NotImplementedError()
diff --git a/third_party/DarkFeat/nets/l2net.py b/third_party/DarkFeat/nets/l2net.py
index e1ddfe8919bd4d5fe75215d253525123e1402952..b51dc0e9e983c7795924f75b2a814bea85fd08a0 100644
--- a/third_party/DarkFeat/nets/l2net.py
+++ b/third_party/DarkFeat/nets/l2net.py
@@ -7,9 +7,10 @@ from .score import peakiness_score
 
 
 class BaseNet(nn.Module):
-    """ Helper class to construct a fully-convolutional network that
-        extract a l2-normalized patch descriptor.
+    """Helper class to construct a fully-convolutional network that
+    extract a l2-normalized patch descriptor.
     """
+
     def __init__(self, inchan=3, dilated=True, dilation=1, bn=True, bn_affine=False):
         super(BaseNet, self).__init__()
         self.inchan = inchan
@@ -22,27 +23,42 @@ class BaseNet(nn.Module):
     def _make_bn(self, outd):
         return nn.BatchNorm2d(outd, affine=self.bn_affine)
 
-    def _add_conv(self, outd, k=3, stride=1, dilation=1, bn=True, relu=True, k_pool = 1, pool_type='max', bias=False):
+    def _add_conv(
+        self,
+        outd,
+        k=3,
+        stride=1,
+        dilation=1,
+        bn=True,
+        relu=True,
+        k_pool=1,
+        pool_type="max",
+        bias=False,
+    ):
         # as in the original implementation, dilation is applied at the end of layer, so it will have impact only from next layer
         d = self.dilation * dilation
-        # if self.dilated: 
+        # if self.dilated:
         #     conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=1)
         #     self.dilation *= stride
         # else:
         #     conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=stride)
-        conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=stride, bias=bias)
+        conv_params = dict(
+            padding=((k - 1) * d) // 2, dilation=d, stride=stride, bias=bias
+        )
 
         ops = nn.ModuleList([])
 
-        ops.append( nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params) )
-        if bn and self.bn: ops.append( self._make_bn(outd) )
-        if relu: ops.append( nn.ReLU(inplace=True) )
+        ops.append(nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params))
+        if bn and self.bn:
+            ops.append(self._make_bn(outd))
+        if relu:
+            ops.append(nn.ReLU(inplace=True))
         self.curchan = outd
-        
+
         if k_pool > 1:
-            if pool_type == 'avg':
+            if pool_type == "avg":
                 ops.append(torch.nn.AvgPool2d(kernel_size=k_pool))
-            elif pool_type == 'max':
+            elif pool_type == "max":
                 ops.append(torch.nn.MaxPool2d(kernel_size=k_pool))
             else:
                 print(f"Error, unknown pooling type {pool_type}...")
@@ -51,29 +67,31 @@ class BaseNet(nn.Module):
 
 
 class Quad_L2Net(BaseNet):
-    """ Same than L2_Net, but replace the final 8x8 conv by 3 successive 2x2 convs.
-    """
+    """Same than L2_Net, but replace the final 8x8 conv by 3 successive 2x2 convs."""
+
     def __init__(self, dim=128, mchan=4, relu22=False, **kw):
         BaseNet.__init__(self, **kw)
-        self.conv0 = self._add_conv(  8*mchan)
-        self.conv1 = self._add_conv(  8*mchan, bn=False)
-        self.bn1 = self._make_bn(8*mchan)
-        self.conv2 = self._add_conv( 16*mchan, stride=2)
-        self.conv3 = self._add_conv( 16*mchan, bn=False)
-        self.bn3 = self._make_bn(16*mchan)
-        self.conv4 = self._add_conv( 32*mchan, stride=2)
-        self.conv5 = self._add_conv( 32*mchan)
+        self.conv0 = self._add_conv(8 * mchan)
+        self.conv1 = self._add_conv(8 * mchan, bn=False)
+        self.bn1 = self._make_bn(8 * mchan)
+        self.conv2 = self._add_conv(16 * mchan, stride=2)
+        self.conv3 = self._add_conv(16 * mchan, bn=False)
+        self.bn3 = self._make_bn(16 * mchan)
+        self.conv4 = self._add_conv(32 * mchan, stride=2)
+        self.conv5 = self._add_conv(32 * mchan)
         # replace last 8x8 convolution with 3 3x3 convolutions
-        self.conv6_0 = self._add_conv( 32*mchan)
-        self.conv6_1 = self._add_conv( 32*mchan)
+        self.conv6_0 = self._add_conv(32 * mchan)
+        self.conv6_1 = self._add_conv(32 * mchan)
         self.conv6_2 = self._add_conv(dim, bn=False, relu=False)
         self.out_dim = dim
 
-        self.moving_avg_params = nn.ParameterList([
-            Parameter(torch.tensor(1.), requires_grad=False),
-            Parameter(torch.tensor(1.), requires_grad=False),
-            Parameter(torch.tensor(1.), requires_grad=False)
-        ])
+        self.moving_avg_params = nn.ParameterList(
+            [
+                Parameter(torch.tensor(1.0), requires_grad=False),
+                Parameter(torch.tensor(1.0), requires_grad=False),
+                Parameter(torch.tensor(1.0), requires_grad=False),
+            ]
+        )
 
     def forward(self, x):
         # x: [N, C, H, W]
@@ -90,7 +108,7 @@ class Quad_L2Net(BaseNet):
         x6_2 = self.conv6_2(x6_1)
 
         # calculate score map
-        comb_weights = torch.tensor([1., 2., 3.], device=x.device)
+        comb_weights = torch.tensor([1.0, 2.0, 3.0], device=x.device)
         comb_weights /= torch.sum(comb_weights)
         ksize = [3, 2, 1]
         det_score_maps = []
@@ -98,15 +116,21 @@ class Quad_L2Net(BaseNet):
         for idx, xx in enumerate([x1, x3, x6_2]):
             if self.training:
                 instance_max = torch.max(xx)
-                self.moving_avg_params[idx].data = self.moving_avg_params[idx] * 0.99 + instance_max.detach() * 0.01
+                self.moving_avg_params[idx].data = (
+                    self.moving_avg_params[idx] * 0.99 + instance_max.detach() * 0.01
+                )
             else:
                 pass
 
-            alpha, beta = peakiness_score(xx, self.moving_avg_params[idx].detach(), ksize=3, dilation=ksize[idx])
+            alpha, beta = peakiness_score(
+                xx, self.moving_avg_params[idx].detach(), ksize=3, dilation=ksize[idx]
+            )
 
             score_vol = alpha * beta
             det_score_map = torch.max(score_vol, dim=1, keepdim=True)[0]
-            det_score_map = F.interpolate(det_score_map, size=x.shape[2:], mode='bilinear', align_corners=True)
+            det_score_map = F.interpolate(
+                det_score_map, size=x.shape[2:], mode="bilinear", align_corners=True
+            )
             det_score_map = comb_weights[idx] * det_score_map
             det_score_maps.append(det_score_map)
 
diff --git a/third_party/DarkFeat/nets/loss.py b/third_party/DarkFeat/nets/loss.py
index 0dd42b4214d021137ddfe72771ccad0264d2321f..1440ef46f43108db0053cf48369e4014c348f98c 100644
--- a/third_party/DarkFeat/nets/loss.py
+++ b/third_party/DarkFeat/nets/loss.py
@@ -4,10 +4,20 @@ import torch.nn.functional as F
 from .geom import rnd_sample, interpolate, get_dist_mat
 
 
-def make_detector_loss(pos0, pos1, dense_feat_map0, dense_feat_map1,
-               score_map0, score_map1, batch_size, num_corr, loss_type, config):
-    joint_loss = 0.
-    accuracy = 0.
+def make_detector_loss(
+    pos0,
+    pos1,
+    dense_feat_map0,
+    dense_feat_map1,
+    score_map0,
+    score_map1,
+    batch_size,
+    num_corr,
+    loss_type,
+    config,
+):
+    joint_loss = 0.0
+    accuracy = 0.0
     all_valid_pos0 = []
     all_valid_pos1 = []
     all_valid_match = []
@@ -22,36 +32,54 @@ def make_detector_loss(pos0, pos1, dense_feat_map0, dense_feat_map1,
         valid_feat0 = F.normalize(valid_feat0, p=2, dim=-1)
         valid_feat1 = F.normalize(valid_feat1, p=2, dim=-1)
 
-        valid_score0 = interpolate(valid_pos0, torch.squeeze(score_map0[i], dim=-1), nd=False)
-        valid_score1 = interpolate(valid_pos1, torch.squeeze(score_map1[i], dim=-1), nd=False)
-            
-        if config['network']['det']['corr_weight']:
+        valid_score0 = interpolate(
+            valid_pos0, torch.squeeze(score_map0[i], dim=-1), nd=False
+        )
+        valid_score1 = interpolate(
+            valid_pos1, torch.squeeze(score_map1[i], dim=-1), nd=False
+        )
+
+        if config["network"]["det"]["corr_weight"]:
             corr_weight = valid_score0 * valid_score1
         else:
             corr_weight = None
 
-        safe_radius = config['network']['det']['safe_radius']
+        safe_radius = config["network"]["det"]["safe_radius"]
         if safe_radius > 0:
             radius_mask_row = get_dist_mat(
-                valid_pos1, valid_pos1, "euclidean_dist_no_norm")
+                valid_pos1, valid_pos1, "euclidean_dist_no_norm"
+            )
             radius_mask_row = torch.le(radius_mask_row, safe_radius)
             radius_mask_col = get_dist_mat(
-                valid_pos0, valid_pos0, "euclidean_dist_no_norm")
+                valid_pos0, valid_pos0, "euclidean_dist_no_norm"
+            )
             radius_mask_col = torch.le(radius_mask_col, safe_radius)
-            radius_mask_row = radius_mask_row.float() - torch.eye(valid_num, device=radius_mask_row.device)
-            radius_mask_col = radius_mask_col.float() - torch.eye(valid_num, device=radius_mask_col.device)
+            radius_mask_row = radius_mask_row.float() - torch.eye(
+                valid_num, device=radius_mask_row.device
+            )
+            radius_mask_col = radius_mask_col.float() - torch.eye(
+                valid_num, device=radius_mask_col.device
+            )
         else:
             radius_mask_row = None
             radius_mask_col = None
 
         if valid_num < 32:
-            si_loss, si_accuracy, matched_mask = 0., 1., torch.zeros((1, valid_num)).bool()
+            si_loss, si_accuracy, matched_mask = (
+                0.0,
+                1.0,
+                torch.zeros((1, valid_num)).bool(),
+            )
         else:
             si_loss, si_accuracy, matched_mask = make_structured_loss(
-                torch.unsqueeze(valid_feat0, 0), torch.unsqueeze(valid_feat1, 0),
+                torch.unsqueeze(valid_feat0, 0),
+                torch.unsqueeze(valid_feat1, 0),
                 loss_type=loss_type,
-                radius_mask_row=radius_mask_row, radius_mask_col=radius_mask_col,
-                corr_weight=torch.unsqueeze(corr_weight, 0) if corr_weight is not None else None
+                radius_mask_row=radius_mask_row,
+                radius_mask_col=radius_mask_col,
+                corr_weight=torch.unsqueeze(corr_weight, 0)
+                if corr_weight is not None
+                else None,
             )
 
         joint_loss += si_loss / batch_size
@@ -63,10 +91,16 @@ def make_detector_loss(pos0, pos1, dense_feat_map0, dense_feat_map1,
     return joint_loss, accuracy
 
 
-def make_structured_loss(feat_anc, feat_pos,
-                         loss_type='RATIO', inlier_mask=None,
-                         radius_mask_row=None, radius_mask_col=None,
-                         corr_weight=None, dist_mat=None):
+def make_structured_loss(
+    feat_anc,
+    feat_pos,
+    loss_type="RATIO",
+    inlier_mask=None,
+    radius_mask_row=None,
+    radius_mask_col=None,
+    corr_weight=None,
+    dist_mat=None,
+):
     """
     Structured loss construction.
     Args:
@@ -82,23 +116,26 @@ def make_structured_loss(feat_anc, feat_pos,
         inlier_mask = torch.ones((batch_size, num_corr), device=feat_anc.device).bool()
     inlier_num = torch.count_nonzero(inlier_mask.float(), dim=-1)
 
-    if loss_type == 'L2NET' or loss_type == 'CIRCLE':
-        dist_type = 'cosine_dist'
-    elif loss_type.find('HARD') >= 0:
-        dist_type = 'euclidean_dist'
+    if loss_type == "L2NET" or loss_type == "CIRCLE":
+        dist_type = "cosine_dist"
+    elif loss_type.find("HARD") >= 0:
+        dist_type = "euclidean_dist"
     else:
         raise NotImplementedError()
 
     if dist_mat is None:
-        dist_mat = get_dist_mat(feat_anc.squeeze(0), feat_pos.squeeze(0), dist_type).unsqueeze(0)
+        dist_mat = get_dist_mat(
+            feat_anc.squeeze(0), feat_pos.squeeze(0), dist_type
+        ).unsqueeze(0)
     pos_vec = dist_mat[0].diag().unsqueeze(0)
 
-    if loss_type.find('HARD') >= 0:
+    if loss_type.find("HARD") >= 0:
         neg_margin = 1
-        dist_mat_without_min_on_diag = dist_mat + \
-            10 * torch.unsqueeze(torch.eye(num_corr, device=dist_mat.device), dim=0)
+        dist_mat_without_min_on_diag = dist_mat + 10 * torch.unsqueeze(
+            torch.eye(num_corr, device=dist_mat.device), dim=0
+        )
         mask = torch.le(dist_mat_without_min_on_diag, 0.008).float()
-        dist_mat_without_min_on_diag += mask*10
+        dist_mat_without_min_on_diag += mask * 10
 
         if radius_mask_row is not None:
             hard_neg_dist_row = dist_mat_without_min_on_diag + 10 * radius_mask_row
@@ -112,18 +149,18 @@ def make_structured_loss(feat_anc, feat_pos,
         hard_neg_dist_row = torch.min(hard_neg_dist_row, dim=-1)[0]
         hard_neg_dist_col = torch.min(hard_neg_dist_col, dim=-2)[0]
 
-        if loss_type == 'HARD_TRIPLET':
+        if loss_type == "HARD_TRIPLET":
             loss_row = torch.clamp(neg_margin + pos_vec - hard_neg_dist_row, min=0)
             loss_col = torch.clamp(neg_margin + pos_vec - hard_neg_dist_col, min=0)
-        elif loss_type == 'HARD_CONTRASTIVE':
+        elif loss_type == "HARD_CONTRASTIVE":
             pos_margin = 0.2
             pos_loss = torch.clamp(pos_vec - pos_margin, min=0)
             loss_row = pos_loss + torch.clamp(neg_margin - hard_neg_dist_row, min=0)
             loss_col = pos_loss + torch.clamp(neg_margin - hard_neg_dist_col, min=0)
         else:
             raise NotImplementedError()
-    
-    elif loss_type == 'CIRCLE':
+
+    elif loss_type == "CIRCLE":
         log_scale = 512
         m = 0.1
         neg_mask_row = torch.unsqueeze(torch.eye(num_corr, device=feat_anc.device), 0)
@@ -141,14 +178,26 @@ def make_structured_loss(feat_anc, feat_pos,
         neg_mat_row = dist_mat - 128 * neg_mask_row
         neg_mat_col = dist_mat - 128 * neg_mask_col
 
-        lse_positive = torch.logsumexp(-log_scale * (pos_vec[..., None] - pos_margin) * \
-                    torch.clamp(pos_optimal - pos_vec[..., None], min=0).detach(), dim=-1)
-        
-        lse_negative_row = torch.logsumexp(log_scale * (neg_mat_row - neg_margin) * \
-                    torch.clamp(neg_mat_row - neg_optimal, min=0).detach(), dim=-1)
-
-        lse_negative_col = torch.logsumexp(log_scale * (neg_mat_col - neg_margin) * \
-                    torch.clamp(neg_mat_col - neg_optimal, min=0).detach(), dim=-2)
+        lse_positive = torch.logsumexp(
+            -log_scale
+            * (pos_vec[..., None] - pos_margin)
+            * torch.clamp(pos_optimal - pos_vec[..., None], min=0).detach(),
+            dim=-1,
+        )
+
+        lse_negative_row = torch.logsumexp(
+            log_scale
+            * (neg_mat_row - neg_margin)
+            * torch.clamp(neg_mat_row - neg_optimal, min=0).detach(),
+            dim=-1,
+        )
+
+        lse_negative_col = torch.logsumexp(
+            log_scale
+            * (neg_mat_col - neg_margin)
+            * torch.clamp(neg_mat_col - neg_optimal, min=0).detach(),
+            dim=-2,
+        )
 
         loss_row = F.softplus(lse_positive + lse_negative_row) / log_scale
         loss_col = F.softplus(lse_positive + lse_negative_col) / log_scale
@@ -156,10 +205,10 @@ def make_structured_loss(feat_anc, feat_pos,
     else:
         raise NotImplementedError()
 
-    if dist_type == 'cosine_dist':
+    if dist_type == "cosine_dist":
         err_row = dist_mat - torch.unsqueeze(pos_vec, -1)
         err_col = dist_mat - torch.unsqueeze(pos_vec, -2)
-    elif dist_type == 'euclidean_dist' or dist_type == 'euclidean_dist_no_norm':
+    elif dist_type == "euclidean_dist" or dist_type == "euclidean_dist_no_norm":
         err_row = torch.unsqueeze(pos_vec, -1) - dist_mat
         err_col = torch.unsqueeze(pos_vec, -2) - dist_mat
     else:
@@ -180,17 +229,18 @@ def make_structured_loss(feat_anc, feat_pos,
 
     for i in range(batch_size):
         if corr_weight is not None:
-            loss += torch.sum(tot_loss[i][inlier_mask[i]]) / \
-                (torch.sum(corr_weight[i][inlier_mask[i]]) + 1e-6)
+            loss += torch.sum(tot_loss[i][inlier_mask[i]]) / (
+                torch.sum(corr_weight[i][inlier_mask[i]]) + 1e-6
+            )
         else:
             loss += torch.mean(tot_loss[i][inlier_mask[i]])
         cnt_err_row = torch.count_nonzero(err_row[i][inlier_mask[i]]).float()
         cnt_err_col = torch.count_nonzero(err_col[i][inlier_mask[i]]).float()
         tot_err = cnt_err_row + cnt_err_col
         if inlier_num[i] != 0:
-            accuracy += 1. - tot_err / inlier_num[i] / batch_size / 2.
+            accuracy += 1.0 - tot_err / inlier_num[i] / batch_size / 2.0
         else:
-            accuracy += 1.
+            accuracy += 1.0
 
     matched_mask = torch.logical_and(torch.eq(err_row, 0), torch.eq(err_col, 0))
     matched_mask = torch.logical_and(matched_mask, inlier_mask)
@@ -205,11 +255,13 @@ def make_structured_loss(feat_anc, feat_pos,
 # for the rest, the noise image's score should less than normal image
 # input: score_map [batch_size, H, W, 1]; indices [2, k, 2]
 # output: loss [scalar]
-def make_noise_score_map_loss(score_map, noise_score_map, indices, batch_size, thld=0.):
+def make_noise_score_map_loss(
+    score_map, noise_score_map, indices, batch_size, thld=0.0
+):
     H, W = score_map.shape[1:3]
     loss = 0
     for i in range(batch_size):
-        kpts_coords = indices[i].T # (2, num_kpts)
+        kpts_coords = indices[i].T  # (2, num_kpts)
         mask = torch.zeros([H, W], device=score_map.device)
         mask[kpts_coords.cpu().numpy()] = 1
 
@@ -217,8 +269,13 @@ def make_noise_score_map_loss(score_map, noise_score_map, indices, batch_size, t
         kernel = torch.ones([1, 1, 3, 3], device=score_map.device)
         mask = F.conv2d(mask.unsqueeze(0).unsqueeze(0), kernel, padding=1)[0, 0] > 0
 
-        loss1 = torch.sum(torch.abs(score_map[i] - noise_score_map[i]).squeeze() * mask) / torch.sum(mask)
-        loss2 = torch.sum(torch.clamp(noise_score_map[i] - score_map[i] - thld, min=0).squeeze() * torch.logical_not(mask)) / (H * W - torch.sum(mask))
+        loss1 = torch.sum(
+            torch.abs(score_map[i] - noise_score_map[i]).squeeze() * mask
+        ) / torch.sum(mask)
+        loss2 = torch.sum(
+            torch.clamp(noise_score_map[i] - score_map[i] - thld, min=0).squeeze()
+            * torch.logical_not(mask)
+        ) / (H * W - torch.sum(mask))
 
         loss += loss1
         loss += loss2
@@ -229,16 +286,28 @@ def make_noise_score_map_loss(score_map, noise_score_map, indices, batch_size, t
     return loss, first_mask
 
 
-def make_noise_score_map_loss_labelmap(score_map, noise_score_map, labelmap, batch_size, thld=0.):
+def make_noise_score_map_loss_labelmap(
+    score_map, noise_score_map, labelmap, batch_size, thld=0.0
+):
     H, W = score_map.shape[1:3]
     loss = 0
     for i in range(batch_size):
         # using 3x3 kernel to put kpts' neightborhood area into the mask
         kernel = torch.ones([1, 1, 3, 3], device=score_map.device)
-        mask = F.conv2d(labelmap[i].unsqueeze(0).to(score_map.device).float(), kernel, padding=1)[0, 0] > 0
-
-        loss1 = torch.sum(torch.abs(score_map[i] - noise_score_map[i]).squeeze() * mask) / torch.sum(mask)
-        loss2 = torch.sum(torch.clamp(noise_score_map[i] - score_map[i] - thld, min=0).squeeze() * torch.logical_not(mask)) / (H * W - torch.sum(mask))
+        mask = (
+            F.conv2d(
+                labelmap[i].unsqueeze(0).to(score_map.device).float(), kernel, padding=1
+            )[0, 0]
+            > 0
+        )
+
+        loss1 = torch.sum(
+            torch.abs(score_map[i] - noise_score_map[i]).squeeze() * mask
+        ) / torch.sum(mask)
+        loss2 = torch.sum(
+            torch.clamp(noise_score_map[i] - score_map[i] - thld, min=0).squeeze()
+            * torch.logical_not(mask)
+        ) / (H * W - torch.sum(mask))
 
         loss += loss1
         loss += loss2
diff --git a/third_party/DarkFeat/nets/multi_sampler.py b/third_party/DarkFeat/nets/multi_sampler.py
index dc400fb2afeb50575cd81d3c01b605bea6db1121..862a6e9e785f826853021c27d5c0fc2cfa2c2f51 100644
--- a/third_party/DarkFeat/nets/multi_sampler.py
+++ b/third_party/DarkFeat/nets/multi_sampler.py
@@ -5,17 +5,28 @@ import numpy as np
 
 from .geom import rnd_sample, interpolate
 
-class MultiSampler (nn.Module):
-    """ Similar to NghSampler, but doesnt warp the 2nd image.
+
+class MultiSampler(nn.Module):
+    """Similar to NghSampler, but doesnt warp the 2nd image.
     Distance to GT =>  0 ... pos_d ... neg_d ... ngh
     Pixel label    =>  + + + + + + 0 0 - - - - - - -
-    
+
     Subsample on query side: if > 0, regular grid
-                                < 0, random points 
+                                < 0, random points
     In both cases, the number of query points is = W*H/subq**2
     """
-    def __init__(self, ngh, subq=1, subd=1, pos_d=0, neg_d=2, border=None,
-                       maxpool_pos=True, subd_neg=0):
+
+    def __init__(
+        self,
+        ngh,
+        subq=1,
+        subd=1,
+        pos_d=0,
+        neg_d=2,
+        border=None,
+        maxpool_pos=True,
+        subd_neg=0,
+    ):
         nn.Module.__init__(self)
         assert 0 <= pos_d < neg_d <= (ngh if ngh else 99)
         self.ngh = ngh
@@ -26,8 +37,9 @@ class MultiSampler (nn.Module):
         self.sub_q = subq
         self.sub_d = subd
         self.sub_d_neg = subd_neg
-        if border is None: border = ngh
-        assert border >= ngh, 'border has to be larger than ngh'
+        if border is None:
+            border = ngh
+        assert border >= ngh, "border has to be larger than ngh"
         self.border = border
         self.maxpool_pos = maxpool_pos
         self.precompute_offsets()
@@ -36,22 +48,37 @@ class MultiSampler (nn.Module):
         pos_d2 = self.pos_d**2
         neg_d2 = self.neg_d**2
         rad2 = self.ngh**2
-        rad = (self.ngh//self.sub_d) * self.ngh # make an integer multiple
+        rad = (self.ngh // self.sub_d) * self.ngh  # make an integer multiple
         pos = []
         neg = []
-        for j in range(-rad, rad+1, self.sub_d):
-          for i in range(-rad, rad+1, self.sub_d):
-            d2 = i*i + j*j
-            if d2 <= pos_d2:
-                pos.append( (i,j) )
-            elif neg_d2 <= d2 <= rad2: 
-                neg.append( (i,j) )
-
-        self.register_buffer('pos_offsets', torch.LongTensor(pos).view(-1,2).t())
-        self.register_buffer('neg_offsets', torch.LongTensor(neg).view(-1,2).t())
-
-
-    def forward(self, feat0, feat1, noise_feat0, noise_feat1, conf0, conf1, noise_conf0, noise_conf1, pos0, pos1, B, H, W, N=2500):
+        for j in range(-rad, rad + 1, self.sub_d):
+            for i in range(-rad, rad + 1, self.sub_d):
+                d2 = i * i + j * j
+                if d2 <= pos_d2:
+                    pos.append((i, j))
+                elif neg_d2 <= d2 <= rad2:
+                    neg.append((i, j))
+
+        self.register_buffer("pos_offsets", torch.LongTensor(pos).view(-1, 2).t())
+        self.register_buffer("neg_offsets", torch.LongTensor(neg).view(-1, 2).t())
+
+    def forward(
+        self,
+        feat0,
+        feat1,
+        noise_feat0,
+        noise_feat1,
+        conf0,
+        conf1,
+        noise_conf0,
+        noise_conf1,
+        pos0,
+        pos1,
+        B,
+        H,
+        W,
+        N=2500,
+    ):
         pscores_ls, nscores_ls, distractors_ls = [], [], []
         valid_feat0_ls = []
         noise_pscores_ls, noise_nscores_ls, noise_distractors_ls = [], [], []
@@ -62,58 +89,103 @@ class MultiSampler (nn.Module):
         mask_ls = []
 
         for i in range(B):
-            tmp_mask = (pos0[i][:, 1] >= self.border) * (pos0[i][:, 1] < W-self.border) \
-                * (pos0[i][:, 0] >= self.border) * (pos0[i][:, 0] < H-self.border)
+            tmp_mask = (
+                (pos0[i][:, 1] >= self.border)
+                * (pos0[i][:, 1] < W - self.border)
+                * (pos0[i][:, 0] >= self.border)
+                * (pos0[i][:, 0] < H - self.border)
+            )
 
             selected_pos0 = pos0[i][tmp_mask]
             selected_pos1 = pos1[i][tmp_mask]
             valid_pos0, valid_pos1 = rnd_sample([selected_pos0, selected_pos1], N)
 
             # sample features from first image
-            valid_feat0 = interpolate(valid_pos0 / 4, feat0[i]) # [N, 128]
-            valid_feat0 = F.normalize(valid_feat0, p=2, dim=-1) # [N, 128]
+            valid_feat0 = interpolate(valid_pos0 / 4, feat0[i])  # [N, 128]
+            valid_feat0 = F.normalize(valid_feat0, p=2, dim=-1)  # [N, 128]
             qconf = interpolate(valid_pos0 / 4, conf0[i])
 
-            valid_noise_feat0 = interpolate(valid_pos0 / 4, noise_feat0[i]) # [N, 128]
-            valid_noise_feat0 = F.normalize(valid_noise_feat0, p=2, dim=-1) # [N, 128]
+            valid_noise_feat0 = interpolate(valid_pos0 / 4, noise_feat0[i])  # [N, 128]
+            valid_noise_feat0 = F.normalize(valid_noise_feat0, p=2, dim=-1)  # [N, 128]
             noise_qconf = interpolate(valid_pos0 / 4, noise_conf0[i])
 
             # sample GT from second image
-            mask = (valid_pos1[:, 1] >= 0) * (valid_pos1[:, 1] < W) \
-                * (valid_pos1[:, 0] >= 0) * (valid_pos1[:, 0] < H)
+            mask = (
+                (valid_pos1[:, 1] >= 0)
+                * (valid_pos1[:, 1] < W)
+                * (valid_pos1[:, 0] >= 0)
+                * (valid_pos1[:, 0] < H)
+            )
 
             def clamp(xy):
                 xy = xy
-                torch.clamp(xy[0], 0, H-1, out=xy[0])
-                torch.clamp(xy[1], 0, W-1, out=xy[1])
+                torch.clamp(xy[0], 0, H - 1, out=xy[0])
+                torch.clamp(xy[1], 0, W - 1, out=xy[1])
                 return xy
 
             # compute positive scores
-            valid_pos1p = clamp(valid_pos1.t()[:,None,:] + self.pos_offsets[:,:,None].to(valid_pos1.device)) # [2, 29, N]
-            valid_pos1p = valid_pos1p.permute(1, 2, 0).reshape(-1, 2) # [29, N, 2] -> [29*N, 2]
-            valid_feat1p = interpolate(valid_pos1p / 4, feat1[i]).reshape(self.pos_offsets.shape[-1], -1, 128) # [29, N, 128]
-            valid_feat1p = F.normalize(valid_feat1p, p=2, dim=-1) # [29, N, 128]
-            valid_noise_feat1p = interpolate(valid_pos1p / 4, feat1[i]).reshape(self.pos_offsets.shape[-1], -1, 128) # [29, N, 128]
-            valid_noise_feat1p = F.normalize(valid_noise_feat1p, p=2, dim=-1) # [29, N, 128]
-
-            pscores = (valid_feat0[None,:,:] * valid_feat1p).sum(dim=-1).t() # [N, 29]
+            valid_pos1p = clamp(
+                valid_pos1.t()[:, None, :]
+                + self.pos_offsets[:, :, None].to(valid_pos1.device)
+            )  # [2, 29, N]
+            valid_pos1p = valid_pos1p.permute(1, 2, 0).reshape(
+                -1, 2
+            )  # [29, N, 2] -> [29*N, 2]
+            valid_feat1p = interpolate(valid_pos1p / 4, feat1[i]).reshape(
+                self.pos_offsets.shape[-1], -1, 128
+            )  # [29, N, 128]
+            valid_feat1p = F.normalize(valid_feat1p, p=2, dim=-1)  # [29, N, 128]
+            valid_noise_feat1p = interpolate(valid_pos1p / 4, feat1[i]).reshape(
+                self.pos_offsets.shape[-1], -1, 128
+            )  # [29, N, 128]
+            valid_noise_feat1p = F.normalize(
+                valid_noise_feat1p, p=2, dim=-1
+            )  # [29, N, 128]
+
+            pscores = (
+                (valid_feat0[None, :, :] * valid_feat1p).sum(dim=-1).t()
+            )  # [N, 29]
             pscores, pos = pscores.max(dim=1, keepdim=True)
-            sel = clamp(valid_pos1.t() + self.pos_offsets[:,pos.view(-1)].to(valid_pos1.device))
-            qconf = (qconf + interpolate(sel.t() / 4, conf1[i]))/2
-            noise_pscores = (valid_noise_feat0[None,:,:] * valid_noise_feat1p).sum(dim=-1).t() # [N, 29]
+            sel = clamp(
+                valid_pos1.t() + self.pos_offsets[:, pos.view(-1)].to(valid_pos1.device)
+            )
+            qconf = (qconf + interpolate(sel.t() / 4, conf1[i])) / 2
+            noise_pscores = (
+                (valid_noise_feat0[None, :, :] * valid_noise_feat1p).sum(dim=-1).t()
+            )  # [N, 29]
             noise_pscores, noise_pos = noise_pscores.max(dim=1, keepdim=True)
-            noise_sel = clamp(valid_pos1.t() + self.pos_offsets[:,noise_pos.view(-1)].to(valid_pos1.device))
-            noise_qconf = (noise_qconf + interpolate(noise_sel.t() / 4, noise_conf1[i]))/2
+            noise_sel = clamp(
+                valid_pos1.t()
+                + self.pos_offsets[:, noise_pos.view(-1)].to(valid_pos1.device)
+            )
+            noise_qconf = (
+                noise_qconf + interpolate(noise_sel.t() / 4, noise_conf1[i])
+            ) / 2
 
             # compute negative scores
-            valid_pos1n = clamp(valid_pos1.t()[:,None,:] + self.neg_offsets[:,:,None].to(valid_pos1.device)) # [2, 29, N]
-            valid_pos1n = valid_pos1n.permute(1, 2, 0).reshape(-1, 2) # [29, N, 2] -> [29*N, 2]
-            valid_feat1n = interpolate(valid_pos1n / 4, feat1[i]).reshape(self.neg_offsets.shape[-1], -1, 128) # [29, N, 128]
-            valid_feat1n = F.normalize(valid_feat1n, p=2, dim=-1) # [29, N, 128]
-            nscores = (valid_feat0[None,:,:] * valid_feat1n).sum(dim=-1).t() # [N, 29]
-            valid_noise_feat1n = interpolate(valid_pos1n / 4, noise_feat1[i]).reshape(self.neg_offsets.shape[-1], -1, 128) # [29, N, 128]
-            valid_noise_feat1n = F.normalize(valid_noise_feat1n, p=2, dim=-1) # [29, N, 128]
-            noise_nscores = (valid_noise_feat0[None,:,:] * valid_noise_feat1n).sum(dim=-1).t() # [N, 29]
+            valid_pos1n = clamp(
+                valid_pos1.t()[:, None, :]
+                + self.neg_offsets[:, :, None].to(valid_pos1.device)
+            )  # [2, 29, N]
+            valid_pos1n = valid_pos1n.permute(1, 2, 0).reshape(
+                -1, 2
+            )  # [29, N, 2] -> [29*N, 2]
+            valid_feat1n = interpolate(valid_pos1n / 4, feat1[i]).reshape(
+                self.neg_offsets.shape[-1], -1, 128
+            )  # [29, N, 128]
+            valid_feat1n = F.normalize(valid_feat1n, p=2, dim=-1)  # [29, N, 128]
+            nscores = (
+                (valid_feat0[None, :, :] * valid_feat1n).sum(dim=-1).t()
+            )  # [N, 29]
+            valid_noise_feat1n = interpolate(valid_pos1n / 4, noise_feat1[i]).reshape(
+                self.neg_offsets.shape[-1], -1, 128
+            )  # [29, N, 128]
+            valid_noise_feat1n = F.normalize(
+                valid_noise_feat1n, p=2, dim=-1
+            )  # [29, N, 128]
+            noise_nscores = (
+                (valid_noise_feat0[None, :, :] * valid_noise_feat1n).sum(dim=-1).t()
+            )  # [N, 29]
 
             if self.sub_d_neg:
                 valid_pos2 = rnd_sample([selected_pos1], N)[0]
@@ -158,15 +230,17 @@ class MultiSampler (nn.Module):
         dscores = torch.matmul(valid_feat0, distractors.t())
         noise_dscores = torch.matmul(valid_noise_feat0, noise_distractors.t())
 
-        dis2 = (valid_pos2[:, 1] - valid_pos1[:, 1][:,None])**2 + (valid_pos2[:, 0] - valid_pos1[:, 0][:,None])**2
-        b = torch.arange(B, device=dscores.device)[:,None].expand(B, N).reshape(-1)
-        dis2 += (b != b[:,None]).long() * self.neg_d**2
+        dis2 = (valid_pos2[:, 1] - valid_pos1[:, 1][:, None]) ** 2 + (
+            valid_pos2[:, 0] - valid_pos1[:, 0][:, None]
+        ) ** 2
+        b = torch.arange(B, device=dscores.device)[:, None].expand(B, N).reshape(-1)
+        dis2 += (b != b[:, None]).long() * self.neg_d**2
         dscores[dis2 < self.neg_d**2] = 0
         noise_dscores[dis2 < self.neg_d**2] = 0
         scores = torch.cat((pscores, nscores, dscores), dim=1)
         noise_scores = torch.cat((noise_pscores, noise_nscores, noise_dscores), dim=1)
 
         gt = scores.new_zeros(scores.shape, dtype=torch.uint8)
-        gt[:, :pscores.shape[1]] = 1
+        gt[:, : pscores.shape[1]] = 1
 
         return scores, noise_scores, gt, mask, qconf, noise_qconf
diff --git a/third_party/DarkFeat/nets/noise_reliability_loss.py b/third_party/DarkFeat/nets/noise_reliability_loss.py
index 9efddae149653c225ee7f2c1eb5fed5f92cef15c..cbd69bba727e38efc3ac356168b4041b30c48e05 100644
--- a/third_party/DarkFeat/nets/noise_reliability_loss.py
+++ b/third_party/DarkFeat/nets/noise_reliability_loss.py
@@ -3,14 +3,15 @@ import torch.nn as nn
 from .reliability_loss import APLoss
 
 
-class MultiPixelAPLoss (nn.Module):
-    """ Computes the pixel-wise AP loss:
-        Given two images and ground-truth optical flow, computes the AP per pixel.
-        
-        feat1:  (B, C, H, W)   pixel-wise features extracted from img1
-        feat2:  (B, C, H, W)   pixel-wise features extracted from img2
-        aflow:  (B, 2, H, W)   absolute flow: aflow[...,y1,x1] = x2,y2
+class MultiPixelAPLoss(nn.Module):
+    """Computes the pixel-wise AP loss:
+    Given two images and ground-truth optical flow, computes the AP per pixel.
+
+    feat1:  (B, C, H, W)   pixel-wise features extracted from img1
+    feat2:  (B, C, H, W)   pixel-wise features extracted from img2
+    aflow:  (B, 2, H, W)   absolute flow: aflow[...,y1,x1] = x2,y2
     """
+
     def __init__(self, sampler, nq=20):
         nn.Module.__init__(self)
         self.aploss = APLoss(nq, min=0, max=1, euc=False)
@@ -20,21 +21,54 @@ class MultiPixelAPLoss (nn.Module):
 
     def loss_from_ap(self, ap, rel, noise_ap, noise_rel):
         dec_ap = torch.clamp(ap - noise_ap, min=0, max=1)
-        return (1 - ap*noise_rel - (1-noise_rel)*self.base), (1. - dec_ap*(1-noise_rel) - noise_rel*self.dec_base)
+        return (1 - ap * noise_rel - (1 - noise_rel) * self.base), (
+            1.0 - dec_ap * (1 - noise_rel) - noise_rel * self.dec_base
+        )
 
-    def forward(self, feat0, feat1, noise_feat0, noise_feat1, conf0, conf1, noise_conf0, noise_conf1, pos0, pos1, B, H, W, N=1500):
+    def forward(
+        self,
+        feat0,
+        feat1,
+        noise_feat0,
+        noise_feat1,
+        conf0,
+        conf1,
+        noise_conf0,
+        noise_conf1,
+        pos0,
+        pos1,
+        B,
+        H,
+        W,
+        N=1500,
+    ):
         # subsample things
-        scores, noise_scores, gt, msk, qconf, noise_qconf = self.sampler(feat0, feat1, noise_feat0, noise_feat1, \
-            conf0, conf1, noise_conf0, noise_conf1, pos0, pos1, B, H, W, N=1500)
-        
+        scores, noise_scores, gt, msk, qconf, noise_qconf = self.sampler(
+            feat0,
+            feat1,
+            noise_feat0,
+            noise_feat1,
+            conf0,
+            conf1,
+            noise_conf0,
+            noise_conf1,
+            pos0,
+            pos1,
+            B,
+            H,
+            W,
+            N=1500,
+        )
+
         # compute pixel-wise AP
         n = qconf.numel()
-        if n == 0: return 0, 0
-        scores, noise_scores, gt = scores.view(n,-1), noise_scores, gt.view(n,-1)
+        if n == 0:
+            return 0, 0
+        scores, noise_scores, gt = scores.view(n, -1), noise_scores, gt.view(n, -1)
         ap = self.aploss(scores, gt).view(msk.shape)
         noise_ap = self.aploss(noise_scores, gt).view(msk.shape)
 
         pixel_loss = self.loss_from_ap(ap, qconf, noise_ap, noise_qconf)
-        
+
         loss = pixel_loss[0][msk].mean(), pixel_loss[1][msk].mean()
-        return loss
\ No newline at end of file
+        return loss
diff --git a/third_party/DarkFeat/nets/reliability_loss.py b/third_party/DarkFeat/nets/reliability_loss.py
index 527f9886a2d4785680bac52ff2fa20033b8d8920..bdb3b73f472d915c9fd4c4542cdcab162298de5e 100644
--- a/third_party/DarkFeat/nets/reliability_loss.py
+++ b/third_party/DarkFeat/nets/reliability_loss.py
@@ -3,15 +3,16 @@ import torch.nn as nn
 import numpy as np
 
 
-class APLoss (nn.Module):
-    """ differentiable AP loss, through quantization.
-        
-        Input: (N, M)   values in [min, max]
-        label: (N, M)   values in {0, 1}
-        
-        Returns: list of query AP (for each n in {1..N})
-                 Note: typically, you want to minimize 1 - mean(AP)
+class APLoss(nn.Module):
+    """differentiable AP loss, through quantization.
+
+    Input: (N, M)   values in [min, max]
+    label: (N, M)   values in {0, 1}
+
+    Returns: list of query AP (for each n in {1..N})
+             Note: typically, you want to minimize 1 - mean(AP)
     """
+
     def __init__(self, nq=25, min=0, max=1, euc=False):
         nn.Module.__init__(self)
         assert isinstance(nq, int) and 2 <= nq <= 100
@@ -21,16 +22,20 @@ class APLoss (nn.Module):
         self.euc = euc
         gap = max - min
         assert gap > 0
-        
+
         # init quantizer = non-learnable (fixed) convolution
-        self.quantizer = q = nn.Conv1d(1, 2*nq, kernel_size=1, bias=True)
-        a = (nq-1) / gap
-        #1st half = lines passing to (min+x,1) and (min+x+1/a,0) with x = {nq-1..0}*gap/(nq-1)
+        self.quantizer = q = nn.Conv1d(1, 2 * nq, kernel_size=1, bias=True)
+        a = (nq - 1) / gap
+        # 1st half = lines passing to (min+x,1) and (min+x+1/a,0) with x = {nq-1..0}*gap/(nq-1)
         q.weight.data[:nq] = -a
-        q.bias.data[:nq] = torch.from_numpy(a*min + np.arange(nq, 0, -1)) # b = 1 + a*(min+x)
-        #2nd half = lines passing to (min+x,1) and (min+x-1/a,0) with x = {nq-1..0}*gap/(nq-1)
+        q.bias.data[:nq] = torch.from_numpy(
+            a * min + np.arange(nq, 0, -1)
+        )  # b = 1 + a*(min+x)
+        # 2nd half = lines passing to (min+x,1) and (min+x-1/a,0) with x = {nq-1..0}*gap/(nq-1)
         q.weight.data[nq:] = a
-        q.bias.data[nq:] = torch.from_numpy(np.arange(2-nq, 2, 1) - a*min) # b = 1 - a*(min+x)
+        q.bias.data[nq:] = torch.from_numpy(
+            np.arange(2 - nq, 2, 1) - a * min
+        )  # b = 1 - a*(min+x)
         # first and last one are special: just horizontal straight line
         q.weight.data[0] = q.weight.data[-1] = 0
         q.bias.data[0] = q.bias.data[-1] = 1
@@ -39,37 +44,42 @@ class APLoss (nn.Module):
         N, M = x.shape
         # print(x.shape, label.shape)
         if self.euc:  # euclidean distance in same range than similarities
-            x = 1 - torch.sqrt(2.001 - 2*x)
+            x = 1 - torch.sqrt(2.001 - 2 * x)
 
         # quantize all predictions
         q = self.quantizer(x.unsqueeze(1))
-        q = torch.min(q[:,:self.nq], q[:,self.nq:]).clamp(min=0) # N x Q x M [1600, 20, 1681]
-
-        nbs = q.sum(dim=-1) # number of samples  N x Q = c
-        rec = (q * label.view(N,1,M).float()).sum(dim=-1) # nb of correct samples = c+ N x Q
-        prec = rec.cumsum(dim=-1) / (1e-16 + nbs.cumsum(dim=-1)) # precision
-        rec /= rec.sum(dim=-1).unsqueeze(1) # norm in [0,1]
-
-        ap = (prec * rec).sum(dim=-1) # per-image AP
+        q = torch.min(q[:, : self.nq], q[:, self.nq :]).clamp(
+            min=0
+        )  # N x Q x M [1600, 20, 1681]
+
+        nbs = q.sum(dim=-1)  # number of samples  N x Q = c
+        rec = (q * label.view(N, 1, M).float()).sum(
+            dim=-1
+        )  # nb of correct samples = c+ N x Q
+        prec = rec.cumsum(dim=-1) / (1e-16 + nbs.cumsum(dim=-1))  # precision
+        rec /= rec.sum(dim=-1).unsqueeze(1)  # norm in [0,1]
+
+        ap = (prec * rec).sum(dim=-1)  # per-image AP
         return ap
 
     def forward(self, x, label):
-        assert x.shape == label.shape # N x M
+        assert x.shape == label.shape  # N x M
         return self.compute_AP(x, label)
 
 
-class PixelAPLoss (nn.Module):
-    """ Computes the pixel-wise AP loss:
-        Given two images and ground-truth optical flow, computes the AP per pixel.
-        
-        feat1:  (B, C, H, W)   pixel-wise features extracted from img1
-        feat2:  (B, C, H, W)   pixel-wise features extracted from img2
-        aflow:  (B, 2, H, W)   absolute flow: aflow[...,y1,x1] = x2,y2
+class PixelAPLoss(nn.Module):
+    """Computes the pixel-wise AP loss:
+    Given two images and ground-truth optical flow, computes the AP per pixel.
+
+    feat1:  (B, C, H, W)   pixel-wise features extracted from img1
+    feat2:  (B, C, H, W)   pixel-wise features extracted from img2
+    aflow:  (B, 2, H, W)   absolute flow: aflow[...,y1,x1] = x2,y2
     """
+
     def __init__(self, sampler, nq=20):
         nn.Module.__init__(self)
         self.aploss = APLoss(nq, min=0, max=1, euc=False)
-        self.name = 'pixAP'
+        self.name = "pixAP"
         self.sampler = sampler
 
     def loss_from_ap(self, ap, rel):
@@ -77,29 +87,32 @@ class PixelAPLoss (nn.Module):
 
     def forward(self, feat0, feat1, conf0, conf1, pos0, pos1, B, H, W, N=1200):
         # subsample things
-        scores, gt, msk, qconf = self.sampler(feat0, feat1, conf0, conf1, pos0, pos1, B, H, W, N=1200)
-        
+        scores, gt, msk, qconf = self.sampler(
+            feat0, feat1, conf0, conf1, pos0, pos1, B, H, W, N=1200
+        )
+
         # compute pixel-wise AP
         n = qconf.numel()
-        if n == 0: return 0
-        scores, gt = scores.view(n,-1), gt.view(n,-1)
+        if n == 0:
+            return 0
+        scores, gt = scores.view(n, -1), gt.view(n, -1)
         ap = self.aploss(scores, gt).view(msk.shape)
 
         pixel_loss = self.loss_from_ap(ap, qconf)
-        
+
         loss = pixel_loss[msk].mean()
         return loss
 
 
-class ReliabilityLoss (PixelAPLoss):
-    """ same than PixelAPLoss, but also train a pixel-wise confidence
-        that this pixel is going to have a good AP.
+class ReliabilityLoss(PixelAPLoss):
+    """same than PixelAPLoss, but also train a pixel-wise confidence
+    that this pixel is going to have a good AP.
     """
+
     def __init__(self, sampler, base=0.5, **kw):
         PixelAPLoss.__init__(self, sampler, **kw)
         assert 0 <= base < 1
         self.base = base
 
     def loss_from_ap(self, ap, rel):
-        return 1 - ap*rel - (1-rel)*self.base
-
+        return 1 - ap * rel - (1 - rel) * self.base
diff --git a/third_party/DarkFeat/nets/sampler.py b/third_party/DarkFeat/nets/sampler.py
index b732a3671872d5675be9826f76b0818d3b99d466..7686b24d78eb92b90ee3cafb95ad48966ee0f00f 100644
--- a/third_party/DarkFeat/nets/sampler.py
+++ b/third_party/DarkFeat/nets/sampler.py
@@ -5,17 +5,28 @@ import numpy as np
 
 from .geom import rnd_sample, interpolate
 
-class NghSampler2 (nn.Module):
-    """ Similar to NghSampler, but doesnt warp the 2nd image.
+
+class NghSampler2(nn.Module):
+    """Similar to NghSampler, but doesnt warp the 2nd image.
     Distance to GT =>  0 ... pos_d ... neg_d ... ngh
     Pixel label    =>  + + + + + + 0 0 - - - - - - -
-    
+
     Subsample on query side: if > 0, regular grid
-                                < 0, random points 
+                                < 0, random points
     In both cases, the number of query points is = W*H/subq**2
     """
-    def __init__(self, ngh, subq=1, subd=1, pos_d=0, neg_d=2, border=None,
-                       maxpool_pos=True, subd_neg=0):
+
+    def __init__(
+        self,
+        ngh,
+        subq=1,
+        subd=1,
+        pos_d=0,
+        neg_d=2,
+        border=None,
+        maxpool_pos=True,
+        subd_neg=0,
+    ):
         nn.Module.__init__(self)
         assert 0 <= pos_d < neg_d <= (ngh if ngh else 99)
         self.ngh = ngh
@@ -26,8 +37,9 @@ class NghSampler2 (nn.Module):
         self.sub_q = subq
         self.sub_d = subd
         self.sub_d_neg = subd_neg
-        if border is None: border = ngh
-        assert border >= ngh, 'border has to be larger than ngh'
+        if border is None:
+            border = ngh
+        assert border >= ngh, "border has to be larger than ngh"
         self.border = border
         self.maxpool_pos = maxpool_pos
         self.precompute_offsets()
@@ -36,39 +48,39 @@ class NghSampler2 (nn.Module):
         pos_d2 = self.pos_d**2
         neg_d2 = self.neg_d**2
         rad2 = self.ngh**2
-        rad = (self.ngh//self.sub_d) * self.ngh # make an integer multiple
+        rad = (self.ngh // self.sub_d) * self.ngh  # make an integer multiple
         pos = []
         neg = []
-        for j in range(-rad, rad+1, self.sub_d):
-          for i in range(-rad, rad+1, self.sub_d):
-            d2 = i*i + j*j
-            if d2 <= pos_d2:
-                pos.append( (i,j) )
-            elif neg_d2 <= d2 <= rad2: 
-                neg.append( (i,j) )
+        for j in range(-rad, rad + 1, self.sub_d):
+            for i in range(-rad, rad + 1, self.sub_d):
+                d2 = i * i + j * j
+                if d2 <= pos_d2:
+                    pos.append((i, j))
+                elif neg_d2 <= d2 <= rad2:
+                    neg.append((i, j))
 
-        self.register_buffer('pos_offsets', torch.LongTensor(pos).view(-1,2).t())
-        self.register_buffer('neg_offsets', torch.LongTensor(neg).view(-1,2).t())
+        self.register_buffer("pos_offsets", torch.LongTensor(pos).view(-1, 2).t())
+        self.register_buffer("neg_offsets", torch.LongTensor(neg).view(-1, 2).t())
 
     def gen_grid(self, step, B, H, W, dev):
         b1 = torch.arange(B, device=dev)
         if step > 0:
             # regular grid
-            x1 = torch.arange(self.border, W-self.border, step, device=dev)
-            y1 = torch.arange(self.border, H-self.border, step, device=dev)
+            x1 = torch.arange(self.border, W - self.border, step, device=dev)
+            y1 = torch.arange(self.border, H - self.border, step, device=dev)
             H1, W1 = len(y1), len(x1)
-            x1 = x1[None,None,:].expand(B,H1,W1).reshape(-1)
-            y1 = y1[None,:,None].expand(B,H1,W1).reshape(-1)
-            b1 = b1[:,None,None].expand(B,H1,W1).reshape(-1)
+            x1 = x1[None, None, :].expand(B, H1, W1).reshape(-1)
+            y1 = y1[None, :, None].expand(B, H1, W1).reshape(-1)
+            b1 = b1[:, None, None].expand(B, H1, W1).reshape(-1)
             shape = (B, H1, W1)
         else:
             # randomly spread
-            n = (H - 2*self.border) * (W - 2*self.border) // step**2
-            x1 = torch.randint(self.border, W-self.border, (n,), device=dev)
-            y1 = torch.randint(self.border, H-self.border, (n,), device=dev)
-            x1 = x1[None,:].expand(B,n).reshape(-1)
-            y1 = y1[None,:].expand(B,n).reshape(-1)
-            b1 = b1[:,None].expand(B,n).reshape(-1)
+            n = (H - 2 * self.border) * (W - 2 * self.border) // step**2
+            x1 = torch.randint(self.border, W - self.border, (n,), device=dev)
+            y1 = torch.randint(self.border, H - self.border, (n,), device=dev)
+            x1 = x1[None, :].expand(B, n).reshape(-1)
+            y1 = y1[None, :].expand(B, n).reshape(-1)
+            b1 = b1[:, None].expand(B, n).reshape(-1)
             shape = (B, n)
         return b1, y1, x1, shape
 
@@ -81,45 +93,73 @@ class NghSampler2 (nn.Module):
 
         for i in range(B):
             # positions in the first image
-            tmp_mask = (pos0[i][:, 1] >= self.border) * (pos0[i][:, 1] < W-self.border) \
-                * (pos0[i][:, 0] >= self.border) * (pos0[i][:, 0] < H-self.border)
+            tmp_mask = (
+                (pos0[i][:, 1] >= self.border)
+                * (pos0[i][:, 1] < W - self.border)
+                * (pos0[i][:, 0] >= self.border)
+                * (pos0[i][:, 0] < H - self.border)
+            )
 
             selected_pos0 = pos0[i][tmp_mask]
             selected_pos1 = pos1[i][tmp_mask]
             valid_pos0, valid_pos1 = rnd_sample([selected_pos0, selected_pos1], N)
 
             # sample features from first image
-            valid_feat0 = interpolate(valid_pos0 / 4, feat0[i]) # [N, 128]
-            valid_feat0 = F.normalize(valid_feat0, p=2, dim=-1) # [N, 128]
+            valid_feat0 = interpolate(valid_pos0 / 4, feat0[i])  # [N, 128]
+            valid_feat0 = F.normalize(valid_feat0, p=2, dim=-1)  # [N, 128]
             qconf = interpolate(valid_pos0 / 4, conf0[i])
 
             # sample GT from second image
-            mask = (valid_pos1[:, 1] >= 0) * (valid_pos1[:, 1] < W) \
-                * (valid_pos1[:, 0] >= 0) * (valid_pos1[:, 0] < H)
+            mask = (
+                (valid_pos1[:, 1] >= 0)
+                * (valid_pos1[:, 1] < W)
+                * (valid_pos1[:, 0] >= 0)
+                * (valid_pos1[:, 0] < H)
+            )
 
             def clamp(xy):
                 xy = xy
-                torch.clamp(xy[0], 0, H-1, out=xy[0])
-                torch.clamp(xy[1], 0, W-1, out=xy[1])
+                torch.clamp(xy[0], 0, H - 1, out=xy[0])
+                torch.clamp(xy[1], 0, W - 1, out=xy[1])
                 return xy
 
             # compute positive scores
-            valid_pos1p = clamp(valid_pos1.t()[:,None,:] + self.pos_offsets[:,:,None].to(valid_pos1.device)) # [2, 29, N]
-            valid_pos1p = valid_pos1p.permute(1, 2, 0).reshape(-1, 2) # [29, N, 2] -> [29*N, 2]
-            valid_feat1p = interpolate(valid_pos1p / 4, feat1[i]).reshape(self.pos_offsets.shape[-1], -1, 128) # [29, N, 128]
-            valid_feat1p = F.normalize(valid_feat1p, p=2, dim=-1) # [29, N, 128]
-
-            pscores = (valid_feat0[None,:,:] * valid_feat1p).sum(dim=-1).t() # [N, 29]
+            valid_pos1p = clamp(
+                valid_pos1.t()[:, None, :]
+                + self.pos_offsets[:, :, None].to(valid_pos1.device)
+            )  # [2, 29, N]
+            valid_pos1p = valid_pos1p.permute(1, 2, 0).reshape(
+                -1, 2
+            )  # [29, N, 2] -> [29*N, 2]
+            valid_feat1p = interpolate(valid_pos1p / 4, feat1[i]).reshape(
+                self.pos_offsets.shape[-1], -1, 128
+            )  # [29, N, 128]
+            valid_feat1p = F.normalize(valid_feat1p, p=2, dim=-1)  # [29, N, 128]
+
+            pscores = (
+                (valid_feat0[None, :, :] * valid_feat1p).sum(dim=-1).t()
+            )  # [N, 29]
             pscores, pos = pscores.max(dim=1, keepdim=True)
-            sel = clamp(valid_pos1.t() + self.pos_offsets[:,pos.view(-1)].to(valid_pos1.device))
-            qconf = (qconf + interpolate(sel.t() / 4, conf1[i]))/2
+            sel = clamp(
+                valid_pos1.t() + self.pos_offsets[:, pos.view(-1)].to(valid_pos1.device)
+            )
+            qconf = (qconf + interpolate(sel.t() / 4, conf1[i])) / 2
 
             # compute negative scores
-            valid_pos1n = clamp(valid_pos1.t()[:,None,:] + self.neg_offsets[:,:,None].to(valid_pos1.device)) # [2, 29, N]
-            valid_pos1n = valid_pos1n.permute(1, 2, 0).reshape(-1, 2) # [29, N, 2] -> [29*N, 2]
-            valid_feat1n = interpolate(valid_pos1n / 4, feat1[i]).reshape(self.neg_offsets.shape[-1], -1, 128) # [29, N, 128]
-            valid_feat1n = F.normalize(valid_feat1n, p=2, dim=-1) # [29, N, 128]
-            nscores = (valid_feat0[None,:,:] * valid_feat1n).sum(dim=-1).t() # [N, 29]
+            valid_pos1n = clamp(
+                valid_pos1.t()[:, None, :]
+                + self.neg_offsets[:, :, None].to(valid_pos1.device)
+            )  # [2, 29, N]
+            valid_pos1n = valid_pos1n.permute(1, 2, 0).reshape(
+                -1, 2
+            )  # [29, N, 2] -> [29*N, 2]
+            valid_feat1n = interpolate(valid_pos1n / 4, feat1[i]).reshape(
+                self.neg_offsets.shape[-1], -1, 128
+            )  # [29, N, 128]
+            valid_feat1n = F.normalize(valid_feat1n, p=2, dim=-1)  # [29, N, 128]
+            nscores = (
+                (valid_feat0[None, :, :] * valid_feat1n).sum(dim=-1).t()
+            )  # [N, 29]
 
             if self.sub_d_neg:
                 valid_pos2 = rnd_sample([selected_pos1], N)[0]
@@ -148,13 +188,15 @@ class NghSampler2 (nn.Module):
         valid_pos2 = torch.cat([i[:N] for i in valid_pos2_ls], dim=0)
 
         dscores = torch.matmul(valid_feat0, distractors.t())
-        dis2 = (valid_pos2[:, 1] - valid_pos1[:, 1][:,None])**2 + (valid_pos2[:, 0] - valid_pos1[:, 0][:,None])**2
-        b = torch.arange(B, device=dscores.device)[:,None].expand(B, N).reshape(-1)
-        dis2 += (b != b[:,None]).long() * self.neg_d**2
+        dis2 = (valid_pos2[:, 1] - valid_pos1[:, 1][:, None]) ** 2 + (
+            valid_pos2[:, 0] - valid_pos1[:, 0][:, None]
+        ) ** 2
+        b = torch.arange(B, device=dscores.device)[:, None].expand(B, N).reshape(-1)
+        dis2 += (b != b[:, None]).long() * self.neg_d**2
         dscores[dis2 < self.neg_d**2] = 0
         scores = torch.cat((pscores, nscores, dscores), dim=1)
-        
+
         gt = scores.new_zeros(scores.shape, dtype=torch.uint8)
-        gt[:, :pscores.shape[1]] = 1
+        gt[:, : pscores.shape[1]] = 1
 
         return scores, gt, mask, qconf
diff --git a/third_party/DarkFeat/nets/score.py b/third_party/DarkFeat/nets/score.py
index a78cf1c893bc338c12803697d55e121a75171f2c..60b255b6d2c9572323460500efd89fb414dee29e 100644
--- a/third_party/DarkFeat/nets/score.py
+++ b/third_party/DarkFeat/nets/score.py
@@ -8,23 +8,20 @@ from .geom import gather_nd
 # output: [batch_size, C, H, W], [batch_size, C, H, W]
 def peakiness_score(inputs, moving_instance_max, ksize=3, dilation=1):
     inputs = inputs / moving_instance_max
-    
+
     batch_size, C, H, W = inputs.shape
 
     pad_size = ksize // 2 + (dilation - 1)
     kernel = torch.ones([C, 1, ksize, ksize], device=inputs.device) / (ksize * ksize)
-    
-    pad_inputs = F.pad(inputs, [pad_size] * 4, mode='reflect')
+
+    pad_inputs = F.pad(inputs, [pad_size] * 4, mode="reflect")
 
     avg_spatial_inputs = F.conv2d(
-        pad_inputs,
-        kernel,
-        stride=1,
-        dilation=dilation,
-        padding=0,
-        groups=C
+        pad_inputs, kernel, stride=1, dilation=dilation, padding=0, groups=C
     )
-    avg_channel_inputs = torch.mean(inputs, axis=1, keepdim=True) # channel dimension is 1
+    avg_channel_inputs = torch.mean(
+        inputs, axis=1, keepdim=True
+    )  # channel dimension is 1
 
     alpha = F.softplus(inputs - avg_spatial_inputs)
     beta = F.softplus(inputs - avg_channel_inputs)
@@ -40,11 +37,17 @@ def extract_kpts(score_map, k=256, score_thld=0, edge_thld=0, nms_size=3, eof_si
 
     mask = score_map > score_thld
     if nms_size > 0:
-        nms_mask = F.max_pool2d(score_map, kernel_size=nms_size, stride=1, padding=nms_size//2)
+        nms_mask = F.max_pool2d(
+            score_map, kernel_size=nms_size, stride=1, padding=nms_size // 2
+        )
         nms_mask = torch.eq(score_map, nms_mask)
         mask = torch.logical_and(nms_mask, mask)
     if eof_size > 0:
-        eof_mask = torch.ones((1, 1, h - 2 * eof_size, w - 2 * eof_size), dtype=torch.float32, device=score_map.device)
+        eof_mask = torch.ones(
+            (1, 1, h - 2 * eof_size, w - 2 * eof_size),
+            dtype=torch.float32,
+            device=score_map.device,
+        )
         eof_mask = F.pad(eof_mask, [eof_size] * 4, value=0)
         eof_mask = eof_mask.bool()
         mask = torch.logical_and(eof_mask, mask)
@@ -86,24 +89,29 @@ def edge_mask(inputs, n_channel, dilation=1, edge_thld=5):
     b, c, h, w = inputs.size()
     device = inputs.device
 
-    dii_filter = torch.tensor(
-        [[0, 1., 0], [0, -2., 0], [0, 1., 0]]
-    ).view(1, 1, 3, 3)
+    dii_filter = torch.tensor([[0, 1.0, 0], [0, -2.0, 0], [0, 1.0, 0]]).view(1, 1, 3, 3)
     dij_filter = 0.25 * torch.tensor(
-        [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]]
-    ).view(1, 1, 3, 3)
-    djj_filter = torch.tensor(
-        [[0, 0, 0], [1., -2., 1.], [0, 0, 0]]
+        [[1.0, 0, -1.0], [0, 0.0, 0], [-1.0, 0, 1.0]]
     ).view(1, 1, 3, 3)
+    djj_filter = torch.tensor([[0, 0, 0], [1.0, -2.0, 1.0], [0, 0, 0]]).view(1, 1, 3, 3)
 
     dii = F.conv2d(
-        inputs.view(-1, 1, h, w), dii_filter.to(device), padding=dilation, dilation=dilation
+        inputs.view(-1, 1, h, w),
+        dii_filter.to(device),
+        padding=dilation,
+        dilation=dilation,
     ).view(b, c, h, w)
     dij = F.conv2d(
-        inputs.view(-1, 1, h, w), dij_filter.to(device), padding=dilation, dilation=dilation
+        inputs.view(-1, 1, h, w),
+        dij_filter.to(device),
+        padding=dilation,
+        dilation=dilation,
     ).view(b, c, h, w)
     djj = F.conv2d(
-        inputs.view(-1, 1, h, w), djj_filter.to(device), padding=dilation, dilation=dilation
+        inputs.view(-1, 1, h, w),
+        djj_filter.to(device),
+        padding=dilation,
+        dilation=dilation,
     ).view(b, c, h, w)
 
     det = dii * djj - dij * dij
diff --git a/third_party/DarkFeat/pose_estimation.py b/third_party/DarkFeat/pose_estimation.py
index c87877191e7e31c3bc0a362d7d481dfd5d4b5757..d4ebe66700f895f0d1fac1b21d502b3a7de02325 100644
--- a/third_party/DarkFeat/pose_estimation.py
+++ b/third_party/DarkFeat/pose_estimation.py
@@ -8,18 +8,28 @@ from tqdm import tqdm
 
 
 def compute_essential(matched_kp1, matched_kp2, K):
-    pts1 = cv2.undistortPoints(matched_kp1,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0))
-    pts2 = cv2.undistortPoints(matched_kp2,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0))
+    pts1 = cv2.undistortPoints(
+        matched_kp1,
+        cameraMatrix=K,
+        distCoeffs=(-0.117918271740560, 0.075246403574314, 0, 0),
+    )
+    pts2 = cv2.undistortPoints(
+        matched_kp2,
+        cameraMatrix=K,
+        distCoeffs=(-0.117918271740560, 0.075246403574314, 0, 0),
+    )
     K_1 = np.eye(3)
     # Estimate the homography between the matches using RANSAC
-    ransac_model, ransac_inliers = cv2.findEssentialMat(pts1, pts2, K_1, method=cv2.RANSAC, prob=0.999, threshold=0.001, maxIters=10000)
-    if ransac_inliers is None or ransac_model.shape != (3,3):
+    ransac_model, ransac_inliers = cv2.findEssentialMat(
+        pts1, pts2, K_1, method=cv2.RANSAC, prob=0.999, threshold=0.001, maxIters=10000
+    )
+    if ransac_inliers is None or ransac_model.shape != (3, 3):
         ransac_inliers = np.array([])
         ransac_model = None
     return ransac_model, ransac_inliers, pts1, pts2
 
 
-def compute_error(R_GT,t_GT,E,pts1_norm, pts2_norm, inliers):
+def compute_error(R_GT, t_GT, E, pts1_norm, pts2_norm, inliers):
     """Compute the angular error between two rotation matrices and two translation vectors.
     Keyword arguments:
     R -- 2D numpy array containing an estimated rotation
@@ -30,14 +40,14 @@ def compute_error(R_GT,t_GT,E,pts1_norm, pts2_norm, inliers):
 
     inliers = inliers.ravel()
     R = np.eye(3)
-    t = np.zeros((3,1))
+    t = np.zeros((3, 1))
     sst = True
     try:
         _, R, t, _ = cv2.recoverPose(E, pts1_norm, pts2_norm, np.eye(3), inliers)
     except:
         sst = False
     # calculate angle between provided rotations
-    # 
+    #
     if sst:
         dR = np.matmul(R, np.transpose(R_GT))
         dR = cv2.Rodrigues(dR)[0]
@@ -48,10 +58,10 @@ def compute_error(R_GT,t_GT,E,pts1_norm, pts2_norm, inliers):
         dT /= float(np.linalg.norm(t_GT))
 
         if dT > 1 or dT < -1:
-            print("Domain warning! dT:",dT)
-            dT = max(-1,min(1,dT))
+            print("Domain warning! dT:", dT)
+            dT = max(-1, min(1, dT))
         dT = math.acos(dT) * 180 / math.pi
-        dT = np.minimum(dT, 180 - dT) # ambiguity of E estimation
+        dT = np.minimum(dT, 180 - dT)  # ambiguity of E estimation
     else:
         dR, dT = 180.0, 180.0
     return dR, dT
@@ -59,8 +69,8 @@ def compute_error(R_GT,t_GT,E,pts1_norm, pts2_norm, inliers):
 
 def pose_evaluation(result_base_dir, dark_name1, dark_name2, enhancer, K, R_GT, t_GT):
     try:
-        m_kp1 = np.load(result_base_dir+enhancer+'/DarkFeat/POINT_1/'+dark_name1)
-        m_kp2 = np.load(result_base_dir+enhancer+'/DarkFeat/POINT_2/'+dark_name2)
+        m_kp1 = np.load(result_base_dir + enhancer + "/DarkFeat/POINT_1/" + dark_name1)
+        m_kp2 = np.load(result_base_dir + enhancer + "/DarkFeat/POINT_2/" + dark_name2)
     except:
         return 180.0, 180.0
     try:
@@ -71,37 +81,37 @@ def pose_evaluation(result_base_dir, dark_name1, dark_name2, enhancer, K, R_GT,
     return dR, dT
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument('--histeq', action='store_true')
-    parser.add_argument('--dataset_dir', type=str, default='/data/hyz/MID/')
+    parser.add_argument("--histeq", action="store_true")
+    parser.add_argument("--dataset_dir", type=str, default="/data/hyz/MID/")
     opt = parser.parse_args()
-    
+
     sizer = (960, 640)
-    focallength_x = 4.504986436499113e+03/(6744/sizer[0])
-    focallength_y = 4.513311442889859e+03/(4502/sizer[1])
+    focallength_x = 4.504986436499113e03 / (6744 / sizer[0])
+    focallength_y = 4.513311442889859e03 / (4502 / sizer[1])
     K = np.eye(3)
-    K[0,0] = focallength_x
-    K[1,1] = focallength_y
-    K[0,2] = 3.363322177533149e+03/(6744/sizer[0])
-    K[1,2] = 2.291824660547715e+03/(4502/sizer[1])
+    K[0, 0] = focallength_x
+    K[1, 1] = focallength_y
+    K[0, 2] = 3.363322177533149e03 / (6744 / sizer[0])
+    K[1, 2] = 2.291824660547715e03 / (4502 / sizer[1])
     Kinv = np.linalg.inv(K)
     Kinvt = np.transpose(Kinv)
 
     PE_MT = np.zeros((6, 8))
 
-    enhancer = 'None' if not opt.histeq else 'HistEQ'
+    enhancer = "None" if not opt.histeq else "HistEQ"
 
-    for scene in ['Indoor', 'Outdoor']:
-        dir_base = opt.dataset_dir + '/' + scene + '/'
-        base_save = 'result_errors/' + scene + '/'
+    for scene in ["Indoor", "Outdoor"]:
+        dir_base = opt.dataset_dir + "/" + scene + "/"
+        base_save = "result_errors/" + scene + "/"
         pair_list = sorted(os.listdir(dir_base))
 
         os.makedirs(base_save, exist_ok=True)
 
         for pair in tqdm(pair_list):
             opention = 1
-            if scene == 'Outdoor':
+            if scene == "Outdoor":
                 pass
             else:
                 if int(pair[4::]) <= 17:
@@ -109,29 +119,43 @@ if __name__ == '__main__':
                 else:
                     pass
             name = []
-            files = sorted(os.listdir(dir_base+pair))
+            files = sorted(os.listdir(dir_base + pair))
             for file_ in files:
-                if file_.endswith('.cr2'):
+                if file_.endswith(".cr2"):
                     name.append(file_[0:9])
-            ISO = ['00100', '00200', '00400', '00800', '01600', '03200', '06400', '12800']
+            ISO = [
+                "00100",
+                "00200",
+                "00400",
+                "00800",
+                "01600",
+                "03200",
+                "06400",
+                "12800",
+            ]
             if opention == 1:
-                Shutter_speed = ['0.005','0.01','0.025','0.05','0.17','0.5']
+                Shutter_speed = ["0.005", "0.01", "0.025", "0.05", "0.17", "0.5"]
             else:
-                Shutter_speed = ['0.01','0.02','0.05','0.1','0.3','1']
+                Shutter_speed = ["0.01", "0.02", "0.05", "0.1", "0.3", "1"]
 
-            E_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'E_estimated.npy')
-            F_GT = np.dot(np.dot(Kinvt,E_GT),Kinv)
-            R_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'R_GT.npy')
-            t_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'T_GT.npy')
-            result_base_dir ='result/' +scene+'/'+pair+'/'
+            E_GT = np.load(dir_base + pair + "/GT_Correspondence/" + "E_estimated.npy")
+            F_GT = np.dot(np.dot(Kinvt, E_GT), Kinv)
+            R_GT = np.load(dir_base + pair + "/GT_Correspondence/" + "R_GT.npy")
+            t_GT = np.load(dir_base + pair + "/GT_Correspondence/" + "T_GT.npy")
+            result_base_dir = "result/" + scene + "/" + pair + "/"
             for iso in ISO:
                 for ex in Shutter_speed:
-                    dark_name1 = name[0]+iso+'_'+ex+'_'+scene+'.npy'
-                    dark_name2 = name[1]+iso+'_'+ex+'_'+scene+'.npy'
-
-                    dr, dt = pose_evaluation(result_base_dir,dark_name1,dark_name2,enhancer,K,R_GT,t_GT) 
-                    PE_MT[Shutter_speed.index(ex),ISO.index(iso)] = max(dr, dt)
-
-                    subprocess.check_output(['mkdir', '-p', base_save + pair + f'/{enhancer}/'])
-                    np.save(base_save + pair + f'/{enhancer}/Pose_error_DarkFeat.npy', PE_MT)
-          
\ No newline at end of file
+                    dark_name1 = name[0] + iso + "_" + ex + "_" + scene + ".npy"
+                    dark_name2 = name[1] + iso + "_" + ex + "_" + scene + ".npy"
+
+                    dr, dt = pose_evaluation(
+                        result_base_dir, dark_name1, dark_name2, enhancer, K, R_GT, t_GT
+                    )
+                    PE_MT[Shutter_speed.index(ex), ISO.index(iso)] = max(dr, dt)
+
+                    subprocess.check_output(
+                        ["mkdir", "-p", base_save + pair + f"/{enhancer}/"]
+                    )
+                    np.save(
+                        base_save + pair + f"/{enhancer}/Pose_error_DarkFeat.npy", PE_MT
+                    )
diff --git a/third_party/DarkFeat/raw_preprocess.py b/third_party/DarkFeat/raw_preprocess.py
index 226155a84e97f15782d3650f4ef6b3fa1880e07b..6f51bef8ae45114160214fbc22b1c5cc832c7d42 100644
--- a/third_party/DarkFeat/raw_preprocess.py
+++ b/third_party/DarkFeat/raw_preprocess.py
@@ -9,54 +9,78 @@ from tqdm import tqdm
 
 def process_raw(args, path, w_new, h_new):
     raw = rawpy.imread(str(path)).raw_image_visible
-    if '_00200_' in str(path) or '_00100_' in str(path):
-        raw = np.clip(raw.astype('float32') - 512, 0, 65535)
+    if "_00200_" in str(path) or "_00100_" in str(path):
+        raw = np.clip(raw.astype("float32") - 512, 0, 65535)
     else:
-        raw = np.clip(raw.astype('float32') - 2048, 0, 65535)
-    img = colour_demosaicing.demosaicing_CFA_Bayer_bilinear(raw, 'RGGB').astype('float32')
+        raw = np.clip(raw.astype("float32") - 2048, 0, 65535)
+    img = colour_demosaicing.demosaicing_CFA_Bayer_bilinear(raw, "RGGB").astype(
+        "float32"
+    )
     img = np.clip(img, 0, 16383)
 
     # HistEQ start
     if args.histeq:
         img2 = np.zeros_like(img)
         for i in range(3):
-            hist,bins = np.histogram(img[..., i].flatten(),16384,[0,16384])
+            hist, bins = np.histogram(img[..., i].flatten(), 16384, [0, 16384])
             cdf = hist.cumsum()
             cdf_normalized = cdf * float(hist.max()) / cdf.max()
-            cdf_m = np.ma.masked_equal(cdf,0)
-            cdf_m = (cdf_m - cdf_m.min())*16383/(cdf_m.max()-cdf_m.min())
-            cdf = np.ma.filled(cdf_m,0).astype('uint16')
-            img2[..., i] = cdf[img[..., i].astype('int16')]
-            img[..., i] = img2[..., i].astype('float32')
+            cdf_m = np.ma.masked_equal(cdf, 0)
+            cdf_m = (cdf_m - cdf_m.min()) * 16383 / (cdf_m.max() - cdf_m.min())
+            cdf = np.ma.filled(cdf_m, 0).astype("uint16")
+            img2[..., i] = cdf[img[..., i].astype("int16")]
+            img[..., i] = img2[..., i].astype("float32")
     # HistEQ end
 
     m = img.mean()
     d = np.abs(img - img.mean()).mean()
-    img = (img - m + 2*d) / 4/d * 255
+    img = (img - m + 2 * d) / 4 / d * 255
     image = np.clip(img, 0, 255)
 
-    image = cv2.resize(image.astype('float32'), (w_new, h_new), interpolation=cv2.INTER_AREA)
+    image = cv2.resize(
+        image.astype("float32"), (w_new, h_new), interpolation=cv2.INTER_AREA
+    )
 
     if args.histeq:
-        path=str(path)
-        os.makedirs('/'.join(path.split('/')[:-2]+[path.split('/')[-2]+'-npy']), exist_ok=True)
-        np.save('/'.join(path.split('/')[:-2]+[path.split('/')[-2]+'-npy']+[path.split('/')[-1].replace('cr2','npy')]), image)
+        path = str(path)
+        os.makedirs(
+            "/".join(path.split("/")[:-2] + [path.split("/")[-2] + "-npy"]),
+            exist_ok=True,
+        )
+        np.save(
+            "/".join(
+                path.split("/")[:-2]
+                + [path.split("/")[-2] + "-npy"]
+                + [path.split("/")[-1].replace("cr2", "npy")]
+            ),
+            image,
+        )
     else:
-        path=str(path)
-        os.makedirs('/'.join(path.split('/')[:-2]+[path.split('/')[-2]+'-npy-nohisteq']), exist_ok=True)
-        np.save('/'.join(path.split('/')[:-2]+[path.split('/')[-2]+'-npy-nohisteq']+[path.split('/')[-1].replace('cr2','npy')]), image)
+        path = str(path)
+        os.makedirs(
+            "/".join(path.split("/")[:-2] + [path.split("/")[-2] + "-npy-nohisteq"]),
+            exist_ok=True,
+        )
+        np.save(
+            "/".join(
+                path.split("/")[:-2]
+                + [path.split("/")[-2] + "-npy-nohisteq"]
+                + [path.split("/")[-1].replace("cr2", "npy")]
+            ),
+            image,
+        )
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     import argparse
+
     parser = argparse.ArgumentParser()
-    parser.add_argument('--H', type=int, default=int(640))
-    parser.add_argument('--W', type=int, default=int(960))
-    parser.add_argument('--histeq', action='store_true')
-    parser.add_argument('--dataset_dir', type=str, default='/data/hyz/MID/')
+    parser.add_argument("--H", type=int, default=int(640))
+    parser.add_argument("--W", type=int, default=int(960))
+    parser.add_argument("--histeq", action="store_true")
+    parser.add_argument("--dataset_dir", type=str, default="/data/hyz/MID/")
     args = parser.parse_args()
 
-    path_ls = glob.glob(args.dataset_dir + '/*/pair*/?????/*')
+    path_ls = glob.glob(args.dataset_dir + "/*/pair*/?????/*")
     for path in tqdm(path_ls):
         process_raw(args, path, args.W, args.H)
-
diff --git a/third_party/DarkFeat/read_error.py b/third_party/DarkFeat/read_error.py
index 406b92dbd3877a11e51aebc3a705cd8d8d17e173..9015dfd2954b21115458fa25a2fd278c7cd69596 100644
--- a/third_party/DarkFeat/read_error.py
+++ b/third_party/DarkFeat/read_error.py
@@ -1,56 +1,80 @@
-import os 
+import os
 import numpy as np
 import subprocess
 
 # def ratio(losses, thresholds=[1,2,3,4,5,6,7,8,9,10]):
-def ratio(losses, thresholds=[5,10]):
-    return [
-        '{:.3f}'.format(np.mean(losses < threshold))
-        for threshold in thresholds
-    ]
+def ratio(losses, thresholds=[5, 10]):
+    return ["{:.3f}".format(np.mean(losses < threshold)) for threshold in thresholds]
 
-if __name__ == '__main__':
-    scene = 'Indoor'
-    dir_base = 'result_errors/Indoor/'
-    save_pt = 'resultfinal_errors/Indoor/'
 
-    subprocess.check_output(['mkdir', '-p', save_pt])
+if __name__ == "__main__":
+    scene = "Indoor"
+    dir_base = "result_errors/Indoor/"
+    save_pt = "resultfinal_errors/Indoor/"
 
-    with open(save_pt +'ratio_methods_'+scene+'.txt','w') as f:
-        f.write('5deg 10deg'+'\n')
+    subprocess.check_output(["mkdir", "-p", save_pt])
+
+    with open(save_pt + "ratio_methods_" + scene + ".txt", "w") as f:
+        f.write("5deg 10deg" + "\n")
         pair_list = os.listdir(dir_base)
-        enhancer = os.listdir(dir_base+'/pair9/')
+        enhancer = os.listdir(dir_base + "/pair9/")
         for method in enhancer:
-            pose_error_list = sorted(os.listdir(dir_base+'/pair9/'+method))
+            pose_error_list = sorted(os.listdir(dir_base + "/pair9/" + method))
             for pose_error in pose_error_list:
-                error_array = np.expand_dims(np.zeros((6, 8)),axis=2)
+                error_array = np.expand_dims(np.zeros((6, 8)), axis=2)
                 for pair in pair_list:
                     try:
-                        error = np.expand_dims(np.load(dir_base+'/'+pair+'/'+method+'/'+pose_error),axis=2)
+                        error = np.expand_dims(
+                            np.load(
+                                dir_base + "/" + pair + "/" + method + "/" + pose_error
+                            ),
+                            axis=2,
+                        )
                     except:
-                        print('error in', dir_base+'/'+pair+'/'+method+'/'+pose_error)
+                        print(
+                            "error in",
+                            dir_base + "/" + pair + "/" + method + "/" + pose_error,
+                        )
                         continue
-                    error_array = np.concatenate((error_array,error),axis=2)
-                ratio_result = ratio(error_array[:,:,1::].flatten())
-                f.write(method + '_' + pose_error[11:-4] +' '+' '.join([str(i) for i in ratio_result])+"\n")
+                    error_array = np.concatenate((error_array, error), axis=2)
+                ratio_result = ratio(error_array[:, :, 1::].flatten())
+                f.write(
+                    method
+                    + "_"
+                    + pose_error[11:-4]
+                    + " "
+                    + " ".join([str(i) for i in ratio_result])
+                    + "\n"
+                )
 
-    
-    scene = 'Outdoor'
-    dir_base = 'result_errors/Outdoor/'
-    save_pt = 'resultfinal_errors/Outdoor/'
+    scene = "Outdoor"
+    dir_base = "result_errors/Outdoor/"
+    save_pt = "resultfinal_errors/Outdoor/"
 
-    subprocess.check_output(['mkdir', '-p', save_pt])
+    subprocess.check_output(["mkdir", "-p", save_pt])
 
-    with open(save_pt +'ratio_methods_'+scene+'.txt','w') as f:
-        f.write('5deg 10deg'+'\n')
+    with open(save_pt + "ratio_methods_" + scene + ".txt", "w") as f:
+        f.write("5deg 10deg" + "\n")
         pair_list = os.listdir(dir_base)
-        enhancer = os.listdir(dir_base+'/pair9/')
+        enhancer = os.listdir(dir_base + "/pair9/")
         for method in enhancer:
-            pose_error_list = sorted(os.listdir(dir_base+'/pair9/'+method))
+            pose_error_list = sorted(os.listdir(dir_base + "/pair9/" + method))
             for pose_error in pose_error_list:
-                error_array = np.expand_dims(np.zeros((6, 8)),axis=2)
+                error_array = np.expand_dims(np.zeros((6, 8)), axis=2)
                 for pair in pair_list:
-                    error = np.expand_dims(np.load(dir_base+'/'+pair+'/'+method+'/'+pose_error),axis=2)
-                    error_array = np.concatenate((error_array,error),axis=2)
-                ratio_result = ratio(error_array[:,:,1::].flatten())
-                f.write(method + '_' + pose_error[11:-4] +' '+' '.join([str(i) for i in ratio_result])+"\n")
+                    error = np.expand_dims(
+                        np.load(
+                            dir_base + "/" + pair + "/" + method + "/" + pose_error
+                        ),
+                        axis=2,
+                    )
+                    error_array = np.concatenate((error_array, error), axis=2)
+                ratio_result = ratio(error_array[:, :, 1::].flatten())
+                f.write(
+                    method
+                    + "_"
+                    + pose_error[11:-4]
+                    + " "
+                    + " ".join([str(i) for i in ratio_result])
+                    + "\n"
+                )
diff --git a/third_party/DarkFeat/run.py b/third_party/DarkFeat/run.py
index 0e4c87053d2970fc927d8991aa0dab208f3c4917..1cf463d4e0218d66dff0c3637346a12d327d9fda 100644
--- a/third_party/DarkFeat/run.py
+++ b/third_party/DarkFeat/run.py
@@ -10,39 +10,45 @@ from trainer_single_norel import SingleTrainerNoRel
 from trainer_single import SingleTrainer
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     # add argument parser
     parser = argparse.ArgumentParser()
-    parser.add_argument('--config', type=str, default='./configs/config.yaml')
-    parser.add_argument('--dataset_dir', type=str, default='/mnt/nvme2n1/hyz/data/GL3D')
-    parser.add_argument('--data_split', type=str, default='comb')
-    parser.add_argument('--is_training', type=bool, default=True)
-    parser.add_argument('--job_name', type=str, default='')
-    parser.add_argument('--gpu', type=str, default='0')
-    parser.add_argument('--start_cnt', type=int, default=0)
-    parser.add_argument('--stage', type=int, default=1)
+    parser.add_argument("--config", type=str, default="./configs/config.yaml")
+    parser.add_argument("--dataset_dir", type=str, default="/mnt/nvme2n1/hyz/data/GL3D")
+    parser.add_argument("--data_split", type=str, default="comb")
+    parser.add_argument("--is_training", type=bool, default=True)
+    parser.add_argument("--job_name", type=str, default="")
+    parser.add_argument("--gpu", type=str, default="0")
+    parser.add_argument("--start_cnt", type=int, default=0)
+    parser.add_argument("--stage", type=int, default=1)
     args = parser.parse_args()
 
     # load global config
-    with open(args.config, 'r') as f:
+    with open(args.config, "r") as f:
         config = yaml.load(f, Loader=yaml.FullLoader)
 
     # setup dataloader
-    dataset = GL3DDataset(args.dataset_dir, config['network'], args.data_split, is_training=args.is_training)
+    dataset = GL3DDataset(
+        args.dataset_dir,
+        config["network"],
+        args.data_split,
+        is_training=args.is_training,
+    )
     data_loader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=4)
 
-    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
-
+    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
 
     if args.stage == 1:
-        trainer = SingleTrainerNoRel(config, f'cuda:0', data_loader, args.job_name, args.start_cnt)
+        trainer = SingleTrainerNoRel(
+            config, f"cuda:0", data_loader, args.job_name, args.start_cnt
+        )
     elif args.stage == 2:
-        trainer = SingleTrainer(config, f'cuda:0', data_loader, args.job_name, args.start_cnt)
+        trainer = SingleTrainer(
+            config, f"cuda:0", data_loader, args.job_name, args.start_cnt
+        )
     elif args.stage == 3:
-        trainer = Trainer(config, f'cuda:0', data_loader, args.job_name, args.start_cnt)
+        trainer = Trainer(config, f"cuda:0", data_loader, args.job_name, args.start_cnt)
     else:
         raise NotImplementedError()
-        
-    trainer.train()
 
-    
\ No newline at end of file
+    trainer.train()
diff --git a/third_party/DarkFeat/trainer.py b/third_party/DarkFeat/trainer.py
index e6ff2af9608e934b6899058d756bb2ab7d0fee2d..1f3bed348f16adf81d3f48ef23563442c7d35fdc 100644
--- a/third_party/DarkFeat/trainer.py
+++ b/third_party/DarkFeat/trainer.py
@@ -23,23 +23,26 @@ class Trainer:
         self.config = config
         self.device = device
         self.loader = loader
-        
+
         # tensorboard writer construction
-        os.makedirs('./runs/', exist_ok=True)
-        if job_name != '':
-            self.log_dir = f'runs/{job_name}'
+        os.makedirs("./runs/", exist_ok=True)
+        if job_name != "":
+            self.log_dir = f"runs/{job_name}"
         else:
             self.log_dir = f'runs/{datetime.datetime.now().strftime("%m-%d-%H%M%S")}'
 
         self.writer = SummaryWriter(self.log_dir)
-        with open(f'{self.log_dir}/config.yaml', 'w') as f:
+        with open(f"{self.log_dir}/config.yaml", "w") as f:
             yaml.dump(config, f)
 
-        if config['network']['input_type'] == 'gray':
+        if config["network"]["input_type"] == "gray":
             self.model = eval(f'{config["network"]["model"]}(inchan=1)').to(device)
-        elif config['network']['input_type'] == 'rgb' or config['network']['input_type'] == 'raw-demosaic':
+        elif (
+            config["network"]["input_type"] == "rgb"
+            or config["network"]["input_type"] == "raw-demosaic"
+        ):
             self.model = eval(f'{config["network"]["model"]}(inchan=3)').to(device)
-        elif config['network']['input_type'] == 'raw':
+        elif config["network"]["input_type"] == "raw":
             self.model = eval(f'{config["network"]["model"]}(inchan=4)').to(device)
         else:
             raise NotImplementedError()
@@ -49,80 +52,104 @@ class Trainer:
 
         # reliability map conv
         self.model.clf = nn.Conv2d(128, 2, kernel_size=1).cuda()
-        
+
         # load model
         self.cnt = 0
         if start_cnt != 0:
-            self.model.load_state_dict(torch.load(f'{self.log_dir}/model_{start_cnt:06d}.pth', map_location=device))
+            self.model.load_state_dict(
+                torch.load(
+                    f"{self.log_dir}/model_{start_cnt:06d}.pth", map_location=device
+                )
+            )
             self.cnt = start_cnt + 1
 
         # sampler
-        sampler = MultiSampler(ngh=7, subq=-8, subd=1, pos_d=3, neg_d=5, border=16,
-            subd_neg=-8,maxpool_pos=True).to(device)
+        sampler = MultiSampler(
+            ngh=7,
+            subq=-8,
+            subd=1,
+            pos_d=3,
+            neg_d=5,
+            border=16,
+            subd_neg=-8,
+            maxpool_pos=True,
+        ).to(device)
         self.reliability_relitive_loss = MultiPixelAPLoss(sampler, nq=20).to(device)
-        
 
         # optimizer and scheduler
-        if self.config['training']['optimizer'] == 'SGD':
+        if self.config["training"]["optimizer"] == "SGD":
             self.optimizer = torch.optim.SGD(
-                [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}],
-                lr=self.config['training']['lr'],
-                momentum=self.config['training']['momentum'],
-                weight_decay=self.config['training']['weight_decay'],
+                [
+                    {
+                        "params": self.model.parameters(),
+                        "initial_lr": self.config["training"]["lr"],
+                    }
+                ],
+                lr=self.config["training"]["lr"],
+                momentum=self.config["training"]["momentum"],
+                weight_decay=self.config["training"]["weight_decay"],
             )
-        elif self.config['training']['optimizer'] == 'Adam':
+        elif self.config["training"]["optimizer"] == "Adam":
             self.optimizer = torch.optim.Adam(
-                [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}],
-                lr=self.config['training']['lr'],
-                weight_decay=self.config['training']['weight_decay']
+                [
+                    {
+                        "params": self.model.parameters(),
+                        "initial_lr": self.config["training"]["lr"],
+                    }
+                ],
+                lr=self.config["training"]["lr"],
+                weight_decay=self.config["training"]["weight_decay"],
             )
         else:
             raise NotImplementedError()
 
         self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
             self.optimizer,
-            step_size=self.config['training']['lr_step'],
-            gamma=self.config['training']['lr_gamma'],
-            last_epoch=start_cnt
+            step_size=self.config["training"]["lr_step"],
+            gamma=self.config["training"]["lr_gamma"],
+            last_epoch=start_cnt,
         )
         for param_tensor in self.model.state_dict():
             print(param_tensor, "\t", self.model.state_dict()[param_tensor].size())
 
-
     def save(self, iter_num):
-        torch.save(self.model.state_dict(), f'{self.log_dir}/model_{iter_num:06d}.pth')
+        torch.save(self.model.state_dict(), f"{self.log_dir}/model_{iter_num:06d}.pth")
 
     def load(self, path):
         self.model.load_state_dict(torch.load(path))
 
     def train(self):
         self.model.train()
-        
+
         for epoch in range(2):
             for batch_idx, inputs in enumerate(self.loader):
                 self.optimizer.zero_grad()
                 t = time.time()
 
                 # preprocess and add noise
-                img0_ori, noise_img0_ori = self.preprocess_noise_pair(inputs['img0'], self.cnt)
-                img1_ori, noise_img1_ori = self.preprocess_noise_pair(inputs['img1'], self.cnt)
+                img0_ori, noise_img0_ori = self.preprocess_noise_pair(
+                    inputs["img0"], self.cnt
+                )
+                img1_ori, noise_img1_ori = self.preprocess_noise_pair(
+                    inputs["img1"], self.cnt
+                )
 
                 img0 = img0_ori.permute(0, 3, 1, 2).float().to(self.device)
                 img1 = img1_ori.permute(0, 3, 1, 2).float().to(self.device)
                 noise_img0 = noise_img0_ori.permute(0, 3, 1, 2).float().to(self.device)
                 noise_img1 = noise_img1_ori.permute(0, 3, 1, 2).float().to(self.device)
 
-                if self.config['network']['input_type'] == 'rgb':
+                if self.config["network"]["input_type"] == "rgb":
                     # 3-channel rgb
                     RGB_mean = [0.485, 0.456, 0.406]
-                    RGB_std  = [0.229, 0.224, 0.225]
+                    RGB_std = [0.229, 0.224, 0.225]
                     norm_RGB = tvf.Normalize(mean=RGB_mean, std=RGB_std)
                     img0 = norm_RGB(img0)
                     img1 = norm_RGB(img1)
                     noise_img0 = norm_RGB(noise_img0)
                     noise_img1 = norm_RGB(noise_img1)
 
-                elif self.config['network']['input_type'] == 'gray':
+                elif self.config["network"]["input_type"] == "gray":
                     # 1-channel
                     img0 = torch.mean(img0, dim=1, keepdim=True)
                     img1 = torch.mean(img1, dim=1, keepdim=True)
@@ -135,11 +162,11 @@ class Trainer:
                     noise_img0 = norm_gray0(noise_img0)
                     noise_img1 = norm_gray1(noise_img1)
 
-                elif self.config['network']['input_type'] == 'raw':
+                elif self.config["network"]["input_type"] == "raw":
                     # 4-channel
                     pass
 
-                elif self.config['network']['input_type'] == 'raw-demosaic':
+                elif self.config["network"]["input_type"] == "raw-demosaic":
                     # 3-channel
                     pass
 
@@ -149,14 +176,26 @@ class Trainer:
                 desc0, score_map0, _, _ = self.model(img0)
                 desc1, score_map1, _, _ = self.model(img1)
 
-                conf0 = F.softmax(self.model.clf(torch.abs(desc0)**2.0), dim=1)[:,1:2]
-                conf1 = F.softmax(self.model.clf(torch.abs(desc1)**2.0), dim=1)[:,1:2]
+                conf0 = F.softmax(self.model.clf(torch.abs(desc0) ** 2.0), dim=1)[
+                    :, 1:2
+                ]
+                conf1 = F.softmax(self.model.clf(torch.abs(desc1) ** 2.0), dim=1)[
+                    :, 1:2
+                ]
 
-                noise_desc0, noise_score_map0, noise_at0, noise_att0 = self.model(noise_img0)
-                noise_desc1, noise_score_map1, noise_at1, noise_att1 = self.model(noise_img1)
+                noise_desc0, noise_score_map0, noise_at0, noise_att0 = self.model(
+                    noise_img0
+                )
+                noise_desc1, noise_score_map1, noise_at1, noise_att1 = self.model(
+                    noise_img1
+                )
 
-                noise_conf0 = F.softmax(self.model.clf(torch.abs(noise_desc0)**2.0), dim=1)[:,1:2]
-                noise_conf1 = F.softmax(self.model.clf(torch.abs(noise_desc1)**2.0), dim=1)[:,1:2]
+                noise_conf0 = F.softmax(
+                    self.model.clf(torch.abs(noise_desc0) ** 2.0), dim=1
+                )[:, 1:2]
+                noise_conf1 = F.softmax(
+                    self.model.clf(torch.abs(noise_desc1) ** 2.0), dim=1
+                )[:, 1:2]
 
                 cur_feat_size0 = torch.tensor(score_map0.shape[2:])
                 cur_feat_size1 = torch.tensor(score_map1.shape[2:])
@@ -174,71 +213,128 @@ class Trainer:
                 noise_conf0 = noise_conf0.permute(0, 2, 3, 1)
                 noise_conf1 = noise_conf1.permute(0, 2, 3, 1)
 
-                r_K0 = getK(inputs['ori_img_size0'], cur_feat_size0, inputs['K0']).to(self.device)
-                r_K1 = getK(inputs['ori_img_size1'], cur_feat_size1, inputs['K1']).to(self.device)
-                
+                r_K0 = getK(inputs["ori_img_size0"], cur_feat_size0, inputs["K0"]).to(
+                    self.device
+                )
+                r_K1 = getK(inputs["ori_img_size1"], cur_feat_size1, inputs["K1"]).to(
+                    self.device
+                )
+
                 pos0 = _grid_positions(
-                    cur_feat_size0[0], cur_feat_size0[1], img0.shape[0]).to(self.device)
+                    cur_feat_size0[0], cur_feat_size0[1], img0.shape[0]
+                ).to(self.device)
 
                 pos0_for_rel, pos1_for_rel, _ = getWarpNoValidate(
-                    pos0, inputs['rel_pose'].to(self.device), inputs['depth0'].to(self.device),
-                    r_K0, inputs['depth1'].to(self.device), r_K1, img0.shape[0])
+                    pos0,
+                    inputs["rel_pose"].to(self.device),
+                    inputs["depth0"].to(self.device),
+                    r_K0,
+                    inputs["depth1"].to(self.device),
+                    r_K1,
+                    img0.shape[0],
+                )
 
                 pos0, pos1, _ = getWarp(
-                    pos0, inputs['rel_pose'].to(self.device), inputs['depth0'].to(self.device),
-                    r_K0, inputs['depth1'].to(self.device), r_K1, img0.shape[0])
+                    pos0,
+                    inputs["rel_pose"].to(self.device),
+                    inputs["depth0"].to(self.device),
+                    r_K0,
+                    inputs["depth1"].to(self.device),
+                    r_K1,
+                    img0.shape[0],
+                )
 
-                reliab_loss_relative = self.reliability_relitive_loss(desc0, desc1, noise_desc0, noise_desc1, conf0, conf1, noise_conf0, noise_conf1, pos0_for_rel, pos1_for_rel, img0.shape[0], img0.shape[2], img0.shape[3])
+                reliab_loss_relative = self.reliability_relitive_loss(
+                    desc0,
+                    desc1,
+                    noise_desc0,
+                    noise_desc1,
+                    conf0,
+                    conf1,
+                    noise_conf0,
+                    noise_conf1,
+                    pos0_for_rel,
+                    pos1_for_rel,
+                    img0.shape[0],
+                    img0.shape[2],
+                    img0.shape[3],
+                )
 
                 det_structured_loss, det_accuracy = make_detector_loss(
-                    pos0, pos1, desc0, desc1,
-                    score_map0, score_map1, img0.shape[0],
-                    self.config['network']['use_corr_n'],
-                    self.config['network']['loss_type'],
-                    self.config
+                    pos0,
+                    pos1,
+                    desc0,
+                    desc1,
+                    score_map0,
+                    score_map1,
+                    img0.shape[0],
+                    self.config["network"]["use_corr_n"],
+                    self.config["network"]["loss_type"],
+                    self.config,
                 )
 
                 det_structured_loss_noise, det_accuracy_noise = make_detector_loss(
-                    pos0, pos1, noise_desc0, noise_desc1,
-                    noise_score_map0, noise_score_map1, img0.shape[0],
-                    self.config['network']['use_corr_n'],
-                    self.config['network']['loss_type'],
-                    self.config
+                    pos0,
+                    pos1,
+                    noise_desc0,
+                    noise_desc1,
+                    noise_score_map0,
+                    noise_score_map1,
+                    img0.shape[0],
+                    self.config["network"]["use_corr_n"],
+                    self.config["network"]["loss_type"],
+                    self.config,
                 )
 
                 indices0, scores0 = extract_kpts(
                     score_map0.permute(0, 3, 1, 2),
-                    k=self.config['network']['det']['kpt_n'],
-                    score_thld=self.config['network']['det']['score_thld'],
-                    nms_size=self.config['network']['det']['nms_size'],
-                    eof_size=self.config['network']['det']['eof_size'],
-                    edge_thld=self.config['network']['det']['edge_thld']
+                    k=self.config["network"]["det"]["kpt_n"],
+                    score_thld=self.config["network"]["det"]["score_thld"],
+                    nms_size=self.config["network"]["det"]["nms_size"],
+                    eof_size=self.config["network"]["det"]["eof_size"],
+                    edge_thld=self.config["network"]["det"]["edge_thld"],
                 )
                 indices1, scores1 = extract_kpts(
                     score_map1.permute(0, 3, 1, 2),
-                    k=self.config['network']['det']['kpt_n'],
-                    score_thld=self.config['network']['det']['score_thld'],
-                    nms_size=self.config['network']['det']['nms_size'],
-                    eof_size=self.config['network']['det']['eof_size'],
-                    edge_thld=self.config['network']['det']['edge_thld']
+                    k=self.config["network"]["det"]["kpt_n"],
+                    score_thld=self.config["network"]["det"]["score_thld"],
+                    nms_size=self.config["network"]["det"]["nms_size"],
+                    eof_size=self.config["network"]["det"]["eof_size"],
+                    edge_thld=self.config["network"]["det"]["edge_thld"],
                 )
 
-                noise_score_loss0, mask0 = make_noise_score_map_loss(score_map0, noise_score_map0, indices0, img0.shape[0], thld=0.1)
-                noise_score_loss1, mask1 = make_noise_score_map_loss(score_map1, noise_score_map1, indices1, img1.shape[0], thld=0.1)
+                noise_score_loss0, mask0 = make_noise_score_map_loss(
+                    score_map0, noise_score_map0, indices0, img0.shape[0], thld=0.1
+                )
+                noise_score_loss1, mask1 = make_noise_score_map_loss(
+                    score_map1, noise_score_map1, indices1, img1.shape[0], thld=0.1
+                )
 
                 total_loss = det_structured_loss + det_structured_loss_noise
-                total_loss += noise_score_loss0 / 2. * 1.
-                total_loss += noise_score_loss1 / 2. * 1.
-                total_loss += reliab_loss_relative[0] / 2. * 0.5
-                total_loss += reliab_loss_relative[1] / 2. * 0.5
-                
+                total_loss += noise_score_loss0 / 2.0 * 1.0
+                total_loss += noise_score_loss1 / 2.0 * 1.0
+                total_loss += reliab_loss_relative[0] / 2.0 * 0.5
+                total_loss += reliab_loss_relative[1] / 2.0 * 0.5
+
                 self.writer.add_scalar("acc/normal_acc", det_accuracy, self.cnt)
                 self.writer.add_scalar("acc/noise_acc", det_accuracy_noise, self.cnt)
                 self.writer.add_scalar("loss/total_loss", total_loss, self.cnt)
-                self.writer.add_scalar("loss/noise_score_loss", (noise_score_loss0 + noise_score_loss1) / 2., self.cnt)
-                self.writer.add_scalar("loss/det_loss_normal", det_structured_loss, self.cnt)
-                self.writer.add_scalar("loss/det_loss_noise", det_structured_loss_noise, self.cnt)
-                print('iter={},\tloss={:.4f},\tacc={:.4f},\t{:.4f}s/iter'.format(self.cnt, total_loss, det_accuracy, time.time()-t))
+                self.writer.add_scalar(
+                    "loss/noise_score_loss",
+                    (noise_score_loss0 + noise_score_loss1) / 2.0,
+                    self.cnt,
+                )
+                self.writer.add_scalar(
+                    "loss/det_loss_normal", det_structured_loss, self.cnt
+                )
+                self.writer.add_scalar(
+                    "loss/det_loss_noise", det_structured_loss_noise, self.cnt
+                )
+                print(
+                    "iter={},\tloss={:.4f},\tacc={:.4f},\t{:.4f}s/iter".format(
+                        self.cnt, total_loss, det_accuracy, time.time() - t
+                    )
+                )
                 # print(f'normal_loss: {det_structured_loss}, noise_loss: {det_structured_loss_noise}, reliab_loss: {reliab_loss_relative[0]}, {reliab_loss_relative[1]}')
 
                 if det_structured_loss != 0:
@@ -249,100 +345,162 @@ class Trainer:
                 if self.cnt % 100 == 0:
                     noise_indices0, noise_scores0 = extract_kpts(
                         noise_score_map0.permute(0, 3, 1, 2),
-                        k=self.config['network']['det']['kpt_n'],
-                        score_thld=self.config['network']['det']['score_thld'],
-                        nms_size=self.config['network']['det']['nms_size'],
-                        eof_size=self.config['network']['det']['eof_size'],
-                        edge_thld=self.config['network']['det']['edge_thld']
+                        k=self.config["network"]["det"]["kpt_n"],
+                        score_thld=self.config["network"]["det"]["score_thld"],
+                        nms_size=self.config["network"]["det"]["nms_size"],
+                        eof_size=self.config["network"]["det"]["eof_size"],
+                        edge_thld=self.config["network"]["det"]["edge_thld"],
                     )
                     noise_indices1, noise_scores1 = extract_kpts(
                         noise_score_map1.permute(0, 3, 1, 2),
-                        k=self.config['network']['det']['kpt_n'],
-                        score_thld=self.config['network']['det']['score_thld'],
-                        nms_size=self.config['network']['det']['nms_size'],
-                        eof_size=self.config['network']['det']['eof_size'],
-                        edge_thld=self.config['network']['det']['edge_thld']
+                        k=self.config["network"]["det"]["kpt_n"],
+                        score_thld=self.config["network"]["det"]["score_thld"],
+                        nms_size=self.config["network"]["det"]["nms_size"],
+                        eof_size=self.config["network"]["det"]["eof_size"],
+                        edge_thld=self.config["network"]["det"]["edge_thld"],
                     )
-                    if self.config['network']['input_type'] == 'raw':
-                        kpt_img0 = self.showKeyPoints(img0_ori[0][..., :3] * 255., indices0[0])
-                        kpt_img1 = self.showKeyPoints(img1_ori[0][..., :3] * 255., indices1[0])
-                        noise_kpt_img0 = self.showKeyPoints(noise_img0_ori[0][..., :3] * 255., noise_indices0[0])
-                        noise_kpt_img1 = self.showKeyPoints(noise_img1_ori[0][..., :3] * 255., noise_indices1[0])
+                    if self.config["network"]["input_type"] == "raw":
+                        kpt_img0 = self.showKeyPoints(
+                            img0_ori[0][..., :3] * 255.0, indices0[0]
+                        )
+                        kpt_img1 = self.showKeyPoints(
+                            img1_ori[0][..., :3] * 255.0, indices1[0]
+                        )
+                        noise_kpt_img0 = self.showKeyPoints(
+                            noise_img0_ori[0][..., :3] * 255.0, noise_indices0[0]
+                        )
+                        noise_kpt_img1 = self.showKeyPoints(
+                            noise_img1_ori[0][..., :3] * 255.0, noise_indices1[0]
+                        )
                     else:
-                        kpt_img0 = self.showKeyPoints(img0_ori[0] * 255., indices0[0])
-                        kpt_img1 = self.showKeyPoints(img1_ori[0] * 255., indices1[0])
-                        noise_kpt_img0 = self.showKeyPoints(noise_img0_ori[0] * 255., noise_indices0[0])
-                        noise_kpt_img1 = self.showKeyPoints(noise_img1_ori[0] * 255., noise_indices1[0])
-
-                    self.writer.add_image('img0/kpts', kpt_img0, self.cnt, dataformats='HWC')
-                    self.writer.add_image('img1/kpts', kpt_img1, self.cnt, dataformats='HWC')
-                    self.writer.add_image('img0/noise_kpts', noise_kpt_img0, self.cnt, dataformats='HWC')
-                    self.writer.add_image('img1/noise_kpts', noise_kpt_img1, self.cnt, dataformats='HWC')
-                    self.writer.add_image('img0/score_map', score_map0[0], self.cnt, dataformats='HWC')
-                    self.writer.add_image('img1/score_map', score_map1[0], self.cnt, dataformats='HWC')
-                    self.writer.add_image('img0/noise_score_map', noise_score_map0[0], self.cnt, dataformats='HWC')
-                    self.writer.add_image('img1/noise_score_map', noise_score_map1[0], self.cnt, dataformats='HWC')
-                    self.writer.add_image('img0/kpt_mask', mask0.unsqueeze(2), self.cnt, dataformats='HWC')
-                    self.writer.add_image('img1/kpt_mask', mask1.unsqueeze(2), self.cnt, dataformats='HWC')
-                    self.writer.add_image('img0/conf', conf0[0], self.cnt, dataformats='HWC')
-                    self.writer.add_image('img1/conf', conf1[0], self.cnt, dataformats='HWC')
-                    self.writer.add_image('img0/noise_conf', noise_conf0[0], self.cnt, dataformats='HWC')
-                    self.writer.add_image('img1/noise_conf', noise_conf1[0], self.cnt, dataformats='HWC')
+                        kpt_img0 = self.showKeyPoints(img0_ori[0] * 255.0, indices0[0])
+                        kpt_img1 = self.showKeyPoints(img1_ori[0] * 255.0, indices1[0])
+                        noise_kpt_img0 = self.showKeyPoints(
+                            noise_img0_ori[0] * 255.0, noise_indices0[0]
+                        )
+                        noise_kpt_img1 = self.showKeyPoints(
+                            noise_img1_ori[0] * 255.0, noise_indices1[0]
+                        )
+
+                    self.writer.add_image(
+                        "img0/kpts", kpt_img0, self.cnt, dataformats="HWC"
+                    )
+                    self.writer.add_image(
+                        "img1/kpts", kpt_img1, self.cnt, dataformats="HWC"
+                    )
+                    self.writer.add_image(
+                        "img0/noise_kpts", noise_kpt_img0, self.cnt, dataformats="HWC"
+                    )
+                    self.writer.add_image(
+                        "img1/noise_kpts", noise_kpt_img1, self.cnt, dataformats="HWC"
+                    )
+                    self.writer.add_image(
+                        "img0/score_map", score_map0[0], self.cnt, dataformats="HWC"
+                    )
+                    self.writer.add_image(
+                        "img1/score_map", score_map1[0], self.cnt, dataformats="HWC"
+                    )
+                    self.writer.add_image(
+                        "img0/noise_score_map",
+                        noise_score_map0[0],
+                        self.cnt,
+                        dataformats="HWC",
+                    )
+                    self.writer.add_image(
+                        "img1/noise_score_map",
+                        noise_score_map1[0],
+                        self.cnt,
+                        dataformats="HWC",
+                    )
+                    self.writer.add_image(
+                        "img0/kpt_mask", mask0.unsqueeze(2), self.cnt, dataformats="HWC"
+                    )
+                    self.writer.add_image(
+                        "img1/kpt_mask", mask1.unsqueeze(2), self.cnt, dataformats="HWC"
+                    )
+                    self.writer.add_image(
+                        "img0/conf", conf0[0], self.cnt, dataformats="HWC"
+                    )
+                    self.writer.add_image(
+                        "img1/conf", conf1[0], self.cnt, dataformats="HWC"
+                    )
+                    self.writer.add_image(
+                        "img0/noise_conf", noise_conf0[0], self.cnt, dataformats="HWC"
+                    )
+                    self.writer.add_image(
+                        "img1/noise_conf", noise_conf1[0], self.cnt, dataformats="HWC"
+                    )
 
                 if self.cnt % 5000 == 0:
                     self.save(self.cnt)
-                
-                self.cnt += 1
 
+                self.cnt += 1
 
     def showKeyPoints(self, img, indices):
         key_points = cv2.KeyPoint_convert(indices.cpu().float().numpy()[:, ::-1])
-        img = img.numpy().astype('uint8')
+        img = img.numpy().astype("uint8")
         img = cv2.drawKeypoints(img, key_points, None, color=(0, 255, 0))
         return img
 
-
     def preprocess(self, img, iter_idx):
-        if not self.config['network']['noise'] and 'raw' not in self.config['network']['input_type']:
+        if (
+            not self.config["network"]["noise"]
+            and "raw" not in self.config["network"]["input_type"]
+        ):
             return img
 
         raw = self.noise_maker.rgb2raw(img, batched=True)
 
-        if self.config['network']['noise']:
-            ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep']
+        if self.config["network"]["noise"]:
+            ratio_dec = (
+                min(self.config["network"]["noise_maxstep"], iter_idx)
+                / self.config["network"]["noise_maxstep"]
+            )
             raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True)
 
-        if self.config['network']['input_type'] == 'raw':
+        if self.config["network"]["input_type"] == "raw":
             return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True))
 
-        if self.config['network']['input_type'] == 'raw-demosaic':
+        if self.config["network"]["input_type"] == "raw-demosaic":
             return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True))
 
         rgb = self.noise_maker.raw2rgb(raw, batched=True)
-        if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray':
+        if (
+            self.config["network"]["input_type"] == "rgb"
+            or self.config["network"]["input_type"] == "gray"
+        ):
             return torch.tensor(rgb)
 
         raise NotImplementedError()
 
-
     def preprocess_noise_pair(self, img, iter_idx):
-        assert self.config['network']['noise']
+        assert self.config["network"]["noise"]
 
         raw = self.noise_maker.rgb2raw(img, batched=True)
 
-        ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep']
-        noise_raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True)
+        ratio_dec = (
+            min(self.config["network"]["noise_maxstep"], iter_idx)
+            / self.config["network"]["noise_maxstep"]
+        )
+        noise_raw = self.noise_maker.raw2noisyRaw(
+            raw, ratio_dec=ratio_dec, batched=True
+        )
 
-        if self.config['network']['input_type'] == 'raw':
-            return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True)), \
-                   torch.tensor(self.noise_maker.raw2packedRaw(noise_raw, batched=True))
+        if self.config["network"]["input_type"] == "raw":
+            return torch.tensor(
+                self.noise_maker.raw2packedRaw(raw, batched=True)
+            ), torch.tensor(self.noise_maker.raw2packedRaw(noise_raw, batched=True))
 
-        if self.config['network']['input_type'] == 'raw-demosaic':
-            return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)), \
-                   torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True))
+        if self.config["network"]["input_type"] == "raw-demosaic":
+            return torch.tensor(
+                self.noise_maker.raw2demosaicRaw(raw, batched=True)
+            ), torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True))
 
         noise_rgb = self.noise_maker.raw2rgb(noise_raw, batched=True)
-        if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray':
+        if (
+            self.config["network"]["input_type"] == "rgb"
+            or self.config["network"]["input_type"] == "gray"
+        ):
             return img, torch.tensor(noise_rgb)
 
         raise NotImplementedError()
diff --git a/third_party/DarkFeat/trainer_single.py b/third_party/DarkFeat/trainer_single.py
index 65566e7e27cfd605eba000d308b6d3610f29e746..0b079d1fc376b3dbd45297902c4d1e195c267156 100644
--- a/third_party/DarkFeat/trainer_single.py
+++ b/third_party/DarkFeat/trainer_single.py
@@ -24,23 +24,29 @@ class SingleTrainer:
         self.config = config
         self.device = device
         self.loader = loader
-        
+
         # tensorboard writer construction
-        os.makedirs('./runs/', exist_ok=True)
-        if job_name != '':
-            self.log_dir = f'runs/{job_name}'
+        os.makedirs("./runs/", exist_ok=True)
+        if job_name != "":
+            self.log_dir = f"runs/{job_name}"
         else:
             self.log_dir = f'runs/{datetime.datetime.now().strftime("%m-%d-%H%M%S")}'
 
         self.writer = SummaryWriter(self.log_dir)
-        with open(f'{self.log_dir}/config.yaml', 'w') as f:
+        with open(f"{self.log_dir}/config.yaml", "w") as f:
             yaml.dump(config, f)
 
-        if config['network']['input_type'] == 'gray' or config['network']['input_type'] == 'raw-gray':
+        if (
+            config["network"]["input_type"] == "gray"
+            or config["network"]["input_type"] == "raw-gray"
+        ):
             self.model = eval(f'{config["network"]["model"]}(inchan=1)').to(device)
-        elif config['network']['input_type'] == 'rgb' or config['network']['input_type'] == 'raw-demosaic':
+        elif (
+            config["network"]["input_type"] == "rgb"
+            or config["network"]["input_type"] == "raw-demosaic"
+        ):
             self.model = eval(f'{config["network"]["model"]}(inchan=3)').to(device)
-        elif config['network']['input_type'] == 'raw':
+        elif config["network"]["input_type"] == "raw":
             self.model = eval(f'{config["network"]["model"]}(inchan=4)').to(device)
         else:
             raise NotImplementedError()
@@ -51,75 +57,98 @@ class SingleTrainer:
         # load model
         self.cnt = 0
         if start_cnt != 0:
-            self.model.load_state_dict(torch.load(f'{self.log_dir}/model_{start_cnt:06d}.pth'))
+            self.model.load_state_dict(
+                torch.load(f"{self.log_dir}/model_{start_cnt:06d}.pth")
+            )
             self.cnt = start_cnt + 1
 
         # sampler
-        sampler = NghSampler2(ngh=7, subq=-8, subd=1, pos_d=3, neg_d=5, border=16,
-            subd_neg=-8,maxpool_pos=True).to(device)
+        sampler = NghSampler2(
+            ngh=7,
+            subq=-8,
+            subd=1,
+            pos_d=3,
+            neg_d=5,
+            border=16,
+            subd_neg=-8,
+            maxpool_pos=True,
+        ).to(device)
         self.reliability_loss = ReliabilityLoss(sampler, base=0.3, nq=20).to(device)
         # reliability map conv
         self.model.clf = nn.Conv2d(128, 2, kernel_size=1).cuda()
 
         # optimizer and scheduler
-        if self.config['training']['optimizer'] == 'SGD':
+        if self.config["training"]["optimizer"] == "SGD":
             self.optimizer = torch.optim.SGD(
-                [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}],
-                lr=self.config['training']['lr'],
-                momentum=self.config['training']['momentum'],
-                weight_decay=self.config['training']['weight_decay'],
+                [
+                    {
+                        "params": self.model.parameters(),
+                        "initial_lr": self.config["training"]["lr"],
+                    }
+                ],
+                lr=self.config["training"]["lr"],
+                momentum=self.config["training"]["momentum"],
+                weight_decay=self.config["training"]["weight_decay"],
             )
-        elif self.config['training']['optimizer'] == 'Adam':
+        elif self.config["training"]["optimizer"] == "Adam":
             self.optimizer = torch.optim.Adam(
-                [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}],
-                lr=self.config['training']['lr'],
-                weight_decay=self.config['training']['weight_decay']
+                [
+                    {
+                        "params": self.model.parameters(),
+                        "initial_lr": self.config["training"]["lr"],
+                    }
+                ],
+                lr=self.config["training"]["lr"],
+                weight_decay=self.config["training"]["weight_decay"],
             )
         else:
             raise NotImplementedError()
 
         self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
             self.optimizer,
-            step_size=self.config['training']['lr_step'],
-            gamma=self.config['training']['lr_gamma'],
-            last_epoch=start_cnt
+            step_size=self.config["training"]["lr_step"],
+            gamma=self.config["training"]["lr_gamma"],
+            last_epoch=start_cnt,
         )
         for param_tensor in self.model.state_dict():
             print(param_tensor, "\t", self.model.state_dict()[param_tensor].size())
 
-
     def save(self, iter_num):
-        torch.save(self.model.state_dict(), f'{self.log_dir}/model_{iter_num:06d}.pth')
+        torch.save(self.model.state_dict(), f"{self.log_dir}/model_{iter_num:06d}.pth")
 
     def load(self, path):
         self.model.load_state_dict(torch.load(path))
 
     def train(self):
         self.model.train()
-        
+
         for epoch in range(2):
             for batch_idx, inputs in enumerate(self.loader):
                 self.optimizer.zero_grad()
                 t = time.time()
 
                 # preprocess and add noise
-                img0_ori, noise_img0_ori = self.preprocess_noise_pair(inputs['img0'], self.cnt)
-                img1_ori, noise_img1_ori = self.preprocess_noise_pair(inputs['img1'], self.cnt)
+                img0_ori, noise_img0_ori = self.preprocess_noise_pair(
+                    inputs["img0"], self.cnt
+                )
+                img1_ori, noise_img1_ori = self.preprocess_noise_pair(
+                    inputs["img1"], self.cnt
+                )
 
                 img0 = img0_ori.permute(0, 3, 1, 2).float().to(self.device)
                 img1 = img1_ori.permute(0, 3, 1, 2).float().to(self.device)
 
-                if self.config['network']['input_type'] == 'rgb':
+                if self.config["network"]["input_type"] == "rgb":
                     # 3-channel rgb
                     RGB_mean = [0.485, 0.456, 0.406]
-                    RGB_std  = [0.229, 0.224, 0.225]
+                    RGB_std = [0.229, 0.224, 0.225]
                     norm_RGB = tvf.Normalize(mean=RGB_mean, std=RGB_std)
                     img0 = norm_RGB(img0)
                     img1 = norm_RGB(img1)
                     noise_img0 = norm_RGB(noise_img0)
                     noise_img1 = norm_RGB(noise_img1)
 
-                elif self.config['network']['input_type'] == 'gray':
+                elif self.config["network"]["input_type"] == "gray":
                     # 1-channel
                     img0 = torch.mean(img0, dim=1, keepdim=True)
                     img1 = torch.mean(img1, dim=1, keepdim=True)
@@ -132,11 +161,11 @@ class SingleTrainer:
                     noise_img0 = norm_gray0(noise_img0)
                     noise_img1 = norm_gray1(noise_img1)
 
-                elif self.config['network']['input_type'] == 'raw':
+                elif self.config["network"]["input_type"] == "raw":
                     # 4-channel
                     pass
 
-                elif self.config['network']['input_type'] == 'raw-demosaic':
+                elif self.config["network"]["input_type"] == "raw-demosaic":
                     # 3-channel
                     pass
 
@@ -149,8 +178,12 @@ class SingleTrainer:
                 cur_feat_size0 = torch.tensor(score_map0.shape[2:])
                 cur_feat_size1 = torch.tensor(score_map1.shape[2:])
 
-                conf0 = F.softmax(self.model.clf(torch.abs(desc0)**2.0), dim=1)[:,1:2]
-                conf1 = F.softmax(self.model.clf(torch.abs(desc1)**2.0), dim=1)[:,1:2]
+                conf0 = F.softmax(self.model.clf(torch.abs(desc0) ** 2.0), dim=1)[
+                    :, 1:2
+                ]
+                conf1 = F.softmax(self.model.clf(torch.abs(desc1) ** 2.0), dim=1)[
+                    :, 1:2
+                ]
 
                 desc0 = desc0.permute(0, 2, 3, 1)
                 desc1 = desc1.permute(0, 2, 3, 1)
@@ -159,39 +192,77 @@ class SingleTrainer:
                 conf0 = conf0.permute(0, 2, 3, 1)
                 conf1 = conf1.permute(0, 2, 3, 1)
 
-                r_K0 = getK(inputs['ori_img_size0'], cur_feat_size0, inputs['K0']).to(self.device)
-                r_K1 = getK(inputs['ori_img_size1'], cur_feat_size1, inputs['K1']).to(self.device)
-                
+                r_K0 = getK(inputs["ori_img_size0"], cur_feat_size0, inputs["K0"]).to(
+                    self.device
+                )
+                r_K1 = getK(inputs["ori_img_size1"], cur_feat_size1, inputs["K1"]).to(
+                    self.device
+                )
+
                 pos0 = _grid_positions(
-                    cur_feat_size0[0], cur_feat_size0[1], img0.shape[0]).to(self.device)
+                    cur_feat_size0[0], cur_feat_size0[1], img0.shape[0]
+                ).to(self.device)
 
                 pos0_for_rel, pos1_for_rel, _ = getWarpNoValidate(
-                    pos0, inputs['rel_pose'].to(self.device), inputs['depth0'].to(self.device),
-                    r_K0, inputs['depth1'].to(self.device), r_K1, img0.shape[0])
+                    pos0,
+                    inputs["rel_pose"].to(self.device),
+                    inputs["depth0"].to(self.device),
+                    r_K0,
+                    inputs["depth1"].to(self.device),
+                    r_K1,
+                    img0.shape[0],
+                )
 
                 pos0, pos1, _ = getWarp(
-                    pos0, inputs['rel_pose'].to(self.device), inputs['depth0'].to(self.device),
-                    r_K0, inputs['depth1'].to(self.device), r_K1, img0.shape[0])
+                    pos0,
+                    inputs["rel_pose"].to(self.device),
+                    inputs["depth0"].to(self.device),
+                    r_K0,
+                    inputs["depth1"].to(self.device),
+                    r_K1,
+                    img0.shape[0],
+                )
 
-                reliab_loss = self.reliability_loss(desc0, desc1, conf0, conf1, pos0_for_rel, pos1_for_rel, img0.shape[0], img0.shape[2], img0.shape[3])
+                reliab_loss = self.reliability_loss(
+                    desc0,
+                    desc1,
+                    conf0,
+                    conf1,
+                    pos0_for_rel,
+                    pos1_for_rel,
+                    img0.shape[0],
+                    img0.shape[2],
+                    img0.shape[3],
+                )
 
                 det_structured_loss, det_accuracy = make_detector_loss(
-                    pos0, pos1, desc0, desc1,
-                    score_map0, score_map1, img0.shape[0],
-                    self.config['network']['use_corr_n'],
-                    self.config['network']['loss_type'],
-                    self.config
+                    pos0,
+                    pos1,
+                    desc0,
+                    desc1,
+                    score_map0,
+                    score_map1,
+                    img0.shape[0],
+                    self.config["network"]["use_corr_n"],
+                    self.config["network"]["loss_type"],
+                    self.config,
                 )
 
                 total_loss = det_structured_loss
-                self.writer.add_scalar("loss/det_loss_normal", det_structured_loss, self.cnt)
-                
+                self.writer.add_scalar(
+                    "loss/det_loss_normal", det_structured_loss, self.cnt
+                )
+
                 total_loss += reliab_loss
-                
+
                 self.writer.add_scalar("acc/normal_acc", det_accuracy, self.cnt)
                 self.writer.add_scalar("loss/total_loss", total_loss, self.cnt)
                 self.writer.add_scalar("loss/reliab_loss", reliab_loss, self.cnt)
-                print('iter={},\tloss={:.4f},\tacc={:.4f},\t{:.4f}s/iter'.format(self.cnt, total_loss, det_accuracy, time.time()-t))
+                print(
+                    "iter={},\tloss={:.4f},\tacc={:.4f},\t{:.4f}s/iter".format(
+                        self.cnt, total_loss, det_accuracy, time.time() - t
+                    )
+                )
 
                 if det_structured_loss != 0:
                     total_loss.backward()
@@ -201,94 +272,133 @@ class SingleTrainer:
                 if self.cnt % 100 == 0:
                     indices0, scores0 = extract_kpts(
                         score_map0.permute(0, 3, 1, 2),
-                        k=self.config['network']['det']['kpt_n'],
-                        score_thld=self.config['network']['det']['score_thld'],
-                        nms_size=self.config['network']['det']['nms_size'],
-                        eof_size=self.config['network']['det']['eof_size'],
-                        edge_thld=self.config['network']['det']['edge_thld']
+                        k=self.config["network"]["det"]["kpt_n"],
+                        score_thld=self.config["network"]["det"]["score_thld"],
+                        nms_size=self.config["network"]["det"]["nms_size"],
+                        eof_size=self.config["network"]["det"]["eof_size"],
+                        edge_thld=self.config["network"]["det"]["edge_thld"],
                     )
                     indices1, scores1 = extract_kpts(
                         score_map1.permute(0, 3, 1, 2),
-                        k=self.config['network']['det']['kpt_n'],
-                        score_thld=self.config['network']['det']['score_thld'],
-                        nms_size=self.config['network']['det']['nms_size'],
-                        eof_size=self.config['network']['det']['eof_size'],
-                        edge_thld=self.config['network']['det']['edge_thld']
+                        k=self.config["network"]["det"]["kpt_n"],
+                        score_thld=self.config["network"]["det"]["score_thld"],
+                        nms_size=self.config["network"]["det"]["nms_size"],
+                        eof_size=self.config["network"]["det"]["eof_size"],
+                        edge_thld=self.config["network"]["det"]["edge_thld"],
                     )
 
-                    if self.config['network']['input_type'] == 'raw':
-                        kpt_img0 = self.showKeyPoints(img0_ori[0][..., :3] * 255., indices0[0])
-                        kpt_img1 = self.showKeyPoints(img1_ori[0][..., :3] * 255., indices1[0])
+                    if self.config["network"]["input_type"] == "raw":
+                        kpt_img0 = self.showKeyPoints(
+                            img0_ori[0][..., :3] * 255.0, indices0[0]
+                        )
+                        kpt_img1 = self.showKeyPoints(
+                            img1_ori[0][..., :3] * 255.0, indices1[0]
+                        )
                     else:
-                        kpt_img0 = self.showKeyPoints(img0_ori[0] * 255., indices0[0])
-                        kpt_img1 = self.showKeyPoints(img1_ori[0] * 255., indices1[0])
+                        kpt_img0 = self.showKeyPoints(img0_ori[0] * 255.0, indices0[0])
+                        kpt_img1 = self.showKeyPoints(img1_ori[0] * 255.0, indices1[0])
 
-                    self.writer.add_image('img0/kpts', kpt_img0, self.cnt, dataformats='HWC')
-                    self.writer.add_image('img1/kpts', kpt_img1, self.cnt, dataformats='HWC')
-                    self.writer.add_image('img0/score_map', score_map0[0], self.cnt, dataformats='HWC')
-                    self.writer.add_image('img1/score_map', score_map1[0], self.cnt, dataformats='HWC')
-                    self.writer.add_image('img0/conf', conf0[0], self.cnt, dataformats='HWC')
-                    self.writer.add_image('img1/conf', conf1[0], self.cnt, dataformats='HWC')
+                    self.writer.add_image(
+                        "img0/kpts", kpt_img0, self.cnt, dataformats="HWC"
+                    )
+                    self.writer.add_image(
+                        "img1/kpts", kpt_img1, self.cnt, dataformats="HWC"
+                    )
+                    self.writer.add_image(
+                        "img0/score_map", score_map0[0], self.cnt, dataformats="HWC"
+                    )
+                    self.writer.add_image(
+                        "img1/score_map", score_map1[0], self.cnt, dataformats="HWC"
+                    )
+                    self.writer.add_image(
+                        "img0/conf", conf0[0], self.cnt, dataformats="HWC"
+                    )
+                    self.writer.add_image(
+                        "img1/conf", conf1[0], self.cnt, dataformats="HWC"
+                    )
 
                 if self.cnt % 10000 == 0:
                     self.save(self.cnt)
-                
-                self.cnt += 1
 
+                self.cnt += 1
 
     def showKeyPoints(self, img, indices):
         key_points = cv2.KeyPoint_convert(indices.cpu().float().numpy()[:, ::-1])
-        img = img.numpy().astype('uint8')
+        img = img.numpy().astype("uint8")
         img = cv2.drawKeypoints(img, key_points, None, color=(0, 255, 0))
         return img
 
-
     def preprocess(self, img, iter_idx):
-        if not self.config['network']['noise'] and 'raw' not in self.config['network']['input_type']:
+        if (
+            not self.config["network"]["noise"]
+            and "raw" not in self.config["network"]["input_type"]
+        ):
             return img
 
         raw = self.noise_maker.rgb2raw(img, batched=True)
 
-        if self.config['network']['noise']:
-            ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep']
+        if self.config["network"]["noise"]:
+            ratio_dec = (
+                min(self.config["network"]["noise_maxstep"], iter_idx)
+                / self.config["network"]["noise_maxstep"]
+            )
             raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True)
 
-        if self.config['network']['input_type'] == 'raw':
+        if self.config["network"]["input_type"] == "raw":
             return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True))
 
-        if self.config['network']['input_type'] == 'raw-demosaic':
+        if self.config["network"]["input_type"] == "raw-demosaic":
             return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True))
 
         rgb = self.noise_maker.raw2rgb(raw, batched=True)
-        if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray':
+        if (
+            self.config["network"]["input_type"] == "rgb"
+            or self.config["network"]["input_type"] == "gray"
+        ):
             return torch.tensor(rgb)
 
         raise NotImplementedError()
 
-
     def preprocess_noise_pair(self, img, iter_idx):
-        assert self.config['network']['noise']
+        assert self.config["network"]["noise"]
 
         raw = self.noise_maker.rgb2raw(img, batched=True)
 
-        ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep']
-        noise_raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True)
+        ratio_dec = (
+            min(self.config["network"]["noise_maxstep"], iter_idx)
+            / self.config["network"]["noise_maxstep"]
+        )
+        noise_raw = self.noise_maker.raw2noisyRaw(
+            raw, ratio_dec=ratio_dec, batched=True
+        )
 
-        if self.config['network']['input_type'] == 'raw':
-            return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True)), \
-                   torch.tensor(self.noise_maker.raw2packedRaw(noise_raw, batched=True))
+        if self.config["network"]["input_type"] == "raw":
+            return torch.tensor(
+                self.noise_maker.raw2packedRaw(raw, batched=True)
+            ), torch.tensor(self.noise_maker.raw2packedRaw(noise_raw, batched=True))
 
-        if self.config['network']['input_type'] == 'raw-demosaic':
-            return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)), \
-                   torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True))
+        if self.config["network"]["input_type"] == "raw-demosaic":
+            return torch.tensor(
+                self.noise_maker.raw2demosaicRaw(raw, batched=True)
+            ), torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True))
 
-        if self.config['network']['input_type'] == 'raw-gray':
+        if self.config["network"]["input_type"] == "raw-gray":
             factor = torch.tensor([0.299, 0.587, 0.114]).double()
-            return torch.matmul(torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)), factor).unsqueeze(-1), \
-                   torch.matmul(torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True)), factor).unsqueeze(-1)
+            return torch.matmul(
+                torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)),
+                factor,
+            ).unsqueeze(-1), torch.matmul(
+                torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True)),
+                factor,
+            ).unsqueeze(
+                -1
+            )
 
         noise_rgb = self.noise_maker.raw2rgb(noise_raw, batched=True)
-        if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray':
+        if (
+            self.config["network"]["input_type"] == "rgb"
+            or self.config["network"]["input_type"] == "gray"
+        ):
             return img, torch.tensor(noise_rgb)
 
         raise NotImplementedError()
diff --git a/third_party/DarkFeat/trainer_single_norel.py b/third_party/DarkFeat/trainer_single_norel.py
index a572e9c599adc30e5753e11e668d121cd378672a..5447a37dabba339183f4e50ef44381ebc7a34998 100644
--- a/third_party/DarkFeat/trainer_single_norel.py
+++ b/third_party/DarkFeat/trainer_single_norel.py
@@ -23,23 +23,29 @@ class SingleTrainerNoRel:
         self.config = config
         self.device = device
         self.loader = loader
-        
+
         # tensorboard writer construction
-        os.makedirs('./runs/', exist_ok=True)
-        if job_name != '':
-            self.log_dir = f'runs/{job_name}'
+        os.makedirs("./runs/", exist_ok=True)
+        if job_name != "":
+            self.log_dir = f"runs/{job_name}"
         else:
             self.log_dir = f'runs/{datetime.datetime.now().strftime("%m-%d-%H%M%S")}'
 
         self.writer = SummaryWriter(self.log_dir)
-        with open(f'{self.log_dir}/config.yaml', 'w') as f:
+        with open(f"{self.log_dir}/config.yaml", "w") as f:
             yaml.dump(config, f)
 
-        if config['network']['input_type'] == 'gray' or config['network']['input_type'] == 'raw-gray':
+        if (
+            config["network"]["input_type"] == "gray"
+            or config["network"]["input_type"] == "raw-gray"
+        ):
             self.model = eval(f'{config["network"]["model"]}(inchan=1)').to(device)
-        elif config['network']['input_type'] == 'rgb' or config['network']['input_type'] == 'raw-demosaic':
+        elif (
+            config["network"]["input_type"] == "rgb"
+            or config["network"]["input_type"] == "raw-demosaic"
+        ):
             self.model = eval(f'{config["network"]["model"]}(inchan=3)').to(device)
-        elif config['network']['input_type'] == 'raw':
+        elif config["network"]["input_type"] == "raw":
             self.model = eval(f'{config["network"]["model"]}(inchan=4)').to(device)
         else:
             raise NotImplementedError()
@@ -50,68 +56,83 @@ class SingleTrainerNoRel:
         # load model
         self.cnt = 0
         if start_cnt != 0:
-            self.model.load_state_dict(torch.load(f'{self.log_dir}/model_{start_cnt:06d}.pth'))
+            self.model.load_state_dict(
+                torch.load(f"{self.log_dir}/model_{start_cnt:06d}.pth")
+            )
             self.cnt = start_cnt + 1
 
         # optimizer and scheduler
-        if self.config['training']['optimizer'] == 'SGD':
+        if self.config["training"]["optimizer"] == "SGD":
             self.optimizer = torch.optim.SGD(
-                [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}],
-                lr=self.config['training']['lr'],
-                momentum=self.config['training']['momentum'],
-                weight_decay=self.config['training']['weight_decay'],
+                [
+                    {
+                        "params": self.model.parameters(),
+                        "initial_lr": self.config["training"]["lr"],
+                    }
+                ],
+                lr=self.config["training"]["lr"],
+                momentum=self.config["training"]["momentum"],
+                weight_decay=self.config["training"]["weight_decay"],
             )
-        elif self.config['training']['optimizer'] == 'Adam':
+        elif self.config["training"]["optimizer"] == "Adam":
             self.optimizer = torch.optim.Adam(
-                [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}],
-                lr=self.config['training']['lr'],
-                weight_decay=self.config['training']['weight_decay']
+                [
+                    {
+                        "params": self.model.parameters(),
+                        "initial_lr": self.config["training"]["lr"],
+                    }
+                ],
+                lr=self.config["training"]["lr"],
+                weight_decay=self.config["training"]["weight_decay"],
             )
         else:
             raise NotImplementedError()
 
         self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
             self.optimizer,
-            step_size=self.config['training']['lr_step'],
-            gamma=self.config['training']['lr_gamma'],
-            last_epoch=start_cnt
+            step_size=self.config["training"]["lr_step"],
+            gamma=self.config["training"]["lr_gamma"],
+            last_epoch=start_cnt,
         )
         for param_tensor in self.model.state_dict():
             print(param_tensor, "\t", self.model.state_dict()[param_tensor].size())
 
-
     def save(self, iter_num):
-        torch.save(self.model.state_dict(), f'{self.log_dir}/model_{iter_num:06d}.pth')
+        torch.save(self.model.state_dict(), f"{self.log_dir}/model_{iter_num:06d}.pth")
 
     def load(self, path):
         self.model.load_state_dict(torch.load(path))
 
     def train(self):
         self.model.train()
-        
+
         for epoch in range(2):
             for batch_idx, inputs in enumerate(self.loader):
                 self.optimizer.zero_grad()
                 t = time.time()
 
                 # preprocess and add noise
-                img0_ori, noise_img0_ori = self.preprocess_noise_pair(inputs['img0'], self.cnt)
-                img1_ori, noise_img1_ori = self.preprocess_noise_pair(inputs['img1'], self.cnt)
+                img0_ori, noise_img0_ori = self.preprocess_noise_pair(
+                    inputs["img0"], self.cnt
+                )
+                img1_ori, noise_img1_ori = self.preprocess_noise_pair(
+                    inputs["img1"], self.cnt
+                )
 
                 img0 = img0_ori.permute(0, 3, 1, 2).float().to(self.device)
                 img1 = img1_ori.permute(0, 3, 1, 2).float().to(self.device)
 
-                if self.config['network']['input_type'] == 'rgb':
+                if self.config["network"]["input_type"] == "rgb":
                     # 3-channel rgb
                     RGB_mean = [0.485, 0.456, 0.406]
-                    RGB_std  = [0.229, 0.224, 0.225]
+                    RGB_std = [0.229, 0.224, 0.225]
                     norm_RGB = tvf.Normalize(mean=RGB_mean, std=RGB_std)
                     img0 = norm_RGB(img0)
                     img1 = norm_RGB(img1)
                     noise_img0 = norm_RGB(noise_img0)
                     noise_img1 = norm_RGB(noise_img1)
 
-                elif self.config['network']['input_type'] == 'gray':
+                elif self.config["network"]["input_type"] == "gray":
                     # 1-channel
                     img0 = torch.mean(img0, dim=1, keepdim=True)
                     img1 = torch.mean(img1, dim=1, keepdim=True)
@@ -124,11 +145,11 @@ class SingleTrainerNoRel:
                     noise_img0 = norm_gray0(noise_img0)
                     noise_img1 = norm_gray1(noise_img1)
 
-                elif self.config['network']['input_type'] == 'raw':
+                elif self.config["network"]["input_type"] == "raw":
                     # 4-channel
                     pass
 
-                elif self.config['network']['input_type'] == 'raw-demosaic':
+                elif self.config["network"]["input_type"] == "raw-demosaic":
                     # 3-channel
                     pass
 
@@ -146,30 +167,52 @@ class SingleTrainerNoRel:
                 score_map0 = score_map0.permute(0, 2, 3, 1)
                 score_map1 = score_map1.permute(0, 2, 3, 1)
 
-                r_K0 = getK(inputs['ori_img_size0'], cur_feat_size0, inputs['K0']).to(self.device)
-                r_K1 = getK(inputs['ori_img_size1'], cur_feat_size1, inputs['K1']).to(self.device)
-                
+                r_K0 = getK(inputs["ori_img_size0"], cur_feat_size0, inputs["K0"]).to(
+                    self.device
+                )
+                r_K1 = getK(inputs["ori_img_size1"], cur_feat_size1, inputs["K1"]).to(
+                    self.device
+                )
+
                 pos0 = _grid_positions(
-                    cur_feat_size0[0], cur_feat_size0[1], img0.shape[0]).to(self.device)
+                    cur_feat_size0[0], cur_feat_size0[1], img0.shape[0]
+                ).to(self.device)
 
                 pos0, pos1, _ = getWarp(
-                    pos0, inputs['rel_pose'].to(self.device), inputs['depth0'].to(self.device),
-                    r_K0, inputs['depth1'].to(self.device), r_K1, img0.shape[0])
+                    pos0,
+                    inputs["rel_pose"].to(self.device),
+                    inputs["depth0"].to(self.device),
+                    r_K0,
+                    inputs["depth1"].to(self.device),
+                    r_K1,
+                    img0.shape[0],
+                )
 
                 det_structured_loss, det_accuracy = make_detector_loss(
-                    pos0, pos1, desc0, desc1,
-                    score_map0, score_map1, img0.shape[0],
-                    self.config['network']['use_corr_n'],
-                    self.config['network']['loss_type'],
-                    self.config
+                    pos0,
+                    pos1,
+                    desc0,
+                    desc1,
+                    score_map0,
+                    score_map1,
+                    img0.shape[0],
+                    self.config["network"]["use_corr_n"],
+                    self.config["network"]["loss_type"],
+                    self.config,
                 )
 
                 total_loss = det_structured_loss
-                
+
                 self.writer.add_scalar("acc/normal_acc", det_accuracy, self.cnt)
                 self.writer.add_scalar("loss/total_loss", total_loss, self.cnt)
-                self.writer.add_scalar("loss/det_loss_normal", det_structured_loss, self.cnt)
-                print('iter={},\tloss={:.4f},\tacc={:.4f},\t{:.4f}s/iter'.format(self.cnt, total_loss, det_accuracy, time.time()-t))
+                self.writer.add_scalar(
+                    "loss/det_loss_normal", det_structured_loss, self.cnt
+                )
+                print(
+                    "iter={},\tloss={:.4f},\tacc={:.4f},\t{:.4f}s/iter".format(
+                        self.cnt, total_loss, det_accuracy, time.time() - t
+                    )
+                )
 
                 if det_structured_loss != 0:
                     total_loss.backward()
@@ -179,87 +222,115 @@ class SingleTrainerNoRel:
                 if self.cnt % 100 == 0:
                     indices0, scores0 = extract_kpts(
                         score_map0.permute(0, 3, 1, 2),
-                        k=self.config['network']['det']['kpt_n'],
-                        score_thld=self.config['network']['det']['score_thld'],
-                        nms_size=self.config['network']['det']['nms_size'],
-                        eof_size=self.config['network']['det']['eof_size'],
-                        edge_thld=self.config['network']['det']['edge_thld']
+                        k=self.config["network"]["det"]["kpt_n"],
+                        score_thld=self.config["network"]["det"]["score_thld"],
+                        nms_size=self.config["network"]["det"]["nms_size"],
+                        eof_size=self.config["network"]["det"]["eof_size"],
+                        edge_thld=self.config["network"]["det"]["edge_thld"],
                     )
                     indices1, scores1 = extract_kpts(
                         score_map1.permute(0, 3, 1, 2),
-                        k=self.config['network']['det']['kpt_n'],
-                        score_thld=self.config['network']['det']['score_thld'],
-                        nms_size=self.config['network']['det']['nms_size'],
-                        eof_size=self.config['network']['det']['eof_size'],
-                        edge_thld=self.config['network']['det']['edge_thld']
+                        k=self.config["network"]["det"]["kpt_n"],
+                        score_thld=self.config["network"]["det"]["score_thld"],
+                        nms_size=self.config["network"]["det"]["nms_size"],
+                        eof_size=self.config["network"]["det"]["eof_size"],
+                        edge_thld=self.config["network"]["det"]["edge_thld"],
                     )
 
-                    if self.config['network']['input_type'] == 'raw':
-                        kpt_img0 = self.showKeyPoints(img0_ori[0][..., :3] * 255., indices0[0])
-                        kpt_img1 = self.showKeyPoints(img1_ori[0][..., :3] * 255., indices1[0])
+                    if self.config["network"]["input_type"] == "raw":
+                        kpt_img0 = self.showKeyPoints(
+                            img0_ori[0][..., :3] * 255.0, indices0[0]
+                        )
+                        kpt_img1 = self.showKeyPoints(
+                            img1_ori[0][..., :3] * 255.0, indices1[0]
+                        )
                     else:
-                        kpt_img0 = self.showKeyPoints(img0_ori[0] * 255., indices0[0])
-                        kpt_img1 = self.showKeyPoints(img1_ori[0] * 255., indices1[0])
+                        kpt_img0 = self.showKeyPoints(img0_ori[0] * 255.0, indices0[0])
+                        kpt_img1 = self.showKeyPoints(img1_ori[0] * 255.0, indices1[0])
 
-                    self.writer.add_image('img0/kpts', kpt_img0, self.cnt, dataformats='HWC')
-                    self.writer.add_image('img1/kpts', kpt_img1, self.cnt, dataformats='HWC')
-                    self.writer.add_image('img0/score_map', score_map0[0], self.cnt, dataformats='HWC')
-                    self.writer.add_image('img1/score_map', score_map1[0], self.cnt, dataformats='HWC')
+                    self.writer.add_image(
+                        "img0/kpts", kpt_img0, self.cnt, dataformats="HWC"
+                    )
+                    self.writer.add_image(
+                        "img1/kpts", kpt_img1, self.cnt, dataformats="HWC"
+                    )
+                    self.writer.add_image(
+                        "img0/score_map", score_map0[0], self.cnt, dataformats="HWC"
+                    )
+                    self.writer.add_image(
+                        "img1/score_map", score_map1[0], self.cnt, dataformats="HWC"
+                    )
 
                 if self.cnt % 10000 == 0:
                     self.save(self.cnt)
-                
-                self.cnt += 1
 
+                self.cnt += 1
 
     def showKeyPoints(self, img, indices):
         key_points = cv2.KeyPoint_convert(indices.cpu().float().numpy()[:, ::-1])
-        img = img.numpy().astype('uint8')
+        img = img.numpy().astype("uint8")
         img = cv2.drawKeypoints(img, key_points, None, color=(0, 255, 0))
         return img
 
-
     def preprocess(self, img, iter_idx):
-        if not self.config['network']['noise'] and 'raw' not in self.config['network']['input_type']:
+        if (
+            not self.config["network"]["noise"]
+            and "raw" not in self.config["network"]["input_type"]
+        ):
             return img
 
         raw = self.noise_maker.rgb2raw(img, batched=True)
 
-        if self.config['network']['noise']:
-            ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep']
+        if self.config["network"]["noise"]:
+            ratio_dec = (
+                min(self.config["network"]["noise_maxstep"], iter_idx)
+                / self.config["network"]["noise_maxstep"]
+            )
             raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True)
 
-        if self.config['network']['input_type'] == 'raw':
+        if self.config["network"]["input_type"] == "raw":
             return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True))
 
-        if self.config['network']['input_type'] == 'raw-demosaic':
+        if self.config["network"]["input_type"] == "raw-demosaic":
             return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True))
 
         rgb = self.noise_maker.raw2rgb(raw, batched=True)
-        if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray':
+        if (
+            self.config["network"]["input_type"] == "rgb"
+            or self.config["network"]["input_type"] == "gray"
+        ):
             return torch.tensor(rgb)
 
         raise NotImplementedError()
 
-
     def preprocess_noise_pair(self, img, iter_idx):
-        assert self.config['network']['noise']
+        assert self.config["network"]["noise"]
 
         raw = self.noise_maker.rgb2raw(img, batched=True)
 
-        ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep']
-        noise_raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True)
+        ratio_dec = (
+            min(self.config["network"]["noise_maxstep"], iter_idx)
+            / self.config["network"]["noise_maxstep"]
+        )
+        noise_raw = self.noise_maker.raw2noisyRaw(
+            raw, ratio_dec=ratio_dec, batched=True
+        )
 
-        if self.config['network']['input_type'] == 'raw':
-            return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True)), \
-                   torch.tensor(self.noise_maker.raw2packedRaw(noise_raw, batched=True))
+        if self.config["network"]["input_type"] == "raw":
+            return torch.tensor(
+                self.noise_maker.raw2packedRaw(raw, batched=True)
+            ), torch.tensor(self.noise_maker.raw2packedRaw(noise_raw, batched=True))
 
-        if self.config['network']['input_type'] == 'raw-demosaic':
-            return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)), \
-                   torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True))
+        if self.config["network"]["input_type"] == "raw-demosaic":
+            return torch.tensor(
+                self.noise_maker.raw2demosaicRaw(raw, batched=True)
+            ), torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True))
 
         noise_rgb = self.noise_maker.raw2rgb(noise_raw, batched=True)
-        if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray':
+        if (
+            self.config["network"]["input_type"] == "rgb"
+            or self.config["network"]["input_type"] == "gray"
+        ):
             return img, torch.tensor(noise_rgb)
 
         raise NotImplementedError()
diff --git a/third_party/DarkFeat/utils/matching.py b/third_party/DarkFeat/utils/matching.py
index ca091f418bb4dc4d278611e5126a930aa51e7f3f..78c2415cf54ec3942c94ded3afec381ba63b358a 100644
--- a/third_party/DarkFeat/utils/matching.py
+++ b/third_party/DarkFeat/utils/matching.py
@@ -2,24 +2,26 @@ import math
 import numpy as np
 import cv2
 
+
 def extract_ORB_keypoints_and_descriptors(img):
     # gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
     detector = cv2.ORB_create(nfeatures=1000)
     kp, desc = detector.detectAndCompute(img, None)
     return kp, desc
 
+
 def match_descriptors_NG(kp1, desc1, kp2, desc2):
     bf = cv2.BFMatcher()
     try:
-        matches = bf.knnMatch(desc1, desc2,k=2)
+        matches = bf.knnMatch(desc1, desc2, k=2)
     except:
         matches = []
-    good_matches=[]
+    good_matches = []
     image1_kp = []
     image2_kp = []
     ratios = []
     try:
-        for (m1,m2) in matches:
+        for (m1, m2) in matches:
             if m1.distance < 0.8 * m2.distance:
                 good_matches.append(m1)
                 image2_kp.append(kp2[m1.trainIdx].pt)
@@ -33,41 +35,42 @@ def match_descriptors_NG(kp1, desc1, kp2, desc2):
     ratios = np.expand_dims(ratios, 2)
     return image1_kp, image2_kp, good_matches, ratios
 
+
 def match_descriptors(kp1, desc1, kp2, desc2, ORB):
     if ORB:
         bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
         try:
-            matches = bf.match(desc1,desc2)
-            matches = sorted(matches, key = lambda x:x.distance)
+            matches = bf.match(desc1, desc2)
+            matches = sorted(matches, key=lambda x: x.distance)
         except:
             matches = []
-        good_matches=[]
+        good_matches = []
         image1_kp = []
         image2_kp = []
         count = 0
         try:
             for m in matches:
-                count+=1
+                count += 1
                 if count < 1000:
                     good_matches.append(m)
                     image2_kp.append(kp2[m.trainIdx].pt)
-                    image1_kp.append(kp1[m.queryIdx].pt)  
+                    image1_kp.append(kp1[m.queryIdx].pt)
         except:
             pass
     else:
         # Match the keypoints with the warped_keypoints with nearest neighbor search
         bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
         try:
-            matches = bf.match(desc1.transpose(1,0), desc2.transpose(1,0))  
-            matches = sorted(matches, key = lambda x:x.distance)
+            matches = bf.match(desc1.transpose(1, 0), desc2.transpose(1, 0))
+            matches = sorted(matches, key=lambda x: x.distance)
         except:
             matches = []
-        good_matches=[]
+        good_matches = []
         image1_kp = []
         image2_kp = []
         try:
             for m in matches:
-                good_matches.append(m)              
+                good_matches.append(m)
                 image2_kp.append(kp2[m.trainIdx].pt)
                 image1_kp.append(kp1[m.queryIdx].pt)
         except:
@@ -79,18 +82,28 @@ def match_descriptors(kp1, desc1, kp2, desc2, ORB):
 
 
 def compute_essential(matched_kp1, matched_kp2, K):
-    pts1 = cv2.undistortPoints(matched_kp1,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0))
-    pts2 = cv2.undistortPoints(matched_kp2,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0))
+    pts1 = cv2.undistortPoints(
+        matched_kp1,
+        cameraMatrix=K,
+        distCoeffs=(-0.117918271740560, 0.075246403574314, 0, 0),
+    )
+    pts2 = cv2.undistortPoints(
+        matched_kp2,
+        cameraMatrix=K,
+        distCoeffs=(-0.117918271740560, 0.075246403574314, 0, 0),
+    )
     K_1 = np.eye(3)
     # Estimate the homography between the matches using RANSAC
-    ransac_model, ransac_inliers = cv2.findEssentialMat(pts1, pts2, K_1, method=cv2.FM_RANSAC, prob=0.999, threshold=0.001)
-    if ransac_inliers is None or ransac_model.shape != (3,3):
+    ransac_model, ransac_inliers = cv2.findEssentialMat(
+        pts1, pts2, K_1, method=cv2.FM_RANSAC, prob=0.999, threshold=0.001
+    )
+    if ransac_inliers is None or ransac_model.shape != (3, 3):
         ransac_inliers = np.array([])
         ransac_model = None
     return ransac_model, ransac_inliers, pts1, pts2
 
 
-def compute_error(R_GT,t_GT,E,pts1_norm, pts2_norm, inliers):
+def compute_error(R_GT, t_GT, E, pts1_norm, pts2_norm, inliers):
     """Compute the angular error between two rotation matrices and two translation vectors.
     Keyword arguments:
     R -- 2D numpy array containing an estimated rotation
@@ -101,14 +114,14 @@ def compute_error(R_GT,t_GT,E,pts1_norm, pts2_norm, inliers):
 
     inliers = inliers.ravel()
     R = np.eye(3)
-    t = np.zeros((3,1))
+    t = np.zeros((3, 1))
     sst = True
     try:
         cv2.recoverPose(E, pts1_norm, pts2_norm, np.eye(3), R, t, inliers)
     except:
         sst = False
     # calculate angle between provided rotations
-    # 
+    #
     if sst:
         dR = np.matmul(R, np.transpose(R_GT))
         dR = cv2.Rodrigues(dR)[0]
@@ -119,10 +132,10 @@ def compute_error(R_GT,t_GT,E,pts1_norm, pts2_norm, inliers):
         dT /= float(np.linalg.norm(t_GT))
 
         if dT > 1 or dT < -1:
-            print("Domain warning! dT:",dT)
-            dT = max(-1,min(1,dT))
+            print("Domain warning! dT:", dT)
+            dT = max(-1, min(1, dT))
         dT = math.acos(dT) * 180 / math.pi
-        dT = np.minimum(dT, 180 - dT) # ambiguity of E estimation
+        dT = np.minimum(dT, 180 - dT)  # ambiguity of E estimation
     else:
-        dR,dT = 180.0, 180.0
+        dR, dT = 180.0, 180.0
     return dR, dT
diff --git a/third_party/DarkFeat/utils/misc.py b/third_party/DarkFeat/utils/misc.py
index 1df6fdec97121486dbb94e0b32a2f66c85c48f7d..7d5ac3c8be8f8aacaaf4ec59f19b3278b963f572 100644
--- a/third_party/DarkFeat/utils/misc.py
+++ b/third_party/DarkFeat/utils/misc.py
@@ -9,7 +9,7 @@ import colour_demosaicing
 
 
 class AverageTimer:
-    """ Class to help manage printing simple timing of code execution. """
+    """Class to help manage printing simple timing of code execution."""
 
     def __init__(self, smoothing=0.3, newline=False):
         self.smoothing = smoothing
@@ -25,7 +25,7 @@ class AverageTimer:
         for name in self.will_print:
             self.will_print[name] = False
 
-    def update(self, name='default'):
+    def update(self, name="default"):
         now = time.time()
         dt = now - self.last_time
         if name in self.times:
@@ -34,19 +34,19 @@ class AverageTimer:
         self.will_print[name] = True
         self.last_time = now
 
-    def print(self, text='Timer'):
-        total = 0.
-        print('[{}]'.format(text), end=' ')
+    def print(self, text="Timer"):
+        total = 0.0
+        print("[{}]".format(text), end=" ")
         for key in self.times:
             val = self.times[key]
             if self.will_print[key]:
-                print('%s=%.3f' % (key, val), end=' ')
+                print("%s=%.3f" % (key, val), end=" ")
                 total += val
-        print('total=%.3f sec {%.1f FPS}' % (total, 1./total), end=' ')
+        print("total=%.3f sec {%.1f FPS}" % (total, 1.0 / total), end=" ")
         if self.newline:
             print(flush=True)
         else:
-            print(end='\r', flush=True)
+            print(end="\r", flush=True)
         self.reset()
 
 
@@ -56,32 +56,36 @@ class VideoStreamer:
         self.resize = resize
         self.i = 0
         if Path(basedir).is_dir():
-            print('==> Processing image directory input: {}'.format(basedir))
+            print("==> Processing image directory input: {}".format(basedir))
             self.listing = list(Path(basedir).glob(image_glob[0]))
             for j in range(1, len(image_glob)):
                 image_path = list(Path(basedir).glob(image_glob[j]))
                 self.listing = self.listing + image_path
             self.listing.sort()
             if len(self.listing) == 0:
-                raise IOError('No images found (maybe bad \'image_glob\' ?)')
+                raise IOError("No images found (maybe bad 'image_glob' ?)")
             self.max_length = len(self.listing)
         else:
-            raise ValueError('VideoStreamer input \"{}\" not recognized.'.format(basedir))
+            raise ValueError('VideoStreamer input "{}" not recognized.'.format(basedir))
 
     def load_image(self, impath):
         raw = rawpy.imread(str(impath)).raw_image_visible
-        raw = np.clip(raw.astype('float32') - 512, 0, 65535)
-        img = colour_demosaicing.demosaicing_CFA_Bayer_bilinear(raw, 'RGGB').astype('float32')
+        raw = np.clip(raw.astype("float32") - 512, 0, 65535)
+        img = colour_demosaicing.demosaicing_CFA_Bayer_bilinear(raw, "RGGB").astype(
+            "float32"
+        )
         img = np.clip(img, 0, 16383)
 
         m = img.mean()
         d = np.abs(img - img.mean()).mean()
-        img = (img - m + 2*d) / 4/d * 255
+        img = (img - m + 2 * d) / 4 / d * 255
         image = np.clip(img, 0, 255)
 
         w_new, h_new = self.resize[0], self.resize[1]
 
-        im = cv2.resize(image.astype('float32'), (w_new, h_new), interpolation=cv2.INTER_AREA)
+        im = cv2.resize(
+            image.astype("float32"), (w_new, h_new), interpolation=cv2.INTER_AREA
+        )
         return im
 
     def next_frame(self):
@@ -95,57 +99,103 @@ class VideoStreamer:
 
 def frame2tensor(frame, device):
     if len(frame.shape) == 2:
-        return torch.from_numpy(frame/255.).float()[None, None].to(device)
+        return torch.from_numpy(frame / 255.0).float()[None, None].to(device)
     else:
-        return torch.from_numpy(frame/255.).float().permute(2, 0, 1)[None].to(device)
-
-
-def make_matching_plot_fast(image0, image1, mkpts0, mkpts1,
-                            color, text, path=None, margin=10,
-                            opencv_display=False, opencv_title='',
-                            small_text=[]):
+        return torch.from_numpy(frame / 255.0).float().permute(2, 0, 1)[None].to(device)
+
+
+def make_matching_plot_fast(
+    image0,
+    image1,
+    mkpts0,
+    mkpts1,
+    color,
+    text,
+    path=None,
+    margin=10,
+    opencv_display=False,
+    opencv_title="",
+    small_text=[],
+):
     H0, W0 = image0.shape[:2]
     H1, W1 = image1.shape[:2]
     H, W = max(H0, H1), W0 + W1 + margin
 
-    out = 255*np.ones((H, W, 3), np.uint8)
+    out = 255 * np.ones((H, W, 3), np.uint8)
     out[:H0, :W0, :] = image0
-    out[:H1, W0+margin:, :] = image1
+    out[:H1, W0 + margin :, :] = image1
 
     # Scale factor for consistent visualization across scales.
-    sc = min(H / 640., 2.0)
+    sc = min(H / 640.0, 2.0)
 
     # Big text.
     Ht = int(30 * sc)  # text height
     txt_color_fg = (255, 255, 255)
     txt_color_bg = (0, 0, 0)
-    
+
     for i, t in enumerate(text):
-        cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX,
-                    1.0*sc, txt_color_bg, 2, cv2.LINE_AA)
-        cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX,
-                    1.0*sc, txt_color_fg, 1, cv2.LINE_AA)
+        cv2.putText(
+            out,
+            t,
+            (int(8 * sc), Ht * (i + 1)),
+            cv2.FONT_HERSHEY_DUPLEX,
+            1.0 * sc,
+            txt_color_bg,
+            2,
+            cv2.LINE_AA,
+        )
+        cv2.putText(
+            out,
+            t,
+            (int(8 * sc), Ht * (i + 1)),
+            cv2.FONT_HERSHEY_DUPLEX,
+            1.0 * sc,
+            txt_color_fg,
+            1,
+            cv2.LINE_AA,
+        )
 
     out_backup = out.copy()
 
     mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int)
-    color = (np.array(color[:, :3])*255).astype(int)[:, ::-1]
+    color = (np.array(color[:, :3]) * 255).astype(int)[:, ::-1]
     for (x0, y0), (x1, y1), c in zip(mkpts0, mkpts1, color):
         c = c.tolist()
-        cv2.line(out, (x0, y0), (x1 + margin + W0, y1),
-                 color=c, thickness=1, lineType=cv2.LINE_AA)
+        cv2.line(
+            out,
+            (x0, y0),
+            (x1 + margin + W0, y1),
+            color=c,
+            thickness=1,
+            lineType=cv2.LINE_AA,
+        )
         # display line end-points as circles
         cv2.circle(out, (x0, y0), 2, c, -1, lineType=cv2.LINE_AA)
-        cv2.circle(out, (x1 + margin + W0, y1), 2, c, -1,
-                   lineType=cv2.LINE_AA)
+        cv2.circle(out, (x1 + margin + W0, y1), 2, c, -1, lineType=cv2.LINE_AA)
 
     # Small text.
     Ht = int(18 * sc)  # text height
     for i, t in enumerate(reversed(small_text)):
-        cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX,
-                    0.5*sc, txt_color_bg, 2, cv2.LINE_AA)
-        cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX,
-                    0.5*sc, txt_color_fg, 1, cv2.LINE_AA)
+        cv2.putText(
+            out,
+            t,
+            (int(8 * sc), int(H - Ht * (i + 0.6))),
+            cv2.FONT_HERSHEY_DUPLEX,
+            0.5 * sc,
+            txt_color_bg,
+            2,
+            cv2.LINE_AA,
+        )
+        cv2.putText(
+            out,
+            t,
+            (int(8 * sc), int(H - Ht * (i + 0.6))),
+            cv2.FONT_HERSHEY_DUPLEX,
+            0.5 * sc,
+            txt_color_fg,
+            1,
+            cv2.LINE_AA,
+        )
 
     if path is not None:
         cv2.imwrite(str(path), out)
@@ -153,6 +203,5 @@ def make_matching_plot_fast(image0, image1, mkpts0, mkpts1,
     if opencv_display:
         cv2.imshow(opencv_title, out)
         cv2.waitKey(1)
-        
-    return out / 2 + out_backup / 2
 
+    return out / 2 + out_backup / 2
diff --git a/third_party/DarkFeat/utils/nn.py b/third_party/DarkFeat/utils/nn.py
index 8a80631d6e12d848cceee3b636baf49deaa7647a..956256aeae1b83700044f8f2df18f8913348ebe7 100644
--- a/third_party/DarkFeat/utils/nn.py
+++ b/third_party/DarkFeat/utils/nn.py
@@ -7,8 +7,8 @@ class NN2(nn.Module):
         super().__init__()
 
     def forward(self, data):
-        desc1, desc2 = data['descriptors0'].cuda(), data['descriptors1'].cuda()
-        kpts1, kpts2 = data['keypoints0'].cuda(), data['keypoints1'].cuda()
+        desc1, desc2 = data["descriptors0"].cuda(), data["descriptors1"].cuda()
+        kpts1, kpts2 = data["keypoints0"].cuda(), data["keypoints1"].cuda()
 
         # torch.cuda.synchronize()
         # t = time.time()
@@ -16,10 +16,10 @@ class NN2(nn.Module):
         if kpts1.shape[1] <= 1 or kpts2.shape[1] <= 1:  # no keypoints
             shape0, shape1 = kpts1.shape[:-1], kpts2.shape[:-1]
             return {
-                'matches0': kpts1.new_full(shape0, -1, dtype=torch.int),
-                'matches1': kpts2.new_full(shape1, -1, dtype=torch.int),
-                'matching_scores0': kpts1.new_zeros(shape0),
-                'matching_scores1': kpts2.new_zeros(shape1),
+                "matches0": kpts1.new_full(shape0, -1, dtype=torch.int),
+                "matches1": kpts2.new_full(shape1, -1, dtype=torch.int),
+                "matching_scores0": kpts1.new_zeros(shape0),
+                "matching_scores1": kpts2.new_zeros(shape1),
             }
 
         sim = torch.matmul(desc1.squeeze().T, desc2.squeeze())
@@ -28,14 +28,16 @@ class NN2(nn.Module):
 
         nn21 = torch.argmax(sim, dim=0)
         mask = torch.eq(ids1, nn21[nn12])
-        matches = torch.stack([torch.masked_select(ids1, mask), torch.masked_select(nn12, mask)])
+        matches = torch.stack(
+            [torch.masked_select(ids1, mask), torch.masked_select(nn12, mask)]
+        )
         # matches = torch.stack([ids1, nn12])
         indices0 = torch.ones((1, desc1.shape[-1]), dtype=int) * -1
         mscores0 = torch.ones((1, desc1.shape[-1]), dtype=float) * -1
 
         # torch.cuda.synchronize()
         # print(time.time() - t)
-            
+
         matches_0 = matches[0].cpu().int().numpy()
         matches_1 = matches[1].cpu().int()
         for i in range(matches.shape[-1]):
@@ -43,8 +45,8 @@ class NN2(nn.Module):
             mscores0[0, matches_0[i]] = sim[matches_0[i], matches_1[i]]
 
         return {
-            'matches0': indices0, # use -1 for invalid match
-            'matches1': indices0, # use -1 for invalid match
-            'matching_scores0': mscores0,
-            'matching_scores1': mscores0,
+            "matches0": indices0,  # use -1 for invalid match
+            "matches1": indices0,  # use -1 for invalid match
+            "matching_scores0": mscores0,
+            "matching_scores1": mscores0,
         }
diff --git a/third_party/DarkFeat/utils/nnmatching.py b/third_party/DarkFeat/utils/nnmatching.py
index 7be6f98c050fc2e416ef48e25ca0f293106c1082..6289623c28989dc73dfbeb1763228f301c62831b 100644
--- a/third_party/DarkFeat/utils/nnmatching.py
+++ b/third_party/DarkFeat/utils/nnmatching.py
@@ -3,28 +3,28 @@ import torch
 from .nn import NN2
 from darkfeat import DarkFeat
 
+
 class NNMatching(torch.nn.Module):
-    def __init__(self, model_path=''):
+    def __init__(self, model_path=""):
         super().__init__()
         self.nn = NN2().eval()
         self.darkfeat = DarkFeat(model_path).eval()
 
     def forward(self, data):
-        """ Run DarkFeat and nearest neighborhood matching
+        """Run DarkFeat and nearest neighborhood matching
         Args:
           data: dictionary with minimal keys: ['image0', 'image1']
         """
         pred = {}
 
         # Extract DarkFeat (keypoints, scores, descriptors)
-        if 'keypoints0' not in data:
-            pred0 = self.darkfeat({'image': data['image0']})
+        if "keypoints0" not in data:
+            pred0 = self.darkfeat({"image": data["image0"]})
             # print({k+'0': v[0].shape for k, v in pred0.items()})
-            pred = {**pred, **{k+'0': [v] for k, v in pred0.items()}}
-        if 'keypoints1' not in data:
-            pred1 = self.darkfeat({'image': data['image1']})
-            pred = {**pred, **{k+'1': [v] for k, v in pred1.items()}}
-        
+            pred = {**pred, **{k + "0": [v] for k, v in pred0.items()}}
+        if "keypoints1" not in data:
+            pred1 = self.darkfeat({"image": data["image1"]})
+            pred = {**pred, **{k + "1": [v] for k, v in pred1.items()}}
 
         # Batch all features
         # We should either have i) one image per batch, or
diff --git a/third_party/DeDoDe/DeDoDe/benchmarks/__init__.py b/third_party/DeDoDe/DeDoDe/benchmarks/__init__.py
index 52113027f2e7ddc144453df9f012f84d3b4ba95b..f428121d175af9f9786cfa9cf9c340b94a170521 100644
--- a/third_party/DeDoDe/DeDoDe/benchmarks/__init__.py
+++ b/third_party/DeDoDe/DeDoDe/benchmarks/__init__.py
@@ -1,3 +1,3 @@
 from .num_inliers import NumInliersBenchmark
 from .mega_pose_est import MegaDepthPoseEstimationBenchmark
-from .mega_pose_est_mnn import MegaDepthPoseMNNBenchmark
\ No newline at end of file
+from .mega_pose_est_mnn import MegaDepthPoseMNNBenchmark
diff --git a/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est.py b/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est.py
index 2104284b54d5fe339d6f12d9ae14dcdd3c0fb564..66292fe5a6efbdf328e5f27d806479616455cff7 100644
--- a/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est.py
+++ b/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est.py
@@ -5,8 +5,9 @@ from PIL import Image
 from tqdm import tqdm
 import torch.nn.functional as F
 
+
 class MegaDepthPoseEstimationBenchmark:
-    def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
+    def __init__(self, data_root="data/megadepth", scene_names=None) -> None:
         if scene_names is None:
             self.scene_names = [
                 "0015_0.1_0.3.npz",
@@ -23,14 +24,23 @@ class MegaDepthPoseEstimationBenchmark:
         ]
         self.data_root = data_root
 
-    def benchmark(self, keypoint_model, matching_model, model_name = None, resolution = None, scale_intrinsics = True, calibrated = True):
-        H,W = matching_model.get_output_resolution()
+    def benchmark(
+        self,
+        keypoint_model,
+        matching_model,
+        model_name=None,
+        resolution=None,
+        scale_intrinsics=True,
+        calibrated=True,
+    ):
+        H, W = matching_model.get_output_resolution()
         with torch.no_grad():
             data_root = self.data_root
             tot_e_t, tot_e_R, tot_e_pose = [], [], []
             thresholds = [5, 10, 20]
             for scene_ind in range(len(self.scenes)):
                 import os
+
                 scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
                 scene = self.scenes[scene_ind]
                 pairs = scene["pair_infos"]
@@ -47,14 +57,20 @@ class MegaDepthPoseEstimationBenchmark:
                     T2 = poses[idx2].copy()
                     R2, t2 = T2[:3, :3], T2[:3, 3]
                     R, t = compute_relative_pose(R1, t1, R2, t2)
-                    T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
+                    T1_to_2 = np.concatenate((R, t[:, None]), axis=-1)
                     im_A_path = f"{data_root}/{im_paths[idx1]}"
                     im_B_path = f"{data_root}/{im_paths[idx2]}"
-                    
-                    keypoints_A = keypoint_model.detect_from_path(im_A_path, num_keypoints = 20_000)["keypoints"][0]
-                    keypoints_B = keypoint_model.detect_from_path(im_B_path, num_keypoints = 20_000)["keypoints"][0]
+
+                    keypoints_A = keypoint_model.detect_from_path(
+                        im_A_path, num_keypoints=20_000
+                    )["keypoints"][0]
+                    keypoints_B = keypoint_model.detect_from_path(
+                        im_B_path, num_keypoints=20_000
+                    )["keypoints"][0]
                     warp, certainty = matching_model.match(im_A_path, im_B_path)
-                    matches = matching_model.match_keypoints(keypoints_A, keypoints_B, warp, certainty, return_tuple = False)                    
+                    matches = matching_model.match_keypoints(
+                        keypoints_A, keypoints_B, warp, certainty, return_tuple=False
+                    )
                     im_A = Image.open(im_A_path)
                     w1, h1 = im_A.size
                     im_B = Image.open(im_B_path)
@@ -67,15 +83,20 @@ class MegaDepthPoseEstimationBenchmark:
                         K1, K2 = K1.copy(), K2.copy()
                         K1[:2] = K1[:2] * scale1
                         K2[:2] = K2[:2] * scale2
-                    kpts1, kpts2 = matching_model.to_pixel_coordinates(matches, h1, w1, h2, w2)
+                    kpts1, kpts2 = matching_model.to_pixel_coordinates(
+                        matches, h1, w1, h2, w2
+                    )
                     for _ in range(1):
                         shuffling = np.random.permutation(np.arange(len(kpts1)))
                         kpts1 = kpts1[shuffling]
                         kpts2 = kpts2[shuffling]
                         try:
-                            threshold = 0.5 
+                            threshold = 0.5
                             if calibrated:
-                                norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
+                                norm_threshold = threshold / (
+                                    np.mean(np.abs(K1[:2, :2]))
+                                    + np.mean(np.abs(K2[:2, :2]))
+                                )
                                 R_est, t_est, mask = estimate_pose(
                                     kpts1.cpu().numpy(),
                                     kpts2.cpu().numpy(),
@@ -111,4 +132,4 @@ class MegaDepthPoseEstimationBenchmark:
                 "map_5": map_5,
                 "map_10": map_10,
                 "map_20": map_20,
-            }
\ No newline at end of file
+            }
diff --git a/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est_mnn.py b/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est_mnn.py
index 15f4cdea05c601173fab765b92d5379e8f0bb349..e979bddfb09ff8760d83442b284662376a074998 100644
--- a/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est_mnn.py
+++ b/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est_mnn.py
@@ -5,8 +5,9 @@ from PIL import Image
 from tqdm import tqdm
 import torch.nn.functional as F
 
+
 class MegaDepthPoseMNNBenchmark:
-    def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
+    def __init__(self, data_root="data/megadepth", scene_names=None) -> None:
         if scene_names is None:
             self.scene_names = [
                 "0015_0.1_0.3.npz",
@@ -23,13 +24,23 @@ class MegaDepthPoseMNNBenchmark:
         ]
         self.data_root = data_root
 
-    def benchmark(self, detector_model, descriptor_model, matcher_model, model_name = None, resolution = None, scale_intrinsics = True, calibrated = True):
+    def benchmark(
+        self,
+        detector_model,
+        descriptor_model,
+        matcher_model,
+        model_name=None,
+        resolution=None,
+        scale_intrinsics=True,
+        calibrated=True,
+    ):
         with torch.no_grad():
             data_root = self.data_root
             tot_e_t, tot_e_R, tot_e_pose = [], [], []
             thresholds = [5, 10, 20]
             for scene_ind in range(len(self.scenes)):
                 import os
+
                 scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
                 scene = self.scenes[scene_ind]
                 pairs = scene["pair_infos"]
@@ -46,19 +57,36 @@ class MegaDepthPoseMNNBenchmark:
                     T2 = poses[idx2].copy()
                     R2, t2 = T2[:3, :3], T2[:3, 3]
                     R, t = compute_relative_pose(R1, t1, R2, t2)
-                    T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
+                    T1_to_2 = np.concatenate((R, t[:, None]), axis=-1)
                     im_A_path = f"{data_root}/{im_paths[idx1]}"
                     im_B_path = f"{data_root}/{im_paths[idx2]}"
                     detections_A = detector_model.detect_from_path(im_A_path)
-                    keypoints_A, P_A = detections_A["keypoints"], detections_A["confidence"]
+                    keypoints_A, P_A = (
+                        detections_A["keypoints"],
+                        detections_A["confidence"],
+                    )
                     detections_B = detector_model.detect_from_path(im_B_path)
-                    keypoints_B, P_B = detections_B["keypoints"], detections_B["confidence"]
-                    description_A = descriptor_model.describe_keypoints_from_path(im_A_path, keypoints_A)["descriptions"]
-                    description_B = descriptor_model.describe_keypoints_from_path(im_B_path, keypoints_B)["descriptions"]
-                    matches_A, matches_B, batch_ids = matcher_model.match(keypoints_A, description_A,
-                        keypoints_B, description_B,
-                        P_A = P_A, P_B = P_B,
-                        normalize = True, inv_temp=20, threshold = 0.01)
+                    keypoints_B, P_B = (
+                        detections_B["keypoints"],
+                        detections_B["confidence"],
+                    )
+                    description_A = descriptor_model.describe_keypoints_from_path(
+                        im_A_path, keypoints_A
+                    )["descriptions"]
+                    description_B = descriptor_model.describe_keypoints_from_path(
+                        im_B_path, keypoints_B
+                    )["descriptions"]
+                    matches_A, matches_B, batch_ids = matcher_model.match(
+                        keypoints_A,
+                        description_A,
+                        keypoints_B,
+                        description_B,
+                        P_A=P_A,
+                        P_B=P_B,
+                        normalize=True,
+                        inv_temp=20,
+                        threshold=0.01,
+                    )
 
                     im_A = Image.open(im_A_path)
                     w1, h1 = im_A.size
@@ -72,15 +100,20 @@ class MegaDepthPoseMNNBenchmark:
                         K1, K2 = K1.copy(), K2.copy()
                         K1[:2] = K1[:2] * scale1
                         K2[:2] = K2[:2] * scale2
-                    kpts1, kpts2 = matcher_model.to_pixel_coords(matches_A, matches_B, h1, w1, h2, w2)
+                    kpts1, kpts2 = matcher_model.to_pixel_coords(
+                        matches_A, matches_B, h1, w1, h2, w2
+                    )
                     for _ in range(1):
                         shuffling = np.random.permutation(np.arange(len(kpts1)))
                         kpts1 = kpts1[shuffling]
                         kpts2 = kpts2[shuffling]
                         try:
-                            threshold = 0.5 
+                            threshold = 0.5
                             if calibrated:
-                                norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
+                                norm_threshold = threshold / (
+                                    np.mean(np.abs(K1[:2, :2]))
+                                    + np.mean(np.abs(K2[:2, :2]))
+                                )
                                 R_est, t_est, mask = estimate_pose(
                                     kpts1.cpu().numpy(),
                                     kpts2.cpu().numpy(),
@@ -116,4 +149,4 @@ class MegaDepthPoseMNNBenchmark:
                 "map_5": map_5,
                 "map_10": map_10,
                 "map_20": map_20,
-            }
\ No newline at end of file
+            }
diff --git a/third_party/DeDoDe/DeDoDe/benchmarks/num_inliers.py b/third_party/DeDoDe/DeDoDe/benchmarks/num_inliers.py
index 24be32b2bc54f1d650836e5ab2f540e80fd3d5c0..f2b36f6a2b97b9c7010ef2455352531ffe3e4405 100644
--- a/third_party/DeDoDe/DeDoDe/benchmarks/num_inliers.py
+++ b/third_party/DeDoDe/DeDoDe/benchmarks/num_inliers.py
@@ -3,39 +3,56 @@ import torch.nn as nn
 from DeDoDe.utils import *
 import DeDoDe
 
+
 class NumInliersBenchmark(nn.Module):
-    
-    def __init__(self, dataset, num_samples = 1000, batch_size = 8, num_keypoints = 10_000, device = "cuda") -> None:
+    def __init__(
+        self,
+        dataset,
+        num_samples=1000,
+        batch_size=8,
+        num_keypoints=10_000,
+        device="cuda",
+    ) -> None:
         super().__init__()
         sampler = torch.utils.data.WeightedRandomSampler(
-                torch.ones(len(dataset)), replacement=False, num_samples=num_samples
-            )
+            torch.ones(len(dataset)), replacement=False, num_samples=num_samples
+        )
         dataloader = torch.utils.data.DataLoader(
-                dataset, batch_size=batch_size, num_workers=batch_size, sampler=sampler
-            )
+            dataset, batch_size=batch_size, num_workers=batch_size, sampler=sampler
+        )
         self.dataloader = dataloader
         self.tracked_metrics = {}
         self.batch_size = batch_size
         self.N = len(dataloader)
         self.num_keypoints = num_keypoints
-    
-    def compute_batch_metrics(self, outputs, batch, device = "cuda"):
+
+    def compute_batch_metrics(self, outputs, batch, device="cuda"):
         kpts_A, kpts_B = outputs["keypoints_A"], outputs["keypoints_B"]
         B, K, H, W = batch["im_A"].shape
-        gt_warp_A_to_B, valid_mask_A_to_B = get_gt_warp(                
-                    batch["im_A_depth"],
-                    batch["im_B_depth"],
-                    batch["T_1to2"],
-                    batch["K1"],
-                    batch["K2"],
-                    H=H,
-                    W=W,
-                )
-        kpts_A_to_B = F.grid_sample(gt_warp_A_to_B[...,2:].float().permute(0,3,1,2), kpts_A[...,None,:], 
-                                    align_corners=False, mode = 'bilinear')[...,0].mT
-        legit_A_to_B = F.grid_sample(valid_mask_A_to_B.reshape(B,1,H,W), kpts_A[...,None,:], 
-                                    align_corners=False, mode = 'bilinear')[...,0,:,0]
-        dists = (torch.cdist(kpts_A_to_B, kpts_B).min(dim=-1).values[legit_A_to_B > 0.]).float()
+        gt_warp_A_to_B, valid_mask_A_to_B = get_gt_warp(
+            batch["im_A_depth"],
+            batch["im_B_depth"],
+            batch["T_1to2"],
+            batch["K1"],
+            batch["K2"],
+            H=H,
+            W=W,
+        )
+        kpts_A_to_B = F.grid_sample(
+            gt_warp_A_to_B[..., 2:].float().permute(0, 3, 1, 2),
+            kpts_A[..., None, :],
+            align_corners=False,
+            mode="bilinear",
+        )[..., 0].mT
+        legit_A_to_B = F.grid_sample(
+            valid_mask_A_to_B.reshape(B, 1, H, W),
+            kpts_A[..., None, :],
+            align_corners=False,
+            mode="bilinear",
+        )[..., 0, :, 0]
+        dists = (
+            torch.cdist(kpts_A_to_B, kpts_B).min(dim=-1).values[legit_A_to_B > 0.0]
+        ).float()
         if legit_A_to_B.sum() == 0:
             return
         percent_inliers_at_1 = (dists < 0.02).float().mean()
@@ -44,33 +61,65 @@ class NumInliersBenchmark(nn.Module):
         percent_inliers_at_01 = (dists < 0.002).float().mean()
         percent_inliers_at_005 = (dists < 0.001).float().mean()
 
-        inlier_bins = torch.linspace(0, 0.002, steps = 100, device = device)[None]
-        inlier_counts = (dists[...,None] < inlier_bins).float().mean(dim=0)
-        self.tracked_metrics["inlier_counts"] = self.tracked_metrics.get("inlier_counts", 0) + 1/self.N * inlier_counts
-        self.tracked_metrics["percent_inliers_at_1"] = self.tracked_metrics.get("percent_inliers_at_1", 0) + 1/self.N * percent_inliers_at_1
-        self.tracked_metrics["percent_inliers_at_05"] = self.tracked_metrics.get("percent_inliers_at_05", 0) + 1/self.N * percent_inliers_at_05
-        self.tracked_metrics["percent_inliers_at_025"] = self.tracked_metrics.get("percent_inliers_at_025", 0) + 1/self.N * percent_inliers_at_025
-        self.tracked_metrics["percent_inliers_at_01"] = self.tracked_metrics.get("percent_inliers_at_01", 0) + 1/self.N * percent_inliers_at_01
-        self.tracked_metrics["percent_inliers_at_005"] = self.tracked_metrics.get("percent_inliers_at_005", 0) + 1/self.N * percent_inliers_at_005
+        inlier_bins = torch.linspace(0, 0.002, steps=100, device=device)[None]
+        inlier_counts = (dists[..., None] < inlier_bins).float().mean(dim=0)
+        self.tracked_metrics["inlier_counts"] = (
+            self.tracked_metrics.get("inlier_counts", 0) + 1 / self.N * inlier_counts
+        )
+        self.tracked_metrics["percent_inliers_at_1"] = (
+            self.tracked_metrics.get("percent_inliers_at_1", 0)
+            + 1 / self.N * percent_inliers_at_1
+        )
+        self.tracked_metrics["percent_inliers_at_05"] = (
+            self.tracked_metrics.get("percent_inliers_at_05", 0)
+            + 1 / self.N * percent_inliers_at_05
+        )
+        self.tracked_metrics["percent_inliers_at_025"] = (
+            self.tracked_metrics.get("percent_inliers_at_025", 0)
+            + 1 / self.N * percent_inliers_at_025
+        )
+        self.tracked_metrics["percent_inliers_at_01"] = (
+            self.tracked_metrics.get("percent_inliers_at_01", 0)
+            + 1 / self.N * percent_inliers_at_01
+        )
+        self.tracked_metrics["percent_inliers_at_005"] = (
+            self.tracked_metrics.get("percent_inliers_at_005", 0)
+            + 1 / self.N * percent_inliers_at_005
+        )
 
     def benchmark(self, detector):
         self.tracked_metrics = {}
         from tqdm import tqdm
+
         print("Evaluating percent inliers...")
-        for idx, batch in tqdm(enumerate(self.dataloader), mininterval = 10.):
+        for idx, batch in tqdm(enumerate(self.dataloader), mininterval=10.0):
             batch = to_cuda(batch)
-            outputs = detector.detect(batch, num_keypoints = self.num_keypoints)
-            keypoints_A, keypoints_B = outputs["keypoints"][:self.batch_size], outputs["keypoints"][self.batch_size:] 
+            outputs = detector.detect(batch, num_keypoints=self.num_keypoints)
+            keypoints_A, keypoints_B = (
+                outputs["keypoints"][: self.batch_size],
+                outputs["keypoints"][self.batch_size :],
+            )
             if isinstance(outputs["keypoints"], (tuple, list)):
-                keypoints_A, keypoints_B = torch.stack(keypoints_A), torch.stack(keypoints_B)
+                keypoints_A, keypoints_B = torch.stack(keypoints_A), torch.stack(
+                    keypoints_B
+                )
             outputs = {"keypoints_A": keypoints_A, "keypoints_B": keypoints_B}
             self.compute_batch_metrics(outputs, batch)
         import matplotlib.pyplot as plt
-        plt.plot(torch.linspace(0, 0.002, steps = 100), self.tracked_metrics["inlier_counts"].cpu())
+
+        plt.plot(
+            torch.linspace(0, 0.002, steps=100),
+            self.tracked_metrics["inlier_counts"].cpu(),
+        )
         import numpy as np
-        x = np.linspace(0,0.002, 100)
+
+        x = np.linspace(0, 0.002, 100)
         sigma = 0.52 * 2 / 512
-        F = 1 - np.exp(-x**2 / (2*sigma**2))
+        F = 1 - np.exp(-(x**2) / (2 * sigma**2))
         plt.plot(x, F)
         plt.savefig("vis/inlier_counts")
-        [print(name, metric.item() * self.N / (idx+1)) for name, metric in self.tracked_metrics.items() if "percent" in name]
\ No newline at end of file
+        [
+            print(name, metric.item() * self.N / (idx + 1))
+            for name, metric in self.tracked_metrics.items()
+            if "percent" in name
+        ]
diff --git a/third_party/DeDoDe/DeDoDe/checkpoint.py b/third_party/DeDoDe/DeDoDe/checkpoint.py
index 07d6f80ae09acf5702475504a8e8d61f40c21cd3..6429ca8b6999a133455bb9e271618f50be4a0ed8 100644
--- a/third_party/DeDoDe/DeDoDe/checkpoint.py
+++ b/third_party/DeDoDe/DeDoDe/checkpoint.py
@@ -6,6 +6,7 @@ import gc
 
 import DeDoDe
 
+
 class CheckPoint:
     def __init__(self, dir=None, name="tmp"):
         self.name = name
@@ -18,7 +19,7 @@ class CheckPoint:
         optimizer,
         lr_scheduler,
         n,
-        ):
+    ):
         if DeDoDe.RANK == 0:
             assert model is not None
             if isinstance(model, (DataParallel, DistributedDataParallel)):
@@ -31,14 +32,14 @@ class CheckPoint:
             }
             torch.save(states, self.dir + self.name + f"_latest.pth")
             print(f"Saved states {list(states.keys())}, at step {n}")
-    
+
     def load(
         self,
         model,
         optimizer,
         lr_scheduler,
         n,
-        ):
+    ):
         if os.path.exists(self.dir + self.name + f"_latest.pth") and DeDoDe.RANK == 0:
             states = torch.load(self.dir + self.name + f"_latest.pth")
             if "model" in states:
@@ -56,4 +57,4 @@ class CheckPoint:
             del states
             gc.collect()
             torch.cuda.empty_cache()
-        return model, optimizer, lr_scheduler, n
\ No newline at end of file
+        return model, optimizer, lr_scheduler, n
diff --git a/third_party/DeDoDe/DeDoDe/datasets/megadepth.py b/third_party/DeDoDe/DeDoDe/datasets/megadepth.py
index 7de9d9a8e270fb74a6591944878c0e5e70ddf650..70d76d471c0d0bd5b8545e28ea06a7d178a1abf6 100644
--- a/third_party/DeDoDe/DeDoDe/datasets/megadepth.py
+++ b/third_party/DeDoDe/DeDoDe/datasets/megadepth.py
@@ -10,6 +10,7 @@ from DeDoDe.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
 import DeDoDe
 from DeDoDe.utils import *
 
+
 class MegadepthScene:
     def __init__(
         self,
@@ -23,14 +24,16 @@ class MegadepthScene:
         scene_info_detections=None,
         scene_info_detections3D=None,
         normalize=True,
-        max_num_pairs = 100_000,
-        scene_name = None,
-        use_horizontal_flip_aug = False,
-        grayscale = False,
-        clahe = False,
+        max_num_pairs=100_000,
+        scene_name=None,
+        use_horizontal_flip_aug=False,
+        grayscale=False,
+        clahe=False,
     ) -> None:
         self.data_root = data_root
-        self.scene_name = os.path.splitext(scene_name)[0]+f"_{min_overlap}_{max_overlap}"
+        self.scene_name = (
+            os.path.splitext(scene_name)[0] + f"_{min_overlap}_{max_overlap}"
+        )
         self.image_paths = scene_info["image_paths"]
         self.depth_paths = scene_info["depth_paths"]
         self.intrinsics = scene_info["intrinsics"]
@@ -49,7 +52,9 @@ class MegadepthScene:
             self.pairs = self.pairs[pairinds]
             self.overlaps = self.overlaps[pairinds]
         self.im_transform_ops = get_tuple_transform_ops(
-            resize=(ht, wt), normalize=normalize, clahe = clahe,
+            resize=(ht, wt),
+            normalize=normalize,
+            clahe=clahe,
         )
         self.depth_transform_ops = get_depth_tuple_transform_ops(
             resize=(ht, wt), normalize=False
@@ -62,17 +67,19 @@ class MegadepthScene:
     def load_im(self, im_B, crop=None):
         im = Image.open(im_B)
         return im
-    
-    def horizontal_flip(self, im_A, im_B, depth_A, depth_B,  K_A, K_B):
+
+    def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
         im_A = im_A.flip(-1)
         im_B = im_B.flip(-1)
-        depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) 
-        flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
-        K_A = flip_mat@K_A  
-        K_B = flip_mat@K_B  
-        
+        depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
+        flip_mat = torch.tensor([[-1, 0, self.wt], [0, 1, 0], [0, 0, 1.0]]).to(
+            K_A.device
+        )
+        K_A = flip_mat @ K_A
+        K_B = flip_mat @ K_B
+
         return im_A, im_B, depth_A, depth_B, K_A, K_B
-        
+
     def load_depth(self, depth_ref, crop=None):
         depth = np.array(h5py.File(depth_ref, "r")["depth"])
         return torch.from_numpy(depth)
@@ -87,8 +94,8 @@ class MegadepthScene:
 
     def scale_detections(self, detections, wi, hi):
         sx, sy = self.wt / wi, self.ht / hi
-        return detections * torch.tensor([[sx,sy]])
-    
+        return detections * torch.tensor([[sx, sy]])
+
     def rand_shake(self, *things):
         t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=(2))
         return [
@@ -99,18 +106,27 @@ class MegadepthScene:
     def tracks_to_detections(self, tracks3D, pose, intrinsics, H, W):
         tracks3D = tracks3D.double()
         intrinsics = intrinsics.double()
-        bearing_vectors = pose[...,:3,:3] @ tracks3D.mT + pose[...,:3,3:]        
+        bearing_vectors = pose[..., :3, :3] @ tracks3D.mT + pose[..., :3, 3:]
         hom_pixel_coords = (intrinsics @ bearing_vectors).mT
-        pixel_coords = hom_pixel_coords[...,:2] / (hom_pixel_coords[...,2:]+1e-12)
-        legit_detections = (pixel_coords > 0).prod(dim = -1) * (pixel_coords[...,0] < W - 1) * (pixel_coords[...,1] < H - 1) * (tracks3D != 0).prod(dim=-1)
+        pixel_coords = hom_pixel_coords[..., :2] / (hom_pixel_coords[..., 2:] + 1e-12)
+        legit_detections = (
+            (pixel_coords > 0).prod(dim=-1)
+            * (pixel_coords[..., 0] < W - 1)
+            * (pixel_coords[..., 1] < H - 1)
+            * (tracks3D != 0).prod(dim=-1)
+        )
         return pixel_coords.float(), legit_detections.bool()
 
     def __getitem__(self, pair_idx):
         try:
             # read intrinsics of original size
             idx1, idx2 = self.pairs[pair_idx]
-            K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3)
-            K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3)
+            K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(
+                3, 3
+            )
+            K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(
+                3, 3
+            )
 
             # read and compute relative poses
             T1 = self.poses[idx1]
@@ -138,19 +154,23 @@ class MegadepthScene:
 
             detections2D_A = self.detections[idx1]
             detections2D_B = self.detections[idx2]
-            
+
             K = 10000
-            tracks3D_A = torch.zeros(K,3)
-            tracks3D_B = torch.zeros(K,3)
-            tracks3D_A[:len(detections2D_A)] = torch.tensor(self.tracks3D[detections2D_A[:K,-1].astype(np.int32)])
-            tracks3D_B[:len(detections2D_B)] = torch.tensor(self.tracks3D[detections2D_B[:K,-1].astype(np.int32)])
-            
-            #projs_A, _ = self.tracks_to_detections(tracks3D_A, T1, K1, W_A, H_A)
-            #tracks3D_B = torch.zeros(K,2)
+            tracks3D_A = torch.zeros(K, 3)
+            tracks3D_B = torch.zeros(K, 3)
+            tracks3D_A[: len(detections2D_A)] = torch.tensor(
+                self.tracks3D[detections2D_A[:K, -1].astype(np.int32)]
+            )
+            tracks3D_B[: len(detections2D_B)] = torch.tensor(
+                self.tracks3D[detections2D_B[:K, -1].astype(np.int32)]
+            )
+
+            # projs_A, _ = self.tracks_to_detections(tracks3D_A, T1, K1, W_A, H_A)
+            # tracks3D_B = torch.zeros(K,2)
 
             K1 = self.scale_intrinsic(K1, W_A, H_A)
             K2 = self.scale_intrinsic(K2, W_B, H_B)
-            
+
             # Process images
             im_A, im_B = self.im_transform_ops((im_A, im_B))
             depth_A, depth_B = self.depth_transform_ops(
@@ -159,34 +179,43 @@ class MegadepthScene:
             [im_A, depth_A], t_A = self.rand_shake(im_A, depth_A)
             [im_B, depth_B], t_B = self.rand_shake(im_B, depth_B)
 
-            detections_A = -torch.ones(K,2)
-            detections_B = -torch.ones(K,2)
-            detections_A[:len(self.detections[idx1])] = self.scale_detections(torch.tensor(detections2D_A[:K,:2]), W_A, H_A) + t_A
-            detections_B[:len(self.detections[idx2])] = self.scale_detections(torch.tensor(detections2D_B[:K,:2]), W_B, H_B) + t_B
+            detections_A = -torch.ones(K, 2)
+            detections_B = -torch.ones(K, 2)
+            detections_A[: len(self.detections[idx1])] = (
+                self.scale_detections(torch.tensor(detections2D_A[:K, :2]), W_A, H_A)
+                + t_A
+            )
+            detections_B[: len(self.detections[idx2])] = (
+                self.scale_detections(torch.tensor(detections2D_B[:K, :2]), W_B, H_B)
+                + t_B
+            )
 
-            
             K1[:2, 2] += t_A
             K2[:2, 2] += t_B
-                    
+
             if self.use_horizontal_flip_aug:
                 if np.random.rand() > 0.5:
-                    im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
-                    detections_A[:,0] = W-detections_A
-                    detections_B[:,0] = W-detections_B
-                    
+                    im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(
+                        im_A, im_B, depth_A, depth_B, K1, K2
+                    )
+                    detections_A[:, 0] = W - detections_A
+                    detections_B[:, 0] = W - detections_B
+
             if DeDoDe.DEBUG_MODE:
-                tensor_to_pil(im_A[0], unnormalize=True).save(
-                                f"vis/im_A.jpg")
-                tensor_to_pil(im_B[0], unnormalize=True).save(
-                                f"vis/im_B.jpg")
+                tensor_to_pil(im_A[0], unnormalize=True).save(f"vis/im_A.jpg")
+                tensor_to_pil(im_B[0], unnormalize=True).save(f"vis/im_B.jpg")
             if self.grayscale:
-                im_A = im_A.mean(dim=-3,keepdim=True)
-                im_B = im_B.mean(dim=-3,keepdim=True)
+                im_A = im_A.mean(dim=-3, keepdim=True)
+                im_B = im_B.mean(dim=-3, keepdim=True)
             data_dict = {
                 "im_A": im_A,
-                "im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
+                "im_A_identifier": self.image_paths[idx1]
+                .split("/")[-1]
+                .split(".jpg")[0],
                 "im_B": im_B,
-                "im_B_identifier": self.image_paths[idx2].split("/")[-1].split(".jpg")[0],
+                "im_B_identifier": self.image_paths[idx2]
+                .split("/")[-1]
+                .split(".jpg")[0],
                 "im_A_depth": depth_A[0, 0],
                 "im_B_depth": depth_B[0, 0],
                 "pose_A": T1,
@@ -211,19 +240,48 @@ class MegadepthScene:
 
 
 class MegadepthBuilder:
-    def __init__(self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) -> None:
+    def __init__(
+        self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore=True
+    ) -> None:
         self.data_root = data_root
         self.scene_info_root = os.path.join(data_root, "prep_scene_info")
         self.all_scenes = os.listdir(self.scene_info_root)
         self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
         # LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those
-        self.loftr_ignore_scenes = set(['0121.npy', '0133.npy', '0168.npy', '0178.npy', '0229.npy', '0349.npy', '0412.npy', '0430.npy', '0443.npy', '1001.npy', '5014.npy', '5015.npy', '5016.npy'])
-        self.imc21_scenes = set(['0008.npy', '0019.npy', '0021.npy', '0024.npy', '0025.npy', '0032.npy', '0063.npy', '1589.npy'])
+        self.loftr_ignore_scenes = set(
+            [
+                "0121.npy",
+                "0133.npy",
+                "0168.npy",
+                "0178.npy",
+                "0229.npy",
+                "0349.npy",
+                "0412.npy",
+                "0430.npy",
+                "0443.npy",
+                "1001.npy",
+                "5014.npy",
+                "5015.npy",
+                "5016.npy",
+            ]
+        )
+        self.imc21_scenes = set(
+            [
+                "0008.npy",
+                "0019.npy",
+                "0021.npy",
+                "0024.npy",
+                "0025.npy",
+                "0032.npy",
+                "0063.npy",
+                "1589.npy",
+            ]
+        )
         self.test_scenes_loftr = ["0015.npy", "0022.npy"]
         self.loftr_ignore = loftr_ignore
         self.imc21_ignore = imc21_ignore
 
-    def build_scenes(self, split="train", min_overlap=0.0, scene_names = None, **kwargs):
+    def build_scenes(self, split="train", min_overlap=0.0, scene_names=None, **kwargs):
         if split == "train":
             scene_names = set(self.all_scenes) - set(self.test_scenes)
         elif split == "train_loftr":
@@ -248,15 +306,27 @@ class MegadepthBuilder:
                 os.path.join(self.scene_info_root, scene_name), allow_pickle=True
             ).item()
             scene_info_detections = np.load(
-                os.path.join(self.scene_info_root, "detections", f"detections_{scene_name}"), allow_pickle=True
+                os.path.join(
+                    self.scene_info_root, "detections", f"detections_{scene_name}"
+                ),
+                allow_pickle=True,
             ).item()
             scene_info_detections3D = np.load(
-                os.path.join(self.scene_info_root, "detections3D", f"detections3D_{scene_name}"), allow_pickle=True
+                os.path.join(
+                    self.scene_info_root, "detections3D", f"detections3D_{scene_name}"
+                ),
+                allow_pickle=True,
             )
 
             scenes.append(
                 MegadepthScene(
-                    self.data_root, scene_info, scene_info_detections = scene_info_detections, scene_info_detections3D = scene_info_detections3D, min_overlap=min_overlap,scene_name = scene_name, **kwargs
+                    self.data_root,
+                    scene_info,
+                    scene_info_detections=scene_info_detections,
+                    scene_info_detections3D=scene_info_detections3D,
+                    min_overlap=min_overlap,
+                    scene_name=scene_name,
+                    **kwargs,
                 )
             )
         return scenes
diff --git a/third_party/DeDoDe/DeDoDe/decoder.py b/third_party/DeDoDe/DeDoDe/decoder.py
index 4e1b58fcc588e6ee12c591b5f446829a914bc611..76f6c3b86e309e9f18e5525e132128c2de08c747 100644
--- a/third_party/DeDoDe/DeDoDe/decoder.py
+++ b/third_party/DeDoDe/DeDoDe/decoder.py
@@ -4,19 +4,26 @@ import torchvision.models as tvm
 
 
 class Decoder(nn.Module):
-    def __init__(self, layers, *args, super_resolution = False, num_prototypes = 1, **kwargs) -> None:
+    def __init__(
+        self, layers, *args, super_resolution=False, num_prototypes=1, **kwargs
+    ) -> None:
         super().__init__(*args, **kwargs)
         self.layers = layers
         self.scales = self.layers.keys()
         self.super_resolution = super_resolution
         self.num_prototypes = num_prototypes
-    def forward(self, features, context = None, scale = None):
+
+    def forward(self, features, context=None, scale=None):
         if context is not None:
-            features = torch.cat((features, context), dim = 1)
+            features = torch.cat((features, context), dim=1)
         stuff = self.layers[scale](features)
-        logits, context = stuff[:,:self.num_prototypes], stuff[:,self.num_prototypes:]
+        logits, context = (
+            stuff[:, : self.num_prototypes],
+            stuff[:, self.num_prototypes :],
+        )
         return logits, context
 
+
 class ConvRefiner(nn.Module):
     def __init__(
         self,
@@ -26,13 +33,16 @@ class ConvRefiner(nn.Module):
         dw=True,
         kernel_size=5,
         hidden_blocks=5,
-        amp = True,
-        residual = False,
-        amp_dtype = torch.float16,
+        amp=True,
+        residual=False,
+        amp_dtype=torch.float16,
     ):
         super().__init__()
         self.block1 = self.create_block(
-            in_dim, hidden_dim, dw=False, kernel_size=1,
+            in_dim,
+            hidden_dim,
+            dw=False,
+            kernel_size=1,
         )
         self.hidden_blocks = nn.Sequential(
             *[
@@ -50,15 +60,15 @@ class ConvRefiner(nn.Module):
         self.amp = amp
         self.amp_dtype = amp_dtype
         self.residual = residual
-        
+
     def create_block(
         self,
         in_dim,
         out_dim,
         dw=True,
         kernel_size=5,
-        bias = True,
-        norm_type = nn.BatchNorm2d,
+        bias=True,
+        norm_type=nn.BatchNorm2d,
     ):
         num_groups = 1 if not dw else in_dim
         if dw:
@@ -74,17 +84,21 @@ class ConvRefiner(nn.Module):
             groups=num_groups,
             bias=bias,
         )
-        norm = norm_type(out_dim) if norm_type is nn.BatchNorm2d else norm_type(num_channels = out_dim)
+        norm = (
+            norm_type(out_dim)
+            if norm_type is nn.BatchNorm2d
+            else norm_type(num_channels=out_dim)
+        )
         relu = nn.ReLU(inplace=True)
         conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
         return nn.Sequential(conv1, norm, relu, conv2)
-        
+
     def forward(self, feats):
-        b,c,hs,ws = feats.shape
-        with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
+        b, c, hs, ws = feats.shape
+        with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype):
             x0 = self.block1(feats)
             x = self.hidden_blocks(x0)
             if self.residual:
-                x = (x + x0)/1.4
+                x = (x + x0) / 1.4
             x = self.out_conv(x)
             return x
diff --git a/third_party/DeDoDe/DeDoDe/descriptors/dedode_descriptor.py b/third_party/DeDoDe/DeDoDe/descriptors/dedode_descriptor.py
index 6d949a1b8ed2a58140af49e8167eda4e4099d022..0f98368f1ee812275726e306f356fdfbefa1663b 100644
--- a/third_party/DeDoDe/DeDoDe/descriptors/dedode_descriptor.py
+++ b/third_party/DeDoDe/DeDoDe/descriptors/dedode_descriptor.py
@@ -5,14 +5,18 @@ import torchvision.models as tvm
 import torch.nn.functional as F
 import numpy as np
 
+
 class DeDoDeDescriptor(nn.Module):
     def __init__(self, encoder, decoder, *args, **kwargs) -> None:
         super().__init__(*args, **kwargs)
         self.encoder = encoder
         self.decoder = decoder
         import torchvision.transforms as transforms
-        self.normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
-        
+
+        self.normalizer = transforms.Normalize(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+        )
+
     def forward(
         self,
         batch,
@@ -26,24 +30,43 @@ class DeDoDeDescriptor(nn.Module):
         context = None
         scales = self.decoder.scales
         for idx, (feature_map, scale) in enumerate(zip(reversed(features), scales)):
-            delta_descriptor, context = self.decoder(feature_map, scale = scale, context = context)
+            delta_descriptor, context = self.decoder(
+                feature_map, scale=scale, context=context
+            )
             descriptor = descriptor + delta_descriptor
             if idx < len(scales) - 1:
-                size = sizes[-(idx+2)]
-                descriptor = F.interpolate(descriptor, size = size, mode = "bilinear", align_corners = False)
-                context = F.interpolate(context, size = size, mode = "bilinear", align_corners = False)
-        return {"description_grid" : descriptor}
-    
+                size = sizes[-(idx + 2)]
+                descriptor = F.interpolate(
+                    descriptor, size=size, mode="bilinear", align_corners=False
+                )
+                context = F.interpolate(
+                    context, size=size, mode="bilinear", align_corners=False
+                )
+        return {"description_grid": descriptor}
+
     @torch.inference_mode()
     def describe_keypoints(self, batch, keypoints):
         self.train(False)
         description_grid = self.forward(batch)["description_grid"]
-        described_keypoints = F.grid_sample(description_grid.float(), keypoints[:,None], mode = "bilinear", align_corners = False)[:,:,0].mT
+        described_keypoints = F.grid_sample(
+            description_grid.float(),
+            keypoints[:, None],
+            mode="bilinear",
+            align_corners=False,
+        )[:, :, 0].mT
         return {"descriptions": described_keypoints}
-    
-    def read_image(self, im_path, H = 560, W = 560):
-        return self.normalizer(torch.from_numpy(np.array(Image.open(im_path).resize((W,H)))/255.).permute(2,0,1)).cuda().float()[None]
 
-    def describe_keypoints_from_path(self, im_path, keypoints, H = 768, W = 768):
-        batch = {"image": self.read_image(im_path, H = H, W = W)}
-        return self.describe_keypoints(batch, keypoints)
\ No newline at end of file
+    def read_image(self, im_path, H=560, W=560):
+        return (
+            self.normalizer(
+                torch.from_numpy(
+                    np.array(Image.open(im_path).resize((W, H))) / 255.0
+                ).permute(2, 0, 1)
+            )
+            .cuda()
+            .float()[None]
+        )
+
+    def describe_keypoints_from_path(self, im_path, keypoints, H=768, W=768):
+        batch = {"image": self.read_image(im_path, H=H, W=W)}
+        return self.describe_keypoints(batch, keypoints)
diff --git a/third_party/DeDoDe/DeDoDe/descriptors/descriptor_loss.py b/third_party/DeDoDe/DeDoDe/descriptors/descriptor_loss.py
index 494d39ca124941e7a9f870b427c9d1317c01dafc..343ef0cde0fbccdf981634bbdbd2c6b8948d0ee7 100644
--- a/third_party/DeDoDe/DeDoDe/descriptors/descriptor_loss.py
+++ b/third_party/DeDoDe/DeDoDe/descriptors/descriptor_loss.py
@@ -6,70 +6,107 @@ import torch.nn.functional as F
 from DeDoDe.utils import *
 import DeDoDe
 
+
 class DescriptorLoss(nn.Module):
-    
-    def __init__(self, detector, num_keypoints = 5000, normalize_descriptions = False, inv_temp = 1, device = "cuda") -> None:
+    def __init__(
+        self,
+        detector,
+        num_keypoints=5000,
+        normalize_descriptions=False,
+        inv_temp=1,
+        device="cuda",
+    ) -> None:
         super().__init__()
         self.detector = detector
         self.tracked_metrics = {}
         self.num_keypoints = num_keypoints
         self.normalize_descriptions = normalize_descriptions
         self.inv_temp = inv_temp
-    
+
     def warp_from_depth(self, batch, kpts_A, kpts_B):
-        mask_A_to_B, kpts_A_to_B = warp_kpts(kpts_A, 
-                    batch["im_A_depth"],
-                    batch["im_B_depth"],
-                    batch["T_1to2"],
-                    batch["K1"],
-                    batch["K2"],)
-        mask_B_to_A, kpts_B_to_A = warp_kpts(kpts_B, 
-                    batch["im_B_depth"],
-                    batch["im_A_depth"],
-                    batch["T_1to2"].inverse(),
-                    batch["K2"],
-                    batch["K1"],)
+        mask_A_to_B, kpts_A_to_B = warp_kpts(
+            kpts_A,
+            batch["im_A_depth"],
+            batch["im_B_depth"],
+            batch["T_1to2"],
+            batch["K1"],
+            batch["K2"],
+        )
+        mask_B_to_A, kpts_B_to_A = warp_kpts(
+            kpts_B,
+            batch["im_B_depth"],
+            batch["im_A_depth"],
+            batch["T_1to2"].inverse(),
+            batch["K2"],
+            batch["K1"],
+        )
         return (mask_A_to_B, kpts_A_to_B), (mask_B_to_A, kpts_B_to_A)
-    
+
     def warp_from_homog(self, batch, kpts_A, kpts_B):
         kpts_A_to_B = homog_transform(batch["Homog_A_to_B"], kpts_A)
         kpts_B_to_A = homog_transform(batch["Homog_A_to_B"].inverse(), kpts_B)
         return (None, kpts_A_to_B), (None, kpts_B_to_A)
 
     def supervised_loss(self, outputs, batch):
-        kpts_A, kpts_B = self.detector.detect(batch, num_keypoints = self.num_keypoints)['keypoints'].clone().chunk(2)
+        kpts_A, kpts_B = (
+            self.detector.detect(batch, num_keypoints=self.num_keypoints)["keypoints"]
+            .clone()
+            .chunk(2)
+        )
         desc_grid_A, desc_grid_B = outputs["description_grid"].chunk(2)
-        desc_A = F.grid_sample(desc_grid_A.float(), kpts_A[:,None], mode = "bilinear", align_corners = False)[:,:,0].mT
-        desc_B = F.grid_sample(desc_grid_B.float(), kpts_B[:,None], mode = "bilinear", align_corners = False)[:,:,0].mT
+        desc_A = F.grid_sample(
+            desc_grid_A.float(), kpts_A[:, None], mode="bilinear", align_corners=False
+        )[:, :, 0].mT
+        desc_B = F.grid_sample(
+            desc_grid_B.float(), kpts_B[:, None], mode="bilinear", align_corners=False
+        )[:, :, 0].mT
         if "im_A_depth" in batch:
-            (mask_A_to_B, kpts_A_to_B), (mask_B_to_A, kpts_B_to_A) = self.warp_from_depth(batch, kpts_A, kpts_B)
+            (mask_A_to_B, kpts_A_to_B), (
+                mask_B_to_A,
+                kpts_B_to_A,
+            ) = self.warp_from_depth(batch, kpts_A, kpts_B)
         elif "Homog_A_to_B" in batch:
-            (mask_A_to_B, kpts_A_to_B), (mask_B_to_A, kpts_B_to_A) = self.warp_from_homog(batch, kpts_A, kpts_B)
-            
+            (mask_A_to_B, kpts_A_to_B), (
+                mask_B_to_A,
+                kpts_B_to_A,
+            ) = self.warp_from_homog(batch, kpts_A, kpts_B)
+
         with torch.no_grad():
             D_B = torch.cdist(kpts_A_to_B, kpts_B)
             D_A = torch.cdist(kpts_A, kpts_B_to_A)
-            inds = torch.nonzero((D_B == D_B.min(dim=-1, keepdim = True).values) 
-                                 * (D_A == D_A.min(dim=-2, keepdim = True).values)
-                                 * (D_B < 0.01)
-                                 * (D_A < 0.01))
-            
-        logP_A_B = dual_log_softmax_matcher(desc_A, desc_B, 
-                                            normalize = self.normalize_descriptions,
-                                            inv_temperature = self.inv_temp)
-        neg_log_likelihood = -logP_A_B[inds[:,0], inds[:,1], inds[:,2]].mean()
+            inds = torch.nonzero(
+                (D_B == D_B.min(dim=-1, keepdim=True).values)
+                * (D_A == D_A.min(dim=-2, keepdim=True).values)
+                * (D_B < 0.01)
+                * (D_A < 0.01)
+            )
+
+        logP_A_B = dual_log_softmax_matcher(
+            desc_A,
+            desc_B,
+            normalize=self.normalize_descriptions,
+            inv_temperature=self.inv_temp,
+        )
+        neg_log_likelihood = -logP_A_B[inds[:, 0], inds[:, 1], inds[:, 2]].mean()
         if False:
             import matplotlib.pyplot as plt
-            inds0 = inds[inds[:,0] == 0]
-            mnn_A = kpts_A[0,inds0[:,1]].detach().cpu()
-            mnn_B = kpts_B[0,inds0[:,2]].detach().cpu()
-            plt.scatter(mnn_A[:,0], -mnn_A[:,1], s = 0.5)
+
+            inds0 = inds[inds[:, 0] == 0]
+            mnn_A = kpts_A[0, inds0[:, 1]].detach().cpu()
+            mnn_B = kpts_B[0, inds0[:, 2]].detach().cpu()
+            plt.scatter(mnn_A[:, 0], -mnn_A[:, 1], s=0.5)
             plt.savefig("vis/mnn_A.jpg")
-        self.tracked_metrics["neg_log_likelihood"] = (0.99 * self.tracked_metrics.get("neg_log_likelihood", neg_log_likelihood.detach().item()) + 0.01 * neg_log_likelihood.detach().item())
+        self.tracked_metrics["neg_log_likelihood"] = (
+            0.99
+            * self.tracked_metrics.get(
+                "neg_log_likelihood", neg_log_likelihood.detach().item()
+            )
+            + 0.01 * neg_log_likelihood.detach().item()
+        )
         if np.random.rand() > 0.99:
             print(self.tracked_metrics["neg_log_likelihood"])
         return neg_log_likelihood
-    
+
     def forward(self, outputs, batch):
         losses = self.supervised_loss(outputs, batch)
-        return losses
\ No newline at end of file
+        return losses
diff --git a/third_party/DeDoDe/DeDoDe/detectors/dedode_detector.py b/third_party/DeDoDe/DeDoDe/detectors/dedode_detector.py
index a482d6ddfb8d44de4d00e815b3002f523700390e..dd68212099a2417ca89a562623f670f9f8526b04 100644
--- a/third_party/DeDoDe/DeDoDe/detectors/dedode_detector.py
+++ b/third_party/DeDoDe/DeDoDe/detectors/dedode_detector.py
@@ -8,15 +8,17 @@ import numpy as np
 from DeDoDe.utils import sample_keypoints, to_pixel_coords, to_normalized_coords
 
 
-
 class DeDoDeDetector(nn.Module):
     def __init__(self, encoder, decoder, *args, **kwargs) -> None:
         super().__init__(*args, **kwargs)
         self.encoder = encoder
         self.decoder = decoder
         import torchvision.transforms as transforms
-        self.normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
-        
+
+        self.normalizer = transforms.Normalize(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+        )
+
     def forward(
         self,
         batch,
@@ -30,24 +32,43 @@ class DeDoDeDetector(nn.Module):
         context = None
         scales = ["8", "4", "2", "1"]
         for idx, (feature_map, scale) in enumerate(zip(reversed(features), scales)):
-            delta_logits, context = self.decoder(feature_map, context = context, scale = scale)
-            logits = logits + delta_logits.float() # ensure float (need bf16 doesnt have f.interpolate)
+            delta_logits, context = self.decoder(
+                feature_map, context=context, scale=scale
+            )
+            logits = (
+                logits + delta_logits.float()
+            )  # ensure float (need bf16 doesnt have f.interpolate)
             if idx < len(scales) - 1:
-                size = sizes[-(idx+2)]
-                logits = F.interpolate(logits, size = size, mode = "bicubic", align_corners = False)
-                context = F.interpolate(context.float(), size = size, mode = "bilinear", align_corners = False)
-        return {"keypoint_logits" : logits.float()}
-    
+                size = sizes[-(idx + 2)]
+                logits = F.interpolate(
+                    logits, size=size, mode="bicubic", align_corners=False
+                )
+                context = F.interpolate(
+                    context.float(), size=size, mode="bilinear", align_corners=False
+                )
+        return {"keypoint_logits": logits.float()}
+
     @torch.inference_mode()
-    def detect(self, batch, num_keypoints = 10_000):
+    def detect(self, batch, num_keypoints=10_000):
         self.train(False)
         keypoint_logits = self.forward(batch)["keypoint_logits"]
-        B,K,H,W = keypoint_logits.shape
-        keypoint_p = keypoint_logits.reshape(B, K*H*W).softmax(dim=-1).reshape(B, K, H*W).sum(dim=1)
-        keypoints, confidence = sample_keypoints(keypoint_p.reshape(B,H,W), 
-                                  use_nms = False, sample_topk = True, num_samples = num_keypoints, 
-                                  return_scoremap=True, sharpen = False, upsample = False,
-                                  increase_coverage=True)
+        B, K, H, W = keypoint_logits.shape
+        keypoint_p = (
+            keypoint_logits.reshape(B, K * H * W)
+            .softmax(dim=-1)
+            .reshape(B, K, H * W)
+            .sum(dim=1)
+        )
+        keypoints, confidence = sample_keypoints(
+            keypoint_p.reshape(B, H, W),
+            use_nms=False,
+            sample_topk=True,
+            num_samples=num_keypoints,
+            return_scoremap=True,
+            sharpen=False,
+            upsample=False,
+            increase_coverage=True,
+        )
         return {"keypoints": keypoints, "confidence": confidence}
 
     @torch.inference_mode()
@@ -56,20 +77,26 @@ class DeDoDeDetector(nn.Module):
         keypoint_logits = self.forward(batch)["keypoint_logits"]
         return {"dense_keypoint_logits": keypoint_logits}
 
-    def read_image(self, im_path, H = 560, W = 560):
+    def read_image(self, im_path, H=560, W=560):
         pil_im = Image.open(im_path).resize((W, H))
-        standard_im = np.array(pil_im)/255.
-        return self.normalizer(torch.from_numpy(standard_im).permute(2,0,1)).cuda().float()[None]
+        standard_im = np.array(pil_im) / 255.0
+        return (
+            self.normalizer(torch.from_numpy(standard_im).permute(2, 0, 1))
+            .cuda()
+            .float()[None]
+        )
 
-    def detect_from_path(self, im_path, num_keypoints = 30_000, H = 768, W = 768, dense = False):
-        batch = {"image": self.read_image(im_path, H = H, W = W)}
+    def detect_from_path(
+        self, im_path, num_keypoints=30_000, H=768, W=768, dense=False
+    ):
+        batch = {"image": self.read_image(im_path, H=H, W=W)}
         if dense:
             return self.detect_dense(batch)
         else:
-            return self.detect(batch, num_keypoints = num_keypoints)
+            return self.detect(batch, num_keypoints=num_keypoints)
 
     def to_pixel_coords(self, x, H, W):
         return to_pixel_coords(x, H, W)
-    
+
     def to_normalized_coords(self, x, H, W):
-        return to_normalized_coords(x, H, W)
\ No newline at end of file
+        return to_normalized_coords(x, H, W)
diff --git a/third_party/DeDoDe/DeDoDe/detectors/loss.py b/third_party/DeDoDe/DeDoDe/detectors/loss.py
index 74d47058c82714729a05ea8a3b8433f352af2f4a..924bb896a66034ef45b11420ca6d48a462092ed1 100644
--- a/third_party/DeDoDe/DeDoDe/detectors/loss.py
+++ b/third_party/DeDoDe/DeDoDe/detectors/loss.py
@@ -5,27 +5,34 @@ import math
 from DeDoDe.utils import *
 import DeDoDe
 
+
 class KeyPointLoss(nn.Module):
-    
-    def __init__(self, smoothing_size = 1, use_max_logit = False, entropy_target = 80, 
-                 num_matches = 1024, jacobian_density_adjustment = False,
-                 matchability_weight = 1, device = "cuda") -> None:
+    def __init__(
+        self,
+        smoothing_size=1,
+        use_max_logit=False,
+        entropy_target=80,
+        num_matches=1024,
+        jacobian_density_adjustment=False,
+        matchability_weight=1,
+        device="cuda",
+    ) -> None:
         super().__init__()
-        X = torch.linspace(-1,1,smoothing_size, device = device)
-        G = (-X**2 / (2 *1/2**2)).exp()
-        G = G/G.sum()
+        X = torch.linspace(-1, 1, smoothing_size, device=device)
+        G = (-(X**2) / (2 * 1 / 2**2)).exp()
+        G = G / G.sum()
         self.use_max_logit = use_max_logit
         self.entropy_target = entropy_target
-        self.smoothing_kernel = G[None, None, None,:]
+        self.smoothing_kernel = G[None, None, None, :]
         self.smoothing_size = smoothing_size
         self.tracked_metrics = {}
         self.center = None
         self.num_matches = num_matches
         self.jacobian_density_adjustment = jacobian_density_adjustment
         self.matchability_weight = matchability_weight
-        
-    def compute_consistency(self, logits_A, logits_B_to_A, mask = None):
-        
+
+    def compute_consistency(self, logits_A, logits_B_to_A, mask=None):
+
         masked_logits_A = torch.full_like(logits_A, -torch.inf)
         masked_logits_A[mask] = logits_A[mask]
 
@@ -36,129 +43,186 @@ class KeyPointLoss(nn.Module):
         log_p_B_to_A = masked_logits_B_to_A.log_softmax(dim=-1)[mask]
 
         return self.compute_jensen_shannon_div(log_p_A, log_p_B_to_A)
-    
-    def compute_joint_neg_log_likelihood(self, logits_A, logits_B_to_A, detections_A = None, detections_B_to_A = None, mask = None, device = "cuda", dtype = torch.float32, num_matches = None):
+
+    def compute_joint_neg_log_likelihood(
+        self,
+        logits_A,
+        logits_B_to_A,
+        detections_A=None,
+        detections_B_to_A=None,
+        mask=None,
+        device="cuda",
+        dtype=torch.float32,
+        num_matches=None,
+    ):
         B, K, HW = logits_A.shape
         logits_A, logits_B_to_A = logits_A.to(dtype), logits_B_to_A.to(dtype)
-        mask = mask[:,None].expand(B, K, HW).reshape(B, K*HW)
-        log_p_B_to_A = self.masked_log_softmax(logits_B_to_A.reshape(B,K*HW), mask = mask)
-        log_p_A = self.masked_log_softmax(logits_A.reshape(B,K*HW), mask = mask)
+        mask = mask[:, None].expand(B, K, HW).reshape(B, K * HW)
+        log_p_B_to_A = self.masked_log_softmax(
+            logits_B_to_A.reshape(B, K * HW), mask=mask
+        )
+        log_p_A = self.masked_log_softmax(logits_A.reshape(B, K * HW), mask=mask)
         log_p = log_p_A + log_p_B_to_A
         if detections_A is None:
             detections_A = torch.zeros_like(log_p_A)
         if detections_B_to_A is None:
             detections_B_to_A = torch.zeros_like(log_p_B_to_A)
         detections_A = detections_A.reshape(B, HW)
-        detections_A[~mask] = 0         
+        detections_A[~mask] = 0
         detections_B_to_A = detections_B_to_A.reshape(B, HW)
         detections_B_to_A[~mask] = 0
-        log_p_target = log_p.detach() + 50*detections_A + 50*detections_B_to_A
+        log_p_target = log_p.detach() + 50 * detections_A + 50 * detections_B_to_A
         num_matches = self.num_matches if num_matches is None else num_matches
-        best_k = -(-log_p_target).flatten().kthvalue(k = B * num_matches, dim=-1).values
-        p_target = (log_p_target > best_k[..., None]).float().reshape(B,K*HW)/num_matches
-        return self.compute_cross_entropy(log_p_A[mask], p_target[mask]) + self.compute_cross_entropy(log_p_B_to_A[mask], p_target[mask])
-                
+        best_k = -(-log_p_target).flatten().kthvalue(k=B * num_matches, dim=-1).values
+        p_target = (log_p_target > best_k[..., None]).float().reshape(
+            B, K * HW
+        ) / num_matches
+        return self.compute_cross_entropy(
+            log_p_A[mask], p_target[mask]
+        ) + self.compute_cross_entropy(log_p_B_to_A[mask], p_target[mask])
+
     def compute_jensen_shannon_div(self, log_p, log_q):
-        return 1/2 * (self.compute_kl_div(log_p, log_q) + self.compute_kl_div(log_q, log_p))
-    
+        return (
+            1
+            / 2
+            * (self.compute_kl_div(log_p, log_q) + self.compute_kl_div(log_q, log_p))
+        )
+
     def compute_kl_div(self, log_p, log_q):
-        return (log_p.exp()*(log_p-log_q)).sum(dim=-1)
-    
+        return (log_p.exp() * (log_p - log_q)).sum(dim=-1)
+
     def masked_log_softmax(self, logits, mask):
         masked_logits = torch.full_like(logits, -torch.inf)
         masked_logits[mask] = logits[mask]
         log_p = masked_logits.log_softmax(dim=-1)
         return log_p
-    
+
     def masked_softmax(self, logits, mask):
         masked_logits = torch.full_like(logits, -torch.inf)
         masked_logits[mask] = logits[mask]
         log_p = masked_logits.softmax(dim=-1)
         return log_p
-    
-    def compute_entropy(self, logits, mask = None):
+
+    def compute_entropy(self, logits, mask=None):
         p = self.masked_softmax(logits, mask)[mask]
         log_p = self.masked_log_softmax(logits, mask)[mask]
-        return -(log_p * p).sum(dim=-1) 
+        return -(log_p * p).sum(dim=-1)
 
-    def compute_detection_img(self, detections, mask, B, H, W, device = "cuda"):
+    def compute_detection_img(self, detections, mask, B, H, W, device="cuda"):
         kernel_size = 5
-        X = torch.linspace(-2,2,kernel_size, device = device)
-        G = (-X**2 / (2 * (1/2)**2)).exp() # half pixel std
-        G = G/G.sum()
-        det_smoothing_kernel = G[None, None, None,:]
-        det_img = torch.zeros((B,1,H,W), device = device) # add small epsilon for later logstuff
+        X = torch.linspace(-2, 2, kernel_size, device=device)
+        G = (-(X**2) / (2 * (1 / 2) ** 2)).exp()  # half pixel std
+        G = G / G.sum()
+        det_smoothing_kernel = G[None, None, None, :]
+        det_img = torch.zeros(
+            (B, 1, H, W), device=device
+        )  # add small epsilon for later logstuff
         for b in range(B):
             valid_detections = (detections[b][mask[b]]).int()
-            det_img[b,0][valid_detections[:,1], valid_detections[:,0]] = 1
-        det_img = F.conv2d(det_img, weight = det_smoothing_kernel, padding = (kernel_size//2, 0))
-        det_img = F.conv2d(det_img, weight = det_smoothing_kernel.mT, padding = (0, kernel_size//2))
+            det_img[b, 0][valid_detections[:, 1], valid_detections[:, 0]] = 1
+        det_img = F.conv2d(
+            det_img, weight=det_smoothing_kernel, padding=(kernel_size // 2, 0)
+        )
+        det_img = F.conv2d(
+            det_img, weight=det_smoothing_kernel.mT, padding=(0, kernel_size // 2)
+        )
         return det_img
 
     def compute_cross_entropy(self, log_p_hat, p):
         return -(log_p_hat * p).sum(dim=-1)
 
-    def compute_matchability(self, keypoint_p, has_depth, B, K, H, W, device = "cuda"):
-        smooth_keypoint_p = F.conv2d(keypoint_p.reshape(B,1,H,W), weight = self.smoothing_kernel, padding = (self.smoothing_size//2,0))
-        smooth_keypoint_p = F.conv2d(smooth_keypoint_p, weight = self.smoothing_kernel.mT, padding = (0,self.smoothing_size//2))
-        log_p_hat = (smooth_keypoint_p+1e-8).log().reshape(B,H*W).log_softmax(dim=-1)
-        smooth_has_depth = F.conv2d(has_depth.reshape(B,1,H,W), weight = self.smoothing_kernel, padding = (0,self.smoothing_size//2))
-        smooth_has_depth = F.conv2d(smooth_has_depth, weight = self.smoothing_kernel.mT, padding = (self.smoothing_size//2,0)).reshape(B,H*W)
-        p = smooth_has_depth/smooth_has_depth.sum(dim=-1,keepdim=True)
-        return self.compute_cross_entropy(log_p_hat, p) - self.compute_cross_entropy((p+1e-12).log(), p)
+    def compute_matchability(self, keypoint_p, has_depth, B, K, H, W, device="cuda"):
+        smooth_keypoint_p = F.conv2d(
+            keypoint_p.reshape(B, 1, H, W),
+            weight=self.smoothing_kernel,
+            padding=(self.smoothing_size // 2, 0),
+        )
+        smooth_keypoint_p = F.conv2d(
+            smooth_keypoint_p,
+            weight=self.smoothing_kernel.mT,
+            padding=(0, self.smoothing_size // 2),
+        )
+        log_p_hat = (
+            (smooth_keypoint_p + 1e-8).log().reshape(B, H * W).log_softmax(dim=-1)
+        )
+        smooth_has_depth = F.conv2d(
+            has_depth.reshape(B, 1, H, W),
+            weight=self.smoothing_kernel,
+            padding=(0, self.smoothing_size // 2),
+        )
+        smooth_has_depth = F.conv2d(
+            smooth_has_depth,
+            weight=self.smoothing_kernel.mT,
+            padding=(self.smoothing_size // 2, 0),
+        ).reshape(B, H * W)
+        p = smooth_has_depth / smooth_has_depth.sum(dim=-1, keepdim=True)
+        return self.compute_cross_entropy(log_p_hat, p) - self.compute_cross_entropy(
+            (p + 1e-12).log(), p
+        )
 
     def tracks_to_detections(self, tracks3D, pose, intrinsics, H, W):
         tracks3D = tracks3D.double()
         intrinsics = intrinsics.double()
-        bearing_vectors = pose[:,:3,:3] @ tracks3D.mT + pose[:,:3,3:]        
+        bearing_vectors = pose[:, :3, :3] @ tracks3D.mT + pose[:, :3, 3:]
         hom_pixel_coords = (intrinsics @ bearing_vectors).mT
-        pixel_coords = hom_pixel_coords[...,:2] / (hom_pixel_coords[...,2:]+1e-12)
-        legit_detections = (pixel_coords > 0).prod(dim = -1) * (pixel_coords[...,0] < W - 1) * (pixel_coords[...,1] < H - 1) * (tracks3D != 0).prod(dim=-1)
+        pixel_coords = hom_pixel_coords[..., :2] / (hom_pixel_coords[..., 2:] + 1e-12)
+        legit_detections = (
+            (pixel_coords > 0).prod(dim=-1)
+            * (pixel_coords[..., 0] < W - 1)
+            * (pixel_coords[..., 1] < H - 1)
+            * (tracks3D != 0).prod(dim=-1)
+        )
         return pixel_coords.float(), legit_detections.bool()
-    
+
     def self_supervised_loss(self, outputs, batch):
         keypoint_logits_A, keypoint_logits_B = outputs["keypoint_logits"].chunk(2)
         B, K, H, W = keypoint_logits_A.shape
-        keypoint_logits_A = keypoint_logits_A.reshape(B, K, H*W)
-        keypoint_logits_B = keypoint_logits_B.reshape(B, K, H*W)
+        keypoint_logits_A = keypoint_logits_A.reshape(B, K, H * W)
+        keypoint_logits_B = keypoint_logits_B.reshape(B, K, H * W)
         keypoint_logits = torch.cat((keypoint_logits_A, keypoint_logits_B))
 
-        warp_A_to_B, mask_A_to_B = get_homog_warp(
-            batch["Homog_A_to_B"], H, W
-        )
+        warp_A_to_B, mask_A_to_B = get_homog_warp(batch["Homog_A_to_B"], H, W)
         warp_B_to_A, mask_B_to_A = get_homog_warp(
             torch.linalg.inv(batch["Homog_A_to_B"]), H, W
         )
-        B = 2*B
-        
-        warp = torch.cat((warp_A_to_B, warp_B_to_A)).reshape(B, H*W, 4)
-        mask = torch.cat((mask_A_to_B, mask_B_to_A)).reshape(B,H*W)
-        
-        keypoint_logits_backwarped = F.grid_sample(torch.cat((keypoint_logits_B, keypoint_logits_A)).reshape(B,K,H,W), 
-            warp[...,-2:].reshape(B,H,W,2).float(), align_corners = False, mode = "bicubic")
-        
-        keypoint_logits_backwarped = keypoint_logits_backwarped.reshape(B,K,H*W)
-        joint_log_likelihood_loss = self.compute_joint_neg_log_likelihood(keypoint_logits, keypoint_logits_backwarped, 
-                                                                          mask = mask.bool(), num_matches = 5_000).mean()
+        B = 2 * B
+
+        warp = torch.cat((warp_A_to_B, warp_B_to_A)).reshape(B, H * W, 4)
+        mask = torch.cat((mask_A_to_B, mask_B_to_A)).reshape(B, H * W)
+
+        keypoint_logits_backwarped = F.grid_sample(
+            torch.cat((keypoint_logits_B, keypoint_logits_A)).reshape(B, K, H, W),
+            warp[..., -2:].reshape(B, H, W, 2).float(),
+            align_corners=False,
+            mode="bicubic",
+        )
+
+        keypoint_logits_backwarped = keypoint_logits_backwarped.reshape(B, K, H * W)
+        joint_log_likelihood_loss = self.compute_joint_neg_log_likelihood(
+            keypoint_logits,
+            keypoint_logits_backwarped,
+            mask=mask.bool(),
+            num_matches=5_000,
+        ).mean()
         return joint_log_likelihood_loss
-    
+
     def supervised_loss(self, outputs, batch):
         keypoint_logits_A, keypoint_logits_B = outputs["keypoint_logits"].chunk(2)
         B, K, H, W = keypoint_logits_A.shape
 
         detections_A, detections_B = batch["detections_A"], batch["detections_B"]
-        
+
         tracks3D_A, tracks3D_B = batch["tracks3D_A"], batch["tracks3D_B"]
-        gt_warp_A_to_B, valid_mask_A_to_B = get_gt_warp(                
-                    batch["im_A_depth"],
-                    batch["im_B_depth"],
-                    batch["T_1to2"],
-                    batch["K1"],
-                    batch["K2"],
-                    H=H,
-                    W=W,
-                )
-        gt_warp_B_to_A, valid_mask_B_to_A = get_gt_warp(                
+        gt_warp_A_to_B, valid_mask_A_to_B = get_gt_warp(
+            batch["im_A_depth"],
+            batch["im_B_depth"],
+            batch["T_1to2"],
+            batch["K1"],
+            batch["K2"],
+            H=H,
+            W=W,
+        )
+        gt_warp_B_to_A, valid_mask_B_to_A = get_gt_warp(
             batch["im_B_depth"],
             batch["im_A_depth"],
             batch["T_1to2"].inverse(),
@@ -167,103 +231,216 @@ class KeyPointLoss(nn.Module):
             H=H,
             W=W,
         )
-        keypoint_logits_A = keypoint_logits_A.reshape(B, K, H*W)
-        keypoint_logits_B = keypoint_logits_B.reshape(B, K, H*W)
+        keypoint_logits_A = keypoint_logits_A.reshape(B, K, H * W)
+        keypoint_logits_B = keypoint_logits_B.reshape(B, K, H * W)
         keypoint_logits = torch.cat((keypoint_logits_A, keypoint_logits_B))
 
-        B = 2*B
+        B = 2 * B
         gt_warp = torch.cat((gt_warp_A_to_B, gt_warp_B_to_A))
         valid_mask = torch.cat((valid_mask_A_to_B, valid_mask_B_to_A))
-        valid_mask = valid_mask.reshape(B,H*W)
+        valid_mask = valid_mask.reshape(B, H * W)
         binary_mask = valid_mask == 1
         if self.jacobian_density_adjustment:
-            j_logdet = jacobi_determinant(gt_warp.reshape(B,H,W,4), valid_mask.reshape(B,H,W).float())[:,None]
+            j_logdet = jacobi_determinant(
+                gt_warp.reshape(B, H, W, 4), valid_mask.reshape(B, H, W).float()
+            )[:, None]
         else:
             j_logdet = 0
         tracks3D = torch.cat((tracks3D_A, tracks3D_B))
-        
-        #detections, legit_detections = self.tracks_to_detections(tracks3D, torch.cat((batch["pose_A"],batch["pose_B"])), torch.cat((batch["K1"],batch["K2"])), H, W)
-        #detections_backwarped, legit_backwarped_detections = self.tracks_to_detections(torch.cat((tracks3D_B, tracks3D_A)), torch.cat((batch["pose_A"],batch["pose_B"])), torch.cat((batch["K1"],batch["K2"])), H, W)
+
+        # detections, legit_detections = self.tracks_to_detections(tracks3D, torch.cat((batch["pose_A"],batch["pose_B"])), torch.cat((batch["K1"],batch["K2"])), H, W)
+        # detections_backwarped, legit_backwarped_detections = self.tracks_to_detections(torch.cat((tracks3D_B, tracks3D_A)), torch.cat((batch["pose_A"],batch["pose_B"])), torch.cat((batch["K1"],batch["K2"])), H, W)
         detections = torch.cat((detections_A, detections_B))
-        legit_detections = ((detections > 0).prod(dim = -1) * (detections[...,0] < W) * (detections[...,1] < H)).bool()
-        det_imgs_A, det_imgs_B = self.compute_detection_img(detections, legit_detections, B, H, W).chunk(2)
+        legit_detections = (
+            (detections > 0).prod(dim=-1)
+            * (detections[..., 0] < W)
+            * (detections[..., 1] < H)
+        ).bool()
+        det_imgs_A, det_imgs_B = self.compute_detection_img(
+            detections, legit_detections, B, H, W
+        ).chunk(2)
         det_imgs = torch.cat((det_imgs_A, det_imgs_B))
-        #det_imgs_backwarped = self.compute_detection_img(detections_backwarped, legit_backwarped_detections, B, H, W)
-        det_imgs_backwarped = F.grid_sample(torch.cat((det_imgs_B, det_imgs_A)).reshape(B,1,H,W), 
-            gt_warp[...,-2:].reshape(B,H,W,2).float(), align_corners = False, mode = "bicubic")
+        # det_imgs_backwarped = self.compute_detection_img(detections_backwarped, legit_backwarped_detections, B, H, W)
+        det_imgs_backwarped = F.grid_sample(
+            torch.cat((det_imgs_B, det_imgs_A)).reshape(B, 1, H, W),
+            gt_warp[..., -2:].reshape(B, H, W, 2).float(),
+            align_corners=False,
+            mode="bicubic",
+        )
+
+        keypoint_logits_backwarped = F.grid_sample(
+            torch.cat((keypoint_logits_B, keypoint_logits_A)).reshape(B, K, H, W),
+            gt_warp[..., -2:].reshape(B, H, W, 2).float(),
+            align_corners=False,
+            mode="bicubic",
+        )
 
-        keypoint_logits_backwarped = F.grid_sample(torch.cat((keypoint_logits_B, keypoint_logits_A)).reshape(B,K,H,W), 
-            gt_warp[...,-2:].reshape(B,H,W,2).float(), align_corners = False, mode = "bicubic")
-        
         # Note: Below step should be taken, but seems difficult to get it to work well.
-        #keypoint_logits_B_to_A = keypoint_logits_B_to_A + j_logdet_A_to_B # adjust for the viewpoint by log jacobian of warp
-        keypoint_logits_backwarped = (keypoint_logits_backwarped + j_logdet).reshape(B,K,H*W)
-
-
-        depth = F.interpolate(torch.cat((batch["im_A_depth"][:,None],batch["im_B_depth"][:,None]),dim=0), size = (H,W), mode = "bilinear", align_corners=False)
-        has_depth = (depth > 0).float().reshape(B,H*W)
-        
-        joint_log_likelihood_loss = self.compute_joint_neg_log_likelihood(keypoint_logits, keypoint_logits_backwarped, 
-                                                                          mask = binary_mask, detections_A = det_imgs, 
-                                                                          detections_B_to_A = det_imgs_backwarped).mean()
-        keypoint_p = keypoint_logits.reshape(B, K*H*W).softmax(dim=-1).reshape(B, K, H*W).sum(dim=1)
-        matchability_loss = self.compute_matchability(keypoint_p, has_depth, B, K, H, W).mean()
-        
-        #peakiness_loss = self.compute_negative_peakiness(keypoint_logits.reshape(B,H,W), mask = binary_mask)
-        #mnn_loss = self.compute_mnn_loss(keypoint_logits_A, keypoint_logits_B, gt_warp_A_to_B, valid_mask_A_to_B, B, H, W)
-        B = B//2
+        # keypoint_logits_B_to_A = keypoint_logits_B_to_A + j_logdet_A_to_B # adjust for the viewpoint by log jacobian of warp
+        keypoint_logits_backwarped = (keypoint_logits_backwarped + j_logdet).reshape(
+            B, K, H * W
+        )
+
+        depth = F.interpolate(
+            torch.cat(
+                (batch["im_A_depth"][:, None], batch["im_B_depth"][:, None]), dim=0
+            ),
+            size=(H, W),
+            mode="bilinear",
+            align_corners=False,
+        )
+        has_depth = (depth > 0).float().reshape(B, H * W)
+
+        joint_log_likelihood_loss = self.compute_joint_neg_log_likelihood(
+            keypoint_logits,
+            keypoint_logits_backwarped,
+            mask=binary_mask,
+            detections_A=det_imgs,
+            detections_B_to_A=det_imgs_backwarped,
+        ).mean()
+        keypoint_p = (
+            keypoint_logits.reshape(B, K * H * W)
+            .softmax(dim=-1)
+            .reshape(B, K, H * W)
+            .sum(dim=1)
+        )
+        matchability_loss = self.compute_matchability(
+            keypoint_p, has_depth, B, K, H, W
+        ).mean()
+
+        # peakiness_loss = self.compute_negative_peakiness(keypoint_logits.reshape(B,H,W), mask = binary_mask)
+        # mnn_loss = self.compute_mnn_loss(keypoint_logits_A, keypoint_logits_B, gt_warp_A_to_B, valid_mask_A_to_B, B, H, W)
+        B = B // 2
         import matplotlib.pyplot as plt
-        kpts_A = sample_keypoints(keypoint_p[:B].reshape(B,H,W), 
-                                use_nms = False, sample_topk = True, num_samples = 4*2048)
-        kpts_B = sample_keypoints(keypoint_p[B:].reshape(B,H,W),
-                                use_nms = False, sample_topk = True, num_samples = 4*2048)
-        kpts_A_to_B = F.grid_sample(gt_warp_A_to_B[...,2:].float().permute(0,3,1,2), kpts_A[...,None,:], 
-                                    align_corners=False, mode = 'bilinear')[...,0].mT
-        legit_A_to_B = F.grid_sample(valid_mask_A_to_B.reshape(B,1,H,W), kpts_A[...,None,:], 
-                                    align_corners=False, mode = 'bilinear')[...,0,:,0]
-        percent_inliers = (torch.cdist(kpts_A_to_B, kpts_B).min(dim=-1).values[legit_A_to_B > 0] < 0.01).float().mean()
-        self.tracked_metrics["mega_percent_inliers"] = (0.9 * self.tracked_metrics.get("mega_percent_inliers", percent_inliers) + 0.1 * percent_inliers)
+
+        kpts_A = sample_keypoints(
+            keypoint_p[:B].reshape(B, H, W),
+            use_nms=False,
+            sample_topk=True,
+            num_samples=4 * 2048,
+        )
+        kpts_B = sample_keypoints(
+            keypoint_p[B:].reshape(B, H, W),
+            use_nms=False,
+            sample_topk=True,
+            num_samples=4 * 2048,
+        )
+        kpts_A_to_B = F.grid_sample(
+            gt_warp_A_to_B[..., 2:].float().permute(0, 3, 1, 2),
+            kpts_A[..., None, :],
+            align_corners=False,
+            mode="bilinear",
+        )[..., 0].mT
+        legit_A_to_B = F.grid_sample(
+            valid_mask_A_to_B.reshape(B, 1, H, W),
+            kpts_A[..., None, :],
+            align_corners=False,
+            mode="bilinear",
+        )[..., 0, :, 0]
+        percent_inliers = (
+            (
+                torch.cdist(kpts_A_to_B, kpts_B).min(dim=-1).values[legit_A_to_B > 0]
+                < 0.01
+            )
+            .float()
+            .mean()
+        )
+        self.tracked_metrics["mega_percent_inliers"] = (
+            0.9 * self.tracked_metrics.get("mega_percent_inliers", percent_inliers)
+            + 0.1 * percent_inliers
+        )
 
         if torch.rand(1) > 0.995:
             keypoint_logits_A_to_B = keypoint_logits_backwarped[:B]
             import matplotlib.pyplot as plt
             import os
-            os.makedirs("vis",exist_ok = True)
+
+            os.makedirs("vis", exist_ok=True)
             for b in range(0, B, 2):
-                #import cv2
-                plt.scatter(kpts_A_to_B[b,:,0].cpu(),-kpts_A_to_B[b,:,1].cpu(), s = 1)
-                plt.scatter(kpts_B[b,:,0].cpu(),-kpts_B[b,:,1].cpu(), s = 1)
-                plt.xlim(-1,1)
-                plt.ylim(-1,1)
+                # import cv2
+                plt.scatter(
+                    kpts_A_to_B[b, :, 0].cpu(), -kpts_A_to_B[b, :, 1].cpu(), s=1
+                )
+                plt.scatter(kpts_B[b, :, 0].cpu(), -kpts_B[b, :, 1].cpu(), s=1)
+                plt.xlim(-1, 1)
+                plt.ylim(-1, 1)
                 plt.savefig(f"vis/keypoints_A_to_B_vs_B_{b}.png")
                 plt.close()
-                tensor_to_pil(keypoint_logits_A[b].reshape(1,H,W).expand(3,H,W).detach().cpu(), 
-                            autoscale = True).save(f"vis/logits_A_{b}.png")
-                tensor_to_pil(keypoint_logits_B[b].reshape(1,H,W).expand(3,H,W).detach().cpu(), 
-                            autoscale = True).save(f"vis/logits_B_{b}.png")
-                tensor_to_pil(keypoint_logits_A_to_B[b].reshape(1,H,W).expand(3,H,W).detach().cpu(), 
-                            autoscale = True).save(f"vis/logits_A_to_B{b}.png")
-                tensor_to_pil(keypoint_logits_A[b].softmax(dim=-1).reshape(1,H,W).expand(3,H,W).detach().cpu(), 
-                            autoscale = True).save(f"vis/keypoint_p_A_{b}.png")
-                tensor_to_pil(keypoint_logits_B[b].softmax(dim=-1).reshape(1,H,W).expand(3,H,W).detach().cpu(), 
-                            autoscale = True).save(f"vis/keypoint_p_B_{b}.png")
-                tensor_to_pil(has_depth[b].reshape(1,H,W).expand(3,H,W).detach().cpu(), autoscale=True).save(f"vis/has_depth_A_{b}.png")                            
-                tensor_to_pil(valid_mask_A_to_B[b].reshape(1,H,W).expand(3,H,W).detach().cpu(), autoscale=True).save(f"vis/valid_mask_A_to_B_{b}.png")                            
-                tensor_to_pil(batch['im_A'][b], unnormalize=True).save(
-                                    f"vis/im_A_{b}.jpg")
-                tensor_to_pil(batch['im_B'][b], unnormalize=True).save(
-                                    f"vis/im_B_{b}.jpg")
+                tensor_to_pil(
+                    keypoint_logits_A[b]
+                    .reshape(1, H, W)
+                    .expand(3, H, W)
+                    .detach()
+                    .cpu(),
+                    autoscale=True,
+                ).save(f"vis/logits_A_{b}.png")
+                tensor_to_pil(
+                    keypoint_logits_B[b]
+                    .reshape(1, H, W)
+                    .expand(3, H, W)
+                    .detach()
+                    .cpu(),
+                    autoscale=True,
+                ).save(f"vis/logits_B_{b}.png")
+                tensor_to_pil(
+                    keypoint_logits_A_to_B[b]
+                    .reshape(1, H, W)
+                    .expand(3, H, W)
+                    .detach()
+                    .cpu(),
+                    autoscale=True,
+                ).save(f"vis/logits_A_to_B{b}.png")
+                tensor_to_pil(
+                    keypoint_logits_A[b]
+                    .softmax(dim=-1)
+                    .reshape(1, H, W)
+                    .expand(3, H, W)
+                    .detach()
+                    .cpu(),
+                    autoscale=True,
+                ).save(f"vis/keypoint_p_A_{b}.png")
+                tensor_to_pil(
+                    keypoint_logits_B[b]
+                    .softmax(dim=-1)
+                    .reshape(1, H, W)
+                    .expand(3, H, W)
+                    .detach()
+                    .cpu(),
+                    autoscale=True,
+                ).save(f"vis/keypoint_p_B_{b}.png")
+                tensor_to_pil(
+                    has_depth[b].reshape(1, H, W).expand(3, H, W).detach().cpu(),
+                    autoscale=True,
+                ).save(f"vis/has_depth_A_{b}.png")
+                tensor_to_pil(
+                    valid_mask_A_to_B[b]
+                    .reshape(1, H, W)
+                    .expand(3, H, W)
+                    .detach()
+                    .cpu(),
+                    autoscale=True,
+                ).save(f"vis/valid_mask_A_to_B_{b}.png")
+                tensor_to_pil(batch["im_A"][b], unnormalize=True).save(
+                    f"vis/im_A_{b}.jpg"
+                )
+                tensor_to_pil(batch["im_B"][b], unnormalize=True).save(
+                    f"vis/im_B_{b}.jpg"
+                )
             plt.close()
-        tot_loss = joint_log_likelihood_loss + self.matchability_weight * matchability_loss# 
-        #tot_loss = tot_loss + (-2*consistency_loss).detach().exp()*compression_loss
+        tot_loss = (
+            joint_log_likelihood_loss + self.matchability_weight * matchability_loss
+        )  #
+        # tot_loss = tot_loss + (-2*consistency_loss).detach().exp()*compression_loss
         if torch.rand(1) > 1:
-            print(f"Precent Inlier: {self.tracked_metrics.get('mega_percent_inliers', 0)}")
+            print(
+                f"Precent Inlier: {self.tracked_metrics.get('mega_percent_inliers', 0)}"
+            )
             print(f"{joint_log_likelihood_loss=} {matchability_loss=}")
             print(f"Total Loss: {tot_loss.item()}")
-        return  tot_loss
-    
+        return tot_loss
+
     def forward(self, outputs, batch):
-        
+
         if not isinstance(outputs, list):
             outputs = [outputs]
         losses = 0
@@ -272,4 +449,4 @@ class KeyPointLoss(nn.Module):
                 losses = losses + self.self_supervised_loss(output, batch)
             else:
                 losses = losses + self.supervised_loss(output, batch)
-        return losses
\ No newline at end of file
+        return losses
diff --git a/third_party/DeDoDe/DeDoDe/encoder.py b/third_party/DeDoDe/DeDoDe/encoder.py
index faf56c4b6629ce7147b46272ae1f4715e4d10740..2aebb1c5ac890c77d01774ab74caed460c2ff028 100644
--- a/third_party/DeDoDe/DeDoDe/encoder.py
+++ b/third_party/DeDoDe/DeDoDe/encoder.py
@@ -4,7 +4,7 @@ import torchvision.models as tvm
 
 
 class VGG19(nn.Module):
-    def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None:
+    def __init__(self, pretrained=False, amp=False, amp_dtype=torch.float16) -> None:
         super().__init__()
         self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
         # Maxpool layers: 6, 13, 26, 39
@@ -12,7 +12,7 @@ class VGG19(nn.Module):
         self.amp_dtype = amp_dtype
 
     def forward(self, x, **kwargs):
-        with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
+        with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype):
             feats = []
             sizes = []
             for layer in self.layers:
@@ -22,21 +22,30 @@ class VGG19(nn.Module):
                 x = layer(x)
             return feats, sizes
 
+
 class VGG(nn.Module):
-    def __init__(self, size = "19", pretrained=False, amp = False, amp_dtype = torch.float16) -> None:
+    def __init__(
+        self, size="19", pretrained=False, amp=False, amp_dtype=torch.float16
+    ) -> None:
         super().__init__()
         if size == "11":
-            self.layers = nn.ModuleList(tvm.vgg11_bn(pretrained=pretrained).features[:22])
-        elif size == "13": 
-            self.layers = nn.ModuleList(tvm.vgg13_bn(pretrained=pretrained).features[:28])
-        elif size == "19": 
-            self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
+            self.layers = nn.ModuleList(
+                tvm.vgg11_bn(pretrained=pretrained).features[:22]
+            )
+        elif size == "13":
+            self.layers = nn.ModuleList(
+                tvm.vgg13_bn(pretrained=pretrained).features[:28]
+            )
+        elif size == "19":
+            self.layers = nn.ModuleList(
+                tvm.vgg19_bn(pretrained=pretrained).features[:40]
+            )
         # Maxpool layers: 6, 13, 26, 39
         self.amp = amp
         self.amp_dtype = amp_dtype
 
     def forward(self, x, **kwargs):
-        with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
+        with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype):
             feats = []
             sizes = []
             for layer in self.layers:
diff --git a/third_party/DeDoDe/DeDoDe/matchers/dual_softmax_matcher.py b/third_party/DeDoDe/DeDoDe/matchers/dual_softmax_matcher.py
index 5cc76cad77ee403d7d5ab729c786982a47fbe6e9..5927cff63be726b842e74647f2beae081d803dca 100644
--- a/third_party/DeDoDe/DeDoDe/matchers/dual_softmax_matcher.py
+++ b/third_party/DeDoDe/DeDoDe/matchers/dual_softmax_matcher.py
@@ -6,33 +6,59 @@ import torch.nn.functional as F
 import numpy as np
 from DeDoDe.utils import dual_softmax_matcher, to_pixel_coords, to_normalized_coords
 
-class DualSoftMaxMatcher(nn.Module):        
+
+class DualSoftMaxMatcher(nn.Module):
     @torch.inference_mode()
-    def match(self, keypoints_A, descriptions_A, 
-              keypoints_B, descriptions_B, P_A = None, P_B = None, 
-              normalize = False, inv_temp = 1, threshold = 0.0):
+    def match(
+        self,
+        keypoints_A,
+        descriptions_A,
+        keypoints_B,
+        descriptions_B,
+        P_A=None,
+        P_B=None,
+        normalize=False,
+        inv_temp=1,
+        threshold=0.0,
+    ):
         if isinstance(descriptions_A, list):
-            matches = [self.match(k_A[None], d_A[None], k_B[None], d_B[None], normalize = normalize,
-                               inv_temp = inv_temp, threshold = threshold) 
-                    for k_A,d_A,k_B,d_B in
-                    zip(keypoints_A, descriptions_A, keypoints_B, descriptions_B)]
+            matches = [
+                self.match(
+                    k_A[None],
+                    d_A[None],
+                    k_B[None],
+                    d_B[None],
+                    normalize=normalize,
+                    inv_temp=inv_temp,
+                    threshold=threshold,
+                )
+                for k_A, d_A, k_B, d_B in zip(
+                    keypoints_A, descriptions_A, keypoints_B, descriptions_B
+                )
+            ]
             matches_A = torch.cat([m[0] for m in matches])
             matches_B = torch.cat([m[1] for m in matches])
             inds = torch.cat([m[2] + b for b, m in enumerate(matches)])
             return matches_A, matches_B, inds
-        
-        P = dual_softmax_matcher(descriptions_A, descriptions_B, 
-                                 normalize = normalize, inv_temperature=inv_temp,
-                                 )
-        inds = torch.nonzero((P == P.max(dim=-1, keepdim = True).values) 
-                        * (P == P.max(dim=-2, keepdim = True).values) * (P > threshold))
-        batch_inds = inds[:,0]
-        matches_A = keypoints_A[batch_inds, inds[:,1]]
-        matches_B = keypoints_B[batch_inds, inds[:,2]]
+
+        P = dual_softmax_matcher(
+            descriptions_A,
+            descriptions_B,
+            normalize=normalize,
+            inv_temperature=inv_temp,
+        )
+        inds = torch.nonzero(
+            (P == P.max(dim=-1, keepdim=True).values)
+            * (P == P.max(dim=-2, keepdim=True).values)
+            * (P > threshold)
+        )
+        batch_inds = inds[:, 0]
+        matches_A = keypoints_A[batch_inds, inds[:, 1]]
+        matches_B = keypoints_B[batch_inds, inds[:, 2]]
         return matches_A, matches_B, batch_inds
 
     def to_pixel_coords(self, x_A, x_B, H_A, W_A, H_B, W_B):
         return to_pixel_coords(x_A, H_A, W_A), to_pixel_coords(x_B, H_B, W_B)
-    
+
     def to_normalized_coords(self, x_A, x_B, H_A, W_A, H_B, W_B):
-        return to_normalized_coords(x_A, H_A, W_A), to_normalized_coords(x_B, H_B, W_B)
\ No newline at end of file
+        return to_normalized_coords(x_A, H_A, W_A), to_normalized_coords(x_B, H_B, W_B)
diff --git a/third_party/DeDoDe/DeDoDe/model_zoo/__init__.py b/third_party/DeDoDe/DeDoDe/model_zoo/__init__.py
index b500da585ffd0216e2e434a2179f3045f485dbfb..6296a2833d1dd18c9d52ba45dc6649ff383dfb6f 100644
--- a/third_party/DeDoDe/DeDoDe/model_zoo/__init__.py
+++ b/third_party/DeDoDe/DeDoDe/model_zoo/__init__.py
@@ -1,3 +1 @@
 from .dedode_models import dedode_detector_B, dedode_detector_L, dedode_descriptor_B
-    
-    
\ No newline at end of file
diff --git a/third_party/DeDoDe/DeDoDe/model_zoo/dedode_models.py b/third_party/DeDoDe/DeDoDe/model_zoo/dedode_models.py
index f43dd22f0d59dabd18eef4beae4a3637dcd8912b..8c6d93d4b6d3a7c0daaf767fa53cd021f248dacd 100644
--- a/third_party/DeDoDe/DeDoDe/model_zoo/dedode_models.py
+++ b/third_party/DeDoDe/DeDoDe/model_zoo/dedode_models.py
@@ -7,8 +7,7 @@ from DeDoDe.decoder import ConvRefiner, Decoder
 from DeDoDe.encoder import VGG19, VGG
 
 
-
-def dedode_detector_B(device = "cuda", weights = None):
+def dedode_detector_B(device="cuda", weights=None):
     residual = True
     hidden_blocks = 5
     amp_dtype = torch.float16
@@ -20,55 +19,55 @@ def dedode_detector_B(device = "cuda", weights = None):
                 512,
                 512,
                 256 + NUM_PROTOTYPES,
-                hidden_blocks = hidden_blocks,
-                residual = residual,
-                amp = amp,
-                amp_dtype = amp_dtype,
+                hidden_blocks=hidden_blocks,
+                residual=residual,
+                amp=amp,
+                amp_dtype=amp_dtype,
             ),
             "4": ConvRefiner(
-                256+256,
+                256 + 256,
                 256,
                 128 + NUM_PROTOTYPES,
-                hidden_blocks = hidden_blocks,
-                residual = residual,
-                amp = amp,
-                amp_dtype = amp_dtype,
-
+                hidden_blocks=hidden_blocks,
+                residual=residual,
+                amp=amp,
+                amp_dtype=amp_dtype,
             ),
             "2": ConvRefiner(
-                128+128,
+                128 + 128,
                 64,
                 32 + NUM_PROTOTYPES,
-                hidden_blocks = hidden_blocks,
-                residual = residual,
-                amp = amp,
-                amp_dtype = amp_dtype,
-
+                hidden_blocks=hidden_blocks,
+                residual=residual,
+                amp=amp,
+                amp_dtype=amp_dtype,
             ),
             "1": ConvRefiner(
                 64 + 32,
                 32,
                 1 + NUM_PROTOTYPES,
-                hidden_blocks = hidden_blocks,
-                residual = residual,
-                amp = amp,
-                amp_dtype = amp_dtype,
+                hidden_blocks=hidden_blocks,
+                residual=residual,
+                amp=amp,
+                amp_dtype=amp_dtype,
             ),
         }
     )
-    encoder = VGG19(pretrained = False, amp = amp, amp_dtype = amp_dtype)
+    encoder = VGG19(pretrained=False, amp=amp, amp_dtype=amp_dtype)
     decoder = Decoder(conv_refiner)
-    model = DeDoDeDetector(encoder = encoder, decoder = decoder).to(device)
+    model = DeDoDeDetector(encoder=encoder, decoder=decoder).to(device)
     if weights is not None:
         model.load_state_dict(weights)
     return model
 
 
-def dedode_detector_L(device = "cuda", weights = None):
+def dedode_detector_L(device="cuda", weights=None):
     NUM_PROTOTYPES = 1
     residual = True
     hidden_blocks = 8
-    amp_dtype = torch.float16#torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+    amp_dtype = (
+        torch.float16
+    )  # torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
     amp = True
     conv_refiner = nn.ModuleDict(
         {
@@ -76,56 +75,55 @@ def dedode_detector_L(device = "cuda", weights = None):
                 512,
                 512,
                 256 + NUM_PROTOTYPES,
-                hidden_blocks = hidden_blocks,
-                residual = residual,
-                amp = amp,
-                amp_dtype = amp_dtype,
+                hidden_blocks=hidden_blocks,
+                residual=residual,
+                amp=amp,
+                amp_dtype=amp_dtype,
             ),
             "4": ConvRefiner(
-                256+256,
+                256 + 256,
                 256,
                 128 + NUM_PROTOTYPES,
-                hidden_blocks = hidden_blocks,
-                residual = residual,
-                amp = amp,
-                amp_dtype = amp_dtype,
-
+                hidden_blocks=hidden_blocks,
+                residual=residual,
+                amp=amp,
+                amp_dtype=amp_dtype,
             ),
             "2": ConvRefiner(
-                128+128,
+                128 + 128,
                 128,
                 64 + NUM_PROTOTYPES,
-                hidden_blocks = hidden_blocks,
-                residual = residual,
-                amp = amp,
-                amp_dtype = amp_dtype,
-
+                hidden_blocks=hidden_blocks,
+                residual=residual,
+                amp=amp,
+                amp_dtype=amp_dtype,
             ),
             "1": ConvRefiner(
                 64 + 64,
                 64,
                 1 + NUM_PROTOTYPES,
-                hidden_blocks = hidden_blocks,
-                residual = residual,
-                amp = amp,
-                amp_dtype = amp_dtype,
+                hidden_blocks=hidden_blocks,
+                residual=residual,
+                amp=amp,
+                amp_dtype=amp_dtype,
             ),
         }
     )
-    encoder = VGG19(pretrained = False, amp = amp, amp_dtype = amp_dtype)
+    encoder = VGG19(pretrained=False, amp=amp, amp_dtype=amp_dtype)
     decoder = Decoder(conv_refiner)
-    model = DeDoDeDetector(encoder = encoder, decoder = decoder).to(device)
+    model = DeDoDeDetector(encoder=encoder, decoder=decoder).to(device)
     if weights is not None:
         model.load_state_dict(weights)
     return model
 
 
-
-def dedode_descriptor_B(device = "cuda", weights = None):
-    NUM_PROTOTYPES = 256 # == descriptor size
+def dedode_descriptor_B(device="cuda", weights=None):
+    NUM_PROTOTYPES = 256  # == descriptor size
     residual = True
     hidden_blocks = 5
-    amp_dtype = torch.float16#torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+    amp_dtype = (
+        torch.float16
+    )  # torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
     amp = True
     conv_refiner = nn.ModuleDict(
         {
@@ -133,45 +131,43 @@ def dedode_descriptor_B(device = "cuda", weights = None):
                 512,
                 512,
                 256 + NUM_PROTOTYPES,
-                hidden_blocks = hidden_blocks,
-                residual = residual,
-                amp = amp,
-                amp_dtype = amp_dtype,
+                hidden_blocks=hidden_blocks,
+                residual=residual,
+                amp=amp,
+                amp_dtype=amp_dtype,
             ),
             "4": ConvRefiner(
-                256+256,
+                256 + 256,
                 256,
                 128 + NUM_PROTOTYPES,
-                hidden_blocks = hidden_blocks,
-                residual = residual,
-                amp = amp,
-                amp_dtype = amp_dtype,
-
+                hidden_blocks=hidden_blocks,
+                residual=residual,
+                amp=amp,
+                amp_dtype=amp_dtype,
             ),
             "2": ConvRefiner(
-                128+128,
+                128 + 128,
                 64,
                 32 + NUM_PROTOTYPES,
-                hidden_blocks = hidden_blocks,
-                residual = residual,
-                amp = amp,
-                amp_dtype = amp_dtype,
-
+                hidden_blocks=hidden_blocks,
+                residual=residual,
+                amp=amp,
+                amp_dtype=amp_dtype,
             ),
             "1": ConvRefiner(
                 64 + 32,
                 32,
                 1 + NUM_PROTOTYPES,
-                hidden_blocks = hidden_blocks,
-                residual = residual,
-                amp = amp,
-                amp_dtype = amp_dtype,
+                hidden_blocks=hidden_blocks,
+                residual=residual,
+                amp=amp,
+                amp_dtype=amp_dtype,
             ),
         }
     )
-    encoder = VGG(size = "19", pretrained = False, amp = amp, amp_dtype = amp_dtype)
+    encoder = VGG(size="19", pretrained=False, amp=amp, amp_dtype=amp_dtype)
     decoder = Decoder(conv_refiner, num_prototypes=NUM_PROTOTYPES)
-    model = DeDoDeDescriptor(encoder = encoder, decoder = decoder).to(device)    
+    model = DeDoDeDescriptor(encoder=encoder, decoder=decoder).to(device)
     if weights is not None:
         model.load_state_dict(weights)
     return model
diff --git a/third_party/DeDoDe/DeDoDe/train.py b/third_party/DeDoDe/DeDoDe/train.py
index 348f268d6f7752bdf2ad45ba1851ec13a57825a0..2572e3a726d16ffef1bb734feeba0a7a19f4d354 100644
--- a/third_party/DeDoDe/DeDoDe/train.py
+++ b/third_party/DeDoDe/DeDoDe/train.py
@@ -3,7 +3,7 @@ from tqdm import tqdm
 from DeDoDe.utils import to_cuda
 
 
-def train_step(train_batch, model, objective, optimizer, grad_scaler = None,**kwargs):
+def train_step(train_batch, model, objective, optimizer, grad_scaler=None, **kwargs):
     optimizer.zero_grad()
     out = model(train_batch)
     l = objective(out, train_batch)
@@ -20,9 +20,17 @@ def train_step(train_batch, model, objective, optimizer, grad_scaler = None,**kw
 
 
 def train_k_steps(
-    n_0, k, dataloader, model, objective, optimizer, lr_scheduler, grad_scaler = None, progress_bar=True
+    n_0,
+    k,
+    dataloader,
+    model,
+    objective,
+    optimizer,
+    lr_scheduler,
+    grad_scaler=None,
+    progress_bar=True,
 ):
-    for n in tqdm(range(n_0, n_0 + k), disable=not progress_bar, mininterval = 10.):
+    for n in tqdm(range(n_0, n_0 + k), disable=not progress_bar, mininterval=10.0):
         batch = next(dataloader)
         model.train(True)
         batch = to_cuda(batch)
@@ -33,7 +41,7 @@ def train_k_steps(
             optimizer=optimizer,
             lr_scheduler=lr_scheduler,
             n=n,
-            grad_scaler = grad_scaler,
+            grad_scaler=grad_scaler,
         )
         lr_scheduler.step()
 
diff --git a/third_party/DeDoDe/DeDoDe/utils.py b/third_party/DeDoDe/DeDoDe/utils.py
index 183c35f5606301720adffa2b7b25e7996404e1a1..1076a06b98ac5ce74f847e75fff86d2a913f9348 100644
--- a/third_party/DeDoDe/DeDoDe/utils.py
+++ b/third_party/DeDoDe/DeDoDe/utils.py
@@ -11,13 +11,14 @@ from einops import rearrange
 import torch
 from time import perf_counter
 
+
 def recover_pose(E, kpts0, kpts1, K0, K1, mask):
     best_num_inliers = 0
-    K0inv = np.linalg.inv(K0[:2,:2])
-    K1inv = np.linalg.inv(K1[:2,:2])
+    K0inv = np.linalg.inv(K0[:2, :2])
+    K1inv = np.linalg.inv(K1[:2, :2])
 
-    kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T 
-    kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
+    kpts0_n = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T
+    kpts1_n = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T
 
     for _E in np.split(E, len(E) / 3):
         n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
@@ -27,17 +28,16 @@ def recover_pose(E, kpts0, kpts1, K0, K1, mask):
     return ret
 
 
-
 # Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py
 # --- GEOMETRY ---
 def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
     if len(kpts0) < 5:
         return None
-    K0inv = np.linalg.inv(K0[:2,:2])
-    K1inv = np.linalg.inv(K1[:2,:2])
+    K0inv = np.linalg.inv(K0[:2, :2])
+    K1inv = np.linalg.inv(K1[:2, :2])
 
-    kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T 
-    kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T
+    kpts0 = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T
+    kpts1 = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T
     E, mask = cv2.findEssentialMat(
         kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf
     )
@@ -54,150 +54,213 @@ def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
     return ret
 
 
-def get_grid(B,H,W, device = "cuda"):
+def get_grid(B, H, W, device="cuda"):
     x1_n = torch.meshgrid(
-    *[
-        torch.linspace(
-            -1 + 1 / n, 1 - 1 / n, n, device=device
-        )
-        for n in (B, H, W)
-    ]
+        *[torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device) for n in (B, H, W)]
     )
     x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
     return x1_n
 
+
 @torch.no_grad()
-def finite_diff_hessian(f: tuple(["B", "H", "W"]), device = "cuda"):
-    dxx = torch.tensor([[0,0,0],[1,-2,1],[0,0,0]], device = device)[None,None]/2
-    dxy = torch.tensor([[1,0,-1],[0,0,0],[-1,0,1]], device = device)[None,None]/4
+def finite_diff_hessian(f: tuple(["B", "H", "W"]), device="cuda"):
+    dxx = (
+        torch.tensor([[0, 0, 0], [1, -2, 1], [0, 0, 0]], device=device)[None, None] / 2
+    )
+    dxy = (
+        torch.tensor([[1, 0, -1], [0, 0, 0], [-1, 0, 1]], device=device)[None, None] / 4
+    )
     dyy = dxx.mT
-    Hxx = F.conv2d(f[:,None], dxx, padding = 1)[:,0]
-    Hxy = F.conv2d(f[:,None], dxy, padding = 1)[:,0]
-    Hyy = F.conv2d(f[:,None], dyy, padding = 1)[:,0]
-    H = torch.stack((Hxx, Hxy, Hxy, Hyy), dim = -1).reshape(*f.shape,2,2)
+    Hxx = F.conv2d(f[:, None], dxx, padding=1)[:, 0]
+    Hxy = F.conv2d(f[:, None], dxy, padding=1)[:, 0]
+    Hyy = F.conv2d(f[:, None], dyy, padding=1)[:, 0]
+    H = torch.stack((Hxx, Hxy, Hxy, Hyy), dim=-1).reshape(*f.shape, 2, 2)
     return H
 
-def finite_diff_grad(f: tuple(["B", "H", "W"]), device = "cuda"):
-    dx = torch.tensor([[0,0,0],[-1,0,1],[0,0,0]],device = device)[None,None]/2
+
+def finite_diff_grad(f: tuple(["B", "H", "W"]), device="cuda"):
+    dx = torch.tensor([[0, 0, 0], [-1, 0, 1], [0, 0, 0]], device=device)[None, None] / 2
     dy = dx.mT
-    gx = F.conv2d(f[:,None], dx, padding = 1)
-    gy = F.conv2d(f[:,None], dy, padding = 1)
-    g = torch.cat((gx, gy), dim = 1)
+    gx = F.conv2d(f[:, None], dx, padding=1)
+    gy = F.conv2d(f[:, None], dy, padding=1)
+    g = torch.cat((gx, gy), dim=1)
     return g
 
-def fast_inv_2x2(matrix: tuple[...,2,2], eps = 1e-10):
-    return 1/(torch.linalg.det(matrix)[...,None,None]+eps) * torch.stack((matrix[...,1,1],-matrix[...,0,1],
-                                                     -matrix[...,1,0],matrix[...,0,0]),dim=-1).reshape(*matrix.shape)
 
-def newton_step(f:tuple["B","H","W"], inds, device = "cuda"):
-    B,H,W = f.shape
-    Hess = finite_diff_hessian(f).reshape(B,H*W,2,2)
-    Hess = torch.gather(Hess, dim = 1, index = inds[...,None].expand(B,-1,2,2))
-    grad = finite_diff_grad(f).reshape(B,H*W,2)
-    grad = torch.gather(grad, dim = 1, index = inds)
-    Hessinv = fast_inv_2x2(Hess-torch.eye(2, device = device)[None,None])
-    step = (Hessinv @ grad[...,None])
-    return step[...,0]
+def fast_inv_2x2(matrix: tuple[..., 2, 2], eps=1e-10):
+    return (
+        1
+        / (torch.linalg.det(matrix)[..., None, None] + eps)
+        * torch.stack(
+            (
+                matrix[..., 1, 1],
+                -matrix[..., 0, 1],
+                -matrix[..., 1, 0],
+                matrix[..., 0, 0],
+            ),
+            dim=-1,
+        ).reshape(*matrix.shape)
+    )
+
+
+def newton_step(f: tuple["B", "H", "W"], inds, device="cuda"):
+    B, H, W = f.shape
+    Hess = finite_diff_hessian(f).reshape(B, H * W, 2, 2)
+    Hess = torch.gather(Hess, dim=1, index=inds[..., None].expand(B, -1, 2, 2))
+    grad = finite_diff_grad(f).reshape(B, H * W, 2)
+    grad = torch.gather(grad, dim=1, index=inds)
+    Hessinv = fast_inv_2x2(Hess - torch.eye(2, device=device)[None, None])
+    step = Hessinv @ grad[..., None]
+    return step[..., 0]
+
 
 @torch.no_grad()
-def sample_keypoints(scoremap, num_samples = 8192, device = "cuda", use_nms = True, 
-                     sample_topk = False, return_scoremap = False, sharpen = False, upsample = False,
-                     increase_coverage = False,):
-    #scoremap = scoremap**2
-    log_scoremap = (scoremap+1e-10).log()
+def sample_keypoints(
+    scoremap,
+    num_samples=8192,
+    device="cuda",
+    use_nms=True,
+    sample_topk=False,
+    return_scoremap=False,
+    sharpen=False,
+    upsample=False,
+    increase_coverage=False,
+):
+    # scoremap = scoremap**2
+    log_scoremap = (scoremap + 1e-10).log()
     if upsample:
-        log_scoremap = F.interpolate(log_scoremap[:,None], scale_factor = 3, mode = "bicubic", align_corners = False)[:,0]#.clamp(min = 0)
+        log_scoremap = F.interpolate(
+            log_scoremap[:, None], scale_factor=3, mode="bicubic", align_corners=False
+        )[
+            :, 0
+        ]  # .clamp(min = 0)
         scoremap = log_scoremap.exp()
-    B,H,W = scoremap.shape
+    B, H, W = scoremap.shape
     if increase_coverage:
-        weights = (-torch.linspace(-2, 2, steps = 51, device = device)**2).exp()[None,None]
+        weights = (-torch.linspace(-2, 2, steps=51, device=device) ** 2).exp()[
+            None, None
+        ]
         # 10000 is just some number for maybe numerical stability, who knows. :), result is invariant anyway
-        local_density_x = F.conv2d((scoremap[:,None]+1e-6)*10000,weights[...,None,:], padding = (0,51//2))
-        local_density = F.conv2d(local_density_x, weights[...,None], padding = (51//2,0))[:,0]
-        scoremap = scoremap * (local_density+1e-8)**(-1/2)
-    grid = get_grid(B,H,W, device=device).reshape(B,H*W,2)
+        local_density_x = F.conv2d(
+            (scoremap[:, None] + 1e-6) * 10000,
+            weights[..., None, :],
+            padding=(0, 51 // 2),
+        )
+        local_density = F.conv2d(
+            local_density_x, weights[..., None], padding=(51 // 2, 0)
+        )[:, 0]
+        scoremap = scoremap * (local_density + 1e-8) ** (-1 / 2)
+    grid = get_grid(B, H, W, device=device).reshape(B, H * W, 2)
     if sharpen:
-        laplace_operator = torch.tensor([[[[0,1,0],[1,-4,1],[0,1,0]]]], device = device)/4
-        scoremap = scoremap[:,None] - 0.5 * F.conv2d(scoremap[:,None], weight = laplace_operator, padding = 1)
-        scoremap = scoremap[:,0].clamp(min = 0)
+        laplace_operator = (
+            torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], device=device) / 4
+        )
+        scoremap = scoremap[:, None] - 0.5 * F.conv2d(
+            scoremap[:, None], weight=laplace_operator, padding=1
+        )
+        scoremap = scoremap[:, 0].clamp(min=0)
     if use_nms:
-        scoremap = scoremap * (scoremap == F.max_pool2d(scoremap, (3, 3), stride = 1, padding = 1))
+        scoremap = scoremap * (
+            scoremap == F.max_pool2d(scoremap, (3, 3), stride=1, padding=1)
+        )
     if sample_topk:
-        inds = torch.topk(scoremap.reshape(B,H*W), k = num_samples).indices
+        inds = torch.topk(scoremap.reshape(B, H * W), k=num_samples).indices
     else:
-        inds = torch.multinomial(scoremap.reshape(B,H*W), num_samples = num_samples, replacement=False)
-    kps = torch.gather(grid, dim = 1, index = inds[...,None].expand(B,num_samples,2))
+        inds = torch.multinomial(
+            scoremap.reshape(B, H * W), num_samples=num_samples, replacement=False
+        )
+    kps = torch.gather(grid, dim=1, index=inds[..., None].expand(B, num_samples, 2))
     if return_scoremap:
-        return kps, torch.gather(scoremap.reshape(B,H*W), dim = 1, index = inds)
+        return kps, torch.gather(scoremap.reshape(B, H * W), dim=1, index=inds)
     return kps
 
+
 @torch.no_grad()
-def jacobi_determinant(warp, certainty, R = 3, device = "cuda", dtype = torch.float32):
+def jacobi_determinant(warp, certainty, R=3, device="cuda", dtype=torch.float32):
     t = perf_counter()
     *dims, _ = warp.shape
     warp = warp.to(dtype)
     certainty = certainty.to(dtype)
-    
+
     dtype = warp.dtype
-    match_regions = torch.zeros((*dims, 4, R, R), device = device).to(dtype)
-    match_regions[:,1:-1, 1:-1] = warp.unfold(1,R,1).unfold(2,R,1)
-    match_regions = rearrange(match_regions,"B H W D R1 R2 -> B H W (R1 R2) D") - warp[...,None,:]
-    
-    match_regions_cert = torch.zeros((*dims, R, R), device = device).to(dtype)
-    match_regions_cert[:,1:-1, 1:-1] = certainty.unfold(1,R,1).unfold(2,R,1)
-    match_regions_cert = rearrange(match_regions_cert,"B H W R1 R2 -> B H W (R1 R2)")[..., None]
-
-    #print("Time for unfold", perf_counter()-t)
-    #t = perf_counter()
+    match_regions = torch.zeros((*dims, 4, R, R), device=device).to(dtype)
+    match_regions[:, 1:-1, 1:-1] = warp.unfold(1, R, 1).unfold(2, R, 1)
+    match_regions = (
+        rearrange(match_regions, "B H W D R1 R2 -> B H W (R1 R2) D")
+        - warp[..., None, :]
+    )
+
+    match_regions_cert = torch.zeros((*dims, R, R), device=device).to(dtype)
+    match_regions_cert[:, 1:-1, 1:-1] = certainty.unfold(1, R, 1).unfold(2, R, 1)
+    match_regions_cert = rearrange(match_regions_cert, "B H W R1 R2 -> B H W (R1 R2)")[
+        ..., None
+    ]
+
+    # print("Time for unfold", perf_counter()-t)
+    # t = perf_counter()
     *dims, N, D = match_regions.shape
     # standardize:
-    mu, sigma = match_regions.mean(dim=(-2,-1), keepdim = True), match_regions.std(dim=(-2,-1),keepdim=True)
-    match_regions = (match_regions-mu)/(sigma+1e-6)
-    x_a, x_b = match_regions.chunk(2,-1)
-    
+    mu, sigma = match_regions.mean(dim=(-2, -1), keepdim=True), match_regions.std(
+        dim=(-2, -1), keepdim=True
+    )
+    match_regions = (match_regions - mu) / (sigma + 1e-6)
+    x_a, x_b = match_regions.chunk(2, -1)
 
-    A = torch.zeros((*dims,2*x_a.shape[-2],4), device = device).to(dtype)
-    A[...,::2,:2] = x_a * match_regions_cert
-    A[...,1::2,2:] = x_a * match_regions_cert
+    A = torch.zeros((*dims, 2 * x_a.shape[-2], 4), device=device).to(dtype)
+    A[..., ::2, :2] = x_a * match_regions_cert
+    A[..., 1::2, 2:] = x_a * match_regions_cert
 
-    a_block = A[...,::2,:2]
+    a_block = A[..., ::2, :2]
     ata = a_block.mT @ a_block
-    #print("Time for ata", perf_counter()-t)
-    #t = perf_counter()
+    # print("Time for ata", perf_counter()-t)
+    # t = perf_counter()
 
-    #atainv = torch.linalg.inv(ata+1e-5*torch.eye(2,device=device).to(dtype))
+    # atainv = torch.linalg.inv(ata+1e-5*torch.eye(2,device=device).to(dtype))
     atainv = fast_inv_2x2(ata)
-    ATA_inv = torch.zeros((*dims, 4, 4), device = device, dtype = dtype)
-    ATA_inv[...,:2,:2] = atainv
-    ATA_inv[...,2:,2:] = atainv
-    atb = A.mT @ (match_regions_cert*x_b).reshape(*dims,N*2,1)
-    theta =  ATA_inv @ atb
-    #print("Time for theta", perf_counter()-t)
-    #t = perf_counter()
+    ATA_inv = torch.zeros((*dims, 4, 4), device=device, dtype=dtype)
+    ATA_inv[..., :2, :2] = atainv
+    ATA_inv[..., 2:, 2:] = atainv
+    atb = A.mT @ (match_regions_cert * x_b).reshape(*dims, N * 2, 1)
+    theta = ATA_inv @ atb
+    # print("Time for theta", perf_counter()-t)
+    # t = perf_counter()
 
     J = theta.reshape(*dims, 2, 2)
-    abs_J_det = torch.linalg.det(J+1e-8*torch.eye(2,2,device = device).expand(*dims,2,2)).abs() # Note: This should always be positive for correct warps, but still taking abs here
-    abs_J_logdet = (abs_J_det+1e-12).log()
+    abs_J_det = torch.linalg.det(
+        J + 1e-8 * torch.eye(2, 2, device=device).expand(*dims, 2, 2)
+    ).abs()  # Note: This should always be positive for correct warps, but still taking abs here
+    abs_J_logdet = (abs_J_det + 1e-12).log()
     B = certainty.shape[0]
     # Handle outliers
-    robust_abs_J_logdet = abs_J_logdet.clamp(-3, 3) # Shouldn't be more that exp(3) \approx 8 times zoom
-    #print("Time for logdet", perf_counter()-t)
-    #t = perf_counter()
+    robust_abs_J_logdet = abs_J_logdet.clamp(
+        -3, 3
+    )  # Shouldn't be more that exp(3) \approx 8 times zoom
+    # print("Time for logdet", perf_counter()-t)
+    # t = perf_counter()
 
     return robust_abs_J_logdet
 
-def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None):
-    
+
+def get_gt_warp(
+    depth1,
+    depth2,
+    T_1to2,
+    K1,
+    K2,
+    depth_interpolation_mode="bilinear",
+    relative_depth_error_threshold=0.05,
+    H=None,
+    W=None,
+):
+
     if H is None:
-        B,H,W = depth1.shape
+        B, H, W = depth1.shape
     else:
         B = depth1.shape[0]
     with torch.no_grad():
         x1_n = torch.meshgrid(
             *[
-                torch.linspace(
-                    -1 + 1 / n, 1 - 1 / n, n, device=depth1.device
-                )
+                torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=depth1.device)
                 for n in (B, H, W)
             ]
         )
@@ -209,20 +272,21 @@ def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bili
             T_1to2.double(),
             K1.double(),
             K2.double(),
-            depth_interpolation_mode = depth_interpolation_mode,
-            relative_depth_error_threshold = relative_depth_error_threshold,
+            depth_interpolation_mode=depth_interpolation_mode,
+            relative_depth_error_threshold=relative_depth_error_threshold,
         )
         prob = mask.float().reshape(B, H, W)
         x2 = x2.reshape(B, H, W, 2)
-        return torch.cat((x1_n.reshape(B,H,W,2),x2),dim=-1), prob
+        return torch.cat((x1_n.reshape(B, H, W, 2), x2), dim=-1), prob
+
 
 def recover_pose(E, kpts0, kpts1, K0, K1, mask):
     best_num_inliers = 0
-    K0inv = np.linalg.inv(K0[:2,:2])
-    K1inv = np.linalg.inv(K1[:2,:2])
+    K0inv = np.linalg.inv(K0[:2, :2])
+    K1inv = np.linalg.inv(K1[:2, :2])
 
-    kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T 
-    kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
+    kpts0_n = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T
+    kpts1_n = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T
 
     for _E in np.split(E, len(E) / 3):
         n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
@@ -232,17 +296,23 @@ def recover_pose(E, kpts0, kpts1, K0, K1, mask):
     return ret
 
 
-
 # Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py
 # --- GEOMETRY ---
-def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999, ):
+def estimate_pose(
+    kpts0,
+    kpts1,
+    K0,
+    K1,
+    norm_thresh,
+    conf=0.99999,
+):
     if len(kpts0) < 5:
         return None
-    K0inv = np.linalg.inv(K0[:2,:2])
-    K1inv = np.linalg.inv(K1[:2,:2])
+    K0inv = np.linalg.inv(K0[:2, :2])
+    K1inv = np.linalg.inv(K1[:2, :2])
 
-    kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T 
-    kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T
+    kpts0 = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T
+    kpts1 = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T
     method = cv2.USAC_ACCURATE
     E, mask = cv2.findEssentialMat(
         kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf, method=method
@@ -259,31 +329,40 @@ def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999, ):
                 ret = (R, t, mask.ravel() > 0)
     return ret
 
+
 def estimate_pose_uncalibrated(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
     if len(kpts0) < 5:
         return None
     method = cv2.USAC_ACCURATE
     F, mask = cv2.findFundamentalMat(
-        kpts0, kpts1, ransacReprojThreshold=norm_thresh, confidence=conf, method=method, maxIters=10000
+        kpts0,
+        kpts1,
+        ransacReprojThreshold=norm_thresh,
+        confidence=conf,
+        method=method,
+        maxIters=10000,
     )
-    E = K1.T@F@K0
+    E = K1.T @ F @ K0
     ret = None
     if E is not None:
         best_num_inliers = 0
-        K0inv = np.linalg.inv(K0[:2,:2])
-        K1inv = np.linalg.inv(K1[:2,:2])
+        K0inv = np.linalg.inv(K0[:2, :2])
+        K1inv = np.linalg.inv(K1[:2, :2])
+
+        kpts0_n = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T
+        kpts1_n = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T
 
-        kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T 
-        kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
- 
         for _E in np.split(E, len(E) / 3):
-            n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
+            n, R, t, _ = cv2.recoverPose(
+                _E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask
+            )
             if n > best_num_inliers:
                 best_num_inliers = n
                 ret = (R, t, mask.ravel() > 0)
     return ret
 
-def unnormalize_coords(x_n,h,w):
+
+def unnormalize_coords(x_n, h, w):
     x = torch.stack(
         (w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1
     )  # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
@@ -316,6 +395,7 @@ def scale_intrinsics(K, scales):
     scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0])
     return np.dot(scales, K)
 
+
 def angle_error_mat(R1, R2):
     cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
     cos = np.clip(cos, -1.0, 1.0)  # numercial errors can make it out of bounds
@@ -355,14 +435,16 @@ def pose_auc(errors, thresholds):
 def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False):
     ops = []
     if resize:
-        ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR, antialias = False))
+        ops.append(
+            TupleResize(resize, mode=InterpolationMode.BILINEAR, antialias=False)
+        )
     return TupleCompose(ops)
 
 
-def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = False):
+def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe=False):
     ops = []
     if resize:
-        ops.append(TupleResize(resize, antialias = True))
+        ops.append(TupleResize(resize, antialias=True))
     if clahe:
         ops.append(TupleClahe())
     if normalize:
@@ -377,22 +459,27 @@ def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe =
             ops.append(TupleToTensorScaled())
     return TupleCompose(ops)
 
+
 class Clahe:
-    def __init__(self, cliplimit = 2, blocksize = 8) -> None:
-        self.clahe = cv2.createCLAHE(cliplimit,(blocksize,blocksize))
+    def __init__(self, cliplimit=2, blocksize=8) -> None:
+        self.clahe = cv2.createCLAHE(cliplimit, (blocksize, blocksize))
+
     def __call__(self, im):
-        im_hsv = cv2.cvtColor(np.array(im),cv2.COLOR_RGB2HSV)
-        im_v = self.clahe.apply(im_hsv[:,:,2])
-        im_hsv[...,2] = im_v
-        im_clahe = cv2.cvtColor(im_hsv,cv2.COLOR_HSV2RGB)
+        im_hsv = cv2.cvtColor(np.array(im), cv2.COLOR_RGB2HSV)
+        im_v = self.clahe.apply(im_hsv[:, :, 2])
+        im_hsv[..., 2] = im_v
+        im_clahe = cv2.cvtColor(im_hsv, cv2.COLOR_HSV2RGB)
         return Image.fromarray(im_clahe)
 
+
 class TupleClahe:
-    def __init__(self, cliplimit = 8, blocksize = 8) -> None:
-        self.clahe = Clahe(cliplimit,blocksize)
+    def __init__(self, cliplimit=8, blocksize=8) -> None:
+        self.clahe = Clahe(cliplimit, blocksize)
+
     def __call__(self, ims):
         return [self.clahe(im) for im in ims]
 
+
 class ToTensorScaled(object):
     """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]"""
 
@@ -443,9 +530,9 @@ class TupleToTensorUnscaled(object):
 
 
 class TupleResize(object):
-    def __init__(self, size, mode=InterpolationMode.BICUBIC, antialias = None):
+    def __init__(self, size, mode=InterpolationMode.BICUBIC, antialias=None):
         self.size = size
-        self.resize = transforms.Resize(size, mode, antialias = antialias)
+        self.resize = transforms.Resize(size, mode, antialias=antialias)
 
     def __call__(self, im_tuple):
         return [self.resize(im) for im in im_tuple]
@@ -453,11 +540,12 @@ class TupleResize(object):
     def __repr__(self):
         return "TupleResize(size={})".format(self.size)
 
+
 class Normalize:
-    def __call__(self,im):
-        mean = im.mean(dim=(1,2), keepdims=True)
-        std = im.std(dim=(1,2), keepdims=True)
-        return (im-mean)/std
+    def __call__(self, im):
+        mean = im.mean(dim=(1, 2), keepdims=True)
+        std = im.std(dim=(1, 2), keepdims=True)
+        return (im - mean) / std
 
 
 class TupleNormalize(object):
@@ -467,7 +555,7 @@ class TupleNormalize(object):
         self.normalize = transforms.Normalize(mean=mean, std=std)
 
     def __call__(self, im_tuple):
-        c,h,w = im_tuple[0].shape
+        c, h, w = im_tuple[0].shape
         if c > 3:
             warnings.warn(f"Number of channels {c=} > 3, assuming first 3 are rgb")
         return [self.normalize(im[:3]) for im in im_tuple]
@@ -495,7 +583,18 @@ class TupleCompose(object):
 
 
 @torch.no_grad()
-def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05):
+def warp_kpts(
+    kpts0,
+    depth0,
+    depth1,
+    T_0to1,
+    K0,
+    K1,
+    smooth_mask=False,
+    return_relative_depth_error=False,
+    depth_interpolation_mode="bilinear",
+    relative_depth_error_threshold=0.05,
+):
     """Warp kpts0 from I0 to I1 with depth, K and Rt
     Also check covisibility and depth consistency.
     Depth is consistent if relative error < 0.2 (hard-coded).
@@ -520,26 +619,44 @@ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return
         # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation
         if smooth_mask:
             raise NotImplementedError("Combined bilinear and NN warp not implemented")
-        valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, 
-                  smooth_mask = smooth_mask, 
-                  return_relative_depth_error = return_relative_depth_error, 
-                  depth_interpolation_mode = "bilinear",
-                  relative_depth_error_threshold = relative_depth_error_threshold)
-        valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, 
-                  smooth_mask = smooth_mask, 
-                  return_relative_depth_error = return_relative_depth_error, 
-                  depth_interpolation_mode = "nearest-exact",
-                  relative_depth_error_threshold = relative_depth_error_threshold)
-        nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest) 
+        valid_bilinear, warp_bilinear = warp_kpts(
+            kpts0,
+            depth0,
+            depth1,
+            T_0to1,
+            K0,
+            K1,
+            smooth_mask=smooth_mask,
+            return_relative_depth_error=return_relative_depth_error,
+            depth_interpolation_mode="bilinear",
+            relative_depth_error_threshold=relative_depth_error_threshold,
+        )
+        valid_nearest, warp_nearest = warp_kpts(
+            kpts0,
+            depth0,
+            depth1,
+            T_0to1,
+            K0,
+            K1,
+            smooth_mask=smooth_mask,
+            return_relative_depth_error=return_relative_depth_error,
+            depth_interpolation_mode="nearest-exact",
+            relative_depth_error_threshold=relative_depth_error_threshold,
+        )
+        nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest)
         warp = warp_bilinear.clone()
-        warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid]
+        warp[nearest_valid_bilinear_invalid] = warp_nearest[
+            nearest_valid_bilinear_invalid
+        ]
         valid = valid_bilinear | valid_nearest
         return valid, warp
-        
-        
-    kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[
-        :, 0, :, 0
-    ]
+
+    kpts0_depth = F.grid_sample(
+        depth0[:, None],
+        kpts0[:, :, None],
+        mode=depth_interpolation_mode,
+        align_corners=False,
+    )[:, 0, :, 0]
     kpts0 = torch.stack(
         (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
     )  # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
@@ -578,22 +695,26 @@ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return
     # w_kpts0[~covisible_mask, :] = -5 # xd
 
     w_kpts0_depth = F.grid_sample(
-        depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False
+        depth1[:, None],
+        w_kpts0[:, :, None],
+        mode=depth_interpolation_mode,
+        align_corners=False,
     )[:, 0, :, 0]
-    
+
     relative_depth_error = (
         (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
     ).abs()
     if not smooth_mask:
         consistent_mask = relative_depth_error < relative_depth_error_threshold
     else:
-        consistent_mask = (-relative_depth_error/smooth_mask).exp()
+        consistent_mask = (-relative_depth_error / smooth_mask).exp()
     valid_mask = nonzero_mask * covisible_mask * consistent_mask
     if return_relative_depth_error:
         return relative_depth_error, w_kpts0
     else:
         return valid_mask, w_kpts0
 
+
 imagenet_mean = torch.tensor([0.485, 0.456, 0.406])
 imagenet_std = torch.tensor([0.229, 0.224, 0.225])
 
@@ -611,15 +732,17 @@ def numpy_to_pil(x: np.ndarray):
     return Image.fromarray(x)
 
 
-def tensor_to_pil(x, unnormalize=False, autoscale = False):
+def tensor_to_pil(x, unnormalize=False, autoscale=False):
     if unnormalize:
-        x = x * (imagenet_std[:, None, None].to(x.device)) + (imagenet_mean[:, None, None].to(x.device))
+        x = x * (imagenet_std[:, None, None].to(x.device)) + (
+            imagenet_mean[:, None, None].to(x.device)
+        )
     if autoscale:
         if x.max() == x.min():
             warnings.warn("x max == x min, cant autoscale")
         else:
-            x = (x-x.min())/(x.max()-x.min())
-        
+            x = (x - x.min()) / (x.max() - x.min())
+
     x = x.detach().permute(1, 2, 0).cpu().numpy()
     x = np.clip(x, 0.0, 1.0)
     return numpy_to_pil(x)
@@ -649,61 +772,57 @@ def compute_relative_pose(R1, t1, R2, t2):
     trans = -rots @ t1 + t2
     return rots, trans
 
+
 def to_pixel_coords(flow, h1, w1):
-    flow = (
-        torch.stack(
-            (
-                w1 * (flow[..., 0] + 1) / 2,
-                h1 * (flow[..., 1] + 1) / 2,
-            ),
-            axis=-1,
-        )
+    flow = torch.stack(
+        (
+            w1 * (flow[..., 0] + 1) / 2,
+            h1 * (flow[..., 1] + 1) / 2,
+        ),
+        axis=-1,
     )
     return flow
 
+
 def to_normalized_coords(flow, h1, w1):
-    flow = (
-        torch.stack(
-            (
-                2 * (flow[..., 0]) / w1 - 1,
-                2 * (flow[..., 1]) / h1 - 1,
-            ),
-            axis=-1,
-        )
+    flow = torch.stack(
+        (
+            2 * (flow[..., 0]) / w1 - 1,
+            2 * (flow[..., 1]) / h1 - 1,
+        ),
+        axis=-1,
     )
     return flow
 
 
 def warp_to_pixel_coords(warp, h1, w1, h2, w2):
     warp1 = warp[..., :2]
-    warp1 = (
-        torch.stack(
-            (
-                w1 * (warp1[..., 0] + 1) / 2,
-                h1 * (warp1[..., 1] + 1) / 2,
-            ),
-            axis=-1,
-        )
+    warp1 = torch.stack(
+        (
+            w1 * (warp1[..., 0] + 1) / 2,
+            h1 * (warp1[..., 1] + 1) / 2,
+        ),
+        axis=-1,
     )
     warp2 = warp[..., 2:]
-    warp2 = (
-        torch.stack(
-            (
-                w2 * (warp2[..., 0] + 1) / 2,
-                h2 * (warp2[..., 1] + 1) / 2,
-            ),
-            axis=-1,
-        )
+    warp2 = torch.stack(
+        (
+            w2 * (warp2[..., 0] + 1) / 2,
+            h2 * (warp2[..., 1] + 1) / 2,
+        ),
+        axis=-1,
     )
-    return torch.cat((warp1,warp2), dim=-1)
+    return torch.cat((warp1, warp2), dim=-1)
 
 
 def to_homogeneous(x):
-    ones = torch.ones_like(x[...,-1:])
-    return torch.cat((x, ones), dim = -1)
+    ones = torch.ones_like(x[..., -1:])
+    return torch.cat((x, ones), dim=-1)
+
+
+def from_homogeneous(xh, eps=1e-12):
+    return xh[..., :-1] / (xh[..., -1:] + eps)
 
-def from_homogeneous(xh, eps = 1e-12):
-    return xh[...,:-1] / (xh[...,-1:]+eps)
 
 def homog_transform(Homog, x):
     xh = to_homogeneous(x)
@@ -711,49 +830,71 @@ def homog_transform(Homog, x):
     y = from_homogeneous(yh)
     return y
 
-def get_homog_warp(Homog, H, W, device = "cuda"):
-    grid = torch.meshgrid(torch.linspace(-1+1/H,1-1/H,H, device = device), torch.linspace(-1+1/W,1-1/W,W, device = device))
-    
-    x_A = torch.stack((grid[1], grid[0]), dim = -1)[None]
+
+def get_homog_warp(Homog, H, W, device="cuda"):
+    grid = torch.meshgrid(
+        torch.linspace(-1 + 1 / H, 1 - 1 / H, H, device=device),
+        torch.linspace(-1 + 1 / W, 1 - 1 / W, W, device=device),
+    )
+
+    x_A = torch.stack((grid[1], grid[0]), dim=-1)[None]
     x_A_to_B = homog_transform(Homog, x_A)
     mask = ((x_A_to_B > -1) * (x_A_to_B < 1)).prod(dim=-1).float()
-    return torch.cat((x_A.expand(*x_A_to_B.shape), x_A_to_B),dim=-1), mask
+    return torch.cat((x_A.expand(*x_A_to_B.shape), x_A_to_B), dim=-1), mask
 
-def dual_log_softmax_matcher(desc_A: tuple['B','N','C'], desc_B: tuple['B','M','C'], inv_temperature = 1, normalize = False):
+
+def dual_log_softmax_matcher(
+    desc_A: tuple["B", "N", "C"],
+    desc_B: tuple["B", "M", "C"],
+    inv_temperature=1,
+    normalize=False,
+):
     B, N, C = desc_A.shape
     if normalize:
-        desc_A = desc_A/desc_A.norm(dim=-1,keepdim=True)
-        desc_B = desc_B/desc_B.norm(dim=-1,keepdim=True)
+        desc_A = desc_A / desc_A.norm(dim=-1, keepdim=True)
+        desc_B = desc_B / desc_B.norm(dim=-1, keepdim=True)
         corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature
     else:
         corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature
-    logP = corr.log_softmax(dim = -2) + corr.log_softmax(dim= -1)
+    logP = corr.log_softmax(dim=-2) + corr.log_softmax(dim=-1)
     return logP
 
-def dual_softmax_matcher(desc_A: tuple['B','N','C'], desc_B: tuple['B','M','C'], inv_temperature = 1, normalize = False):
+
+def dual_softmax_matcher(
+    desc_A: tuple["B", "N", "C"],
+    desc_B: tuple["B", "M", "C"],
+    inv_temperature=1,
+    normalize=False,
+):
     if len(desc_A.shape) < 3:
         desc_A, desc_B = desc_A[None], desc_B[None]
     B, N, C = desc_A.shape
     if normalize:
-        desc_A = desc_A/desc_A.norm(dim=-1,keepdim=True)
-        desc_B = desc_B/desc_B.norm(dim=-1,keepdim=True)
+        desc_A = desc_A / desc_A.norm(dim=-1, keepdim=True)
+        desc_B = desc_B / desc_B.norm(dim=-1, keepdim=True)
         corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature
     else:
         corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature
-    P = corr.softmax(dim = -2) * corr.softmax(dim= -1)
+    P = corr.softmax(dim=-2) * corr.softmax(dim=-1)
     return P
 
-def conditional_softmax_matcher(desc_A: tuple['B','N','C'], desc_B: tuple['B','M','C'], inv_temperature = 1, normalize = False):
+
+def conditional_softmax_matcher(
+    desc_A: tuple["B", "N", "C"],
+    desc_B: tuple["B", "M", "C"],
+    inv_temperature=1,
+    normalize=False,
+):
     if len(desc_A.shape) < 3:
         desc_A, desc_B = desc_A[None], desc_B[None]
     B, N, C = desc_A.shape
     if normalize:
-        desc_A = desc_A/desc_A.norm(dim=-1,keepdim=True)
-        desc_B = desc_B/desc_B.norm(dim=-1,keepdim=True)
+        desc_A = desc_A / desc_A.norm(dim=-1, keepdim=True)
+        desc_B = desc_B / desc_B.norm(dim=-1, keepdim=True)
         corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature
     else:
         corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature
-    P_B_cond_A = corr.softmax(dim = -1)
-    P_A_cond_B = corr.softmax(dim = -2)
-    
-    return P_A_cond_B, P_B_cond_A 
\ No newline at end of file
+    P_B_cond_A = corr.softmax(dim=-1)
+    P_A_cond_B = corr.softmax(dim=-2)
+
+    return P_A_cond_B, P_B_cond_A
diff --git a/third_party/DeDoDe/data_prep/prep_keypoints.py b/third_party/DeDoDe/data_prep/prep_keypoints.py
index 25713ed7573babadc3a42daa544d85052fc37421..616f91b875879f726218efdfe4bb6dc95297b33a 100644
--- a/third_party/DeDoDe/data_prep/prep_keypoints.py
+++ b/third_party/DeDoDe/data_prep/prep_keypoints.py
@@ -9,70 +9,64 @@ import os
 
 base_path = "data/megadepth"
 # Remove the trailing / if need be.
-if base_path[-1] in ['/', '\\']:
-    base_path = base_path[: - 1]
+if base_path[-1] in ["/", "\\"]:
+    base_path = base_path[:-1]
 
 
-base_depth_path = os.path.join(
-    base_path, 'phoenix/S6/zl548/MegaDepth_v1'
-)
-base_undistorted_sfm_path = os.path.join(
-    base_path, 'Undistorted_SfM'
-)
+base_depth_path = os.path.join(base_path, "phoenix/S6/zl548/MegaDepth_v1")
+base_undistorted_sfm_path = os.path.join(base_path, "Undistorted_SfM")
 
 scene_ids = os.listdir(base_undistorted_sfm_path)
 for scene_id in scene_ids:
-    if os.path.exists(f"{base_path}/prep_scene_info/detections/detections_{scene_id}.npy"):
+    if os.path.exists(
+        f"{base_path}/prep_scene_info/detections/detections_{scene_id}.npy"
+    ):
         print(f"skipping {scene_id} as it exists")
         continue
     undistorted_sparse_path = os.path.join(
-        base_undistorted_sfm_path, scene_id, 'sparse-txt'
+        base_undistorted_sfm_path, scene_id, "sparse-txt"
     )
     if not os.path.exists(undistorted_sparse_path):
         print("sparse path doesnt exist")
         continue
 
-    depths_path = os.path.join(
-        base_depth_path, scene_id, 'dense0', 'depths'
-    )
+    depths_path = os.path.join(base_depth_path, scene_id, "dense0", "depths")
     if not os.path.exists(depths_path):
         print("depths doesnt exist")
-        
+
         continue
 
-    images_path = os.path.join(
-        base_undistorted_sfm_path, scene_id, 'images'
-    )
+    images_path = os.path.join(base_undistorted_sfm_path, scene_id, "images")
     if not os.path.exists(images_path):
         print("images path doesnt exist")
         continue
 
     # Process cameras.txt
-    if not os.path.exists(os.path.join(undistorted_sparse_path, 'cameras.txt')):
+    if not os.path.exists(os.path.join(undistorted_sparse_path, "cameras.txt")):
         print("no cameras")
         continue
-    with open(os.path.join(undistorted_sparse_path, 'cameras.txt'), 'r') as f:
-        raw = f.readlines()[3 :]  # skip the header
+    with open(os.path.join(undistorted_sparse_path, "cameras.txt"), "r") as f:
+        raw = f.readlines()[3:]  # skip the header
 
     camera_intrinsics = {}
     for camera in raw:
-        camera = camera.split(' ')
-        camera_intrinsics[int(camera[0])] = [float(elem) for elem in camera[2 :]]
+        camera = camera.split(" ")
+        camera_intrinsics[int(camera[0])] = [float(elem) for elem in camera[2:]]
 
     # Process points3D.txt
-    with open(os.path.join(undistorted_sparse_path, 'points3D.txt'), 'r') as f:
-        raw = f.readlines()[3 :]  # skip the header
+    with open(os.path.join(undistorted_sparse_path, "points3D.txt"), "r") as f:
+        raw = f.readlines()[3:]  # skip the header
 
     points3D = {}
     for point3D in raw:
-        point3D = point3D.split(' ')
-        points3D[int(point3D[0])] = np.array([
-            float(point3D[1]), float(point3D[2]), float(point3D[3])
-        ])
-        
+        point3D = point3D.split(" ")
+        points3D[int(point3D[0])] = np.array(
+            [float(point3D[1]), float(point3D[2]), float(point3D[3])]
+        )
+
     # Process images.txt
-    with open(os.path.join(undistorted_sparse_path, 'images.txt'), 'r') as f:
-        raw = f.readlines()[4 :]  # skip the header
+    with open(os.path.join(undistorted_sparse_path, "images.txt"), "r") as f:
+        raw = f.readlines()[4:]  # skip the header
 
     image_id_to_idx = {}
     image_names = []
@@ -81,20 +75,22 @@ for scene_id in scene_ids:
     points3D_id_to_2D = []
     n_points3D = []
     id_to_detections = {}
-    for idx, (image, points) in enumerate(zip(raw[:: 2], raw[1 :: 2])):
-        image = image.split(' ')
-        points = points.split(' ')
+    for idx, (image, points) in enumerate(zip(raw[::2], raw[1::2])):
+        image = image.split(" ")
+        points = points.split(" ")
 
         image_id_to_idx[int(image[0])] = idx
 
-        image_name = image[-1].strip('\n')
+        image_name = image[-1].strip("\n")
         image_names.append(image_name)
 
-        raw_pose.append([float(elem) for elem in image[1 : -2]])
+        raw_pose.append([float(elem) for elem in image[1:-2]])
         camera.append(int(image[-2]))
-        points_np = np.array(points).astype(np.float32).reshape(len(points)//3, 3)
-        visible_points = points_np[points_np[:,2] != -1]
+        points_np = np.array(points).astype(np.float32).reshape(len(points) // 3, 3)
+        visible_points = points_np[points_np[:, 2] != -1]
         id_to_detections[idx] = visible_points
-    np.save(f"{base_path}/prep_scene_info/detections/detections_{scene_id}.npy",
-            id_to_detections)
-    print(f"{scene_id} done")
\ No newline at end of file
+    np.save(
+        f"{base_path}/prep_scene_info/detections/detections_{scene_id}.npy",
+        id_to_detections,
+    )
+    print(f"{scene_id} done")
diff --git a/third_party/DeDoDe/demo/demo_kpts.py b/third_party/DeDoDe/demo/demo_kpts.py
index 270a23b12e2148ce7a438a68ab3ef1135a93a9e6..f0ae36aa4bbe3439e96d7b45bfa809c48b6ebf45 100644
--- a/third_party/DeDoDe/demo/demo_kpts.py
+++ b/third_party/DeDoDe/demo/demo_kpts.py
@@ -4,17 +4,19 @@ import numpy as np
 from PIL import Image
 from DeDoDe import dedode_detector_L
 
-def draw_kpts(im, kpts):    
-    kpts = [cv2.KeyPoint(x,y,1.) for x,y in kpts.cpu().numpy()]
+
+def draw_kpts(im, kpts):
+    kpts = [cv2.KeyPoint(x, y, 1.0) for x, y in kpts.cpu().numpy()]
     im = np.array(im)
     ret = cv2.drawKeypoints(im, kpts, None)
     return ret
 
-detector = dedode_detector_L(weights = torch.load("dedode_detector_l.pth"))
+
+detector = dedode_detector_L(weights=torch.load("dedode_detector_l.pth"))
 im_path = "assets/im_A.jpg"
 im = Image.open(im_path)
-out = detector.detect_from_path(im_path, num_keypoints = 10_000)
-W,H = im.size
+out = detector.detect_from_path(im_path, num_keypoints=10_000)
+W, H = im.size
 kps = out["keypoints"]
 kps = detector.to_pixel_coords(kps, H, W)
-Image.fromarray(draw_kpts(im, kps[0])).save("demo/keypoints.png")
\ No newline at end of file
+Image.fromarray(draw_kpts(im, kps[0])).save("demo/keypoints.png")
diff --git a/third_party/DeDoDe/demo/demo_match.py b/third_party/DeDoDe/demo/demo_match.py
index 6492392d07a49fcdb7e287b619b404df84521ca8..2ddecc453e1e3d0beb5e832819833209ad431048 100644
--- a/third_party/DeDoDe/demo/demo_match.py
+++ b/third_party/DeDoDe/demo/demo_match.py
@@ -5,17 +5,18 @@ from DeDoDe.utils import *
 from PIL import Image
 import cv2
 
-def draw_matches(im_A, kpts_A, im_B, kpts_B):    
-    kpts_A = [cv2.KeyPoint(x,y,1.) for x,y in kpts_A.cpu().numpy()]
-    kpts_B = [cv2.KeyPoint(x,y,1.) for x,y in kpts_B.cpu().numpy()]
-    matches_A_to_B = [cv2.DMatch(idx, idx, 0.) for idx in range(len(kpts_A))]
+
+def draw_matches(im_A, kpts_A, im_B, kpts_B):
+    kpts_A = [cv2.KeyPoint(x, y, 1.0) for x, y in kpts_A.cpu().numpy()]
+    kpts_B = [cv2.KeyPoint(x, y, 1.0) for x, y in kpts_B.cpu().numpy()]
+    matches_A_to_B = [cv2.DMatch(idx, idx, 0.0) for idx in range(len(kpts_A))]
     im_A, im_B = np.array(im_A), np.array(im_B)
-    ret = cv2.drawMatches(im_A, kpts_A, im_B, kpts_B, 
-                    matches_A_to_B, None)
+    ret = cv2.drawMatches(im_A, kpts_A, im_B, kpts_B, matches_A_to_B, None)
     return ret
 
-detector = dedode_detector_L(weights = torch.load("dedode_detector_L.pth"))
-descriptor = dedode_descriptor_B(weights = torch.load("dedode_descriptor_B.pth"))
+
+detector = dedode_detector_L(weights=torch.load("dedode_detector_L.pth"))
+descriptor = dedode_descriptor_B(weights=torch.load("dedode_descriptor_B.pth"))
 matcher = DualSoftMaxMatcher()
 
 im_A_path = "assets/im_A.jpg"
@@ -26,20 +27,33 @@ W_A, H_A = im_A.size
 W_B, H_B = im_B.size
 
 
-detections_A = detector.detect_from_path(im_A_path, num_keypoints = 10_000)
+detections_A = detector.detect_from_path(im_A_path, num_keypoints=10_000)
 keypoints_A, P_A = detections_A["keypoints"], detections_A["confidence"]
-detections_B = detector.detect_from_path(im_B_path, num_keypoints = 10_000)
+detections_B = detector.detect_from_path(im_B_path, num_keypoints=10_000)
 keypoints_B, P_B = detections_B["keypoints"], detections_B["confidence"]
-description_A = descriptor.describe_keypoints_from_path(im_A_path, keypoints_A)["descriptions"]
-description_B = descriptor.describe_keypoints_from_path(im_B_path, keypoints_B)["descriptions"]
-matches_A, matches_B, batch_ids = matcher.match(keypoints_A, description_A,
-    keypoints_B, description_B,
-    P_A = P_A, P_B = P_B,
-    normalize = True, inv_temp=20, threshold = 0.1)#Increasing threshold -> fewer matches, fewer outliers
+description_A = descriptor.describe_keypoints_from_path(im_A_path, keypoints_A)[
+    "descriptions"
+]
+description_B = descriptor.describe_keypoints_from_path(im_B_path, keypoints_B)[
+    "descriptions"
+]
+matches_A, matches_B, batch_ids = matcher.match(
+    keypoints_A,
+    description_A,
+    keypoints_B,
+    description_B,
+    P_A=P_A,
+    P_B=P_B,
+    normalize=True,
+    inv_temp=20,
+    threshold=0.1,
+)  # Increasing threshold -> fewer matches, fewer outliers
 
 matches_A, matches_B = matcher.to_pixel_coords(matches_A, matches_B, H_A, W_A, H_B, W_B)
 
 import cv2
 import numpy as np
 
-Image.fromarray(draw_matches(im_A, matches_A[::5], im_B, matches_B[::5])).save("demo/matches.png")
\ No newline at end of file
+Image.fromarray(draw_matches(im_A, matches_A[::5], im_B, matches_B[::5])).save(
+    "demo/matches.png"
+)
diff --git a/third_party/DeDoDe/demo/demo_scoremap.py b/third_party/DeDoDe/demo/demo_scoremap.py
index 68af499dbb58e275e227bbdc979b4d1923902df0..1a0a2b2470783c69753960725aee1b689b0cb2cc 100644
--- a/third_party/DeDoDe/demo/demo_scoremap.py
+++ b/third_party/DeDoDe/demo/demo_scoremap.py
@@ -5,16 +5,20 @@ import numpy as np
 from DeDoDe import dedode_detector_L
 from DeDoDe.utils import tensor_to_pil
 
-detector = dedode_detector_L(weights = torch.load("dedode_detector_l.pth"))
+detector = dedode_detector_L(weights=torch.load("dedode_detector_l.pth"))
 H, W = 768, 768
 im_path = "assets/im_A.jpg"
 
-out = detector.detect_from_path(im_path, dense = True, H = H, W = W)
+out = detector.detect_from_path(im_path, dense=True, H=H, W=W)
 
 logit_map = out["dense_keypoint_logits"].clone()
 min = logit_map.max() - 3
 logit_map[logit_map < min] = min
-logit_map = (logit_map-min)/(logit_map.max()-min)
-logit_map = logit_map.cpu()[0].expand(3,H,W)
-im_A = torch.tensor(np.array(Image.open(im_path).resize((W,H)))/255.).permute(2,0,1)
-tensor_to_pil(logit_map * logit_map  +  0.15 * (1-logit_map) * im_A).save("demo/dense_logits.png")
+logit_map = (logit_map - min) / (logit_map.max() - min)
+logit_map = logit_map.cpu()[0].expand(3, H, W)
+im_A = torch.tensor(np.array(Image.open(im_path).resize((W, H))) / 255.0).permute(
+    2, 0, 1
+)
+tensor_to_pil(logit_map * logit_map + 0.15 * (1 - logit_map) * im_A).save(
+    "demo/dense_logits.png"
+)
diff --git a/third_party/DeDoDe/setup.py b/third_party/DeDoDe/setup.py
index 18a43e0b69131d3f91229f5a59e9b1d48411d890..94d1fd8ed2e5ac769222afce4f084ac19029a2a4 100644
--- a/third_party/DeDoDe/setup.py
+++ b/third_party/DeDoDe/setup.py
@@ -3,8 +3,8 @@ from setuptools import setup, find_packages
 
 setup(
     name="DeDoDe",
-    packages=find_packages(include= ["DeDoDe*"]),
+    packages=find_packages(include=["DeDoDe*"]),
     install_requires=open("requirements.txt", "r").read().split("\n"),
     version="0.0.1",
     author="Johan Edstedt",
-)
\ No newline at end of file
+)
diff --git a/third_party/GlueStick/gluestick/__init__.py b/third_party/GlueStick/gluestick/__init__.py
index d3051821ecfb2e18f4b9b4dfb50f35064106eb57..4eaf01e90440afeb485a4743f181dac348ede63d 100644
--- a/third_party/GlueStick/gluestick/__init__.py
+++ b/third_party/GlueStick/gluestick/__init__.py
@@ -8,11 +8,12 @@ GLUESTICK_ROOT = Path(__file__).parent.parent
 
 def get_class(mod_name, base_path, BaseClass):
     """Get the class object which inherits from BaseClass and is defined in
-       the module named mod_name, child of base_path.
+    the module named mod_name, child of base_path.
     """
     import inspect
-    mod_path = '{}.{}'.format(base_path, mod_name)
-    mod = __import__(mod_path, fromlist=[''])
+
+    mod_path = "{}.{}".format(base_path, mod_name)
+    mod = __import__(mod_path, fromlist=[""])
     classes = inspect.getmembers(mod, inspect.isclass)
     # Filter classes defined in the module
     classes = [c for c in classes if c[1].__module__ == mod_path]
@@ -24,7 +25,8 @@ def get_class(mod_name, base_path, BaseClass):
 
 def get_model(name):
     from .models.base_model import BaseModel
-    return get_class('models.' + name, __name__, BaseModel)
+
+    return get_class("models." + name, __name__, BaseModel)
 
 
 def numpy_image_to_torch(image):
@@ -34,8 +36,8 @@ def numpy_image_to_torch(image):
     elif image.ndim == 2:
         image = image[None]  # add channel axis
     else:
-        raise ValueError(f'Not an image: {image.shape}')
-    return torch.from_numpy(image / 255.).float()
+        raise ValueError(f"Not an image: {image.shape}")
+    return torch.from_numpy(image / 255.0).float()
 
 
 def map_tensor(input_, func):
diff --git a/third_party/GlueStick/gluestick/drawing.py b/third_party/GlueStick/gluestick/drawing.py
index 8e6d24b6bfedc93449142647410057d978d733ef..8365b7e1f91adedcd190c49b2a38cbcd817d84c2 100644
--- a/third_party/GlueStick/gluestick/drawing.py
+++ b/third_party/GlueStick/gluestick/drawing.py
@@ -4,8 +4,7 @@ import numpy as np
 import seaborn as sns
 
 
-def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5,
-                adaptive=True):
+def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True):
     """Plot a set of images horizontally.
     Args:
         imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
@@ -23,7 +22,8 @@ def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5,
         ratios = [4 / 3] * n
     figsize = [sum(ratios) * 4.5, 4.5]
     fig, ax = plt.subplots(
-        1, n, figsize=figsize, dpi=dpi, gridspec_kw={'width_ratios': ratios})
+        1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
+    )
     if n == 1:
         ax = [ax]
     for i in range(n):
@@ -39,7 +39,7 @@ def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5,
     return ax
 
 
-def plot_keypoints(kpts, colors='lime', ps=4, alpha=1):
+def plot_keypoints(kpts, colors="lime", ps=4, alpha=1):
     """Plot keypoints for existing images.
     Args:
         kpts: list of ndarrays of size (N, 2).
@@ -53,7 +53,7 @@ def plot_keypoints(kpts, colors='lime', ps=4, alpha=1):
         a.scatter(k[:, 0], k[:, 1], c=c, s=ps, alpha=alpha, linewidths=0)
 
 
-def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.):
+def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0):
     """Plot matches for a pair of existing images.
     Args:
         kpts0, kpts1: corresponding keypoints of size (N, 2).
@@ -80,11 +80,18 @@ def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.):
         transFigure = fig.transFigure.inverted()
         fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
         fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
-        fig.lines += [matplotlib.lines.Line2D(
-            (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]),
-            zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw,
-            alpha=a)
-            for i in range(len(kpts0))]
+        fig.lines += [
+            matplotlib.lines.Line2D(
+                (fkpts0[i, 0], fkpts1[i, 0]),
+                (fkpts0[i, 1], fkpts1[i, 1]),
+                zorder=1,
+                transform=fig.transFigure,
+                c=color[i],
+                linewidth=lw,
+                alpha=a,
+            )
+            for i in range(len(kpts0))
+        ]
 
     # freeze the axes to prevent the transform to change
     ax0.autoscale(enable=False)
@@ -95,9 +102,16 @@ def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.):
         ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
 
 
-def plot_lines(lines, line_colors='orange', point_colors='cyan',
-               ps=4, lw=2, alpha=1., indices=(0, 1)):
-    """ Plot lines and endpoints for existing images.
+def plot_lines(
+    lines,
+    line_colors="orange",
+    point_colors="cyan",
+    ps=4,
+    lw=2,
+    alpha=1.0,
+    indices=(0, 1),
+):
+    """Plot lines and endpoints for existing images.
     Args:
         lines: list of ndarrays of size (N, 2, 2).
         colors: string, or list of list of tuples (one for each keypoints).
@@ -120,18 +134,20 @@ def plot_lines(lines, line_colors='orange', point_colors='cyan',
     # Plot the lines and junctions
     for a, l, lc, pc in zip(axes, lines, line_colors, point_colors):
         for i in range(len(l)):
-            line = matplotlib.lines.Line2D((l[i, 0, 0], l[i, 1, 0]),
-                                           (l[i, 0, 1], l[i, 1, 1]),
-                                           zorder=1, c=lc, linewidth=lw,
-                                           alpha=alpha)
+            line = matplotlib.lines.Line2D(
+                (l[i, 0, 0], l[i, 1, 0]),
+                (l[i, 0, 1], l[i, 1, 1]),
+                zorder=1,
+                c=lc,
+                linewidth=lw,
+                alpha=alpha,
+            )
             a.add_line(line)
         pts = l.reshape(-1, 2)
-        a.scatter(pts[:, 0], pts[:, 1],
-                  c=pc, s=ps, linewidths=0, zorder=2, alpha=alpha)
+        a.scatter(pts[:, 0], pts[:, 1], c=pc, s=ps, linewidths=0, zorder=2, alpha=alpha)
 
 
-def plot_color_line_matches(lines, correct_matches=None,
-                            lw=2, indices=(0, 1)):
+def plot_color_line_matches(lines, correct_matches=None, lw=2, indices=(0, 1)):
     """Plot line matches for existing images with multiple colors.
     Args:
         lines: list of ndarrays of size (N, 2, 2).
@@ -140,7 +156,7 @@ def plot_color_line_matches(lines, correct_matches=None,
         indices: indices of the images to draw the matches on.
     """
     n_lines = len(lines[0])
-    colors = sns.color_palette('husl', n_colors=n_lines)
+    colors = sns.color_palette("husl", n_colors=n_lines)
     np.random.shuffle(colors)
     alphas = np.ones(n_lines)
     # If correct_matches is not None, display wrong matches with a low alpha
@@ -159,8 +175,15 @@ def plot_color_line_matches(lines, correct_matches=None,
         transFigure = fig.transFigure.inverted()
         endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
         endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
-        fig.lines += [matplotlib.lines.Line2D(
-            (endpoint0[i, 0], endpoint1[i, 0]),
-            (endpoint0[i, 1], endpoint1[i, 1]),
-            zorder=1, transform=fig.transFigure, c=colors[i],
-            alpha=alphas[i], linewidth=lw) for i in range(n_lines)]
+        fig.lines += [
+            matplotlib.lines.Line2D(
+                (endpoint0[i, 0], endpoint1[i, 0]),
+                (endpoint0[i, 1], endpoint1[i, 1]),
+                zorder=1,
+                transform=fig.transFigure,
+                c=colors[i],
+                alpha=alphas[i],
+                linewidth=lw,
+            )
+            for i in range(n_lines)
+        ]
diff --git a/third_party/GlueStick/gluestick/geometry.py b/third_party/GlueStick/gluestick/geometry.py
index 97853c4807d319eb9ea0377db7385e9a72fb400b..0cdd232e74aeda84e1683dcb8e51385cc2497c37 100644
--- a/third_party/GlueStick/gluestick/geometry.py
+++ b/third_party/GlueStick/gluestick/geometry.py
@@ -21,7 +21,7 @@ def to_homogeneous(points):
         raise ValueError
 
 
-def from_homogeneous(points, eps=0.):
+def from_homogeneous(points, eps=0.0):
     """Remove the homogeneous dimension of N-dimensional points.
     Args:
         points: torch.Tensor or numpy.ndarray with size (..., N+1).
@@ -32,14 +32,22 @@ def from_homogeneous(points, eps=0.):
 
 
 def skew_symmetric(v):
-    """Create a skew-symmetric matrix from a (batched) vector of size (..., 3).
-    """
+    """Create a skew-symmetric matrix from a (batched) vector of size (..., 3)."""
     z = torch.zeros_like(v[..., 0])
-    M = torch.stack([
-        z, -v[..., 2], v[..., 1],
-        v[..., 2], z, -v[..., 0],
-        -v[..., 1], v[..., 0], z,
-    ], dim=-1).reshape(v.shape[:-1] + (3, 3))
+    M = torch.stack(
+        [
+            z,
+            -v[..., 2],
+            v[..., 1],
+            v[..., 2],
+            z,
+            -v[..., 0],
+            -v[..., 1],
+            v[..., 0],
+            z,
+        ],
+        dim=-1,
+    ).reshape(v.shape[:-1] + (3, 3))
     return M
 
 
@@ -67,7 +75,7 @@ def warp_points_torch(points, H, inverse=True):
     H_mat = torch.cat([H, torch.ones_like(H[..., :1])], axis=-1).reshape(out_shape)
     if inverse:
         H_mat = torch.inverse(H_mat)
-    warped_points = torch.einsum('...nj,...ji->...ni', points, H_mat.transpose(-2, -1))
+    warped_points = torch.einsum("...nj,...ji->...ni", points, H_mat.transpose(-2, -1))
 
     warped_points = from_homogeneous(warped_points, eps=1e-5)
 
@@ -76,18 +84,27 @@ def warp_points_torch(points, H, inverse=True):
 
 def seg_equation(segs):
     # calculate list of start, end and midpoints points from both lists
-    start_points, end_points = to_homogeneous(segs[..., 0, :]), to_homogeneous(segs[..., 1, :])
+    start_points, end_points = to_homogeneous(segs[..., 0, :]), to_homogeneous(
+        segs[..., 1, :]
+    )
     # Compute the line equations as ax + by + c = 0 , where x^2 + y^2 = 1
     lines = torch.cross(start_points, end_points, dim=-1)
-    lines_norm = (torch.sqrt(lines[..., 0] ** 2 + lines[..., 1] ** 2)[..., None])
-    assert torch.all(lines_norm > 0), 'Error: trying to compute the equation of a line with a single point'
+    lines_norm = torch.sqrt(lines[..., 0] ** 2 + lines[..., 1] ** 2)[..., None]
+    assert torch.all(
+        lines_norm > 0
+    ), "Error: trying to compute the equation of a line with a single point"
     lines = lines / lines_norm
     return lines
 
 
 def is_inside_img(pts: torch.Tensor, img_shape: Tuple[int, int]):
     h, w = img_shape
-    return (pts >= 0).all(dim=-1) & (pts[..., 0] < w) & (pts[..., 1] < h) & (~torch.isinf(pts).any(dim=-1))
+    return (
+        (pts >= 0).all(dim=-1)
+        & (pts[..., 0] < w)
+        & (pts[..., 1] < h)
+        & (~torch.isinf(pts).any(dim=-1))
+    )
 
 
 def shrink_segs_to_img(segs: torch.Tensor, img_shape: Tuple[int, int]) -> torch.Tensor:
@@ -102,7 +119,9 @@ def shrink_segs_to_img(segs: torch.Tensor, img_shape: Tuple[int, int]) -> torch.
     # Project the segments to the reference image
     segs = segs.clone()
     eqs = seg_equation(segs)
-    x0, y0 = torch.tensor([1., 0, 0.], device=device), torch.tensor([0., 1, 0], device=device)
+    x0, y0 = torch.tensor([1.0, 0, 0.0], device=device), torch.tensor(
+        [0.0, 1, 0], device=device
+    )
     x0 = x0.repeat(eqs.shape[:-1] + (1,))
     y0 = y0.repeat(eqs.shape[:-1] + (1,))
     pt_x0s = torch.cross(eqs, x0, dim=-1)
@@ -112,7 +131,9 @@ def shrink_segs_to_img(segs: torch.Tensor, img_shape: Tuple[int, int]) -> torch.
     pt_y0s = pt_y0s[..., :-1] / pt_y0s[..., None, -1]
     pt_y0s_valid = is_inside_img(pt_y0s, img_shape)
 
-    xW, yH = torch.tensor([1., 0, EPS - w], device=device), torch.tensor([0., 1, EPS - h], device=device)
+    xW, yH = torch.tensor([1.0, 0, EPS - w], device=device), torch.tensor(
+        [0.0, 1, EPS - h], device=device
+    )
     xW = xW.repeat(eqs.shape[:-1] + (1,))
     yH = yH.repeat(eqs.shape[:-1] + (1,))
     pt_xWs = torch.cross(eqs, xW, dim=-1)
@@ -143,11 +164,17 @@ def shrink_segs_to_img(segs: torch.Tensor, img_shape: Tuple[int, int]) -> torch.
     mask = (segs[..., 1, 1] > (h - 1)) & pt_yHs_valid
     segs[mask, 1, :] = pt_yHs[mask]
 
-    assert torch.all(segs >= 0) and torch.all(segs[..., 0] < w) and torch.all(segs[..., 1] < h)
+    assert (
+        torch.all(segs >= 0)
+        and torch.all(segs[..., 0] < w)
+        and torch.all(segs[..., 1] < h)
+    )
     return segs
 
 
-def warp_lines_torch(lines, H, inverse=True, dst_shape: Tuple[int, int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+def warp_lines_torch(
+    lines, H, inverse=True, dst_shape: Tuple[int, int] = None
+) -> Tuple[torch.Tensor, torch.Tensor]:
     """
     :param lines: A tensor of shape (B, N, 2, 2) where B is the batch size, N the number of lines.
     :param H: The homography used to convert the lines. batched or not (shapes (B, 8) and (8,) respectively).
@@ -156,12 +183,16 @@ def warp_lines_torch(lines, H, inverse=True, dst_shape: Tuple[int, int] = None)
     """
     device = lines.device
     batch_size, n = lines.shape[:2]
-    lines = warp_points_torch(lines.reshape(batch_size, -1, 2), H, inverse).reshape(lines.shape)
+    lines = warp_points_torch(lines.reshape(batch_size, -1, 2), H, inverse).reshape(
+        lines.shape
+    )
 
     if dst_shape is None:
         return lines, torch.ones(lines.shape[:-2], dtype=torch.bool, device=device)
 
-    out_img = torch.any((lines < 0) | (lines >= torch.tensor(dst_shape[::-1], device=device)), -1)
+    out_img = torch.any(
+        (lines < 0) | (lines >= torch.tensor(dst_shape[::-1], device=device)), -1
+    )
     valid = ~out_img.all(-1)
     any_out_of_img = out_img.any(-1)
     lines_to_trim = valid & any_out_of_img
diff --git a/third_party/GlueStick/gluestick/models/base_model.py b/third_party/GlueStick/gluestick/models/base_model.py
index 30ca991655a28ca88074b42312c33b360f655fab..ef326bbb9e7deb78ee59d7cf9b2a76a5234106b4 100644
--- a/third_party/GlueStick/gluestick/models/base_model.py
+++ b/third_party/GlueStick/gluestick/models/base_model.py
@@ -13,7 +13,7 @@ class MetaModel(ABCMeta):
     def __prepare__(name, bases, **kwds):
         total_conf = OmegaConf.create()
         for base in bases:
-            for key in ('base_default_conf', 'default_conf'):
+            for key in ("base_default_conf", "default_conf"):
                 update = getattr(base, key, {})
                 if isinstance(update, dict):
                     update = OmegaConf.create(update)
@@ -49,10 +49,11 @@ class BaseModel(nn.Module, metaclass=MetaModel):
         metrics(self, pred, data): method that returns a dictionary of metrics,
         each as a batch of scalars.
     """
+
     default_conf = {
-        'name': None,
-        'trainable': True,  # if false: do not optimize this model parameters
-        'freeze_batch_normalization': False,  # use test-time statistics
+        "name": None,
+        "trainable": True,  # if false: do not optimize this model parameters
+        "freeze_batch_normalization": False,  # use test-time statistics
     }
     required_data_keys = []
     strict_conf = True
@@ -61,15 +62,16 @@ class BaseModel(nn.Module, metaclass=MetaModel):
         """Perform some logic and call the _init method of the child model."""
         super().__init__()
         default_conf = OmegaConf.merge(
-                self.base_default_conf, OmegaConf.create(self.default_conf))
+            self.base_default_conf, OmegaConf.create(self.default_conf)
+        )
         if self.strict_conf:
             OmegaConf.set_struct(default_conf, True)
 
         # fixme: backward compatibility
-        if 'pad' in conf and 'pad' not in default_conf:  # backward compat.
+        if "pad" in conf and "pad" not in default_conf:  # backward compat.
             with omegaconf.read_write(conf):
                 with omegaconf.open_dict(conf):
-                    conf['interpolation'] = {'pad': conf.pop('pad')}
+                    conf["interpolation"] = {"pad": conf.pop("pad")}
 
         if isinstance(conf, dict):
             conf = OmegaConf.create(conf)
@@ -89,6 +91,7 @@ class BaseModel(nn.Module, metaclass=MetaModel):
         def freeze_bn(module):
             if isinstance(module, nn.modules.batchnorm._BatchNorm):
                 module.eval()
+
         if self.conf.freeze_batch_normalization:
             self.apply(freeze_bn)
 
@@ -96,9 +99,10 @@ class BaseModel(nn.Module, metaclass=MetaModel):
 
     def forward(self, data):
         """Check the data and call the _forward method of the child model."""
+
         def recursive_key_check(expected, given):
             for key in expected:
-                assert key in given, f'Missing key {key} in data'
+                assert key in given, f"Missing key {key} in data"
                 if isinstance(expected, dict):
                     recursive_key_check(expected[key], given[key])
 
diff --git a/third_party/GlueStick/gluestick/models/gluestick.py b/third_party/GlueStick/gluestick/models/gluestick.py
index c2a6c477eebecc2c43feea007f99c2115aa7c216..8179f8ff779401f535260b930a3f5e4d957af614 100644
--- a/third_party/GlueStick/gluestick/models/gluestick.py
+++ b/third_party/GlueStick/gluestick/models/gluestick.py
@@ -12,139 +12,178 @@ ETH_EPS = 1e-8
 
 class GlueStick(BaseModel):
     default_conf = {
-        'input_dim': 256,
-        'descriptor_dim': 256,
-        'bottleneck_dim': None,
-        'weights': None,
-        'keypoint_encoder': [32, 64, 128, 256],
-        'GNN_layers': ['self', 'cross'] * 9,
-        'num_line_iterations': 1,
-        'line_attention': False,
-        'filter_threshold': 0.2,
-        'checkpointed': False,
-        'skip_init': False,
-        'inter_supervision': None,
-        'loss': {
-            'nll_weight': 1.,
-            'nll_balancing': 0.5,
-            'reward_weight': 0.,
-            'bottleneck_l2_weight': 0.,
-            'dense_nll_weight': 0.,
-            'inter_supervision': [0.3, 0.6],
+        "input_dim": 256,
+        "descriptor_dim": 256,
+        "bottleneck_dim": None,
+        "weights": None,
+        "keypoint_encoder": [32, 64, 128, 256],
+        "GNN_layers": ["self", "cross"] * 9,
+        "num_line_iterations": 1,
+        "line_attention": False,
+        "filter_threshold": 0.2,
+        "checkpointed": False,
+        "skip_init": False,
+        "inter_supervision": None,
+        "loss": {
+            "nll_weight": 1.0,
+            "nll_balancing": 0.5,
+            "reward_weight": 0.0,
+            "bottleneck_l2_weight": 0.0,
+            "dense_nll_weight": 0.0,
+            "inter_supervision": [0.3, 0.6],
         },
     }
     required_data_keys = [
-        'keypoints0', 'keypoints1',
-        'descriptors0', 'descriptors1',
-        'keypoint_scores0', 'keypoint_scores1']
-
-    DEFAULT_LOSS_CONF = {'nll_weight': 1., 'nll_balancing': 0.5, 'reward_weight': 0., 'bottleneck_l2_weight': 0.}
+        "keypoints0",
+        "keypoints1",
+        "descriptors0",
+        "descriptors1",
+        "keypoint_scores0",
+        "keypoint_scores1",
+    ]
+
+    DEFAULT_LOSS_CONF = {
+        "nll_weight": 1.0,
+        "nll_balancing": 0.5,
+        "reward_weight": 0.0,
+        "bottleneck_l2_weight": 0.0,
+    }
 
     def _init(self, conf):
         if conf.bottleneck_dim is not None:
             self.bottleneck_down = nn.Conv1d(
-                conf.input_dim, conf.bottleneck_dim, kernel_size=1)
+                conf.input_dim, conf.bottleneck_dim, kernel_size=1
+            )
             self.bottleneck_up = nn.Conv1d(
-                conf.bottleneck_dim, conf.input_dim, kernel_size=1)
+                conf.bottleneck_dim, conf.input_dim, kernel_size=1
+            )
             nn.init.constant_(self.bottleneck_down.bias, 0.0)
             nn.init.constant_(self.bottleneck_up.bias, 0.0)
 
         if conf.input_dim != conf.descriptor_dim:
             self.input_proj = nn.Conv1d(
-                conf.input_dim, conf.descriptor_dim, kernel_size=1)
+                conf.input_dim, conf.descriptor_dim, kernel_size=1
+            )
             nn.init.constant_(self.input_proj.bias, 0.0)
 
-        self.kenc = KeypointEncoder(conf.descriptor_dim,
-                                    conf.keypoint_encoder)
+        self.kenc = KeypointEncoder(conf.descriptor_dim, conf.keypoint_encoder)
         self.lenc = EndPtEncoder(conf.descriptor_dim, conf.keypoint_encoder)
-        self.gnn = AttentionalGNN(conf.descriptor_dim, conf.GNN_layers,
-                                  checkpointed=conf.checkpointed,
-                                  inter_supervision=conf.inter_supervision,
-                                  num_line_iterations=conf.num_line_iterations,
-                                  line_attention=conf.line_attention)
-        self.final_proj = nn.Conv1d(conf.descriptor_dim, conf.descriptor_dim,
-                                    kernel_size=1)
+        self.gnn = AttentionalGNN(
+            conf.descriptor_dim,
+            conf.GNN_layers,
+            checkpointed=conf.checkpointed,
+            inter_supervision=conf.inter_supervision,
+            num_line_iterations=conf.num_line_iterations,
+            line_attention=conf.line_attention,
+        )
+        self.final_proj = nn.Conv1d(
+            conf.descriptor_dim, conf.descriptor_dim, kernel_size=1
+        )
         nn.init.constant_(self.final_proj.bias, 0.0)
         nn.init.orthogonal_(self.final_proj.weight, gain=1)
         self.final_line_proj = nn.Conv1d(
-            conf.descriptor_dim, conf.descriptor_dim, kernel_size=1)
+            conf.descriptor_dim, conf.descriptor_dim, kernel_size=1
+        )
         nn.init.constant_(self.final_line_proj.bias, 0.0)
         nn.init.orthogonal_(self.final_line_proj.weight, gain=1)
         if conf.inter_supervision is not None:
             self.inter_line_proj = nn.ModuleList(
-                [nn.Conv1d(conf.descriptor_dim, conf.descriptor_dim, kernel_size=1)
-                 for _ in conf.inter_supervision])
+                [
+                    nn.Conv1d(conf.descriptor_dim, conf.descriptor_dim, kernel_size=1)
+                    for _ in conf.inter_supervision
+                ]
+            )
             self.layer2idx = {}
             for i, l in enumerate(conf.inter_supervision):
                 nn.init.constant_(self.inter_line_proj[i].bias, 0.0)
                 nn.init.orthogonal_(self.inter_line_proj[i].weight, gain=1)
                 self.layer2idx[l] = i
 
-        bin_score = torch.nn.Parameter(torch.tensor(1.))
-        self.register_parameter('bin_score', bin_score)
-        line_bin_score = torch.nn.Parameter(torch.tensor(1.))
-        self.register_parameter('line_bin_score', line_bin_score)
+        bin_score = torch.nn.Parameter(torch.tensor(1.0))
+        self.register_parameter("bin_score", bin_score)
+        line_bin_score = torch.nn.Parameter(torch.tensor(1.0))
+        self.register_parameter("line_bin_score", line_bin_score)
 
         if conf.weights:
             assert isinstance(conf.weights, str)
-            state_dict = torch.load(conf.weights, map_location='cpu')
-            if 'model' in state_dict:
-                state_dict = {k.replace('matcher.', ''): v for k, v in state_dict['model'].items() if 'matcher.' in k}
-                state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
+            state_dict = torch.load(conf.weights, map_location="cpu")
+            if "model" in state_dict:
+                state_dict = {
+                    k.replace("matcher.", ""): v
+                    for k, v in state_dict["model"].items()
+                    if "matcher." in k
+                }
+                state_dict = {
+                    k.replace("module.", ""): v for k, v in state_dict.items()
+                }
             self.load_state_dict(state_dict)
 
     def _forward(self, data):
-        device = data['keypoints0'].device
-        b_size = len(data['keypoints0'])
-        image_size0 = (data['image_size0'] if 'image_size0' in data
-                       else data['image0'].shape)
-        image_size1 = (data['image_size1'] if 'image_size1' in data
-                       else data['image1'].shape)
+        device = data["keypoints0"].device
+        b_size = len(data["keypoints0"])
+        image_size0 = (
+            data["image_size0"] if "image_size0" in data else data["image0"].shape
+        )
+        image_size1 = (
+            data["image_size1"] if "image_size1" in data else data["image1"].shape
+        )
 
         pred = {}
-        desc0, desc1 = data['descriptors0'], data['descriptors1']
-        kpts0, kpts1 = data['keypoints0'], data['keypoints1']
+        desc0, desc1 = data["descriptors0"], data["descriptors1"]
+        kpts0, kpts1 = data["keypoints0"], data["keypoints1"]
 
         n_kpts0, n_kpts1 = kpts0.shape[1], kpts1.shape[1]
-        n_lines0, n_lines1 = data['lines0'].shape[1], data['lines1'].shape[1]
+        n_lines0, n_lines1 = data["lines0"].shape[1], data["lines1"].shape[1]
         if n_kpts0 == 0 or n_kpts1 == 0:
             # No detected keypoints nor lines
-            pred['log_assignment'] = torch.zeros(
-                b_size, n_kpts0, n_kpts1, dtype=torch.float, device=device)
-            pred['matches0'] = torch.full(
-                (b_size, n_kpts0), -1, device=device, dtype=torch.int64)
-            pred['matches1'] = torch.full(
-                (b_size, n_kpts1), -1, device=device, dtype=torch.int64)
-            pred['match_scores0'] = torch.zeros(
-                (b_size, n_kpts0), device=device, dtype=torch.float32)
-            pred['match_scores1'] = torch.zeros(
-                (b_size, n_kpts1), device=device, dtype=torch.float32)
-            pred['line_log_assignment'] = torch.zeros(b_size, n_lines0, n_lines1,
-                                                      dtype=torch.float, device=device)
-            pred['line_matches0'] = torch.full((b_size, n_lines0), -1,
-                                               device=device, dtype=torch.int64)
-            pred['line_matches1'] = torch.full((b_size, n_lines1), -1,
-                                               device=device, dtype=torch.int64)
-            pred['line_match_scores0'] = torch.zeros(
-                (b_size, n_lines0), device=device, dtype=torch.float32)
-            pred['line_match_scores1'] = torch.zeros(
-                (b_size, n_kpts1), device=device, dtype=torch.float32)
+            pred["log_assignment"] = torch.zeros(
+                b_size, n_kpts0, n_kpts1, dtype=torch.float, device=device
+            )
+            pred["matches0"] = torch.full(
+                (b_size, n_kpts0), -1, device=device, dtype=torch.int64
+            )
+            pred["matches1"] = torch.full(
+                (b_size, n_kpts1), -1, device=device, dtype=torch.int64
+            )
+            pred["match_scores0"] = torch.zeros(
+                (b_size, n_kpts0), device=device, dtype=torch.float32
+            )
+            pred["match_scores1"] = torch.zeros(
+                (b_size, n_kpts1), device=device, dtype=torch.float32
+            )
+            pred["line_log_assignment"] = torch.zeros(
+                b_size, n_lines0, n_lines1, dtype=torch.float, device=device
+            )
+            pred["line_matches0"] = torch.full(
+                (b_size, n_lines0), -1, device=device, dtype=torch.int64
+            )
+            pred["line_matches1"] = torch.full(
+                (b_size, n_lines1), -1, device=device, dtype=torch.int64
+            )
+            pred["line_match_scores0"] = torch.zeros(
+                (b_size, n_lines0), device=device, dtype=torch.float32
+            )
+            pred["line_match_scores1"] = torch.zeros(
+                (b_size, n_kpts1), device=device, dtype=torch.float32
+            )
             return pred
 
-        lines0 = data['lines0'].flatten(1, 2)
-        lines1 = data['lines1'].flatten(1, 2)
-        lines_junc_idx0 = data['lines_junc_idx0'].flatten(1, 2)  # [b_size, num_lines * 2]
-        lines_junc_idx1 = data['lines_junc_idx1'].flatten(1, 2)
+        lines0 = data["lines0"].flatten(1, 2)
+        lines1 = data["lines1"].flatten(1, 2)
+        lines_junc_idx0 = data["lines_junc_idx0"].flatten(
+            1, 2
+        )  # [b_size, num_lines * 2]
+        lines_junc_idx1 = data["lines_junc_idx1"].flatten(1, 2)
 
         if self.conf.bottleneck_dim is not None:
-            pred['down_descriptors0'] = desc0 = self.bottleneck_down(desc0)
-            pred['down_descriptors1'] = desc1 = self.bottleneck_down(desc1)
+            pred["down_descriptors0"] = desc0 = self.bottleneck_down(desc0)
+            pred["down_descriptors1"] = desc1 = self.bottleneck_down(desc1)
             desc0 = self.bottleneck_up(desc0)
             desc1 = self.bottleneck_up(desc1)
             desc0 = nn.functional.normalize(desc0, p=2, dim=1)
             desc1 = nn.functional.normalize(desc1, p=2, dim=1)
-            pred['bottleneck_descriptors0'] = desc0
-            pred['bottleneck_descriptors1'] = desc1
+            pred["bottleneck_descriptors0"] = desc0
+            pred["bottleneck_descriptors1"] = desc1
             if self.conf.loss.nll_weight == 0:
                 desc0 = desc0.detach()
                 desc1 = desc1.detach()
@@ -158,79 +197,113 @@ class GlueStick(BaseModel):
 
         assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1)
         assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1)
-        desc0 = desc0 + self.kenc(kpts0, data['keypoint_scores0'])
-        desc1 = desc1 + self.kenc(kpts1, data['keypoint_scores1'])
+        desc0 = desc0 + self.kenc(kpts0, data["keypoint_scores0"])
+        desc1 = desc1 + self.kenc(kpts1, data["keypoint_scores1"])
 
         if n_lines0 != 0 and n_lines1 != 0:
             # Pre-compute the line encodings
             lines0 = normalize_keypoints(lines0, image_size0).reshape(
-                b_size, n_lines0, 2, 2)
+                b_size, n_lines0, 2, 2
+            )
             lines1 = normalize_keypoints(lines1, image_size1).reshape(
-                b_size, n_lines1, 2, 2)
-            line_enc0 = self.lenc(lines0, data['line_scores0'])
-            line_enc1 = self.lenc(lines1, data['line_scores1'])
+                b_size, n_lines1, 2, 2
+            )
+            line_enc0 = self.lenc(lines0, data["line_scores0"])
+            line_enc1 = self.lenc(lines1, data["line_scores1"])
         else:
             line_enc0 = torch.zeros(
-                b_size, self.conf.descriptor_dim, n_lines0 * 2,
-                dtype=torch.float, device=device)
+                b_size,
+                self.conf.descriptor_dim,
+                n_lines0 * 2,
+                dtype=torch.float,
+                device=device,
+            )
             line_enc1 = torch.zeros(
-                b_size, self.conf.descriptor_dim, n_lines1 * 2,
-                dtype=torch.float, device=device)
+                b_size,
+                self.conf.descriptor_dim,
+                n_lines1 * 2,
+                dtype=torch.float,
+                device=device,
+            )
 
-        desc0, desc1 = self.gnn(desc0, desc1, line_enc0, line_enc1,
-                                lines_junc_idx0, lines_junc_idx1)
+        desc0, desc1 = self.gnn(
+            desc0, desc1, line_enc0, line_enc1, lines_junc_idx0, lines_junc_idx1
+        )
 
         # Match all points (KP and line junctions)
         mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
 
-        kp_scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1)
-        kp_scores = kp_scores / self.conf.descriptor_dim ** .5
+        kp_scores = torch.einsum("bdn,bdm->bnm", mdesc0, mdesc1)
+        kp_scores = kp_scores / self.conf.descriptor_dim**0.5
         kp_scores = log_double_softmax(kp_scores, self.bin_score)
         m0, m1, mscores0, mscores1 = self._get_matches(kp_scores)
-        pred['log_assignment'] = kp_scores
-        pred['matches0'] = m0
-        pred['matches1'] = m1
-        pred['match_scores0'] = mscores0
-        pred['match_scores1'] = mscores1
+        pred["log_assignment"] = kp_scores
+        pred["matches0"] = m0
+        pred["matches1"] = m1
+        pred["match_scores0"] = mscores0
+        pred["match_scores1"] = mscores1
 
         # Match the lines
         if n_lines0 > 0 and n_lines1 > 0:
-            (line_scores, m0_lines, m1_lines, mscores0_lines,
-             mscores1_lines, raw_line_scores) = self._get_line_matches(
-                desc0[:, :, :2 * n_lines0], desc1[:, :, :2 * n_lines1],
-                lines_junc_idx0, lines_junc_idx1, self.final_line_proj)
+            (
+                line_scores,
+                m0_lines,
+                m1_lines,
+                mscores0_lines,
+                mscores1_lines,
+                raw_line_scores,
+            ) = self._get_line_matches(
+                desc0[:, :, : 2 * n_lines0],
+                desc1[:, :, : 2 * n_lines1],
+                lines_junc_idx0,
+                lines_junc_idx1,
+                self.final_line_proj,
+            )
             if self.conf.inter_supervision:
                 for l in self.conf.inter_supervision:
-                    (line_scores_i, m0_lines_i, m1_lines_i, mscores0_lines_i,
-                     mscores1_lines_i) = self._get_line_matches(
-                        self.gnn.inter_layers[l][0][:, :, :2 * n_lines0],
-                        self.gnn.inter_layers[l][1][:, :, :2 * n_lines1],
-                        lines_junc_idx0, lines_junc_idx1,
-                        self.inter_line_proj[self.layer2idx[l]])
-                    pred[f'line_{l}_log_assignment'] = line_scores_i
-                    pred[f'line_{l}_matches0'] = m0_lines_i
-                    pred[f'line_{l}_matches1'] = m1_lines_i
-                    pred[f'line_{l}_match_scores0'] = mscores0_lines_i
-                    pred[f'line_{l}_match_scores1'] = mscores1_lines_i
+                    (
+                        line_scores_i,
+                        m0_lines_i,
+                        m1_lines_i,
+                        mscores0_lines_i,
+                        mscores1_lines_i,
+                    ) = self._get_line_matches(
+                        self.gnn.inter_layers[l][0][:, :, : 2 * n_lines0],
+                        self.gnn.inter_layers[l][1][:, :, : 2 * n_lines1],
+                        lines_junc_idx0,
+                        lines_junc_idx1,
+                        self.inter_line_proj[self.layer2idx[l]],
+                    )
+                    pred[f"line_{l}_log_assignment"] = line_scores_i
+                    pred[f"line_{l}_matches0"] = m0_lines_i
+                    pred[f"line_{l}_matches1"] = m1_lines_i
+                    pred[f"line_{l}_match_scores0"] = mscores0_lines_i
+                    pred[f"line_{l}_match_scores1"] = mscores1_lines_i
         else:
-            line_scores = torch.zeros(b_size, n_lines0, n_lines1,
-                                      dtype=torch.float, device=device)
-            m0_lines = torch.full((b_size, n_lines0), -1,
-                                  device=device, dtype=torch.int64)
-            m1_lines = torch.full((b_size, n_lines1), -1,
-                                  device=device, dtype=torch.int64)
+            line_scores = torch.zeros(
+                b_size, n_lines0, n_lines1, dtype=torch.float, device=device
+            )
+            m0_lines = torch.full(
+                (b_size, n_lines0), -1, device=device, dtype=torch.int64
+            )
+            m1_lines = torch.full(
+                (b_size, n_lines1), -1, device=device, dtype=torch.int64
+            )
             mscores0_lines = torch.zeros(
-                (b_size, n_lines0), device=device, dtype=torch.float32)
+                (b_size, n_lines0), device=device, dtype=torch.float32
+            )
             mscores1_lines = torch.zeros(
-                (b_size, n_lines1), device=device, dtype=torch.float32)
-            raw_line_scores = torch.zeros(b_size, n_lines0, n_lines1,
-                                          dtype=torch.float, device=device)
-        pred['line_log_assignment'] = line_scores
-        pred['line_matches0'] = m0_lines
-        pred['line_matches1'] = m1_lines
-        pred['line_match_scores0'] = mscores0_lines
-        pred['line_match_scores1'] = mscores1_lines
-        pred['raw_line_scores'] = raw_line_scores
+                (b_size, n_lines1), device=device, dtype=torch.float32
+            )
+            raw_line_scores = torch.zeros(
+                b_size, n_lines0, n_lines1, dtype=torch.float, device=device
+            )
+        pred["line_log_assignment"] = line_scores
+        pred["line_matches0"] = m0_lines
+        pred["line_matches1"] = m1_lines
+        pred["line_match_scores0"] = mscores0_lines
+        pred["line_match_scores1"] = mscores1_lines
+        pred["raw_line_scores"] = raw_line_scores
 
         return pred
 
@@ -249,35 +322,47 @@ class GlueStick(BaseModel):
         m1 = torch.where(valid1, m1, m1.new_tensor(-1))
         return m0, m1, mscores0, mscores1
 
-    def _get_line_matches(self, ldesc0, ldesc1, lines_junc_idx0,
-                          lines_junc_idx1, final_proj):
+    def _get_line_matches(
+        self, ldesc0, ldesc1, lines_junc_idx0, lines_junc_idx1, final_proj
+    ):
         mldesc0 = final_proj(ldesc0)
         mldesc1 = final_proj(ldesc1)
 
-        line_scores = torch.einsum('bdn,bdm->bnm', mldesc0, mldesc1)
-        line_scores = line_scores / self.conf.descriptor_dim ** .5
+        line_scores = torch.einsum("bdn,bdm->bnm", mldesc0, mldesc1)
+        line_scores = line_scores / self.conf.descriptor_dim**0.5
 
         # Get the line representation from the junction descriptors
         n2_lines0 = lines_junc_idx0.shape[1]
         n2_lines1 = lines_junc_idx1.shape[1]
         line_scores = torch.gather(
-            line_scores, dim=2,
-            index=lines_junc_idx1[:, None, :].repeat(1, line_scores.shape[1], 1))
+            line_scores,
+            dim=2,
+            index=lines_junc_idx1[:, None, :].repeat(1, line_scores.shape[1], 1),
+        )
         line_scores = torch.gather(
-            line_scores, dim=1,
-            index=lines_junc_idx0[:, :, None].repeat(1, 1, n2_lines1))
-        line_scores = line_scores.reshape((-1, n2_lines0 // 2, 2,
-                                           n2_lines1 // 2, 2))
+            line_scores,
+            dim=1,
+            index=lines_junc_idx0[:, :, None].repeat(1, 1, n2_lines1),
+        )
+        line_scores = line_scores.reshape((-1, n2_lines0 // 2, 2, n2_lines1 // 2, 2))
 
         # Match either in one direction or the other
         raw_line_scores = 0.5 * torch.maximum(
             line_scores[:, :, 0, :, 0] + line_scores[:, :, 1, :, 1],
-            line_scores[:, :, 0, :, 1] + line_scores[:, :, 1, :, 0])
+            line_scores[:, :, 0, :, 1] + line_scores[:, :, 1, :, 0],
+        )
         line_scores = log_double_softmax(raw_line_scores, self.line_bin_score)
         m0_lines, m1_lines, mscores0_lines, mscores1_lines = self._get_matches(
-            line_scores)
-        return (line_scores, m0_lines, m1_lines, mscores0_lines,
-                mscores1_lines, raw_line_scores)
+            line_scores
+        )
+        return (
+            line_scores,
+            m0_lines,
+            m1_lines,
+            mscores0_lines,
+            mscores1_lines,
+            raw_line_scores,
+        )
 
     def loss(self, pred, data):
         raise NotImplementedError()
@@ -290,8 +375,7 @@ def MLP(channels, do_bn=True):
     n = len(channels)
     layers = []
     for i in range(1, n):
-        layers.append(
-            nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
+        layers.append(nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
         if i < (n - 1):
             if do_bn:
                 layers.append(nn.BatchNorm1d(channels[i]))
@@ -338,17 +422,20 @@ class EndPtEncoder(nn.Module):
         endpt_offset = (endpoints[:, :, 1] - endpoints[:, :, 0]).unsqueeze(2)
         endpt_offset = torch.cat([endpt_offset, -endpt_offset], dim=2)
         endpt_offset = endpt_offset.reshape(b_size, 2 * n_pts, 2).transpose(1, 2)
-        inputs = [endpoints.flatten(1, 2).transpose(1, 2),
-                  endpt_offset, scores.repeat(1, 2).unsqueeze(1)]
+        inputs = [
+            endpoints.flatten(1, 2).transpose(1, 2),
+            endpt_offset,
+            scores.repeat(1, 2).unsqueeze(1),
+        ]
         return self.encoder(torch.cat(inputs, dim=1))
 
 
 @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
 def attention(query, key, value):
     dim = query.shape[1]
-    scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim ** .5
+    scores = torch.einsum("bdhn,bdhm->bhnm", query, key) / dim**0.5
     prob = torch.nn.functional.softmax(scores, dim=-1)
-    return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob
+    return torch.einsum("bhnm,bdhm->bdhn", prob, value), prob
 
 
 class MultiHeadedAttention(nn.Module):
@@ -363,8 +450,10 @@ class MultiHeadedAttention(nn.Module):
 
     def forward(self, query, key, value):
         b = query.size(0)
-        query, key, value = [l(x).view(b, self.dim, self.h, -1)
-                             for l, x in zip(self.proj, (query, key, value))]
+        query, key, value = [
+            l(x).view(b, self.dim, self.h, -1)
+            for l, x in zip(self.proj, (query, key, value))
+        ]
         x, prob = attention(query, key, value)
         # self.prob.append(prob.mean(dim=1))
         return self.merge(x.contiguous().view(b, self.dim * self.h, -1))
@@ -377,9 +466,9 @@ class AttentionalPropagation(nn.Module):
         self.mlp = MLP([num_dim * 2, num_dim * 2, num_dim], do_bn=True)
         nn.init.constant_(self.mlp[-1].bias, 0.0)
         if skip_init:
-            self.register_parameter('scaling', nn.Parameter(torch.tensor(0.)))
+            self.register_parameter("scaling", nn.Parameter(torch.tensor(0.0)))
         else:
-            self.scaling = 1.
+            self.scaling = 1.0
 
     def forward(self, x, source):
         message = self.attn(x, source, source)
@@ -389,14 +478,14 @@ class AttentionalPropagation(nn.Module):
 class GNNLayer(nn.Module):
     def __init__(self, feature_dim, layer_type, skip_init):
         super().__init__()
-        assert layer_type in ['cross', 'self']
+        assert layer_type in ["cross", "self"]
         self.type = layer_type
         self.update = AttentionalPropagation(feature_dim, 4, skip_init)
 
     def forward(self, desc0, desc1):
-        if self.type == 'cross':
+        if self.type == "cross":
             src0, src1 = desc1, desc0
-        elif self.type == 'self':
+        elif self.type == "self":
             src0, src1 = desc0, desc1
         else:
             raise ValueError("Unknown layer type: " + self.type)
@@ -422,11 +511,19 @@ class LineLayer(nn.Module):
         # Create one message per line endpoint
         b_size = lines_junc_idx.shape[0]
         line_desc = torch.gather(
-            ldesc, 2, lines_junc_idx[:, None].repeat(1, self.dim, 1))
-        message = torch.cat([
-            line_desc,
-            line_desc.reshape(b_size, self.dim, -1, 2).flip([-1]).flatten(2, 3).clone(),
-            line_enc], dim=1)
+            ldesc, 2, lines_junc_idx[:, None].repeat(1, self.dim, 1)
+        )
+        message = torch.cat(
+            [
+                line_desc,
+                line_desc.reshape(b_size, self.dim, -1, 2)
+                .flip([-1])
+                .flatten(2, 3)
+                .clone(),
+                line_enc,
+            ],
+            dim=1,
+        )
         return self.mlp(message)  # [b_size, D, n_lines * 2]
 
     def get_endpoint_attention(self, ldesc, line_enc, lines_junc_idx):
@@ -442,22 +539,32 @@ class LineLayer(nn.Module):
 
         # Key: combination of neighboring desc and line encodings
         line_desc = torch.gather(ldesc, 2, expanded_lines_junc_idx)
-        key = self.proj_neigh(torch.cat([
-            line_desc.reshape(b_size, self.dim, -1, 2).flip([-1]).flatten(2, 3).clone(),
-            line_enc], dim=1))  # [b_size, D, n_lines * 2]
+        key = self.proj_neigh(
+            torch.cat(
+                [
+                    line_desc.reshape(b_size, self.dim, -1, 2)
+                    .flip([-1])
+                    .flatten(2, 3)
+                    .clone(),
+                    line_enc,
+                ],
+                dim=1,
+            )
+        )  # [b_size, D, n_lines * 2]
 
         # Compute the attention weights with a custom softmax per junction
-        prob = (query * key).sum(dim=1) / self.dim ** .5  # [b_size, n_lines * 2]
+        prob = (query * key).sum(dim=1) / self.dim**0.5  # [b_size, n_lines * 2]
         prob = torch.exp(prob - prob.max())
         denom = torch.zeros_like(ldesc[:, 0]).scatter_reduce_(
-            dim=1, index=lines_junc_idx,
-            src=prob, reduce='sum', include_self=False)  # [b_size, n_junc]
+            dim=1, index=lines_junc_idx, src=prob, reduce="sum", include_self=False
+        )  # [b_size, n_junc]
         denom = torch.gather(denom, 1, lines_junc_idx)  # [b_size, n_lines * 2]
         prob = prob / (denom + ETH_EPS)
         return prob  # [b_size, n_lines * 2]
 
-    def forward(self, ldesc0, ldesc1, line_enc0, line_enc1, lines_junc_idx0,
-                lines_junc_idx1):
+    def forward(
+        self, ldesc0, ldesc1, line_enc0, line_enc1, lines_junc_idx0, lines_junc_idx1
+    ):
         # Gather the endpoint updates
         lupdate0 = self.get_endpoint_update(ldesc0, line_enc0, lines_junc_idx0)
         lupdate1 = self.get_endpoint_update(ldesc1, line_enc1, lines_junc_idx1)
@@ -466,26 +573,40 @@ class LineLayer(nn.Module):
         dim = ldesc0.shape[1]
         if self.line_attention:
             # Compute an attention for each neighbor and do a weighted average
-            prob0 = self.get_endpoint_attention(ldesc0, line_enc0,
-                                                lines_junc_idx0)
+            prob0 = self.get_endpoint_attention(ldesc0, line_enc0, lines_junc_idx0)
             lupdate0 = lupdate0 * prob0[:, None]
             update0 = update0.scatter_reduce_(
-                dim=2, index=lines_junc_idx0[:, None].repeat(1, dim, 1),
-                src=lupdate0, reduce='sum', include_self=False)
-            prob1 = self.get_endpoint_attention(ldesc1, line_enc1,
-                                                lines_junc_idx1)
+                dim=2,
+                index=lines_junc_idx0[:, None].repeat(1, dim, 1),
+                src=lupdate0,
+                reduce="sum",
+                include_self=False,
+            )
+            prob1 = self.get_endpoint_attention(ldesc1, line_enc1, lines_junc_idx1)
             lupdate1 = lupdate1 * prob1[:, None]
             update1 = update1.scatter_reduce_(
-                dim=2, index=lines_junc_idx1[:, None].repeat(1, dim, 1),
-                src=lupdate1, reduce='sum', include_self=False)
+                dim=2,
+                index=lines_junc_idx1[:, None].repeat(1, dim, 1),
+                src=lupdate1,
+                reduce="sum",
+                include_self=False,
+            )
         else:
             # Average the updates for each junction (requires torch > 1.12)
             update0 = update0.scatter_reduce_(
-                dim=2, index=lines_junc_idx0[:, None].repeat(1, dim, 1),
-                src=lupdate0, reduce='mean', include_self=False)
+                dim=2,
+                index=lines_junc_idx0[:, None].repeat(1, dim, 1),
+                src=lupdate0,
+                reduce="mean",
+                include_self=False,
+            )
             update1 = update1.scatter_reduce_(
-                dim=2, index=lines_junc_idx1[:, None].repeat(1, dim, 1),
-                src=lupdate1, reduce='mean', include_self=False)
+                dim=2,
+                index=lines_junc_idx1[:, None].repeat(1, dim, 1),
+                src=lupdate1,
+                reduce="mean",
+                include_self=False,
+            )
 
         # Update
         ldesc0 = ldesc0 + update0
@@ -495,47 +616,75 @@ class LineLayer(nn.Module):
 
 
 class AttentionalGNN(nn.Module):
-    def __init__(self, feature_dim, layer_types, checkpointed=False,
-                 skip=False, inter_supervision=None, num_line_iterations=1,
-                 line_attention=False):
+    def __init__(
+        self,
+        feature_dim,
+        layer_types,
+        checkpointed=False,
+        skip=False,
+        inter_supervision=None,
+        num_line_iterations=1,
+        line_attention=False,
+    ):
         super().__init__()
         self.checkpointed = checkpointed
         self.inter_supervision = inter_supervision
         self.num_line_iterations = num_line_iterations
         self.inter_layers = {}
-        self.layers = nn.ModuleList([
-            GNNLayer(feature_dim, layer_type, skip)
-            for layer_type in layer_types])
+        self.layers = nn.ModuleList(
+            [GNNLayer(feature_dim, layer_type, skip) for layer_type in layer_types]
+        )
         self.line_layers = nn.ModuleList(
-            [LineLayer(feature_dim, line_attention)
-             for _ in range(len(layer_types) // 2)])
-
-    def forward(self, desc0, desc1, line_enc0, line_enc1,
-                lines_junc_idx0, lines_junc_idx1):
+            [
+                LineLayer(feature_dim, line_attention)
+                for _ in range(len(layer_types) // 2)
+            ]
+        )
+
+    def forward(
+        self, desc0, desc1, line_enc0, line_enc1, lines_junc_idx0, lines_junc_idx1
+    ):
         for i, layer in enumerate(self.layers):
             if self.checkpointed:
                 desc0, desc1 = torch.utils.checkpoint.checkpoint(
-                    layer, desc0, desc1, preserve_rng_state=False)
+                    layer, desc0, desc1, preserve_rng_state=False
+                )
             else:
                 desc0, desc1 = layer(desc0, desc1)
-            if (layer.type == 'self' and lines_junc_idx0.shape[1] > 0
-                    and lines_junc_idx1.shape[1] > 0):
+            if (
+                layer.type == "self"
+                and lines_junc_idx0.shape[1] > 0
+                and lines_junc_idx1.shape[1] > 0
+            ):
                 # Add line self attention layers after every self layer
                 for _ in range(self.num_line_iterations):
                     if self.checkpointed:
                         desc0, desc1 = torch.utils.checkpoint.checkpoint(
-                            self.line_layers[i // 2], desc0, desc1, line_enc0,
-                            line_enc1, lines_junc_idx0, lines_junc_idx1,
-                            preserve_rng_state=False)
+                            self.line_layers[i // 2],
+                            desc0,
+                            desc1,
+                            line_enc0,
+                            line_enc1,
+                            lines_junc_idx0,
+                            lines_junc_idx1,
+                            preserve_rng_state=False,
+                        )
                     else:
                         desc0, desc1 = self.line_layers[i // 2](
-                            desc0, desc1, line_enc0, line_enc1,
-                            lines_junc_idx0, lines_junc_idx1)
+                            desc0,
+                            desc1,
+                            line_enc0,
+                            line_enc1,
+                            lines_junc_idx0,
+                            lines_junc_idx1,
+                        )
 
             # Optionally store the line descriptor at intermediate layers
-            if (self.inter_supervision is not None
-                    and (i // 2) in self.inter_supervision
-                    and layer.type == 'cross'):
+            if (
+                self.inter_supervision is not None
+                and (i // 2) in self.inter_supervision
+                and layer.type == "cross"
+            ):
                 self.inter_layers[i // 2] = (desc0.clone(), desc1.clone())
         return desc0, desc1
 
diff --git a/third_party/GlueStick/gluestick/models/superpoint.py b/third_party/GlueStick/gluestick/models/superpoint.py
index 0e0948a90cf5c858ddd14cc498231479fa10d6e3..19e66cdba41749a765829cce0ead608afb04964c 100644
--- a/third_party/GlueStick/gluestick/models/superpoint.py
+++ b/third_party/GlueStick/gluestick/models/superpoint.py
@@ -25,7 +25,8 @@ def simple_nms(scores, radius):
 
     def max_pool(x):
         return torch.nn.functional.max_pool2d(
-            x, kernel_size=radius * 2 + 1, stride=1, padding=radius)
+            x, kernel_size=radius * 2 + 1, stride=1, padding=radius
+        )
 
     zeros = torch.zeros_like(scores)
     max_mask = scores == max_pool(scores)
@@ -54,33 +55,35 @@ def top_k_keypoints(keypoints, scores, k):
 def sample_descriptors(keypoints, descriptors, s):
     b, c, h, w = descriptors.shape
     keypoints = keypoints - s / 2 + 0.5
-    keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
-                              ).to(keypoints)[None]
+    keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],).to(
+        keypoints
+    )[None]
     keypoints = keypoints * 2 - 1  # normalize to (-1, 1)
-    args = {'align_corners': True} if torch.__version__ >= '1.3' else {}
+    args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
     descriptors = torch.nn.functional.grid_sample(
-        descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args)
+        descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args
+    )
     descriptors = torch.nn.functional.normalize(
-        descriptors.reshape(b, c, -1), p=2, dim=1)
+        descriptors.reshape(b, c, -1), p=2, dim=1
+    )
     return descriptors
 
 
 class SuperPoint(BaseModel):
     default_conf = {
-        'has_detector': True,
-        'has_descriptor': True,
-        'descriptor_dim': 256,
-
+        "has_detector": True,
+        "has_descriptor": True,
+        "descriptor_dim": 256,
         # Inference
-        'return_all': False,
-        'sparse_outputs': True,
-        'nms_radius': 4,
-        'detection_threshold': 0.005,
-        'max_num_keypoints': -1,
-        'force_num_keypoints': False,
-        'remove_borders': 4,
+        "return_all": False,
+        "sparse_outputs": True,
+        "nms_radius": 4,
+        "detection_threshold": 0.005,
+        "max_num_keypoints": -1,
+        "force_num_keypoints": False,
+        "remove_borders": 4,
     }
-    required_data_keys = ['image']
+    required_data_keys = ["image"]
 
     def _init(self, conf):
         self.relu = nn.ReLU(inplace=True)
@@ -103,13 +106,14 @@ class SuperPoint(BaseModel):
         if conf.has_descriptor:
             self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
             self.convDb = nn.Conv2d(
-                c5, conf.descriptor_dim, kernel_size=1, stride=1, padding=0)
+                c5, conf.descriptor_dim, kernel_size=1, stride=1, padding=0
+            )
 
-        path = GLUESTICK_ROOT / 'resources' / 'weights' / 'superpoint_v1.pth'
+        path = GLUESTICK_ROOT / "resources" / "weights" / "superpoint_v1.pth"
         self.load_state_dict(torch.load(str(path)), strict=False)
 
     def _forward(self, data):
-        image = data['image']
+        image = data["image"]
         if image.shape[1] == 3:  # RGB
             scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
             image = (image * scale).sum(1, keepdim=True)
@@ -136,22 +140,24 @@ class SuperPoint(BaseModel):
             b, c, h, w = scores.shape
             scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
             scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
-            pred['keypoint_scores'] = dense_scores = scores
+            pred["keypoint_scores"] = dense_scores = scores
         if self.conf.has_descriptor:
             # Compute the dense descriptors
             cDa = self.relu(self.convDa(x))
             all_desc = self.convDb(cDa)
             all_desc = torch.nn.functional.normalize(all_desc, p=2, dim=1)
-            pred['descriptors'] = all_desc
+            pred["descriptors"] = all_desc
 
             if self.conf.max_num_keypoints == 0:  # Predict dense descriptors only
                 b_size = len(image)
                 device = image.device
                 return {
-                    'keypoints': torch.empty(b_size, 0, 2, device=device),
-                    'keypoint_scores': torch.empty(b_size, 0, device=device),
-                    'descriptors': torch.empty(b_size, self.conf.descriptor_dim, 0, device=device),
-                    'all_descriptors': all_desc
+                    "keypoints": torch.empty(b_size, 0, 2, device=device),
+                    "keypoint_scores": torch.empty(b_size, 0, device=device),
+                    "descriptors": torch.empty(
+                        b_size, self.conf.descriptor_dim, 0, device=device
+                    ),
+                    "all_descriptors": all_desc,
                 }
 
         if self.conf.sparse_outputs:
@@ -161,26 +167,36 @@ class SuperPoint(BaseModel):
 
             # Extract keypoints
             keypoints = [
-                torch.nonzero(s > self.conf.detection_threshold)
-                for s in scores]
+                torch.nonzero(s > self.conf.detection_threshold) for s in scores
+            ]
             scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
 
             # Discard keypoints near the image borders
-            keypoints, scores = list(zip(*[
-                remove_borders(k, s, self.conf.remove_borders, h * 8, w * 8)
-                for k, s in zip(keypoints, scores)]))
+            keypoints, scores = list(
+                zip(
+                    *[
+                        remove_borders(k, s, self.conf.remove_borders, h * 8, w * 8)
+                        for k, s in zip(keypoints, scores)
+                    ]
+                )
+            )
 
             # Keep the k keypoints with highest score
             if self.conf.max_num_keypoints > 0:
-                keypoints, scores = list(zip(*[
-                    top_k_keypoints(k, s, self.conf.max_num_keypoints)
-                    for k, s in zip(keypoints, scores)]))
+                keypoints, scores = list(
+                    zip(
+                        *[
+                            top_k_keypoints(k, s, self.conf.max_num_keypoints)
+                            for k, s in zip(keypoints, scores)
+                        ]
+                    )
+                )
 
             # Convert (h, w) to (x, y)
             keypoints = [torch.flip(k, [1]).float() for k in keypoints]
 
             if self.conf.force_num_keypoints:
-                _, _, h, w = data['image'].shape
+                _, _, h, w = data["image"].shape
                 assert self.conf.max_num_keypoints > 0
                 scores = list(scores)
                 for i in range(len(keypoints)):
@@ -194,8 +210,10 @@ class SuperPoint(BaseModel):
                         scores[i] = torch.cat([s, new_s], 0)
 
             # Extract descriptors
-            desc = [sample_descriptors(k[None], d[None], 8)[0]
-                    for k, d in zip(keypoints, all_desc)]
+            desc = [
+                sample_descriptors(k[None], d[None], 8)[0]
+                for k, d in zip(keypoints, all_desc)
+            ]
 
             if (len(keypoints) == 1) or self.conf.force_num_keypoints:
                 keypoints = torch.stack(keypoints, 0)
@@ -203,14 +221,14 @@ class SuperPoint(BaseModel):
                 desc = torch.stack(desc, 0)
 
             pred = {
-                'keypoints': keypoints,
-                'keypoint_scores': scores,
-                'descriptors': desc,
+                "keypoints": keypoints,
+                "keypoint_scores": scores,
+                "descriptors": desc,
             }
 
             if self.conf.return_all:
-                pred['all_descriptors'] = all_desc
-                pred['dense_score'] = dense_scores
+                pred["all_descriptors"] = all_desc
+                pred["dense_score"] = dense_scores
             else:
                 del all_desc
                 torch.cuda.empty_cache()
diff --git a/third_party/GlueStick/gluestick/models/two_view_pipeline.py b/third_party/GlueStick/gluestick/models/two_view_pipeline.py
index e0e21c1f62e2bd4ad573ebb87ea5635742b5032e..07a7bf06ea8c7ad2abba5fac2568ebcaffd497b0 100644
--- a/third_party/GlueStick/gluestick/models/two_view_pipeline.py
+++ b/third_party/GlueStick/gluestick/models/two_view_pipeline.py
@@ -22,10 +22,12 @@ def keep_quadrant_kp_subset(keypoints, scores, descs, h, w):
     h2, w2 = h // 2, w // 2
     w_x = np.random.choice([0, w2])
     w_y = np.random.choice([0, h2])
-    valid_mask = ((keypoints[..., 0] >= w_x)
-                  & (keypoints[..., 0] < w_x + w2)
-                  & (keypoints[..., 1] >= w_y)
-                  & (keypoints[..., 1] < w_y + h2))
+    valid_mask = (
+        (keypoints[..., 0] >= w_x)
+        & (keypoints[..., 0] < w_x + w2)
+        & (keypoints[..., 1] >= w_y)
+        & (keypoints[..., 1] < w_y + h2)
+    )
     keypoints = keypoints[valid_mask][None]
     scores = scores[valid_mask][None]
     descs = descs.permute(0, 2, 1)[valid_mask].t()[None]
@@ -46,47 +48,44 @@ def keep_best_kp_subset(keypoints, scores, descs, num_selected):
     """Keep the top num_selected best keypoints."""
     sorted_indices = torch.sort(scores, dim=1)[1]
     selected_kp = sorted_indices[:, -num_selected:]
-    keypoints = torch.gather(keypoints, 1,
-                             selected_kp[:, :, None].repeat(1, 1, 2))
+    keypoints = torch.gather(keypoints, 1, selected_kp[:, :, None].repeat(1, 1, 2))
     scores = torch.gather(scores, 1, selected_kp)
-    descs = torch.gather(descs, 2,
-                         selected_kp[:, None].repeat(1, descs.shape[1], 1))
+    descs = torch.gather(descs, 2, selected_kp[:, None].repeat(1, descs.shape[1], 1))
     return keypoints, scores, descs
 
 
 class TwoViewPipeline(BaseModel):
     default_conf = {
-        'extractor': {
-            'name': 'superpoint',
-            'trainable': False,
+        "extractor": {
+            "name": "superpoint",
+            "trainable": False,
         },
-        'use_lines': False,
-        'use_points': True,
-        'randomize_num_kp': False,
-        'detector': {'name': None},
-        'descriptor': {'name': None},
-        'matcher': {'name': 'nearest_neighbor_matcher'},
-        'filter': {'name': None},
-        'solver': {'name': None},
-        'ground_truth': {
-            'from_pose_depth': False,
-            'from_homography': False,
-            'th_positive': 3,
-            'th_negative': 5,
-            'reward_positive': 1,
-            'reward_negative': -0.25,
-            'is_likelihood_soft': True,
-            'p_random_occluders': 0,
-            'n_line_sampled_pts': 50,
-            'line_perp_dist_th': 5,
-            'overlap_th': 0.2,
-            'min_visibility_th': 0.5
+        "use_lines": False,
+        "use_points": True,
+        "randomize_num_kp": False,
+        "detector": {"name": None},
+        "descriptor": {"name": None},
+        "matcher": {"name": "nearest_neighbor_matcher"},
+        "filter": {"name": None},
+        "solver": {"name": None},
+        "ground_truth": {
+            "from_pose_depth": False,
+            "from_homography": False,
+            "th_positive": 3,
+            "th_negative": 5,
+            "reward_positive": 1,
+            "reward_negative": -0.25,
+            "is_likelihood_soft": True,
+            "p_random_occluders": 0,
+            "n_line_sampled_pts": 50,
+            "line_perp_dist_th": 5,
+            "overlap_th": 0.2,
+            "min_visibility_th": 0.5,
         },
     }
-    required_data_keys = ['image0', 'image1']
+    required_data_keys = ["image0", "image1"]
     strict_conf = False  # need to pass new confs to children models
-    components = [
-        'extractor', 'detector', 'descriptor', 'matcher', 'filter', 'solver']
+    components = ["extractor", "detector", "descriptor", "matcher", "filter", "solver"]
 
     def _init(self, conf):
         if conf.extractor.name:
@@ -95,17 +94,16 @@ class TwoViewPipeline(BaseModel):
             if self.conf.detector.name:
                 self.detector = get_model(conf.detector.name)(conf.detector)
             else:
-                self.required_data_keys += ['keypoints0', 'keypoints1']
+                self.required_data_keys += ["keypoints0", "keypoints1"]
             if self.conf.descriptor.name:
-                self.descriptor = get_model(conf.descriptor.name)(
-                    conf.descriptor)
+                self.descriptor = get_model(conf.descriptor.name)(conf.descriptor)
             else:
-                self.required_data_keys += ['descriptors0', 'descriptors1']
+                self.required_data_keys += ["descriptors0", "descriptors1"]
 
         if conf.matcher.name:
             self.matcher = get_model(conf.matcher.name)(conf.matcher)
         else:
-            self.required_data_keys += ['matches0']
+            self.required_data_keys += ["matches0"]
 
         if conf.filter.name:
             self.filter = get_model(conf.filter.name)(conf.filter)
@@ -114,7 +112,6 @@ class TwoViewPipeline(BaseModel):
             self.solver = get_model(conf.solver.name)(conf.solver)
 
     def _forward(self, data):
-
         def process_siamese(data, i):
             data_i = {k[:-1]: v for k, v in data.items() if k[-1] == i}
             if self.conf.extractor.name:
@@ -124,21 +121,28 @@ class TwoViewPipeline(BaseModel):
                 if self.conf.detector.name:
                     pred_i = self.detector(data_i)
                 else:
-                    for k in ['keypoints', 'keypoint_scores', 'descriptors',
-                              'lines', 'line_scores', 'line_descriptors',
-                              'valid_lines']:
+                    for k in [
+                        "keypoints",
+                        "keypoint_scores",
+                        "descriptors",
+                        "lines",
+                        "line_scores",
+                        "line_descriptors",
+                        "valid_lines",
+                    ]:
                         if k in data_i:
                             pred_i[k] = data_i[k]
                 if self.conf.descriptor.name:
-                    pred_i = {
-                        **pred_i, **self.descriptor({**data_i, **pred_i})}
+                    pred_i = {**pred_i, **self.descriptor({**data_i, **pred_i})}
             return pred_i
 
-        pred0 = process_siamese(data, '0')
-        pred1 = process_siamese(data, '1')
+        pred0 = process_siamese(data, "0")
+        pred1 = process_siamese(data, "1")
 
-        pred = {**{k + '0': v for k, v in pred0.items()},
-                **{k + '1': v for k, v in pred1.items()}}
+        pred = {
+            **{k + "0": v for k, v in pred0.items()},
+            **{k + "1": v for k, v in pred1.items()},
+        }
 
         if self.conf.matcher.name:
             pred = {**pred, **self.matcher({**data, **pred})}
@@ -161,8 +165,8 @@ class TwoViewPipeline(BaseModel):
                 except NotImplementedError:
                     continue
                 losses = {**losses, **losses_}
-                total = losses_['total'] + total
-        return {**losses, 'total': total}
+                total = losses_["total"] + total
+        return {**losses, "total": total}
 
     def metrics(self, pred, data):
         metrics = {}
diff --git a/third_party/GlueStick/gluestick/models/wireframe.py b/third_party/GlueStick/gluestick/models/wireframe.py
index 0e3dd9873c6fdb4edcb4c75a103673ee2cb3b3fa..9da539387c6da8a5a8df6c677af69803ccdb54b4 100644
--- a/third_party/GlueStick/gluestick/models/wireframe.py
+++ b/third_party/GlueStick/gluestick/models/wireframe.py
@@ -9,7 +9,7 @@ from ..geometry import warp_lines_torch
 
 
 def lines_to_wireframe(lines, line_scores, all_descs, conf):
-    """ Given a set of lines, their score and dense descriptors,
+    """Given a set of lines, their score and dense descriptors,
         merge close-by endpoints and compute a wireframe defined by
         its junctions and connectivity.
     Returns:
@@ -26,29 +26,41 @@ def lines_to_wireframe(lines, line_scores, all_descs, conf):
     device = lines.device
     endpoints = lines.reshape(b_size, -1, 2)
 
-    (junctions, junc_scores, junc_descs, connectivity, new_lines,
-     lines_junc_idx, num_true_junctions) = [], [], [], [], [], [], []
+    (
+        junctions,
+        junc_scores,
+        junc_descs,
+        connectivity,
+        new_lines,
+        lines_junc_idx,
+        num_true_junctions,
+    ) = ([], [], [], [], [], [], [])
     for bs in range(b_size):
         # Cluster the junctions that are close-by
-        db = DBSCAN(eps=conf.nms_radius, min_samples=1).fit(
-            endpoints[bs].cpu().numpy())
+        db = DBSCAN(eps=conf.nms_radius, min_samples=1).fit(endpoints[bs].cpu().numpy())
         clusters = db.labels_
         n_clusters = len(set(clusters))
         num_true_junctions.append(n_clusters)
 
         # Compute the average junction and score for each cluster
-        clusters = torch.tensor(clusters, dtype=torch.long,
-                                device=device)
-        new_junc = torch.zeros(n_clusters, 2, dtype=torch.float,
-                               device=device)
-        new_junc.scatter_reduce_(0, clusters[:, None].repeat(1, 2),
-                                 endpoints[bs], reduce='mean',
-                                 include_self=False)
+        clusters = torch.tensor(clusters, dtype=torch.long, device=device)
+        new_junc = torch.zeros(n_clusters, 2, dtype=torch.float, device=device)
+        new_junc.scatter_reduce_(
+            0,
+            clusters[:, None].repeat(1, 2),
+            endpoints[bs],
+            reduce="mean",
+            include_self=False,
+        )
         junctions.append(new_junc)
         new_scores = torch.zeros(n_clusters, dtype=torch.float, device=device)
         new_scores.scatter_reduce_(
-            0, clusters, torch.repeat_interleave(line_scores[bs], 2),
-            reduce='mean', include_self=False)
+            0,
+            clusters,
+            torch.repeat_interleave(line_scores[bs], 2),
+            reduce="mean",
+            include_self=False,
+        )
         junc_scores.append(new_scores)
 
         # Compute the new lines
@@ -56,50 +68,56 @@ def lines_to_wireframe(lines, line_scores, all_descs, conf):
         lines_junc_idx.append(clusters.reshape(-1, 2))
 
         # Compute the junction connectivity
-        junc_connect = torch.eye(n_clusters, dtype=torch.bool,
-                                 device=device)
+        junc_connect = torch.eye(n_clusters, dtype=torch.bool, device=device)
         pairs = clusters.reshape(-1, 2)  # these pairs are connected by a line
         junc_connect[pairs[:, 0], pairs[:, 1]] = True
         junc_connect[pairs[:, 1], pairs[:, 0]] = True
         connectivity.append(junc_connect)
 
         # Interpolate the new junction descriptors
-        junc_descs.append(sample_descriptors(
-            junctions[-1][None], all_descs[bs:(bs + 1)], 8)[0])
+        junc_descs.append(
+            sample_descriptors(junctions[-1][None], all_descs[bs : (bs + 1)], 8)[0]
+        )
 
     new_lines = torch.stack(new_lines, dim=0)
     lines_junc_idx = torch.stack(lines_junc_idx, dim=0)
-    return (junctions, junc_scores, junc_descs, connectivity,
-            new_lines, lines_junc_idx, num_true_junctions)
+    return (
+        junctions,
+        junc_scores,
+        junc_descs,
+        connectivity,
+        new_lines,
+        lines_junc_idx,
+        num_true_junctions,
+    )
 
 
 class SPWireframeDescriptor(BaseModel):
     default_conf = {
-        'sp_params': {
-            'has_detector': True,
-            'has_descriptor': True,
-            'descriptor_dim': 256,
-            'trainable': False,
-
+        "sp_params": {
+            "has_detector": True,
+            "has_descriptor": True,
+            "descriptor_dim": 256,
+            "trainable": False,
             # Inference
-            'return_all': True,
-            'sparse_outputs': True,
-            'nms_radius': 4,
-            'detection_threshold': 0.005,
-            'max_num_keypoints': 1000,
-            'force_num_keypoints': True,
-            'remove_borders': 4,
+            "return_all": True,
+            "sparse_outputs": True,
+            "nms_radius": 4,
+            "detection_threshold": 0.005,
+            "max_num_keypoints": 1000,
+            "force_num_keypoints": True,
+            "remove_borders": 4,
         },
-        'wireframe_params': {
-            'merge_points': True,
-            'merge_line_endpoints': True,
-            'nms_radius': 3,
-            'max_n_junctions': 500,
+        "wireframe_params": {
+            "merge_points": True,
+            "merge_line_endpoints": True,
+            "nms_radius": 3,
+            "max_n_junctions": 500,
         },
-        'max_n_lines': 250,
-        'min_length': 15,
+        "max_n_lines": 250,
+        "min_length": 15,
     }
-    required_data_keys = ['image']
+    required_data_keys = ["image"]
 
     def _init(self, conf):
         self.conf = conf
@@ -139,78 +157,108 @@ class SPWireframeDescriptor(BaseModel):
         return lines, scores, valid_lines
 
     def _forward(self, data):
-        b_size, _, h, w = data['image'].shape
-        device = data['image'].device
+        b_size, _, h, w = data["image"].shape
+        device = data["image"].device
 
         if not self.conf.sp_params.force_num_keypoints:
             assert b_size == 1, "Only batch size of 1 accepted for non padded inputs"
 
         # Line detection
-        if 'lines' not in data or 'line_scores' not in data:
-            if 'original_img' in data:
+        if "lines" not in data or "line_scores" not in data:
+            if "original_img" in data:
                 # Detect more lines, because when projecting them to the image most of them will be discarded
                 lines, line_scores, valid_lines = self.detect_lsd_lines(
-                    data['original_img'], self.conf.max_n_lines * 3)
+                    data["original_img"], self.conf.max_n_lines * 3
+                )
                 # Apply the same transformation that is applied in homography_adaptation
-                lines, valid_lines2 = warp_lines_torch(lines, data['H'], False, data['image'].shape[-2:])
+                lines, valid_lines2 = warp_lines_torch(
+                    lines, data["H"], False, data["image"].shape[-2:]
+                )
                 valid_lines = valid_lines & valid_lines2
                 lines[~valid_lines] = -1
                 line_scores[~valid_lines] = 0
                 # Re-sort the line segments to pick the ones that are inside the image and have bigger score
-                sorted_scores, sorting_indices = torch.sort(line_scores, dim=-1, descending=True)
-                line_scores = sorted_scores[:, :self.conf.max_n_lines]
-                sorting_indices = sorting_indices[:, :self.conf.max_n_lines]
+                sorted_scores, sorting_indices = torch.sort(
+                    line_scores, dim=-1, descending=True
+                )
+                line_scores = sorted_scores[:, : self.conf.max_n_lines]
+                sorting_indices = sorting_indices[:, : self.conf.max_n_lines]
                 lines = torch.take_along_dim(lines, sorting_indices[..., None, None], 1)
                 valid_lines = torch.take_along_dim(valid_lines, sorting_indices, 1)
             else:
-                lines, line_scores, valid_lines = self.detect_lsd_lines(data['image'])
+                lines, line_scores, valid_lines = self.detect_lsd_lines(data["image"])
 
         else:
-            lines, line_scores, valid_lines = data['lines'], data['line_scores'], data['valid_lines']
+            lines, line_scores, valid_lines = (
+                data["lines"],
+                data["line_scores"],
+                data["valid_lines"],
+            )
         if line_scores.shape[-1] != 0:
-            line_scores /= (line_scores.new_tensor(1e-8) + line_scores.max(dim=1).values[:, None])
+            line_scores /= (
+                line_scores.new_tensor(1e-8) + line_scores.max(dim=1).values[:, None]
+            )
 
         # SuperPoint prediction
         pred = self.sp(data)
 
         # Remove keypoints that are too close to line endpoints
         if self.conf.wireframe_params.merge_points:
-            kp = pred['keypoints']
+            kp = pred["keypoints"]
             line_endpts = lines.reshape(b_size, -1, 2)
-            dist_pt_lines = torch.norm(
-                kp[:, :, None] - line_endpts[:, None], dim=-1)
+            dist_pt_lines = torch.norm(kp[:, :, None] - line_endpts[:, None], dim=-1)
             # For each keypoint, mark it as valid or to remove
             pts_to_remove = torch.any(
-                dist_pt_lines < self.conf.sp_params.nms_radius, dim=2)
+                dist_pt_lines < self.conf.sp_params.nms_radius, dim=2
+            )
             # Simply remove them (we assume batch_size = 1 here)
             assert len(kp) == 1
-            pred['keypoints'] = pred['keypoints'][0][~pts_to_remove[0]][None]
-            pred['keypoint_scores'] = pred['keypoint_scores'][0][~pts_to_remove[0]][None]
-            pred['descriptors'] = pred['descriptors'][0].T[~pts_to_remove[0]].T[None]
+            pred["keypoints"] = pred["keypoints"][0][~pts_to_remove[0]][None]
+            pred["keypoint_scores"] = pred["keypoint_scores"][0][~pts_to_remove[0]][
+                None
+            ]
+            pred["descriptors"] = pred["descriptors"][0].T[~pts_to_remove[0]].T[None]
 
         # Connect the lines together to form a wireframe
         orig_lines = lines.clone()
         if self.conf.wireframe_params.merge_line_endpoints and len(lines[0]) > 0:
             # Merge first close-by endpoints to connect lines
-            (line_points, line_pts_scores, line_descs, line_association,
-             lines, lines_junc_idx, num_true_junctions) = lines_to_wireframe(
-                lines, line_scores, pred['all_descriptors'],
-                conf=self.conf.wireframe_params)
+            (
+                line_points,
+                line_pts_scores,
+                line_descs,
+                line_association,
+                lines,
+                lines_junc_idx,
+                num_true_junctions,
+            ) = lines_to_wireframe(
+                lines,
+                line_scores,
+                pred["all_descriptors"],
+                conf=self.conf.wireframe_params,
+            )
 
             # Add the keypoints to the junctions and fill the rest with random keypoints
-            (all_points, all_scores, all_descs,
-             pl_associativity) = [], [], [], []
+            (all_points, all_scores, all_descs, pl_associativity) = [], [], [], []
             for bs in range(b_size):
-                all_points.append(torch.cat(
-                    [line_points[bs], pred['keypoints'][bs]], dim=0))
-                all_scores.append(torch.cat(
-                    [line_pts_scores[bs], pred['keypoint_scores'][bs]], dim=0))
-                all_descs.append(torch.cat(
-                    [line_descs[bs], pred['descriptors'][bs]], dim=1))
-
-                associativity = torch.eye(len(all_points[-1]), dtype=torch.bool, device=device)
-                associativity[:num_true_junctions[bs], :num_true_junctions[bs]] = \
-                    line_association[bs][:num_true_junctions[bs], :num_true_junctions[bs]]
+                all_points.append(
+                    torch.cat([line_points[bs], pred["keypoints"][bs]], dim=0)
+                )
+                all_scores.append(
+                    torch.cat([line_pts_scores[bs], pred["keypoint_scores"][bs]], dim=0)
+                )
+                all_descs.append(
+                    torch.cat([line_descs[bs], pred["descriptors"][bs]], dim=1)
+                )
+
+                associativity = torch.eye(
+                    len(all_points[-1]), dtype=torch.bool, device=device
+                )
+                associativity[
+                    : num_true_junctions[bs], : num_true_junctions[bs]
+                ] = line_association[bs][
+                    : num_true_junctions[bs], : num_true_junctions[bs]
+                ]
                 pl_associativity.append(associativity)
 
             all_points = torch.stack(all_points, dim=0)
@@ -219,38 +267,55 @@ class SPWireframeDescriptor(BaseModel):
             pl_associativity = torch.stack(pl_associativity, dim=0)
         else:
             # Lines are independent
-            all_points = torch.cat([lines.reshape(b_size, -1, 2),
-                                    pred['keypoints']], dim=1)
+            all_points = torch.cat(
+                [lines.reshape(b_size, -1, 2), pred["keypoints"]], dim=1
+            )
             n_pts = all_points.shape[1]
             num_lines = lines.shape[1]
             num_true_junctions = [num_lines * 2] * b_size
-            all_scores = torch.cat([
-                torch.repeat_interleave(line_scores, 2, dim=1),
-                pred['keypoint_scores']], dim=1)
-            pred['line_descriptors'] = self.endpoints_pooling(
-                lines, pred['all_descriptors'], (h, w))
-            all_descs = torch.cat([
-                pred['line_descriptors'].reshape(b_size, self.conf.sp_params.descriptor_dim, -1),
-                pred['descriptors']], dim=2)
-            pl_associativity = torch.eye(
-                n_pts, dtype=torch.bool,
-                device=device)[None].repeat(b_size, 1, 1)
-            lines_junc_idx = torch.arange(
-                num_lines * 2, device=device).reshape(1, -1, 2).repeat(b_size, 1, 1)
-
-        del pred['all_descriptors']  # Remove dense descriptors to save memory
+            all_scores = torch.cat(
+                [
+                    torch.repeat_interleave(line_scores, 2, dim=1),
+                    pred["keypoint_scores"],
+                ],
+                dim=1,
+            )
+            pred["line_descriptors"] = self.endpoints_pooling(
+                lines, pred["all_descriptors"], (h, w)
+            )
+            all_descs = torch.cat(
+                [
+                    pred["line_descriptors"].reshape(
+                        b_size, self.conf.sp_params.descriptor_dim, -1
+                    ),
+                    pred["descriptors"],
+                ],
+                dim=2,
+            )
+            pl_associativity = torch.eye(n_pts, dtype=torch.bool, device=device)[
+                None
+            ].repeat(b_size, 1, 1)
+            lines_junc_idx = (
+                torch.arange(num_lines * 2, device=device)
+                .reshape(1, -1, 2)
+                .repeat(b_size, 1, 1)
+            )
+
+        del pred["all_descriptors"]  # Remove dense descriptors to save memory
         torch.cuda.empty_cache()
 
-        return {'keypoints': all_points,
-                'keypoint_scores': all_scores,
-                'descriptors': all_descs,
-                'pl_associativity': pl_associativity,
-                'num_junctions': torch.tensor(num_true_junctions),
-                'lines': lines,
-                'orig_lines': orig_lines,
-                'lines_junc_idx': lines_junc_idx,
-                'line_scores': line_scores,
-                'valid_lines': valid_lines}
+        return {
+            "keypoints": all_points,
+            "keypoint_scores": all_scores,
+            "descriptors": all_descs,
+            "pl_associativity": pl_associativity,
+            "num_junctions": torch.tensor(num_true_junctions),
+            "lines": lines,
+            "orig_lines": orig_lines,
+            "lines_junc_idx": lines_junc_idx,
+            "line_scores": line_scores,
+            "valid_lines": valid_lines,
+        }
 
     @staticmethod
     def endpoints_pooling(segs, all_descriptors, img_shape):
@@ -259,11 +324,21 @@ class SPWireframeDescriptor(BaseModel):
         scale_x = filter_shape[1] / img_shape[1]
         scale_y = filter_shape[0] / img_shape[0]
 
-        scaled_segs = torch.round(segs * torch.tensor([scale_x, scale_y]).to(segs)).long()
+        scaled_segs = torch.round(
+            segs * torch.tensor([scale_x, scale_y]).to(segs)
+        ).long()
         scaled_segs[..., 0] = torch.clip(scaled_segs[..., 0], 0, filter_shape[1] - 1)
         scaled_segs[..., 1] = torch.clip(scaled_segs[..., 1], 0, filter_shape[0] - 1)
-        line_descriptors = [all_descriptors[None, b, ..., torch.squeeze(b_segs[..., 1]), torch.squeeze(b_segs[..., 0])]
-                            for b, b_segs in enumerate(scaled_segs)]
+        line_descriptors = [
+            all_descriptors[
+                None,
+                b,
+                ...,
+                torch.squeeze(b_segs[..., 1]),
+                torch.squeeze(b_segs[..., 0]),
+            ]
+            for b, b_segs in enumerate(scaled_segs)
+        ]
         line_descriptors = torch.cat(line_descriptors)
         return line_descriptors  # Shape (1, 256, 308, 2)
 
diff --git a/third_party/GlueStick/gluestick/run.py b/third_party/GlueStick/gluestick/run.py
index 6baa88834f0b4dfde769ebe6c671e4ec49d4ed10..89569b878cca84fc48ef0b772f71b07befeb45a6 100644
--- a/third_party/GlueStick/gluestick/run.py
+++ b/third_party/GlueStick/gluestick/run.py
@@ -7,49 +7,58 @@ import torch
 from matplotlib import pyplot as plt
 
 from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT
-from .drawing import plot_images, plot_lines, plot_color_line_matches, plot_keypoints, plot_matches
+from .drawing import (
+    plot_images,
+    plot_lines,
+    plot_color_line_matches,
+    plot_keypoints,
+    plot_matches,
+)
 from .models.two_view_pipeline import TwoViewPipeline
 
 
 def main():
     # Parse input parameters
     parser = argparse.ArgumentParser(
-        prog='GlueStick Demo',
-        description='Demo app to show the point and line matches obtained by GlueStick')
-    parser.add_argument('-img1', default=join('resources' + os.path.sep + 'img1.jpg'))
-    parser.add_argument('-img2', default=join('resources' + os.path.sep + 'img2.jpg'))
-    parser.add_argument('--max_pts', type=int, default=1000)
-    parser.add_argument('--max_lines', type=int, default=300)
-    parser.add_argument('--skip-imshow', default=False, action='store_true')
+        prog="GlueStick Demo",
+        description="Demo app to show the point and line matches obtained by GlueStick",
+    )
+    parser.add_argument("-img1", default=join("resources" + os.path.sep + "img1.jpg"))
+    parser.add_argument("-img2", default=join("resources" + os.path.sep + "img2.jpg"))
+    parser.add_argument("--max_pts", type=int, default=1000)
+    parser.add_argument("--max_lines", type=int, default=300)
+    parser.add_argument("--skip-imshow", default=False, action="store_true")
     args = parser.parse_args()
 
     # Evaluation config
     conf = {
-        'name': 'two_view_pipeline',
-        'use_lines': True,
-        'extractor': {
-            'name': 'wireframe',
-            'sp_params': {
-                'force_num_keypoints': False,
-                'max_num_keypoints': args.max_pts,
+        "name": "two_view_pipeline",
+        "use_lines": True,
+        "extractor": {
+            "name": "wireframe",
+            "sp_params": {
+                "force_num_keypoints": False,
+                "max_num_keypoints": args.max_pts,
             },
-            'wireframe_params': {
-                'merge_points': True,
-                'merge_line_endpoints': True,
+            "wireframe_params": {
+                "merge_points": True,
+                "merge_line_endpoints": True,
             },
-            'max_n_lines': args.max_lines,
+            "max_n_lines": args.max_lines,
         },
-        'matcher': {
-            'name': 'gluestick',
-            'weights': str(GLUESTICK_ROOT / 'resources' / 'weights' / 'checkpoint_GlueStick_MD.tar'),
-            'trainable': False,
+        "matcher": {
+            "name": "gluestick",
+            "weights": str(
+                GLUESTICK_ROOT / "resources" / "weights" / "checkpoint_GlueStick_MD.tar"
+            ),
+            "trainable": False,
+        },
+        "ground_truth": {
+            "from_pose_depth": False,
         },
-        'ground_truth': {
-            'from_pose_depth': False,
-        }
     }
 
-    device = 'cuda' if torch.cuda.is_available() else 'cpu'
+    device = "cuda" if torch.cuda.is_available() else "cpu"
 
     pipeline_model = TwoViewPipeline(conf).to(device).eval()
 
@@ -57,8 +66,11 @@ def main():
     gray1 = cv2.imread(args.img2, 0)
 
     torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1)
-    torch_gray0, torch_gray1 = torch_gray0.to(device)[None], torch_gray1.to(device)[None]
-    x = {'image0': torch_gray0, 'image1': torch_gray1}
+    torch_gray0, torch_gray1 = (
+        torch_gray0.to(device)[None],
+        torch_gray1.to(device)[None],
+    )
+    x = {"image0": torch_gray0, "image1": torch_gray1}
     pred = pipeline_model(x)
 
     pred = batch_to_np(pred)
@@ -79,29 +91,51 @@ def main():
     matched_lines1 = line_seg1[match_indices]
 
     # Plot the matches
-    img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(gray1, cv2.COLOR_GRAY2BGR)
-    plot_images([img0, img1], ['Image 1 - detected lines', 'Image 2 - detected lines'], dpi=200, pad=2.0)
+    img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(
+        gray1, cv2.COLOR_GRAY2BGR
+    )
+    plot_images(
+        [img0, img1],
+        ["Image 1 - detected lines", "Image 2 - detected lines"],
+        dpi=200,
+        pad=2.0,
+    )
     plot_lines([line_seg0, line_seg1], ps=4, lw=2)
-    plt.gcf().canvas.manager.set_window_title('Detected Lines')
-    plt.savefig('detected_lines.png')
-
-    plot_images([img0, img1], ['Image 1 - detected points', 'Image 2 - detected points'], dpi=200, pad=2.0)
-    plot_keypoints([kp0, kp1], colors='c')
-    plt.gcf().canvas.manager.set_window_title('Detected Points')
-    plt.savefig('detected_points.png')
-
-    plot_images([img0, img1], ['Image 1 - line matches', 'Image 2 - line matches'], dpi=200, pad=2.0)
+    plt.gcf().canvas.manager.set_window_title("Detected Lines")
+    plt.savefig("detected_lines.png")
+
+    plot_images(
+        [img0, img1],
+        ["Image 1 - detected points", "Image 2 - detected points"],
+        dpi=200,
+        pad=2.0,
+    )
+    plot_keypoints([kp0, kp1], colors="c")
+    plt.gcf().canvas.manager.set_window_title("Detected Points")
+    plt.savefig("detected_points.png")
+
+    plot_images(
+        [img0, img1],
+        ["Image 1 - line matches", "Image 2 - line matches"],
+        dpi=200,
+        pad=2.0,
+    )
     plot_color_line_matches([matched_lines0, matched_lines1], lw=2)
-    plt.gcf().canvas.manager.set_window_title('Line Matches')
-    plt.savefig('line_matches.png')
-
-    plot_images([img0, img1], ['Image 1 - point matches', 'Image 2 - point matches'], dpi=200, pad=2.0)
-    plot_matches(matched_kps0, matched_kps1, 'green', lw=1, ps=0)
-    plt.gcf().canvas.manager.set_window_title('Point Matches')
-    plt.savefig('detected_points.png')
+    plt.gcf().canvas.manager.set_window_title("Line Matches")
+    plt.savefig("line_matches.png")
+
+    plot_images(
+        [img0, img1],
+        ["Image 1 - point matches", "Image 2 - point matches"],
+        dpi=200,
+        pad=2.0,
+    )
+    plot_matches(matched_kps0, matched_kps1, "green", lw=1, ps=0)
+    plt.gcf().canvas.manager.set_window_title("Point Matches")
+    plt.savefig("detected_points.png")
     if not args.skip_imshow:
         plt.show()
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
diff --git a/third_party/GlueStick/setup.py b/third_party/GlueStick/setup.py
index f0caa063e99cf6d7784fe7d54af08dbb66811627..c1a9df947ac2b788597e3028226f8efbdcd21b94 100644
--- a/third_party/GlueStick/setup.py
+++ b/third_party/GlueStick/setup.py
@@ -1,3 +1,3 @@
 from setuptools import setup
 
-setup(name='gluestick', version="0.0", packages=['gluestick'])
+setup(name="gluestick", version="0.0", packages=["gluestick"])
diff --git a/third_party/LightGlue/lightglue/__init__.py b/third_party/LightGlue/lightglue/__init__.py
index 97ad123d1dd573770da8ce2c4025386b8c70e1a3..aed9fbee8abe8562a5821893e8a219e2f9a38171 100644
--- a/third_party/LightGlue/lightglue/__init__.py
+++ b/third_party/LightGlue/lightglue/__init__.py
@@ -1,4 +1,4 @@
 from .lightglue import LightGlue
 from .superpoint import SuperPoint
 from .disk import DISK
-from .utils import match_pair
\ No newline at end of file
+from .utils import match_pair
diff --git a/third_party/LightGlue/lightglue/disk.py b/third_party/LightGlue/lightglue/disk.py
index 0fd0dec1049299bb53861f359ef63b12578bc0dd..c3e6e63ba76a018709e3332cdf432d06f4cda081 100644
--- a/third_party/LightGlue/lightglue/disk.py
+++ b/third_party/LightGlue/lightglue/disk.py
@@ -7,21 +7,21 @@ from .utils import ImagePreprocessor
 
 class DISK(nn.Module):
     default_conf = {
-        'weights': 'depth',
-        'max_num_keypoints': None,
-        'desc_dim': 128,
-        'nms_window_size': 5,
-        'detection_threshold': 0.0,
-        'pad_if_not_divisible': True,
+        "weights": "depth",
+        "max_num_keypoints": None,
+        "desc_dim": 128,
+        "nms_window_size": 5,
+        "detection_threshold": 0.0,
+        "pad_if_not_divisible": True,
     }
 
     preprocess_conf = {
         **ImagePreprocessor.default_conf,
-        'resize': 1024,
-        'grayscale': False,
+        "resize": 1024,
+        "grayscale": False,
     }
 
-    required_data_keys = ['image']
+    required_data_keys = ["image"]
 
     def __init__(self, **conf) -> None:
         super().__init__()
@@ -30,16 +30,16 @@ class DISK(nn.Module):
         self.model = kornia.feature.DISK.from_pretrained(self.conf.weights)
 
     def forward(self, data: dict) -> dict:
-        """ Compute keypoints, scores, descriptors for image """
+        """Compute keypoints, scores, descriptors for image"""
         for key in self.required_data_keys:
-            assert key in data, f'Missing key {key} in data'
-        image = data['image']
+            assert key in data, f"Missing key {key} in data"
+        image = data["image"]
         features = self.model(
             image,
             n=self.conf.max_num_keypoints,
             window_size=self.conf.nms_window_size,
             score_threshold=self.conf.detection_threshold,
-            pad_if_not_divisible=self.conf.pad_if_not_divisible
+            pad_if_not_divisible=self.conf.pad_if_not_divisible,
         )
         keypoints = [f.keypoints for f in features]
         scores = [f.detection_scores for f in features]
@@ -51,20 +51,19 @@ class DISK(nn.Module):
         descriptors = torch.stack(descriptors, 0)
 
         return {
-            'keypoints': keypoints.to(image),
-            'keypoint_scores': scores.to(image),
-            'descriptors': descriptors.to(image),
+            "keypoints": keypoints.to(image),
+            "keypoint_scores": scores.to(image),
+            "descriptors": descriptors.to(image),
         }
 
     def extract(self, img: torch.Tensor, **conf) -> dict:
-        """ Perform extraction with online resizing"""
+        """Perform extraction with online resizing"""
         if img.dim() == 3:
             img = img[None]  # add batch dim
         assert img.dim() == 4 and img.shape[0] == 1
         shape = img.shape[-2:][::-1]
-        img, scales = ImagePreprocessor(
-            **{**self.preprocess_conf, **conf})(img)
-        feats = self.forward({'image': img})
-        feats['image_size'] = torch.tensor(shape)[None].to(img).float()
-        feats['keypoints'] = (feats['keypoints'] + .5) / scales[None] - .5
+        img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img)
+        feats = self.forward({"image": img})
+        feats["image_size"] = torch.tensor(shape)[None].to(img).float()
+        feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5
         return feats
diff --git a/third_party/LightGlue/lightglue/lightglue.py b/third_party/LightGlue/lightglue/lightglue.py
index 3dc872bdc902bb71f640ae8749c07240924c5540..4b20300bf9068267e7b4d334dc2d3e85114ddd3e 100644
--- a/third_party/LightGlue/lightglue/lightglue.py
+++ b/third_party/LightGlue/lightglue/lightglue.py
@@ -12,7 +12,7 @@ try:
 except ModuleNotFoundError:
     FlashCrossAttention = None
 
-if FlashCrossAttention or hasattr(F, 'scaled_dot_product_attention'):
+if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"):
     FLASH_AVAILABLE = True
 else:
     FLASH_AVAILABLE = False
@@ -21,9 +21,7 @@ torch.backends.cudnn.deterministic = True
 
 
 @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
-def normalize_keypoints(
-        kpts: torch.Tensor,
-        size: torch.Tensor) -> torch.Tensor:
+def normalize_keypoints(kpts: torch.Tensor, size: torch.Tensor) -> torch.Tensor:
     if isinstance(size, torch.Size):
         size = torch.tensor(size)[None]
     shift = size.float().to(kpts) / 2
@@ -38,22 +36,20 @@ def rotate_half(x: torch.Tensor) -> torch.Tensor:
     return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
 
 
-def apply_cached_rotary_emb(
-        freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
+def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
     return (t * freqs[0]) + (rotate_half(t) * freqs[1])
 
 
 class LearnableFourierPositionalEncoding(nn.Module):
-    def __init__(self, M: int, dim: int, F_dim: int = None,
-                 gamma: float = 1.0) -> None:
+    def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None:
         super().__init__()
         F_dim = F_dim if F_dim is not None else dim
         self.gamma = gamma
         self.Wr = nn.Linear(M, F_dim // 2, bias=False)
-        nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)
+        nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
-        """ encode position vector """
+        """encode position vector"""
         projected = self.Wr(x)
         cosines, sines = torch.cos(projected), torch.sin(projected)
         emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
@@ -63,16 +59,14 @@ class LearnableFourierPositionalEncoding(nn.Module):
 class TokenConfidence(nn.Module):
     def __init__(self, dim: int) -> None:
         super().__init__()
-        self.token = nn.Sequential(
-            nn.Linear(dim, 1),
-            nn.Sigmoid()
-        )
+        self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
 
     def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
-        """ get confidence tokens """
+        """get confidence tokens"""
         return (
             self.token(desc0.detach().float()).squeeze(-1),
-            self.token(desc1.detach().float()).squeeze(-1))
+            self.token(desc1.detach().float()).squeeze(-1),
+        )
 
 
 class Attention(nn.Module):
@@ -80,8 +74,8 @@ class Attention(nn.Module):
         super().__init__()
         if allow_flash and not FLASH_AVAILABLE:
             warnings.warn(
-                'FlashAttention is not available. For optimal speed, '
-                'consider installing torch >= 2.0 or flash-attn.',
+                "FlashAttention is not available. For optimal speed, "
+                "consider installing torch >= 2.0 or flash-attn.",
                 stacklevel=2,
             )
         self.enable_flash = allow_flash and FLASH_AVAILABLE
@@ -89,7 +83,7 @@ class Attention(nn.Module):
             self.flash_ = FlashCrossAttention()
 
     def forward(self, q, k, v) -> torch.Tensor:
-        if self.enable_flash and q.device.type == 'cuda':
+        if self.enable_flash and q.device.type == "cuda":
             if FlashCrossAttention:
                 q, k, v = [x.transpose(-2, -3) for x in [q, k, v]]
                 m = self.flash_(q.half(), torch.stack([k, v], 2).half())
@@ -98,35 +92,35 @@ class Attention(nn.Module):
                 args = [x.half().contiguous() for x in [q, k, v]]
                 with torch.backends.cuda.sdp_kernel(enable_flash=True):
                     return F.scaled_dot_product_attention(*args).to(q.dtype)
-        elif hasattr(F, 'scaled_dot_product_attention'):
+        elif hasattr(F, "scaled_dot_product_attention"):
             args = [x.contiguous() for x in [q, k, v]]
             return F.scaled_dot_product_attention(*args).to(q.dtype)
         else:
             s = q.shape[-1] ** -0.5
-            attn = F.softmax(torch.einsum('...id,...jd->...ij', q, k) * s, -1)
-            return torch.einsum('...ij,...jd->...id', attn, v)
+            attn = F.softmax(torch.einsum("...id,...jd->...ij", q, k) * s, -1)
+            return torch.einsum("...ij,...jd->...id", attn, v)
 
 
 class Transformer(nn.Module):
-    def __init__(self, embed_dim: int, num_heads: int,
-                 flash: bool = False, bias: bool = True) -> None:
+    def __init__(
+        self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
+    ) -> None:
         super().__init__()
         self.embed_dim = embed_dim
         self.num_heads = num_heads
         assert self.embed_dim % num_heads == 0
         self.head_dim = self.embed_dim // num_heads
-        self.Wqkv = nn.Linear(embed_dim, 3*embed_dim, bias=bias)
+        self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
         self.inner_attn = Attention(flash)
         self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
         self.ffn = nn.Sequential(
-            nn.Linear(2*embed_dim, 2*embed_dim),
-            nn.LayerNorm(2*embed_dim, elementwise_affine=True),
+            nn.Linear(2 * embed_dim, 2 * embed_dim),
+            nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
             nn.GELU(),
-            nn.Linear(2*embed_dim, embed_dim)
+            nn.Linear(2 * embed_dim, embed_dim),
         )
 
-    def _forward(self, x: torch.Tensor,
-                 encoding: Optional[torch.Tensor] = None):
+    def _forward(self, x: torch.Tensor, encoding: Optional[torch.Tensor] = None):
         qkv = self.Wqkv(x)
         qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
         q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
@@ -134,8 +128,7 @@ class Transformer(nn.Module):
             q = apply_cached_rotary_emb(encoding, q)
             k = apply_cached_rotary_emb(encoding, k)
         context = self.inner_attn(q, k, v)
-        message = self.out_proj(
-            context.transpose(1, 2).flatten(start_dim=-2))
+        message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
         return x + self.ffn(torch.cat([x, message], -1))
 
     def forward(self, x0, x1, encoding0=None, encoding1=None):
@@ -143,21 +136,22 @@ class Transformer(nn.Module):
 
 
 class CrossTransformer(nn.Module):
-    def __init__(self, embed_dim: int, num_heads: int,
-                 flash: bool = False, bias: bool = True) -> None:
+    def __init__(
+        self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
+    ) -> None:
         super().__init__()
         self.heads = num_heads
         dim_head = embed_dim // num_heads
-        self.scale = dim_head ** -0.5
+        self.scale = dim_head**-0.5
         inner_dim = dim_head * num_heads
         self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias)
         self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
         self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
         self.ffn = nn.Sequential(
-            nn.Linear(2*embed_dim, 2*embed_dim),
-            nn.LayerNorm(2*embed_dim, elementwise_affine=True),
+            nn.Linear(2 * embed_dim, 2 * embed_dim),
+            nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
             nn.GELU(),
-            nn.Linear(2*embed_dim, embed_dim)
+            nn.Linear(2 * embed_dim, embed_dim),
         )
 
         if flash and FLASH_AVAILABLE:
@@ -173,19 +167,19 @@ class CrossTransformer(nn.Module):
         v0, v1 = self.map_(self.to_v, x0, x1)
         qk0, qk1, v0, v1 = map(
             lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
-            (qk0, qk1, v0, v1))
+            (qk0, qk1, v0, v1),
+        )
         if self.flash is not None:
             m0 = self.flash(qk0, qk1, v1)
             m1 = self.flash(qk1, qk0, v0)
         else:
             qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
-            sim = torch.einsum('b h i d, b h j d -> b h i j', qk0, qk1)
+            sim = torch.einsum("b h i d, b h j d -> b h i j", qk0, qk1)
             attn01 = F.softmax(sim, dim=-1)
             attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
-            m0 = torch.einsum('bhij, bhjd -> bhid', attn01, v1)
-            m1 = torch.einsum('bhji, bhjd -> bhid', attn10.transpose(-2, -1), v0)
-        m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2),
-                           m0, m1)
+            m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
+            m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
+        m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
         m0, m1 = self.map_(self.to_out, m0, m1)
         x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
         x1 = x1 + self.ffn(torch.cat([x1, m1], -1))
@@ -193,15 +187,15 @@ class CrossTransformer(nn.Module):
 
 
 def sigmoid_log_double_softmax(
-        sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor) -> torch.Tensor:
-    """ create the log assignment matrix from logits and similarity"""
+    sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
+) -> torch.Tensor:
+    """create the log assignment matrix from logits and similarity"""
     b, m, n = sim.shape
     certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
     scores0 = F.log_softmax(sim, 2)
-    scores1 = F.log_softmax(
-        sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
-    scores = sim.new_full((b, m+1, n+1), 0)
-    scores[:, :m, :n] = (scores0 + scores1 + certainties)
+    scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
+    scores = sim.new_full((b, m + 1, n + 1), 0)
+    scores[:, :m, :n] = scores0 + scores1 + certainties
     scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
     scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
     return scores
@@ -215,11 +209,11 @@ class MatchAssignment(nn.Module):
         self.final_proj = nn.Linear(dim, dim, bias=True)
 
     def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
-        """ build assignment matrix from descriptors """
+        """build assignment matrix from descriptors"""
         mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
         _, _, d = mdesc0.shape
-        mdesc0, mdesc1 = mdesc0 / d**.25, mdesc1 / d**.25
-        sim = torch.einsum('bmd,bnd->bmn', mdesc0, mdesc1)
+        mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
+        sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1)
         z0 = self.matchability(desc0)
         z1 = self.matchability(desc1)
         scores = sigmoid_log_double_softmax(sim, z0, z1)
@@ -232,7 +226,7 @@ class MatchAssignment(nn.Module):
 
 
 def filter_matches(scores: torch.Tensor, th: float):
-    """ obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
+    """obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
     max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
     m0, m1 = max0.indices, max1.indices
     mutual0 = torch.arange(m0.shape[1]).to(m0)[None] == m1.gather(1, m0)
@@ -253,42 +247,39 @@ def filter_matches(scores: torch.Tensor, th: float):
 
 class LightGlue(nn.Module):
     default_conf = {
-        'name': 'lightglue',  # just for interfacing
-        'input_dim': 256,  # input descriptor dimension (autoselected from weights)
-        'descriptor_dim': 256,
-        'n_layers': 9,
-        'num_heads': 4,
-        'flash': True,  # enable FlashAttention if available.
-        'mp': False,  # enable mixed precision
-        'depth_confidence': 0.95,  # early stopping, disable with -1
-        'width_confidence': 0.99,  # point pruning, disable with -1
-        'filter_threshold': 0.1,  # match threshold
-        'weights': None,
+        "name": "lightglue",  # just for interfacing
+        "input_dim": 256,  # input descriptor dimension (autoselected from weights)
+        "descriptor_dim": 256,
+        "n_layers": 9,
+        "num_heads": 4,
+        "flash": True,  # enable FlashAttention if available.
+        "mp": False,  # enable mixed precision
+        "depth_confidence": 0.95,  # early stopping, disable with -1
+        "width_confidence": 0.99,  # point pruning, disable with -1
+        "filter_threshold": 0.1,  # match threshold
+        "weights": None,
     }
 
-    required_data_keys = [
-        'image0', 'image1']
+    required_data_keys = ["image0", "image1"]
 
     version = "v0.1_arxiv"
     url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"
 
     features = {
-        'superpoint': ('superpoint_lightglue', 256),
-        'disk': ('disk_lightglue', 128)
+        "superpoint": ("superpoint_lightglue", 256),
+        "disk": ("disk_lightglue", 128),
     }
 
-    def __init__(self, features='superpoint', **conf) -> None:
+    def __init__(self, features="superpoint", **conf) -> None:
         super().__init__()
         self.conf = {**self.default_conf, **conf}
         if features is not None:
-            assert (features in list(self.features.keys()))
-            self.conf['weights'], self.conf['input_dim'] = \
-                self.features[features]
+            assert features in list(self.features.keys())
+            self.conf["weights"], self.conf["input_dim"] = self.features[features]
         self.conf = conf = SimpleNamespace(**self.conf)
 
         if conf.input_dim != conf.descriptor_dim:
-            self.input_proj = nn.Linear(
-                conf.input_dim, conf.descriptor_dim, bias=True)
+            self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
         else:
             self.input_proj = nn.Identity()
 
@@ -297,26 +288,29 @@ class LightGlue(nn.Module):
 
         h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
         self.self_attn = nn.ModuleList(
-            [Transformer(d, h, conf.flash) for _ in range(n)])
+            [Transformer(d, h, conf.flash) for _ in range(n)]
+        )
         self.cross_attn = nn.ModuleList(
-            [CrossTransformer(d, h, conf.flash) for _ in range(n)])
-        self.log_assignment = nn.ModuleList(
-            [MatchAssignment(d) for _ in range(n)])
-        self.token_confidence = nn.ModuleList([
-            TokenConfidence(d) for _ in range(n-1)])
+            [CrossTransformer(d, h, conf.flash) for _ in range(n)]
+        )
+        self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
+        self.token_confidence = nn.ModuleList(
+            [TokenConfidence(d) for _ in range(n - 1)]
+        )
 
         if features is not None:
-            fname = f'{conf.weights}_{self.version}.pth'.replace('.', '-')
+            fname = f"{conf.weights}_{self.version}.pth".replace(".", "-")
             state_dict = torch.hub.load_state_dict_from_url(
-                self.url.format(self.version, features), file_name=fname)
+                self.url.format(self.version, features), file_name=fname
+            )
             self.load_state_dict(state_dict, strict=False)
         elif conf.weights is not None:
             path = Path(__file__).parent
-            path = path / 'weights/{}.pth'.format(self.conf.weights)
-            state_dict = torch.load(str(path), map_location='cpu')
+            path = path / "weights/{}.pth".format(self.conf.weights)
+            state_dict = torch.load(str(path), map_location="cpu")
             self.load_state_dict(state_dict, strict=False)
 
-        print('Loaded LightGlue model')
+        print("Loaded LightGlue model")
 
     def forward(self, data: dict) -> dict:
         """
@@ -339,27 +333,27 @@ class LightGlue(nn.Module):
             matching_scores1: [B x N]
             matches: List[[Si x 2]], scores: List[[Si]]
         """
-        with torch.autocast(enabled=self.conf.mp, device_type='cuda'):
+        with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
             return self._forward(data)
 
     def _forward(self, data: dict) -> dict:
         for key in self.required_data_keys:
-            assert key in data, f'Missing key {key} in data'
-        data0, data1 = data['image0'], data['image1']
-        kpts0_, kpts1_ = data0['keypoints'], data1['keypoints']
+            assert key in data, f"Missing key {key} in data"
+        data0, data1 = data["image0"], data["image1"]
+        kpts0_, kpts1_ = data0["keypoints"], data1["keypoints"]
         b, m, _ = kpts0_.shape
         b, n, _ = kpts1_.shape
-        size0, size1 = data0.get('image_size'), data1.get('image_size')
-        size0 = size0 if size0 is not None else data0['image'].shape[-2:][::-1]
-        size1 = size1 if size1 is not None else data1['image'].shape[-2:][::-1]
+        size0, size1 = data0.get("image_size"), data1.get("image_size")
+        size0 = size0 if size0 is not None else data0["image"].shape[-2:][::-1]
+        size1 = size1 if size1 is not None else data1["image"].shape[-2:][::-1]
         kpts0 = normalize_keypoints(kpts0_, size=size0)
         kpts1 = normalize_keypoints(kpts1_, size=size1)
 
         assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1)
         assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1)
 
-        desc0 = data0['descriptors'].detach()
-        desc1 = data1['descriptors'].detach()
+        desc0 = data0["descriptors"].detach()
+        desc1 = data1["descriptors"].detach()
 
         assert desc0.shape[-1] == self.conf.input_dim
         assert desc1.shape[-1] == self.conf.input_dim
@@ -384,19 +378,18 @@ class LightGlue(nn.Module):
         token0, token1 = None, None
         for i in range(self.conf.n_layers):
             # self+cross attention
-            desc0, desc1 = self.self_attn[i](
-                desc0, desc1, encoding0, encoding1)
+            desc0, desc1 = self.self_attn[i](desc0, desc1, encoding0, encoding1)
             desc0, desc1 = self.cross_attn[i](desc0, desc1)
             if i == self.conf.n_layers - 1:
                 continue  # no early stopping or adaptive width at last layer
             if dec > 0:  # early stopping
                 token0, token1 = self.token_confidence[i](desc0, desc1)
-                if self.stop(token0, token1, self.conf_th(i), dec, m+n):
+                if self.stop(token0, token1, self.conf_th(i), dec, m + n):
                     break
             if wic > 0:  # point pruning
                 match0, match1 = self.log_assignment[i].scores(desc0, desc1)
-                mask0 = self.get_mask(token0, match0, self.conf_th(i), 1-wic)
-                mask1 = self.get_mask(token1, match1, self.conf_th(i), 1-wic)
+                mask0 = self.get_mask(token0, match0, self.conf_th(i), 1 - wic)
+                mask1 = self.get_mask(token1, match1, self.conf_th(i), 1 - wic)
                 ind0, ind1 = ind0[mask0][None], ind1[mask1][None]
                 desc0, desc1 = desc0[mask0][None], desc1[mask1][None]
                 if desc0.shape[-2] == 0 or desc1.shape[-2] == 0:
@@ -409,17 +402,16 @@ class LightGlue(nn.Module):
         if wic > 0:  # scatter with indices after pruning
             scores_, _ = self.log_assignment[i](desc0, desc1)
             dt, dev = scores_.dtype, scores_.device
-            scores = torch.zeros(b, m+1, n+1, dtype=dt, device=dev)
+            scores = torch.zeros(b, m + 1, n + 1, dtype=dt, device=dev)
             scores[:, :-1, :-1] = -torch.inf
             scores[:, ind0[0], -1] = scores_[:, :-1, -1]
             scores[:, -1, ind1[0]] = scores_[:, -1, :-1]
-            x, y = torch.meshgrid(ind0[0], ind1[0], indexing='ij')
+            x, y = torch.meshgrid(ind0[0], ind1[0], indexing="ij")
             scores[:, x, y] = scores_[:, :-1, :-1]
         else:
             scores, _ = self.log_assignment[i](desc0, desc1)
 
-        m0, m1, mscores0, mscores1 = filter_matches(
-            scores, self.conf.filter_threshold)
+        m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
 
         matches, mscores = [], []
         for k in range(b):
@@ -428,36 +420,48 @@ class LightGlue(nn.Module):
             mscores.append(mscores0[k][valid])
 
         return {
-            'log_assignment': scores,
-            'matches0': m0,
-            'matches1': m1,
-            'matching_scores0': mscores0,
-            'matching_scores1': mscores1,
-            'stop': i+1,
-            'prune0': prune0,
-            'prune1': prune1,
-            'matches': matches,
-            'scores': mscores,
+            "log_assignment": scores,
+            "matches0": m0,
+            "matches1": m1,
+            "matching_scores0": mscores0,
+            "matching_scores1": mscores1,
+            "stop": i + 1,
+            "prune0": prune0,
+            "prune1": prune1,
+            "matches": matches,
+            "scores": mscores,
         }
 
     def conf_th(self, i: int) -> float:
-        """ scaled confidence threshold """
-        return np.clip(
-            0.8 + 0.1 * np.exp(-4.0 * i / self.conf.n_layers), 0, 1)
-
-    def get_mask(self, confidence: torch.Tensor, match: torch.Tensor,
-                 conf_th: float, match_th: float) -> torch.Tensor:
-        """ mask points which should be removed """
+        """scaled confidence threshold"""
+        return np.clip(0.8 + 0.1 * np.exp(-4.0 * i / self.conf.n_layers), 0, 1)
+
+    def get_mask(
+        self,
+        confidence: torch.Tensor,
+        match: torch.Tensor,
+        conf_th: float,
+        match_th: float,
+    ) -> torch.Tensor:
+        """mask points which should be removed"""
         if conf_th and confidence is not None:
-            mask = torch.where(confidence > conf_th, match,
-                               match.new_tensor(1.0)) > match_th
+            mask = (
+                torch.where(confidence > conf_th, match, match.new_tensor(1.0))
+                > match_th
+            )
         else:
             mask = match > match_th
         return mask
 
-    def stop(self, token0: torch.Tensor, token1: torch.Tensor,
-             conf_th: float, inl_th: float, seql: int) -> torch.Tensor:
-        """ evaluate stopping condition"""
+    def stop(
+        self,
+        token0: torch.Tensor,
+        token1: torch.Tensor,
+        conf_th: float,
+        inl_th: float,
+        seql: int,
+    ) -> torch.Tensor:
+        """evaluate stopping condition"""
         tokens = torch.cat([token0, token1], -1)
         if conf_th:
             pos = 1.0 - (tokens < conf_th).float().sum() / seql
diff --git a/third_party/LightGlue/lightglue/superpoint.py b/third_party/LightGlue/lightglue/superpoint.py
index abe7539767c9b2fe788e376e872d2844386b1a4a..1b7ce40f698bda6b2aca34d4ee504bd725933005 100644
--- a/third_party/LightGlue/lightglue/superpoint.py
+++ b/third_party/LightGlue/lightglue/superpoint.py
@@ -48,12 +48,13 @@ from .utils import ImagePreprocessor
 
 
 def simple_nms(scores, nms_radius: int):
-    """ Fast Non-maximum suppression to remove nearby points """
-    assert (nms_radius >= 0)
+    """Fast Non-maximum suppression to remove nearby points"""
+    assert nms_radius >= 0
 
     def max_pool(x):
         return torch.nn.functional.max_pool2d(
-            x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius)
+            x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
+        )
 
     zeros = torch.zeros_like(scores)
     max_mask = scores == max_pool(scores)
@@ -73,17 +74,20 @@ def top_k_keypoints(keypoints, scores, k):
 
 
 def sample_descriptors(keypoints, descriptors, s: int = 8):
-    """ Interpolate descriptors at keypoint locations """
+    """Interpolate descriptors at keypoint locations"""
     b, c, h, w = descriptors.shape
     keypoints = keypoints - s / 2 + 0.5
-    keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)],
-                              ).to(keypoints)[None]
-    keypoints = keypoints*2 - 1  # normalize to (-1, 1)
-    args = {'align_corners': True} if torch.__version__ >= '1.3' else {}
+    keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],).to(
+        keypoints
+    )[None]
+    keypoints = keypoints * 2 - 1  # normalize to (-1, 1)
+    args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
     descriptors = torch.nn.functional.grid_sample(
-        descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args)
+        descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args
+    )
     descriptors = torch.nn.functional.normalize(
-        descriptors.reshape(b, c, -1), p=2, dim=1)
+        descriptors.reshape(b, c, -1), p=2, dim=1
+    )
     return descriptors
 
 
@@ -95,21 +99,22 @@ class SuperPoint(nn.Module):
     Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629
 
     """
+
     default_conf = {
-        'descriptor_dim': 256,
-        'nms_radius': 4,
-        'max_num_keypoints': None,
-        'detection_threshold': 0.0005,
-        'remove_borders': 4,
+        "descriptor_dim": 256,
+        "nms_radius": 4,
+        "max_num_keypoints": None,
+        "detection_threshold": 0.0005,
+        "remove_borders": 4,
     }
 
     preprocess_conf = {
         **ImagePreprocessor.default_conf,
-        'resize': 1024,
-        'grayscale': True,
+        "resize": 1024,
+        "grayscale": True,
     }
 
-    required_data_keys = ['image']
+    required_data_keys = ["image"]
 
     def __init__(self, **conf):
         super().__init__()
@@ -133,26 +138,26 @@ class SuperPoint(nn.Module):
 
         self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
         self.convDb = nn.Conv2d(
-            c5, self.conf['descriptor_dim'],
-            kernel_size=1, stride=1, padding=0)
+            c5, self.conf["descriptor_dim"], kernel_size=1, stride=1, padding=0
+        )
 
         url = "https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_v1.pth"
         self.load_state_dict(torch.hub.load_state_dict_from_url(url))
 
-        mk = self.conf['max_num_keypoints']
+        mk = self.conf["max_num_keypoints"]
         if mk is not None and mk <= 0:
-            raise ValueError('max_num_keypoints must be positive or None')
+            raise ValueError("max_num_keypoints must be positive or None")
 
-        print('Loaded SuperPoint model')
+        print("Loaded SuperPoint model")
 
     def forward(self, data: dict) -> dict:
-        """ Compute keypoints, scores, descriptors for image """
+        """Compute keypoints, scores, descriptors for image"""
         for key in self.required_data_keys:
-            assert key in data, f'Missing key {key} in data'
-        image = data['image']
+            assert key in data, f"Missing key {key} in data"
+        image = data["image"]
         if image.shape[1] == 3:  # RGB
             scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
-            image = (image*scale).sum(1, keepdim=True)
+            image = (image * scale).sum(1, keepdim=True)
         # Shared Encoder
         x = self.relu(self.conv1a(image))
         x = self.relu(self.conv1b(x))
@@ -172,31 +177,37 @@ class SuperPoint(nn.Module):
         scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
         b, _, h, w = scores.shape
         scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
-        scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8)
-        scores = simple_nms(scores, self.conf['nms_radius'])
+        scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
+        scores = simple_nms(scores, self.conf["nms_radius"])
 
         # Discard keypoints near the image borders
-        if self.conf['remove_borders']:
-            pad = self.conf['remove_borders']
+        if self.conf["remove_borders"]:
+            pad = self.conf["remove_borders"]
             scores[:, :pad] = -1
             scores[:, :, :pad] = -1
             scores[:, -pad:] = -1
             scores[:, :, -pad:] = -1
 
         # Extract keypoints
-        best_kp = torch.where(scores > self.conf['detection_threshold'])
+        best_kp = torch.where(scores > self.conf["detection_threshold"])
         scores = scores[best_kp]
 
         # Separate into batches
-        keypoints = [torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i]
-                     for i in range(b)]
+        keypoints = [
+            torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i] for i in range(b)
+        ]
         scores = [scores[best_kp[0] == i] for i in range(b)]
 
         # Keep the k keypoints with highest score
-        if self.conf['max_num_keypoints'] is not None:
-            keypoints, scores = list(zip(*[
-                top_k_keypoints(k, s, self.conf['max_num_keypoints'])
-                for k, s in zip(keypoints, scores)]))
+        if self.conf["max_num_keypoints"] is not None:
+            keypoints, scores = list(
+                zip(
+                    *[
+                        top_k_keypoints(k, s, self.conf["max_num_keypoints"])
+                        for k, s in zip(keypoints, scores)
+                    ]
+                )
+            )
 
         # Convert (h, w) to (x, y)
         keypoints = [torch.flip(k, [1]).float() for k in keypoints]
@@ -207,24 +218,25 @@ class SuperPoint(nn.Module):
         descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
 
         # Extract descriptors
-        descriptors = [sample_descriptors(k[None], d[None], 8)[0]
-                       for k, d in zip(keypoints, descriptors)]
+        descriptors = [
+            sample_descriptors(k[None], d[None], 8)[0]
+            for k, d in zip(keypoints, descriptors)
+        ]
 
         return {
-            'keypoints': torch.stack(keypoints, 0),
-            'keypoint_scores': torch.stack(scores, 0),
-            'descriptors': torch.stack(descriptors, 0).transpose(-1, -2),
+            "keypoints": torch.stack(keypoints, 0),
+            "keypoint_scores": torch.stack(scores, 0),
+            "descriptors": torch.stack(descriptors, 0).transpose(-1, -2),
         }
 
     def extract(self, img: torch.Tensor, **conf) -> dict:
-        """ Perform extraction with online resizing"""
+        """Perform extraction with online resizing"""
         if img.dim() == 3:
             img = img[None]  # add batch dim
         assert img.dim() == 4 and img.shape[0] == 1
         shape = img.shape[-2:][::-1]
-        img, scales = ImagePreprocessor(
-            **{**self.preprocess_conf, **conf})(img)
-        feats = self.forward({'image': img})
-        feats['image_size'] = torch.tensor(shape)[None].to(img).float()
-        feats['keypoints'] = (feats['keypoints'] + .5) / scales[None] - .5
+        img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img)
+        feats = self.forward({"image": img})
+        feats["image_size"] = torch.tensor(shape)[None].to(img).float()
+        feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5
         return feats
diff --git a/third_party/LightGlue/lightglue/utils.py b/third_party/LightGlue/lightglue/utils.py
index 3e06184948948670db1425ed22f5cbb86061a332..e8d30803931aad89e16e9b543959f76fda87389e 100644
--- a/third_party/LightGlue/lightglue/utils.py
+++ b/third_party/LightGlue/lightglue/utils.py
@@ -10,12 +10,12 @@ from types import SimpleNamespace
 
 class ImagePreprocessor:
     default_conf = {
-        'resize': None,  # target edge length, None for no resizing
-        'side': 'long',
-        'interpolation': 'bilinear',
-        'align_corners': None,
-        'antialias': True,
-        'grayscale': False,  # convert rgb to grayscale
+        "resize": None,  # target edge length, None for no resizing
+        "side": "long",
+        "interpolation": "bilinear",
+        "align_corners": None,
+        "antialias": True,
+        "grayscale": False,  # convert rgb to grayscale
     }
 
     def __init__(self, **conf) -> None:
@@ -28,9 +28,12 @@ class ImagePreprocessor:
         h, w = img.shape[-2:]
         if self.conf.resize is not None:
             img = kornia.geometry.transform.resize(
-                img, self.conf.resize, side=self.conf.side,
+                img,
+                self.conf.resize,
+                side=self.conf.side,
                 antialias=self.conf.antialias,
-                align_corners=self.conf.align_corners)
+                align_corners=self.conf.align_corners,
+            )
         scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
         if self.conf.grayscale and img.shape[-3] == 3:
             img = kornia.color.rgb_to_grayscale(img)
@@ -53,28 +56,31 @@ def map_tensor(input_, func: Callable):
         return input_
 
 
-def batch_to_device(batch: dict, device: str = 'cpu',
-                    non_blocking: bool = True):
+def batch_to_device(batch: dict, device: str = "cpu", non_blocking: bool = True):
     """Move batch (dict) to device"""
+
     def _func(tensor):
         return tensor.to(device=device, non_blocking=non_blocking).detach()
+
     return map_tensor(batch, _func)
 
 
 def rbd(data: dict) -> dict:
     """Remove batch dimension from elements in data"""
-    return {k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v
-            for k, v in data.items()}
+    return {
+        k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v
+        for k, v in data.items()
+    }
 
 
 def read_image(path: Path, grayscale: bool = False) -> np.ndarray:
     """Read an image from path as RGB or grayscale"""
     if not Path(path).exists():
-        raise FileNotFoundError(f'No image at path {path}.')
+        raise FileNotFoundError(f"No image at path {path}.")
     mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
     image = cv2.imread(str(path), mode)
     if image is None:
-        raise IOError(f'Could not read image at {path}.')
+        raise IOError(f"Could not read image at {path}.")
     if not grayscale:
         image = image[..., ::-1]
     return image
@@ -87,31 +93,35 @@ def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
     elif image.ndim == 2:
         image = image[None]  # add channel axis
     else:
-        raise ValueError(f'Not an image: {image.shape}')
-    return torch.tensor(image / 255., dtype=torch.float)
+        raise ValueError(f"Not an image: {image.shape}")
+    return torch.tensor(image / 255.0, dtype=torch.float)
 
 
-def resize_image(image: np.ndarray, size: Union[List[int], int],
-                 fn: str = 'max', interp: Optional[str] = 'area',
-                 ) -> np.ndarray:
+def resize_image(
+    image: np.ndarray,
+    size: Union[List[int], int],
+    fn: str = "max",
+    interp: Optional[str] = "area",
+) -> np.ndarray:
     """Resize an image to a fixed size, or according to max or min edge."""
     h, w = image.shape[:2]
 
-    fn = {'max': max, 'min': min}[fn]
+    fn = {"max": max, "min": min}[fn]
     if isinstance(size, int):
         scale = size / fn(h, w)
-        h_new, w_new = int(round(h*scale)), int(round(w*scale))
+        h_new, w_new = int(round(h * scale)), int(round(w * scale))
         scale = (w_new / w, h_new / h)
     elif isinstance(size, (tuple, list)):
         h_new, w_new = size
         scale = (w_new / w, h_new / h)
     else:
-        raise ValueError(f'Incorrect new size: {size}')
+        raise ValueError(f"Incorrect new size: {size}")
     mode = {
-        'linear': cv2.INTER_LINEAR,
-        'cubic': cv2.INTER_CUBIC,
-        'nearest': cv2.INTER_NEAREST,
-        'area': cv2.INTER_AREA}[interp]
+        "linear": cv2.INTER_LINEAR,
+        "cubic": cv2.INTER_CUBIC,
+        "nearest": cv2.INTER_NEAREST,
+        "area": cv2.INTER_AREA,
+    }[interp]
     return cv2.resize(image, (w_new, h_new), interpolation=mode), scale
 
 
@@ -122,13 +132,18 @@ def load_image(path: Path, resize: int = None, **kwargs) -> torch.Tensor:
     return numpy_image_to_torch(image)
 
 
-def match_pair(extractor, matcher,
-               image0: torch.Tensor, image1: torch.Tensor,
-               device: str = 'cpu', **preprocess):
+def match_pair(
+    extractor,
+    matcher,
+    image0: torch.Tensor,
+    image1: torch.Tensor,
+    device: str = "cpu",
+    **preprocess,
+):
     """Match a pair of images (image0, image1) with an extractor and matcher"""
     feats0 = extractor.extract(image0, **preprocess)
     feats1 = extractor.extract(image1, **preprocess)
-    matches01 = matcher({'image0': feats0, 'image1': feats1})
+    matches01 = matcher({"image0": feats0, "image1": feats1})
     data = [feats0, feats1, matches01]
     # remove batch dim and move to target device
     feats0, feats1, matches01 = [batch_to_device(rbd(x), device) for x in data]
diff --git a/third_party/LightGlue/lightglue/viz2d.py b/third_party/LightGlue/lightglue/viz2d.py
index 3b8e65b45c8424a0a1747b6f81f6b1d5bb928471..4999a76fd0001b0b7570ba38639fcf0a30b0c915 100644
--- a/third_party/LightGlue/lightglue/viz2d.py
+++ b/third_party/LightGlue/lightglue/viz2d.py
@@ -14,33 +14,32 @@ import torch
 
 def cm_RdGn(x):
     """Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
-    x = np.clip(x, 0, 1)[..., None]*2
-    c = x*np.array([[0, 1., 0]]) + (2-x)*np.array([[1., 0, 0]])
+    x = np.clip(x, 0, 1)[..., None] * 2
+    c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]])
     return np.clip(c, 0, 1)
 
 
 def cm_BlRdGn(x_):
     """Custom colormap: blue (-1) -> red (0.0) -> green (1)."""
-    x = np.clip(x_, 0, 1)[..., None]*2
-    c = x*np.array([[0, 1., 0, 1.]]) + (2-x)*np.array([[1., 0, 0, 1.]])
+    x = np.clip(x_, 0, 1)[..., None] * 2
+    c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]])
 
-    xn = -np.clip(x_, -1, 0)[..., None]*2
-    cn = xn*np.array([[0, 0.1, 1, 1.]]) + (2-xn)*np.array([[1., 0, 0, 1.]])
+    xn = -np.clip(x_, -1, 0)[..., None] * 2
+    cn = xn * np.array([[0, 0.1, 1, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]])
     out = np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1)
     return out
 
 
 def cm_prune(x_):
-    """ Custom colormap to visualize pruning """
+    """Custom colormap to visualize pruning"""
     if isinstance(x_, torch.Tensor):
         x_ = x_.cpu().numpy()
     max_i = max(x_)
-    norm_x = np.where(x_ == max_i, -1, (x_-1) / 9)
+    norm_x = np.where(x_ == max_i, -1, (x_ - 1) / 9)
     return cm_BlRdGn(norm_x)
 
 
-def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5,
-                adaptive=True):
+def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True):
     """Plot a set of images horizontally.
     Args:
         imgs: list of NumPy RGB (H, W, 3) or PyTorch RGB (3, H, W) or mono (H, W).
@@ -49,9 +48,12 @@ def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5,
         adaptive: whether the figure size should fit the image aspect ratios.
     """
     # conversion to (H, W, 3) for torch.Tensor
-    imgs = [img.permute(1, 2, 0).cpu().numpy()
-            if (isinstance(img, torch.Tensor) and img.dim() == 3) else img
-            for img in imgs]
+    imgs = [
+        img.permute(1, 2, 0).cpu().numpy()
+        if (isinstance(img, torch.Tensor) and img.dim() == 3)
+        else img
+        for img in imgs
+    ]
 
     n = len(imgs)
     if not isinstance(cmaps, (list, tuple)):
@@ -60,10 +62,11 @@ def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5,
     if adaptive:
         ratios = [i.shape[1] / i.shape[0] for i in imgs]  # W / H
     else:
-        ratios = [4/3] * n
-    figsize = [sum(ratios)*4.5, 4.5]
+        ratios = [4 / 3] * n
+    figsize = [sum(ratios) * 4.5, 4.5]
     fig, ax = plt.subplots(
-        1, n, figsize=figsize, dpi=dpi, gridspec_kw={'width_ratios': ratios})
+        1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
+    )
     if n == 1:
         ax = [ax]
     for i in range(n):
@@ -78,7 +81,7 @@ def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5,
     fig.tight_layout(pad=pad)
 
 
-def plot_keypoints(kpts, colors='lime', ps=4, axes=None, a=1.0):
+def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0):
     """Plot keypoints for existing images.
     Args:
         kpts: list of ndarrays of size (N, 2).
@@ -97,8 +100,7 @@ def plot_keypoints(kpts, colors='lime', ps=4, axes=None, a=1.0):
         ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha)
 
 
-def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1., labels=None,
-                 axes=None):
+def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axes=None):
     """Plot matches for a pair of existing images.
     Args:
         kpts0, kpts1: corresponding keypoints of size (N, 2).
@@ -127,12 +129,20 @@ def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1., labels=None,
     if lw > 0:
         for i in range(len(kpts0)):
             line = matplotlib.patches.ConnectionPatch(
-                xyA=(kpts0[i, 0], kpts0[i, 1]), xyB=(kpts1[i, 0], kpts1[i, 1]),
-                coordsA=ax0.transData, coordsB=ax1.transData,
-                axesA=ax0, axesB=ax1,
-                zorder=1, color=color[i], linewidth=lw, clip_on=True,
-                alpha=a, label=None if labels is None else labels[i],
-                picker=5.0)
+                xyA=(kpts0[i, 0], kpts0[i, 1]),
+                xyB=(kpts1[i, 0], kpts1[i, 1]),
+                coordsA=ax0.transData,
+                coordsB=ax1.transData,
+                axesA=ax0,
+                axesB=ax1,
+                zorder=1,
+                color=color[i],
+                linewidth=lw,
+                clip_on=True,
+                alpha=a,
+                label=None if labels is None else labels[i],
+                picker=5.0,
+            )
             line.set_annotation_clip(True)
             fig.add_artist(line)
 
@@ -145,17 +155,30 @@ def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1., labels=None,
         ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
 
 
-def add_text(idx, text, pos=(0.01, 0.99), fs=15, color='w',
-             lcolor='k', lwidth=2, ha='left', va='top'):
+def add_text(
+    idx,
+    text,
+    pos=(0.01, 0.99),
+    fs=15,
+    color="w",
+    lcolor="k",
+    lwidth=2,
+    ha="left",
+    va="top",
+):
     ax = plt.gcf().axes[idx]
-    t = ax.text(*pos, text, fontsize=fs, ha=ha, va=va,
-                color=color, transform=ax.transAxes)
+    t = ax.text(
+        *pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes
+    )
     if lcolor is not None:
-        t.set_path_effects([
-            path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
-            path_effects.Normal()])
+        t.set_path_effects(
+            [
+                path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
+                path_effects.Normal(),
+            ]
+        )
 
 
 def save_plot(path, **kw):
     """Save the current figure without any white margin."""
-    plt.savefig(path, bbox_inches='tight', pad_inches=0, **kw)
+    plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)
diff --git a/third_party/LightGlue/setup.py b/third_party/LightGlue/setup.py
index fc349143002bf0860762a0341ceed47667759a10..2b012e92a208d09e4983317c4eb3c1d8093177e8 100644
--- a/third_party/LightGlue/setup.py
+++ b/third_party/LightGlue/setup.py
@@ -1,24 +1,24 @@
 from pathlib import Path
 from setuptools import setup
 
-description = ['LightGlue']
+description = ["LightGlue"]
 
-with open(str(Path(__file__).parent / 'README.md'), 'r', encoding='utf-8') as f:
+with open(str(Path(__file__).parent / "README.md"), "r", encoding="utf-8") as f:
     readme = f.read()
-with open(str(Path(__file__).parent / 'requirements.txt'), 'r') as f:
-    dependencies = f.read().split('\n')
+with open(str(Path(__file__).parent / "requirements.txt"), "r") as f:
+    dependencies = f.read().split("\n")
 
 setup(
-    name='lightglue',
-    version='0.0',
-    packages=['lightglue'],
-    python_requires='>=3.6',
+    name="lightglue",
+    version="0.0",
+    packages=["lightglue"],
+    python_requires=">=3.6",
     install_requires=dependencies,
-    author='Philipp Lindenberger, Paul-Edouard Sarlin',
+    author="Philipp Lindenberger, Paul-Edouard Sarlin",
     description=description,
     long_description=readme,
     long_description_content_type="text/markdown",
-    url='https://github.com/cvg/LightGlue/',
+    url="https://github.com/cvg/LightGlue/",
     classifiers=[
         "Programming Language :: Python :: 3",
         "License :: OSI Approved :: Apache Software License",
diff --git a/third_party/Roma/demo/demo_fundamental.py b/third_party/Roma/demo/demo_fundamental.py
index 31618d4b06cd56fdd4be9065fb00b826a19e10f9..a71fd5532412fb4c65eb109e8e9f83813c11fd85 100644
--- a/third_party/Roma/demo/demo_fundamental.py
+++ b/third_party/Roma/demo/demo_fundamental.py
@@ -3,11 +3,12 @@ import torch
 import cv2
 from roma import roma_outdoor
 
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 if __name__ == "__main__":
     from argparse import ArgumentParser
+
     parser = ArgumentParser()
     parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
     parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
@@ -19,7 +20,6 @@ if __name__ == "__main__":
     # Create model
     roma_model = roma_outdoor(device=device)
 
-
     W_A, H_A = Image.open(im1_path).size
     W_B, H_B = Image.open(im2_path).size
 
@@ -27,7 +27,12 @@ if __name__ == "__main__":
     warp, certainty = roma_model.match(im1_path, im2_path, device=device)
     # Sample matches for estimation
     matches, certainty = roma_model.sample(warp, certainty)
-    kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)    
+    kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
     F, mask = cv2.findFundamentalMat(
-        kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
-    )
\ No newline at end of file
+        kpts1.cpu().numpy(),
+        kpts2.cpu().numpy(),
+        ransacReprojThreshold=0.2,
+        method=cv2.USAC_MAGSAC,
+        confidence=0.999999,
+        maxIters=10000,
+    )
diff --git a/third_party/Roma/demo/demo_match.py b/third_party/Roma/demo/demo_match.py
index 46413bb2b336e2ef2c0bc48315821e4de0fcb982..69eb07ffb0b480db99252bbb03a9858964e8d5f0 100644
--- a/third_party/Roma/demo/demo_match.py
+++ b/third_party/Roma/demo/demo_match.py
@@ -6,15 +6,18 @@ from roma.utils.utils import tensor_to_pil
 
 from roma import roma_indoor
 
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 if __name__ == "__main__":
     from argparse import ArgumentParser
+
     parser = ArgumentParser()
     parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
     parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
-    parser.add_argument("--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str)
+    parser.add_argument(
+        "--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str
+    )
 
     args, _ = parser.parse_known_args()
     im1_path = args.im_A_path
@@ -36,12 +39,12 @@ if __name__ == "__main__":
     x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
 
     im2_transfer_rgb = F.grid_sample(
-    x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
+        x2[None], warp[:, :W, 2:][None], mode="bilinear", align_corners=False
     )[0]
     im1_transfer_rgb = F.grid_sample(
-    x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
+        x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
     )[0]
-    warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
-    white_im = torch.ones((H,2*W),device=device)
+    warp_im = torch.cat((im2_transfer_rgb, im1_transfer_rgb), dim=2)
+    white_im = torch.ones((H, 2 * W), device=device)
     vis_im = certainty * warp_im + (1 - certainty) * white_im
-    tensor_to_pil(vis_im, unnormalize=False).save(save_path)
\ No newline at end of file
+    tensor_to_pil(vis_im, unnormalize=False).save(save_path)
diff --git a/third_party/Roma/roma/__init__.py b/third_party/Roma/roma/__init__.py
index a7c96481e0a808b68c7b3054a3e34fa0b5c45ab9..a3c12d5247b93a83882edfb45bd127db794e791f 100644
--- a/third_party/Roma/roma/__init__.py
+++ b/third_party/Roma/roma/__init__.py
@@ -2,7 +2,7 @@ import os
 from .models import roma_outdoor, roma_indoor
 
 DEBUG_MODE = False
-RANK = int(os.environ.get('RANK', default = 0))
+RANK = int(os.environ.get("RANK", default=0))
 GLOBAL_STEP = 0
 STEP_SIZE = 1
-LOCAL_RANK = -1
\ No newline at end of file
+LOCAL_RANK = -1
diff --git a/third_party/Roma/roma/benchmarks/hpatches_sequences_homog_benchmark.py b/third_party/Roma/roma/benchmarks/hpatches_sequences_homog_benchmark.py
index 2154a471c73d9e883c3ba8ed1b90d708f4950a63..6417d4d54798360a027a0d11d50fc65cdfae015a 100644
--- a/third_party/Roma/roma/benchmarks/hpatches_sequences_homog_benchmark.py
+++ b/third_party/Roma/roma/benchmarks/hpatches_sequences_homog_benchmark.py
@@ -53,7 +53,7 @@ class HpatchesHomogBenchmark:
         )
         return im_A_coords, im_A_to_im_B
 
-    def benchmark(self, model, model_name = None):
+    def benchmark(self, model, model_name=None):
         n_matches = []
         homog_dists = []
         for seq_idx, seq_name in tqdm(
@@ -69,9 +69,7 @@ class HpatchesHomogBenchmark:
                 H = np.loadtxt(
                     os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
                 )
-                dense_matches, dense_certainty = model.match(
-                    im_A_path, im_B_path
-                )
+                dense_matches, dense_certainty = model.match(im_A_path, im_B_path)
                 good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
                 pos_a, pos_b = self.convert_coordinates(
                     good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
@@ -80,9 +78,9 @@ class HpatchesHomogBenchmark:
                     H_pred, inliers = cv2.findHomography(
                         pos_a,
                         pos_b,
-                        method = cv2.RANSAC,
-                        confidence = 0.99999,
-                        ransacReprojThreshold = 3 * min(w2, h2) / 480,
+                        method=cv2.RANSAC,
+                        confidence=0.99999,
+                        ransacReprojThreshold=3 * min(w2, h2) / 480,
                     )
                 except:
                     H_pred = None
diff --git a/third_party/Roma/roma/benchmarks/megadepth_dense_benchmark.py b/third_party/Roma/roma/benchmarks/megadepth_dense_benchmark.py
index 0600d354b1d0dfa7f8e2b0f8882a4cc08fafeed9..f51a77e15510572b8f594dbc7713a0f348a33fd8 100644
--- a/third_party/Roma/roma/benchmarks/megadepth_dense_benchmark.py
+++ b/third_party/Roma/roma/benchmarks/megadepth_dense_benchmark.py
@@ -6,8 +6,11 @@ from roma.utils import warp_kpts
 from torch.utils.data import ConcatDataset
 import roma
 
+
 class MegadepthDenseBenchmark:
-    def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000) -> None:
+    def __init__(
+        self, data_root="data/megadepth", h=384, w=512, num_samples=2000
+    ) -> None:
         mega = MegadepthBuilder(data_root=data_root)
         self.dataset = ConcatDataset(
             mega.build_scenes(split="test_loftr", ht=h, wt=w)
@@ -49,13 +52,15 @@ class MegadepthDenseBenchmark:
             pck_3_tot = 0.0
             pck_5_tot = 0.0
             sampler = torch.utils.data.WeightedRandomSampler(
-                torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples
+                torch.ones(len(self.dataset)),
+                replacement=False,
+                num_samples=self.num_samples,
             )
             B = batch_size
             dataloader = torch.utils.data.DataLoader(
                 self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler
             )
-            for idx, data in tqdm.tqdm(enumerate(dataloader), disable = roma.RANK > 0):
+            for idx, data in tqdm.tqdm(enumerate(dataloader), disable=roma.RANK > 0):
                 im_A, im_B, depth1, depth2, T_1to2, K1, K2 = (
                     data["im_A"],
                     data["im_B"],
@@ -72,25 +77,36 @@ class MegadepthDenseBenchmark:
                 if roma.DEBUG_MODE:
                     from roma.utils.utils import tensor_to_pil
                     import torch.nn.functional as F
+
                     path = "vis"
                     H, W = model.get_output_resolution()
-                    white_im = torch.ones((B,1,H,W),device="cuda")
+                    white_im = torch.ones((B, 1, H, W), device="cuda")
                     im_B_transfer_rgb = F.grid_sample(
-                        im_B.cuda(), matches[:,:,:W, 2:], mode="bilinear", align_corners=False
+                        im_B.cuda(),
+                        matches[:, :, :W, 2:],
+                        mode="bilinear",
+                        align_corners=False,
                     )
                     warp_im = im_B_transfer_rgb
-                    c_b = certainty[:,None]#(certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None]
+                    c_b = certainty[
+                        :, None
+                    ]  # (certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None]
                     vis_im = c_b * warp_im + (1 - c_b) * white_im
                     for b in range(B):
                         import os
-                        os.makedirs(f"{path}/{model.name}/{idx}_{b}_{H}_{W}",exist_ok=True)
+
+                        os.makedirs(
+                            f"{path}/{model.name}/{idx}_{b}_{H}_{W}", exist_ok=True
+                        )
                         tensor_to_pil(vis_im[b], unnormalize=True).save(
-                            f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg")
+                            f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg"
+                        )
                         tensor_to_pil(im_A[b].cuda(), unnormalize=True).save(
-                            f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg")
+                            f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg"
+                        )
                         tensor_to_pil(im_B[b].cuda(), unnormalize=True).save(
-                            f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg")
-
+                            f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg"
+                        )
 
                 gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
                     gd_tot + gd.mean(),
diff --git a/third_party/Roma/roma/benchmarks/megadepth_pose_estimation_benchmark.py b/third_party/Roma/roma/benchmarks/megadepth_pose_estimation_benchmark.py
index 8007fe8ecad09c33401450ad6b7af1f3dad043d2..5d936a07d550763d0378a23ea83c79cec5d373fe 100644
--- a/third_party/Roma/roma/benchmarks/megadepth_pose_estimation_benchmark.py
+++ b/third_party/Roma/roma/benchmarks/megadepth_pose_estimation_benchmark.py
@@ -7,8 +7,9 @@ import torch.nn.functional as F
 import roma
 import kornia.geometry.epipolar as kepi
 
+
 class MegaDepthPoseEstimationBenchmark:
-    def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
+    def __init__(self, data_root="data/megadepth", scene_names=None) -> None:
         if scene_names is None:
             self.scene_names = [
                 "0015_0.1_0.3.npz",
@@ -25,14 +26,22 @@ class MegaDepthPoseEstimationBenchmark:
         ]
         self.data_root = data_root
 
-    def benchmark(self, model, model_name = None, resolution = None, scale_intrinsics = True, calibrated = True):
-        H,W = model.get_output_resolution()
+    def benchmark(
+        self,
+        model,
+        model_name=None,
+        resolution=None,
+        scale_intrinsics=True,
+        calibrated=True,
+    ):
+        H, W = model.get_output_resolution()
         with torch.no_grad():
             data_root = self.data_root
             tot_e_t, tot_e_R, tot_e_pose = [], [], []
             thresholds = [5, 10, 20]
             for scene_ind in range(len(self.scenes)):
                 import os
+
                 scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
                 scene = self.scenes[scene_ind]
                 pairs = scene["pair_infos"]
@@ -49,16 +58,16 @@ class MegaDepthPoseEstimationBenchmark:
                     T2 = poses[idx2].copy()
                     R2, t2 = T2[:3, :3], T2[:3, 3]
                     R, t = compute_relative_pose(R1, t1, R2, t2)
-                    T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
+                    T1_to_2 = np.concatenate((R, t[:, None]), axis=-1)
                     im_A_path = f"{data_root}/{im_paths[idx1]}"
                     im_B_path = f"{data_root}/{im_paths[idx2]}"
                     dense_matches, dense_certainty = model.match(
                         im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
                     )
-                    sparse_matches,_ = model.sample(
+                    sparse_matches, _ = model.sample(
                         dense_matches, dense_certainty, 5000
                     )
-                    
+
                     im_A = Image.open(im_A_path)
                     w1, h1 = im_A.size
                     im_B = Image.open(im_B_path)
@@ -74,24 +83,20 @@ class MegaDepthPoseEstimationBenchmark:
                         K2[:2] = K2[:2] * scale2
 
                     kpts1 = sparse_matches[:, :2]
-                    kpts1 = (
-                        np.stack(
-                            (
-                                w1 * (kpts1[:, 0] + 1) / 2,
-                                h1 * (kpts1[:, 1] + 1) / 2,
-                            ),
-                            axis=-1,
-                        )
+                    kpts1 = np.stack(
+                        (
+                            w1 * (kpts1[:, 0] + 1) / 2,
+                            h1 * (kpts1[:, 1] + 1) / 2,
+                        ),
+                        axis=-1,
                     )
                     kpts2 = sparse_matches[:, 2:]
-                    kpts2 = (
-                        np.stack(
-                            (
-                                w2 * (kpts2[:, 0] + 1) / 2,
-                                h2 * (kpts2[:, 1] + 1) / 2,
-                            ),
-                            axis=-1,
-                        )
+                    kpts2 = np.stack(
+                        (
+                            w2 * (kpts2[:, 0] + 1) / 2,
+                            h2 * (kpts2[:, 1] + 1) / 2,
+                        ),
+                        axis=-1,
                     )
 
                     for _ in range(5):
@@ -99,9 +104,12 @@ class MegaDepthPoseEstimationBenchmark:
                         kpts1 = kpts1[shuffling]
                         kpts2 = kpts2[shuffling]
                         try:
-                            threshold = 0.5 
+                            threshold = 0.5
                             if calibrated:
-                                norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
+                                norm_threshold = threshold / (
+                                    np.mean(np.abs(K1[:2, :2]))
+                                    + np.mean(np.abs(K2[:2, :2]))
+                                )
                                 R_est, t_est, mask = estimate_pose(
                                     kpts1,
                                     kpts2,
diff --git a/third_party/Roma/roma/benchmarks/scannet_benchmark.py b/third_party/Roma/roma/benchmarks/scannet_benchmark.py
index 853af0d0ebef4dfefe2632eb49e4156ea791ee76..3187c2acf79f5af8f64397f55f6df40af327945b 100644
--- a/third_party/Roma/roma/benchmarks/scannet_benchmark.py
+++ b/third_party/Roma/roma/benchmarks/scannet_benchmark.py
@@ -10,7 +10,7 @@ class ScanNetBenchmark:
     def __init__(self, data_root="data/scannet") -> None:
         self.data_root = data_root
 
-    def benchmark(self, model, model_name = None):
+    def benchmark(self, model, model_name=None):
         model.train(False)
         with torch.no_grad():
             data_root = self.data_root
@@ -24,20 +24,20 @@ class ScanNetBenchmark:
                 scene = pairs[pairind]
                 scene_name = f"scene0{scene[0]}_00"
                 im_A_path = osp.join(
-                        self.data_root,
-                        "scans_test",
-                        scene_name,
-                        "color",
-                        f"{scene[2]}.jpg",
-                    )
+                    self.data_root,
+                    "scans_test",
+                    scene_name,
+                    "color",
+                    f"{scene[2]}.jpg",
+                )
                 im_A = Image.open(im_A_path)
                 im_B_path = osp.join(
-                        self.data_root,
-                        "scans_test",
-                        scene_name,
-                        "color",
-                        f"{scene[3]}.jpg",
-                    )
+                    self.data_root,
+                    "scans_test",
+                    scene_name,
+                    "color",
+                    f"{scene[3]}.jpg",
+                )
                 im_B = Image.open(im_B_path)
                 T_gt = rel_pose[pairind].reshape(3, 4)
                 R, t = T_gt[:3, :3], T_gt[:3, 3]
@@ -76,24 +76,20 @@ class ScanNetBenchmark:
 
                 offset = 0.5
                 kpts1 = sparse_matches[:, :2]
-                kpts1 = (
-                    np.stack(
-                        (
-                            w1 * (kpts1[:, 0] + 1) / 2 - offset,
-                            h1 * (kpts1[:, 1] + 1) / 2 - offset,
-                        ),
-                        axis=-1,
-                    )
+                kpts1 = np.stack(
+                    (
+                        w1 * (kpts1[:, 0] + 1) / 2 - offset,
+                        h1 * (kpts1[:, 1] + 1) / 2 - offset,
+                    ),
+                    axis=-1,
                 )
                 kpts2 = sparse_matches[:, 2:]
-                kpts2 = (
-                    np.stack(
-                        (
-                            w2 * (kpts2[:, 0] + 1) / 2 - offset,
-                            h2 * (kpts2[:, 1] + 1) / 2 - offset,
-                        ),
-                        axis=-1,
-                    )
+                kpts2 = np.stack(
+                    (
+                        w2 * (kpts2[:, 0] + 1) / 2 - offset,
+                        h2 * (kpts2[:, 1] + 1) / 2 - offset,
+                    ),
+                    axis=-1,
                 )
                 for _ in range(5):
                     shuffling = np.random.permutation(np.arange(len(kpts1)))
@@ -101,7 +97,8 @@ class ScanNetBenchmark:
                     kpts2 = kpts2[shuffling]
                     try:
                         norm_threshold = 0.5 / (
-                        np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
+                            np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))
+                        )
                         R_est, t_est, mask = estimate_pose(
                             kpts1,
                             kpts2,
diff --git a/third_party/Roma/roma/checkpointing/checkpoint.py b/third_party/Roma/roma/checkpointing/checkpoint.py
index 8995efeb54f4d558127ea63423fa958c64e9088f..6372d89fe86c00c7acedf015886717bfeca7bb1f 100644
--- a/third_party/Roma/roma/checkpointing/checkpoint.py
+++ b/third_party/Roma/roma/checkpointing/checkpoint.py
@@ -7,6 +7,7 @@ import gc
 
 import roma
 
+
 class CheckPoint:
     def __init__(self, dir=None, name="tmp"):
         self.name = name
@@ -19,7 +20,7 @@ class CheckPoint:
         optimizer,
         lr_scheduler,
         n,
-        ):
+    ):
         if roma.RANK == 0:
             assert model is not None
             if isinstance(model, (DataParallel, DistributedDataParallel)):
@@ -32,14 +33,14 @@ class CheckPoint:
             }
             torch.save(states, self.dir + self.name + f"_latest.pth")
             logger.info(f"Saved states {list(states.keys())}, at step {n}")
-    
+
     def load(
         self,
         model,
         optimizer,
         lr_scheduler,
         n,
-        ):
+    ):
         if os.path.exists(self.dir + self.name + f"_latest.pth") and roma.RANK == 0:
             states = torch.load(self.dir + self.name + f"_latest.pth")
             if "model" in states:
@@ -57,4 +58,4 @@ class CheckPoint:
             del states
             gc.collect()
             torch.cuda.empty_cache()
-        return model, optimizer, lr_scheduler, n
\ No newline at end of file
+        return model, optimizer, lr_scheduler, n
diff --git a/third_party/Roma/roma/datasets/__init__.py b/third_party/Roma/roma/datasets/__init__.py
index b60c709926a4a7bd019b73eac10879063a996c90..6a11f122e222f0a9eded4afd3dd0b900826063e8 100644
--- a/third_party/Roma/roma/datasets/__init__.py
+++ b/third_party/Roma/roma/datasets/__init__.py
@@ -1,2 +1,2 @@
 from .megadepth import MegadepthBuilder
-from .scannet import ScanNetBuilder
\ No newline at end of file
+from .scannet import ScanNetBuilder
diff --git a/third_party/Roma/roma/datasets/megadepth.py b/third_party/Roma/roma/datasets/megadepth.py
index 5deee5ac30c439a9f300c0ad2271f141931020c0..75cb72ded02c80d1ad6bce0d0269626ee49a9275 100644
--- a/third_party/Roma/roma/datasets/megadepth.py
+++ b/third_party/Roma/roma/datasets/megadepth.py
@@ -10,6 +10,7 @@ import roma
 from roma.utils import *
 import math
 
+
 class MegadepthScene:
     def __init__(
         self,
@@ -22,18 +23,20 @@ class MegadepthScene:
         shake_t=0,
         rot_prob=0.0,
         normalize=True,
-        max_num_pairs = 100_000,
-        scene_name = None,
-        use_horizontal_flip_aug = False,
-        use_single_horizontal_flip_aug = False,
-        colorjiggle_params = None,
-        random_eraser = None,
-        use_randaug = False,
-        randaug_params = None,
-        randomize_size = False,
+        max_num_pairs=100_000,
+        scene_name=None,
+        use_horizontal_flip_aug=False,
+        use_single_horizontal_flip_aug=False,
+        colorjiggle_params=None,
+        random_eraser=None,
+        use_randaug=False,
+        randaug_params=None,
+        randomize_size=False,
     ) -> None:
         self.data_root = data_root
-        self.scene_name = os.path.splitext(scene_name)[0]+f"_{min_overlap}_{max_overlap}"
+        self.scene_name = (
+            os.path.splitext(scene_name)[0] + f"_{min_overlap}_{max_overlap}"
+        )
         self.image_paths = scene_info["image_paths"]
         self.depth_paths = scene_info["depth_paths"]
         self.intrinsics = scene_info["intrinsics"]
@@ -51,18 +54,18 @@ class MegadepthScene:
             self.overlaps = self.overlaps[pairinds]
         if randomize_size:
             area = ht * wt
-            s = int(16 * (math.sqrt(area)//16))
-            sizes = ((ht,wt), (s,s), (wt,ht))
+            s = int(16 * (math.sqrt(area) // 16))
+            sizes = ((ht, wt), (s, s), (wt, ht))
             choice = roma.RANK % 3
-            ht, wt = sizes[choice] 
+            ht, wt = sizes[choice]
         # counts, bins = np.histogram(self.overlaps,20)
         # print(counts)
         self.im_transform_ops = get_tuple_transform_ops(
-            resize=(ht, wt), normalize=normalize, colorjiggle_params = colorjiggle_params,
+            resize=(ht, wt),
+            normalize=normalize,
+            colorjiggle_params=colorjiggle_params,
         )
-        self.depth_transform_ops = get_depth_tuple_transform_ops(
-                resize=(ht, wt)
-            )
+        self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt))
         self.wt, self.ht = wt, ht
         self.shake_t = shake_t
         self.random_eraser = random_eraser
@@ -75,17 +78,19 @@ class MegadepthScene:
     def load_im(self, im_path):
         im = Image.open(im_path)
         return im
-    
-    def horizontal_flip(self, im_A, im_B, depth_A, depth_B,  K_A, K_B):
+
+    def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
         im_A = im_A.flip(-1)
         im_B = im_B.flip(-1)
-        depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) 
-        flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
-        K_A = flip_mat@K_A  
-        K_B = flip_mat@K_B  
-        
+        depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
+        flip_mat = torch.tensor([[-1, 0, self.wt], [0, 1, 0], [0, 0, 1.0]]).to(
+            K_A.device
+        )
+        K_A = flip_mat @ K_A
+        K_B = flip_mat @ K_B
+
         return im_A, im_B, depth_A, depth_B, K_A, K_B
-    
+
     def load_depth(self, depth_ref, crop=None):
         depth = np.array(h5py.File(depth_ref, "r")["depth"])
         return torch.from_numpy(depth)
@@ -140,29 +145,31 @@ class MegadepthScene:
         depth_A, depth_B = self.depth_transform_ops(
             (depth_A[None, None], depth_B[None, None])
         )
-        
-        [im_A, im_B, depth_A, depth_B], t = self.rand_shake(im_A, im_B, depth_A, depth_B)
+
+        [im_A, im_B, depth_A, depth_B], t = self.rand_shake(
+            im_A, im_B, depth_A, depth_B
+        )
         K1[:2, 2] += t
         K2[:2, 2] += t
-        
+
         im_A, im_B = im_A[None], im_B[None]
         if self.random_eraser is not None:
             im_A, depth_A = self.random_eraser(im_A, depth_A)
             im_B, depth_B = self.random_eraser(im_B, depth_B)
-                
+
         if self.use_horizontal_flip_aug:
             if np.random.rand() > 0.5:
-                im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
+                im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(
+                    im_A, im_B, depth_A, depth_B, K1, K2
+                )
         if self.use_single_horizontal_flip_aug:
             if np.random.rand() > 0.5:
                 im_B, depth_B, K2 = self.single_horizontal_flip(im_B, depth_B, K2)
-        
+
         if roma.DEBUG_MODE:
-            tensor_to_pil(im_A[0], unnormalize=True).save(
-                            f"vis/im_A.jpg")
-            tensor_to_pil(im_B[0], unnormalize=True).save(
-                            f"vis/im_B.jpg")
-            
+            tensor_to_pil(im_A[0], unnormalize=True).save(f"vis/im_A.jpg")
+            tensor_to_pil(im_B[0], unnormalize=True).save(f"vis/im_B.jpg")
+
         data_dict = {
             "im_A": im_A[0],
             "im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
@@ -175,25 +182,53 @@ class MegadepthScene:
             "T_1to2": T_1to2,
             "im_A_path": im_A_ref,
             "im_B_path": im_B_ref,
-            
         }
         return data_dict
 
 
 class MegadepthBuilder:
-    def __init__(self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) -> None:
+    def __init__(
+        self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore=True
+    ) -> None:
         self.data_root = data_root
         self.scene_info_root = os.path.join(data_root, "prep_scene_info")
         self.all_scenes = os.listdir(self.scene_info_root)
         self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
         # LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those
-        self.loftr_ignore_scenes = set(['0121.npy', '0133.npy', '0168.npy', '0178.npy', '0229.npy', '0349.npy', '0412.npy', '0430.npy', '0443.npy', '1001.npy', '5014.npy', '5015.npy', '5016.npy'])
-        self.imc21_scenes = set(['0008.npy', '0019.npy', '0021.npy', '0024.npy', '0025.npy', '0032.npy', '0063.npy', '1589.npy'])
+        self.loftr_ignore_scenes = set(
+            [
+                "0121.npy",
+                "0133.npy",
+                "0168.npy",
+                "0178.npy",
+                "0229.npy",
+                "0349.npy",
+                "0412.npy",
+                "0430.npy",
+                "0443.npy",
+                "1001.npy",
+                "5014.npy",
+                "5015.npy",
+                "5016.npy",
+            ]
+        )
+        self.imc21_scenes = set(
+            [
+                "0008.npy",
+                "0019.npy",
+                "0021.npy",
+                "0024.npy",
+                "0025.npy",
+                "0032.npy",
+                "0063.npy",
+                "1589.npy",
+            ]
+        )
         self.test_scenes_loftr = ["0015.npy", "0022.npy"]
         self.loftr_ignore = loftr_ignore
         self.imc21_ignore = imc21_ignore
 
-    def build_scenes(self, split="train", min_overlap=0.0, scene_names = None, **kwargs):
+    def build_scenes(self, split="train", min_overlap=0.0, scene_names=None, **kwargs):
         if split == "train":
             scene_names = set(self.all_scenes) - set(self.test_scenes)
         elif split == "train_loftr":
@@ -217,7 +252,11 @@ class MegadepthBuilder:
             ).item()
             scenes.append(
                 MegadepthScene(
-                    self.data_root, scene_info, min_overlap=min_overlap,scene_name = scene_name, **kwargs
+                    self.data_root,
+                    scene_info,
+                    min_overlap=min_overlap,
+                    scene_name=scene_name,
+                    **kwargs,
                 )
             )
         return scenes
diff --git a/third_party/Roma/roma/datasets/scannet.py b/third_party/Roma/roma/datasets/scannet.py
index 704ea57259afdfbbca627ad143bee97a0a79d41c..91bea57c9d1ae2773c11a9c8d47f31026a2c227b 100644
--- a/third_party/Roma/roma/datasets/scannet.py
+++ b/third_party/Roma/roma/datasets/scannet.py
@@ -5,10 +5,7 @@ import cv2
 import h5py
 import numpy as np
 import torch
-from torch.utils.data import (
-    Dataset,
-    DataLoader,
-    ConcatDataset)
+from torch.utils.data import Dataset, DataLoader, ConcatDataset
 
 import torchvision.transforms.functional as tvf
 import kornia.augmentation as K
@@ -19,22 +16,36 @@ from roma.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
 from roma.utils.transforms import GeometricSequential
 from tqdm import tqdm
 
+
 class ScanNetScene:
-    def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.,use_horizontal_flip_aug = False,
-) -> None:
-        self.scene_root = osp.join(data_root,"scans","scans_train")
-        self.data_names = scene_info['name']
-        self.overlaps = scene_info['score']
+    def __init__(
+        self,
+        data_root,
+        scene_info,
+        ht=384,
+        wt=512,
+        min_overlap=0.0,
+        shake_t=0,
+        rot_prob=0.0,
+        use_horizontal_flip_aug=False,
+    ) -> None:
+        self.scene_root = osp.join(data_root, "scans", "scans_train")
+        self.data_names = scene_info["name"]
+        self.overlaps = scene_info["score"]
         # Only sample 10s
-        valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0
+        valid = (self.data_names[:, -2:] % 10).sum(axis=-1) == 0
         self.overlaps = self.overlaps[valid]
         self.data_names = self.data_names[valid]
         if len(self.data_names) > 10000:
-            pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False)
+            pairinds = np.random.choice(
+                np.arange(0, len(self.data_names)), 10000, replace=False
+            )
             self.data_names = self.data_names[pairinds]
             self.overlaps = self.overlaps[pairinds]
         self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True)
-        self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False)
+        self.depth_transform_ops = get_depth_tuple_transform_ops(
+            resize=(ht, wt), normalize=False
+        )
         self.wt, self.ht = wt, ht
         self.shake_t = shake_t
         self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
@@ -43,7 +54,7 @@ class ScanNetScene:
     def load_im(self, im_B, crop=None):
         im = Image.open(im_B)
         return im
-    
+
     def load_depth(self, depth_ref, crop=None):
         depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED)
         depth = depth / 1000
@@ -52,64 +63,73 @@ class ScanNetScene:
 
     def __len__(self):
         return len(self.data_names)
-    
+
     def scale_intrinsic(self, K, wi, hi):
-        sx, sy = self.wt / wi, self.ht /  hi
-        sK = torch.tensor([[sx, 0, 0],
-                        [0, sy, 0],
-                        [0, 0, 1]])
-        return sK@K
+        sx, sy = self.wt / wi, self.ht / hi
+        sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
+        return sK @ K
 
-    def horizontal_flip(self, im_A, im_B, depth_A, depth_B,  K_A, K_B):
+    def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
         im_A = im_A.flip(-1)
         im_B = im_B.flip(-1)
-        depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) 
-        flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
-        K_A = flip_mat@K_A  
-        K_B = flip_mat@K_B  
-        
+        depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
+        flip_mat = torch.tensor([[-1, 0, self.wt], [0, 1, 0], [0, 0, 1.0]]).to(
+            K_A.device
+        )
+        K_A = flip_mat @ K_A
+        K_B = flip_mat @ K_B
+
         return im_A, im_B, depth_A, depth_B, K_A, K_B
-    def read_scannet_pose(self,path):
-        """ Read ScanNet's Camera2World pose and transform it to World2Camera.
-        
+
+    def read_scannet_pose(self, path):
+        """Read ScanNet's Camera2World pose and transform it to World2Camera.
+
         Returns:
             pose_w2c (np.ndarray): (4, 4)
         """
-        cam2world = np.loadtxt(path, delimiter=' ')
+        cam2world = np.loadtxt(path, delimiter=" ")
         world2cam = np.linalg.inv(cam2world)
         return world2cam
 
-
-    def read_scannet_intrinsic(self,path):
-        """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
-        """
-        intrinsic = np.loadtxt(path, delimiter=' ')
-        return torch.tensor(intrinsic[:-1, :-1], dtype = torch.float)
+    def read_scannet_intrinsic(self, path):
+        """Read ScanNet's intrinsic matrix and return the 3x3 matrix."""
+        intrinsic = np.loadtxt(path, delimiter=" ")
+        return torch.tensor(intrinsic[:-1, :-1], dtype=torch.float)
 
     def __getitem__(self, pair_idx):
         # read intrinsics of original size
         data_name = self.data_names[pair_idx]
         scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name
-        scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
-        
+        scene_name = f"scene{scene_name:04d}_{scene_sub_name:02d}"
+
         # read the intrinsic of depthmap
-        K1 = K2 =  self.read_scannet_intrinsic(osp.join(self.scene_root,
-                       scene_name,
-                       'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter
+        K1 = K2 = self.read_scannet_intrinsic(
+            osp.join(self.scene_root, scene_name, "intrinsic", "intrinsic_color.txt")
+        )  # the depth K is not the same, but doesnt really matter
         # read and compute relative poses
-        T1 =  self.read_scannet_pose(osp.join(self.scene_root,
-                       scene_name,
-                       'pose', f'{stem_name_1}.txt'))
-        T2 =  self.read_scannet_pose(osp.join(self.scene_root,
-                       scene_name,
-                       'pose', f'{stem_name_2}.txt'))
-        T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4]  # (4, 4)
+        T1 = self.read_scannet_pose(
+            osp.join(self.scene_root, scene_name, "pose", f"{stem_name_1}.txt")
+        )
+        T2 = self.read_scannet_pose(
+            osp.join(self.scene_root, scene_name, "pose", f"{stem_name_2}.txt")
+        )
+        T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[
+            :4, :4
+        ]  # (4, 4)
 
         # Load positive pair data
-        im_A_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg')
-        im_B_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg')
-        depth_A_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png')
-        depth_B_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png')
+        im_A_ref = os.path.join(
+            self.scene_root, scene_name, "color", f"{stem_name_1}.jpg"
+        )
+        im_B_ref = os.path.join(
+            self.scene_root, scene_name, "color", f"{stem_name_2}.jpg"
+        )
+        depth_A_ref = os.path.join(
+            self.scene_root, scene_name, "depth", f"{stem_name_1}.png"
+        )
+        depth_B_ref = os.path.join(
+            self.scene_root, scene_name, "depth", f"{stem_name_2}.png"
+        )
 
         im_A = self.load_im(im_A_ref)
         im_B = self.load_im(im_B_ref)
@@ -121,40 +141,51 @@ class ScanNetScene:
         K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
         # Process images
         im_A, im_B = self.im_transform_ops((im_A, im_B))
-        depth_A, depth_B = self.depth_transform_ops((depth_A[None,None], depth_B[None,None]))
+        depth_A, depth_B = self.depth_transform_ops(
+            (depth_A[None, None], depth_B[None, None])
+        )
         if self.use_horizontal_flip_aug:
             if np.random.rand() > 0.5:
-                im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
-
-        data_dict = {'im_A': im_A,
-                    'im_B': im_B,
-                    'im_A_depth': depth_A[0,0],
-                    'im_B_depth': depth_B[0,0],
-                    'K1': K1,
-                    'K2': K2,
-                    'T_1to2':T_1to2,
-                    }
+                im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(
+                    im_A, im_B, depth_A, depth_B, K1, K2
+                )
+
+        data_dict = {
+            "im_A": im_A,
+            "im_B": im_B,
+            "im_A_depth": depth_A[0, 0],
+            "im_B_depth": depth_B[0, 0],
+            "K1": K1,
+            "K2": K2,
+            "T_1to2": T_1to2,
+        }
         return data_dict
 
 
 class ScanNetBuilder:
-    def __init__(self, data_root = 'data/scannet') -> None:
+    def __init__(self, data_root="data/scannet") -> None:
         self.data_root = data_root
-        self.scene_info_root = os.path.join(data_root,'scannet_indices')
+        self.scene_info_root = os.path.join(data_root, "scannet_indices")
         self.all_scenes = os.listdir(self.scene_info_root)
-        
-    def build_scenes(self, split = 'train', min_overlap=0., **kwargs):
+
+    def build_scenes(self, split="train", min_overlap=0.0, **kwargs):
         # Note: split doesn't matter here as we always use same scannet_train scenes
         scene_names = self.all_scenes
         scenes = []
-        for scene_name in tqdm(scene_names, disable = roma.RANK > 0):
-            scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True)
-            scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs))
+        for scene_name in tqdm(scene_names, disable=roma.RANK > 0):
+            scene_info = np.load(
+                os.path.join(self.scene_info_root, scene_name), allow_pickle=True
+            )
+            scenes.append(
+                ScanNetScene(
+                    self.data_root, scene_info, min_overlap=min_overlap, **kwargs
+                )
+            )
         return scenes
-    
-    def weight_scenes(self, concat_dataset, alpha=.5):
+
+    def weight_scenes(self, concat_dataset, alpha=0.5):
         ns = []
         for d in concat_dataset.datasets:
             ns.append(len(d))
-        ws = torch.cat([torch.ones(n)/n**alpha for n in ns])
+        ws = torch.cat([torch.ones(n) / n**alpha for n in ns])
         return ws
diff --git a/third_party/Roma/roma/losses/__init__.py b/third_party/Roma/roma/losses/__init__.py
index 2e08abacfc0f83d7de0f2ddc0583766a80bf53cf..12cb6d40b90ca3ccf712321f78c033401db865fb 100644
--- a/third_party/Roma/roma/losses/__init__.py
+++ b/third_party/Roma/roma/losses/__init__.py
@@ -1 +1 @@
-from .robust_loss import RobustLosses
\ No newline at end of file
+from .robust_loss import RobustLosses
diff --git a/third_party/Roma/roma/losses/robust_loss.py b/third_party/Roma/roma/losses/robust_loss.py
index b932b2706f619c083485e1be0d86eec44ead83ef..cd9fd5bbc9c2d01bb6dd40823e350b588bd598b3 100644
--- a/third_party/Roma/roma/losses/robust_loss.py
+++ b/third_party/Roma/roma/losses/robust_loss.py
@@ -7,6 +7,7 @@ import wandb
 import roma
 import math
 
+
 class RobustLosses(nn.Module):
     def __init__(
         self,
@@ -17,12 +18,12 @@ class RobustLosses(nn.Module):
         local_loss=True,
         local_dist=4.0,
         local_largest_scale=8,
-        smooth_mask = False,
-        depth_interpolation_mode = "bilinear",
-        mask_depth_loss = False,
-        relative_depth_error_threshold = 0.05,
-        alpha = 1.,
-        c = 1e-3,
+        smooth_mask=False,
+        depth_interpolation_mode="bilinear",
+        mask_depth_loss=False,
+        relative_depth_error_threshold=0.05,
+        alpha=1.0,
+        c=1e-3,
     ):
         super().__init__()
         self.robust = robust  # measured in pixels
@@ -45,68 +46,103 @@ class RobustLosses(nn.Module):
             B, C, H, W = scale_gm_cls.shape
             device = x2.device
             cls_res = round(math.sqrt(C))
-            G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)])
-            G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2)
-            GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices
-        cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction  = 'none')[prob > 0.99]
+            G = torch.meshgrid(
+                *[
+                    torch.linspace(
+                        -1 + 1 / cls_res, 1 - 1 / cls_res, steps=cls_res, device=device
+                    )
+                    for _ in range(2)
+                ]
+            )
+            G = torch.stack((G[1], G[0]), dim=-1).reshape(C, 2)
+            GT = (
+                (G[None, :, None, None, :] - x2[:, None])
+                .norm(dim=-1)
+                .min(dim=1)
+                .indices
+            )
+        cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction="none")[prob > 0.99]
         if not torch.any(cls_loss):
-            cls_loss = (certainty_loss * 0.0)  # Prevent issues where prob is 0 everywhere
+            cls_loss = certainty_loss * 0.0  # Prevent issues where prob is 0 everywhere
 
-        certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob)
+        certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:, 0], prob)
         losses = {
             f"gm_certainty_loss_{scale}": certainty_loss.mean(),
             f"gm_cls_loss_{scale}": cls_loss.mean(),
         }
-        wandb.log(losses, step = roma.GLOBAL_STEP)
+        wandb.log(losses, step=roma.GLOBAL_STEP)
         return losses
 
-    def delta_cls_loss(self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale):
+    def delta_cls_loss(
+        self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale
+    ):
         with torch.no_grad():
             B, C, H, W = delta_cls.shape
             device = x2.device
             cls_res = round(math.sqrt(C))
-            G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)])
-            G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) * offset_scale
-            GT = (G[None,:,None,None,:] + flow_pre_delta[:,None] - x2[:,None]).norm(dim=-1).min(dim=1).indices
-        cls_loss = F.cross_entropy(delta_cls, GT, reduction  = 'none')[prob > 0.99]
+            G = torch.meshgrid(
+                *[
+                    torch.linspace(
+                        -1 + 1 / cls_res, 1 - 1 / cls_res, steps=cls_res, device=device
+                    )
+                    for _ in range(2)
+                ]
+            )
+            G = torch.stack((G[1], G[0]), dim=-1).reshape(C, 2) * offset_scale
+            GT = (
+                (G[None, :, None, None, :] + flow_pre_delta[:, None] - x2[:, None])
+                .norm(dim=-1)
+                .min(dim=1)
+                .indices
+            )
+        cls_loss = F.cross_entropy(delta_cls, GT, reduction="none")[prob > 0.99]
         if not torch.any(cls_loss):
-            cls_loss = (certainty_loss * 0.0)  # Prevent issues where prob is 0 everywhere
-        certainty_loss = F.binary_cross_entropy_with_logits(certainty[:,0], prob)
+            cls_loss = certainty_loss * 0.0  # Prevent issues where prob is 0 everywhere
+        certainty_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob)
         losses = {
             f"delta_certainty_loss_{scale}": certainty_loss.mean(),
             f"delta_cls_loss_{scale}": cls_loss.mean(),
         }
-        wandb.log(losses, step = roma.GLOBAL_STEP)
+        wandb.log(losses, step=roma.GLOBAL_STEP)
         return losses
 
-    def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"):
-        epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1)
+    def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode="delta"):
+        epe = (flow.permute(0, 2, 3, 1) - x2).norm(dim=-1)
         if scale == 1:
-            pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean()
-            wandb.log({"train_pck_05": pck_05}, step = roma.GLOBAL_STEP)
+            pck_05 = (epe[prob > 0.99] < 0.5 * (2 / 512)).float().mean()
+            wandb.log({"train_pck_05": pck_05}, step=roma.GLOBAL_STEP)
 
         ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob)
         a = self.alpha
         cs = self.c * scale
         x = epe[prob > 0.99]
-        reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2)
+        reg_loss = cs**a * ((x / (cs)) ** 2 + 1**2) ** (a / 2)
         if not torch.any(reg_loss):
-            reg_loss = (ce_loss * 0.0)  # Prevent issues where prob is 0 everywhere
+            reg_loss = ce_loss * 0.0  # Prevent issues where prob is 0 everywhere
         losses = {
             f"{mode}_certainty_loss_{scale}": ce_loss.mean(),
             f"{mode}_regression_loss_{scale}": reg_loss.mean(),
         }
-        wandb.log(losses, step = roma.GLOBAL_STEP)
+        wandb.log(losses, step=roma.GLOBAL_STEP)
         return losses
 
     def forward(self, corresps, batch):
         scales = list(corresps.keys())
         tot_loss = 0.0
         # scale_weights due to differences in scale for regression gradients and classification gradients
-        scale_weights = {1:1, 2:1, 4:1, 8:1, 16:1}
+        scale_weights = {1: 1, 2: 1, 4: 1, 8: 1, 16: 1}
         for scale in scales:
             scale_corresps = corresps[scale]
-            scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_cls, scale_gm_certainty, flow, scale_gm_flow = (
+            (
+                scale_certainty,
+                flow_pre_delta,
+                delta_cls,
+                offset_scale,
+                scale_gm_cls,
+                scale_gm_certainty,
+                flow,
+                scale_gm_flow,
+            ) = (
                 scale_corresps["certainty"],
                 scale_corresps["flow_pre_delta"],
                 scale_corresps.get("delta_cls"),
@@ -115,43 +151,72 @@ class RobustLosses(nn.Module):
                 scale_corresps.get("gm_certainty"),
                 scale_corresps["flow"],
                 scale_corresps.get("gm_flow"),
-
             )
             flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d")
             b, h, w, d = flow_pre_delta.shape
-            gt_warp, gt_prob = get_gt_warp(                
-            batch["im_A_depth"],
-            batch["im_B_depth"],
-            batch["T_1to2"],
-            batch["K1"],
-            batch["K2"],
-            H=h,
-            W=w,
-        )
+            gt_warp, gt_prob = get_gt_warp(
+                batch["im_A_depth"],
+                batch["im_B_depth"],
+                batch["T_1to2"],
+                batch["K1"],
+                batch["K2"],
+                H=h,
+                W=w,
+            )
             x2 = gt_warp.float()
             prob = gt_prob
-            
+
             if self.local_largest_scale >= scale:
                 prob = prob * (
-                        F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[:, 0]
-                        < (2 / 512) * (self.local_dist[scale] * scale))
-            
+                    F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[
+                        :, 0
+                    ]
+                    < (2 / 512) * (self.local_dist[scale] * scale)
+                )
+
             if scale_gm_cls is not None:
-                gm_cls_losses = self.gm_cls_loss(x2, prob, scale_gm_cls, scale_gm_certainty, scale)
-                gm_loss = self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"] + gm_cls_losses[f"gm_cls_loss_{scale}"]
+                gm_cls_losses = self.gm_cls_loss(
+                    x2, prob, scale_gm_cls, scale_gm_certainty, scale
+                )
+                gm_loss = (
+                    self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"]
+                    + gm_cls_losses[f"gm_cls_loss_{scale}"]
+                )
                 tot_loss = tot_loss + scale_weights[scale] * gm_loss
             elif scale_gm_flow is not None:
-                gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm")
-                gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"]
+                gm_flow_losses = self.regression_loss(
+                    x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode="gm"
+                )
+                gm_loss = (
+                    self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"]
+                    + gm_flow_losses[f"gm_regression_loss_{scale}"]
+                )
                 tot_loss = tot_loss + scale_weights[scale] * gm_loss
-            
+
             if delta_cls is not None:
-                delta_cls_losses = self.delta_cls_loss(x2, prob, flow_pre_delta, delta_cls, scale_certainty, scale, offset_scale)
-                delta_cls_loss = self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"] + delta_cls_losses[f"delta_cls_loss_{scale}"]
+                delta_cls_losses = self.delta_cls_loss(
+                    x2,
+                    prob,
+                    flow_pre_delta,
+                    delta_cls,
+                    scale_certainty,
+                    scale,
+                    offset_scale,
+                )
+                delta_cls_loss = (
+                    self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"]
+                    + delta_cls_losses[f"delta_cls_loss_{scale}"]
+                )
                 tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss
             else:
-                delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale)
-                reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"]
+                delta_regression_losses = self.regression_loss(
+                    x2, prob, flow, scale_certainty, scale
+                )
+                reg_loss = (
+                    self.ce_weight
+                    * delta_regression_losses[f"delta_certainty_loss_{scale}"]
+                    + delta_regression_losses[f"delta_regression_loss_{scale}"]
+                )
                 tot_loss = tot_loss + scale_weights[scale] * reg_loss
-            prev_epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1).detach()
+            prev_epe = (flow.permute(0, 2, 3, 1) - x2).norm(dim=-1).detach()
         return tot_loss
diff --git a/third_party/Roma/roma/models/__init__.py b/third_party/Roma/roma/models/__init__.py
index 5f20461e2f3a1722e558cefab94c5164be8842c3..3918d67063b9ab7a8ced80c22a5e74f95ff7fd4a 100644
--- a/third_party/Roma/roma/models/__init__.py
+++ b/third_party/Roma/roma/models/__init__.py
@@ -1 +1 @@
-from .model_zoo import roma_outdoor, roma_indoor
\ No newline at end of file
+from .model_zoo import roma_outdoor, roma_indoor
diff --git a/third_party/Roma/roma/models/encoders.py b/third_party/Roma/roma/models/encoders.py
index 69b488743b91905aca6adc3e4d3439421d492051..923a56d7ca30d73884ac5f313d44614998540dc3 100644
--- a/third_party/Roma/roma/models/encoders.py
+++ b/third_party/Roma/roma/models/encoders.py
@@ -8,35 +8,52 @@ import gc
 
 
 class ResNet50(nn.Module):
-    def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False) -> None:
+    def __init__(
+        self,
+        pretrained=False,
+        high_res=False,
+        weights=None,
+        dilation=None,
+        freeze_bn=True,
+        anti_aliased=False,
+        early_exit=False,
+        amp=False,
+    ) -> None:
         super().__init__()
         if dilation is None:
-            dilation = [False,False,False]
+            dilation = [False, False, False]
         if anti_aliased:
             pass
         else:
             if weights is not None:
-                self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation)
+                self.net = tvm.resnet50(
+                    weights=weights, replace_stride_with_dilation=dilation
+                )
             else:
-                self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation)
-            
+                self.net = tvm.resnet50(
+                    pretrained=pretrained, replace_stride_with_dilation=dilation
+                )
+
         self.high_res = high_res
         self.freeze_bn = freeze_bn
         self.early_exit = early_exit
         self.amp = amp
-        self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+        if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
+            self.amp_dtype = torch.bfloat16
+        else:
+            self.amp_dtype = torch.float16
 
     def forward(self, x, **kwargs):
-        with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
+        with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype):
             net = self.net
-            feats = {1:x}
+            feats = {1: x}
             x = net.conv1(x)
             x = net.bn1(x)
             x = net.relu(x)
-            feats[2] = x 
+            feats[2] = x
             x = net.maxpool(x)
             x = net.layer1(x)
-            feats[4] = x 
+            feats[4] = x
             x = net.layer2(x)
             feats[8] = x
             if self.early_exit:
@@ -55,35 +72,45 @@ class ResNet50(nn.Module):
                     m.eval()
                 pass
 
+
 class VGG19(nn.Module):
-    def __init__(self, pretrained=False, amp = False) -> None:
+    def __init__(self, pretrained=False, amp=False) -> None:
         super().__init__()
         self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
         self.amp = amp
-        self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+        if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
+            self.amp_dtype = torch.bfloat16
+        else:
+            self.amp_dtype = torch.float16
 
     def forward(self, x, **kwargs):
-        with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
+        with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype):
             feats = {}
             scale = 1
             for layer in self.layers:
                 if isinstance(layer, nn.MaxPool2d):
                     feats[scale] = x
-                    scale = scale*2
+                    scale = scale * 2
                 x = layer(x)
             return feats
 
+
 class CNNandDinov2(nn.Module):
-    def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None):
+    def __init__(self, cnn_kwargs=None, amp=False, use_vgg=False, dinov2_weights=None):
         super().__init__()
         if dinov2_weights is None:
-            dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu")
+            dinov2_weights = torch.hub.load_state_dict_from_url(
+                "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth",
+                map_location="cpu",
+            )
         from .transformer import vit_large
-        vit_kwargs = dict(img_size= 518,
-            patch_size= 14,
-            init_values = 1.0,
-            ffn_layer = "mlp",
-            block_chunks = 0,
+
+        vit_kwargs = dict(
+            img_size=518,
+            patch_size=14,
+            init_values=1.0,
+            ffn_layer="mlp",
+            block_chunks=0,
         )
 
         dinov2_vitl14 = vit_large(**vit_kwargs).eval()
@@ -94,25 +121,35 @@ class CNNandDinov2(nn.Module):
         else:
             self.cnn = VGG19(**cnn_kwargs)
         self.amp = amp
-        self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+        if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
+            self.amp_dtype = torch.bfloat16
+        else:
+            self.amp_dtype = torch.float16
         if self.amp:
             dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
-        self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
-    
-    
+        self.dinov2_vitl14 = [dinov2_vitl14]  # ugly hack to not show parameters to DDP
+
     def train(self, mode: bool = True):
         return self.cnn.train(mode)
-    
-    def forward(self, x, upsample = False):
-        B,C,H,W = x.shape
+
+    def forward(self, x, upsample=False):
+        B, C, H, W = x.shape
         feature_pyramid = self.cnn(x)
-        
+
         if not upsample:
             with torch.no_grad():
                 if self.dinov2_vitl14[0].device != x.device:
-                    self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
-                dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
-                features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14)
+                    self.dinov2_vitl14[0] = (
+                        self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
+                    )
+                dinov2_features_16 = self.dinov2_vitl14[0].forward_features(
+                    x.to(self.amp_dtype)
+                )
+                features_16 = (
+                    dinov2_features_16["x_norm_patchtokens"]
+                    .permute(0, 2, 1)
+                    .reshape(B, 1024, H // 14, W // 14)
+                )
                 del dinov2_features_16
                 feature_pyramid[16] = features_16
-        return feature_pyramid
\ No newline at end of file
+        return feature_pyramid
diff --git a/third_party/Roma/roma/models/matcher.py b/third_party/Roma/roma/models/matcher.py
index c06e1ba3aebe8dec7ee9f1800a6f4ba55ac8f0d9..3e1cee16b586ef1ff5f18e74b203d20aa1f16b1c 100644
--- a/third_party/Roma/roma/models/matcher.py
+++ b/third_party/Roma/roma/models/matcher.py
@@ -14,6 +14,7 @@ from roma.utils.local_correlation import local_correlation
 from roma.utils.utils import cls_to_flow_refine
 from roma.utils.kde import kde
 
+
 class ConvRefiner(nn.Module):
     def __init__(
         self,
@@ -23,25 +24,29 @@ class ConvRefiner(nn.Module):
         dw=False,
         kernel_size=5,
         hidden_blocks=3,
-        displacement_emb = None,
-        displacement_emb_dim = None,
-        local_corr_radius = None,
-        corr_in_other = None,
-        no_im_B_fm = False,
-        amp = False,
-        concat_logits = False,
-        use_bias_block_1 = True,
-        use_cosine_corr = False,
-        disable_local_corr_grad = False,
-        is_classifier = False,
-        sample_mode = "bilinear",
-        norm_type = nn.BatchNorm2d,
-        bn_momentum = 0.1,
+        displacement_emb=None,
+        displacement_emb_dim=None,
+        local_corr_radius=None,
+        corr_in_other=None,
+        no_im_B_fm=False,
+        amp=False,
+        concat_logits=False,
+        use_bias_block_1=True,
+        use_cosine_corr=False,
+        disable_local_corr_grad=False,
+        is_classifier=False,
+        sample_mode="bilinear",
+        norm_type=nn.BatchNorm2d,
+        bn_momentum=0.1,
     ):
         super().__init__()
         self.bn_momentum = bn_momentum
         self.block1 = self.create_block(
-            in_dim, hidden_dim, dw=dw, kernel_size=kernel_size, bias = use_bias_block_1,
+            in_dim,
+            hidden_dim,
+            dw=dw,
+            kernel_size=kernel_size,
+            bias=use_bias_block_1,
         )
         self.hidden_blocks = nn.Sequential(
             *[
@@ -59,7 +64,7 @@ class ConvRefiner(nn.Module):
         self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
         if displacement_emb:
             self.has_displacement_emb = True
-            self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0)
+            self.disp_emb = nn.Conv2d(2, displacement_emb_dim, 1, 1, 0)
         else:
             self.has_displacement_emb = False
         self.local_corr_radius = local_corr_radius
@@ -71,16 +76,19 @@ class ConvRefiner(nn.Module):
         self.disable_local_corr_grad = disable_local_corr_grad
         self.is_classifier = is_classifier
         self.sample_mode = sample_mode
-        self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
-        
+        if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
+            self.amp_dtype = torch.bfloat16
+        else:
+            self.amp_dtype = torch.float16
+
     def create_block(
         self,
         in_dim,
         out_dim,
         dw=False,
         kernel_size=5,
-        bias = True,
-        norm_type = nn.BatchNorm2d,
+        bias=True,
+        norm_type=nn.BatchNorm2d,
     ):
         num_groups = 1 if not dw else in_dim
         if dw:
@@ -96,38 +104,56 @@ class ConvRefiner(nn.Module):
             groups=num_groups,
             bias=bias,
         )
-        norm = norm_type(out_dim, momentum = self.bn_momentum) if norm_type is nn.BatchNorm2d else norm_type(num_channels = out_dim)
+        norm = (
+            norm_type(out_dim, momentum=self.bn_momentum)
+            if norm_type is nn.BatchNorm2d
+            else norm_type(num_channels=out_dim)
+        )
         relu = nn.ReLU(inplace=True)
         conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
         return nn.Sequential(conv1, norm, relu, conv2)
-        
-    def forward(self, x, y, flow, scale_factor = 1, logits = None):
-        b,c,hs,ws = x.shape
-        with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
+
+    def forward(self, x, y, flow, scale_factor=1, logits=None):
+        b, c, hs, ws = x.shape
+        with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype):
             with torch.no_grad():
-                x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False, mode = self.sample_mode)
+                x_hat = F.grid_sample(
+                    y,
+                    flow.permute(0, 2, 3, 1),
+                    align_corners=False,
+                    mode=self.sample_mode,
+                )
             if self.has_displacement_emb:
                 im_A_coords = torch.meshgrid(
-                (
-                    torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
-                    torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
-                )
+                    (
+                        torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
+                        torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
+                    )
                 )
                 im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
                 im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
-                in_displacement = flow-im_A_coords
-                emb_in_displacement = self.disp_emb(40/32 * scale_factor * in_displacement)
+                in_displacement = flow - im_A_coords
+                emb_in_displacement = self.disp_emb(
+                    40 / 32 * scale_factor * in_displacement
+                )
                 if self.local_corr_radius:
                     if self.corr_in_other:
                         # Corr in other means take a kxk grid around the predicted coordinate in other image
-                        local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow, 
-                                                       sample_mode = self.sample_mode)
+                        local_corr = local_correlation(
+                            x,
+                            y,
+                            local_radius=self.local_corr_radius,
+                            flow=flow,
+                            sample_mode=self.sample_mode,
+                        )
                     else:
-                        raise NotImplementedError("Local corr in own frame should not be used.")
+                        raise NotImplementedError(
+                            "Local corr in own frame should not be used."
+                        )
                     if self.no_im_B_fm:
                         x_hat = torch.zeros_like(x)
                     d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
-                else:    
+                else:
                     d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
             else:
                 if self.no_im_B_fm:
@@ -141,6 +167,7 @@ class ConvRefiner(nn.Module):
         displacement, certainty = d[:, :-1], d[:, -1:]
         return displacement, certainty
 
+
 class CosKernel(nn.Module):  # similar to softmax kernel
     def __init__(self, T, learn_temperature=False):
         super().__init__()
@@ -161,6 +188,7 @@ class CosKernel(nn.Module):  # similar to softmax kernel
         K = ((c - 1.0) / T).exp()
         return K
 
+
 class GP(nn.Module):
     def __init__(
         self,
@@ -174,7 +202,7 @@ class GP(nn.Module):
         only_nearest_neighbour=False,
         sigma_noise=0.1,
         no_cov=False,
-        predict_features = False,
+        predict_features=False,
     ):
         super().__init__()
         self.K = kernel(T=T, learn_temperature=learn_temperature)
@@ -262,7 +290,9 @@ class GP(nn.Module):
         mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
         if not self.no_cov:
             cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
-            cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1)
+            cov_x = rearrange(
+                cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1
+            )
             local_cov_x = self.get_local_cov(cov_x)
             local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
             gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
@@ -270,11 +300,22 @@ class GP(nn.Module):
             gp_feats = mu_x
         return gp_feats
 
+
 class Decoder(nn.Module):
     def __init__(
-        self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings = None,
-        num_refinement_steps_per_scale = 1, warp_noise_std = 0.0, displacement_dropout_p = 0.0, gm_warp_dropout_p = 0.0,
-        flow_upsample_mode = "bilinear"
+        self,
+        embedding_decoder,
+        gps,
+        proj,
+        conv_refiner,
+        detach=False,
+        scales="all",
+        pos_embeddings=None,
+        num_refinement_steps_per_scale=1,
+        warp_noise_std=0.0,
+        displacement_dropout_p=0.0,
+        gm_warp_dropout_p=0.0,
+        flow_upsample_mode="bilinear",
     ):
         super().__init__()
         self.embedding_decoder = embedding_decoder
@@ -296,8 +337,11 @@ class Decoder(nn.Module):
         self.displacement_dropout_p = displacement_dropout_p
         self.gm_warp_dropout_p = gm_warp_dropout_p
         self.flow_upsample_mode = flow_upsample_mode
-        self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
-        
+        if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
+            self.amp_dtype = torch.bfloat16
+        else:
+            self.amp_dtype = torch.float16
+
     def get_placeholder_flow(self, b, h, w, device):
         coarse_coords = torch.meshgrid(
             (
@@ -310,8 +354,8 @@ class Decoder(nn.Module):
         ].expand(b, h, w, 2)
         coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
         return coarse_coords
-    
-    def get_positional_embedding(self, b, h ,w, device):
+
+    def get_positional_embedding(self, b, h, w, device):
         coarse_coords = torch.meshgrid(
             (
                 torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
@@ -326,16 +370,29 @@ class Decoder(nn.Module):
         coarse_embedded_coords = self.pos_embedding(coarse_coords)
         return coarse_embedded_coords
 
-    def forward(self, f1, f2, gt_warp = None, gt_prob = None, upsample = False, flow = None, certainty = None, scale_factor = 1):
+    def forward(
+        self,
+        f1,
+        f2,
+        gt_warp=None,
+        gt_prob=None,
+        upsample=False,
+        flow=None,
+        certainty=None,
+        scale_factor=1,
+    ):
         coarse_scales = self.embedding_decoder.scales()
-        all_scales = self.scales if not upsample else ["8", "4", "2", "1"] 
+        all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
         sizes = {scale: f1[scale].shape[-2:] for scale in f1}
         h, w = sizes[1]
         b = f1[1].shape[0]
         device = f1[1].device
         coarsest_scale = int(all_scales[0])
         old_stuff = torch.zeros(
-            b, self.embedding_decoder.hidden_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device
+            b,
+            self.embedding_decoder.hidden_dim,
+            *sizes[coarsest_scale],
+            device=f1[coarsest_scale].device,
         )
         corresps = {}
         if not upsample:
@@ -343,17 +400,17 @@ class Decoder(nn.Module):
             certainty = 0.0
         else:
             flow = F.interpolate(
-                    flow,
-                    size=sizes[coarsest_scale],
-                    align_corners=False,
-                    mode="bilinear",
-                )
+                flow,
+                size=sizes[coarsest_scale],
+                align_corners=False,
+                mode="bilinear",
+            )
             certainty = F.interpolate(
-                    certainty,
-                    size=sizes[coarsest_scale],
-                    align_corners=False,
-                    mode="bilinear",
-                )
+                certainty,
+                size=sizes[coarsest_scale],
+                align_corners=False,
+                mode="bilinear",
+            )
         displacement = 0.0
         for new_scale in all_scales:
             ins = int(new_scale)
@@ -371,32 +428,59 @@ class Decoder(nn.Module):
                 gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder(
                     gp_posterior, f1_s, old_stuff, new_scale
                 )
-                
+
                 if self.embedding_decoder.is_classifier:
                     flow = cls_to_flow_refine(
                         gm_warp_or_cls,
-                    ).permute(0,3,1,2)
-                    corresps[ins].update({"gm_cls": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
+                    ).permute(0, 3, 1, 2)
+                    corresps[ins].update(
+                        {
+                            "gm_cls": gm_warp_or_cls,
+                            "gm_certainty": certainty,
+                        }
+                    ) if self.training else None
                 else:
-                    corresps[ins].update({"gm_flow": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
+                    corresps[ins].update(
+                        {
+                            "gm_flow": gm_warp_or_cls,
+                            "gm_certainty": certainty,
+                        }
+                    ) if self.training else None
                     flow = gm_warp_or_cls.detach()
-                    
+
             if new_scale in self.conv_refiner:
-                corresps[ins].update({"flow_pre_delta": flow}) if self.training else None
+                corresps[ins].update(
+                    {"flow_pre_delta": flow}
+                ) if self.training else None
                 delta_flow, delta_certainty = self.conv_refiner[new_scale](
-                    f1_s, f2_s, flow, scale_factor = scale_factor, logits = certainty,
-                )                    
-                corresps[ins].update({"delta_flow": delta_flow,}) if self.training else None
-                displacement = ins*torch.stack((delta_flow[:, 0].float() / (self.refine_init * w),
-                                                delta_flow[:, 1].float() / (self.refine_init * h),),dim=1,)
+                    f1_s,
+                    f2_s,
+                    flow,
+                    scale_factor=scale_factor,
+                    logits=certainty,
+                )
+                corresps[ins].update(
+                    {
+                        "delta_flow": delta_flow,
+                    }
+                ) if self.training else None
+                displacement = ins * torch.stack(
+                    (
+                        delta_flow[:, 0].float() / (self.refine_init * w),
+                        delta_flow[:, 1].float() / (self.refine_init * h),
+                    ),
+                    dim=1,
+                )
                 flow = flow + displacement
                 certainty = (
                     certainty + delta_certainty
                 )  # predict both certainty and displacement
-            corresps[ins].update({
-                "certainty": certainty,
-                "flow": flow,             
-            })
+            corresps[ins].update(
+                {
+                    "certainty": certainty,
+                    "flow": flow,
+                }
+            )
             if new_scale != "1":
                 flow = F.interpolate(
                     flow,
@@ -411,7 +495,7 @@ class Decoder(nn.Module):
                 if self.detach:
                     flow = flow.detach()
                     certainty = certainty.detach()
-            #torch.cuda.empty_cache()                
+            # torch.cuda.empty_cache()
         return corresps
 
 
@@ -422,11 +506,11 @@ class RegressionMatcher(nn.Module):
         decoder,
         h=448,
         w=448,
-        sample_mode = "threshold",
-        upsample_preds = False,
-        symmetric = False,
-        name = None,
-        attenuate_cert = None,
+        sample_mode="threshold",
+        upsample_preds=False,
+        symmetric=False,
+        name=None,
+        attenuate_cert=None,
     ):
         super().__init__()
         self.attenuate_cert = attenuate_cert
@@ -438,24 +522,26 @@ class RegressionMatcher(nn.Module):
         self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
         self.sample_mode = sample_mode
         self.upsample_preds = upsample_preds
-        self.upsample_res = (14*16*6, 14*16*6)
+        self.upsample_res = (14 * 16 * 6, 14 * 16 * 6)
         self.symmetric = symmetric
         self.sample_thresh = 0.05
-            
+
     def get_output_resolution(self):
         if not self.upsample_preds:
             return self.h_resized, self.w_resized
         else:
             return self.upsample_res
-    
-    def extract_backbone_features(self, batch, batched = True, upsample = False):
+
+    def extract_backbone_features(self, batch, batched=True, upsample=False):
         x_q = batch["im_A"]
         x_s = batch["im_B"]
         if batched:
-            X = torch.cat((x_q, x_s), dim = 0)
-            feature_pyramid = self.encoder(X, upsample = upsample)
+            X = torch.cat((x_q, x_s), dim=0)
+            feature_pyramid = self.encoder(X, upsample=upsample)
         else:
-            feature_pyramid = self.encoder(x_q, upsample = upsample), self.encoder(x_s, upsample = upsample)
+            feature_pyramid = self.encoder(x_q, upsample=upsample), self.encoder(
+                x_s, upsample=upsample
+            )
         return feature_pyramid
 
     def sample(
@@ -473,22 +559,28 @@ class RegressionMatcher(nn.Module):
             certainty.reshape(-1),
         )
         expansion_factor = 4 if "balanced" in self.sample_mode else 1
-        good_samples = torch.multinomial(certainty, 
-                          num_samples = min(expansion_factor*num, len(certainty)), 
-                          replacement=False)
+        good_samples = torch.multinomial(
+            certainty,
+            num_samples=min(expansion_factor * num, len(certainty)),
+            replacement=False,
+        )
         good_matches, good_certainty = matches[good_samples], certainty[good_samples]
         if "balanced" not in self.sample_mode:
             return good_matches, good_certainty
         density = kde(good_matches, std=0.1)
-        p = 1 / (density+1)
-        p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
-        balanced_samples = torch.multinomial(p, 
-                          num_samples = min(num,len(good_certainty)), 
-                          replacement=False)
+        p = 1 / (density + 1)
+        p[
+            density < 10
+        ] = 1e-7  # Basically should have at least 10 perfect neighbours, or around 100 ok ones
+        balanced_samples = torch.multinomial(
+            p, num_samples=min(num, len(good_certainty)), replacement=False
+        )
         return good_matches[balanced_samples], good_certainty[balanced_samples]
 
-    def forward(self, batch, batched = True, upsample = False, scale_factor = 1):
-        feature_pyramid = self.extract_backbone_features(batch, batched=batched, upsample = upsample)
+    def forward(self, batch, batched=True, upsample=False, scale_factor=1):
+        feature_pyramid = self.extract_backbone_features(
+            batch, batched=batched, upsample=upsample
+        )
         if batched:
             f_q_pyramid = {
                 scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
@@ -498,32 +590,42 @@ class RegressionMatcher(nn.Module):
             }
         else:
             f_q_pyramid, f_s_pyramid = feature_pyramid
-        corresps = self.decoder(f_q_pyramid, 
-                                f_s_pyramid, 
-                                upsample = upsample, 
-                                **(batch["corresps"] if "corresps" in batch else {}),
-                                scale_factor=scale_factor)
-        
+        corresps = self.decoder(
+            f_q_pyramid,
+            f_s_pyramid,
+            upsample=upsample,
+            **(batch["corresps"] if "corresps" in batch else {}),
+            scale_factor=scale_factor,
+        )
+
         return corresps
 
-    def forward_symmetric(self, batch, batched = True, upsample = False, scale_factor = 1):
-        feature_pyramid = self.extract_backbone_features(batch, batched = batched, upsample = upsample)
+    def forward_symmetric(self, batch, batched=True, upsample=False, scale_factor=1):
+        feature_pyramid = self.extract_backbone_features(
+            batch, batched=batched, upsample=upsample
+        )
         f_q_pyramid = feature_pyramid
         f_s_pyramid = {
-            scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim = 0)
+            scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim=0)
             for scale, f_scale in feature_pyramid.items()
         }
-        corresps = self.decoder(f_q_pyramid, 
-                                f_s_pyramid, 
-                                upsample = upsample, 
-                                **(batch["corresps"] if "corresps" in batch else {}),
-                                scale_factor=scale_factor)
+        corresps = self.decoder(
+            f_q_pyramid,
+            f_s_pyramid,
+            upsample=upsample,
+            **(batch["corresps"] if "corresps" in batch else {}),
+            scale_factor=scale_factor,
+        )
         return corresps
-    
+
     def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
-        kpts_A, kpts_B = matches[...,:2], matches[...,2:]
-        kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1)
-        kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
+        kpts_A, kpts_B = matches[..., :2], matches[..., 2:]
+        kpts_A = torch.stack(
+            (W_A / 2 * (kpts_A[..., 0] + 1), H_A / 2 * (kpts_A[..., 1] + 1)), axis=-1
+        )
+        kpts_B = torch.stack(
+            (W_B / 2 * (kpts_B[..., 0] + 1), H_B / 2 * (kpts_B[..., 1] + 1)), axis=-1
+        )
         return kpts_A, kpts_B
 
     def match(
@@ -532,11 +634,12 @@ class RegressionMatcher(nn.Module):
         im_B_path,
         *args,
         batched=False,
-        device = None,
+        device=None,
     ):
         if device is None:
-            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         from PIL import Image
+
         if isinstance(im_A_path, (str, os.PathLike)):
             im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
         else:
@@ -552,9 +655,9 @@ class RegressionMatcher(nn.Module):
                 # Get images in good format
                 ws = self.w_resized
                 hs = self.h_resized
-                
+
                 test_transform = get_tuple_transform_ops(
-                    resize=(hs, ws), normalize=True, clahe = False
+                    resize=(hs, ws), normalize=True, clahe=False
                 )
                 im_A, im_B = test_transform((im_A, im_B))
                 batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
@@ -564,25 +667,32 @@ class RegressionMatcher(nn.Module):
                 assert w == w2 and h == h2, "For batched images we assume same size"
                 batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)}
                 if h != self.h_resized or self.w_resized != w:
-                    warn("Model resolution and batch resolution differ, may produce unexpected results")
+                    warn(
+                        "Model resolution and batch resolution differ, may produce unexpected results"
+                    )
                 hs, ws = h, w
             finest_scale = 1
             # Run matcher
             if symmetric:
-                corresps  = self.forward_symmetric(batch)
+                corresps = self.forward_symmetric(batch)
             else:
-                corresps = self.forward(batch, batched = True)
+                corresps = self.forward(batch, batched=True)
 
             if self.upsample_preds:
                 hs, ws = self.upsample_res
-            
+
             if self.attenuate_cert:
                 low_res_certainty = F.interpolate(
-                corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
+                    corresps[16]["certainty"],
+                    size=(hs, ws),
+                    align_corners=False,
+                    mode="bilinear",
                 )
                 cert_clamp = 0
                 factor = 0.5
-                low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
+                low_res_certainty = (
+                    factor * low_res_certainty * (low_res_certainty < cert_clamp)
+                )
 
             if self.upsample_preds:
                 finest_corresps = corresps[finest_scale]
@@ -593,25 +703,33 @@ class RegressionMatcher(nn.Module):
                 im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
                 im_A, im_B = test_transform((im_A, im_B))
                 im_A, im_B = im_A[None].to(device), im_B[None].to(device)
-                scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
+                scale_factor = math.sqrt(
+                    self.upsample_res[0]
+                    * self.upsample_res[1]
+                    / (self.w_resized * self.h_resized)
+                )
                 batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
                 if symmetric:
-                    corresps = self.forward_symmetric(batch, upsample = True, batched=True, scale_factor = scale_factor)
+                    corresps = self.forward_symmetric(
+                        batch, upsample=True, batched=True, scale_factor=scale_factor
+                    )
                 else:
-                    corresps = self.forward(batch, batched = True, upsample=True, scale_factor = scale_factor)
-            
-            im_A_to_im_B = corresps[finest_scale]["flow"] 
-            certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0)
+                    corresps = self.forward(
+                        batch, batched=True, upsample=True, scale_factor=scale_factor
+                    )
+
+            im_A_to_im_B = corresps[finest_scale]["flow"]
+            certainty = corresps[finest_scale]["certainty"] - (
+                low_res_certainty if self.attenuate_cert else 0
+            )
             if finest_scale != 1:
                 im_A_to_im_B = F.interpolate(
-                im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
+                    im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
                 )
                 certainty = F.interpolate(
-                certainty, size=(hs, ws), align_corners=False, mode="bilinear"
-                )
-            im_A_to_im_B = im_A_to_im_B.permute(
-                0, 2, 3, 1
+                    certainty, size=(hs, ws), align_corners=False, mode="bilinear"
                 )
+            im_A_to_im_B = im_A_to_im_B.permute(0, 2, 3, 1)
             # Create im_A meshgrid
             im_A_coords = torch.meshgrid(
                 (
@@ -625,25 +743,21 @@ class RegressionMatcher(nn.Module):
             im_A_coords = im_A_coords.permute(0, 2, 3, 1)
             if (im_A_to_im_B.abs() > 1).any() and True:
                 wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
-                certainty[wrong[:,None]] = 0
+                certainty[wrong[:, None]] = 0
             im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
             if symmetric:
                 A_to_B, B_to_A = im_A_to_im_B.chunk(2)
                 q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
                 im_B_coords = im_A_coords
                 s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
-                warp = torch.cat((q_warp, s_warp),dim=2)
+                warp = torch.cat((q_warp, s_warp), dim=2)
                 certainty = torch.cat(certainty.chunk(2), dim=3)
             else:
                 warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
             if batched:
-                return (
-                    warp,
-                    certainty[:, 0]
-                )
+                return (warp, certainty[:, 0])
             else:
                 return (
                     warp[0],
                     certainty[0, 0],
                 )
-
diff --git a/third_party/Roma/roma/models/model_zoo/__init__.py b/third_party/Roma/roma/models/model_zoo/__init__.py
index 91edd4e69f2b39f18d62545a95f2774324ff404b..2ef0b6cf03473500d4198521764cd6dc9ccba784 100644
--- a/third_party/Roma/roma/models/model_zoo/__init__.py
+++ b/third_party/Roma/roma/models/model_zoo/__init__.py
@@ -6,25 +6,41 @@ weight_urls = {
         "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth",
         "indoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth",
     },
-    "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", #hopefully this doesnt change :D
+    "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth",  # hopefully this doesnt change :D
 }
 
+
 def roma_outdoor(device, weights=None, dinov2_weights=None):
     if weights is None:
-        weights = torch.hub.load_state_dict_from_url(weight_urls["roma"]["outdoor"],
-                                                     map_location=device)
+        weights = torch.hub.load_state_dict_from_url(
+            weight_urls["roma"]["outdoor"], map_location=device
+        )
     if dinov2_weights is None:
-        dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
-                                                     map_location=device)
-    return roma_model(resolution=(14*8*6,14*8*6), upsample_preds=True,
-               weights=weights,dinov2_weights = dinov2_weights,device=device)
+        dinov2_weights = torch.hub.load_state_dict_from_url(
+            weight_urls["dinov2"], map_location=device
+        )
+    return roma_model(
+        resolution=(14 * 8 * 6, 14 * 8 * 6),
+        upsample_preds=True,
+        weights=weights,
+        dinov2_weights=dinov2_weights,
+        device=device,
+    )
+
 
 def roma_indoor(device, weights=None, dinov2_weights=None):
     if weights is None:
-        weights = torch.hub.load_state_dict_from_url(weight_urls["roma"]["indoor"],
-                                                     map_location=device)
+        weights = torch.hub.load_state_dict_from_url(
+            weight_urls["roma"]["indoor"], map_location=device
+        )
     if dinov2_weights is None:
-        dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
-                                                     map_location=device)
-    return roma_model(resolution=(14*8*5,14*8*5), upsample_preds=False,
-               weights=weights,dinov2_weights = dinov2_weights,device=device)
+        dinov2_weights = torch.hub.load_state_dict_from_url(
+            weight_urls["dinov2"], map_location=device
+        )
+    return roma_model(
+        resolution=(14 * 8 * 5, 14 * 8 * 5),
+        upsample_preds=False,
+        weights=weights,
+        dinov2_weights=dinov2_weights,
+        device=device,
+    )
diff --git a/third_party/Roma/roma/models/model_zoo/roma_models.py b/third_party/Roma/roma/models/model_zoo/roma_models.py
index dfb0ff7264880d25f0feb0802e582bf29c84b051..f98ee44f5e2ebd7e43a8e4b17f99b6ed0e85c93a 100644
--- a/third_party/Roma/roma/models/model_zoo/roma_models.py
+++ b/third_party/Roma/roma/models/model_zoo/roma_models.py
@@ -4,87 +4,95 @@ from roma.models.matcher import *
 from roma.models.transformer import Block, TransformerDecoder, MemEffAttention
 from roma.models.encoders import *
 
-def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_weights=None, **kwargs):
+
+def roma_model(
+    resolution, upsample_preds, device=None, weights=None, dinov2_weights=None, **kwargs
+):
     # roma weights and dinov2 weights are loaded seperately, as dinov2 weights are not parameters
-    torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
-    torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
-    warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
+    torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
+    torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
+    warnings.filterwarnings(
+        "ignore", category=UserWarning, message="TypedStorage is deprecated"
+    )
     gp_dim = 512
     feat_dim = 512
     decoder_dim = gp_dim + feat_dim
     cls_to_coord_res = 64
     coordinate_decoder = TransformerDecoder(
-        nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]), 
-        decoder_dim, 
+        nn.Sequential(
+            *[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]
+        ),
+        decoder_dim,
         cls_to_coord_res**2 + 1,
         is_classifier=True,
-        amp = True,
-        pos_enc = False,)
+        amp=True,
+        pos_enc=False,
+    )
     dw = True
     hidden_blocks = 8
     kernel_size = 5
     displacement_emb = "linear"
     disable_local_corr_grad = True
-    
+
     conv_refiner = nn.ModuleDict(
         {
             "16": ConvRefiner(
-                2 * 512+128+(2*7+1)**2,
-                2 * 512+128+(2*7+1)**2,
+                2 * 512 + 128 + (2 * 7 + 1) ** 2,
+                2 * 512 + 128 + (2 * 7 + 1) ** 2,
                 2 + 1,
                 kernel_size=kernel_size,
                 dw=dw,
                 hidden_blocks=hidden_blocks,
                 displacement_emb=displacement_emb,
                 displacement_emb_dim=128,
-                local_corr_radius = 7,
-                corr_in_other = True,
-                amp = True,
-                disable_local_corr_grad = disable_local_corr_grad,
-                bn_momentum = 0.01,
+                local_corr_radius=7,
+                corr_in_other=True,
+                amp=True,
+                disable_local_corr_grad=disable_local_corr_grad,
+                bn_momentum=0.01,
             ),
             "8": ConvRefiner(
-                2 * 512+64+(2*3+1)**2,
-                2 * 512+64+(2*3+1)**2,
+                2 * 512 + 64 + (2 * 3 + 1) ** 2,
+                2 * 512 + 64 + (2 * 3 + 1) ** 2,
                 2 + 1,
                 kernel_size=kernel_size,
                 dw=dw,
                 hidden_blocks=hidden_blocks,
                 displacement_emb=displacement_emb,
                 displacement_emb_dim=64,
-                local_corr_radius = 3,
-                corr_in_other = True,
-                amp = True,
-                disable_local_corr_grad = disable_local_corr_grad,
-                bn_momentum = 0.01,
+                local_corr_radius=3,
+                corr_in_other=True,
+                amp=True,
+                disable_local_corr_grad=disable_local_corr_grad,
+                bn_momentum=0.01,
             ),
             "4": ConvRefiner(
-                2 * 256+32+(2*2+1)**2,
-                2 * 256+32+(2*2+1)**2,
+                2 * 256 + 32 + (2 * 2 + 1) ** 2,
+                2 * 256 + 32 + (2 * 2 + 1) ** 2,
                 2 + 1,
                 kernel_size=kernel_size,
                 dw=dw,
                 hidden_blocks=hidden_blocks,
                 displacement_emb=displacement_emb,
                 displacement_emb_dim=32,
-                local_corr_radius = 2,
-                corr_in_other = True,
-                amp = True,
-                disable_local_corr_grad = disable_local_corr_grad,
-                bn_momentum = 0.01,
+                local_corr_radius=2,
+                corr_in_other=True,
+                amp=True,
+                disable_local_corr_grad=disable_local_corr_grad,
+                bn_momentum=0.01,
             ),
             "2": ConvRefiner(
-                2 * 64+16,
-                128+16,
+                2 * 64 + 16,
+                128 + 16,
                 2 + 1,
                 kernel_size=kernel_size,
                 dw=dw,
                 hidden_blocks=hidden_blocks,
                 displacement_emb=displacement_emb,
                 displacement_emb_dim=16,
-                amp = True,
-                disable_local_corr_grad = disable_local_corr_grad,
-                bn_momentum = 0.01,
+                amp=True,
+                disable_local_corr_grad=disable_local_corr_grad,
+                bn_momentum=0.01,
             ),
             "1": ConvRefiner(
                 2 * 9 + 6,
@@ -92,12 +100,12 @@ def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_w
                 2 + 1,
                 kernel_size=kernel_size,
                 dw=dw,
-                hidden_blocks = hidden_blocks,
-                displacement_emb = displacement_emb,
-                displacement_emb_dim = 6,
-                amp = True,
-                disable_local_corr_grad = disable_local_corr_grad,
-                bn_momentum = 0.01,
+                hidden_blocks=hidden_blocks,
+                displacement_emb=displacement_emb,
+                displacement_emb_dim=6,
+                amp=True,
+                disable_local_corr_grad=disable_local_corr_grad,
+                bn_momentum=0.01,
             ),
         }
     )
@@ -122,36 +130,46 @@ def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_w
     proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
     proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
     proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
-    proj = nn.ModuleDict({
-        "16": proj16,
-        "8": proj8,
-        "4": proj4,
-        "2": proj2,
-        "1": proj1,
-        })
+    proj = nn.ModuleDict(
+        {
+            "16": proj16,
+            "8": proj8,
+            "4": proj4,
+            "2": proj2,
+            "1": proj1,
+        }
+    )
     displacement_dropout_p = 0.0
     gm_warp_dropout_p = 0.0
-    decoder = Decoder(coordinate_decoder, 
-                      gps, 
-                      proj, 
-                      conv_refiner, 
-                      detach=True, 
-                      scales=["16", "8", "4", "2", "1"], 
-                      displacement_dropout_p = displacement_dropout_p,
-                      gm_warp_dropout_p = gm_warp_dropout_p)
-    
+    decoder = Decoder(
+        coordinate_decoder,
+        gps,
+        proj,
+        conv_refiner,
+        detach=True,
+        scales=["16", "8", "4", "2", "1"],
+        displacement_dropout_p=displacement_dropout_p,
+        gm_warp_dropout_p=gm_warp_dropout_p,
+    )
+
     encoder = CNNandDinov2(
-        cnn_kwargs = dict(
-            pretrained=False,
-            amp = True),
-        amp = True,
-        use_vgg = True,
-        dinov2_weights = dinov2_weights
+        cnn_kwargs=dict(pretrained=False, amp=True),
+        amp=True,
+        use_vgg=True,
+        dinov2_weights=dinov2_weights,
     )
-    h,w = resolution
+    h, w = resolution
     symmetric = True
     attenuate_cert = True
-    matcher = RegressionMatcher(encoder, decoder, h=h, w=w, upsample_preds=upsample_preds, 
-                                symmetric = symmetric, attenuate_cert=attenuate_cert, **kwargs).to(device)
+    matcher = RegressionMatcher(
+        encoder,
+        decoder,
+        h=h,
+        w=w,
+        upsample_preds=upsample_preds,
+        symmetric=symmetric,
+        attenuate_cert=attenuate_cert,
+        **kwargs
+    ).to(device)
     matcher.load_state_dict(weights)
     return matcher
diff --git a/third_party/Roma/roma/models/transformer/__init__.py b/third_party/Roma/roma/models/transformer/__init__.py
index 4770ebb19f111df14f1539fa3696553d96d4e48b..a4b45d163d7e693b62edb5322a56387f82b27e04 100644
--- a/third_party/Roma/roma/models/transformer/__init__.py
+++ b/third_party/Roma/roma/models/transformer/__init__.py
@@ -7,9 +7,21 @@ from .layers.block import Block
 from .layers.attention import MemEffAttention
 from .dinov2 import vit_large
 
+
 class TransformerDecoder(nn.Module):
-    def __init__(self, blocks, hidden_dim, out_dim, is_classifier = False, *args, 
-                 amp = False, pos_enc = True, learned_embeddings = False, embedding_dim = None, **kwargs) -> None:
+    def __init__(
+        self,
+        blocks,
+        hidden_dim,
+        out_dim,
+        is_classifier=False,
+        *args,
+        amp=False,
+        pos_enc=True,
+        learned_embeddings=False,
+        embedding_dim=None,
+        **kwargs
+    ) -> None:
         super().__init__(*args, **kwargs)
         self.blocks = blocks
         self.to_out = nn.Linear(hidden_dim, out_dim)
@@ -18,30 +30,44 @@ class TransformerDecoder(nn.Module):
         self._scales = [16]
         self.is_classifier = is_classifier
         self.amp = amp
-        self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+        if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
+            self.amp_dtype = torch.bfloat16
+        else:
+            self.amp_dtype = torch.float16
         self.pos_enc = pos_enc
         self.learned_embeddings = learned_embeddings
         if self.learned_embeddings:
-            self.learned_pos_embeddings = nn.Parameter(nn.init.kaiming_normal_(torch.empty((1, hidden_dim, embedding_dim, embedding_dim))))
+            self.learned_pos_embeddings = nn.Parameter(
+                nn.init.kaiming_normal_(
+                    torch.empty((1, hidden_dim, embedding_dim, embedding_dim))
+                )
+            )
 
     def scales(self):
         return self._scales.copy()
 
     def forward(self, gp_posterior, features, old_stuff, new_scale):
         with torch.autocast("cuda", dtype=self.amp_dtype, enabled=self.amp):
-            B,C,H,W = gp_posterior.shape
-            x = torch.cat((gp_posterior, features), dim = 1)
-            B,C,H,W = x.shape
-            grid = get_grid(B, H, W, x.device).reshape(B,H*W,2)
+            B, C, H, W = gp_posterior.shape
+            x = torch.cat((gp_posterior, features), dim=1)
+            B, C, H, W = x.shape
+            grid = get_grid(B, H, W, x.device).reshape(B, H * W, 2)
             if self.learned_embeddings:
-                pos_enc = F.interpolate(self.learned_pos_embeddings, size = (H,W), mode = 'bilinear', align_corners = False).permute(0,2,3,1).reshape(1,H*W,C)
+                pos_enc = (
+                    F.interpolate(
+                        self.learned_pos_embeddings,
+                        size=(H, W),
+                        mode="bilinear",
+                        align_corners=False,
+                    )
+                    .permute(0, 2, 3, 1)
+                    .reshape(1, H * W, C)
+                )
             else:
                 pos_enc = 0
-            tokens = x.reshape(B,C,H*W).permute(0,2,1) + pos_enc
+            tokens = x.reshape(B, C, H * W).permute(0, 2, 1) + pos_enc
             z = self.blocks(tokens)
             out = self.to_out(z)
-            out = out.permute(0,2,1).reshape(B, self.out_dim, H, W)
+            out = out.permute(0, 2, 1).reshape(B, self.out_dim, H, W)
             warp, certainty = out[:, :-1], out[:, -1:]
             return warp, certainty, None
-
-
diff --git a/third_party/Roma/roma/models/transformer/dinov2.py b/third_party/Roma/roma/models/transformer/dinov2.py
index b556c63096d17239c8603d5fe626c331963099fd..1c27c65b5061cc0113792e40b96eaf7f4266ce18 100644
--- a/third_party/Roma/roma/models/transformer/dinov2.py
+++ b/third_party/Roma/roma/models/transformer/dinov2.py
@@ -18,16 +18,29 @@ import torch.nn as nn
 import torch.utils.checkpoint
 from torch.nn.init import trunc_normal_
 
-from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
-
-
-
-def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+from .layers import (
+    Mlp,
+    PatchEmbed,
+    SwiGLUFFNFused,
+    MemEffAttention,
+    NestedTensorBlock as Block,
+)
+
+
+def named_apply(
+    fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
+) -> nn.Module:
     if not depth_first and include_root:
         fn(module=module, name=name)
     for child_name, child_module in module.named_children():
         child_name = ".".join((name, child_name)) if name else child_name
-        named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+        named_apply(
+            fn=fn,
+            module=child_module,
+            name=child_name,
+            depth_first=depth_first,
+            include_root=True,
+        )
     if depth_first and include_root:
         fn(module=module, name=name)
     return module
@@ -87,22 +100,33 @@ class DinoVisionTransformer(nn.Module):
         super().__init__()
         norm_layer = partial(nn.LayerNorm, eps=1e-6)
 
-        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
+        self.num_features = (
+            self.embed_dim
+        ) = embed_dim  # num_features for consistency with other models
         self.num_tokens = 1
         self.n_blocks = depth
         self.num_heads = num_heads
         self.patch_size = patch_size
 
-        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+        self.patch_embed = embed_layer(
+            img_size=img_size,
+            patch_size=patch_size,
+            in_chans=in_chans,
+            embed_dim=embed_dim,
+        )
         num_patches = self.patch_embed.num_patches
 
         self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
-        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+        self.pos_embed = nn.Parameter(
+            torch.zeros(1, num_patches + self.num_tokens, embed_dim)
+        )
 
         if drop_path_uniform is True:
             dpr = [drop_path_rate] * depth
         else:
-            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
+            dpr = [
+                x.item() for x in torch.linspace(0, drop_path_rate, depth)
+            ]  # stochastic depth decay rule
 
         if ffn_layer == "mlp":
             ffn_layer = Mlp
@@ -139,7 +163,9 @@ class DinoVisionTransformer(nn.Module):
             chunksize = depth // block_chunks
             for i in range(0, depth, chunksize):
                 # this is to keep the block index consistent if we chunk the block list
-                chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+                chunked_blocks.append(
+                    [nn.Identity()] * i + blocks_list[i : i + chunksize]
+                )
             self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
         else:
             self.chunked_blocks = False
@@ -153,7 +179,7 @@ class DinoVisionTransformer(nn.Module):
         self.init_weights()
         for param in self.parameters():
             param.requires_grad = False
-    
+
     @property
     def device(self):
         return self.cls_token.device
@@ -180,20 +206,29 @@ class DinoVisionTransformer(nn.Module):
         w0, h0 = w0 + 0.1, h0 + 0.1
 
         patch_pos_embed = nn.functional.interpolate(
-            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+            patch_pos_embed.reshape(
+                1, int(math.sqrt(N)), int(math.sqrt(N)), dim
+            ).permute(0, 3, 1, 2),
             scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
             mode="bicubic",
         )
 
-        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
+        assert (
+            int(w0) == patch_pos_embed.shape[-2]
+            and int(h0) == patch_pos_embed.shape[-1]
+        )
         patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
-        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
+            previous_dtype
+        )
 
     def prepare_tokens_with_masks(self, x, masks=None):
         B, nc, w, h = x.shape
         x = self.patch_embed(x)
         if masks is not None:
-            x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+            x = torch.where(
+                masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
+            )
 
         x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
         x = x + self.interpolate_pos_encoding(x, w, h)
@@ -201,7 +236,10 @@ class DinoVisionTransformer(nn.Module):
         return x
 
     def forward_features_list(self, x_list, masks_list):
-        x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+        x = [
+            self.prepare_tokens_with_masks(x, masks)
+            for x, masks in zip(x_list, masks_list)
+        ]
         for blk in self.blocks:
             x = blk(x)
 
@@ -240,26 +278,34 @@ class DinoVisionTransformer(nn.Module):
         x = self.prepare_tokens_with_masks(x)
         # If n is an int, take the n last blocks. If it's a list, take them
         output, total_block_len = [], len(self.blocks)
-        blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+        blocks_to_take = (
+            range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+        )
         for i, blk in enumerate(self.blocks):
             x = blk(x)
             if i in blocks_to_take:
                 output.append(x)
-        assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+        assert len(output) == len(
+            blocks_to_take
+        ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
         return output
 
     def _get_intermediate_layers_chunked(self, x, n=1):
         x = self.prepare_tokens_with_masks(x)
         output, i, total_block_len = [], 0, len(self.blocks[-1])
         # If n is an int, take the n last blocks. If it's a list, take them
-        blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+        blocks_to_take = (
+            range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+        )
         for block_chunk in self.blocks:
             for blk in block_chunk[i:]:  # Passing the nn.Identity()
                 x = blk(x)
                 if i in blocks_to_take:
                     output.append(x)
                 i += 1
-        assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+        assert len(output) == len(
+            blocks_to_take
+        ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
         return output
 
     def get_intermediate_layers(
@@ -281,7 +327,9 @@ class DinoVisionTransformer(nn.Module):
         if reshape:
             B, _, w, h = x.shape
             outputs = [
-                out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+                out.reshape(B, w // self.patch_size, h // self.patch_size, -1)
+                .permute(0, 3, 1, 2)
+                .contiguous()
                 for out in outputs
             ]
         if return_class_token:
@@ -356,4 +404,4 @@ def vit_giant2(patch_size=16, **kwargs):
         block_fn=partial(Block, attn_class=MemEffAttention),
         **kwargs,
     )
-    return model
\ No newline at end of file
+    return model
diff --git a/third_party/Roma/roma/models/transformer/layers/attention.py b/third_party/Roma/roma/models/transformer/layers/attention.py
index 1f9b0c94b40967dfdff4f261c127cbd21328c905..12f388719bf5f171d59aee238d902bb7915f864b 100644
--- a/third_party/Roma/roma/models/transformer/layers/attention.py
+++ b/third_party/Roma/roma/models/transformer/layers/attention.py
@@ -48,7 +48,11 @@ class Attention(nn.Module):
 
     def forward(self, x: Tensor) -> Tensor:
         B, N, C = x.shape
-        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+        qkv = (
+            self.qkv(x)
+            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
+            .permute(2, 0, 3, 1, 4)
+        )
 
         q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
         attn = q @ k.transpose(-2, -1)
diff --git a/third_party/Roma/roma/models/transformer/layers/block.py b/third_party/Roma/roma/models/transformer/layers/block.py
index 25488f57cc0ad3c692f86b62555f6668e2a66db1..1b5f5158f073788d3d5fe3e09742d4485ef26441 100644
--- a/third_party/Roma/roma/models/transformer/layers/block.py
+++ b/third_party/Roma/roma/models/transformer/layers/block.py
@@ -62,7 +62,9 @@ class Block(nn.Module):
             attn_drop=attn_drop,
             proj_drop=drop,
         )
-        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+        self.ls1 = (
+            LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+        )
         self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
 
         self.norm2 = norm_layer(dim)
@@ -74,7 +76,9 @@ class Block(nn.Module):
             drop=drop,
             bias=ffn_bias,
         )
-        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+        self.ls2 = (
+            LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+        )
         self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
 
         self.sample_drop_ratio = drop_path
@@ -127,7 +131,9 @@ def drop_add_residual_stochastic_depth(
     residual_scale_factor = b / sample_subset_size
 
     # 3) add the residual
-    x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+    x_plus_residual = torch.index_add(
+        x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
+    )
     return x_plus_residual.view_as(x)
 
 
@@ -143,10 +149,16 @@ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None
     if scaling_vector is None:
         x_flat = x.flatten(1)
         residual = residual.flatten(1)
-        x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+        x_plus_residual = torch.index_add(
+            x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
+        )
     else:
         x_plus_residual = scaled_index_add(
-            x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+            x,
+            brange,
+            residual.to(dtype=x.dtype),
+            scaling=scaling_vector,
+            alpha=residual_scale_factor,
         )
     return x_plus_residual
 
@@ -158,7 +170,11 @@ def get_attn_bias_and_cat(x_list, branges=None):
     """
     this will perform the index select, cat the tensors, and provide the attn_bias from cache
     """
-    batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+    batch_sizes = (
+        [b.shape[0] for b in branges]
+        if branges is not None
+        else [x.shape[0] for x in x_list]
+    )
     all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
     if all_shapes not in attn_bias_cache.keys():
         seqlens = []
@@ -170,7 +186,9 @@ def get_attn_bias_and_cat(x_list, branges=None):
         attn_bias_cache[all_shapes] = attn_bias
 
     if branges is not None:
-        cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+        cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
+            1, -1, x_list[0].shape[-1]
+        )
     else:
         tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
         cat_tensors = torch.cat(tensors_bs1, dim=1)
@@ -185,7 +203,9 @@ def drop_add_residual_stochastic_depth_list(
     scaling_vector=None,
 ) -> Tensor:
     # 1) generate random set of indices for dropping samples in the batch
-    branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+    branges_scales = [
+        get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
+    ]
     branges = [s[0] for s in branges_scales]
     residual_scale_factors = [s[1] for s in branges_scales]
 
@@ -196,8 +216,14 @@ def drop_add_residual_stochastic_depth_list(
     residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias))  # type: ignore
 
     outputs = []
-    for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
-        outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+    for x, brange, residual, residual_scale_factor in zip(
+        x_list, branges, residual_list, residual_scale_factors
+    ):
+        outputs.append(
+            add_residual(
+                x, brange, residual, residual_scale_factor, scaling_vector
+            ).view_as(x)
+        )
     return outputs
 
 
@@ -220,13 +246,17 @@ class NestedTensorBlock(Block):
                 x_list,
                 residual_func=attn_residual_func,
                 sample_drop_ratio=self.sample_drop_ratio,
-                scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+                scaling_vector=self.ls1.gamma
+                if isinstance(self.ls1, LayerScale)
+                else None,
             )
             x_list = drop_add_residual_stochastic_depth_list(
                 x_list,
                 residual_func=ffn_residual_func,
                 sample_drop_ratio=self.sample_drop_ratio,
-                scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+                scaling_vector=self.ls2.gamma
+                if isinstance(self.ls1, LayerScale)
+                else None,
             )
             return x_list
         else:
@@ -246,7 +276,9 @@ class NestedTensorBlock(Block):
         if isinstance(x_or_x_list, Tensor):
             return super().forward(x_or_x_list)
         elif isinstance(x_or_x_list, list):
-            assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
+            assert (
+                XFORMERS_AVAILABLE
+            ), "Please install xFormers for nested tensors usage"
             return self.forward_nested(x_or_x_list)
         else:
             raise AssertionError
diff --git a/third_party/Roma/roma/models/transformer/layers/dino_head.py b/third_party/Roma/roma/models/transformer/layers/dino_head.py
index 7212db92a4fd8d4c7230e284e551a0234e9d8623..1147dd3a3c046aee8d427b42b1055f38a218275b 100644
--- a/third_party/Roma/roma/models/transformer/layers/dino_head.py
+++ b/third_party/Roma/roma/models/transformer/layers/dino_head.py
@@ -23,7 +23,14 @@ class DINOHead(nn.Module):
     ):
         super().__init__()
         nlayers = max(nlayers, 1)
-        self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
+        self.mlp = _build_mlp(
+            nlayers,
+            in_dim,
+            bottleneck_dim,
+            hidden_dim=hidden_dim,
+            use_bn=use_bn,
+            bias=mlp_bias,
+        )
         self.apply(self._init_weights)
         self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
         self.last_layer.weight_g.data.fill_(1)
@@ -42,7 +49,9 @@ class DINOHead(nn.Module):
         return x
 
 
-def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
+def _build_mlp(
+    nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
+):
     if nlayers == 1:
         return nn.Linear(in_dim, bottleneck_dim, bias=bias)
     else:
diff --git a/third_party/Roma/roma/models/transformer/layers/drop_path.py b/third_party/Roma/roma/models/transformer/layers/drop_path.py
index af05625984dd14682cc96a63bf0c97bab1f123b1..a23ba7325d0fd154d5885573770956042ce2311d 100644
--- a/third_party/Roma/roma/models/transformer/layers/drop_path.py
+++ b/third_party/Roma/roma/models/transformer/layers/drop_path.py
@@ -16,7 +16,9 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
     if drop_prob == 0.0 or not training:
         return x
     keep_prob = 1 - drop_prob
-    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
+    shape = (x.shape[0],) + (1,) * (
+        x.ndim - 1
+    )  # work with diff dim tensors, not just 2D ConvNets
     random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
     if keep_prob > 0.0:
         random_tensor.div_(keep_prob)
diff --git a/third_party/Roma/roma/models/transformer/layers/patch_embed.py b/third_party/Roma/roma/models/transformer/layers/patch_embed.py
index 574abe41175568d700a389b8b96d1ba554914779..837f952cf9a463444feeb146e0d5b539102ee26c 100644
--- a/third_party/Roma/roma/models/transformer/layers/patch_embed.py
+++ b/third_party/Roma/roma/models/transformer/layers/patch_embed.py
@@ -63,15 +63,21 @@ class PatchEmbed(nn.Module):
 
         self.flatten_embedding = flatten_embedding
 
-        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+        self.proj = nn.Conv2d(
+            in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
+        )
         self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
 
     def forward(self, x: Tensor) -> Tensor:
         _, _, H, W = x.shape
         patch_H, patch_W = self.patch_size
 
-        assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
-        assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+        assert (
+            H % patch_H == 0
+        ), f"Input image height {H} is not a multiple of patch height {patch_H}"
+        assert (
+            W % patch_W == 0
+        ), f"Input image width {W} is not a multiple of patch width: {patch_W}"
 
         x = self.proj(x)  # B C H W
         H, W = x.size(2), x.size(3)
@@ -83,7 +89,13 @@ class PatchEmbed(nn.Module):
 
     def flops(self) -> float:
         Ho, Wo = self.patches_resolution
-        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+        flops = (
+            Ho
+            * Wo
+            * self.embed_dim
+            * self.in_chans
+            * (self.patch_size[0] * self.patch_size[1])
+        )
         if self.norm is not None:
             flops += Ho * Wo * self.embed_dim
         return flops
diff --git a/third_party/Roma/roma/train/train.py b/third_party/Roma/roma/train/train.py
index 5556f7ebf9b6378e1395c125dde093f5e55e7141..eb3deaf1792a315d1cce77a2ee0fd50ae9e98ac1 100644
--- a/third_party/Roma/roma/train/train.py
+++ b/third_party/Roma/roma/train/train.py
@@ -4,41 +4,62 @@ import roma
 import torch
 import wandb
 
-def log_param_statistics(named_parameters, norm_type = 2):
+
+def log_param_statistics(named_parameters, norm_type=2):
     named_parameters = list(named_parameters)
     grads = [p.grad for n, p in named_parameters if p.grad is not None]
-    weight_norms = [p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None]
-    names = [n for n,p in named_parameters if p.grad is not None]
+    weight_norms = [
+        p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None
+    ]
+    names = [n for n, p in named_parameters if p.grad is not None]
     param_norm = torch.stack(weight_norms).norm(p=norm_type)
     device = grads[0].device
-    grad_norms = torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads])
+    grad_norms = torch.stack(
+        [torch.norm(g.detach(), norm_type).to(device) for g in grads]
+    )
     nans_or_infs = torch.isinf(grad_norms) | torch.isnan(grad_norms)
     nan_inf_names = [name for name, naninf in zip(names, nans_or_infs) if naninf]
     total_grad_norm = torch.norm(grad_norms, norm_type)
     if torch.any(nans_or_infs):
         print(f"These params have nan or inf grads: {nan_inf_names}")
-    wandb.log({"grad_norm": total_grad_norm.item()}, step = roma.GLOBAL_STEP)
-    wandb.log({"param_norm": param_norm.item()}, step = roma.GLOBAL_STEP)
+    wandb.log({"grad_norm": total_grad_norm.item()}, step=roma.GLOBAL_STEP)
+    wandb.log({"param_norm": param_norm.item()}, step=roma.GLOBAL_STEP)
+
 
-def train_step(train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm = 1.,**kwargs):
+def train_step(
+    train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm=1.0, **kwargs
+):
     optimizer.zero_grad()
     out = model(train_batch)
     l = objective(out, train_batch)
     grad_scaler.scale(l).backward()
     grad_scaler.unscale_(optimizer)
     log_param_statistics(model.named_parameters())
-    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm) # what should max norm be?
+    torch.nn.utils.clip_grad_norm_(
+        model.parameters(), grad_clip_norm
+    )  # what should max norm be?
     grad_scaler.step(optimizer)
     grad_scaler.update()
-    wandb.log({"grad_scale": grad_scaler._scale.item()}, step = roma.GLOBAL_STEP)
-    if grad_scaler._scale < 1.:
-        grad_scaler._scale = torch.tensor(1.).to(grad_scaler._scale)
-    roma.GLOBAL_STEP = roma.GLOBAL_STEP + roma.STEP_SIZE # increment global step
+    wandb.log({"grad_scale": grad_scaler._scale.item()}, step=roma.GLOBAL_STEP)
+    if grad_scaler._scale < 1.0:
+        grad_scaler._scale = torch.tensor(1.0).to(grad_scaler._scale)
+    roma.GLOBAL_STEP = roma.GLOBAL_STEP + roma.STEP_SIZE  # increment global step
     return {"train_out": out, "train_loss": l.item()}
 
 
 def train_k_steps(
-    n_0, k, dataloader, model, objective, optimizer, lr_scheduler, grad_scaler, progress_bar=True, grad_clip_norm = 1., warmup = None, ema_model = None,
+    n_0,
+    k,
+    dataloader,
+    model,
+    objective,
+    optimizer,
+    lr_scheduler,
+    grad_scaler,
+    progress_bar=True,
+    grad_clip_norm=1.0,
+    warmup=None,
+    ema_model=None,
 ):
     for n in tqdm(range(n_0, n_0 + k), disable=(not progress_bar) or roma.RANK > 0):
         batch = next(dataloader)
@@ -52,7 +73,7 @@ def train_k_steps(
             lr_scheduler=lr_scheduler,
             grad_scaler=grad_scaler,
             n=n,
-            grad_clip_norm = grad_clip_norm,
+            grad_clip_norm=grad_clip_norm,
         )
         if ema_model is not None:
             ema_model.update()
@@ -61,7 +82,10 @@ def train_k_steps(
                 lr_scheduler.step()
         else:
             lr_scheduler.step()
-        [wandb.log({f"lr_group_{grp}": lr}) for grp, lr in enumerate(lr_scheduler.get_last_lr())]
+        [
+            wandb.log({f"lr_group_{grp}": lr})
+            for grp, lr in enumerate(lr_scheduler.get_last_lr())
+        ]
 
 
 def train_epoch(
diff --git a/third_party/Roma/roma/utils/kde.py b/third_party/Roma/roma/utils/kde.py
index 90a058fb68253cfe23c2a7f21b213bea8e06cfe3..eff7c72dad4a3f90f5ff79d2630427de89838fc5 100644
--- a/third_party/Roma/roma/utils/kde.py
+++ b/third_party/Roma/roma/utils/kde.py
@@ -1,8 +1,9 @@
 import torch
 
-def kde(x, std = 0.1):
+
+def kde(x, std=0.1):
     # use a gaussian kernel to estimate density
-    x = x.half() # Do it in half precision
-    scores = (-torch.cdist(x,x)**2/(2*std**2)).exp()
+    x = x.half()  # Do it in half precision
+    scores = (-torch.cdist(x, x) ** 2 / (2 * std**2)).exp()
     density = scores.sum(dim=-1)
-    return density
\ No newline at end of file
+    return density
diff --git a/third_party/Roma/roma/utils/local_correlation.py b/third_party/Roma/roma/utils/local_correlation.py
index 586eef5f154a95968b253ad9701933b55b3a4dd6..84a13c63b52db979000916bcb9511e1d3a5ca7fa 100644
--- a/third_party/Roma/roma/utils/local_correlation.py
+++ b/third_party/Roma/roma/utils/local_correlation.py
@@ -1,47 +1,66 @@
 import torch
 import torch.nn.functional as F
 
+
 def local_correlation(
     feature0,
     feature1,
     local_radius,
     padding_mode="zeros",
-    flow = None,
-    sample_mode = "bilinear",
+    flow=None,
+    sample_mode="bilinear",
 ):
     r = local_radius
-    K = (2*r+1)**2
+    K = (2 * r + 1) ** 2
     B, c, h, w = feature0.size()
     feature0 = feature0.half()
     feature1 = feature1.half()
-    corr = torch.empty((B,K,h,w), device = feature0.device, dtype=feature0.dtype)
+    corr = torch.empty((B, K, h, w), device=feature0.device, dtype=feature0.dtype)
     if flow is None:
         # If flow is None, assume feature0 and feature1 are aligned
         coords = torch.meshgrid(
-                (
-                    torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device="cuda"),
-                    torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device="cuda"),
-                ))
-        coords = torch.stack((coords[1], coords[0]), dim=-1)[
-            None
-        ].expand(B, h, w, 2)
+            (
+                torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device="cuda"),
+                torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device="cuda"),
+            )
+        )
+        coords = torch.stack((coords[1], coords[0]), dim=-1)[None].expand(B, h, w, 2)
     else:
-        coords = flow.permute(0,2,3,1) # If using flow, sample around flow target.
+        coords = flow.permute(0, 2, 3, 1)  # If using flow, sample around flow target.
     local_window = torch.meshgrid(
-                (
-                    torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device="cuda"),
-                    torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device="cuda"),
-                ))
-    local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[
-            None
-        ].expand(1, 2*r+1, 2*r+1, 2).reshape(1, (2*r+1)**2, 2)
+        (
+            torch.linspace(
+                -2 * local_radius / h, 2 * local_radius / h, 2 * r + 1, device="cuda"
+            ),
+            torch.linspace(
+                -2 * local_radius / w, 2 * local_radius / w, 2 * r + 1, device="cuda"
+            ),
+        )
+    )
+    local_window = (
+        torch.stack((local_window[1], local_window[0]), dim=-1)[None]
+        .expand(1, 2 * r + 1, 2 * r + 1, 2)
+        .reshape(1, (2 * r + 1) ** 2, 2)
+    )
     for _ in range(B):
         with torch.no_grad():
-            local_window_coords = (coords[_,:,:,None]+local_window[:,None,None]).reshape(1,h,w*(2*r+1)**2,2).float()
+            local_window_coords = (
+                (coords[_, :, :, None] + local_window[:, None, None])
+                .reshape(1, h, w * (2 * r + 1) ** 2, 2)
+                .float()
+            )
             window_feature = F.grid_sample(
-                feature1[_:_+1].float(), local_window_coords, padding_mode=padding_mode, align_corners=False, mode = sample_mode, #
+                feature1[_ : _ + 1].float(),
+                local_window_coords,
+                padding_mode=padding_mode,
+                align_corners=False,
+                mode=sample_mode,  #
             )
-            window_feature = window_feature.reshape(c,h,w,(2*r+1)**2)
-        corr[_] = (feature0[_,...,None]/(c**.5)*window_feature).sum(dim=0).permute(2,0,1)
+            window_feature = window_feature.reshape(c, h, w, (2 * r + 1) ** 2)
+        corr[_] = (
+            (feature0[_, ..., None] / (c**0.5) * window_feature)
+            .sum(dim=0)
+            .permute(2, 0, 1)
+        )
     torch.cuda.empty_cache()
-    return corr
\ No newline at end of file
+    return corr
diff --git a/third_party/Roma/roma/utils/transforms.py b/third_party/Roma/roma/utils/transforms.py
index ea6476bd816a31df36f7d1b5417853637b65474b..b33c3f30f422bca6a81aa201952b7bb2d3d906bf 100644
--- a/third_party/Roma/roma/utils/transforms.py
+++ b/third_party/Roma/roma/utils/transforms.py
@@ -16,7 +16,9 @@ class GeometricSequential:
         for t in self.transforms:
             if np.random.rand() < t.p:
                 M = M.matmul(
-                    t.compute_transformation(x, t.generate_parameters((b, c, h, w)), None)
+                    t.compute_transformation(
+                        x, t.generate_parameters((b, c, h, w)), None
+                    )
                 )
         return (
             warp_perspective(
@@ -104,15 +106,14 @@ class RandomPerspective(K.RandomPerspective):
         return dict(start_points=start_points, end_points=end_points)
 
 
-
 class RandomErasing:
-    def __init__(self, p = 0., scale = 0.) -> None:
+    def __init__(self, p=0.0, scale=0.0) -> None:
         self.p = p
         self.scale = scale
-        self.random_eraser = K.RandomErasing(scale = (0.02, scale), p = p)
+        self.random_eraser = K.RandomErasing(scale=(0.02, scale), p=p)
+
     def __call__(self, image, depth):
         if self.p > 0:
             image = self.random_eraser(image)
             depth = self.random_eraser(depth, params=self.random_eraser._params)
         return image, depth
-        
\ No newline at end of file
diff --git a/third_party/Roma/roma/utils/utils.py b/third_party/Roma/roma/utils/utils.py
index d673f679823c833688e2548dd40bf50943796a71..969e1003419f3b7f05874830b79de73363017f01 100644
--- a/third_party/Roma/roma/utils/utils.py
+++ b/third_party/Roma/roma/utils/utils.py
@@ -9,13 +9,14 @@ import torch.nn.functional as F
 from PIL import Image
 import kornia
 
+
 def recover_pose(E, kpts0, kpts1, K0, K1, mask):
     best_num_inliers = 0
-    K0inv = np.linalg.inv(K0[:2,:2])
-    K1inv = np.linalg.inv(K1[:2,:2])
+    K0inv = np.linalg.inv(K0[:2, :2])
+    K1inv = np.linalg.inv(K1[:2, :2])
 
-    kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T 
-    kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
+    kpts0_n = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T
+    kpts1_n = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T
 
     for _E in np.split(E, len(E) / 3):
         n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
@@ -25,17 +26,16 @@ def recover_pose(E, kpts0, kpts1, K0, K1, mask):
     return ret
 
 
-
 # Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py
 # --- GEOMETRY ---
 def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
     if len(kpts0) < 5:
         return None
-    K0inv = np.linalg.inv(K0[:2,:2])
-    K1inv = np.linalg.inv(K1[:2,:2])
+    K0inv = np.linalg.inv(K0[:2, :2])
+    K1inv = np.linalg.inv(K1[:2, :2])
 
-    kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T 
-    kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T
+    kpts0 = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T
+    kpts1 = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T
     E, mask = cv2.findEssentialMat(
         kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf
     )
@@ -51,31 +51,40 @@ def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
                 ret = (R, t, mask.ravel() > 0)
     return ret
 
+
 def estimate_pose_uncalibrated(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
     if len(kpts0) < 5:
         return None
     method = cv2.USAC_ACCURATE
     F, mask = cv2.findFundamentalMat(
-        kpts0, kpts1, ransacReprojThreshold=norm_thresh, confidence=conf, method=method, maxIters=10000
+        kpts0,
+        kpts1,
+        ransacReprojThreshold=norm_thresh,
+        confidence=conf,
+        method=method,
+        maxIters=10000,
     )
-    E = K1.T@F@K0
+    E = K1.T @ F @ K0
     ret = None
     if E is not None:
         best_num_inliers = 0
-        K0inv = np.linalg.inv(K0[:2,:2])
-        K1inv = np.linalg.inv(K1[:2,:2])
+        K0inv = np.linalg.inv(K0[:2, :2])
+        K1inv = np.linalg.inv(K1[:2, :2])
+
+        kpts0_n = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T
+        kpts1_n = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T
 
-        kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T 
-        kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
- 
         for _E in np.split(E, len(E) / 3):
-            n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
+            n, R, t, _ = cv2.recoverPose(
+                _E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask
+            )
             if n > best_num_inliers:
                 best_num_inliers = n
                 ret = (R, t, mask.ravel() > 0)
     return ret
 
-def unnormalize_coords(x_n,h,w):
+
+def unnormalize_coords(x_n, h, w):
     x = torch.stack(
         (w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1
     )  # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
@@ -155,6 +164,7 @@ def get_depth_tuple_transform_ops_nearest_exact(resize=None):
         ops.append(TupleResizeNearestExact(resize))
     return TupleCompose(ops)
 
+
 def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False):
     ops = []
     if resize:
@@ -162,7 +172,9 @@ def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False):
     return TupleCompose(ops)
 
 
-def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = False, colorjiggle_params = None):
+def get_tuple_transform_ops(
+    resize=None, normalize=True, unscale=False, clahe=False, colorjiggle_params=None
+):
     ops = []
     if resize:
         ops.append(TupleResize(resize))
@@ -173,6 +185,7 @@ def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe =
         )  # Imagenet mean/std
     return TupleCompose(ops)
 
+
 class ToTensorScaled(object):
     """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]"""
 
@@ -221,11 +234,15 @@ class TupleToTensorUnscaled(object):
     def __repr__(self):
         return "TupleToTensorUnscaled()"
 
+
 class TupleResizeNearestExact:
     def __init__(self, size):
         self.size = size
+
     def __call__(self, im_tuple):
-        return [F.interpolate(im, size = self.size, mode = 'nearest-exact') for im in im_tuple]
+        return [
+            F.interpolate(im, size=self.size, mode="nearest-exact") for im in im_tuple
+        ]
 
     def __repr__(self):
         return "TupleResizeNearestExact(size={})".format(self.size)
@@ -235,17 +252,19 @@ class TupleResize(object):
     def __init__(self, size, mode=InterpolationMode.BICUBIC):
         self.size = size
         self.resize = transforms.Resize(size, mode)
+
     def __call__(self, im_tuple):
         return [self.resize(im) for im in im_tuple]
 
     def __repr__(self):
         return "TupleResize(size={})".format(self.size)
-    
+
+
 class Normalize:
-    def __call__(self,im):
-        mean = im.mean(dim=(1,2), keepdims=True)
-        std = im.std(dim=(1,2), keepdims=True)
-        return (im-mean)/std
+    def __call__(self, im):
+        mean = im.mean(dim=(1, 2), keepdims=True)
+        std = im.std(dim=(1, 2), keepdims=True)
+        return (im - mean) / std
 
 
 class TupleNormalize(object):
@@ -255,7 +274,7 @@ class TupleNormalize(object):
         self.normalize = transforms.Normalize(mean=mean, std=std)
 
     def __call__(self, im_tuple):
-        c,h,w = im_tuple[0].shape
+        c, h, w = im_tuple[0].shape
         if c > 3:
             warnings.warn(f"Number of channels c={c} > 3, assuming first 3 are rgb")
         return [self.normalize(im[:3]) for im in im_tuple]
@@ -281,50 +300,82 @@ class TupleCompose(object):
         format_string += "\n)"
         return format_string
 
+
 @torch.no_grad()
-def cls_to_flow(cls, deterministic_sampling = True):
-    B,C,H,W = cls.shape
+def cls_to_flow(cls, deterministic_sampling=True):
+    B, C, H, W = cls.shape
     device = cls.device
     res = round(math.sqrt(C))
-    G = torch.meshgrid(*[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)])
-    G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
+    G = torch.meshgrid(
+        *[
+            torch.linspace(-1 + 1 / res, 1 - 1 / res, steps=res, device=device)
+            for _ in range(2)
+        ]
+    )
+    G = torch.stack([G[1], G[0]], dim=-1).reshape(C, 2)
     if deterministic_sampling:
         sampled_cls = cls.max(dim=1).indices
     else:
-        sampled_cls = torch.multinomial(cls.permute(0,2,3,1).reshape(B*H*W,C).softmax(dim=-1), 1).reshape(B,H,W)
+        sampled_cls = torch.multinomial(
+            cls.permute(0, 2, 3, 1).reshape(B * H * W, C).softmax(dim=-1), 1
+        ).reshape(B, H, W)
     flow = G[sampled_cls]
     return flow
 
+
 @torch.no_grad()
 def cls_to_flow_refine(cls):
-    B,C,H,W = cls.shape
+    B, C, H, W = cls.shape
     device = cls.device
     res = round(math.sqrt(C))
-    G = torch.meshgrid(*[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)])
-    G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
+    G = torch.meshgrid(
+        *[
+            torch.linspace(-1 + 1 / res, 1 - 1 / res, steps=res, device=device)
+            for _ in range(2)
+        ]
+    )
+    G = torch.stack([G[1], G[0]], dim=-1).reshape(C, 2)
     cls = cls.softmax(dim=1)
     mode = cls.max(dim=1).indices
-    
-    index = torch.stack((mode-1, mode, mode+1, mode - res, mode + res), dim = 1).clamp(0,C - 1).long()
-    neighbours = torch.gather(cls, dim = 1, index = index)[...,None]
-    flow = neighbours[:,0] * G[index[:,0]] + neighbours[:,1] * G[index[:,1]] + neighbours[:,2] * G[index[:,2]] + neighbours[:,3] * G[index[:,3]] + neighbours[:,4] * G[index[:,4]]
-    tot_prob = neighbours.sum(dim=1)  
+
+    index = (
+        torch.stack((mode - 1, mode, mode + 1, mode - res, mode + res), dim=1)
+        .clamp(0, C - 1)
+        .long()
+    )
+    neighbours = torch.gather(cls, dim=1, index=index)[..., None]
+    flow = (
+        neighbours[:, 0] * G[index[:, 0]]
+        + neighbours[:, 1] * G[index[:, 1]]
+        + neighbours[:, 2] * G[index[:, 2]]
+        + neighbours[:, 3] * G[index[:, 3]]
+        + neighbours[:, 4] * G[index[:, 4]]
+    )
+    tot_prob = neighbours.sum(dim=1)
     flow = flow / tot_prob
     return flow
 
 
-def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None):
-    
+def get_gt_warp(
+    depth1,
+    depth2,
+    T_1to2,
+    K1,
+    K2,
+    depth_interpolation_mode="bilinear",
+    relative_depth_error_threshold=0.05,
+    H=None,
+    W=None,
+):
+
     if H is None:
-        B,H,W = depth1.shape
+        B, H, W = depth1.shape
     else:
         B = depth1.shape[0]
     with torch.no_grad():
         x1_n = torch.meshgrid(
             *[
-                torch.linspace(
-                    -1 + 1 / n, 1 - 1 / n, n, device=depth1.device
-                )
+                torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=depth1.device)
                 for n in (B, H, W)
             ]
         )
@@ -336,15 +387,27 @@ def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bili
             T_1to2.double(),
             K1.double(),
             K2.double(),
-            depth_interpolation_mode = depth_interpolation_mode,
-            relative_depth_error_threshold = relative_depth_error_threshold,
+            depth_interpolation_mode=depth_interpolation_mode,
+            relative_depth_error_threshold=relative_depth_error_threshold,
         )
         prob = mask.float().reshape(B, H, W)
         x2 = x2.reshape(B, H, W, 2)
         return x2, prob
 
+
 @torch.no_grad()
-def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05):
+def warp_kpts(
+    kpts0,
+    depth0,
+    depth1,
+    T_0to1,
+    K0,
+    K1,
+    smooth_mask=False,
+    return_relative_depth_error=False,
+    depth_interpolation_mode="bilinear",
+    relative_depth_error_threshold=0.05,
+):
     """Warp kpts0 from I0 to I1 with depth, K and Rt
     Also check covisibility and depth consistency.
     Depth is consistent if relative error < 0.2 (hard-coded).
@@ -369,26 +432,44 @@ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return
         # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation
         if smooth_mask:
             raise NotImplementedError("Combined bilinear and NN warp not implemented")
-        valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, 
-                  smooth_mask = smooth_mask, 
-                  return_relative_depth_error = return_relative_depth_error, 
-                  depth_interpolation_mode = "bilinear",
-                  relative_depth_error_threshold = relative_depth_error_threshold)
-        valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, 
-                  smooth_mask = smooth_mask, 
-                  return_relative_depth_error = return_relative_depth_error, 
-                  depth_interpolation_mode = "nearest-exact",
-                  relative_depth_error_threshold = relative_depth_error_threshold)
-        nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest) 
+        valid_bilinear, warp_bilinear = warp_kpts(
+            kpts0,
+            depth0,
+            depth1,
+            T_0to1,
+            K0,
+            K1,
+            smooth_mask=smooth_mask,
+            return_relative_depth_error=return_relative_depth_error,
+            depth_interpolation_mode="bilinear",
+            relative_depth_error_threshold=relative_depth_error_threshold,
+        )
+        valid_nearest, warp_nearest = warp_kpts(
+            kpts0,
+            depth0,
+            depth1,
+            T_0to1,
+            K0,
+            K1,
+            smooth_mask=smooth_mask,
+            return_relative_depth_error=return_relative_depth_error,
+            depth_interpolation_mode="nearest-exact",
+            relative_depth_error_threshold=relative_depth_error_threshold,
+        )
+        nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest)
         warp = warp_bilinear.clone()
-        warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid]
+        warp[nearest_valid_bilinear_invalid] = warp_nearest[
+            nearest_valid_bilinear_invalid
+        ]
         valid = valid_bilinear | valid_nearest
         return valid, warp
-        
-        
-    kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[
-        :, 0, :, 0
-    ]
+
+    kpts0_depth = F.grid_sample(
+        depth0[:, None],
+        kpts0[:, :, None],
+        mode=depth_interpolation_mode,
+        align_corners=False,
+    )[:, 0, :, 0]
     kpts0 = torch.stack(
         (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
     )  # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
@@ -427,22 +508,26 @@ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return
     # w_kpts0[~covisible_mask, :] = -5 # xd
 
     w_kpts0_depth = F.grid_sample(
-        depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False
+        depth1[:, None],
+        w_kpts0[:, :, None],
+        mode=depth_interpolation_mode,
+        align_corners=False,
     )[:, 0, :, 0]
-    
+
     relative_depth_error = (
         (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
     ).abs()
     if not smooth_mask:
         consistent_mask = relative_depth_error < relative_depth_error_threshold
     else:
-        consistent_mask = (-relative_depth_error/smooth_mask).exp()
+        consistent_mask = (-relative_depth_error / smooth_mask).exp()
     valid_mask = nonzero_mask * covisible_mask * consistent_mask
     if return_relative_depth_error:
         return relative_depth_error, w_kpts0
     else:
         return valid_mask, w_kpts0
 
+
 imagenet_mean = torch.tensor([0.485, 0.456, 0.406])
 imagenet_std = torch.tensor([0.229, 0.224, 0.225])
 
@@ -462,7 +547,9 @@ def numpy_to_pil(x: np.ndarray):
 
 def tensor_to_pil(x, unnormalize=False):
     if unnormalize:
-        x = x * (imagenet_std[:, None, None].to(x.device)) + (imagenet_mean[:, None, None].to(x.device))
+        x = x * (imagenet_std[:, None, None].to(x.device)) + (
+            imagenet_mean[:, None, None].to(x.device)
+        )
     x = x.detach().permute(1, 2, 0).cpu().numpy()
     x = np.clip(x, 0.0, 1.0)
     return numpy_to_pil(x)
@@ -492,70 +579,63 @@ def compute_relative_pose(R1, t1, R2, t2):
     trans = -rots @ t1 + t2
     return rots, trans
 
+
 @torch.no_grad()
 def reset_opt(opt):
     for group in opt.param_groups:
-        for p in group['params']:
+        for p in group["params"]:
             if p.requires_grad:
                 state = opt.state[p]
                 # State initialization
 
                 # Exponential moving average of gradient values
-                state['exp_avg'] = torch.zeros_like(p)
+                state["exp_avg"] = torch.zeros_like(p)
                 # Exponential moving average of squared gradient values
-                state['exp_avg_sq'] = torch.zeros_like(p)
+                state["exp_avg_sq"] = torch.zeros_like(p)
                 # Exponential moving average of gradient difference
-                state['exp_avg_diff'] = torch.zeros_like(p)
+                state["exp_avg_diff"] = torch.zeros_like(p)
 
 
 def flow_to_pixel_coords(flow, h1, w1):
-    flow = (
-        torch.stack(
-            (
-                w1 * (flow[..., 0] + 1) / 2,
-                h1 * (flow[..., 1] + 1) / 2,
-            ),
-            axis=-1,
-        )
+    flow = torch.stack(
+        (
+            w1 * (flow[..., 0] + 1) / 2,
+            h1 * (flow[..., 1] + 1) / 2,
+        ),
+        axis=-1,
     )
     return flow
 
+
 def flow_to_normalized_coords(flow, h1, w1):
-    flow = (
-        torch.stack(
-            (
-                2 * (flow[..., 0]) / w1 - 1,
-                2 * (flow[..., 1]) / h1 - 1,
-            ),
-            axis=-1,
-        )
+    flow = torch.stack(
+        (
+            2 * (flow[..., 0]) / w1 - 1,
+            2 * (flow[..., 1]) / h1 - 1,
+        ),
+        axis=-1,
     )
     return flow
 
 
 def warp_to_pixel_coords(warp, h1, w1, h2, w2):
     warp1 = warp[..., :2]
-    warp1 = (
-        torch.stack(
-            (
-                w1 * (warp1[..., 0] + 1) / 2,
-                h1 * (warp1[..., 1] + 1) / 2,
-            ),
-            axis=-1,
-        )
+    warp1 = torch.stack(
+        (
+            w1 * (warp1[..., 0] + 1) / 2,
+            h1 * (warp1[..., 1] + 1) / 2,
+        ),
+        axis=-1,
     )
     warp2 = warp[..., 2:]
-    warp2 = (
-        torch.stack(
-            (
-                w2 * (warp2[..., 0] + 1) / 2,
-                h2 * (warp2[..., 1] + 1) / 2,
-            ),
-            axis=-1,
-        )
+    warp2 = torch.stack(
+        (
+            w2 * (warp2[..., 0] + 1) / 2,
+            h2 * (warp2[..., 1] + 1) / 2,
+        ),
+        axis=-1,
     )
-    return torch.cat((warp1,warp2), dim=-1)
-
+    return torch.cat((warp1, warp2), dim=-1)
 
 
 def signed_point_line_distance(point, line, eps: float = 1e-9):
@@ -576,7 +656,9 @@ def signed_point_line_distance(point, line, eps: float = 1e-9):
     if not line.shape[-1] == 3:
         raise ValueError(f"lines must be a (*, 3) tensor. Got {line.shape}")
 
-    numerator = (line[..., 0] * point[..., 0] + line[..., 1] * point[..., 1] + line[..., 2])
+    numerator = (
+        line[..., 0] * point[..., 0] + line[..., 1] * point[..., 1] + line[..., 2]
+    )
     denominator = line[..., :2].norm(dim=-1)
 
     return numerator / (denominator + eps)
@@ -600,6 +682,7 @@ def signed_left_to_right_epipolar_distance(pts1, pts2, Fm):
         the computed Symmetrical distance with shape :math:`(*, N)`.
     """
     import kornia
+
     if (len(Fm.shape) < 3) or not Fm.shape[-2:] == (3, 3):
         raise ValueError(f"Fm must be a (*, 3, 3) tensor. Got {Fm.shape}")
 
@@ -611,12 +694,10 @@ def signed_left_to_right_epipolar_distance(pts1, pts2, Fm):
 
     return signed_point_line_distance(pts2, line1_in_2)
 
+
 def get_grid(b, h, w, device):
     grid = torch.meshgrid(
-        *[
-            torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device)
-            for n in (b, h, w)
-        ]
+        *[torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device) for n in (b, h, w)]
     )
     grid = torch.stack((grid[2], grid[1]), dim=-1).reshape(b, h, w, 2)
     return grid
diff --git a/third_party/SGMNet/components/__init__.py b/third_party/SGMNet/components/__init__.py
index c10d2027efcf985c68abf7185f28b947012cae45..a3a974825d770263feafa99fb09b7b656602584d 100644
--- a/third_party/SGMNet/components/__init__.py
+++ b/third_party/SGMNet/components/__init__.py
@@ -1,3 +1,3 @@
-from . import extractors 
+from . import extractors
 from . import matchers
-from .load_component import load_component
\ No newline at end of file
+from .load_component import load_component
diff --git a/third_party/SGMNet/components/evaluators.py b/third_party/SGMNet/components/evaluators.py
index 59bf0bd7ce3dd085dc86072fc41bad24b9805991..a59af1a1614cfa217b6c50be9826e0ee1832191c 100644
--- a/third_party/SGMNet/components/evaluators.py
+++ b/third_party/SGMNet/components/evaluators.py
@@ -1,127 +1,181 @@
 import numpy as np
 import sys
 import os
+
 ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
 sys.path.insert(0, ROOT_DIR)
 
-from utils import evaluation_utils,metrics,fm_utils
+from utils import evaluation_utils, metrics, fm_utils
 import cv2
 
+
 class auc_eval:
-    def __init__(self,config):
-        self.config=config
-        self.err_r,self.err_t,self.err=[],[],[]
-        self.ms=[]
-        self.precision=[]
-
-    def run(self,info):
-        E,r_gt,t_gt=info['e'],info['r_gt'],info['t_gt']
-        K1,K2,img1,img2=info['K1'],info['K2'],info['img1'],info['img2']
-        corr1,corr2=info['corr1'],info['corr2']
-        corr1,corr2=evaluation_utils.normalize_intrinsic(corr1,K1),evaluation_utils.normalize_intrinsic(corr2,K2)
-        size1,size2=max(img1.shape),max(img2.shape)
-        scale1,scale2=self.config['rescale']/size1,self.config['rescale']/size2
-        #ransac
-        ransac_th=4./((K1[0,0]+K1[1,1])*scale1+(K2[0,0]+K2[1,1])*scale2)
-        R_hat,t_hat,E_hat=self.estimate(corr1,corr2,ransac_th)
-        #get pose error
-        err_r, err_t=metrics.evaluate_R_t(r_gt,t_gt,R_hat,t_hat)
-        err=max(err_r,err_t)
-        
-        if len(corr1)>1:
-            inlier_mask=metrics.compute_epi_inlier(corr1,corr2,E,self.config['inlier_th'])
-            precision=inlier_mask.mean()
-            ms=inlier_mask.sum()/len(info['x1'])
+    def __init__(self, config):
+        self.config = config
+        self.err_r, self.err_t, self.err = [], [], []
+        self.ms = []
+        self.precision = []
+
+    def run(self, info):
+        E, r_gt, t_gt = info["e"], info["r_gt"], info["t_gt"]
+        K1, K2, img1, img2 = info["K1"], info["K2"], info["img1"], info["img2"]
+        corr1, corr2 = info["corr1"], info["corr2"]
+        corr1, corr2 = evaluation_utils.normalize_intrinsic(
+            corr1, K1
+        ), evaluation_utils.normalize_intrinsic(corr2, K2)
+        size1, size2 = max(img1.shape), max(img2.shape)
+        scale1, scale2 = self.config["rescale"] / size1, self.config["rescale"] / size2
+        # ransac
+        ransac_th = 4.0 / (
+            (K1[0, 0] + K1[1, 1]) * scale1 + (K2[0, 0] + K2[1, 1]) * scale2
+        )
+        R_hat, t_hat, E_hat = self.estimate(corr1, corr2, ransac_th)
+        # get pose error
+        err_r, err_t = metrics.evaluate_R_t(r_gt, t_gt, R_hat, t_hat)
+        err = max(err_r, err_t)
+
+        if len(corr1) > 1:
+            inlier_mask = metrics.compute_epi_inlier(
+                corr1, corr2, E, self.config["inlier_th"]
+            )
+            precision = inlier_mask.mean()
+            ms = inlier_mask.sum() / len(info["x1"])
         else:
-            ms=precision=0
-        
-        return {'err_r':err_r,'err_t':err_t,'err':err,'ms':ms,'precision':precision}
-
-    def res_inqueue(self,res):
-        self.err_r.append(res['err_r']),self.err_t.append(res['err_t']),self.err.append(res['err'])
-        self.ms.append(res['ms']),self.precision.append(res['precision'])
-
-    def estimate(self,corr1,corr2,th):
+            ms = precision = 0
+
+        return {
+            "err_r": err_r,
+            "err_t": err_t,
+            "err": err,
+            "ms": ms,
+            "precision": precision,
+        }
+
+    def res_inqueue(self, res):
+        self.err_r.append(res["err_r"]), self.err_t.append(
+            res["err_t"]
+        ), self.err.append(res["err"])
+        self.ms.append(res["ms"]), self.precision.append(res["precision"])
+
+    def estimate(self, corr1, corr2, th):
         num_inlier = -1
         if corr1.shape[0] >= 5:
-            E, mask_new = cv2.findEssentialMat(corr1, corr2,method=cv2.RANSAC, threshold=th,prob=1-1e-5)
+            E, mask_new = cv2.findEssentialMat(
+                corr1, corr2, method=cv2.RANSAC, threshold=th, prob=1 - 1e-5
+            )
             if E is None:
-                E=[np.eye(3)]
+                E = [np.eye(3)]
             for _E in np.split(E, len(E) / 3):
-                _num_inlier, _R, _t, _ = cv2.recoverPose(_E, corr1, corr2,np.eye(3), 1e9,mask=mask_new)
+                _num_inlier, _R, _t, _ = cv2.recoverPose(
+                    _E, corr1, corr2, np.eye(3), 1e9, mask=mask_new
+                )
                 if _num_inlier > num_inlier:
                     num_inlier = _num_inlier
                     R = _R
                     t = _t
                     E = _E
         else:
-            E,R,t=np.eye(3),np.eye(3),np.zeros(3)
-        return R,t,E
+            E, R, t = np.eye(3), np.eye(3), np.zeros(3)
+        return R, t, E
 
     def parse(self):
         ths = np.arange(7) * 5
-        approx_auc=metrics.approx_pose_auc(self.err,ths)
-        exact_auc=metrics.pose_auc(self.err,ths)
-        mean_pre,mean_ms=np.mean(np.asarray(self.precision)),np.mean(np.asarray(self.ms))
-        
-        print('auc th: ',ths[1:])
-        print('approx auc: ',approx_auc)
-        print('exact auc: ', exact_auc)
-        print('mean match score: ',mean_ms*100)
-        print('mean precision: ',mean_pre*100)
-
-        
-
-class FMbench_eval:
+        approx_auc = metrics.approx_pose_auc(self.err, ths)
+        exact_auc = metrics.pose_auc(self.err, ths)
+        mean_pre, mean_ms = np.mean(np.asarray(self.precision)), np.mean(
+            np.asarray(self.ms)
+        )
 
-    def __init__(self,config):
-        self.config=config
-        self.pre,self.pre_post,self.sgd=[],[],[]
-        self.num_corr,self.num_corr_post=[],[]
+        print("auc th: ", ths[1:])
+        print("approx auc: ", approx_auc)
+        print("exact auc: ", exact_auc)
+        print("mean match score: ", mean_ms * 100)
+        print("mean precision: ", mean_pre * 100)
 
-    def run(self,info):
-        corr1,corr2=info['corr1'],info['corr2']
-        F=info['f']
-        img1,img2=info['img1'],info['img2']
 
-        if len(corr1)>1:
-            pre_bf=fm_utils.compute_inlier_rate(corr1,corr2,np.flip(img1.shape[:2]),np.flip(img2.shape[:2]),F,th=self.config['inlier_th']).mean()
-            F_hat,mask_F=cv2.findFundamentalMat(corr1,corr2,method=cv2.FM_RANSAC,ransacReprojThreshold=1,confidence=1-1e-5)
+class FMbench_eval:
+    def __init__(self, config):
+        self.config = config
+        self.pre, self.pre_post, self.sgd = [], [], []
+        self.num_corr, self.num_corr_post = [], []
+
+    def run(self, info):
+        corr1, corr2 = info["corr1"], info["corr2"]
+        F = info["f"]
+        img1, img2 = info["img1"], info["img2"]
+
+        if len(corr1) > 1:
+            pre_bf = fm_utils.compute_inlier_rate(
+                corr1,
+                corr2,
+                np.flip(img1.shape[:2]),
+                np.flip(img2.shape[:2]),
+                F,
+                th=self.config["inlier_th"],
+            ).mean()
+            F_hat, mask_F = cv2.findFundamentalMat(
+                corr1,
+                corr2,
+                method=cv2.FM_RANSAC,
+                ransacReprojThreshold=1,
+                confidence=1 - 1e-5,
+            )
             if F_hat is None:
-                F_hat=np.ones([3,3])
-                mask_F=np.ones([len(corr1)]).astype(bool)
+                F_hat = np.ones([3, 3])
+                mask_F = np.ones([len(corr1)]).astype(bool)
             else:
-                mask_F=mask_F.squeeze().astype(bool)
-            F_hat=F_hat[:3]
-            pre_af=fm_utils.compute_inlier_rate(corr1[mask_F],corr2[mask_F],np.flip(img1.shape[:2]),np.flip(img2.shape[:2]),F,th=self.config['inlier_th']).mean()
-            num_corr_af=mask_F.sum()
-            num_corr=len(corr1)
-            sgd=fm_utils.compute_SGD(F,F_hat,np.flip(img1.shape[:2]),np.flip(img2.shape[:2]))
+                mask_F = mask_F.squeeze().astype(bool)
+            F_hat = F_hat[:3]
+            pre_af = fm_utils.compute_inlier_rate(
+                corr1[mask_F],
+                corr2[mask_F],
+                np.flip(img1.shape[:2]),
+                np.flip(img2.shape[:2]),
+                F,
+                th=self.config["inlier_th"],
+            ).mean()
+            num_corr_af = mask_F.sum()
+            num_corr = len(corr1)
+            sgd = fm_utils.compute_SGD(
+                F, F_hat, np.flip(img1.shape[:2]), np.flip(img2.shape[:2])
+            )
         else:
-            pre_bf,pre_af,sgd=0,0,1e8
-            num_corr,num_corr_af=0,0
-        return {'pre':pre_bf,'pre_post':pre_af,'sgd':sgd,'num_corr':num_corr,'num_corr_post':num_corr_af}
-
-
-    def res_inqueue(self,res):
-        self.pre.append(res['pre']),self.pre_post.append(res['pre_post']),self.sgd.append(res['sgd'])
-        self.num_corr.append(res['num_corr']),self.num_corr_post.append(res['num_corr_post'])
+            pre_bf, pre_af, sgd = 0, 0, 1e8
+            num_corr, num_corr_af = 0, 0
+        return {
+            "pre": pre_bf,
+            "pre_post": pre_af,
+            "sgd": sgd,
+            "num_corr": num_corr,
+            "num_corr_post": num_corr_af,
+        }
+
+    def res_inqueue(self, res):
+        self.pre.append(res["pre"]), self.pre_post.append(
+            res["pre_post"]
+        ), self.sgd.append(res["sgd"])
+        self.num_corr.append(res["num_corr"]), self.num_corr_post.append(
+            res["num_corr_post"]
+        )
 
     def parse(self):
-        for seq_index in range(len(self.config['seq'])):
-            seq=self.config['seq'][seq_index]
-            offset=seq_index*1000
-            pre=np.asarray(self.pre)[offset:offset+1000].mean()
-            pre_post=np.asarray(self.pre_post)[offset:offset+1000].mean()
-            num_corr=np.asarray(self.num_corr)[offset:offset+1000].mean()
-            num_corr_post=np.asarray(self.num_corr_post)[offset:offset+1000].mean()
-            f_recall=(np.asarray(self.sgd)[offset:offset+1000]<self.config['sgd_inlier_th']).mean()
-
-            print(seq,'results:')
-            print('F_recall: ',f_recall)
-            print('precision: ',pre)
-            print('precision_post: ',pre_post)
-            print('num_corr: ',num_corr)
-            print('num_corr_post: ',num_corr_post,'\n')
-
-
+        for seq_index in range(len(self.config["seq"])):
+            seq = self.config["seq"][seq_index]
+            offset = seq_index * 1000
+            pre = np.asarray(self.pre)[offset : offset + 1000].mean()
+            pre_post = np.asarray(self.pre_post)[offset : offset + 1000].mean()
+            num_corr = np.asarray(self.num_corr)[offset : offset + 1000].mean()
+            num_corr_post = np.asarray(self.num_corr_post)[
+                offset : offset + 1000
+            ].mean()
+            f_recall = (
+                np.asarray(self.sgd)[offset : offset + 1000]
+                < self.config["sgd_inlier_th"]
+            ).mean()
+
+            print(seq, "results:")
+            print("F_recall: ", f_recall)
+            print("precision: ", pre)
+            print("precision_post: ", pre_post)
+            print("num_corr: ", num_corr)
+            print("num_corr_post: ", num_corr_post, "\n")
diff --git a/third_party/SGMNet/components/extractors.py b/third_party/SGMNet/components/extractors.py
index 43b03ab35b307fc8dd8af15f9bcf61c61d268918..8cd2a76aaaaf93fd16319b5a1e01f463b50a5d3b 100644
--- a/third_party/SGMNet/components/extractors.py
+++ b/third_party/SGMNet/components/extractors.py
@@ -4,83 +4,104 @@ import torch
 import os
 
 import sys
+
 ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
 sys.path.insert(0, ROOT_DIR)
 
 from superpoint import SuperPoint
 
 
-def resize(img,resize):
-    img_h,img_w=img.shape[0],img.shape[1]
-    cur_size=max(img_h,img_w)
-    if len(resize)==1: 
-      scale1,scale2=resize[0]/cur_size,resize[0]/cur_size
+def resize(img, resize):
+    img_h, img_w = img.shape[0], img.shape[1]
+    cur_size = max(img_h, img_w)
+    if len(resize) == 1:
+        scale1, scale2 = resize[0] / cur_size, resize[0] / cur_size
     else:
-      scale1,scale2=resize[0]/img_h,resize[1]/img_w
-    new_h,new_w=int(img_h*scale1),int(img_w*scale2)
-    new_img=cv2.resize(img.astype('float32'),(new_w,new_h)).astype('uint8')
-    scale=np.asarray([scale2,scale1])
-    return new_img,scale
+        scale1, scale2 = resize[0] / img_h, resize[1] / img_w
+    new_h, new_w = int(img_h * scale1), int(img_w * scale2)
+    new_img = cv2.resize(img.astype("float32"), (new_w, new_h)).astype("uint8")
+    scale = np.asarray([scale2, scale1])
+    return new_img, scale
 
 
 class ExtractSIFT:
-  def __init__(self,config,root=True):
-    self.num_kp=config['num_kpt']
-    self.contrastThreshold=config['det_th']
-    self.resize=config['resize']
-    self.root=root
-
-  def run(self, img_path):
-    self.sift = cv2.xfeatures2d.SIFT_create(nfeatures=self.num_kp, contrastThreshold=self.contrastThreshold)
-    img = cv2.imread(img_path,cv2.IMREAD_GRAYSCALE)
-    scale=[1,1]
-    if self.resize[0]!=-1:
-      img,scale=resize(img,self.resize)
-    cv_kp, desc = self.sift.detectAndCompute(img, None)
-    kp = np.array([[_kp.pt[0]/scale[1], _kp.pt[1]/scale[0], _kp.response] for _kp in cv_kp]) # N*3
-    index=np.flip(np.argsort(kp[:,2]))
-    kp,desc=kp[index],desc[index]
-    if self.root:
-      desc=np.sqrt(abs(desc/(np.linalg.norm(desc,axis=-1,ord=1)[:,np.newaxis]+1e-8)))
-    return kp[:self.num_kp], desc[:self.num_kp]
+    def __init__(self, config, root=True):
+        self.num_kp = config["num_kpt"]
+        self.contrastThreshold = config["det_th"]
+        self.resize = config["resize"]
+        self.root = root
 
+    def run(self, img_path):
+        self.sift = cv2.xfeatures2d.SIFT_create(
+            nfeatures=self.num_kp, contrastThreshold=self.contrastThreshold
+        )
+        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
+        scale = [1, 1]
+        if self.resize[0] != -1:
+            img, scale = resize(img, self.resize)
+        cv_kp, desc = self.sift.detectAndCompute(img, None)
+        kp = np.array(
+            [
+                [_kp.pt[0] / scale[1], _kp.pt[1] / scale[0], _kp.response]
+                for _kp in cv_kp
+            ]
+        )  # N*3
+        index = np.flip(np.argsort(kp[:, 2]))
+        kp, desc = kp[index], desc[index]
+        if self.root:
+            desc = np.sqrt(
+                abs(desc / (np.linalg.norm(desc, axis=-1, ord=1)[:, np.newaxis] + 1e-8))
+            )
+        return kp[: self.num_kp], desc[: self.num_kp]
 
 
 class ExtractSuperpoint(object):
-  def __init__(self,config):
-    default_config = {
-      'descriptor_dim': 256,
-      'nms_radius': 4,
-      'detection_threshold': config['det_th'],
-      'max_keypoints': config['num_kpt'],
-      'remove_borders': 4,
-      'model_path':'../weights/sp/superpoint_v1.pth'
-    }
-    self.superpoint_extractor=SuperPoint(default_config)
-    self.superpoint_extractor.eval(),self.superpoint_extractor.cuda()
-    self.num_kp=config['num_kpt']
-    if 'padding' in config.keys():
-      self.padding=config['padding']
-    else:
-      self.padding=False
-    self.resize=config['resize']
+    def __init__(self, config):
+        default_config = {
+            "descriptor_dim": 256,
+            "nms_radius": 4,
+            "detection_threshold": config["det_th"],
+            "max_keypoints": config["num_kpt"],
+            "remove_borders": 4,
+            "model_path": "../weights/sp/superpoint_v1.pth",
+        }
+        self.superpoint_extractor = SuperPoint(default_config)
+        self.superpoint_extractor.eval(), self.superpoint_extractor.cuda()
+        self.num_kp = config["num_kpt"]
+        if "padding" in config.keys():
+            self.padding = config["padding"]
+        else:
+            self.padding = False
+        self.resize = config["resize"]
 
-  def run(self,img_path):
-    img = cv2.imread(img_path,cv2.IMREAD_GRAYSCALE)
-    scale=1
-    if self.resize[0]!=-1:
-      img,scale=resize(img,self.resize)
-    with torch.no_grad():
-      result=self.superpoint_extractor(torch.from_numpy(img/255.).float()[None, None].cuda())
-    score,kpt,desc=result['scores'][0],result['keypoints'][0],result['descriptors'][0]
-    score,kpt,desc=score.cpu().numpy(),kpt.cpu().numpy(),desc.cpu().numpy().T
-    kpt=np.concatenate([kpt/scale,score[:,np.newaxis]],axis=-1)
-    #padding randomly
-    if self.padding:
-      if len(kpt)<self.num_kp:
-        res=int(self.num_kp-len(kpt))
-        pad_x,pad_desc=np.random.uniform(size=[res,2])*(img.shape[0]+img.shape[1])/2,np.random.uniform(size=[res,256])
-        pad_kpt,pad_desc=np.concatenate([pad_x,np.zeros([res,1])],axis=-1),pad_desc/np.linalg.norm(pad_desc,axis=-1)[:,np.newaxis]
-        kpt,desc=np.concatenate([kpt,pad_kpt],axis=0),np.concatenate([desc,pad_desc],axis=0)
-    return kpt,desc
-  
\ No newline at end of file
+    def run(self, img_path):
+        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
+        scale = 1
+        if self.resize[0] != -1:
+            img, scale = resize(img, self.resize)
+        with torch.no_grad():
+            result = self.superpoint_extractor(
+                torch.from_numpy(img / 255.0).float()[None, None].cuda()
+            )
+        score, kpt, desc = (
+            result["scores"][0],
+            result["keypoints"][0],
+            result["descriptors"][0],
+        )
+        score, kpt, desc = score.cpu().numpy(), kpt.cpu().numpy(), desc.cpu().numpy().T
+        kpt = np.concatenate([kpt / scale, score[:, np.newaxis]], axis=-1)
+        # padding randomly
+        if self.padding:
+            if len(kpt) < self.num_kp:
+                res = int(self.num_kp - len(kpt))
+                pad_x, pad_desc = np.random.uniform(size=[res, 2]) * (
+                    img.shape[0] + img.shape[1]
+                ) / 2, np.random.uniform(size=[res, 256])
+                pad_kpt, pad_desc = (
+                    np.concatenate([pad_x, np.zeros([res, 1])], axis=-1),
+                    pad_desc / np.linalg.norm(pad_desc, axis=-1)[:, np.newaxis],
+                )
+                kpt, desc = np.concatenate([kpt, pad_kpt], axis=0), np.concatenate(
+                    [desc, pad_desc], axis=0
+                )
+        return kpt, desc
diff --git a/third_party/SGMNet/components/load_component.py b/third_party/SGMNet/components/load_component.py
index d934655c234ec279a44fef4d3a60fe411acef9f8..1d46389bf64640dc928d08132765b9b4d5e0a8ad 100644
--- a/third_party/SGMNet/components/load_component.py
+++ b/third_party/SGMNet/components/load_component.py
@@ -3,50 +3,54 @@ from . import readers
 from . import evaluators
 from . import extractors
 
-def load_component(compo_name,model_name,config):
-    if compo_name=='extractor':
-        component=load_extractor(model_name,config)
-    elif compo_name=='reader':
-        component=load_reader(model_name,config)
-    elif compo_name=='matcher':
-        component=load_matcher(model_name,config)
-    elif compo_name=='evaluator':
-        component=load_evaluator(model_name,config)
+
+def load_component(compo_name, model_name, config):
+    if compo_name == "extractor":
+        component = load_extractor(model_name, config)
+    elif compo_name == "reader":
+        component = load_reader(model_name, config)
+    elif compo_name == "matcher":
+        component = load_matcher(model_name, config)
+    elif compo_name == "evaluator":
+        component = load_evaluator(model_name, config)
     else:
         raise NotImplementedError
     return component
 
 
-def load_extractor(model_name,config):
-    if model_name=='root':
-        extractor =extractors.ExtractSIFT(config)
-    elif model_name=='sp':
-        extractor=extractors.ExtractSuperpoint(config)
+def load_extractor(model_name, config):
+    if model_name == "root":
+        extractor = extractors.ExtractSIFT(config)
+    elif model_name == "sp":
+        extractor = extractors.ExtractSuperpoint(config)
     else:
         raise NotImplementedError
     return extractor
 
-def load_matcher(model_name,config):
-    if model_name=='SGM':
-        matcher=matchers.GNN_Matcher(config,'SGM')
-    elif model_name=='SG':
-        matcher=matchers.GNN_Matcher(config,'SG')
-    elif model_name=='NN':
-        matcher=matchers.NN_Matcher(config)
+
+def load_matcher(model_name, config):
+    if model_name == "SGM":
+        matcher = matchers.GNN_Matcher(config, "SGM")
+    elif model_name == "SG":
+        matcher = matchers.GNN_Matcher(config, "SG")
+    elif model_name == "NN":
+        matcher = matchers.NN_Matcher(config)
     else:
         raise NotImplementedError
     return matcher
 
-def load_reader(model_name,config):
-    if model_name=='standard':
-        reader=readers.standard_reader(config)
+
+def load_reader(model_name, config):
+    if model_name == "standard":
+        reader = readers.standard_reader(config)
     else:
         raise NotImplementedError
     return reader
 
-def load_evaluator(model_name,config):
-    if model_name=='AUC':
-        evaluator=evaluators.auc_eval(config)
-    elif model_name=='FM':
-        evaluator=evaluators.FMbench_eval(config)
+
+def load_evaluator(model_name, config):
+    if model_name == "AUC":
+        evaluator = evaluators.auc_eval(config)
+    elif model_name == "FM":
+        evaluator = evaluators.FMbench_eval(config)
     return evaluator
diff --git a/third_party/SGMNet/components/matchers.py b/third_party/SGMNet/components/matchers.py
index 95fa8897bfaff91372466b3c811106d0bdb69f34..3e160b2fba5a73581b88b6f74816b15981e02ee7 100644
--- a/third_party/SGMNet/components/matchers.py
+++ b/third_party/SGMNet/components/matchers.py
@@ -1,8 +1,9 @@
 import torch
 import numpy as np
 import os
-from collections import OrderedDict,namedtuple
+from collections import OrderedDict, namedtuple
 import sys
+
 ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
 sys.path.insert(0, ROOT_DIR)
 
@@ -10,75 +11,92 @@ from sgmnet import matcher as SGM_Model
 from superglue import matcher as SG_Model
 from utils import evaluation_utils
 
-class GNN_Matcher(object):
 
-    def __init__(self,config,model_name):
-        assert model_name=='SGM' or model_name=='SG'
+class GNN_Matcher(object):
+    def __init__(self, config, model_name):
+        assert model_name == "SGM" or model_name == "SG"
 
-        config=namedtuple('config',config.keys())(*config.values())
-        self.p_th=config.p_th
-        self.model = SGM_Model(config) if model_name=='SGM' else SG_Model(config) 
-        self.model.cuda(),self.model.eval()
-        checkpoint = torch.load(os.path.join(config.model_dir, 'model_best.pth'))
-        #for ddp model
-        if list(checkpoint['state_dict'].items())[0][0].split('.')[0]=='module':
-            new_stat_dict=OrderedDict()
-            for key,value in checkpoint['state_dict'].items():
-                new_stat_dict[key[7:]]=value
-            checkpoint['state_dict']=new_stat_dict
-        self.model.load_state_dict(checkpoint['state_dict'])
+        config = namedtuple("config", config.keys())(*config.values())
+        self.p_th = config.p_th
+        self.model = SGM_Model(config) if model_name == "SGM" else SG_Model(config)
+        self.model.cuda(), self.model.eval()
+        checkpoint = torch.load(os.path.join(config.model_dir, "model_best.pth"))
+        # for ddp model
+        if list(checkpoint["state_dict"].items())[0][0].split(".")[0] == "module":
+            new_stat_dict = OrderedDict()
+            for key, value in checkpoint["state_dict"].items():
+                new_stat_dict[key[7:]] = value
+            checkpoint["state_dict"] = new_stat_dict
+        self.model.load_state_dict(checkpoint["state_dict"])
 
-    def run(self,test_data):
-        norm_x1,norm_x2=evaluation_utils.normalize_size(test_data['x1'][:,:2],test_data['size1']),\
-                                                    evaluation_utils.normalize_size(test_data['x2'][:,:2],test_data['size2'])
-        x1,x2=np.concatenate([norm_x1,test_data['x1'][:,2,np.newaxis]],axis=-1),np.concatenate([norm_x2,test_data['x2'][:,2,np.newaxis]],axis=-1)
-        feed_data={'x1':torch.from_numpy(x1[np.newaxis]).cuda().float(),
-                   'x2':torch.from_numpy(x2[np.newaxis]).cuda().float(),
-                   'desc1':torch.from_numpy(test_data['desc1'][np.newaxis]).cuda().float(),
-                   'desc2':torch.from_numpy(test_data['desc2'][np.newaxis]).cuda().float()}
+    def run(self, test_data):
+        norm_x1, norm_x2 = evaluation_utils.normalize_size(
+            test_data["x1"][:, :2], test_data["size1"]
+        ), evaluation_utils.normalize_size(test_data["x2"][:, :2], test_data["size2"])
+        x1, x2 = np.concatenate(
+            [norm_x1, test_data["x1"][:, 2, np.newaxis]], axis=-1
+        ), np.concatenate([norm_x2, test_data["x2"][:, 2, np.newaxis]], axis=-1)
+        feed_data = {
+            "x1": torch.from_numpy(x1[np.newaxis]).cuda().float(),
+            "x2": torch.from_numpy(x2[np.newaxis]).cuda().float(),
+            "desc1": torch.from_numpy(test_data["desc1"][np.newaxis]).cuda().float(),
+            "desc2": torch.from_numpy(test_data["desc2"][np.newaxis]).cuda().float(),
+        }
         with torch.no_grad():
-            res=self.model(feed_data,test_mode=True)
-            p=res['p']
-        index1,index2=self.match_p(p[0,:-1,:-1])
-        corr1,corr2=test_data['x1'][:,:2][index1.cpu()],test_data['x2'][:,:2][index2.cpu()]
-        if len(corr1.shape)==1:
-            corr1,corr2=corr1[np.newaxis],corr2[np.newaxis]
-        return corr1,corr2
-    
-    def match_p(self,p):#p N*M
-        score,index=torch.topk(p,k=1,dim=-1)
-        _,index2=torch.topk(p,k=1,dim=-2)
-        mask_th,index,index2=score[:,0]>self.p_th,index[:,0],index2.squeeze(0)
-        mask_mc=index2[index] == torch.arange(len(p)).cuda()
-        mask=mask_th&mask_mc
-        index1,index2=torch.nonzero(mask).squeeze(1),index[mask]
-        return index1,index2
+            res = self.model(feed_data, test_mode=True)
+            p = res["p"]
+        index1, index2 = self.match_p(p[0, :-1, :-1])
+        corr1, corr2 = (
+            test_data["x1"][:, :2][index1.cpu()],
+            test_data["x2"][:, :2][index2.cpu()],
+        )
+        if len(corr1.shape) == 1:
+            corr1, corr2 = corr1[np.newaxis], corr2[np.newaxis]
+        return corr1, corr2
 
+    def match_p(self, p):  # p N*M
+        score, index = torch.topk(p, k=1, dim=-1)
+        _, index2 = torch.topk(p, k=1, dim=-2)
+        mask_th, index, index2 = score[:, 0] > self.p_th, index[:, 0], index2.squeeze(0)
+        mask_mc = index2[index] == torch.arange(len(p)).cuda()
+        mask = mask_th & mask_mc
+        index1, index2 = torch.nonzero(mask).squeeze(1), index[mask]
+        return index1, index2
 
-class NN_Matcher(object):
 
-    def __init__(self,config):
-        config=namedtuple('config',config.keys())(*config.values())
-        self.mutual_check=config.mutual_check
-        self.ratio_th=config.ratio_th
+class NN_Matcher(object):
+    def __init__(self, config):
+        config = namedtuple("config", config.keys())(*config.values())
+        self.mutual_check = config.mutual_check
+        self.ratio_th = config.ratio_th
 
-    def run(self,test_data):
-        desc1,desc2,x1,x2=test_data['desc1'],test_data['desc2'],test_data['x1'],test_data['x2']
-        desc_mat=np.sqrt(abs((desc1**2).sum(-1)[:,np.newaxis]+(desc2**2).sum(-1)[np.newaxis]-2*desc1@desc2.T))
-        nn_index=np.argpartition(desc_mat,kth=(1,2),axis=-1)
-        dis_value12=np.take_along_axis(desc_mat,nn_index, axis=-1)
-        ratio_score=dis_value12[:,0]/dis_value12[:,1]
-        nn_index1=nn_index[:,0]
-        nn_index2=np.argmin(desc_mat,axis=0)
-        mask_ratio,mask_mutual=ratio_score<self.ratio_th,np.arange(len(x1))==nn_index2[nn_index1]
-        corr1,corr2=x1[:,:2],x2[:,:2][nn_index1]
+    def run(self, test_data):
+        desc1, desc2, x1, x2 = (
+            test_data["desc1"],
+            test_data["desc2"],
+            test_data["x1"],
+            test_data["x2"],
+        )
+        desc_mat = np.sqrt(
+            abs(
+                (desc1**2).sum(-1)[:, np.newaxis]
+                + (desc2**2).sum(-1)[np.newaxis]
+                - 2 * desc1 @ desc2.T
+            )
+        )
+        nn_index = np.argpartition(desc_mat, kth=(1, 2), axis=-1)
+        dis_value12 = np.take_along_axis(desc_mat, nn_index, axis=-1)
+        ratio_score = dis_value12[:, 0] / dis_value12[:, 1]
+        nn_index1 = nn_index[:, 0]
+        nn_index2 = np.argmin(desc_mat, axis=0)
+        mask_ratio, mask_mutual = (
+            ratio_score < self.ratio_th,
+            np.arange(len(x1)) == nn_index2[nn_index1],
+        )
+        corr1, corr2 = x1[:, :2], x2[:, :2][nn_index1]
         if self.mutual_check:
-            mask=mask_ratio&mask_mutual
+            mask = mask_ratio & mask_mutual
         else:
-            mask=mask_ratio
-        corr1,corr2=corr1[mask],corr2[mask]
-        return corr1,corr2
-
-
-
-
+            mask = mask_ratio
+        corr1, corr2 = corr1[mask], corr2[mask]
+        return corr1, corr2
diff --git a/third_party/SGMNet/components/readers.py b/third_party/SGMNet/components/readers.py
index b03d5e23c5553acefcaa007241270b82407de37b..e6c1e7dd5cb92afdeadf7f04a5086d7c14af22eb 100644
--- a/third_party/SGMNet/components/readers.py
+++ b/third_party/SGMNet/components/readers.py
@@ -3,30 +3,60 @@ import numpy as np
 import h5py
 import cv2
 
+
 class standard_reader:
-    def __init__(self,config):
-        self.raw_dir=config['rawdata_dir']
-        self.dataset=h5py.File(config['dataset_dir'],'r')
-        self.num_kpt=config['num_kpt']
-
-    def run(self,index):
-        K1,K2=np.asarray(self.dataset['K1'][str(index)]),np.asarray(self.dataset['K2'][str(index)])
-        R = np.asarray(self.dataset['R'][str(index)])
-        t = np.asarray(self.dataset['T'][str(index)])
-        t = t / np.sqrt((t ** 2).sum())
-
-        desc1,desc2=self.dataset['desc1'][str(index)][()][:self.num_kpt],self.dataset['desc2'][str(index)][()][:self.num_kpt]
-        x1, x2 = self.dataset['kpt1'][str(index)][()][:self.num_kpt], self.dataset['kpt2'][str(index)][()][:self.num_kpt]
-        e,f=self.dataset['e'][str(index)][()],self.dataset['f'][str(index)][()]
-
-        img1_path,img2_path=self.dataset['img_path1'][str(index)][()][0].decode(),self.dataset['img_path2'][str(index)][()][0].decode()
-        img1,img2=cv2.imread(os.path.join(self.raw_dir,img1_path)),cv2.imread(os.path.join(self.raw_dir,img2_path))
-   
-        info={'index':index,'K1':K1,'K2':K2,'R':R,'t':t,'x1':x1,'x2':x2,'desc1':desc1,'desc2':desc2,'img1':img1,'img2':img2,'e':e,'f':f,'r_gt':R,'t_gt':t}
+    def __init__(self, config):
+        self.raw_dir = config["rawdata_dir"]
+        self.dataset = h5py.File(config["dataset_dir"], "r")
+        self.num_kpt = config["num_kpt"]
+
+    def run(self, index):
+        K1, K2 = np.asarray(self.dataset["K1"][str(index)]), np.asarray(
+            self.dataset["K2"][str(index)]
+        )
+        R = np.asarray(self.dataset["R"][str(index)])
+        t = np.asarray(self.dataset["T"][str(index)])
+        t = t / np.sqrt((t**2).sum())
+
+        desc1, desc2 = (
+            self.dataset["desc1"][str(index)][()][: self.num_kpt],
+            self.dataset["desc2"][str(index)][()][: self.num_kpt],
+        )
+        x1, x2 = (
+            self.dataset["kpt1"][str(index)][()][: self.num_kpt],
+            self.dataset["kpt2"][str(index)][()][: self.num_kpt],
+        )
+        e, f = self.dataset["e"][str(index)][()], self.dataset["f"][str(index)][()]
+
+        img1_path, img2_path = (
+            self.dataset["img_path1"][str(index)][()][0].decode(),
+            self.dataset["img_path2"][str(index)][()][0].decode(),
+        )
+        img1, img2 = cv2.imread(os.path.join(self.raw_dir, img1_path)), cv2.imread(
+            os.path.join(self.raw_dir, img2_path)
+        )
+
+        info = {
+            "index": index,
+            "K1": K1,
+            "K2": K2,
+            "R": R,
+            "t": t,
+            "x1": x1,
+            "x2": x2,
+            "desc1": desc1,
+            "desc2": desc2,
+            "img1": img1,
+            "img2": img2,
+            "e": e,
+            "f": f,
+            "r_gt": R,
+            "t_gt": t,
+        }
         return info
 
     def close(self):
         self.dataset.close()
 
     def __len__(self):
-        return len(self.dataset['K1'])
\ No newline at end of file
+        return len(self.dataset["K1"])
diff --git a/third_party/SGMNet/datadump/check_training_data.py b/third_party/SGMNet/datadump/check_training_data.py
index 40cf939c4bf5217c85f04b0fd402f78526387bb7..0b2df392358206d702b60d9d06d28e4f969f570a 100644
--- a/third_party/SGMNet/datadump/check_training_data.py
+++ b/third_party/SGMNet/datadump/check_training_data.py
@@ -8,67 +8,93 @@ import pyxis as px
 from tqdm import trange
 
 import sys
+
 ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
 sys.path.insert(0, ROOT_DIR)
 
-from utils import evaluation_utils,train_utils
-
-parser = argparse.ArgumentParser(description='checking training data.')
-parser.add_argument('--meta_dir', type=str, default='dataset/valid')
-parser.add_argument('--dataset_dir', type=str, default='dataset')
-parser.add_argument('--desc_dir', type=str, default='desc')
-parser.add_argument('--raw_dir', type=str, default='raw_data')
-parser.add_argument('--desc_suffix', type=str, default='_root_1000.hdf5')
-parser.add_argument('--vis_folder',type=str,default=None)
-args=parser.parse_args()  
+from utils import evaluation_utils, train_utils
 
+parser = argparse.ArgumentParser(description="checking training data.")
+parser.add_argument("--meta_dir", type=str, default="dataset/valid")
+parser.add_argument("--dataset_dir", type=str, default="dataset")
+parser.add_argument("--desc_dir", type=str, default="desc")
+parser.add_argument("--raw_dir", type=str, default="raw_data")
+parser.add_argument("--desc_suffix", type=str, default="_root_1000.hdf5")
+parser.add_argument("--vis_folder", type=str, default=None)
+args = parser.parse_args()
 
 
-if __name__=='__main__':
+if __name__ == "__main__":
     if args.vis_folder is not None and not os.path.exists(args.vis_folder):
         os.mkdir(args.vis_folder)
 
-    pair_num_list=np.loadtxt(os.path.join(args.meta_dir,'pair_num.txt'),dtype=str)
-    pair_seq_list,accu_pair_list=train_utils.parse_pair_seq(pair_num_list)
-    total_pair=int(pair_num_list[0,1])
-    total_inlier_rate,total_corr_num,total_incorr_num=[],[],[]
-    pair_num_list=pair_num_list[1:]
+    pair_num_list = np.loadtxt(os.path.join(args.meta_dir, "pair_num.txt"), dtype=str)
+    pair_seq_list, accu_pair_list = train_utils.parse_pair_seq(pair_num_list)
+    total_pair = int(pair_num_list[0, 1])
+    total_inlier_rate, total_corr_num, total_incorr_num = [], [], []
+    pair_num_list = pair_num_list[1:]
 
     for index in trange(total_pair):
-        seq=pair_seq_list[index]
-        index_within_seq=index-accu_pair_list[seq]
-        with h5py.File(os.path.join(args.dataset_dir,seq,'info.h5py'),'r') as data:
-            corr=data['corr'][str(index_within_seq)][()]
-            corr1,corr2=corr[:,0],corr[:,1]
-            incorr1,incorr2=data['incorr1'][str(index_within_seq)][()],data['incorr2'][str(index_within_seq)][()]
-            img_path1,img_path2=data['img_path1'][str(index_within_seq)][()][0].decode(),data['img_path2'][str(index_within_seq)][()][0].decode()
-            img_name1,img_name2=img_path1.split('/')[-1],img_path2.split('/')[-1]
-            fea_path1,fea_path2=os.path.join(args.desc_dir,seq,img_name1+args.desc_suffix),os.path.join(args.desc_dir,seq,img_name2+args.desc_suffix)
-            with h5py.File(fea_path1,'r') as fea1, h5py.File(fea_path2,'r') as fea2:
-                desc1,kpt1=fea1['descriptors'][()],fea1['keypoints'][()][:,:2]
-                desc2,kpt2=fea2['descriptors'][()],fea2['keypoints'][()][:,:2]
-            sim_mat=desc1@desc2.T
-            nn_index1,nn_index2=np.argmax(sim_mat,axis=1),np.argmax(sim_mat,axis=0)
-            mask_mutual=(nn_index2[nn_index1]==np.arange(len(nn_index1)))[corr1]
-            mask_inlier=nn_index1[corr1]==corr2
-            mask_nn_correct=np.logical_and(mask_mutual,mask_inlier)
-            #statistics
+        seq = pair_seq_list[index]
+        index_within_seq = index - accu_pair_list[seq]
+        with h5py.File(os.path.join(args.dataset_dir, seq, "info.h5py"), "r") as data:
+            corr = data["corr"][str(index_within_seq)][()]
+            corr1, corr2 = corr[:, 0], corr[:, 1]
+            incorr1, incorr2 = (
+                data["incorr1"][str(index_within_seq)][()],
+                data["incorr2"][str(index_within_seq)][()],
+            )
+            img_path1, img_path2 = (
+                data["img_path1"][str(index_within_seq)][()][0].decode(),
+                data["img_path2"][str(index_within_seq)][()][0].decode(),
+            )
+            img_name1, img_name2 = img_path1.split("/")[-1], img_path2.split("/")[-1]
+            fea_path1, fea_path2 = os.path.join(
+                args.desc_dir, seq, img_name1 + args.desc_suffix
+            ), os.path.join(args.desc_dir, seq, img_name2 + args.desc_suffix)
+            with h5py.File(fea_path1, "r") as fea1, h5py.File(fea_path2, "r") as fea2:
+                desc1, kpt1 = fea1["descriptors"][()], fea1["keypoints"][()][:, :2]
+                desc2, kpt2 = fea2["descriptors"][()], fea2["keypoints"][()][:, :2]
+            sim_mat = desc1 @ desc2.T
+            nn_index1, nn_index2 = np.argmax(sim_mat, axis=1), np.argmax(
+                sim_mat, axis=0
+            )
+            mask_mutual = (nn_index2[nn_index1] == np.arange(len(nn_index1)))[corr1]
+            mask_inlier = nn_index1[corr1] == corr2
+            mask_nn_correct = np.logical_and(mask_mutual, mask_inlier)
+            # statistics
             total_inlier_rate.append(mask_nn_correct.mean())
             total_corr_num.append(len(corr1))
-            total_incorr_num.append((len(incorr1)+len(incorr2))/2)
-            #dump visualization
+            total_incorr_num.append((len(incorr1) + len(incorr2)) / 2)
+            # dump visualization
             if args.vis_folder is not None:
-                #draw corr
-                img1,img2=cv2.imread(os.path.join(args.raw_dir,img_path1)),cv2.imread(os.path.join(args.raw_dir,img_path2))
-                corr1_pos,corr2_pos=np.take_along_axis(kpt1,corr1[:,np.newaxis],axis=0),np.take_along_axis(kpt2,corr2[:,np.newaxis],axis=0)
-                dis_corr=evaluation_utils.draw_match(img1,img2,corr1_pos,corr2_pos)
-                cv2.imwrite(os.path.join(args.vis_folder,str(index)+'.png'),dis_corr)
-                #draw incorr
-                incorr1_pos,incorr2_pos=np.take_along_axis(kpt1,incorr1[:,np.newaxis],axis=0),np.take_along_axis(kpt2,incorr2[:,np.newaxis],axis=0)
-                dis_incorr1,dis_incorr2=evaluation_utils.draw_points(img1,incorr1_pos),evaluation_utils.draw_points(img2,incorr2_pos)
-                cv2.imwrite(os.path.join(args.vis_folder,str(index)+'_incorr1.png'),dis_incorr1)
-                cv2.imwrite(os.path.join(args.vis_folder,str(index)+'_incorr2.png'),dis_incorr2)
+                # draw corr
+                img1, img2 = cv2.imread(
+                    os.path.join(args.raw_dir, img_path1)
+                ), cv2.imread(os.path.join(args.raw_dir, img_path2))
+                corr1_pos, corr2_pos = np.take_along_axis(
+                    kpt1, corr1[:, np.newaxis], axis=0
+                ), np.take_along_axis(kpt2, corr2[:, np.newaxis], axis=0)
+                dis_corr = evaluation_utils.draw_match(img1, img2, corr1_pos, corr2_pos)
+                cv2.imwrite(
+                    os.path.join(args.vis_folder, str(index) + ".png"), dis_corr
+                )
+                # draw incorr
+                incorr1_pos, incorr2_pos = np.take_along_axis(
+                    kpt1, incorr1[:, np.newaxis], axis=0
+                ), np.take_along_axis(kpt2, incorr2[:, np.newaxis], axis=0)
+                dis_incorr1, dis_incorr2 = evaluation_utils.draw_points(
+                    img1, incorr1_pos
+                ), evaluation_utils.draw_points(img2, incorr2_pos)
+                cv2.imwrite(
+                    os.path.join(args.vis_folder, str(index) + "_incorr1.png"),
+                    dis_incorr1,
+                )
+                cv2.imwrite(
+                    os.path.join(args.vis_folder, str(index) + "_incorr2.png"),
+                    dis_incorr2,
+                )
 
-    print('NN matching accuracy: ',np.asarray(total_inlier_rate).mean())
-    print('mean corr number: ',np.asarray(total_corr_num).mean())
-    print('mean incorr number: ',np.asarray(total_incorr_num).mean())
+    print("NN matching accuracy: ", np.asarray(total_inlier_rate).mean())
+    print("mean corr number: ", np.asarray(total_corr_num).mean())
+    print("mean incorr number: ", np.asarray(total_incorr_num).mean())
diff --git a/third_party/SGMNet/datadump/dump.py b/third_party/SGMNet/datadump/dump.py
index c30a695639bda467e6af0a607ebbb69d19fd2b54..8c95f7bb348b8b2e388729df071bb331d6556534 100644
--- a/third_party/SGMNet/datadump/dump.py
+++ b/third_party/SGMNet/datadump/dump.py
@@ -1,27 +1,29 @@
 import argparse
 import yaml
 
+
 def str2bool(v):
     return v.lower() in ("true", "1")
 
 
 # Parse command line arguments.
-parser = argparse.ArgumentParser(description='dump eval data.')
-parser.add_argument('--config_path', type=str, default='configs/yfcc.yaml')
+parser = argparse.ArgumentParser(description="dump eval data.")
+parser.add_argument("--config_path", type=str, default="configs/yfcc.yaml")
+
 
 def get_dumper(name):
-    mod = __import__('dumper.{}'.format(name), fromlist=[''])
+    mod = __import__("dumper.{}".format(name), fromlist=[""])
     return getattr(mod, name)
 
 
-if __name__=='__main__':
-    args=parser.parse_args()    
-    with open(args.config_path, 'r') as f:
+if __name__ == "__main__":
+    args = parser.parse_args()
+    with open(args.config_path, "r") as f:
         config = yaml.load(f)
 
-    dataset=get_dumper(config['data_name'])(config)
+    dataset = get_dumper(config["data_name"])(config)
 
     dataset.initialize()
-    if config['extractor']['extract']:
+    if config["extractor"]["extract"]:
         dataset.dump_feature()
-    dataset.format_dump_data()
\ No newline at end of file
+    dataset.format_dump_data()
diff --git a/third_party/SGMNet/datadump/dumper/base_dumper.py b/third_party/SGMNet/datadump/dumper/base_dumper.py
index 075890de54bdcd60871bad380c71dba6f149b254..039c565d9afcb744d30594f3697d45e8d1f234f9 100644
--- a/third_party/SGMNet/datadump/dumper/base_dumper.py
+++ b/third_party/SGMNet/datadump/dumper/base_dumper.py
@@ -3,25 +3,27 @@ import os
 import h5py
 import numpy as np
 from tqdm import trange
-from torch.multiprocessing import Pool,set_start_method
-set_start_method('spawn',force=True)
+from torch.multiprocessing import Pool, set_start_method
+
+set_start_method("spawn", force=True)
 
 import sys
+
 ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
 sys.path.insert(0, ROOT_DIR)
 from components import load_component
 
 
 class BaseDumper(metaclass=ABCMeta):
-    def __init__(self,config):
-        self.config=config
-        self.img_seq=[]
-        self.dump_seq=[]#feature dump seq
-    
+    def __init__(self, config):
+        self.config = config
+        self.img_seq = []
+        self.dump_seq = []  # feature dump seq
+
     @abstractmethod
     def get_seqs(self):
         raise NotImplementedError
-    
+
     @abstractmethod
     def format_dump_folder(self):
         raise NotImplementedError
@@ -29,70 +31,98 @@ class BaseDumper(metaclass=ABCMeta):
     @abstractmethod
     def format_dump_data(self):
         raise NotImplementedError
-    
+
     def initialize(self):
-        self.extractor=load_component('extractor',self.config['extractor']['name'],self.config['extractor'])
+        self.extractor = load_component(
+            "extractor", self.config["extractor"]["name"], self.config["extractor"]
+        )
         self.get_seqs()
         self.format_dump_folder()
 
-
-    def extract(self,index):
-        img_path,dump_path=self.img_seq[index],self.dump_seq[index]
-        if not self.config['extractor']['overwrite'] and os.path.exists(dump_path):
+    def extract(self, index):
+        img_path, dump_path = self.img_seq[index], self.dump_seq[index]
+        if not self.config["extractor"]["overwrite"] and os.path.exists(dump_path):
             return
         kp, desc = self.extractor.run(img_path)
-        self.write_feature(kp,desc,dump_path)
+        self.write_feature(kp, desc, dump_path)
 
     def dump_feature(self):
-        print('Extrating features...')
-        self.num_img=len(self.dump_seq)
-        pool=Pool(self.config['extractor']['num_process'])
-        iteration_num=self.num_img//self.config['extractor']['num_process']
-        if self.num_img%self.config['extractor']['num_process']!=0:
-            iteration_num+=1
+        print("Extrating features...")
+        self.num_img = len(self.dump_seq)
+        pool = Pool(self.config["extractor"]["num_process"])
+        iteration_num = self.num_img // self.config["extractor"]["num_process"]
+        if self.num_img % self.config["extractor"]["num_process"] != 0:
+            iteration_num += 1
         for index in trange(iteration_num):
-            indicies_list=range(index*self.config['extractor']['num_process'],min((index+1)*self.config['extractor']['num_process'],self.num_img))
-            pool.map(self.extract,indicies_list)
+            indicies_list = range(
+                index * self.config["extractor"]["num_process"],
+                min(
+                    (index + 1) * self.config["extractor"]["num_process"], self.num_img
+                ),
+            )
+            pool.map(self.extract, indicies_list)
         pool.close()
         pool.join()
 
-    def write_feature(self,pts, desc, filename):
+    def write_feature(self, pts, desc, filename):
         with h5py.File(filename, "w") as ifp:
-            ifp.create_dataset('keypoints', pts.shape, dtype=np.float32)
-            ifp.create_dataset('descriptors', desc.shape, dtype=np.float32)
+            ifp.create_dataset("keypoints", pts.shape, dtype=np.float32)
+            ifp.create_dataset("descriptors", desc.shape, dtype=np.float32)
             ifp["keypoints"][:] = pts
             ifp["descriptors"][:] = desc
 
     def form_standard_dataset(self):
-        dataset_path=os.path.join(self.config['dataset_dump_dir'],self.config['data_name']+\
-                                '_'+self.config['extractor']['name']+'_'+str(self.config['extractor']['num_kpt'])+'.hdf5')
-        
-        pair_data_type=['K1','K2','R','T','e','f']
-        num_pairs=len(self.data['K1'])
-        with h5py.File(dataset_path, 'w') as f:
-            print('collecting pair info...')
+        dataset_path = os.path.join(
+            self.config["dataset_dump_dir"],
+            self.config["data_name"]
+            + "_"
+            + self.config["extractor"]["name"]
+            + "_"
+            + str(self.config["extractor"]["num_kpt"])
+            + ".hdf5",
+        )
+
+        pair_data_type = ["K1", "K2", "R", "T", "e", "f"]
+        num_pairs = len(self.data["K1"])
+        with h5py.File(dataset_path, "w") as f:
+            print("collecting pair info...")
             for type in pair_data_type:
-                dg=f.create_group(type)
+                dg = f.create_group(type)
                 for idx in range(num_pairs):
-                    data_item=np.asarray(self.data[type][idx])
-                    dg.create_dataset(str(idx),data_item.shape,data_item.dtype,data=data_item)
+                    data_item = np.asarray(self.data[type][idx])
+                    dg.create_dataset(
+                        str(idx), data_item.shape, data_item.dtype, data=data_item
+                    )
 
-            for type in ['img_path1','img_path2']:
-                dg=f.create_group(type)
+            for type in ["img_path1", "img_path2"]:
+                dg = f.create_group(type)
                 for idx in range(num_pairs):
-                    dg.create_dataset(str(idx),[1],h5py.string_dtype(encoding='ascii'),data=self.data[type][idx].encode('ascii'))
+                    dg.create_dataset(
+                        str(idx),
+                        [1],
+                        h5py.string_dtype(encoding="ascii"),
+                        data=self.data[type][idx].encode("ascii"),
+                    )
 
-            #dump desc
-            print('collecting desc and kpt...')
-            desc1_g,desc2_g,kpt1_g,kpt2_g=f.create_group('desc1'),f.create_group('desc2'),f.create_group('kpt1'),f.create_group('kpt2')
+            # dump desc
+            print("collecting desc and kpt...")
+            desc1_g, desc2_g, kpt1_g, kpt2_g = (
+                f.create_group("desc1"),
+                f.create_group("desc2"),
+                f.create_group("kpt1"),
+                f.create_group("kpt2"),
+            )
             for idx in trange(num_pairs):
-                desc_file1,desc_file2=h5py.File(self.data['fea_path1'][idx],'r'),h5py.File(self.data['fea_path2'][idx],'r')
-                desc1,desc2,kpt1,kpt2=desc_file1['descriptors'][()],desc_file2['descriptors'][()],desc_file1['keypoints'][()],desc_file2['keypoints'][()]
-                desc1_g.create_dataset(str(idx),desc1.shape,desc1.dtype,data=desc1)
-                desc2_g.create_dataset(str(idx),desc2.shape,desc2.dtype,data=desc2)
-                kpt1_g.create_dataset(str(idx),kpt1.shape,kpt1.dtype,data=kpt1)
-                kpt2_g.create_dataset(str(idx),kpt2.shape,kpt2.dtype,data=kpt2)
-
-          
-
-        
\ No newline at end of file
+                desc_file1, desc_file2 = h5py.File(
+                    self.data["fea_path1"][idx], "r"
+                ), h5py.File(self.data["fea_path2"][idx], "r")
+                desc1, desc2, kpt1, kpt2 = (
+                    desc_file1["descriptors"][()],
+                    desc_file2["descriptors"][()],
+                    desc_file1["keypoints"][()],
+                    desc_file2["keypoints"][()],
+                )
+                desc1_g.create_dataset(str(idx), desc1.shape, desc1.dtype, data=desc1)
+                desc2_g.create_dataset(str(idx), desc2.shape, desc2.dtype, data=desc2)
+                kpt1_g.create_dataset(str(idx), kpt1.shape, kpt1.dtype, data=kpt1)
+                kpt2_g.create_dataset(str(idx), kpt2.shape, kpt2.dtype, data=kpt2)
diff --git a/third_party/SGMNet/datadump/dumper/fmbench.py b/third_party/SGMNet/datadump/dumper/fmbench.py
index 1d80384595704360c78b9c05b8a4cb3ace9d0406..4e64fecc76c3a261dbb2762b049998b228703581 100644
--- a/third_party/SGMNet/datadump/dumper/fmbench.py
+++ b/third_party/SGMNet/datadump/dumper/fmbench.py
@@ -8,85 +8,168 @@ from numpy.core.fromnumeric import reshape
 from .base_dumper import BaseDumper
 
 import sys
+
 ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
 sys.path.insert(0, ROOT_DIR)
 import utils
 
+
 class fmbench(BaseDumper):
-    
     def get_seqs(self):
-        data_dir=os.path.join(self.config['rawdata_dir'])
-        self.split_list=[]
-        for seq in self.config['data_seq']:
-            cur_split_list=np.unique(np.loadtxt(os.path.join(data_dir,seq,'pairs_which_dataset.txt'),dtype=str))
+        data_dir = os.path.join(self.config["rawdata_dir"])
+        self.split_list = []
+        for seq in self.config["data_seq"]:
+            cur_split_list = np.unique(
+                np.loadtxt(
+                    os.path.join(data_dir, seq, "pairs_which_dataset.txt"), dtype=str
+                )
+            )
             self.split_list.append(cur_split_list)
             for split in cur_split_list:
-                split_dir=os.path.join(data_dir,seq,split)
-                dump_dir=os.path.join(self.config['feature_dump_dir'],seq,split)
-                cur_img_seq=glob.glob(os.path.join(split_dir,'Images','*.jpg'))
-                cur_dump_seq=[os.path.join(dump_dir,path.split('/')[-1])+'_'+self.config['extractor']['name']+'_'+str(self.config['extractor']['num_kpt'])\
-                             +'.hdf5' for path in cur_img_seq]
-                self.img_seq+=cur_img_seq
-                self.dump_seq+=cur_dump_seq
+                split_dir = os.path.join(data_dir, seq, split)
+                dump_dir = os.path.join(self.config["feature_dump_dir"], seq, split)
+                cur_img_seq = glob.glob(os.path.join(split_dir, "Images", "*.jpg"))
+                cur_dump_seq = [
+                    os.path.join(dump_dir, path.split("/")[-1])
+                    + "_"
+                    + self.config["extractor"]["name"]
+                    + "_"
+                    + str(self.config["extractor"]["num_kpt"])
+                    + ".hdf5"
+                    for path in cur_img_seq
+                ]
+                self.img_seq += cur_img_seq
+                self.dump_seq += cur_dump_seq
 
     def format_dump_folder(self):
-        if not os.path.exists(self.config['feature_dump_dir']):
-            os.mkdir(self.config['feature_dump_dir'])
-        for seq_index in range(len(self.config['data_seq'])):
-            seq_dir=os.path.join(self.config['feature_dump_dir'],self.config['data_seq'][seq_index])
+        if not os.path.exists(self.config["feature_dump_dir"]):
+            os.mkdir(self.config["feature_dump_dir"])
+        for seq_index in range(len(self.config["data_seq"])):
+            seq_dir = os.path.join(
+                self.config["feature_dump_dir"], self.config["data_seq"][seq_index]
+            )
             if not os.path.exists(seq_dir):
                 os.mkdir(seq_dir)
             for split in self.split_list[seq_index]:
-                split_dir=os.path.join(seq_dir,split)
+                split_dir = os.path.join(seq_dir, split)
                 if not os.path.exists(split_dir):
                     os.mkdir(split_dir)
 
     def format_dump_data(self):
-        print('Formatting data...')
-        self.data={'K1':[],'K2':[],'R':[],'T':[],'e':[],'f':[],'fea_path1':[],'fea_path2':[],'img_path1':[],'img_path2':[]}
+        print("Formatting data...")
+        self.data = {
+            "K1": [],
+            "K2": [],
+            "R": [],
+            "T": [],
+            "e": [],
+            "f": [],
+            "fea_path1": [],
+            "fea_path2": [],
+            "img_path1": [],
+            "img_path2": [],
+        }
 
-        for seq_index in range(len(self.config['data_seq'])):
-            seq=self.config['data_seq'][seq_index]
+        for seq_index in range(len(self.config["data_seq"])):
+            seq = self.config["data_seq"][seq_index]
             print(seq)
-            pair_list=np.loadtxt(os.path.join(self.config['rawdata_dir'],seq,'pairs_with_gt.txt'),dtype=float)
-            which_split_list=np.loadtxt(os.path.join(self.config['rawdata_dir'],seq,'pairs_which_dataset.txt'),dtype=str)
+            pair_list = np.loadtxt(
+                os.path.join(self.config["rawdata_dir"], seq, "pairs_with_gt.txt"),
+                dtype=float,
+            )
+            which_split_list = np.loadtxt(
+                os.path.join(
+                    self.config["rawdata_dir"], seq, "pairs_which_dataset.txt"
+                ),
+                dtype=str,
+            )
 
             for pair_index in trange(len(pair_list)):
-                cur_pair=pair_list[pair_index]
-                cur_split=which_split_list[pair_index]
-                index1,index2=int(cur_pair[0]),int(cur_pair[1])
-                #get intrinsic
-                camera=np.loadtxt(os.path.join(self.config['rawdata_dir'],seq,cur_split,'Camera.txt'),dtype=float)
-                K1,K2=camera[index1].reshape([3,3]),camera[index2].reshape([3,3])
-                #get pose
-                pose=np.loadtxt(os.path.join(self.config['rawdata_dir'],seq,cur_split,'Poses.txt'),dtype=float)
-                pose1,pose2=pose[index1].reshape([3,4]),pose[index2].reshape([3,4])
-                R1,R2,t1,t2=pose1[:3,:3],pose2[:3,:3],pose1[:3,3][:,np.newaxis],pose2[:3,3][:,np.newaxis]
+                cur_pair = pair_list[pair_index]
+                cur_split = which_split_list[pair_index]
+                index1, index2 = int(cur_pair[0]), int(cur_pair[1])
+                # get intrinsic
+                camera = np.loadtxt(
+                    os.path.join(
+                        self.config["rawdata_dir"], seq, cur_split, "Camera.txt"
+                    ),
+                    dtype=float,
+                )
+                K1, K2 = camera[index1].reshape([3, 3]), camera[index2].reshape([3, 3])
+                # get pose
+                pose = np.loadtxt(
+                    os.path.join(
+                        self.config["rawdata_dir"], seq, cur_split, "Poses.txt"
+                    ),
+                    dtype=float,
+                )
+                pose1, pose2 = pose[index1].reshape([3, 4]), pose[index2].reshape(
+                    [3, 4]
+                )
+                R1, R2, t1, t2 = (
+                    pose1[:3, :3],
+                    pose2[:3, :3],
+                    pose1[:3, 3][:, np.newaxis],
+                    pose2[:3, 3][:, np.newaxis],
+                )
                 dR = np.dot(R2, R1.T)
                 dt = t2 - np.dot(dR, t1)
                 dt /= np.sqrt(np.sum(dt**2))
-                
-                e_gt_unnorm = np.reshape(np.matmul(
-                np.reshape(utils.evaluation_utils.np_skew_symmetric(dt.astype('float64').reshape(1, 3)), (3, 3)),
-                np.reshape(dR.astype('float64'), (3, 3))), (3, 3))
+
+                e_gt_unnorm = np.reshape(
+                    np.matmul(
+                        np.reshape(
+                            utils.evaluation_utils.np_skew_symmetric(
+                                dt.astype("float64").reshape(1, 3)
+                            ),
+                            (3, 3),
+                        ),
+                        np.reshape(dR.astype("float64"), (3, 3)),
+                    ),
+                    (3, 3),
+                )
                 e_gt = e_gt_unnorm / np.linalg.norm(e_gt_unnorm)
 
-                f=cur_pair[2:].reshape([3,3])
-                f_gt=f / np.linalg.norm(f)
+                f = cur_pair[2:].reshape([3, 3])
+                f_gt = f / np.linalg.norm(f)
+
+                self.data["K1"].append(K1), self.data["K2"].append(K2)
+                self.data["R"].append(dR), self.data["T"].append(dt)
+                self.data["e"].append(e_gt), self.data["f"].append(f_gt)
+
+                img_path1, img_path2 = os.path.join(
+                    seq, cur_split, "Images", str(index1).zfill(8) + ".jpg"
+                ), os.path.join(seq, cur_split, "Images", str(index1).zfill(8) + ".jpg")
+
+                fea_path1, fea_path2 = os.path.join(
+                    self.config["feature_dump_dir"],
+                    seq,
+                    cur_split,
+                    str(index1).zfill(8)
+                    + ".jpg"
+                    + "_"
+                    + self.config["extractor"]["name"]
+                    + "_"
+                    + str(self.config["extractor"]["num_kpt"])
+                    + ".hdf5",
+                ), os.path.join(
+                    self.config["feature_dump_dir"],
+                    seq,
+                    cur_split,
+                    str(index2).zfill(8)
+                    + ".jpg"
+                    + "_"
+                    + self.config["extractor"]["name"]
+                    + "_"
+                    + str(self.config["extractor"]["num_kpt"])
+                    + ".hdf5",
+                )
 
-                self.data['K1'].append(K1),self.data['K2'].append(K2)
-                self.data['R'].append(dR),self.data['T'].append(dt)
-                self.data['e'].append(e_gt),self.data['f'].append(f_gt)
+                self.data["img_path1"].append(img_path1), self.data["img_path2"].append(
+                    img_path2
+                )
+                self.data["fea_path1"].append(fea_path1), self.data["fea_path2"].append(
+                    fea_path2
+                )
 
-                img_path1,img_path2=os.path.join(seq,cur_split,'Images',str(index1).zfill(8)+'.jpg'),\
-                                    os.path.join(seq,cur_split,'Images',str(index1).zfill(8)+'.jpg')
-                
-                fea_path1,fea_path2=os.path.join(self.config['feature_dump_dir'],seq,cur_split,str(index1).zfill(8)+'.jpg'+'_'+self.config['extractor']['name']
-                                    +'_'+str(self.config['extractor']['num_kpt'])+'.hdf5'),\
-                                    os.path.join(self.config['feature_dump_dir'],seq,cur_split,str(index2).zfill(8)+'.jpg'+'_'+self.config['extractor']['name']
-                                    +'_'+str(self.config['extractor']['num_kpt'])+'.hdf5')
-                
-                self.data['img_path1'].append(img_path1),self.data['img_path2'].append(img_path2)
-                self.data['fea_path1'].append(fea_path1),self.data['fea_path2'].append(fea_path2)
-            
         self.form_standard_dataset()
diff --git a/third_party/SGMNet/datadump/dumper/gl3d_train.py b/third_party/SGMNet/datadump/dumper/gl3d_train.py
index 42a703442037eeb8174bc17f4cbd14c9768db1d1..babcde0bbf2277d50e991a4210e5855c16e9c05a 100644
--- a/third_party/SGMNet/datadump/dumper/gl3d_train.py
+++ b/third_party/SGMNet/datadump/dumper/gl3d_train.py
@@ -10,110 +10,140 @@ import pyxis as px
 from .base_dumper import BaseDumper
 
 import sys
+
 ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
 sys.path.insert(0, ROOT_DIR)
 
-from utils import transformations,data_utils
+from utils import transformations, data_utils
+
 
 class gl3d_train(BaseDumper):
-    
     def get_seqs(self):
-        data_dir=os.path.join(self.config['rawdata_dir'],'data')
-        seq_train=np.loadtxt(os.path.join(self.config['rawdata_dir'],'list','comb','imageset_train.txt'),dtype=str)
-        seq_valid=np.loadtxt(os.path.join(self.config['rawdata_dir'],'list','comb','imageset_test.txt'),dtype=str)
+        data_dir = os.path.join(self.config["rawdata_dir"], "data")
+        seq_train = np.loadtxt(
+            os.path.join(
+                self.config["rawdata_dir"], "list", "comb", "imageset_train.txt"
+            ),
+            dtype=str,
+        )
+        seq_valid = np.loadtxt(
+            os.path.join(
+                self.config["rawdata_dir"], "list", "comb", "imageset_test.txt"
+            ),
+            dtype=str,
+        )
 
-        #filtering seq list
-        self.seq_list,self.train_list,self.valid_list=[],[],[]
+        # filtering seq list
+        self.seq_list, self.train_list, self.valid_list = [], [], []
         for seq in seq_train:
-            if seq not in self.config['exclude_seq']:
+            if seq not in self.config["exclude_seq"]:
                 self.train_list.append(seq)
         for seq in seq_valid:
-            if seq not in self.config['exclude_seq']:
+            if seq not in self.config["exclude_seq"]:
                 self.valid_list.append(seq)
-        seq_list=[]
-        if self.config['dump_train']:
+        seq_list = []
+        if self.config["dump_train"]:
             seq_list.append(self.train_list)
-        if self.config['dump_valid']:
+        if self.config["dump_valid"]:
             seq_list.append(self.valid_list)
-        self.seq_list=np.concatenate(seq_list,axis=0)
+        self.seq_list = np.concatenate(seq_list, axis=0)
 
-        #self.seq_list=self.seq_list[:2]
-        #self.valid_list=self.valid_list[:2]
+        # self.seq_list=self.seq_list[:2]
+        # self.valid_list=self.valid_list[:2]
         for seq in self.seq_list:
-            dump_dir=os.path.join(self.config['feature_dump_dir'],seq)
-            cur_img_seq=glob.glob(os.path.join(data_dir,seq,'undist_images','*.jpg'))
-            cur_dump_seq=[os.path.join(dump_dir,path.split('/')[-1])+'_'+self.config['extractor']['name']+'_'+str(self.config['extractor']['num_kpt'])\
-                             +'.hdf5' for path in cur_img_seq]
-            self.img_seq+=cur_img_seq
-            self.dump_seq+=cur_dump_seq
-
+            dump_dir = os.path.join(self.config["feature_dump_dir"], seq)
+            cur_img_seq = glob.glob(
+                os.path.join(data_dir, seq, "undist_images", "*.jpg")
+            )
+            cur_dump_seq = [
+                os.path.join(dump_dir, path.split("/")[-1])
+                + "_"
+                + self.config["extractor"]["name"]
+                + "_"
+                + str(self.config["extractor"]["num_kpt"])
+                + ".hdf5"
+                for path in cur_img_seq
+            ]
+            self.img_seq += cur_img_seq
+            self.dump_seq += cur_dump_seq
 
     def format_dump_folder(self):
-        if not os.path.exists(self.config['feature_dump_dir']):
-            os.mkdir(self.config['feature_dump_dir'])
+        if not os.path.exists(self.config["feature_dump_dir"]):
+            os.mkdir(self.config["feature_dump_dir"])
         for seq in self.seq_list:
-            seq_dir=os.path.join(self.config['feature_dump_dir'],seq)
+            seq_dir = os.path.join(self.config["feature_dump_dir"], seq)
             if not os.path.exists(seq_dir):
                 os.mkdir(seq_dir)
-        if not os.path.exists(self.config['dataset_dump_dir']):
-            os.mkdir(self.config['dataset_dump_dir'])
-    
+        if not os.path.exists(self.config["dataset_dump_dir"]):
+            os.mkdir(self.config["dataset_dump_dir"])
 
-    def load_geom(self,seq):
+    def load_geom(self, seq):
         # load geometry file
-        geom_file=os.path.join(self.config['rawdata_dir'],'data',seq,'geolabel','cameras.txt')
-        basename_list=np.loadtxt(os.path.join(self.config['rawdata_dir'],'data',seq,'basenames.txt'),dtype=str)
+        geom_file = os.path.join(
+            self.config["rawdata_dir"], "data", seq, "geolabel", "cameras.txt"
+        )
+        basename_list = np.loadtxt(
+            os.path.join(self.config["rawdata_dir"], "data", seq, "basenames.txt"),
+            dtype=str,
+        )
         geom_dict = []
         cameras = np.loadtxt(geom_file)
-        camera_index=0
+        camera_index = 0
         for base_index in range(len(basename_list)):
-            if base_index<cameras[camera_index][0]:
+            if base_index < cameras[camera_index][0]:
                 geom_dict.append(None)
                 continue
             cur_geom = {}
             ori_img_size = [cameras[camera_index][-2], cameras[camera_index][-1]]
-            scale_factor = [1000. / ori_img_size[0], 1000. / ori_img_size[1]]
-            K = np.asarray([[cameras[camera_index][1], cameras[camera_index][5], cameras[camera_index][3]],
-                            [0, cameras[camera_index][2], cameras[camera_index][4]],
-                            [0, 0, 1]])
+            scale_factor = [1000.0 / ori_img_size[0], 1000.0 / ori_img_size[1]]
+            K = np.asarray(
+                [
+                    [
+                        cameras[camera_index][1],
+                        cameras[camera_index][5],
+                        cameras[camera_index][3],
+                    ],
+                    [0, cameras[camera_index][2], cameras[camera_index][4]],
+                    [0, 0, 1],
+                ]
+            )
             # Rescale calbration according to previous resizing
-            S = np.asarray([[scale_factor[0], 0, 0],
-                            [0, scale_factor[1], 0],
-                            [0, 0, 1]])
+            S = np.asarray(
+                [[scale_factor[0], 0, 0], [0, scale_factor[1], 0], [0, 0, 1]]
+            )
             K = np.dot(S, K)
             cur_geom["K"] = K
-            cur_geom['R'] = cameras[camera_index][9:18].reshape([3, 3])
-            cur_geom['T'] = cameras[camera_index][6:9]
-            cur_geom['size']=np.asarray([1000,1000])
+            cur_geom["R"] = cameras[camera_index][9:18].reshape([3, 3])
+            cur_geom["T"] = cameras[camera_index][6:9]
+            cur_geom["size"] = np.asarray([1000, 1000])
             geom_dict.append(cur_geom)
-            camera_index+=1
+            camera_index += 1
         return geom_dict
 
-
-    def load_depth(self,file_path):
-        with open(os.path.join(file_path), 'rb') as fin:
+    def load_depth(self, file_path):
+        with open(os.path.join(file_path), "rb") as fin:
             color = None
             width = None
             height = None
             scale = None
             data_type = None
-            header = str(fin.readline().decode('UTF-8')).rstrip()
-            if header == 'PF':
+            header = str(fin.readline().decode("UTF-8")).rstrip()
+            if header == "PF":
                 color = True
-            elif header == 'Pf':
+            elif header == "Pf":
                 color = False
             else:
-                raise Exception('Not a PFM file.')
-            dim_match = re.match(r'^(\d+)\s(\d+)\s$', fin.readline().decode('UTF-8'))
+                raise Exception("Not a PFM file.")
+            dim_match = re.match(r"^(\d+)\s(\d+)\s$", fin.readline().decode("UTF-8"))
             if dim_match:
                 width, height = map(int, dim_match.groups())
             else:
-                raise Exception('Malformed PFM header.')
-            scale = float((fin.readline().decode('UTF-8')).rstrip())
+                raise Exception("Malformed PFM header.")
+            scale = float((fin.readline().decode("UTF-8")).rstrip())
             if scale < 0:  # little-endian
-                data_type = '<f'
+                data_type = "<f"
             else:
-                data_type = '>f'  # big-endian
+                data_type = ">f"  # big-endian
             data_string = fin.read()
             data = np.fromstring(data_string, data_type)
             shape = (height, width, 3) if color else (height, width)
@@ -121,128 +151,251 @@ class gl3d_train(BaseDumper):
             data = np.flip(data, 0)
         return data
 
-
-    def dump_info(self,seq,info):
-        pair_type=['dR','dt','K1','K2','size1','size2','corr','incorr1','incorr2']
-        num_pairs=len(info['dR'])
-        os.mkdir(os.path.join(self.config['dataset_dump_dir'],seq))
-        with h5py.File(os.path.join(self.config['dataset_dump_dir'],seq,'info.h5py'), 'w') as f:
+    def dump_info(self, seq, info):
+        pair_type = [
+            "dR",
+            "dt",
+            "K1",
+            "K2",
+            "size1",
+            "size2",
+            "corr",
+            "incorr1",
+            "incorr2",
+        ]
+        num_pairs = len(info["dR"])
+        os.mkdir(os.path.join(self.config["dataset_dump_dir"], seq))
+        with h5py.File(
+            os.path.join(self.config["dataset_dump_dir"], seq, "info.h5py"), "w"
+        ) as f:
             for type in pair_type:
-                dg=f.create_group(type)
+                dg = f.create_group(type)
                 for idx in range(num_pairs):
-                    data_item=np.asarray(info[type][idx])
-                    dg.create_dataset(str(idx),data_item.shape,data_item.dtype,data=data_item)
-            for type in ['img_path1','img_path2']:
-                dg=f.create_group(type)
+                    data_item = np.asarray(info[type][idx])
+                    dg.create_dataset(
+                        str(idx), data_item.shape, data_item.dtype, data=data_item
+                    )
+            for type in ["img_path1", "img_path2"]:
+                dg = f.create_group(type)
                 for idx in range(num_pairs):
-                    dg.create_dataset(str(idx),[1],h5py.string_dtype(encoding='ascii'),data=info[type][idx].encode('ascii'))
-        
-        with open(os.path.join(self.config['dataset_dump_dir'],seq,'pair_num.txt'), 'w') as f:
-            f.write(str(info['pair_num']))
-
-    def format_seq(self,index):
-        seq=self.seq_list[index]
-        seq_dir=os.path.join(os.path.join(self.config['rawdata_dir'],'data',seq))
-        basename_list=np.loadtxt(os.path.join(seq_dir,'basenames.txt'),dtype=str)
-        pair_list=np.loadtxt(os.path.join(seq_dir,'geolabel','common_track.txt'),dtype=float)[:,:2].astype(int)
-        overlap_score=np.loadtxt(os.path.join(seq_dir,'geolabel','common_track.txt'),dtype=float)[:,2]
-        geom_dict=self.load_geom(seq)
-
-        #check info existance
-        if os.path.exists(os.path.join(self.config['dataset_dump_dir'],seq,'pair_num.txt')):
+                    dg.create_dataset(
+                        str(idx),
+                        [1],
+                        h5py.string_dtype(encoding="ascii"),
+                        data=info[type][idx].encode("ascii"),
+                    )
+
+        with open(
+            os.path.join(self.config["dataset_dump_dir"], seq, "pair_num.txt"), "w"
+        ) as f:
+            f.write(str(info["pair_num"]))
+
+    def format_seq(self, index):
+        seq = self.seq_list[index]
+        seq_dir = os.path.join(os.path.join(self.config["rawdata_dir"], "data", seq))
+        basename_list = np.loadtxt(os.path.join(seq_dir, "basenames.txt"), dtype=str)
+        pair_list = np.loadtxt(
+            os.path.join(seq_dir, "geolabel", "common_track.txt"), dtype=float
+        )[:, :2].astype(int)
+        overlap_score = np.loadtxt(
+            os.path.join(seq_dir, "geolabel", "common_track.txt"), dtype=float
+        )[:, 2]
+        geom_dict = self.load_geom(seq)
+
+        # check info existance
+        if os.path.exists(
+            os.path.join(self.config["dataset_dump_dir"], seq, "pair_num.txt")
+        ):
             return
 
-        angle_list=[]
-        #filtering pairs
+        angle_list = []
+        # filtering pairs
         for cur_pair in pair_list:
-            pair_index1,pair_index2=cur_pair[0],cur_pair[1]
-            geo1,geo2=geom_dict[pair_index1],geom_dict[pair_index2]
-            dR = np.dot(geo2['R'], geo1['R'].T)
+            pair_index1, pair_index2 = cur_pair[0], cur_pair[1]
+            geo1, geo2 = geom_dict[pair_index1], geom_dict[pair_index2]
+            dR = np.dot(geo2["R"], geo1["R"].T)
             q = transformations.quaternion_from_matrix(dR)
             angle_list.append(math.acos(q[0]) * 2 * 180 / math.pi)
-        angle_list=np.asarray(angle_list)
-        mask_survive=np.logical_and(
-                            np.logical_and(angle_list>self.config['angle_th'][0],angle_list<self.config['angle_th'][1]),
-                            np.logical_and(overlap_score>self.config['overlap_th'][0],overlap_score<self.config['overlap_th'][1])
-                        )
-        pair_list=pair_list[mask_survive]
-        if len(pair_list)<100:
-            print(seq,len(pair_list))
-        #sample pairs
-        shuffled_pair_list=np.random.permutation(pair_list)
-        sample_target=min(self.config['pairs_per_seq'],len(shuffled_pair_list))
-        sample_number=0
-
-        info={'dR':[],'dt':[],'K1':[],'K2':[],'img_path1':[],'img_path2':[],'fea_path1':[],'fea_path2':[],'size1':[],'size2':[],
-            'corr':[],'incorr1':[],'incorr2':[],'pair_num':[]}
+        angle_list = np.asarray(angle_list)
+        mask_survive = np.logical_and(
+            np.logical_and(
+                angle_list > self.config["angle_th"][0],
+                angle_list < self.config["angle_th"][1],
+            ),
+            np.logical_and(
+                overlap_score > self.config["overlap_th"][0],
+                overlap_score < self.config["overlap_th"][1],
+            ),
+        )
+        pair_list = pair_list[mask_survive]
+        if len(pair_list) < 100:
+            print(seq, len(pair_list))
+        # sample pairs
+        shuffled_pair_list = np.random.permutation(pair_list)
+        sample_target = min(self.config["pairs_per_seq"], len(shuffled_pair_list))
+        sample_number = 0
+
+        info = {
+            "dR": [],
+            "dt": [],
+            "K1": [],
+            "K2": [],
+            "img_path1": [],
+            "img_path2": [],
+            "fea_path1": [],
+            "fea_path2": [],
+            "size1": [],
+            "size2": [],
+            "corr": [],
+            "incorr1": [],
+            "incorr2": [],
+            "pair_num": [],
+        }
         for cur_pair in shuffled_pair_list:
-            pair_index1,pair_index2=cur_pair[0],cur_pair[1]
-            geo1,geo2=geom_dict[pair_index1],geom_dict[pair_index2]
-            dR = np.dot(geo2['R'], geo1['R'].T)
+            pair_index1, pair_index2 = cur_pair[0], cur_pair[1]
+            geo1, geo2 = geom_dict[pair_index1], geom_dict[pair_index2]
+            dR = np.dot(geo2["R"], geo1["R"].T)
             t1, t2 = geo1["T"].reshape([3, 1]), geo2["T"].reshape([3, 1])
             dt = t2 - np.dot(dR, t1)
-            K1,K2=geo1['K'],geo2['K']
-            size1,size2=geo1['size'],geo2['size']
-
-            basename1,basename2=basename_list[pair_index1],basename_list[pair_index2]
-            img_path1,img_path2=os.path.join(seq,'undist_images',basename1+'.jpg'),os.path.join(seq,'undist_images',basename2+'.jpg')
-            fea_path1,fea_path2=os.path.join(seq,basename1+'.jpg'+'_'+self.config['extractor']['name']+'_'+str(self.config['extractor']['num_kpt'])+'.hdf5'),\
-                                os.path.join(seq,basename2+'.jpg'+'_'+self.config['extractor']['name']+'_'+str(self.config['extractor']['num_kpt'])+'.hdf5')
-
-            with h5py.File(os.path.join(self.config['feature_dump_dir'],fea_path1),'r') as fea1, \
-                h5py.File(os.path.join(self.config['feature_dump_dir'],fea_path2),'r') as fea2:
-                desc1,desc2=fea1['descriptors'][()],fea2['descriptors'][()]
-                kpt1,kpt2=fea1['keypoints'][()],fea2['keypoints'][()]
-                depth_path1,depth_path2=os.path.join(self.config['rawdata_dir'],'data',seq,'depths',basename1+'.pfm'),\
-                                        os.path.join(self.config['rawdata_dir'],'data',seq,'depths',basename2+'.pfm')
-                depth1,depth2=self.load_depth(depth_path1),self.load_depth(depth_path2)
-                corr_index,incorr_index1,incorr_index2=data_utils.make_corr(kpt1[:,:2],kpt2[:,:2],desc1,desc2,depth1,depth2,K1,K2,dR,dt,size1,size2,
-                                                                            self.config['corr_th'],self.config['incorr_th'],self.config['check_desc'])
-            
-            if len(corr_index)>self.config['min_corr'] and len(incorr_index1)>self.config['min_incorr'] and len(incorr_index2)>self.config['min_incorr']:
-                info['corr'].append(corr_index),info['incorr1'].append(incorr_index1),info['incorr2'].append(incorr_index2)
-                info['dR'].append(dR),info['dt'].append(dt),info['K1'].append(K1),info['K2'].append(K2),info['img_path1'].append(img_path1),info['img_path2'].append(img_path2)
-                info['fea_path1'].append(fea_path1),info['fea_path2'].append(fea_path2),info['size1'].append(size1),info['size2'].append(size2)
-                sample_number+=1
-            if sample_number==sample_target:
+            K1, K2 = geo1["K"], geo2["K"]
+            size1, size2 = geo1["size"], geo2["size"]
+
+            basename1, basename2 = (
+                basename_list[pair_index1],
+                basename_list[pair_index2],
+            )
+            img_path1, img_path2 = os.path.join(
+                seq, "undist_images", basename1 + ".jpg"
+            ), os.path.join(seq, "undist_images", basename2 + ".jpg")
+            fea_path1, fea_path2 = os.path.join(
+                seq,
+                basename1
+                + ".jpg"
+                + "_"
+                + self.config["extractor"]["name"]
+                + "_"
+                + str(self.config["extractor"]["num_kpt"])
+                + ".hdf5",
+            ), os.path.join(
+                seq,
+                basename2
+                + ".jpg"
+                + "_"
+                + self.config["extractor"]["name"]
+                + "_"
+                + str(self.config["extractor"]["num_kpt"])
+                + ".hdf5",
+            )
+
+            with h5py.File(
+                os.path.join(self.config["feature_dump_dir"], fea_path1), "r"
+            ) as fea1, h5py.File(
+                os.path.join(self.config["feature_dump_dir"], fea_path2), "r"
+            ) as fea2:
+                desc1, desc2 = fea1["descriptors"][()], fea2["descriptors"][()]
+                kpt1, kpt2 = fea1["keypoints"][()], fea2["keypoints"][()]
+                depth_path1, depth_path2 = os.path.join(
+                    self.config["rawdata_dir"],
+                    "data",
+                    seq,
+                    "depths",
+                    basename1 + ".pfm",
+                ), os.path.join(
+                    self.config["rawdata_dir"],
+                    "data",
+                    seq,
+                    "depths",
+                    basename2 + ".pfm",
+                )
+                depth1, depth2 = self.load_depth(depth_path1), self.load_depth(
+                    depth_path2
+                )
+                corr_index, incorr_index1, incorr_index2 = data_utils.make_corr(
+                    kpt1[:, :2],
+                    kpt2[:, :2],
+                    desc1,
+                    desc2,
+                    depth1,
+                    depth2,
+                    K1,
+                    K2,
+                    dR,
+                    dt,
+                    size1,
+                    size2,
+                    self.config["corr_th"],
+                    self.config["incorr_th"],
+                    self.config["check_desc"],
+                )
+
+            if (
+                len(corr_index) > self.config["min_corr"]
+                and len(incorr_index1) > self.config["min_incorr"]
+                and len(incorr_index2) > self.config["min_incorr"]
+            ):
+                info["corr"].append(corr_index), info["incorr1"].append(
+                    incorr_index1
+                ), info["incorr2"].append(incorr_index2)
+                info["dR"].append(dR), info["dt"].append(dt), info["K1"].append(
+                    K1
+                ), info["K2"].append(K2), info["img_path1"].append(img_path1), info[
+                    "img_path2"
+                ].append(
+                    img_path2
+                )
+                info["fea_path1"].append(fea_path1), info["fea_path2"].append(
+                    fea_path2
+                ), info["size1"].append(size1), info["size2"].append(size2)
+                sample_number += 1
+            if sample_number == sample_target:
                 break
-        info['pair_num']=sample_number
-        #dump info
-        self.dump_info(seq,info)
+        info["pair_num"] = sample_number
+        # dump info
+        self.dump_info(seq, info)
 
-  
     def collect_meta(self):
-        print('collecting meta info...')
-        dump_path,seq_list=[],[]
-        if self.config['dump_train']:
-            dump_path.append(os.path.join(self.config['dataset_dump_dir'],'train'))
+        print("collecting meta info...")
+        dump_path, seq_list = [], []
+        if self.config["dump_train"]:
+            dump_path.append(os.path.join(self.config["dataset_dump_dir"], "train"))
             seq_list.append(self.train_list)
-        if self.config['dump_valid']:
-            dump_path.append(os.path.join(self.config['dataset_dump_dir'],'valid'))
+        if self.config["dump_valid"]:
+            dump_path.append(os.path.join(self.config["dataset_dump_dir"], "valid"))
             seq_list.append(self.valid_list)
-        for pth,seqs in zip(dump_path,seq_list):
+        for pth, seqs in zip(dump_path, seq_list):
             if not os.path.exists(pth):
                 os.mkdir(pth)
-            pair_num_list,total_pair=[],0
-            for seq_index in range(len(seqs)):    
-                seq=seqs[seq_index]
-                pair_num=np.loadtxt(os.path.join(self.config['dataset_dump_dir'],seq,'pair_num.txt'),dtype=int)
+            pair_num_list, total_pair = [], 0
+            for seq_index in range(len(seqs)):
+                seq = seqs[seq_index]
+                pair_num = np.loadtxt(
+                    os.path.join(self.config["dataset_dump_dir"], seq, "pair_num.txt"),
+                    dtype=int,
+                )
                 pair_num_list.append(str(pair_num))
-                total_pair+=pair_num
-            pair_num_list=np.stack([np.asarray(seqs,dtype=str),np.asarray(pair_num_list,dtype=str)],axis=1)
-            pair_num_list=np.concatenate([np.asarray([['total',str(total_pair)]]),pair_num_list],axis=0)
-            np.savetxt(os.path.join(pth,'pair_num.txt'),pair_num_list,fmt='%s')
-            
+                total_pair += pair_num
+            pair_num_list = np.stack(
+                [np.asarray(seqs, dtype=str), np.asarray(pair_num_list, dtype=str)],
+                axis=1,
+            )
+            pair_num_list = np.concatenate(
+                [np.asarray([["total", str(total_pair)]]), pair_num_list], axis=0
+            )
+            np.savetxt(os.path.join(pth, "pair_num.txt"), pair_num_list, fmt="%s")
+
     def format_dump_data(self):
-        print('Formatting data...')
-        iteration_num=len(self.seq_list)//self.config['num_process']
-        if len(self.seq_list)%self.config['num_process']!=0:
-            iteration_num+=1
-        pool=Pool(self.config['num_process'])
+        print("Formatting data...")
+        iteration_num = len(self.seq_list) // self.config["num_process"]
+        if len(self.seq_list) % self.config["num_process"] != 0:
+            iteration_num += 1
+        pool = Pool(self.config["num_process"])
         for index in trange(iteration_num):
-            indices=range(index*self.config['num_process'],min((index+1)*self.config['num_process'],len(self.seq_list)))
-            pool.map(self.format_seq,indices)
+            indices = range(
+                index * self.config["num_process"],
+                min((index + 1) * self.config["num_process"], len(self.seq_list)),
+            )
+            pool.map(self.format_seq, indices)
         pool.close()
         pool.join()
 
-        self.collect_meta()
\ No newline at end of file
+        self.collect_meta()
diff --git a/third_party/SGMNet/datadump/dumper/scannet.py b/third_party/SGMNet/datadump/dumper/scannet.py
index 2556f727fcc9b4c621e44d9ee5cb4e99cb19b7e8..ac45f41e3530fea49191188146187bcef7bd514d 100644
--- a/third_party/SGMNet/datadump/dumper/scannet.py
+++ b/third_party/SGMNet/datadump/dumper/scannet.py
@@ -7,66 +7,137 @@ import h5py
 from .base_dumper import BaseDumper
 
 import sys
+
 ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
 sys.path.insert(0, ROOT_DIR)
 import utils
 
+
 class scannet(BaseDumper):
     def get_seqs(self):
-        self.pair_list=np.loadtxt('../assets/scannet_eval_list.txt',dtype=str)
-        self.seq_list=np.unique(np.asarray([path.split('/')[0] for path in self.pair_list[:,0]],dtype=str))
-        self.dump_seq,self.img_seq=[],[]
+        self.pair_list = np.loadtxt("../assets/scannet_eval_list.txt", dtype=str)
+        self.seq_list = np.unique(
+            np.asarray([path.split("/")[0] for path in self.pair_list[:, 0]], dtype=str)
+        )
+        self.dump_seq, self.img_seq = [], []
         for seq in self.seq_list:
-            dump_dir=os.path.join(self.config['feature_dump_dir'],seq)
-            cur_img_seq=glob.glob(os.path.join(os.path.join(self.config['rawdata_dir'],seq,'img','*.jpg')))
-            cur_dump_seq=[os.path.join(dump_dir,path.split('/')[-1])+'_'+self.config['extractor']['name']+'_'+str(self.config['extractor']['num_kpt'])\
-                            +'.hdf5' for path in cur_img_seq]
-            self.img_seq+=cur_img_seq
-            self.dump_seq+=cur_dump_seq
+            dump_dir = os.path.join(self.config["feature_dump_dir"], seq)
+            cur_img_seq = glob.glob(
+                os.path.join(
+                    os.path.join(self.config["rawdata_dir"], seq, "img", "*.jpg")
+                )
+            )
+            cur_dump_seq = [
+                os.path.join(dump_dir, path.split("/")[-1])
+                + "_"
+                + self.config["extractor"]["name"]
+                + "_"
+                + str(self.config["extractor"]["num_kpt"])
+                + ".hdf5"
+                for path in cur_img_seq
+            ]
+            self.img_seq += cur_img_seq
+            self.dump_seq += cur_dump_seq
 
     def format_dump_folder(self):
-        if not os.path.exists(self.config['feature_dump_dir']):
-            os.mkdir(self.config['feature_dump_dir'])
+        if not os.path.exists(self.config["feature_dump_dir"]):
+            os.mkdir(self.config["feature_dump_dir"])
         for seq in self.seq_list:
-            seq_dir=os.path.join(self.config['feature_dump_dir'],seq)
+            seq_dir = os.path.join(self.config["feature_dump_dir"], seq)
             if not os.path.exists(seq_dir):
                 os.mkdir(seq_dir)
 
     def format_dump_data(self):
-        print('Formatting data...')
-        self.data={'K1':[],'K2':[],'R':[],'T':[],'e':[],'f':[],'fea_path1':[],'fea_path2':[],'img_path1':[],'img_path2':[]}
+        print("Formatting data...")
+        self.data = {
+            "K1": [],
+            "K2": [],
+            "R": [],
+            "T": [],
+            "e": [],
+            "f": [],
+            "fea_path1": [],
+            "fea_path2": [],
+            "img_path1": [],
+            "img_path2": [],
+        }
 
         for pair in self.pair_list:
-            img_path1,img_path2=pair[0],pair[1]
-            seq=img_path1.split('/')[0]
-            index1,index2=int(img_path1.split('/')[-1][:-4]),int(img_path2.split('/')[-1][:-4])
-            ex1,ex2=np.loadtxt(os.path.join(self.config['rawdata_dir'],seq,'extrinsic',str(index1)+'.txt'),dtype=float),\
-                    np.loadtxt(os.path.join(self.config['rawdata_dir'],seq,'extrinsic',str(index2)+'.txt'),dtype=float)
-            K1,K2=np.loadtxt(os.path.join(self.config['rawdata_dir'],seq,'intrinsic',str(index1)+'.txt'),dtype=float),\
-                  np.loadtxt(os.path.join(self.config['rawdata_dir'],seq,'intrinsic',str(index2)+'.txt'),dtype=float)
-        
+            img_path1, img_path2 = pair[0], pair[1]
+            seq = img_path1.split("/")[0]
+            index1, index2 = int(img_path1.split("/")[-1][:-4]), int(
+                img_path2.split("/")[-1][:-4]
+            )
+            ex1, ex2 = np.loadtxt(
+                os.path.join(
+                    self.config["rawdata_dir"], seq, "extrinsic", str(index1) + ".txt"
+                ),
+                dtype=float,
+            ), np.loadtxt(
+                os.path.join(
+                    self.config["rawdata_dir"], seq, "extrinsic", str(index2) + ".txt"
+                ),
+                dtype=float,
+            )
+            K1, K2 = np.loadtxt(
+                os.path.join(
+                    self.config["rawdata_dir"], seq, "intrinsic", str(index1) + ".txt"
+                ),
+                dtype=float,
+            ), np.loadtxt(
+                os.path.join(
+                    self.config["rawdata_dir"], seq, "intrinsic", str(index2) + ".txt"
+                ),
+                dtype=float,
+            )
 
-            relative_extrinsic=np.matmul(np.linalg.inv(ex2),ex1)
-            dR,dt=relative_extrinsic[:3,:3],relative_extrinsic[:3,3]
+            relative_extrinsic = np.matmul(np.linalg.inv(ex2), ex1)
+            dR, dt = relative_extrinsic[:3, :3], relative_extrinsic[:3, 3]
             dt /= np.sqrt(np.sum(dt**2))
-            
-            e_gt_unnorm = np.reshape(np.matmul(
-            np.reshape(utils.evaluation_utils.np_skew_symmetric(dt.astype('float64').reshape(1, 3)), (3, 3)),
-            np.reshape(dR.astype('float64'), (3, 3))), (3, 3))
+
+            e_gt_unnorm = np.reshape(
+                np.matmul(
+                    np.reshape(
+                        utils.evaluation_utils.np_skew_symmetric(
+                            dt.astype("float64").reshape(1, 3)
+                        ),
+                        (3, 3),
+                    ),
+                    np.reshape(dR.astype("float64"), (3, 3)),
+                ),
+                (3, 3),
+            )
             e_gt = e_gt_unnorm / np.linalg.norm(e_gt_unnorm)
-            f_gt_unnorm=np.linalg.inv(K2.T)@e_gt@np.linalg.inv(K1)
+            f_gt_unnorm = np.linalg.inv(K2.T) @ e_gt @ np.linalg.inv(K1)
             f_gt = f_gt_unnorm / np.linalg.norm(f_gt_unnorm)
 
-            self.data['K1'].append(K1),self.data['K2'].append(K2)
-            self.data['R'].append(dR),self.data['T'].append(dt)
-            self.data['e'].append(e_gt),self.data['f'].append(f_gt)
-            
-            dump_seq_dir=os.path.join(self.config['feature_dump_dir'],seq)
-            fea_path1,fea_path2=os.path.join(dump_seq_dir,img_path1.split('/')[-1]+'_'+self.config['extractor']['name']
-                                +'_'+str(self.config['extractor']['num_kpt'])+'.hdf5'),\
-                                os.path.join(dump_seq_dir,img_path2.split('/')[-1]+'_'+self.config['extractor']['name']
-                                +'_'+str(self.config['extractor']['num_kpt'])+'.hdf5')
-            self.data['img_path1'].append(img_path1),self.data['img_path2'].append(img_path2)
-            self.data['fea_path1'].append(fea_path1),self.data['fea_path2'].append(fea_path2)
+            self.data["K1"].append(K1), self.data["K2"].append(K2)
+            self.data["R"].append(dR), self.data["T"].append(dt)
+            self.data["e"].append(e_gt), self.data["f"].append(f_gt)
+
+            dump_seq_dir = os.path.join(self.config["feature_dump_dir"], seq)
+            fea_path1, fea_path2 = os.path.join(
+                dump_seq_dir,
+                img_path1.split("/")[-1]
+                + "_"
+                + self.config["extractor"]["name"]
+                + "_"
+                + str(self.config["extractor"]["num_kpt"])
+                + ".hdf5",
+            ), os.path.join(
+                dump_seq_dir,
+                img_path2.split("/")[-1]
+                + "_"
+                + self.config["extractor"]["name"]
+                + "_"
+                + str(self.config["extractor"]["num_kpt"])
+                + ".hdf5",
+            )
+            self.data["img_path1"].append(img_path1), self.data["img_path2"].append(
+                img_path2
+            )
+            self.data["fea_path1"].append(fea_path1), self.data["fea_path2"].append(
+                fea_path2
+            )
 
         self.form_standard_dataset()
diff --git a/third_party/SGMNet/datadump/dumper/yfcc.py b/third_party/SGMNet/datadump/dumper/yfcc.py
index 0c52e4324bba3e5ed424fe58af7a94fd3132b1e5..be1efe71775aef04a6e720751d637a093e28c06a 100644
--- a/third_party/SGMNet/datadump/dumper/yfcc.py
+++ b/third_party/SGMNet/datadump/dumper/yfcc.py
@@ -6,82 +6,145 @@ import h5py
 from .base_dumper import BaseDumper
 
 import sys
+
 ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
 sys.path.insert(0, ROOT_DIR)
 import utils
 
+
 class yfcc(BaseDumper):
-    
     def get_seqs(self):
-        data_dir=os.path.join(self.config['rawdata_dir'],'yfcc100m')
-        for seq in self.config['data_seq']:
-            for split in self.config['data_split']:
-                split_dir=os.path.join(data_dir,seq,split)
-                dump_dir=os.path.join(self.config['feature_dump_dir'],seq,split)
-                cur_img_seq=glob.glob(os.path.join(split_dir,'images','*.jpg'))
-                cur_dump_seq=[os.path.join(dump_dir,path.split('/')[-1])+'_'+self.config['extractor']['name']+'_'+str(self.config['extractor']['num_kpt'])\
-                             +'.hdf5' for path in cur_img_seq]
-                self.img_seq+=cur_img_seq
-                self.dump_seq+=cur_dump_seq
+        data_dir = os.path.join(self.config["rawdata_dir"], "yfcc100m")
+        for seq in self.config["data_seq"]:
+            for split in self.config["data_split"]:
+                split_dir = os.path.join(data_dir, seq, split)
+                dump_dir = os.path.join(self.config["feature_dump_dir"], seq, split)
+                cur_img_seq = glob.glob(os.path.join(split_dir, "images", "*.jpg"))
+                cur_dump_seq = [
+                    os.path.join(dump_dir, path.split("/")[-1])
+                    + "_"
+                    + self.config["extractor"]["name"]
+                    + "_"
+                    + str(self.config["extractor"]["num_kpt"])
+                    + ".hdf5"
+                    for path in cur_img_seq
+                ]
+                self.img_seq += cur_img_seq
+                self.dump_seq += cur_dump_seq
 
     def format_dump_folder(self):
-        if not os.path.exists(self.config['feature_dump_dir']):
-            os.mkdir(self.config['feature_dump_dir'])
-        for seq in self.config['data_seq']:
-            seq_dir=os.path.join(self.config['feature_dump_dir'],seq)
+        if not os.path.exists(self.config["feature_dump_dir"]):
+            os.mkdir(self.config["feature_dump_dir"])
+        for seq in self.config["data_seq"]:
+            seq_dir = os.path.join(self.config["feature_dump_dir"], seq)
             if not os.path.exists(seq_dir):
                 os.mkdir(seq_dir)
-            for split in self.config['data_split']:
-                split_dir=os.path.join(seq_dir,split)
+            for split in self.config["data_split"]:
+                split_dir = os.path.join(seq_dir, split)
                 if not os.path.exists(split_dir):
                     os.mkdir(split_dir)
 
     def format_dump_data(self):
-        print('Formatting data...')
-        pair_path=os.path.join(self.config['rawdata_dir'],'pairs')
-        self.data={'K1':[],'K2':[],'R':[],'T':[],'e':[],'f':[],'fea_path1':[],'fea_path2':[],'img_path1':[],'img_path2':[]}
+        print("Formatting data...")
+        pair_path = os.path.join(self.config["rawdata_dir"], "pairs")
+        self.data = {
+            "K1": [],
+            "K2": [],
+            "R": [],
+            "T": [],
+            "e": [],
+            "f": [],
+            "fea_path1": [],
+            "fea_path2": [],
+            "img_path1": [],
+            "img_path2": [],
+        }
+
+        for seq in self.config["data_seq"]:
+            pair_name = os.path.join(pair_path, seq + "-te-1000-pairs.pkl")
+            with open(pair_name, "rb") as f:
+                pairs = pickle.load(f)
 
-        for seq in self.config['data_seq']:
-            pair_name=os.path.join(pair_path,seq+'-te-1000-pairs.pkl')
-            with open(pair_name, 'rb') as f:
-                pairs=pickle.load(f)
-   
-            #generate id list
-            seq_dir=os.path.join(self.config['rawdata_dir'],'yfcc100m',seq,'test')
-            name_list=np.loadtxt(os.path.join(seq_dir,'images.txt'),dtype=str)
-            cam_name_list=np.loadtxt(os.path.join(seq_dir,'calibration.txt'),dtype=str)
+            # generate id list
+            seq_dir = os.path.join(self.config["rawdata_dir"], "yfcc100m", seq, "test")
+            name_list = np.loadtxt(os.path.join(seq_dir, "images.txt"), dtype=str)
+            cam_name_list = np.loadtxt(
+                os.path.join(seq_dir, "calibration.txt"), dtype=str
+            )
 
             for cur_pair in pairs:
-                index1,index2=cur_pair[0],cur_pair[1]
-                cam1,cam2=h5py.File(os.path.join(seq_dir,cam_name_list[index1]),'r'),h5py.File(os.path.join(seq_dir,cam_name_list[index2]),'r')
-                K1,K2=cam1['K'][()],cam2['K'][()]
-                [w1,h1],[w2,h2]=cam1['imsize'][()][0],cam2['imsize'][()][0]
-                cx1,cy1,cx2,cy2 = (w1 - 1.0) * 0.5,(h1 - 1.0) * 0.5, (w2 - 1.0) * 0.5,(h2 - 1.0) * 0.5
-                K1[0,2],K1[1,2],K2[0,2],K2[1,2]=cx1,cy1,cx2,cy2
+                index1, index2 = cur_pair[0], cur_pair[1]
+                cam1, cam2 = h5py.File(
+                    os.path.join(seq_dir, cam_name_list[index1]), "r"
+                ), h5py.File(os.path.join(seq_dir, cam_name_list[index2]), "r")
+                K1, K2 = cam1["K"][()], cam2["K"][()]
+                [w1, h1], [w2, h2] = cam1["imsize"][()][0], cam2["imsize"][()][0]
+                cx1, cy1, cx2, cy2 = (
+                    (w1 - 1.0) * 0.5,
+                    (h1 - 1.0) * 0.5,
+                    (w2 - 1.0) * 0.5,
+                    (h2 - 1.0) * 0.5,
+                )
+                K1[0, 2], K1[1, 2], K2[0, 2], K2[1, 2] = cx1, cy1, cx2, cy2
 
-                R1,R2,t1,t2=cam1['R'][()],cam2['R'][()],cam1['T'][()].reshape([3,1]),cam2['T'][()].reshape([3,1])
+                R1, R2, t1, t2 = (
+                    cam1["R"][()],
+                    cam2["R"][()],
+                    cam1["T"][()].reshape([3, 1]),
+                    cam2["T"][()].reshape([3, 1]),
+                )
                 dR = np.dot(R2, R1.T)
                 dt = t2 - np.dot(dR, t1)
                 dt /= np.sqrt(np.sum(dt**2))
-                
-                e_gt_unnorm = np.reshape(np.matmul(
-                np.reshape(utils.evaluation_utils.np_skew_symmetric(dt.astype('float64').reshape(1, 3)), (3, 3)),
-                np.reshape(dR.astype('float64'), (3, 3))), (3, 3))
+
+                e_gt_unnorm = np.reshape(
+                    np.matmul(
+                        np.reshape(
+                            utils.evaluation_utils.np_skew_symmetric(
+                                dt.astype("float64").reshape(1, 3)
+                            ),
+                            (3, 3),
+                        ),
+                        np.reshape(dR.astype("float64"), (3, 3)),
+                    ),
+                    (3, 3),
+                )
                 e_gt = e_gt_unnorm / np.linalg.norm(e_gt_unnorm)
-                f_gt_unnorm=np.linalg.inv(K2.T)@e_gt@np.linalg.inv(K1)
+                f_gt_unnorm = np.linalg.inv(K2.T) @ e_gt @ np.linalg.inv(K1)
                 f_gt = f_gt_unnorm / np.linalg.norm(f_gt_unnorm)
 
-                self.data['K1'].append(K1),self.data['K2'].append(K2)
-                self.data['R'].append(dR),self.data['T'].append(dt)
-                self.data['e'].append(e_gt),self.data['f'].append(f_gt)
-                
-                img_path1,img_path2=os.path.join('yfcc100m',seq,'test',name_list[index1]),os.path.join('yfcc100m',seq,'test',name_list[index2])
-                dump_seq_dir=os.path.join(self.config['feature_dump_dir'],seq,'test')
-                fea_path1,fea_path2=os.path.join(dump_seq_dir,name_list[index1].split('/')[-1]+'_'+self.config['extractor']['name']
-                                    +'_'+str(self.config['extractor']['num_kpt'])+'.hdf5'),\
-                                    os.path.join(dump_seq_dir,name_list[index2].split('/')[-1]+'_'+self.config['extractor']['name']
-                                    +'_'+str(self.config['extractor']['num_kpt'])+'.hdf5')
-                self.data['img_path1'].append(img_path1),self.data['img_path2'].append(img_path2)
-                self.data['fea_path1'].append(fea_path1),self.data['fea_path2'].append(fea_path2)
+                self.data["K1"].append(K1), self.data["K2"].append(K2)
+                self.data["R"].append(dR), self.data["T"].append(dt)
+                self.data["e"].append(e_gt), self.data["f"].append(f_gt)
+
+                img_path1, img_path2 = os.path.join(
+                    "yfcc100m", seq, "test", name_list[index1]
+                ), os.path.join("yfcc100m", seq, "test", name_list[index2])
+                dump_seq_dir = os.path.join(
+                    self.config["feature_dump_dir"], seq, "test"
+                )
+                fea_path1, fea_path2 = os.path.join(
+                    dump_seq_dir,
+                    name_list[index1].split("/")[-1]
+                    + "_"
+                    + self.config["extractor"]["name"]
+                    + "_"
+                    + str(self.config["extractor"]["num_kpt"])
+                    + ".hdf5",
+                ), os.path.join(
+                    dump_seq_dir,
+                    name_list[index2].split("/")[-1]
+                    + "_"
+                    + self.config["extractor"]["name"]
+                    + "_"
+                    + str(self.config["extractor"]["num_kpt"])
+                    + ".hdf5",
+                )
+                self.data["img_path1"].append(img_path1), self.data["img_path2"].append(
+                    img_path2
+                )
+                self.data["fea_path1"].append(fea_path1), self.data["fea_path2"].append(
+                    fea_path2
+                )
 
         self.form_standard_dataset()
diff --git a/third_party/SGMNet/demo/demo.py b/third_party/SGMNet/demo/demo.py
index cbe277e26d09121f5517854a7ea014b0797a2bde..835b20485698fbccb055a8f08024014142666377 100644
--- a/third_party/SGMNet/demo/demo.py
+++ b/third_party/SGMNet/demo/demo.py
@@ -10,36 +10,56 @@ from components import load_component
 from utils import evaluation_utils
 
 import argparse
+
 parser = argparse.ArgumentParser()
-parser.add_argument('--config_path', type=str, default='configs/sgm_config.yaml',
-  help='number of processes.')
-parser.add_argument('--img1_path', type=str, default='demo_1.jpg',
-  help='number of processes.')
-parser.add_argument('--img2_path', type=str, default='demo_2.jpg',
-  help='number of processes.')
+parser.add_argument(
+    "--config_path",
+    type=str,
+    default="configs/sgm_config.yaml",
+    help="number of processes.",
+)
+parser.add_argument(
+    "--img1_path", type=str, default="demo_1.jpg", help="number of processes."
+)
+parser.add_argument(
+    "--img2_path", type=str, default="demo_2.jpg", help="number of processes."
+)
 
 
 args = parser.parse_args()
 
-if __name__=='__main__':
-    with open(args.config_path, 'r') as f:
-      demo_config = yaml.load(f)
+if __name__ == "__main__":
+    with open(args.config_path, "r") as f:
+        demo_config = yaml.load(f)
+
+    extractor = load_component(
+        "extractor", demo_config["extractor"]["name"], demo_config["extractor"]
+    )
 
-    extractor=load_component('extractor',demo_config['extractor']['name'],demo_config['extractor'])
+    img1, img2 = cv2.imread(args.img1_path), cv2.imread(args.img2_path)
+    size1, size2 = np.flip(np.asarray(img1.shape[:2])), np.flip(
+        np.asarray(img2.shape[:2])
+    )
+    kpt1, desc1 = extractor.run(args.img1_path)
+    kpt2, desc2 = extractor.run(args.img2_path)
 
-    img1,img2=cv2.imread(args.img1_path),cv2.imread(args.img2_path)
-    size1,size2=np.flip(np.asarray(img1.shape[:2])),np.flip(np.asarray(img2.shape[:2]))
-    kpt1,desc1=extractor.run(args.img1_path)
-    kpt2,desc2=extractor.run(args.img2_path)
-    
-    matcher=load_component('matcher',demo_config['matcher']['name'],demo_config['matcher'])
-    test_data={'x1':kpt1,'x2':kpt2,'desc1':desc1,'desc2':desc2,'size1':size1,'size2':size2}
-    corr1,corr2= matcher.run(test_data)
+    matcher = load_component(
+        "matcher", demo_config["matcher"]["name"], demo_config["matcher"]
+    )
+    test_data = {
+        "x1": kpt1,
+        "x2": kpt2,
+        "desc1": desc1,
+        "desc2": desc2,
+        "size1": size1,
+        "size2": size2,
+    }
+    corr1, corr2 = matcher.run(test_data)
 
-    #draw points
+    # draw points
     dis_points_1 = evaluation_utils.draw_points(img1, kpt1)
-    dis_points_2 =  evaluation_utils.draw_points(img2, kpt2)
+    dis_points_2 = evaluation_utils.draw_points(img2, kpt2)
 
-    #visualize match
-    display=evaluation_utils.draw_match(dis_points_1,dis_points_2,corr1,corr2)
-    cv2.imwrite('match.png',display)
+    # visualize match
+    display = evaluation_utils.draw_match(dis_points_1, dis_points_2, corr1, corr2)
+    cv2.imwrite("match.png", display)
diff --git a/third_party/SGMNet/evaluation/eval_cost.py b/third_party/SGMNet/evaluation/eval_cost.py
index dd3f88abc93290c96ed3d7fa8624c3534e006911..972b4c226c84c3f24dfb2b76e0a31b12719166b0 100644
--- a/third_party/SGMNet/evaluation/eval_cost.py
+++ b/third_party/SGMNet/evaluation/eval_cost.py
@@ -1,9 +1,10 @@
 import torch
 import yaml
 import time
-from collections import OrderedDict,namedtuple
+from collections import OrderedDict, namedtuple
 import os
 import sys
+
 ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
 sys.path.insert(0, ROOT_DIR)
 
@@ -12,49 +13,59 @@ from superglue import matcher as SG_Model
 
 
 import argparse
+
 parser = argparse.ArgumentParser()
-parser.add_argument('--matcher_name', type=str, default='SGM',
-  help='number of processes.')
-parser.add_argument('--config_path', type=str, default='configs/cost/sgm_cost.yaml',
-  help='number of processes.')
-parser.add_argument('--num_kpt', type=int, default=4000,
-  help='keypoint number, default:100')
-parser.add_argument('--iter_num', type=int, default=100,
-  help='keypoint number, default:100')
+parser.add_argument(
+    "--matcher_name", type=str, default="SGM", help="number of processes."
+)
+parser.add_argument(
+    "--config_path",
+    type=str,
+    default="configs/cost/sgm_cost.yaml",
+    help="number of processes.",
+)
+parser.add_argument(
+    "--num_kpt", type=int, default=4000, help="keypoint number, default:100"
+)
+parser.add_argument(
+    "--iter_num", type=int, default=100, help="keypoint number, default:100"
+)
 
 
-def test_cost(test_data,model):
+def test_cost(test_data, model):
     with torch.no_grad():
-        #warm up call
-        _=model(test_data)
+        # warm up call
+        _ = model(test_data)
         torch.cuda.synchronize()
-        a=time.time()
+        a = time.time()
         for _ in range(int(args.iter_num)):
-            _=model(test_data)
+            _ = model(test_data)
         torch.cuda.synchronize()
-        b=time.time()
-    print('Average time per run(ms): ',(b-a)/args.iter_num*1e3)
-    print('Peak memory(MB): ',torch.cuda.max_memory_allocated()/1e6)
+        b = time.time()
+    print("Average time per run(ms): ", (b - a) / args.iter_num * 1e3)
+    print("Peak memory(MB): ", torch.cuda.max_memory_allocated() / 1e6)
 
 
-if __name__=='__main__':
-    torch.backends.cudnn.benchmark=False
+if __name__ == "__main__":
+    torch.backends.cudnn.benchmark = False
     args = parser.parse_args()
-    with open(args.config_path, 'r') as f:
-      model_config = yaml.load(f)
-    model_config=namedtuple('model_config',model_config.keys())(*model_config.values())
-    
-    if args.matcher_name=='SGM':
-      model = SGM_Model(model_config) 
-    elif args.matcher_name=='SG':
-      model = SG_Model(model_config)
-    model.cuda(),model.eval()
-    
+    with open(args.config_path, "r") as f:
+        model_config = yaml.load(f)
+    model_config = namedtuple("model_config", model_config.keys())(
+        *model_config.values()
+    )
+
+    if args.matcher_name == "SGM":
+        model = SGM_Model(model_config)
+    elif args.matcher_name == "SG":
+        model = SG_Model(model_config)
+    model.cuda(), model.eval()
+
     test_data = {
-            'x1':torch.rand(1,args.num_kpt,2).cuda()-0.5,
-            'x2':torch.rand(1,args.num_kpt,2).cuda()-0.5,
-            'desc1': torch.rand(1,args.num_kpt,128).cuda(),
-            'desc2': torch.rand(1,args.num_kpt,128).cuda()
-            }
+        "x1": torch.rand(1, args.num_kpt, 2).cuda() - 0.5,
+        "x2": torch.rand(1, args.num_kpt, 2).cuda() - 0.5,
+        "desc1": torch.rand(1, args.num_kpt, 128).cuda(),
+        "desc2": torch.rand(1, args.num_kpt, 128).cuda(),
+    }
 
-    test_cost(test_data,model)
+    test_cost(test_data, model)
diff --git a/third_party/SGMNet/evaluation/evaluate.py b/third_party/SGMNet/evaluation/evaluate.py
index dd5229375caa03b2763bf37a266fb76e80f8e25e..ec6c3ed2aa907838ed3d1cc0ed15710bcd5a6e5f 100644
--- a/third_party/SGMNet/evaluation/evaluate.py
+++ b/third_party/SGMNet/evaluation/evaluate.py
@@ -1,5 +1,5 @@
 import os
-from torch.multiprocessing import Process,Manager,set_start_method,Pool
+from torch.multiprocessing import Process, Manager, set_start_method, Pool
 import functools
 import argparse
 import yaml
@@ -7,111 +7,144 @@ import numpy as np
 import sys
 import cv2
 from tqdm import trange
-set_start_method('spawn',force=True)
+
+set_start_method("spawn", force=True)
 
 
 ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
 sys.path.insert(0, ROOT_DIR)
 
 from components import load_component
-from utils import evaluation_utils,metrics
-
-parser = argparse.ArgumentParser(description='dump eval data.')
-parser.add_argument('--config_path', type=str, default='configs/eval/scannet_eval_sgm.yaml')
-parser.add_argument('--num_process_match', type=int, default=4)
-parser.add_argument('--num_process_eval', type=int, default=4)
-parser.add_argument('--vis_folder',type=str,default=None)
-args=parser.parse_args()    
-
-def feed_match(info,matcher):
-    x1,x2,desc1,desc2,size1,size2=info['x1'],info['x2'],info['desc1'],info['desc2'],info['img1'].shape[:2],info['img2'].shape[:2]
-    test_data = {'x1': x1,'x2': x2,'desc1': desc1,'desc2': desc2,'size1':np.flip(np.asarray(size1)),'size2':np.flip(np.asarray(size2)) }
-    corr1,corr2=matcher.run(test_data)
-    return [corr1,corr2]
-
-
-def reader_handler(config,read_que):
-  reader=load_component('reader',config['name'],config)
-  for index in range(len(reader)):
-    index+=0
-    info=reader.run(index)
-    read_que.put(info)
-  read_que.put('over')
-
-
-def match_handler(config,read_que,match_que):
-  matcher=load_component('matcher',config['name'],config)
-  match_func=functools.partial(feed_match,matcher=matcher)
-  pool = Pool(args.num_process_match)
-  cache=[]
-  while True:
-    item=read_que.get()
-    #clear cache
-    if item=='over':
-      if len(cache)!=0:
-        results=pool.map(match_func,cache)
-        for cur_item,cur_result in zip(cache,results):
-          cur_item['corr1'],cur_item['corr2']=cur_result[0],cur_result[1]
-          match_que.put(cur_item)
-      match_que.put('over')
-      break
-    cache.append(item)
-    #print(len(cache))
-    if len(cache)==args.num_process_match:
-      #matching in parallel
-      results=pool.map(match_func,cache)
-      for cur_item,cur_result in zip(cache,results):
-          cur_item['corr1'],cur_item['corr2']=cur_result[0],cur_result[1]
-          match_que.put(cur_item)
-      cache=[]
-  pool.close()
-  pool.join()
-
-
-def evaluate_handler(config,match_que):
-  evaluator=load_component('evaluator',config['name'],config)
-  pool = Pool(args.num_process_eval)
-  cache=[]
-  for _ in trange(config['num_pair']):
-    item=match_que.get()
-    if item=='over':
-      if len(cache)!=0:
-        results=pool.map(evaluator.run,cache)
-        for cur_res in results:
-          evaluator.res_inqueue(cur_res)
-      break
-    cache.append(item)
-    if len(cache)==args.num_process_eval:
-      results=pool.map(evaluator.run,cache)
-      for cur_res in results:
-          evaluator.res_inqueue(cur_res)
-      cache=[]
-    if args.vis_folder is not None:
-      #dump visualization
-      corr1_norm,corr2_norm=evaluation_utils.normalize_intrinsic(item['corr1'],item['K1']),\
-                            evaluation_utils.normalize_intrinsic(item['corr2'],item['K2'])
-      inlier_mask=metrics.compute_epi_inlier(corr1_norm,corr2_norm,item['e'],config['inlier_th'])
-      display=evaluation_utils.draw_match(item['img1'],item['img2'],item['corr1'],item['corr2'],inlier_mask)
-      cv2.imwrite(os.path.join(args.vis_folder,str(item['index'])+'.png'),display)
-  evaluator.parse()
-
-
-if __name__=='__main__':
-  with open(args.config_path, 'r') as f:
-    config = yaml.load(f)
-  if args.vis_folder is not None and not os.path.exists(args.vis_folder):
-    os.mkdir(args.vis_folder)
-
-  read_que,match_que,estimate_que=Manager().Queue(maxsize=100),Manager().Queue(maxsize=100),Manager().Queue(maxsize=100)
-
-  read_process=Process(target=reader_handler,args=(config['reader'],read_que))
-  match_process=Process(target=match_handler,args=(config['matcher'],read_que,match_que))
-  evaluate_process=Process(target=evaluate_handler,args=(config['evaluator'],match_que))
-
-  read_process.start()
-  match_process.start()
-  evaluate_process.start()
-
-  read_process.join()
-  match_process.join()
-  evaluate_process.join()
\ No newline at end of file
+from utils import evaluation_utils, metrics
+
+parser = argparse.ArgumentParser(description="dump eval data.")
+parser.add_argument(
+    "--config_path", type=str, default="configs/eval/scannet_eval_sgm.yaml"
+)
+parser.add_argument("--num_process_match", type=int, default=4)
+parser.add_argument("--num_process_eval", type=int, default=4)
+parser.add_argument("--vis_folder", type=str, default=None)
+args = parser.parse_args()
+
+
+def feed_match(info, matcher):
+    x1, x2, desc1, desc2, size1, size2 = (
+        info["x1"],
+        info["x2"],
+        info["desc1"],
+        info["desc2"],
+        info["img1"].shape[:2],
+        info["img2"].shape[:2],
+    )
+    test_data = {
+        "x1": x1,
+        "x2": x2,
+        "desc1": desc1,
+        "desc2": desc2,
+        "size1": np.flip(np.asarray(size1)),
+        "size2": np.flip(np.asarray(size2)),
+    }
+    corr1, corr2 = matcher.run(test_data)
+    return [corr1, corr2]
+
+
+def reader_handler(config, read_que):
+    reader = load_component("reader", config["name"], config)
+    for index in range(len(reader)):
+        index += 0
+        info = reader.run(index)
+        read_que.put(info)
+    read_que.put("over")
+
+
+def match_handler(config, read_que, match_que):
+    matcher = load_component("matcher", config["name"], config)
+    match_func = functools.partial(feed_match, matcher=matcher)
+    pool = Pool(args.num_process_match)
+    cache = []
+    while True:
+        item = read_que.get()
+        # clear cache
+        if item == "over":
+            if len(cache) != 0:
+                results = pool.map(match_func, cache)
+                for cur_item, cur_result in zip(cache, results):
+                    cur_item["corr1"], cur_item["corr2"] = cur_result[0], cur_result[1]
+                    match_que.put(cur_item)
+            match_que.put("over")
+            break
+        cache.append(item)
+        # print(len(cache))
+        if len(cache) == args.num_process_match:
+            # matching in parallel
+            results = pool.map(match_func, cache)
+            for cur_item, cur_result in zip(cache, results):
+                cur_item["corr1"], cur_item["corr2"] = cur_result[0], cur_result[1]
+                match_que.put(cur_item)
+            cache = []
+    pool.close()
+    pool.join()
+
+
+def evaluate_handler(config, match_que):
+    evaluator = load_component("evaluator", config["name"], config)
+    pool = Pool(args.num_process_eval)
+    cache = []
+    for _ in trange(config["num_pair"]):
+        item = match_que.get()
+        if item == "over":
+            if len(cache) != 0:
+                results = pool.map(evaluator.run, cache)
+                for cur_res in results:
+                    evaluator.res_inqueue(cur_res)
+            break
+        cache.append(item)
+        if len(cache) == args.num_process_eval:
+            results = pool.map(evaluator.run, cache)
+            for cur_res in results:
+                evaluator.res_inqueue(cur_res)
+            cache = []
+        if args.vis_folder is not None:
+            # dump visualization
+            corr1_norm, corr2_norm = evaluation_utils.normalize_intrinsic(
+                item["corr1"], item["K1"]
+            ), evaluation_utils.normalize_intrinsic(item["corr2"], item["K2"])
+            inlier_mask = metrics.compute_epi_inlier(
+                corr1_norm, corr2_norm, item["e"], config["inlier_th"]
+            )
+            display = evaluation_utils.draw_match(
+                item["img1"], item["img2"], item["corr1"], item["corr2"], inlier_mask
+            )
+            cv2.imwrite(
+                os.path.join(args.vis_folder, str(item["index"]) + ".png"), display
+            )
+    evaluator.parse()
+
+
+if __name__ == "__main__":
+    with open(args.config_path, "r") as f:
+        config = yaml.load(f)
+    if args.vis_folder is not None and not os.path.exists(args.vis_folder):
+        os.mkdir(args.vis_folder)
+
+    read_que, match_que, estimate_que = (
+        Manager().Queue(maxsize=100),
+        Manager().Queue(maxsize=100),
+        Manager().Queue(maxsize=100),
+    )
+
+    read_process = Process(target=reader_handler, args=(config["reader"], read_que))
+    match_process = Process(
+        target=match_handler, args=(config["matcher"], read_que, match_que)
+    )
+    evaluate_process = Process(
+        target=evaluate_handler, args=(config["evaluator"], match_que)
+    )
+
+    read_process.start()
+    match_process.start()
+    evaluate_process.start()
+
+    read_process.join()
+    match_process.join()
+    evaluate_process.join()
diff --git a/third_party/SGMNet/sgmnet/__init__.py b/third_party/SGMNet/sgmnet/__init__.py
index 828543beceebb10d05fd9d5fdfcc4b1c91e5af6b..fabeccd0fe21eb5be637602f2b2eb3cfd944d11b 100644
--- a/third_party/SGMNet/sgmnet/__init__.py
+++ b/third_party/SGMNet/sgmnet/__init__.py
@@ -1 +1 @@
-from .match_model import matcher
\ No newline at end of file
+from .match_model import matcher
diff --git a/third_party/SGMNet/sgmnet/match_model.py b/third_party/SGMNet/sgmnet/match_model.py
index 1e55fa5d042b010f8d9a99e006002563a3961ae7..c758cf5d6537fb3c47a2de00cc279857755943ef 100644
--- a/third_party/SGMNet/sgmnet/match_model.py
+++ b/third_party/SGMNet/sgmnet/match_model.py
@@ -1,9 +1,10 @@
 import torch
 import torch.nn as nn
 
-eps=1e-8
+eps = 1e-8
 
-def sinkhorn(M,r,c,iteration):
+
+def sinkhorn(M, r, c, iteration):
     p = torch.softmax(M, dim=-1)
     u = torch.ones_like(r)
     v = torch.ones_like(c)
@@ -13,46 +14,79 @@ def sinkhorn(M,r,c,iteration):
     p = p * u.unsqueeze(-1) * v.unsqueeze(-2)
     return p
 
-def sink_algorithm(M,dustbin,iteration):
+
+def sink_algorithm(M, dustbin, iteration):
     M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1)
     M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2)
-    r = torch.ones([M.shape[0], M.shape[1] - 1],device='cuda')
-    r = torch.cat([r, torch.ones([M.shape[0], 1],device='cuda') * M.shape[1]], dim=-1)
-    c = torch.ones([M.shape[0], M.shape[2] - 1],device='cuda')
-    c = torch.cat([c, torch.ones([M.shape[0], 1],device='cuda') * M.shape[2]], dim=-1)
-    p=sinkhorn(M,r,c,iteration)
+    r = torch.ones([M.shape[0], M.shape[1] - 1], device="cuda")
+    r = torch.cat([r, torch.ones([M.shape[0], 1], device="cuda") * M.shape[1]], dim=-1)
+    c = torch.ones([M.shape[0], M.shape[2] - 1], device="cuda")
+    c = torch.cat([c, torch.ones([M.shape[0], 1], device="cuda") * M.shape[2]], dim=-1)
+    p = sinkhorn(M, r, c, iteration)
     return p
 
-        
-def seeding(nn_index1,nn_index2,x1,x2,topk,match_score,confbar,nms_radius,use_mc=True,test=False):
-    
-    #apply mutual check before nms
+
+def seeding(
+    nn_index1,
+    nn_index2,
+    x1,
+    x2,
+    topk,
+    match_score,
+    confbar,
+    nms_radius,
+    use_mc=True,
+    test=False,
+):
+
+    # apply mutual check before nms
     if use_mc:
-        mask_not_mutual=nn_index2.gather(dim=-1,index=nn_index1)!=torch.arange(nn_index1.shape[1],device='cuda')
-        match_score[mask_not_mutual]=-1
-    #NMS
-    pos_dismat1=((x1.norm(p=2,dim=-1)**2).unsqueeze_(-1)+(x1.norm(p=2,dim=-1)**2).unsqueeze_(-2)-2*(x1@x1.transpose(1,2))).abs_().sqrt_()
-    x2=x2.gather(index=nn_index1.unsqueeze(-1).expand(-1,-1,2),dim=1)
-    pos_dismat2=((x2.norm(p=2,dim=-1)**2).unsqueeze_(-1)+(x2.norm(p=2,dim=-1)**2).unsqueeze_(-2)-2*(x2@x2.transpose(1,2))).abs_().sqrt_()
-    radius1, radius2 = nms_radius * pos_dismat1.mean(dim=(1,2),keepdim=True), nms_radius * pos_dismat2.mean(dim=(1,2),keepdim=True)
+        mask_not_mutual = nn_index2.gather(dim=-1, index=nn_index1) != torch.arange(
+            nn_index1.shape[1], device="cuda"
+        )
+        match_score[mask_not_mutual] = -1
+    # NMS
+    pos_dismat1 = (
+        (
+            (x1.norm(p=2, dim=-1) ** 2).unsqueeze_(-1)
+            + (x1.norm(p=2, dim=-1) ** 2).unsqueeze_(-2)
+            - 2 * (x1 @ x1.transpose(1, 2))
+        )
+        .abs_()
+        .sqrt_()
+    )
+    x2 = x2.gather(index=nn_index1.unsqueeze(-1).expand(-1, -1, 2), dim=1)
+    pos_dismat2 = (
+        (
+            (x2.norm(p=2, dim=-1) ** 2).unsqueeze_(-1)
+            + (x2.norm(p=2, dim=-1) ** 2).unsqueeze_(-2)
+            - 2 * (x2 @ x2.transpose(1, 2))
+        )
+        .abs_()
+        .sqrt_()
+    )
+    radius1, radius2 = nms_radius * pos_dismat1.mean(
+        dim=(1, 2), keepdim=True
+    ), nms_radius * pos_dismat2.mean(dim=(1, 2), keepdim=True)
     nms_mask = (pos_dismat1 >= radius1) & (pos_dismat2 >= radius2)
-    mask_not_local_max=(match_score.unsqueeze(-1)>=match_score.unsqueeze(-2))|nms_mask
-    mask_not_local_max=~(mask_not_local_max.min(dim=-1).values)
+    mask_not_local_max = (
+        match_score.unsqueeze(-1) >= match_score.unsqueeze(-2)
+    ) | nms_mask
+    mask_not_local_max = ~(mask_not_local_max.min(dim=-1).values)
     match_score[mask_not_local_max] = -1
- 
-    #confidence bar
-    match_score[match_score<confbar]=-1
-    mask_survive=match_score>0
-    if test:
-        topk=min(mask_survive.sum(dim=1)[0]+2,topk)
-    _,topindex = torch.topk(match_score,topk,dim=-1)#b*k
-    seed_index1,seed_index2=topindex,nn_index1.gather(index=topindex,dim=-1)
-    return seed_index1,seed_index2
 
+    # confidence bar
+    match_score[match_score < confbar] = -1
+    mask_survive = match_score > 0
+    if test:
+        topk = min(mask_survive.sum(dim=1)[0] + 2, topk)
+    _, topindex = torch.topk(match_score, topk, dim=-1)  # b*k
+    seed_index1, seed_index2 = topindex, nn_index1.gather(index=topindex, dim=-1)
+    return seed_index1, seed_index2
 
 
 class PointCN(nn.Module):
-    def __init__(self, channels,out_channels):
+    def __init__(self, channels, out_channels):
         nn.Module.__init__(self)
         self.shot_cut = nn.Conv1d(channels, out_channels, kernel_size=1)
         self.conv = nn.Sequential(
@@ -63,7 +97,7 @@ class PointCN(nn.Module):
             nn.InstanceNorm1d(channels, eps=1e-3),
             nn.SyncBatchNorm(channels),
             nn.ReLU(),
-            nn.Conv1d(channels, out_channels, kernel_size=1)
+            nn.Conv1d(channels, out_channels, kernel_size=1),
         )
 
     def forward(self, x):
@@ -71,152 +105,254 @@ class PointCN(nn.Module):
 
 
 class attention_propagantion(nn.Module):
-
-    def __init__(self,channel,head):
+    def __init__(self, channel, head):
         nn.Module.__init__(self)
-        self.head=head
-        self.head_dim=channel//head
-        self.query_filter,self.key_filter,self.value_filter=nn.Conv1d(channel,channel,kernel_size=1),nn.Conv1d(channel,channel,kernel_size=1),\
-                                                            nn.Conv1d(channel,channel,kernel_size=1)
-        self.mh_filter=nn.Conv1d(channel,channel,kernel_size=1)
-        self.cat_filter=nn.Sequential(nn.Conv1d(2*channel,2*channel, kernel_size=1), nn.SyncBatchNorm(2*channel), nn.ReLU(),
-                                      nn.Conv1d(2*channel, channel, kernel_size=1))
-
-    def forward(self,desc1,desc2,weight_v=None):
-        #desc1(q) attend to desc2(k,v)
-        batch_size=desc1.shape[0]
-        query,key,value=self.query_filter(desc1).view(batch_size,self.head,self.head_dim,-1),self.key_filter(desc2).view(batch_size,self.head,self.head_dim,-1),\
-                        self.value_filter(desc2).view(batch_size,self.head,self.head_dim,-1)
+        self.head = head
+        self.head_dim = channel // head
+        self.query_filter, self.key_filter, self.value_filter = (
+            nn.Conv1d(channel, channel, kernel_size=1),
+            nn.Conv1d(channel, channel, kernel_size=1),
+            nn.Conv1d(channel, channel, kernel_size=1),
+        )
+        self.mh_filter = nn.Conv1d(channel, channel, kernel_size=1)
+        self.cat_filter = nn.Sequential(
+            nn.Conv1d(2 * channel, 2 * channel, kernel_size=1),
+            nn.SyncBatchNorm(2 * channel),
+            nn.ReLU(),
+            nn.Conv1d(2 * channel, channel, kernel_size=1),
+        )
+
+    def forward(self, desc1, desc2, weight_v=None):
+        # desc1(q) attend to desc2(k,v)
+        batch_size = desc1.shape[0]
+        query, key, value = (
+            self.query_filter(desc1).view(batch_size, self.head, self.head_dim, -1),
+            self.key_filter(desc2).view(batch_size, self.head, self.head_dim, -1),
+            self.value_filter(desc2).view(batch_size, self.head, self.head_dim, -1),
+        )
         if weight_v is not None:
-            value=value*weight_v.view(batch_size,1,1,-1)
-        score=torch.softmax(torch.einsum('bhdn,bhdm->bhnm',query,key)/ self.head_dim ** 0.5,dim=-1)
-        add_value=torch.einsum('bhnm,bhdm->bhdn',score,value).reshape(batch_size,self.head_dim*self.head,-1)
-        add_value=self.mh_filter(add_value)
-        desc1_new=desc1+self.cat_filter(torch.cat([desc1,add_value],dim=1))
+            value = value * weight_v.view(batch_size, 1, 1, -1)
+        score = torch.softmax(
+            torch.einsum("bhdn,bhdm->bhnm", query, key) / self.head_dim**0.5, dim=-1
+        )
+        add_value = torch.einsum("bhnm,bhdm->bhdn", score, value).reshape(
+            batch_size, self.head_dim * self.head, -1
+        )
+        add_value = self.mh_filter(add_value)
+        desc1_new = desc1 + self.cat_filter(torch.cat([desc1, add_value], dim=1))
         return desc1_new
 
 
 class hybrid_block(nn.Module):
-    def __init__(self,channel,head):
+    def __init__(self, channel, head):
         nn.Module.__init__(self)
-        self.head=head
-        self.channel=channel
+        self.head = head
+        self.channel = channel
         self.attention_block_down = attention_propagantion(channel, head)
-        self.cluster_filter=nn.Sequential(nn.Conv1d(2*channel,2*channel, kernel_size=1), nn.SyncBatchNorm(2*channel), nn.ReLU(),
-                                         nn.Conv1d(2*channel, 2*channel, kernel_size=1))
-        self.cross_filter=attention_propagantion(channel,head)
-        self.confidence_filter=PointCN(2*channel,1)
-        self.attention_block_self=attention_propagantion(channel,head)
-        self.attention_block_up=attention_propagantion(channel,head)
-        
-    def forward(self,desc1,desc2,seed_index1,seed_index2):
-        cluster1, cluster2 = desc1.gather(dim=-1, index=seed_index1.unsqueeze(1).expand(-1, self.channel, -1)), \
-                             desc2.gather(dim=-1, index=seed_index2.unsqueeze(1).expand(-1, self.channel, -1))
-        
-        #pooling
-        cluster1, cluster2 = self.attention_block_down(cluster1, desc1), self.attention_block_down(cluster2, desc2)
-        concate_cluster=self.cluster_filter(torch.cat([cluster1,cluster2],dim=1))
-        #filtering
-        cluster1,cluster2=self.cross_filter(concate_cluster[:,:self.channel],concate_cluster[:,self.channel:]),\
-                        self.cross_filter(concate_cluster[:,self.channel:],concate_cluster[:,:self.channel])
-        cluster1,cluster2=self.attention_block_self(cluster1,cluster1),self.attention_block_self(cluster2,cluster2)
-        #unpooling
-        seed_weight=self.confidence_filter(torch.cat([cluster1,cluster2],dim=1))
-        seed_weight=torch.sigmoid(seed_weight).squeeze(1)
-        desc1_new,desc2_new=self.attention_block_up(desc1,cluster1,seed_weight),self.attention_block_up(desc2,cluster2,seed_weight)
-        return desc1_new,desc2_new,seed_weight
+        self.cluster_filter = nn.Sequential(
+            nn.Conv1d(2 * channel, 2 * channel, kernel_size=1),
+            nn.SyncBatchNorm(2 * channel),
+            nn.ReLU(),
+            nn.Conv1d(2 * channel, 2 * channel, kernel_size=1),
+        )
+        self.cross_filter = attention_propagantion(channel, head)
+        self.confidence_filter = PointCN(2 * channel, 1)
+        self.attention_block_self = attention_propagantion(channel, head)
+        self.attention_block_up = attention_propagantion(channel, head)
 
+    def forward(self, desc1, desc2, seed_index1, seed_index2):
+        cluster1, cluster2 = desc1.gather(
+            dim=-1, index=seed_index1.unsqueeze(1).expand(-1, self.channel, -1)
+        ), desc2.gather(
+            dim=-1, index=seed_index2.unsqueeze(1).expand(-1, self.channel, -1)
+        )
+
+        # pooling
+        cluster1, cluster2 = self.attention_block_down(
+            cluster1, desc1
+        ), self.attention_block_down(cluster2, desc2)
+        concate_cluster = self.cluster_filter(torch.cat([cluster1, cluster2], dim=1))
+        # filtering
+        cluster1, cluster2 = self.cross_filter(
+            concate_cluster[:, : self.channel], concate_cluster[:, self.channel :]
+        ), self.cross_filter(
+            concate_cluster[:, self.channel :], concate_cluster[:, : self.channel]
+        )
+        cluster1, cluster2 = self.attention_block_self(
+            cluster1, cluster1
+        ), self.attention_block_self(cluster2, cluster2)
+        # unpooling
+        seed_weight = self.confidence_filter(torch.cat([cluster1, cluster2], dim=1))
+        seed_weight = torch.sigmoid(seed_weight).squeeze(1)
+        desc1_new, desc2_new = self.attention_block_up(
+            desc1, cluster1, seed_weight
+        ), self.attention_block_up(desc2, cluster2, seed_weight)
+        return desc1_new, desc2_new, seed_weight
 
 
 class matcher(nn.Module):
-    def __init__(self,config):
+    def __init__(self, config):
         nn.Module.__init__(self)
-        self.seed_top_k=config.seed_top_k
-        self.conf_bar=config.conf_bar
-        self.seed_radius_coe=config.seed_radius_coe
-        self.use_score_encoding=config.use_score_encoding
-        self.detach_iter=config.detach_iter
-        self.seedlayer=config.seedlayer
-        self.layer_num=config.layer_num
-        self.sink_iter=config.sink_iter
-
-        self.position_encoder = nn.Sequential(nn.Conv1d(3, 32, kernel_size=1) if config.use_score_encoding else nn.Conv1d(2, 32, kernel_size=1), 
-                                            nn.SyncBatchNorm(32),nn.ReLU(),
-                                            nn.Conv1d(32, 64, kernel_size=1), nn.SyncBatchNorm(64),nn.ReLU(),
-                                            nn.Conv1d(64, 128, kernel_size=1), nn.SyncBatchNorm(128),nn.ReLU(),
-                                            nn.Conv1d(128, 256, kernel_size=1), nn.SyncBatchNorm(256),nn.ReLU(),
-                                            nn.Conv1d(256, config.net_channels, kernel_size=1))
-     
-        
-        self.hybrid_block=nn.Sequential(*[hybrid_block(config.net_channels, config.head) for _ in range(config.layer_num)])
-        self.final_project = nn.Conv1d(config.net_channels, config.net_channels, kernel_size=1)
-        self.dustbin=nn.Parameter(torch.tensor(1.5,dtype=torch.float32))
-        
-        #if reseeding
-        if len(config.seedlayer)!=1:
-            self.mid_dustbin=nn.ParameterDict({str(i):nn.Parameter(torch.tensor(2,dtype=torch.float32)) for i in config.seedlayer[1:]})
-            self.mid_final_project = nn.Conv1d(config.net_channels, config.net_channels, kernel_size=1)
-       
-    def forward(self,data,test_mode=True):
-        x1, x2, desc1, desc2 = data['x1'][:,:,:2], data['x2'][:,:,:2], data['desc1'], data['desc2']
-        desc1, desc2 = torch.nn.functional.normalize(desc1,dim=-1), torch.nn.functional.normalize(desc2,dim=-1)
+        self.seed_top_k = config.seed_top_k
+        self.conf_bar = config.conf_bar
+        self.seed_radius_coe = config.seed_radius_coe
+        self.use_score_encoding = config.use_score_encoding
+        self.detach_iter = config.detach_iter
+        self.seedlayer = config.seedlayer
+        self.layer_num = config.layer_num
+        self.sink_iter = config.sink_iter
+
+        self.position_encoder = nn.Sequential(
+            nn.Conv1d(3, 32, kernel_size=1)
+            if config.use_score_encoding
+            else nn.Conv1d(2, 32, kernel_size=1),
+            nn.SyncBatchNorm(32),
+            nn.ReLU(),
+            nn.Conv1d(32, 64, kernel_size=1),
+            nn.SyncBatchNorm(64),
+            nn.ReLU(),
+            nn.Conv1d(64, 128, kernel_size=1),
+            nn.SyncBatchNorm(128),
+            nn.ReLU(),
+            nn.Conv1d(128, 256, kernel_size=1),
+            nn.SyncBatchNorm(256),
+            nn.ReLU(),
+            nn.Conv1d(256, config.net_channels, kernel_size=1),
+        )
+
+        self.hybrid_block = nn.Sequential(
+            *[
+                hybrid_block(config.net_channels, config.head)
+                for _ in range(config.layer_num)
+            ]
+        )
+        self.final_project = nn.Conv1d(
+            config.net_channels, config.net_channels, kernel_size=1
+        )
+        self.dustbin = nn.Parameter(torch.tensor(1.5, dtype=torch.float32))
+
+        # if reseeding
+        if len(config.seedlayer) != 1:
+            self.mid_dustbin = nn.ParameterDict(
+                {
+                    str(i): nn.Parameter(torch.tensor(2, dtype=torch.float32))
+                    for i in config.seedlayer[1:]
+                }
+            )
+            self.mid_final_project = nn.Conv1d(
+                config.net_channels, config.net_channels, kernel_size=1
+            )
+
+    def forward(self, data, test_mode=True):
+        x1, x2, desc1, desc2 = (
+            data["x1"][:, :, :2],
+            data["x2"][:, :, :2],
+            data["desc1"],
+            data["desc2"],
+        )
+        desc1, desc2 = torch.nn.functional.normalize(
+            desc1, dim=-1
+        ), torch.nn.functional.normalize(desc2, dim=-1)
         if test_mode:
-            encode_x1,encode_x2=data['x1'],data['x2']
+            encode_x1, encode_x2 = data["x1"], data["x2"]
         else:
-            encode_x1,encode_x2=data['aug_x1'], data['aug_x2']
-    
-        #preparation
-        desc_dismat=(2-2*torch.matmul(desc1,desc2.transpose(1,2))).sqrt_()
-        values,nn_index=torch.topk(desc_dismat,k=2,largest=False,dim=-1,sorted=True)
-        nn_index2=torch.min(desc_dismat,dim=1).indices.squeeze(1)
-        inverse_ratio_score,nn_index1=values[:,:,1]/values[:,:,0],nn_index[:,:,0]#get inverse score
-   
-        #initial seeding
-        seed_index1,seed_index2=seeding(nn_index1,nn_index2,x1,x2,self.seed_top_k[0],inverse_ratio_score,self.conf_bar[0],\
-                                self.seed_radius_coe,test=test_mode) 
-
-        #position encoding
-        desc1,desc2=desc1.transpose(1,2),desc2.transpose(1,2)   
+            encode_x1, encode_x2 = data["aug_x1"], data["aug_x2"]
+
+        # preparation
+        desc_dismat = (2 - 2 * torch.matmul(desc1, desc2.transpose(1, 2))).sqrt_()
+        values, nn_index = torch.topk(
+            desc_dismat, k=2, largest=False, dim=-1, sorted=True
+        )
+        nn_index2 = torch.min(desc_dismat, dim=1).indices.squeeze(1)
+        inverse_ratio_score, nn_index1 = (
+            values[:, :, 1] / values[:, :, 0],
+            nn_index[:, :, 0],
+        )  # get inverse score
+
+        # initial seeding
+        seed_index1, seed_index2 = seeding(
+            nn_index1,
+            nn_index2,
+            x1,
+            x2,
+            self.seed_top_k[0],
+            inverse_ratio_score,
+            self.conf_bar[0],
+            self.seed_radius_coe,
+            test=test_mode,
+        )
+
+        # position encoding
+        desc1, desc2 = desc1.transpose(1, 2), desc2.transpose(1, 2)
         if not self.use_score_encoding:
-            encode_x1,encode_x2=encode_x1[:,:,:2],encode_x2[:,:,:2]
-        encode_x1,encode_x2=encode_x1.transpose(1,2),encode_x2.transpose(1,2)
-        x1_pos_embedding, x2_pos_embedding = self.position_encoder(encode_x1), self.position_encoder(encode_x2)
+            encode_x1, encode_x2 = encode_x1[:, :, :2], encode_x2[:, :, :2]
+        encode_x1, encode_x2 = encode_x1.transpose(1, 2), encode_x2.transpose(1, 2)
+        x1_pos_embedding, x2_pos_embedding = self.position_encoder(
+            encode_x1
+        ), self.position_encoder(encode_x2)
         aug_desc1, aug_desc2 = x1_pos_embedding + desc1, x2_pos_embedding + desc2
-      
-        seed_weight_tower,mid_p_tower,seed_index_tower,nn_index_tower=[],[],[],[]
-        seed_index_tower.append(torch.stack([seed_index1, seed_index2],dim=-1))
+
+        seed_weight_tower, mid_p_tower, seed_index_tower, nn_index_tower = (
+            [],
+            [],
+            [],
+            [],
+        )
+        seed_index_tower.append(torch.stack([seed_index1, seed_index2], dim=-1))
         nn_index_tower.append(nn_index1)
 
-        seed_para_index=0
+        seed_para_index = 0
         for i in range(self.layer_num):
-            #mid seeding
-            if i in self.seedlayer and i!= 0:
-                seed_para_index+=1
-                aug_desc1,aug_desc2=self.mid_final_project(aug_desc1),self.mid_final_project(aug_desc2)
-                M=torch.matmul(aug_desc1.transpose(1,2),aug_desc2)
-                p=sink_algorithm(M,self.mid_dustbin[str(i)],self.sink_iter[seed_para_index-1])
+            # mid seeding
+            if i in self.seedlayer and i != 0:
+                seed_para_index += 1
+                aug_desc1, aug_desc2 = self.mid_final_project(
+                    aug_desc1
+                ), self.mid_final_project(aug_desc2)
+                M = torch.matmul(aug_desc1.transpose(1, 2), aug_desc2)
+                p = sink_algorithm(
+                    M, self.mid_dustbin[str(i)], self.sink_iter[seed_para_index - 1]
+                )
                 mid_p_tower.append(p)
-                #rematching with p
-                values,nn_index=torch.topk(p[:,:-1,:-1],k=1,dim=-1)
-                nn_index2=torch.max(p[:,:-1,:-1],dim=1).indices.squeeze(1)
-                p_match_score,nn_index1=values[:,:,0],nn_index[:,:,0]
-                #reseeding
-                seed_index1, seed_index2 = seeding(nn_index1,nn_index2,x1,x2,self.seed_top_k[seed_para_index],p_match_score,\
-                                                    self.conf_bar[seed_para_index],self.seed_radius_coe,test=test_mode)
-                seed_index_tower.append(torch.stack([seed_index1, seed_index2],dim=-1)), nn_index_tower.append(nn_index1)
-                if not test_mode and data['step']<self.detach_iter:
-                    aug_desc1,aug_desc2=aug_desc1.detach(),aug_desc2.detach()
-
-            aug_desc1, aug_desc2,seed_weight=self.hybrid_block[i](aug_desc1, aug_desc2,seed_index1,seed_index2)
-            seed_weight_tower.append(seed_weight)
-        
-    
-        aug_desc1,aug_desc2 = self.final_project(aug_desc1), self.final_project(aug_desc2)
-        cmat = torch.matmul(aug_desc1.transpose(1, 2), aug_desc2)
-        p = sink_algorithm(cmat, self.dustbin,self.sink_iter[-1])
-        #seed_weight_tower: l*b*k
-        #seed_index_tower: l*b*k*2
-        #nn_index_tower: seed_l*b
-        return {'p':p,'seed_conf':seed_weight_tower,'seed_index':seed_index_tower,'mid_p':mid_p_tower,'nn_index':nn_index_tower}
+                # rematching with p
+                values, nn_index = torch.topk(p[:, :-1, :-1], k=1, dim=-1)
+                nn_index2 = torch.max(p[:, :-1, :-1], dim=1).indices.squeeze(1)
+                p_match_score, nn_index1 = values[:, :, 0], nn_index[:, :, 0]
+                # reseeding
+                seed_index1, seed_index2 = seeding(
+                    nn_index1,
+                    nn_index2,
+                    x1,
+                    x2,
+                    self.seed_top_k[seed_para_index],
+                    p_match_score,
+                    self.conf_bar[seed_para_index],
+                    self.seed_radius_coe,
+                    test=test_mode,
+                )
+                seed_index_tower.append(
+                    torch.stack([seed_index1, seed_index2], dim=-1)
+                ), nn_index_tower.append(nn_index1)
+                if not test_mode and data["step"] < self.detach_iter:
+                    aug_desc1, aug_desc2 = aug_desc1.detach(), aug_desc2.detach()
 
+            aug_desc1, aug_desc2, seed_weight = self.hybrid_block[i](
+                aug_desc1, aug_desc2, seed_index1, seed_index2
+            )
+            seed_weight_tower.append(seed_weight)
 
+        aug_desc1, aug_desc2 = self.final_project(aug_desc1), self.final_project(
+            aug_desc2
+        )
+        cmat = torch.matmul(aug_desc1.transpose(1, 2), aug_desc2)
+        p = sink_algorithm(cmat, self.dustbin, self.sink_iter[-1])
+        # seed_weight_tower: l*b*k
+        # seed_index_tower: l*b*k*2
+        # nn_index_tower: seed_l*b
+        return {
+            "p": p,
+            "seed_conf": seed_weight_tower,
+            "seed_index": seed_index_tower,
+            "mid_p": mid_p_tower,
+            "nn_index": nn_index_tower,
+        }
diff --git a/third_party/SGMNet/superglue/__init__.py b/third_party/SGMNet/superglue/__init__.py
index 828543beceebb10d05fd9d5fdfcc4b1c91e5af6b..fabeccd0fe21eb5be637602f2b2eb3cfd944d11b 100644
--- a/third_party/SGMNet/superglue/__init__.py
+++ b/third_party/SGMNet/superglue/__init__.py
@@ -1 +1 @@
-from .match_model import matcher
\ No newline at end of file
+from .match_model import matcher
diff --git a/third_party/SGMNet/superglue/match_model.py b/third_party/SGMNet/superglue/match_model.py
index adf5ae53b2385e2bdf4478982fcdd6dbdb014c3c..4a0270dce45a1882397374615156b5310fd181d1 100644
--- a/third_party/SGMNet/superglue/match_model.py
+++ b/third_party/SGMNet/superglue/match_model.py
@@ -3,9 +3,10 @@ import torch.nn as nn
 import time
 
 
-eps=1e-8
+eps = 1e-8
 
-def sinkhorn(M,r,c,iteration):
+
+def sinkhorn(M, r, c, iteration):
     p = torch.softmax(M, dim=-1)
     u = torch.ones_like(r)
     v = torch.ones_like(c)
@@ -15,92 +16,152 @@ def sinkhorn(M,r,c,iteration):
     p = p * u.unsqueeze(-1) * v.unsqueeze(-2)
     return p
 
-def sink_algorithm(M,dustbin,iteration):
+
+def sink_algorithm(M, dustbin, iteration):
     M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1)
     M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2)
-    r = torch.ones([M.shape[0], M.shape[1] - 1],device='cuda')
-    r = torch.cat([r, torch.ones([M.shape[0], 1],device='cuda') * M.shape[1]], dim=-1)
-    c = torch.ones([M.shape[0], M.shape[2] - 1],device='cuda')
-    c = torch.cat([c, torch.ones([M.shape[0], 1],device='cuda') * M.shape[2]], dim=-1)
-    p=sinkhorn(M,r,c,iteration)
+    r = torch.ones([M.shape[0], M.shape[1] - 1], device="cuda")
+    r = torch.cat([r, torch.ones([M.shape[0], 1], device="cuda") * M.shape[1]], dim=-1)
+    c = torch.ones([M.shape[0], M.shape[2] - 1], device="cuda")
+    c = torch.cat([c, torch.ones([M.shape[0], 1], device="cuda") * M.shape[2]], dim=-1)
+    p = sinkhorn(M, r, c, iteration)
     return p
 
 
 class attention_block(nn.Module):
-    def __init__(self,channels,head,type):
-        assert type=='self' or type=='cross','invalid attention type'
+    def __init__(self, channels, head, type):
+        assert type == "self" or type == "cross", "invalid attention type"
         nn.Module.__init__(self)
-        self.head=head
-        self.type=type
-        self.head_dim=channels//head
-        self.query_filter=nn.Conv1d(channels, channels, kernel_size=1)
-        self.key_filter=nn.Conv1d(channels,channels,kernel_size=1)
-        self.value_filter=nn.Conv1d(channels,channels,kernel_size=1)
-        self.attention_filter=nn.Sequential(nn.Conv1d(2*channels,2*channels, kernel_size=1),nn.SyncBatchNorm(2*channels), nn.ReLU(),
-                                             nn.Conv1d(2*channels, channels, kernel_size=1))
-        self.mh_filter=nn.Conv1d(channels, channels, kernel_size=1)
-
-    def forward(self,fea1,fea2):
-        batch_size,n,m=fea1.shape[0],fea1.shape[2],fea2.shape[2]
-        query1, key1, value1 = self.query_filter(fea1).view(batch_size,self.head_dim,self.head,-1), self.key_filter(fea1).view(batch_size,self.head_dim,self.head,-1), \
-                               self.value_filter(fea1).view(batch_size,self.head_dim,self.head,-1)
-        query2, key2, value2 = self.query_filter(fea2).view(batch_size,self.head_dim,self.head,-1), self.key_filter(fea2).view(batch_size,self.head_dim,self.head,-1), \
-                               self.value_filter(fea2).view(batch_size,self.head_dim,self.head,-1)
-        if(self.type=='self'):
-            score1,score2=torch.softmax(torch.einsum('bdhn,bdhm->bhnm',query1,key1)/self.head_dim**0.5,dim=-1),\
-                          torch.softmax(torch.einsum('bdhn,bdhm->bhnm',query2,key2)/self.head_dim**0.5,dim=-1)
-            add_value1, add_value2 = torch.einsum('bhnm,bdhm->bdhn', score1, value1), torch.einsum('bhnm,bdhm->bdhn',score2, value2)
+        self.head = head
+        self.type = type
+        self.head_dim = channels // head
+        self.query_filter = nn.Conv1d(channels, channels, kernel_size=1)
+        self.key_filter = nn.Conv1d(channels, channels, kernel_size=1)
+        self.value_filter = nn.Conv1d(channels, channels, kernel_size=1)
+        self.attention_filter = nn.Sequential(
+            nn.Conv1d(2 * channels, 2 * channels, kernel_size=1),
+            nn.SyncBatchNorm(2 * channels),
+            nn.ReLU(),
+            nn.Conv1d(2 * channels, channels, kernel_size=1),
+        )
+        self.mh_filter = nn.Conv1d(channels, channels, kernel_size=1)
+
+    def forward(self, fea1, fea2):
+        batch_size, n, m = fea1.shape[0], fea1.shape[2], fea2.shape[2]
+        query1, key1, value1 = (
+            self.query_filter(fea1).view(batch_size, self.head_dim, self.head, -1),
+            self.key_filter(fea1).view(batch_size, self.head_dim, self.head, -1),
+            self.value_filter(fea1).view(batch_size, self.head_dim, self.head, -1),
+        )
+        query2, key2, value2 = (
+            self.query_filter(fea2).view(batch_size, self.head_dim, self.head, -1),
+            self.key_filter(fea2).view(batch_size, self.head_dim, self.head, -1),
+            self.value_filter(fea2).view(batch_size, self.head_dim, self.head, -1),
+        )
+        if self.type == "self":
+            score1, score2 = torch.softmax(
+                torch.einsum("bdhn,bdhm->bhnm", query1, key1) / self.head_dim**0.5,
+                dim=-1,
+            ), torch.softmax(
+                torch.einsum("bdhn,bdhm->bhnm", query2, key2) / self.head_dim**0.5,
+                dim=-1,
+            )
+            add_value1, add_value2 = torch.einsum(
+                "bhnm,bdhm->bdhn", score1, value1
+            ), torch.einsum("bhnm,bdhm->bdhn", score2, value2)
         else:
-            score1,score2 = torch.softmax(torch.einsum('bdhn,bdhm->bhnm', query1, key2) / self.head_dim ** 0.5,dim=-1), \
-                            torch.softmax(torch.einsum('bdhn,bdhm->bhnm', query2, key1) / self.head_dim ** 0.5, dim=-1)
-            add_value1, add_value2 =torch.einsum('bhnm,bdhm->bdhn',score1,value2),torch.einsum('bhnm,bdhm->bdhn',score2,value1)
-        add_value1,add_value2=self.mh_filter(add_value1.contiguous().view(batch_size,self.head*self.head_dim,n)),self.mh_filter(add_value2.contiguous().view(batch_size,self.head*self.head_dim,m))
-        fea11, fea22 = torch.cat([fea1, add_value1], dim=1), torch.cat([fea2, add_value2], dim=1)
-        fea1, fea2 = fea1+self.attention_filter(fea11), fea2+self.attention_filter(fea22)
-     
-        return fea1,fea2
+            score1, score2 = torch.softmax(
+                torch.einsum("bdhn,bdhm->bhnm", query1, key2) / self.head_dim**0.5,
+                dim=-1,
+            ), torch.softmax(
+                torch.einsum("bdhn,bdhm->bhnm", query2, key1) / self.head_dim**0.5,
+                dim=-1,
+            )
+            add_value1, add_value2 = torch.einsum(
+                "bhnm,bdhm->bdhn", score1, value2
+            ), torch.einsum("bhnm,bdhm->bdhn", score2, value1)
+        add_value1, add_value2 = self.mh_filter(
+            add_value1.contiguous().view(batch_size, self.head * self.head_dim, n)
+        ), self.mh_filter(
+            add_value2.contiguous().view(batch_size, self.head * self.head_dim, m)
+        )
+        fea11, fea22 = torch.cat([fea1, add_value1], dim=1), torch.cat(
+            [fea2, add_value2], dim=1
+        )
+        fea1, fea2 = fea1 + self.attention_filter(fea11), fea2 + self.attention_filter(
+            fea22
+        )
+
+        return fea1, fea2
 
 
 class matcher(nn.Module):
     def __init__(self, config):
         nn.Module.__init__(self)
-        self.use_score_encoding=config.use_score_encoding
-        self.layer_num=config.layer_num
-        self.sink_iter=config.sink_iter
-        self.position_encoder = nn.Sequential(nn.Conv1d(3, 32, kernel_size=1) if config.use_score_encoding else nn.Conv1d(2, 32, kernel_size=1), 
-                                              nn.SyncBatchNorm(32), nn.ReLU(),
-                                              nn.Conv1d(32, 64, kernel_size=1), nn.SyncBatchNorm(64),nn.ReLU(),
-                                              nn.Conv1d(64, 128, kernel_size=1), nn.SyncBatchNorm(128), nn.ReLU(),
-                                              nn.Conv1d(128, 256, kernel_size=1), nn.SyncBatchNorm(256), nn.ReLU(),
-                                              nn.Conv1d(256, config.net_channels, kernel_size=1))
-       
-        self.dustbin=nn.Parameter(torch.tensor(1,dtype=torch.float32,device='cuda'))
-        self.self_attention_block=nn.Sequential(*[attention_block(config.net_channels,config.head,'self') for _ in range(config.layer_num)])
-        self.cross_attention_block=nn.Sequential(*[attention_block(config.net_channels,config.head,'cross') for _ in range(config.layer_num)])
-        self.final_project=nn.Conv1d(config.net_channels, config.net_channels, kernel_size=1)
-
-    def forward(self,data,test_mode=True):
-        desc1, desc2 = data['desc1'], data['desc2']
-        desc1, desc2 = torch.nn.functional.normalize(desc1,dim=-1), torch.nn.functional.normalize(desc2,dim=-1)
-        desc1,desc2=desc1.transpose(1,2),desc2.transpose(1,2)   
+        self.use_score_encoding = config.use_score_encoding
+        self.layer_num = config.layer_num
+        self.sink_iter = config.sink_iter
+        self.position_encoder = nn.Sequential(
+            nn.Conv1d(3, 32, kernel_size=1)
+            if config.use_score_encoding
+            else nn.Conv1d(2, 32, kernel_size=1),
+            nn.SyncBatchNorm(32),
+            nn.ReLU(),
+            nn.Conv1d(32, 64, kernel_size=1),
+            nn.SyncBatchNorm(64),
+            nn.ReLU(),
+            nn.Conv1d(64, 128, kernel_size=1),
+            nn.SyncBatchNorm(128),
+            nn.ReLU(),
+            nn.Conv1d(128, 256, kernel_size=1),
+            nn.SyncBatchNorm(256),
+            nn.ReLU(),
+            nn.Conv1d(256, config.net_channels, kernel_size=1),
+        )
+
+        self.dustbin = nn.Parameter(torch.tensor(1, dtype=torch.float32, device="cuda"))
+        self.self_attention_block = nn.Sequential(
+            *[
+                attention_block(config.net_channels, config.head, "self")
+                for _ in range(config.layer_num)
+            ]
+        )
+        self.cross_attention_block = nn.Sequential(
+            *[
+                attention_block(config.net_channels, config.head, "cross")
+                for _ in range(config.layer_num)
+            ]
+        )
+        self.final_project = nn.Conv1d(
+            config.net_channels, config.net_channels, kernel_size=1
+        )
+
+    def forward(self, data, test_mode=True):
+        desc1, desc2 = data["desc1"], data["desc2"]
+        desc1, desc2 = torch.nn.functional.normalize(
+            desc1, dim=-1
+        ), torch.nn.functional.normalize(desc2, dim=-1)
+        desc1, desc2 = desc1.transpose(1, 2), desc2.transpose(1, 2)
         if test_mode:
-            encode_x1,encode_x2=data['x1'],data['x2']
+            encode_x1, encode_x2 = data["x1"], data["x2"]
         else:
-            encode_x1,encode_x2=data['aug_x1'], data['aug_x2']
+            encode_x1, encode_x2 = data["aug_x1"], data["aug_x2"]
         if not self.use_score_encoding:
-            encode_x1,encode_x2=encode_x1[:,:,:2],encode_x2[:,:,:2]
+            encode_x1, encode_x2 = encode_x1[:, :, :2], encode_x2[:, :, :2]
 
-        encode_x1,encode_x2=encode_x1.transpose(1,2),encode_x2.transpose(1,2)
+        encode_x1, encode_x2 = encode_x1.transpose(1, 2), encode_x2.transpose(1, 2)
 
-        x1_pos_embedding, x2_pos_embedding = self.position_encoder(encode_x1), self.position_encoder(encode_x2)
-        aug_desc1, aug_desc2 = x1_pos_embedding + desc1, x2_pos_embedding+desc2
+        x1_pos_embedding, x2_pos_embedding = self.position_encoder(
+            encode_x1
+        ), self.position_encoder(encode_x2)
+        aug_desc1, aug_desc2 = x1_pos_embedding + desc1, x2_pos_embedding + desc2
         for i in range(self.layer_num):
-            aug_desc1,aug_desc2=self.self_attention_block[i](aug_desc1,aug_desc2)
-            aug_desc1,aug_desc2=self.cross_attention_block[i](aug_desc1,aug_desc2)
+            aug_desc1, aug_desc2 = self.self_attention_block[i](aug_desc1, aug_desc2)
+            aug_desc1, aug_desc2 = self.cross_attention_block[i](aug_desc1, aug_desc2)
 
-        aug_desc1,aug_desc2=self.final_project(aug_desc1),self.final_project(aug_desc2)
+        aug_desc1, aug_desc2 = self.final_project(aug_desc1), self.final_project(
+            aug_desc2
+        )
         desc_mat = torch.matmul(aug_desc1.transpose(1, 2), aug_desc2)
-        p = sink_algorithm(desc_mat, self.dustbin,self.sink_iter[0])
-        return {'p':p}
-
-
+        p = sink_algorithm(desc_mat, self.dustbin, self.sink_iter[0])
+        return {"p": p}
diff --git a/third_party/SGMNet/superpoint/__init__.py b/third_party/SGMNet/superpoint/__init__.py
index 111c8882a7bc7512c6191ca86a0e71c3b1404233..f1127dfc54047e2d0d877da1d3eb5c2ed569b85e 100644
--- a/third_party/SGMNet/superpoint/__init__.py
+++ b/third_party/SGMNet/superpoint/__init__.py
@@ -1 +1 @@
-from .superpoint import SuperPoint
\ No newline at end of file
+from .superpoint import SuperPoint
diff --git a/third_party/SGMNet/superpoint/superpoint.py b/third_party/SGMNet/superpoint/superpoint.py
index d4e3ce481409264a3188270ad01aa62b1614377f..38b839cbc731460e487c9359c6e0edcaec7be7c9 100644
--- a/third_party/SGMNet/superpoint/superpoint.py
+++ b/third_party/SGMNet/superpoint/superpoint.py
@@ -3,11 +3,12 @@ from torch import nn
 
 
 def simple_nms(scores, nms_radius):
-    assert(nms_radius >= 0)
+    assert nms_radius >= 0
 
     def max_pool(x):
         return torch.nn.functional.max_pool2d(
-            x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius)
+            x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
+        )
 
     zeros = torch.zeros_like(scores)
     max_mask = scores == max_pool(scores)
@@ -36,19 +37,21 @@ def top_k_keypoints(keypoints, scores, k):
 def sample_descriptors(keypoints, descriptors, s):
     b, c, h, w = descriptors.shape
     keypoints = keypoints - s / 2 + 0.5
-    keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)],
-                              ).to(keypoints)[None]
-    keypoints = keypoints*2 - 1  # normalize to (-1, 1)
-    args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {}
+    keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],).to(
+        keypoints
+    )[None]
+    keypoints = keypoints * 2 - 1  # normalize to (-1, 1)
+    args = {"align_corners": True} if int(torch.__version__[2]) > 2 else {}
     descriptors = torch.nn.functional.grid_sample(
-        descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args)
+        descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args
+    )
     descriptors = torch.nn.functional.normalize(
-        descriptors.reshape(b, c, -1), p=2, dim=1)
+        descriptors.reshape(b, c, -1), p=2, dim=1
+    )
     return descriptors
 
 
 class SuperPoint(nn.Module):
-
     def __init__(self, config):
         super().__init__()
         self.config = {**config}
@@ -71,16 +74,16 @@ class SuperPoint(nn.Module):
 
         self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
         self.convDb = nn.Conv2d(
-            c5, self.config['descriptor_dim'],
-            kernel_size=1, stride=1, padding=0)
+            c5, self.config["descriptor_dim"], kernel_size=1, stride=1, padding=0
+        )
 
-        self.load_state_dict(torch.load(config['model_path']))
+        self.load_state_dict(torch.load(config["model_path"]))
 
-        mk = self.config['max_keypoints']
+        mk = self.config["max_keypoints"]
         if mk == 0 or mk < -1:
-            raise ValueError('\"max_keypoints\" must be positive or \"-1\"')
+            raise ValueError('"max_keypoints" must be positive or "-1"')
 
-        print('Loaded SuperPoint model')
+        print("Loaded SuperPoint model")
 
     def forward(self, data):
         # Shared Encoder
@@ -101,25 +104,35 @@ class SuperPoint(nn.Module):
         scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
         b, c, h, w = scores.shape
         scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
-        scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8)
-        scores = simple_nms(scores, self.config['nms_radius'])
+        scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
+        scores = simple_nms(scores, self.config["nms_radius"])
 
         # Extract keypoints
         keypoints = [
-            torch.nonzero(s > self.config['detection_threshold'])
-            for s in scores]
+            torch.nonzero(s > self.config["detection_threshold"]) for s in scores
+        ]
         scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
 
         # Discard keypoints near the image borders
-        keypoints, scores = list(zip(*[
-            remove_borders(k, s, self.config['remove_borders'], h*8, w*8)
-            for k, s in zip(keypoints, scores)]))
+        keypoints, scores = list(
+            zip(
+                *[
+                    remove_borders(k, s, self.config["remove_borders"], h * 8, w * 8)
+                    for k, s in zip(keypoints, scores)
+                ]
+            )
+        )
 
         # Keep the k keypoints with highest score
-        if self.config['max_keypoints'] >= 0:
-            keypoints, scores = list(zip(*[
-                top_k_keypoints(k, s, self.config['max_keypoints'])
-                for k, s in zip(keypoints, scores)]))
+        if self.config["max_keypoints"] >= 0:
+            keypoints, scores = list(
+                zip(
+                    *[
+                        top_k_keypoints(k, s, self.config["max_keypoints"])
+                        for k, s in zip(keypoints, scores)
+                    ]
+                )
+            )
 
         # Convert (h, w) to (x, y)
         keypoints = [torch.flip(k, [1]).float() for k in keypoints]
@@ -130,11 +143,13 @@ class SuperPoint(nn.Module):
         descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
 
         # Extract descriptors
-        descriptors = [sample_descriptors(k[None], d[None], 8)[0]
-                       for k, d in zip(keypoints, descriptors)]
+        descriptors = [
+            sample_descriptors(k[None], d[None], 8)[0]
+            for k, d in zip(keypoints, descriptors)
+        ]
 
         return {
-            'keypoints': keypoints,
-            'scores': scores,
-            'descriptors': descriptors,
+            "keypoints": keypoints,
+            "scores": scores,
+            "descriptors": descriptors,
         }
diff --git a/third_party/SGMNet/train/config.py b/third_party/SGMNet/train/config.py
index 31c4c1c6deef3d6dd568897f4202d96456586376..3610e40ff0628b1c5c4a2bc2a73d38a6d2cd65b1 100644
--- a/third_party/SGMNet/train/config.py
+++ b/third_party/SGMNet/train/config.py
@@ -1,5 +1,6 @@
 import argparse
 
+
 def str2bool(v):
     return v.lower() in ("true", "1")
 
@@ -18,102 +19,111 @@ def add_argument_group(name):
 # Network
 net_arg = add_argument_group("Network")
 net_arg.add_argument(
-    "--model_name", type=str,default='SGM', help=""
-    "model for training")
+    "--model_name", type=str, default="SGM", help="" "model for training"
+)
 net_arg.add_argument(
-    "--config_path", type=str,default='configs/sgm.yaml', help=""
-    "config path for model")
+    "--config_path",
+    type=str,
+    default="configs/sgm.yaml",
+    help="" "config path for model",
+)
 
 # -----------------------------------------------------------------------------
 # Data
 data_arg = add_argument_group("Data")
 data_arg.add_argument(
-    "--rawdata_path", type=str, default='rawdata', help=""
-    "path for rawdata")
+    "--rawdata_path", type=str, default="rawdata", help="" "path for rawdata"
+)
 data_arg.add_argument(
-    "--dataset_path", type=str, default='dataset', help=""
-    "path for dataset")
+    "--dataset_path", type=str, default="dataset", help="" "path for dataset"
+)
 data_arg.add_argument(
-    "--desc_path", type=str, default='desc', help=""
-    "path for descriptor(kpt) dir")
+    "--desc_path", type=str, default="desc", help="" "path for descriptor(kpt) dir"
+)
 data_arg.add_argument(
-    "--num_kpt", type=int, default=1000, help=""
-    "number of kpt for training")
+    "--num_kpt", type=int, default=1000, help="" "number of kpt for training"
+)
 data_arg.add_argument(
-    "--input_normalize", type=str, default='img', help=""
-    "normalize type for input kpt, img or intrinsic")
+    "--input_normalize",
+    type=str,
+    default="img",
+    help="" "normalize type for input kpt, img or intrinsic",
+)
 data_arg.add_argument(
-    "--data_aug", type=str2bool, default=True, help=""
-    "apply kpt coordinate homography augmentation")
+    "--data_aug",
+    type=str2bool,
+    default=True,
+    help="" "apply kpt coordinate homography augmentation",
+)
 data_arg.add_argument(
-    "--desc_suffix", type=str, default='suffix', help=""
-    "desc file suffix")
+    "--desc_suffix", type=str, default="suffix", help="" "desc file suffix"
+)
 
 
 # -----------------------------------------------------------------------------
 # Loss
 loss_arg = add_argument_group("loss")
+loss_arg.add_argument("--momentum", type=float, default=0.9, help="" "momentum")
 loss_arg.add_argument(
-    "--momentum", type=float, default=0.9, help=""
-    "momentum")
-loss_arg.add_argument(
-    "--seed_loss_weight", type=float, default=250, help=""
-    "confidence loss weight for sgm")
+    "--seed_loss_weight",
+    type=float,
+    default=250,
+    help="" "confidence loss weight for sgm",
+)
 loss_arg.add_argument(
-    "--mid_loss_weight", type=float, default=1, help=""
-    "midseeding loss weight for sgm")
+    "--mid_loss_weight", type=float, default=1, help="" "midseeding loss weight for sgm"
+)
 loss_arg.add_argument(
-    "--inlier_th", type=float, default=5e-3, help=""
-    "inlier threshold for epipolar distance (for sgm and visualization)")
+    "--inlier_th",
+    type=float,
+    default=5e-3,
+    help="" "inlier threshold for epipolar distance (for sgm and visualization)",
+)
 
 
 # -----------------------------------------------------------------------------
 # Training
 train_arg = add_argument_group("Train")
+train_arg.add_argument("--train_lr", type=float, default=1e-4, help="" "learning rate")
+train_arg.add_argument("--train_batch_size", type=int, default=16, help="" "batch size")
 train_arg.add_argument(
-    "--train_lr", type=float, default=1e-4, help=""
-    "learning rate")
-train_arg.add_argument(
-    "--train_batch_size", type=int, default=16, help=""
-    "batch size")
-train_arg.add_argument(
-    "--gpu_id", type=str,default='0', help='id(s) for CUDA_VISIBLE_DEVICES')
-train_arg.add_argument(
-    "--train_iter", type=int, default=1000000, help=""
-    "training iterations to perform")
-train_arg.add_argument(
-    "--log_base", type=str, default="./log/", help=""
-    "log path")
+    "--gpu_id", type=str, default="0", help="id(s) for CUDA_VISIBLE_DEVICES"
+)
 train_arg.add_argument(
-    "--val_intv", type=int, default=20000, help=""
-    "validation interval")
+    "--train_iter", type=int, default=1000000, help="" "training iterations to perform"
+)
+train_arg.add_argument("--log_base", type=str, default="./log/", help="" "log path")
 train_arg.add_argument(
-    "--save_intv", type=int, default=1000, help=""
-    "summary interval")
+    "--val_intv", type=int, default=20000, help="" "validation interval"
+)
 train_arg.add_argument(
-    "--log_intv", type=int, default=100, help=""
-    "log interval")
+    "--save_intv", type=int, default=1000, help="" "summary interval"
+)
+train_arg.add_argument("--log_intv", type=int, default=100, help="" "log interval")
 train_arg.add_argument(
-    "--decay_rate", type=float, default=0.999996, help=""
-    "lr decay rate")
+    "--decay_rate", type=float, default=0.999996, help="" "lr decay rate"
+)
 train_arg.add_argument(
-    "--decay_iter", type=float, default=300000, help=""
-    "lr decay iter")
+    "--decay_iter", type=float, default=300000, help="" "lr decay iter"
+)
 train_arg.add_argument(
-    "--local_rank", type=int, default=0, help=""
-    "local rank for ddp")
+    "--local_rank", type=int, default=0, help="" "local rank for ddp"
+)
 train_arg.add_argument(
-    "--train_vis_folder", type=str, default='.', help=""
-    "visualization folder during training")
+    "--train_vis_folder",
+    type=str,
+    default=".",
+    help="" "visualization folder during training",
+)
 
 # -----------------------------------------------------------------------------
 # Visualization
-vis_arg = add_argument_group('Visualization')
+vis_arg = add_argument_group("Visualization")
 vis_arg.add_argument(
-    "--tqdm_width", type=int, default=79, help=""
-    "width of the tqdm bar"
+    "--tqdm_width", type=int, default=79, help="" "width of the tqdm bar"
 )
 
+
 def get_config():
     config, unparsed = parser.parse_known_args()
     return config, unparsed
@@ -122,5 +132,6 @@ def get_config():
 def print_usage():
     parser.print_usage()
 
+
 #
-# config.py ends here
\ No newline at end of file
+# config.py ends here
diff --git a/third_party/SGMNet/train/dataset.py b/third_party/SGMNet/train/dataset.py
index d07a84e9588b755a86119363f08860187d1668c0..37a97fd6204240e636d4b234f6c855f948c76b99 100644
--- a/third_party/SGMNet/train/dataset.py
+++ b/third_party/SGMNet/train/dataset.py
@@ -7,137 +7,278 @@ import h5py
 import random
 
 import sys
+
 ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))
 sys.path.insert(0, ROOT_DIR)
 
-from utils import train_utils,evaluation_utils
+from utils import train_utils, evaluation_utils
 
-torch.multiprocessing.set_sharing_strategy('file_system')
+torch.multiprocessing.set_sharing_strategy("file_system")
 
 
 class Offline_Dataset(data.Dataset):
-    def __init__(self,config,mode):
-        assert mode=='train' or mode=='valid'
+    def __init__(self, config, mode):
+        assert mode == "train" or mode == "valid"
 
         self.config = config
         self.mode = mode
-        metadir=os.path.join(config.dataset_path,'valid') if mode=='valid' else os.path.join(config.dataset_path,'train')
-        
-        pair_num_list=np.loadtxt(os.path.join(metadir,'pair_num.txt'),dtype=str)
-        self.total_pairs=int(pair_num_list[0,1])
-        self.pair_seq_list,self.accu_pair_num=train_utils.parse_pair_seq(pair_num_list)
+        metadir = (
+            os.path.join(config.dataset_path, "valid")
+            if mode == "valid"
+            else os.path.join(config.dataset_path, "train")
+        )
 
+        pair_num_list = np.loadtxt(os.path.join(metadir, "pair_num.txt"), dtype=str)
+        self.total_pairs = int(pair_num_list[0, 1])
+        self.pair_seq_list, self.accu_pair_num = train_utils.parse_pair_seq(
+            pair_num_list
+        )
 
     def collate_fn(self, batch):
-        batch_size, num_pts = len(batch), batch[0]['x1'].shape[0]
-        
+        batch_size, num_pts = len(batch), batch[0]["x1"].shape[0]
+
         data = {}
-        dtype=['x1','x2','kpt1','kpt2','desc1','desc2','num_corr','num_incorr1','num_incorr2','e_gt','pscore1','pscore2','img_path1','img_path2']
+        dtype = [
+            "x1",
+            "x2",
+            "kpt1",
+            "kpt2",
+            "desc1",
+            "desc2",
+            "num_corr",
+            "num_incorr1",
+            "num_incorr2",
+            "e_gt",
+            "pscore1",
+            "pscore2",
+            "img_path1",
+            "img_path2",
+        ]
         for key in dtype:
-            data[key]=[]
+            data[key] = []
         for sample in batch:
             for key in dtype:
                 data[key].append(sample[key])
-           
-        for key in ['x1', 'x2','kpt1','kpt2', 'desc1', 'desc2','e_gt','pscore1','pscore2']:
+
+        for key in [
+            "x1",
+            "x2",
+            "kpt1",
+            "kpt2",
+            "desc1",
+            "desc2",
+            "e_gt",
+            "pscore1",
+            "pscore2",
+        ]:
             data[key] = torch.from_numpy(np.stack(data[key])).float()
-        for key in ['num_corr', 'num_incorr1', 'num_incorr2']:
+        for key in ["num_corr", "num_incorr1", "num_incorr2"]:
             data[key] = torch.from_numpy(np.stack(data[key])).int()
 
         # kpt augmentation with random homography
-        if (self.mode == 'train' and self.config.data_aug):
-            homo_mat = torch.from_numpy(train_utils.get_rnd_homography(batch_size)).unsqueeze(1)
-            aug_seed=random.random() 
-            if aug_seed<0.5:
-                x1_homo = torch.cat([data['x1'], torch.ones([batch_size, num_pts, 1])], dim=-1).unsqueeze(-1)
+        if self.mode == "train" and self.config.data_aug:
+            homo_mat = torch.from_numpy(
+                train_utils.get_rnd_homography(batch_size)
+            ).unsqueeze(1)
+            aug_seed = random.random()
+            if aug_seed < 0.5:
+                x1_homo = torch.cat(
+                    [data["x1"], torch.ones([batch_size, num_pts, 1])], dim=-1
+                ).unsqueeze(-1)
                 x1_homo = torch.matmul(homo_mat.float(), x1_homo.float()).squeeze(-1)
-                data['aug_x1'] = x1_homo[:, :, :2] / x1_homo[:, :, 2].unsqueeze(-1)
-                data['aug_x2']=data['x2']
+                data["aug_x1"] = x1_homo[:, :, :2] / x1_homo[:, :, 2].unsqueeze(-1)
+                data["aug_x2"] = data["x2"]
             else:
-                x2_homo = torch.cat([data['x2'], torch.ones([batch_size, num_pts, 1])], dim=-1).unsqueeze(-1)
+                x2_homo = torch.cat(
+                    [data["x2"], torch.ones([batch_size, num_pts, 1])], dim=-1
+                ).unsqueeze(-1)
                 x2_homo = torch.matmul(homo_mat.float(), x2_homo.float()).squeeze(-1)
-                data['aug_x2'] = x2_homo[:, :, :2] / x2_homo[:, :, 2].unsqueeze(-1)
-                data['aug_x1']=data['x1']
+                data["aug_x2"] = x2_homo[:, :, :2] / x2_homo[:, :, 2].unsqueeze(-1)
+                data["aug_x1"] = data["x1"]
         else:
-            data['aug_x1'],data['aug_x2']=data['x1'],data['x2']
+            data["aug_x1"], data["aug_x2"] = data["x1"], data["x2"]
         return data
 
-
     def __getitem__(self, index):
-        seq=self.pair_seq_list[index]
-        index_within_seq=index-self.accu_pair_num[seq]
+        seq = self.pair_seq_list[index]
+        index_within_seq = index - self.accu_pair_num[seq]
 
-        with h5py.File(os.path.join(self.config.dataset_path,seq,'info.h5py'),'r') as data:
-            R,t = data['dR'][str(index_within_seq)][()], data['dt'][str(index_within_seq)][()]
-            egt = np.reshape(np.matmul(np.reshape(evaluation_utils.np_skew_symmetric(t.astype('float64').reshape(1, 3)), (3, 3)),np.reshape(R.astype('float64'), (3, 3))), (3, 3))
+        with h5py.File(
+            os.path.join(self.config.dataset_path, seq, "info.h5py"), "r"
+        ) as data:
+            R, t = (
+                data["dR"][str(index_within_seq)][()],
+                data["dt"][str(index_within_seq)][()],
+            )
+            egt = np.reshape(
+                np.matmul(
+                    np.reshape(
+                        evaluation_utils.np_skew_symmetric(
+                            t.astype("float64").reshape(1, 3)
+                        ),
+                        (3, 3),
+                    ),
+                    np.reshape(R.astype("float64"), (3, 3)),
+                ),
+                (3, 3),
+            )
             egt = egt / np.linalg.norm(egt)
-            K1, K2 = data['K1'][str(index_within_seq)][()],data['K2'][str(index_within_seq)][()]
-            size1,size2=data['size1'][str(index_within_seq)][()],data['size2'][str(index_within_seq)][()]
-
-            img_path1,img_path2=data['img_path1'][str(index_within_seq)][()][0].decode(),data['img_path2'][str(index_within_seq)][()][0].decode()
-            img_name1,img_name2=img_path1.split('/')[-1],img_path2.split('/')[-1]
-            img_path1,img_path2=os.path.join(self.config.rawdata_path,img_path1),os.path.join(self.config.rawdata_path,img_path2)
-            fea_path1,fea_path2=os.path.join(self.config.desc_path,seq,img_name1+self.config.desc_suffix),\
-                                os.path.join(self.config.desc_path,seq,img_name2+self.config.desc_suffix)
-            with h5py.File(fea_path1,'r') as fea1, h5py.File(fea_path2,'r') as fea2:
-                desc1,kpt1,pscore1=fea1['descriptors'][()],fea1['keypoints'][()][:,:2],fea1['keypoints'][()][:,2]
-                desc2,kpt2,pscore2=fea2['descriptors'][()],fea2['keypoints'][()][:,:2],fea2['keypoints'][()][:,2]
-                kpt1,kpt2,desc1,desc2=kpt1[:self.config.num_kpt],kpt2[:self.config.num_kpt],desc1[:self.config.num_kpt],desc2[:self.config.num_kpt]
+            K1, K2 = (
+                data["K1"][str(index_within_seq)][()],
+                data["K2"][str(index_within_seq)][()],
+            )
+            size1, size2 = (
+                data["size1"][str(index_within_seq)][()],
+                data["size2"][str(index_within_seq)][()],
+            )
+
+            img_path1, img_path2 = (
+                data["img_path1"][str(index_within_seq)][()][0].decode(),
+                data["img_path2"][str(index_within_seq)][()][0].decode(),
+            )
+            img_name1, img_name2 = img_path1.split("/")[-1], img_path2.split("/")[-1]
+            img_path1, img_path2 = os.path.join(
+                self.config.rawdata_path, img_path1
+            ), os.path.join(self.config.rawdata_path, img_path2)
+            fea_path1, fea_path2 = os.path.join(
+                self.config.desc_path, seq, img_name1 + self.config.desc_suffix
+            ), os.path.join(
+                self.config.desc_path, seq, img_name2 + self.config.desc_suffix
+            )
+            with h5py.File(fea_path1, "r") as fea1, h5py.File(fea_path2, "r") as fea2:
+                desc1, kpt1, pscore1 = (
+                    fea1["descriptors"][()],
+                    fea1["keypoints"][()][:, :2],
+                    fea1["keypoints"][()][:, 2],
+                )
+                desc2, kpt2, pscore2 = (
+                    fea2["descriptors"][()],
+                    fea2["keypoints"][()][:, :2],
+                    fea2["keypoints"][()][:, 2],
+                )
+                kpt1, kpt2, desc1, desc2 = (
+                    kpt1[: self.config.num_kpt],
+                    kpt2[: self.config.num_kpt],
+                    desc1[: self.config.num_kpt],
+                    desc2[: self.config.num_kpt],
+                )
 
             # normalize kpt
-            if self.config.input_normalize=='intrinsic':
-                x1, x2 = np.concatenate([kpt1, np.ones([kpt1.shape[0], 1])], axis=-1), np.concatenate(
-                    [kpt2, np.ones([kpt2.shape[0], 1])], axis=-1)
-                x1, x2 = np.matmul(np.linalg.inv(K1), x1.T).T[:, :2], np.matmul(np.linalg.inv(K2), x2.T).T[:, :2]
-            elif self.config.input_normalize=='img' :   
-                x1,x2=(kpt1-size1/2)/size1,(kpt2-size2/2)/size2
-                S1_inv,S2_inv=np.asarray([[size1[0],0,0.5*size1[0]],[0,size1[1],0.5*size1[1]],[0,0,1]]),\
-                            np.asarray([[size2[0],0,0.5*size2[0]],[0,size2[1],0.5*size2[1]],[0,0,1]])
-                M1,M2=np.matmul(np.linalg.inv(K1),S1_inv),np.matmul(np.linalg.inv(K2),S2_inv)
-                egt=np.matmul(np.matmul(M2.transpose(),egt),M1)
+            if self.config.input_normalize == "intrinsic":
+                x1, x2 = np.concatenate(
+                    [kpt1, np.ones([kpt1.shape[0], 1])], axis=-1
+                ), np.concatenate([kpt2, np.ones([kpt2.shape[0], 1])], axis=-1)
+                x1, x2 = (
+                    np.matmul(np.linalg.inv(K1), x1.T).T[:, :2],
+                    np.matmul(np.linalg.inv(K2), x2.T).T[:, :2],
+                )
+            elif self.config.input_normalize == "img":
+                x1, x2 = (kpt1 - size1 / 2) / size1, (kpt2 - size2 / 2) / size2
+                S1_inv, S2_inv = np.asarray(
+                    [
+                        [size1[0], 0, 0.5 * size1[0]],
+                        [0, size1[1], 0.5 * size1[1]],
+                        [0, 0, 1],
+                    ]
+                ), np.asarray(
+                    [
+                        [size2[0], 0, 0.5 * size2[0]],
+                        [0, size2[1], 0.5 * size2[1]],
+                        [0, 0, 1],
+                    ]
+                )
+                M1, M2 = np.matmul(np.linalg.inv(K1), S1_inv), np.matmul(
+                    np.linalg.inv(K2), S2_inv
+                )
+                egt = np.matmul(np.matmul(M2.transpose(), egt), M1)
                 egt = egt / np.linalg.norm(egt)
             else:
                 raise NotImplementedError
 
-            corr=data['corr'][str(index_within_seq)][()]
-            incorr1,incorr2=data['incorr1'][str(index_within_seq)][()],data['incorr2'][str(index_within_seq)][()]
-            
-        #permute kpt
-        valid_corr=corr[corr.max(axis=-1)<self.config.num_kpt]
-        valid_incorr1,valid_incorr2=incorr1[incorr1<self.config.num_kpt],incorr2[incorr2<self.config.num_kpt]
-        num_corr, num_incorr1, num_incorr2 = len(valid_corr), len(valid_incorr1), len(valid_incorr2)
-        mask1_invlaid, mask2_invalid = np.ones(x1.shape[0]).astype(bool), np.ones(x2.shape[0]).astype(bool)
+            corr = data["corr"][str(index_within_seq)][()]
+            incorr1, incorr2 = (
+                data["incorr1"][str(index_within_seq)][()],
+                data["incorr2"][str(index_within_seq)][()],
+            )
+
+        # permute kpt
+        valid_corr = corr[corr.max(axis=-1) < self.config.num_kpt]
+        valid_incorr1, valid_incorr2 = (
+            incorr1[incorr1 < self.config.num_kpt],
+            incorr2[incorr2 < self.config.num_kpt],
+        )
+        num_corr, num_incorr1, num_incorr2 = (
+            len(valid_corr),
+            len(valid_incorr1),
+            len(valid_incorr2),
+        )
+        mask1_invlaid, mask2_invalid = np.ones(x1.shape[0]).astype(bool), np.ones(
+            x2.shape[0]
+        ).astype(bool)
         mask1_invlaid[valid_corr[:, 0]] = False
         mask2_invalid[valid_corr[:, 1]] = False
         mask1_invlaid[valid_incorr1] = False
         mask2_invalid[valid_incorr2] = False
-        invalid_index1,invalid_index2=np.nonzero(mask1_invlaid)[0],np.nonzero(mask2_invalid)[0]
+        invalid_index1, invalid_index2 = (
+            np.nonzero(mask1_invlaid)[0],
+            np.nonzero(mask2_invalid)[0],
+        )
 
-        #random sample from point w/o valid annotation 
+        # random sample from point w/o valid annotation
         cur_kpt1 = self.config.num_kpt - num_corr - num_incorr1
         cur_kpt2 = self.config.num_kpt - num_corr - num_incorr2
 
-        if (invalid_index1.shape[0] < cur_kpt1):
-            sub_idx1 = np.concatenate([np.arange(len(invalid_index1)),np.random.randint(len(invalid_index1),size=cur_kpt1-len(invalid_index1))])
-        if (invalid_index1.shape[0] >= cur_kpt1):
-            sub_idx1 =np.random.choice(len(invalid_index1), cur_kpt1,replace=False)
-        if (invalid_index2.shape[0] < cur_kpt2):
-            sub_idx2 = np.concatenate([np.arange(len(invalid_index2)),np.random.randint(len(invalid_index2),size=cur_kpt2-len(invalid_index2))])
-        if (invalid_index2.shape[0] >= cur_kpt2):
-            sub_idx2 = np.random.choice(len(invalid_index2), cur_kpt2,replace=False)
-        
-        per_idx1,per_idx2=np.concatenate([valid_corr[:,0],valid_incorr1,invalid_index1[sub_idx1]]),\
-                          np.concatenate([valid_corr[:,1],valid_incorr2,invalid_index2[sub_idx2]])
-        
-        pscore1,pscore2=pscore1[per_idx1][:,np.newaxis],pscore2[per_idx2][:,np.newaxis]
-        x1,x2=x1[per_idx1][:,:2],x2[per_idx2][:,:2]
-        desc1,desc2=desc1[per_idx1],desc2[per_idx2]
-        kpt1,kpt2=kpt1[per_idx1],kpt2[per_idx2]
-        
-        return {'x1': x1, 'x2': x2, 'kpt1':kpt1,'kpt2':kpt2,'desc1': desc1, 'desc2': desc2, 'num_corr': num_corr, 'num_incorr1': num_incorr1,'num_incorr2': num_incorr2,'e_gt':egt,\
-                'pscore1':pscore1,'pscore2':pscore2,'img_path1':img_path1,'img_path2':img_path2}
+        if invalid_index1.shape[0] < cur_kpt1:
+            sub_idx1 = np.concatenate(
+                [
+                    np.arange(len(invalid_index1)),
+                    np.random.randint(
+                        len(invalid_index1), size=cur_kpt1 - len(invalid_index1)
+                    ),
+                ]
+            )
+        if invalid_index1.shape[0] >= cur_kpt1:
+            sub_idx1 = np.random.choice(len(invalid_index1), cur_kpt1, replace=False)
+        if invalid_index2.shape[0] < cur_kpt2:
+            sub_idx2 = np.concatenate(
+                [
+                    np.arange(len(invalid_index2)),
+                    np.random.randint(
+                        len(invalid_index2), size=cur_kpt2 - len(invalid_index2)
+                    ),
+                ]
+            )
+        if invalid_index2.shape[0] >= cur_kpt2:
+            sub_idx2 = np.random.choice(len(invalid_index2), cur_kpt2, replace=False)
 
-    def __len__(self):
-        return self.total_pairs
+        per_idx1, per_idx2 = np.concatenate(
+            [valid_corr[:, 0], valid_incorr1, invalid_index1[sub_idx1]]
+        ), np.concatenate([valid_corr[:, 1], valid_incorr2, invalid_index2[sub_idx2]])
+
+        pscore1, pscore2 = (
+            pscore1[per_idx1][:, np.newaxis],
+            pscore2[per_idx2][:, np.newaxis],
+        )
+        x1, x2 = x1[per_idx1][:, :2], x2[per_idx2][:, :2]
+        desc1, desc2 = desc1[per_idx1], desc2[per_idx2]
+        kpt1, kpt2 = kpt1[per_idx1], kpt2[per_idx2]
 
+        return {
+            "x1": x1,
+            "x2": x2,
+            "kpt1": kpt1,
+            "kpt2": kpt2,
+            "desc1": desc1,
+            "desc2": desc2,
+            "num_corr": num_corr,
+            "num_incorr1": num_incorr1,
+            "num_incorr2": num_incorr2,
+            "e_gt": egt,
+            "pscore1": pscore1,
+            "pscore2": pscore2,
+            "img_path1": img_path1,
+            "img_path2": img_path2,
+        }
 
+    def __len__(self):
+        return self.total_pairs
diff --git a/third_party/SGMNet/train/loss.py b/third_party/SGMNet/train/loss.py
index fad4234fc5827321c31e72c08ad4a3466bad1c30..227f7c5d237be292e552a25ea899940ec54fc923 100644
--- a/third_party/SGMNet/train/loss.py
+++ b/third_party/SGMNet/train/loss.py
@@ -4,122 +4,195 @@ import numpy as np
 
 def batch_episym(x1, x2, F):
     batch_size, num_pts = x1.shape[0], x1.shape[1]
-    x1 = torch.cat([x1, x1.new_ones(batch_size, num_pts,1)], dim=-1).reshape(batch_size, num_pts,3,1)
-    x2 = torch.cat([x2, x2.new_ones(batch_size, num_pts,1)], dim=-1).reshape(batch_size, num_pts,3,1)
-    F = F.reshape(-1,1,3,3).repeat(1,num_pts,1,1)
-    x2Fx1 = torch.matmul(x2.transpose(2,3), torch.matmul(F, x1)).reshape(batch_size,num_pts)
-    Fx1 = torch.matmul(F,x1).reshape(batch_size,num_pts,3)
-    Ftx2 = torch.matmul(F.transpose(2,3),x2).reshape(batch_size,num_pts,3)
-    ys = (x2Fx1**2 * (
-            1.0 / (Fx1[:, :, 0]**2 + Fx1[:, :, 1]**2 + 1e-15) +
-            1.0 / (Ftx2[:, :, 0]**2 + Ftx2[:, :, 1]**2 + 1e-15))).sqrt()
+    x1 = torch.cat([x1, x1.new_ones(batch_size, num_pts, 1)], dim=-1).reshape(
+        batch_size, num_pts, 3, 1
+    )
+    x2 = torch.cat([x2, x2.new_ones(batch_size, num_pts, 1)], dim=-1).reshape(
+        batch_size, num_pts, 3, 1
+    )
+    F = F.reshape(-1, 1, 3, 3).repeat(1, num_pts, 1, 1)
+    x2Fx1 = torch.matmul(x2.transpose(2, 3), torch.matmul(F, x1)).reshape(
+        batch_size, num_pts
+    )
+    Fx1 = torch.matmul(F, x1).reshape(batch_size, num_pts, 3)
+    Ftx2 = torch.matmul(F.transpose(2, 3), x2).reshape(batch_size, num_pts, 3)
+    ys = (
+        x2Fx1**2
+        * (
+            1.0 / (Fx1[:, :, 0] ** 2 + Fx1[:, :, 1] ** 2 + 1e-15)
+            + 1.0 / (Ftx2[:, :, 0] ** 2 + Ftx2[:, :, 1] ** 2 + 1e-15)
+        )
+    ).sqrt()
     return ys
-    
-
-def CELoss(seed_x1,seed_x2,e,confidence,inlier_th,batch_mask=1):
-    #seed_x: b*k*2
-    ys=batch_episym(seed_x1,seed_x2,e)
-    mask_pos,mask_neg=(ys<=inlier_th).float(),(ys>inlier_th).float()
-    num_pos,num_neg=torch.relu(torch.sum(mask_pos, dim=1) - 1.0) + 1.0,torch.relu(torch.sum(mask_neg, dim=1) - 1.0) + 1.0
-    loss_pos,loss_neg=-torch.log(abs(confidence) + 1e-8)*mask_pos,-torch.log(abs(1-confidence)+1e-8)*mask_neg
-    classif_loss = torch.mean(loss_pos * 0.5 / num_pos.unsqueeze(-1) + loss_neg * 0.5 / num_neg.unsqueeze(-1),dim=-1)
-    classif_loss =classif_loss*batch_mask
-    classif_loss=classif_loss.mean()
+
+
+def CELoss(seed_x1, seed_x2, e, confidence, inlier_th, batch_mask=1):
+    # seed_x: b*k*2
+    ys = batch_episym(seed_x1, seed_x2, e)
+    mask_pos, mask_neg = (ys <= inlier_th).float(), (ys > inlier_th).float()
+    num_pos, num_neg = (
+        torch.relu(torch.sum(mask_pos, dim=1) - 1.0) + 1.0,
+        torch.relu(torch.sum(mask_neg, dim=1) - 1.0) + 1.0,
+    )
+    loss_pos, loss_neg = (
+        -torch.log(abs(confidence) + 1e-8) * mask_pos,
+        -torch.log(abs(1 - confidence) + 1e-8) * mask_neg,
+    )
+    classif_loss = torch.mean(
+        loss_pos * 0.5 / num_pos.unsqueeze(-1) + loss_neg * 0.5 / num_neg.unsqueeze(-1),
+        dim=-1,
+    )
+    classif_loss = classif_loss * batch_mask
+    classif_loss = classif_loss.mean()
     precision = torch.mean(
-        torch.sum((confidence > 0.5).type(confidence.type()) * mask_pos, dim=1) /
-        (torch.sum((confidence > 0.5).type(confidence.type()), dim=1)+1e-8)
+        torch.sum((confidence > 0.5).type(confidence.type()) * mask_pos, dim=1)
+        / (torch.sum((confidence > 0.5).type(confidence.type()), dim=1) + 1e-8)
     )
     recall = torch.mean(
-        torch.sum((confidence > 0.5).type(confidence.type()) * mask_pos, dim=1) /
-        num_pos
+        torch.sum((confidence > 0.5).type(confidence.type()) * mask_pos, dim=1)
+        / num_pos
     )
-    return classif_loss,precision,recall
+    return classif_loss, precision, recall
 
 
-def CorrLoss(desc_mat,batch_num_corr,batch_num_incorr1,batch_num_incorr2):
-    total_loss_corr,total_loss_incorr=0,0
-    total_acc_corr,total_acc_incorr=0,0
+def CorrLoss(desc_mat, batch_num_corr, batch_num_incorr1, batch_num_incorr2):
+    total_loss_corr, total_loss_incorr = 0, 0
+    total_acc_corr, total_acc_incorr = 0, 0
     batch_size = desc_mat.shape[0]
-    log_p=torch.log(abs(desc_mat)+1e-8)
+    log_p = torch.log(abs(desc_mat) + 1e-8)
 
     for i in range(batch_size):
-        cur_log_p=log_p[i]
-        num_corr=batch_num_corr[i]
-        num_incorr1,num_incorr2=batch_num_incorr1[i],batch_num_incorr2[i]
-     
-        #loss and acc
+        cur_log_p = log_p[i]
+        num_corr = batch_num_corr[i]
+        num_incorr1, num_incorr2 = batch_num_incorr1[i], batch_num_incorr2[i]
+
+        # loss and acc
         loss_corr = -torch.diag(cur_log_p)[:num_corr].mean()
-        loss_incorr=(-cur_log_p[num_corr:num_corr+num_incorr1,-1].mean()-cur_log_p[-1,num_corr:num_corr+num_incorr2].mean())/2
+        loss_incorr = (
+            -cur_log_p[num_corr : num_corr + num_incorr1, -1].mean()
+            - cur_log_p[-1, num_corr : num_corr + num_incorr2].mean()
+        ) / 2
 
-        value_row, row_index = torch.max(desc_mat[i,:-1,:-1], dim=-1)
-        value_col, col_index = torch.max(desc_mat[i,:-1,:-1], dim=-2)
-        acc_incorr=((value_row[num_corr:num_corr+num_incorr1]<0.2).float().mean()+
-                    (value_col[num_corr:num_corr+num_incorr2]<0.2).float().mean())/2
+        value_row, row_index = torch.max(desc_mat[i, :-1, :-1], dim=-1)
+        value_col, col_index = torch.max(desc_mat[i, :-1, :-1], dim=-2)
+        acc_incorr = (
+            (value_row[num_corr : num_corr + num_incorr1] < 0.2).float().mean()
+            + (value_col[num_corr : num_corr + num_incorr2] < 0.2).float().mean()
+        ) / 2
 
         acc_row_mask = row_index[:num_corr] == torch.arange(num_corr).cuda()
         acc_col_mask = col_index[:num_corr] == torch.arange(num_corr).cuda()
         acc = (acc_col_mask & acc_row_mask).float().mean()
-     
-        total_loss_corr+=loss_corr
-        total_loss_incorr+=loss_incorr
+
+        total_loss_corr += loss_corr
+        total_loss_incorr += loss_incorr
         total_acc_corr += acc
-        total_acc_incorr+=acc_incorr
+        total_acc_incorr += acc_incorr
 
-    total_acc_corr/=batch_size
-    total_acc_incorr/=batch_size
-    total_loss_corr/=batch_size
-    total_loss_incorr/=batch_size
-    return total_loss_corr,total_loss_incorr,total_acc_corr,total_acc_incorr
+    total_acc_corr /= batch_size
+    total_acc_incorr /= batch_size
+    total_loss_corr /= batch_size
+    total_loss_incorr /= batch_size
+    return total_loss_corr, total_loss_incorr, total_acc_corr, total_acc_incorr
 
 
 class SGMLoss:
-    def __init__(self,config,model_config):
-        self.config=config
-        self.model_config=model_config
-
-    def run(self,data,result):
-        loss_corr,loss_incorr,acc_corr,acc_incorr=CorrLoss(result['p'],data['num_corr'],data['num_incorr1'],data['num_incorr2'])
-        loss_mid_corr_tower,loss_mid_incorr_tower,acc_mid_tower=[],[],[]
-        
-        #mid loss
-        for i in range(len(result['mid_p'])):
-            mid_p=result['mid_p'][i]
-            loss_mid_corr,loss_mid_incorr,mid_acc_corr,mid_acc_incorr=CorrLoss(mid_p,data['num_corr'],data['num_incorr1'],data['num_incorr2'])
-            loss_mid_corr_tower.append(loss_mid_corr),loss_mid_incorr_tower.append(loss_mid_incorr),acc_mid_tower.append(mid_acc_corr)
-        if len(result['mid_p']) != 0:
-            loss_mid_corr_tower,loss_mid_incorr_tower, acc_mid_tower = torch.stack(loss_mid_corr_tower), torch.stack(loss_mid_incorr_tower), torch.stack(acc_mid_tower)
+    def __init__(self, config, model_config):
+        self.config = config
+        self.model_config = model_config
+
+    def run(self, data, result):
+        loss_corr, loss_incorr, acc_corr, acc_incorr = CorrLoss(
+            result["p"], data["num_corr"], data["num_incorr1"], data["num_incorr2"]
+        )
+        loss_mid_corr_tower, loss_mid_incorr_tower, acc_mid_tower = [], [], []
+
+        # mid loss
+        for i in range(len(result["mid_p"])):
+            mid_p = result["mid_p"][i]
+            loss_mid_corr, loss_mid_incorr, mid_acc_corr, mid_acc_incorr = CorrLoss(
+                mid_p, data["num_corr"], data["num_incorr1"], data["num_incorr2"]
+            )
+            loss_mid_corr_tower.append(loss_mid_corr), loss_mid_incorr_tower.append(
+                loss_mid_incorr
+            ), acc_mid_tower.append(mid_acc_corr)
+        if len(result["mid_p"]) != 0:
+            loss_mid_corr_tower, loss_mid_incorr_tower, acc_mid_tower = (
+                torch.stack(loss_mid_corr_tower),
+                torch.stack(loss_mid_incorr_tower),
+                torch.stack(acc_mid_tower),
+            )
         else:
-            loss_mid_corr_tower,loss_mid_incorr_tower, acc_mid_tower= torch.zeros(1).cuda(), torch.zeros(1).cuda(),torch.zeros(1).cuda()
-  
-        #seed confidence loss
-        classif_loss_tower,classif_precision_tower,classif_recall_tower=[],[],[]
-        for layer in range(len(result['seed_conf'])):
-            confidence=result['seed_conf'][layer]
-            seed_index=result['seed_index'][(np.asarray(self.model_config.seedlayer)<=layer).nonzero()[0][-1]]
-            seed_x1,seed_x2=data['x1'].gather(dim=1, index=seed_index[:,:,0,None].expand(-1, -1,2)),\
-                            data['x2'].gather(dim=1, index=seed_index[:,:,1,None].expand(-1, -1,2))
-            classif_loss,classif_precision,classif_recall=CELoss(seed_x1,seed_x2,data['e_gt'],confidence,self.config.inlier_th)
-            classif_loss_tower.append(classif_loss), classif_precision_tower.append(classif_precision), classif_recall_tower.append(classif_recall)
-        classif_loss, classif_precision_tower, classif_recall_tower=torch.stack(classif_loss_tower).mean(),torch.stack(classif_precision_tower), \
-                                                                    torch.stack(classif_recall_tower)
-       
-            
-        classif_loss*=self.config.seed_loss_weight
-        loss_mid_corr_tower*=self.config.mid_loss_weight
-        loss_mid_incorr_tower*=self.config.mid_loss_weight
-        total_loss=loss_corr+loss_incorr+classif_loss+loss_mid_corr_tower.sum()+loss_mid_incorr_tower.sum()
-
-        return {'loss_corr':loss_corr,'loss_incorr':loss_incorr,'acc_corr':acc_corr,'acc_incorr':acc_incorr,'loss_seed_conf':classif_loss,
-                'pre_seed_conf':classif_precision_tower,'recall_seed_conf':classif_recall_tower,'loss_corr_mid':loss_mid_corr_tower,
-                'loss_incorr_mid':loss_mid_incorr_tower,'mid_acc_corr':acc_mid_tower,'total_loss':total_loss}
-        
+            loss_mid_corr_tower, loss_mid_incorr_tower, acc_mid_tower = (
+                torch.zeros(1).cuda(),
+                torch.zeros(1).cuda(),
+                torch.zeros(1).cuda(),
+            )
+
+        # seed confidence loss
+        classif_loss_tower, classif_precision_tower, classif_recall_tower = [], [], []
+        for layer in range(len(result["seed_conf"])):
+            confidence = result["seed_conf"][layer]
+            seed_index = result["seed_index"][
+                (np.asarray(self.model_config.seedlayer) <= layer).nonzero()[0][-1]
+            ]
+            seed_x1, seed_x2 = data["x1"].gather(
+                dim=1, index=seed_index[:, :, 0, None].expand(-1, -1, 2)
+            ), data["x2"].gather(
+                dim=1, index=seed_index[:, :, 1, None].expand(-1, -1, 2)
+            )
+            classif_loss, classif_precision, classif_recall = CELoss(
+                seed_x1, seed_x2, data["e_gt"], confidence, self.config.inlier_th
+            )
+            classif_loss_tower.append(classif_loss), classif_precision_tower.append(
+                classif_precision
+            ), classif_recall_tower.append(classif_recall)
+        classif_loss, classif_precision_tower, classif_recall_tower = (
+            torch.stack(classif_loss_tower).mean(),
+            torch.stack(classif_precision_tower),
+            torch.stack(classif_recall_tower),
+        )
+
+        classif_loss *= self.config.seed_loss_weight
+        loss_mid_corr_tower *= self.config.mid_loss_weight
+        loss_mid_incorr_tower *= self.config.mid_loss_weight
+        total_loss = (
+            loss_corr
+            + loss_incorr
+            + classif_loss
+            + loss_mid_corr_tower.sum()
+            + loss_mid_incorr_tower.sum()
+        )
+
+        return {
+            "loss_corr": loss_corr,
+            "loss_incorr": loss_incorr,
+            "acc_corr": acc_corr,
+            "acc_incorr": acc_incorr,
+            "loss_seed_conf": classif_loss,
+            "pre_seed_conf": classif_precision_tower,
+            "recall_seed_conf": classif_recall_tower,
+            "loss_corr_mid": loss_mid_corr_tower,
+            "loss_incorr_mid": loss_mid_incorr_tower,
+            "mid_acc_corr": acc_mid_tower,
+            "total_loss": total_loss,
+        }
+
+
 class SGLoss:
-    def __init__(self,config,model_config):
-        self.config=config
-        self.model_config=model_config
-        
-    def run(self,data,result):
-        loss_corr,loss_incorr,acc_corr,acc_incorr=CorrLoss(result['p'],data['num_corr'],data['num_incorr1'],data['num_incorr2'])
-        total_loss=loss_corr+loss_incorr
-        return {'loss_corr':loss_corr,'loss_incorr':loss_incorr,'acc_corr':acc_corr,'acc_incorr':acc_incorr,'total_loss':total_loss}
-     
\ No newline at end of file
+    def __init__(self, config, model_config):
+        self.config = config
+        self.model_config = model_config
+
+    def run(self, data, result):
+        loss_corr, loss_incorr, acc_corr, acc_incorr = CorrLoss(
+            result["p"], data["num_corr"], data["num_incorr1"], data["num_incorr2"]
+        )
+        total_loss = loss_corr + loss_incorr
+        return {
+            "loss_corr": loss_corr,
+            "loss_incorr": loss_incorr,
+            "acc_corr": acc_corr,
+            "acc_incorr": acc_incorr,
+            "total_loss": total_loss,
+        }
diff --git a/third_party/SGMNet/train/main.py b/third_party/SGMNet/train/main.py
index 9d4c8fff432a3b2d58c82b9e5f2897a4e702b2dd..00e1bf699a92057c445d4b5f83eb46794d6fb7f7 100644
--- a/third_party/SGMNet/train/main.py
+++ b/third_party/SGMNet/train/main.py
@@ -11,51 +11,72 @@ from train import train
 from config import get_config, print_usage
 
 
-def main(config,model_config):
+def main(config, model_config):
     """The main function."""
     # Initialize network
-    if config.model_name=='SGM':
+    if config.model_name == "SGM":
         model = SGM_Model(model_config)
-    elif config.model_name=='SG':
-        model= SG_Model(model_config)
+    elif config.model_name == "SG":
+        model = SG_Model(model_config)
     else:
         raise NotImplementedError
 
-    #initialize ddp
+    # initialize ddp
     torch.cuda.set_device(config.local_rank)
-    device = torch.device(f'cuda:{config.local_rank}')
+    device = torch.device(f"cuda:{config.local_rank}")
     model.to(device)
-    dist.init_process_group(backend='nccl',init_method='env://')
-    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.local_rank])
-    
-    if config.local_rank==0:
-        os.system('nvidia-smi')
-
-    #initialize dataset
-    train_dataset = Offline_Dataset(config,'train')
-    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,shuffle=True)
-    train_loader=torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size//torch.distributed.get_world_size(),
-            num_workers=8//dist.get_world_size(), pin_memory=False,sampler=train_sampler,collate_fn=train_dataset.collate_fn)
-    
-    valid_dataset = Offline_Dataset(config,'valid')
-    valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset,shuffle=False)
-    valid_loader=torch.utils.data.DataLoader(valid_dataset, batch_size=config.train_batch_size,
-                num_workers=8//dist.get_world_size(), pin_memory=False,collate_fn=valid_dataset.collate_fn,sampler=valid_sampler)
-    
-    if config.local_rank==0:
-        print('start training .....')
-    train(model,train_loader, valid_loader, config,model_config)
+    dist.init_process_group(backend="nccl", init_method="env://")
+    model = torch.nn.parallel.DistributedDataParallel(
+        model, device_ids=[config.local_rank]
+    )
+
+    if config.local_rank == 0:
+        os.system("nvidia-smi")
+
+    # initialize dataset
+    train_dataset = Offline_Dataset(config, "train")
+    train_sampler = torch.utils.data.distributed.DistributedSampler(
+        train_dataset, shuffle=True
+    )
+    train_loader = torch.utils.data.DataLoader(
+        train_dataset,
+        batch_size=config.train_batch_size // torch.distributed.get_world_size(),
+        num_workers=8 // dist.get_world_size(),
+        pin_memory=False,
+        sampler=train_sampler,
+        collate_fn=train_dataset.collate_fn,
+    )
+
+    valid_dataset = Offline_Dataset(config, "valid")
+    valid_sampler = torch.utils.data.distributed.DistributedSampler(
+        valid_dataset, shuffle=False
+    )
+    valid_loader = torch.utils.data.DataLoader(
+        valid_dataset,
+        batch_size=config.train_batch_size,
+        num_workers=8 // dist.get_world_size(),
+        pin_memory=False,
+        collate_fn=valid_dataset.collate_fn,
+        sampler=valid_sampler,
+    )
+
+    if config.local_rank == 0:
+        print("start training .....")
+    train(model, train_loader, valid_loader, config, model_config)
+
 
 if __name__ == "__main__":
     # ----------------------------------------
     # Parse configuration
     config, unparsed = get_config()
-    with open(config.config_path, 'r') as f:
+    with open(config.config_path, "r") as f:
         model_config = yaml.load(f)
-    model_config=namedtuple('model_config',model_config.keys())(*model_config.values())
+    model_config = namedtuple("model_config", model_config.keys())(
+        *model_config.values()
+    )
     # If we have unparsed arguments, print usage and exit
     if len(unparsed) > 0:
         print_usage()
         exit(1)
 
-    main(config,model_config)
+    main(config, model_config)
diff --git a/third_party/SGMNet/train/train.py b/third_party/SGMNet/train/train.py
index 31e848e1d2e5f028d4ff3abaf0cc446be7d89c65..b012b7bf231de77972f443ab6979038151d2cfce 100644
--- a/third_party/SGMNet/train/train.py
+++ b/third_party/SGMNet/train/train.py
@@ -5,156 +5,226 @@ import os
 from tensorboardX import SummaryWriter
 import numpy as np
 import cv2
-from loss import SGMLoss,SGLoss
-from valid import valid,dump_train_vis
+from loss import SGMLoss, SGLoss
+from valid import valid, dump_train_vis
 
 import sys
+
 ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
 sys.path.insert(0, ROOT_DIR)
 
 
 from utils import train_utils
 
-def train_step(optimizer, model, match_loss, data,step,pre_avg_loss):
-    data['step']=step
-    result=model(data,test_mode=False)
-    loss_res=match_loss.run(data,result)
-    
+
+def train_step(optimizer, model, match_loss, data, step, pre_avg_loss):
+    data["step"] = step
+    result = model(data, test_mode=False)
+    loss_res = match_loss.run(data, result)
+
     optimizer.zero_grad()
-    loss_res['total_loss'].backward()
-    #apply reduce on all record tensor
+    loss_res["total_loss"].backward()
+    # apply reduce on all record tensor
     for key in loss_res.keys():
-        loss_res[key]=train_utils.reduce_tensor(loss_res[key],'mean')
-  
-    if loss_res['total_loss']<7*pre_avg_loss or step<200 or pre_avg_loss==0:
+        loss_res[key] = train_utils.reduce_tensor(loss_res[key], "mean")
+
+    if loss_res["total_loss"] < 7 * pre_avg_loss or step < 200 or pre_avg_loss == 0:
         optimizer.step()
-        unusual_loss=False
+        unusual_loss = False
     else:
         optimizer.zero_grad()
-        unusual_loss=True
-    return loss_res,unusual_loss
+        unusual_loss = True
+    return loss_res, unusual_loss
 
 
-def train(model, train_loader, valid_loader, config,model_config):
+def train(model, train_loader, valid_loader, config, model_config):
     model.train()
     optimizer = optim.Adam(model.parameters(), lr=config.train_lr)
-    
-    if config.model_name=='SGM':
-        match_loss = SGMLoss(config,model_config) 
-    elif config.model_name=='SG':
-        match_loss= SGLoss(config,model_config)
+
+    if config.model_name == "SGM":
+        match_loss = SGMLoss(config, model_config)
+    elif config.model_name == "SG":
+        match_loss = SGLoss(config, model_config)
     else:
         raise NotImplementedError
-    
-    checkpoint_path = os.path.join(config.log_base, 'checkpoint.pth')
+
+    checkpoint_path = os.path.join(config.log_base, "checkpoint.pth")
     config.resume = os.path.isfile(checkpoint_path)
     if config.resume:
-        if config.local_rank==0:
-            print('==> Resuming from checkpoint..')
-        checkpoint = torch.load(checkpoint_path,map_location='cuda:{}'.format(config.local_rank))
-        model.load_state_dict(checkpoint['state_dict'])
-        best_acc = checkpoint['best_acc']
-        start_step = checkpoint['step']
-        optimizer.load_state_dict(checkpoint['optimizer'])
+        if config.local_rank == 0:
+            print("==> Resuming from checkpoint..")
+        checkpoint = torch.load(
+            checkpoint_path, map_location="cuda:{}".format(config.local_rank)
+        )
+        model.load_state_dict(checkpoint["state_dict"])
+        best_acc = checkpoint["best_acc"]
+        start_step = checkpoint["step"]
+        optimizer.load_state_dict(checkpoint["optimizer"])
     else:
         best_acc = -1
         start_step = 0
     train_loader_iter = iter(train_loader)
-    
-    if config.local_rank==0:
-        writer=SummaryWriter(os.path.join(config.log_base,'log_file'))
-
-    train_loader.sampler.set_epoch(start_step*config.train_batch_size//len(train_loader.dataset))
-    pre_avg_loss=0
-    
-    progress_bar=trange(start_step, config.train_iter,ncols=config.tqdm_width) if config.local_rank==0 else range(start_step, config.train_iter)
+
+    if config.local_rank == 0:
+        writer = SummaryWriter(os.path.join(config.log_base, "log_file"))
+
+    train_loader.sampler.set_epoch(
+        start_step * config.train_batch_size // len(train_loader.dataset)
+    )
+    pre_avg_loss = 0
+
+    progress_bar = (
+        trange(start_step, config.train_iter, ncols=config.tqdm_width)
+        if config.local_rank == 0
+        else range(start_step, config.train_iter)
+    )
     for step in progress_bar:
         try:
             train_data = next(train_loader_iter)
         except StopIteration:
-            if config.local_rank==0:
-                print('epoch: ',step*config.train_batch_size//len(train_loader.dataset))
-            train_loader.sampler.set_epoch(step*config.train_batch_size//len(train_loader.dataset))
+            if config.local_rank == 0:
+                print(
+                    "epoch: ",
+                    step * config.train_batch_size // len(train_loader.dataset),
+                )
+            train_loader.sampler.set_epoch(
+                step * config.train_batch_size // len(train_loader.dataset)
+            )
             train_loader_iter = iter(train_loader)
             train_data = next(train_loader_iter)
-    
+
         train_data = train_utils.tocuda(train_data)
-        lr=min(config.train_lr*config.decay_rate**(step-config.decay_iter),config.train_lr)
+        lr = min(
+            config.train_lr * config.decay_rate ** (step - config.decay_iter),
+            config.train_lr,
+        )
         for param_group in optimizer.param_groups:
-            param_group['lr'] = lr
+            param_group["lr"] = lr
 
         # run training
-        loss_res,unusual_loss = train_step(optimizer, model, match_loss, train_data,step-start_step,pre_avg_loss)
-        if (step-start_step)<=200:
-            pre_avg_loss=loss_res['total_loss'].data
-        if (step-start_step)>200 and not unusual_loss:
-            pre_avg_loss=pre_avg_loss.data*0.9+loss_res['total_loss'].data*0.1
-        if unusual_loss and config.local_rank==0:
-            print('unusual loss! pre_avg_loss: ',pre_avg_loss,'cur_loss: ',loss_res['total_loss'].data)
-        #log
-        if config.local_rank==0 and step%config.log_intv==0 and not unusual_loss:
-            writer.add_scalar('TotalLoss',loss_res['total_loss'],step)
-            writer.add_scalar('CorrLoss',loss_res['loss_corr'],step)
-            writer.add_scalar('InCorrLoss', loss_res['loss_incorr'], step)
-            writer.add_scalar('dustbin', model.module.dustbin, step)
-
-            if config.model_name=='SGM':
-                writer.add_scalar('SeedConfLoss', loss_res['loss_seed_conf'], step)
-                writer.add_scalar('MidCorrLoss', loss_res['loss_corr_mid'].sum(), step)
-                writer.add_scalar('MidInCorrLoss', loss_res['loss_incorr_mid'].sum(), step)
-            
+        loss_res, unusual_loss = train_step(
+            optimizer, model, match_loss, train_data, step - start_step, pre_avg_loss
+        )
+        if (step - start_step) <= 200:
+            pre_avg_loss = loss_res["total_loss"].data
+        if (step - start_step) > 200 and not unusual_loss:
+            pre_avg_loss = pre_avg_loss.data * 0.9 + loss_res["total_loss"].data * 0.1
+        if unusual_loss and config.local_rank == 0:
+            print(
+                "unusual loss! pre_avg_loss: ",
+                pre_avg_loss,
+                "cur_loss: ",
+                loss_res["total_loss"].data,
+            )
+        # log
+        if config.local_rank == 0 and step % config.log_intv == 0 and not unusual_loss:
+            writer.add_scalar("TotalLoss", loss_res["total_loss"], step)
+            writer.add_scalar("CorrLoss", loss_res["loss_corr"], step)
+            writer.add_scalar("InCorrLoss", loss_res["loss_incorr"], step)
+            writer.add_scalar("dustbin", model.module.dustbin, step)
+
+            if config.model_name == "SGM":
+                writer.add_scalar("SeedConfLoss", loss_res["loss_seed_conf"], step)
+                writer.add_scalar("MidCorrLoss", loss_res["loss_corr_mid"].sum(), step)
+                writer.add_scalar(
+                    "MidInCorrLoss", loss_res["loss_incorr_mid"].sum(), step
+                )
 
         # valid ans save
         b_save = ((step + 1) % config.save_intv) == 0
         b_validate = ((step + 1) % config.val_intv) == 0
         if b_validate:
-            total_loss,acc_corr,acc_incorr,seed_precision_tower,seed_recall_tower,acc_mid=valid(valid_loader, model, match_loss, config,model_config)
-            if config.local_rank==0:
-                writer.add_scalar('ValidAcc', acc_corr, step)
-                writer.add_scalar('ValidLoss', total_loss, step)
-                
-                if config.model_name=='SGM':
+            (
+                total_loss,
+                acc_corr,
+                acc_incorr,
+                seed_precision_tower,
+                seed_recall_tower,
+                acc_mid,
+            ) = valid(valid_loader, model, match_loss, config, model_config)
+            if config.local_rank == 0:
+                writer.add_scalar("ValidAcc", acc_corr, step)
+                writer.add_scalar("ValidLoss", total_loss, step)
+
+                if config.model_name == "SGM":
                     for i in range(len(seed_recall_tower)):
-                        writer.add_scalar('seed_conf_pre_%d'%i,seed_precision_tower[i],step)
-                        writer.add_scalar('seed_conf_recall_%d' % i, seed_precision_tower[i], step)
+                        writer.add_scalar(
+                            "seed_conf_pre_%d" % i, seed_precision_tower[i], step
+                        )
+                        writer.add_scalar(
+                            "seed_conf_recall_%d" % i, seed_precision_tower[i], step
+                        )
                     for i in range(len(acc_mid)):
-                        writer.add_scalar('acc_mid%d'%i,acc_mid[i],step)
-                    print('acc_corr: ',acc_corr.data,'acc_incorr: ',acc_incorr.data,'seed_conf_pre: ',seed_precision_tower.mean().data,
-                     'seed_conf_recall: ',seed_recall_tower.mean().data,'acc_mid: ',acc_mid.mean().data)
+                        writer.add_scalar("acc_mid%d" % i, acc_mid[i], step)
+                    print(
+                        "acc_corr: ",
+                        acc_corr.data,
+                        "acc_incorr: ",
+                        acc_incorr.data,
+                        "seed_conf_pre: ",
+                        seed_precision_tower.mean().data,
+                        "seed_conf_recall: ",
+                        seed_recall_tower.mean().data,
+                        "acc_mid: ",
+                        acc_mid.mean().data,
+                    )
                 else:
-                     print('acc_corr: ',acc_corr.data,'acc_incorr: ',acc_incorr.data)
-                
-                #saving best
+                    print("acc_corr: ", acc_corr.data, "acc_incorr: ", acc_incorr.data)
+
+                # saving best
                 if acc_corr > best_acc:
                     print("Saving best model with va_res = {}".format(acc_corr))
                     best_acc = acc_corr
-                    save_dict={'step': step + 1,
-                    'state_dict': model.state_dict(),
-                    'best_acc': best_acc,
-                    'optimizer' : optimizer.state_dict()}
+                    save_dict = {
+                        "step": step + 1,
+                        "state_dict": model.state_dict(),
+                        "best_acc": best_acc,
+                        "optimizer": optimizer.state_dict(),
+                    }
                     save_dict.update(save_dict)
-                    torch.save(save_dict, os.path.join(config.log_base, 'model_best.pth'))
+                    torch.save(
+                        save_dict, os.path.join(config.log_base, "model_best.pth")
+                    )
 
         if b_save:
-            if config.local_rank==0:
-                save_dict={'step': step + 1,
-                'state_dict': model.state_dict(),
-                'best_acc': best_acc,
-                'optimizer' : optimizer.state_dict()}
+            if config.local_rank == 0:
+                save_dict = {
+                    "step": step + 1,
+                    "state_dict": model.state_dict(),
+                    "best_acc": best_acc,
+                    "optimizer": optimizer.state_dict(),
+                }
                 torch.save(save_dict, checkpoint_path)
-            
-            #draw match results
+
+            # draw match results
             model.eval()
             with torch.no_grad():
-                if config.local_rank==0:
-                    if not os.path.exists(os.path.join(config.train_vis_folder,'train_vis')):
-                        os.mkdir(os.path.join(config.train_vis_folder,'train_vis'))
-                    if not os.path.exists(os.path.join(config.train_vis_folder,'train_vis',config.log_base)):
-                        os.mkdir(os.path.join(config.train_vis_folder,'train_vis',config.log_base))
-                    os.mkdir(os.path.join(config.train_vis_folder,'train_vis',config.log_base,str(step)))
-                res=model(train_data)
-                dump_train_vis(res,train_data,step,config)
+                if config.local_rank == 0:
+                    if not os.path.exists(
+                        os.path.join(config.train_vis_folder, "train_vis")
+                    ):
+                        os.mkdir(os.path.join(config.train_vis_folder, "train_vis"))
+                    if not os.path.exists(
+                        os.path.join(
+                            config.train_vis_folder, "train_vis", config.log_base
+                        )
+                    ):
+                        os.mkdir(
+                            os.path.join(
+                                config.train_vis_folder, "train_vis", config.log_base
+                            )
+                        )
+                    os.mkdir(
+                        os.path.join(
+                            config.train_vis_folder,
+                            "train_vis",
+                            config.log_base,
+                            str(step),
+                        )
+                    )
+                res = model(train_data)
+                dump_train_vis(res, train_data, step, config)
             model.train()
-    
-    if config.local_rank==0:
+
+    if config.local_rank == 0:
         writer.close()
diff --git a/third_party/SGMNet/train/valid.py b/third_party/SGMNet/train/valid.py
index 443694d85104730cd50aeb342326ce593dc5684d..b9873f9b34ff77462d87aaad8c128e3b497fa39a 100644
--- a/third_party/SGMNet/train/valid.py
+++ b/third_party/SGMNet/train/valid.py
@@ -6,72 +6,119 @@ from loss import batch_episym
 from tqdm import tqdm
 
 import sys
+
 ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
 sys.path.insert(0, ROOT_DIR)
 
-from utils import evaluation_utils,train_utils
+from utils import evaluation_utils, train_utils
 
 
-def valid(valid_loader, model,match_loss, config,model_config):
+def valid(valid_loader, model, match_loss, config, model_config):
     model.eval()
     loader_iter = iter(valid_loader)
     num_pair = 0
-    total_loss,total_acc_corr,total_acc_incorr=0,0,0
-    total_precision,total_recall=torch.zeros(model_config.layer_num ,device='cuda'),\
-                                 torch.zeros(model_config.layer_num ,device='cuda')
-    total_acc_mid=torch.zeros(len(model_config.seedlayer)-1,device='cuda')
+    total_loss, total_acc_corr, total_acc_incorr = 0, 0, 0
+    total_precision, total_recall = torch.zeros(
+        model_config.layer_num, device="cuda"
+    ), torch.zeros(model_config.layer_num, device="cuda")
+    total_acc_mid = torch.zeros(len(model_config.seedlayer) - 1, device="cuda")
 
     with torch.no_grad():
-        if config.local_rank==0:
-            loader_iter=tqdm(loader_iter)
-            print('validating...')
+        if config.local_rank == 0:
+            loader_iter = tqdm(loader_iter)
+            print("validating...")
         for test_data in loader_iter:
-            num_pair+= 1
+            num_pair += 1
             test_data = train_utils.tocuda(test_data)
-            res= model(test_data)
-            loss_res=match_loss.run(test_data,res)
-           
-            total_acc_corr+=loss_res['acc_corr']
-            total_acc_incorr+=loss_res['acc_incorr']
-            total_loss+=loss_res['total_loss']
+            res = model(test_data)
+            loss_res = match_loss.run(test_data, res)
+
+            total_acc_corr += loss_res["acc_corr"]
+            total_acc_incorr += loss_res["acc_incorr"]
+            total_loss += loss_res["total_loss"]
 
-            if config.model_name=='SGM':
-                total_acc_mid+=loss_res['mid_acc_corr']
-                total_precision,total_recall=total_precision+loss_res['pre_seed_conf'],total_recall+loss_res['recall_seed_conf']
-                
-        total_acc_corr/=num_pair
+            if config.model_name == "SGM":
+                total_acc_mid += loss_res["mid_acc_corr"]
+                total_precision, total_recall = (
+                    total_precision + loss_res["pre_seed_conf"],
+                    total_recall + loss_res["recall_seed_conf"],
+                )
+
+        total_acc_corr /= num_pair
         total_acc_incorr /= num_pair
-        total_precision/=num_pair
-        total_recall/=num_pair
-        total_acc_mid/=num_pair
+        total_precision /= num_pair
+        total_recall /= num_pair
+        total_acc_mid /= num_pair
 
-        #apply tensor reduction
-        total_loss,total_acc_corr,total_acc_incorr,total_precision,total_recall,total_acc_mid=train_utils.reduce_tensor(total_loss,'sum'),\
-                        train_utils.reduce_tensor(total_acc_corr,'mean'),train_utils.reduce_tensor(total_acc_incorr,'mean'),\
-                        train_utils.reduce_tensor(total_precision,'mean'),train_utils.reduce_tensor(total_recall,'mean'),train_utils.reduce_tensor(total_acc_mid,'mean')
+        # apply tensor reduction
+        (
+            total_loss,
+            total_acc_corr,
+            total_acc_incorr,
+            total_precision,
+            total_recall,
+            total_acc_mid,
+        ) = (
+            train_utils.reduce_tensor(total_loss, "sum"),
+            train_utils.reduce_tensor(total_acc_corr, "mean"),
+            train_utils.reduce_tensor(total_acc_incorr, "mean"),
+            train_utils.reduce_tensor(total_precision, "mean"),
+            train_utils.reduce_tensor(total_recall, "mean"),
+            train_utils.reduce_tensor(total_acc_mid, "mean"),
+        )
     model.train()
-    return total_loss,total_acc_corr,total_acc_incorr,total_precision,total_recall,total_acc_mid
-
+    return (
+        total_loss,
+        total_acc_corr,
+        total_acc_incorr,
+        total_precision,
+        total_recall,
+        total_acc_mid,
+    )
 
 
-def dump_train_vis(res,data,step,config):
-    #batch matching
-    p=res['p'][:,:-1,:-1]
-    score,index1=torch.max(p,dim=-1)
-    _,index2=torch.max(p,dim=-2)
-    mask_th=score>0.2
-    mask_mc=index2.gather(index=index1,dim=1) == torch.arange(len(p[0])).cuda()[None]
-    mask_p=mask_th&mask_mc#B*N
+def dump_train_vis(res, data, step, config):
+    # batch matching
+    p = res["p"][:, :-1, :-1]
+    score, index1 = torch.max(p, dim=-1)
+    _, index2 = torch.max(p, dim=-2)
+    mask_th = score > 0.2
+    mask_mc = index2.gather(index=index1, dim=1) == torch.arange(len(p[0])).cuda()[None]
+    mask_p = mask_th & mask_mc  # B*N
 
-    corr1,corr2=data['x1'],data['x2'].gather(index=index1[:,:,None].expand(-1,-1,2),dim=1)
-    corr1_kpt,corr2_kpt=data['kpt1'],data['kpt2'].gather(index=index1[:,:,None].expand(-1,-1,2),dim=1)
-    epi_dis=batch_episym(corr1,corr2,data['e_gt'])
-    mask_inlier=epi_dis<config.inlier_th#B*N
+    corr1, corr2 = data["x1"], data["x2"].gather(
+        index=index1[:, :, None].expand(-1, -1, 2), dim=1
+    )
+    corr1_kpt, corr2_kpt = data["kpt1"], data["kpt2"].gather(
+        index=index1[:, :, None].expand(-1, -1, 2), dim=1
+    )
+    epi_dis = batch_episym(corr1, corr2, data["e_gt"])
+    mask_inlier = epi_dis < config.inlier_th  # B*N
 
-    #dump vis
-    for cur_mask_p,cur_mask_inlier,cur_corr1,cur_corr2,img_path1,img_path2 in zip(mask_p,mask_inlier,corr1_kpt,corr2_kpt,data['img_path1'],data['img_path2']):
-        img1,img2=cv2.imread(img_path1),cv2.imread(img_path2)
-        dis_play=evaluation_utils.draw_match(img1,img2,cur_corr1[cur_mask_p].cpu().numpy(),cur_corr2[cur_mask_p].cpu().numpy(),inlier=cur_mask_inlier)
-        base_name_seq=os.path.join(img_path1.split('/')[-1]+'_'+img_path2.split('/')[-1]+'_'+img_path1.split('/')[-2])
-        save_path=os.path.join(config.train_vis_folder,'train_vis',config.log_base,str(step),base_name_seq+'.png')
-        cv2.imwrite(save_path,dis_play)
\ No newline at end of file
+    # dump vis
+    for cur_mask_p, cur_mask_inlier, cur_corr1, cur_corr2, img_path1, img_path2 in zip(
+        mask_p, mask_inlier, corr1_kpt, corr2_kpt, data["img_path1"], data["img_path2"]
+    ):
+        img1, img2 = cv2.imread(img_path1), cv2.imread(img_path2)
+        dis_play = evaluation_utils.draw_match(
+            img1,
+            img2,
+            cur_corr1[cur_mask_p].cpu().numpy(),
+            cur_corr2[cur_mask_p].cpu().numpy(),
+            inlier=cur_mask_inlier,
+        )
+        base_name_seq = os.path.join(
+            img_path1.split("/")[-1]
+            + "_"
+            + img_path2.split("/")[-1]
+            + "_"
+            + img_path1.split("/")[-2]
+        )
+        save_path = os.path.join(
+            config.train_vis_folder,
+            "train_vis",
+            config.log_base,
+            str(step),
+            base_name_seq + ".png",
+        )
+        cv2.imwrite(save_path, dis_play)
diff --git a/third_party/SGMNet/utils/__init__.py b/third_party/SGMNet/utils/__init__.py
index 2e456fd7c48ed8d25157a9344e300d412ea47c1c..354f9ed78c66b2df30dd8203ac7a2be95741f7af 100644
--- a/third_party/SGMNet/utils/__init__.py
+++ b/third_party/SGMNet/utils/__init__.py
@@ -2,4 +2,4 @@ from . import fm_utils
 from . import evaluation_utils
 from . import metrics
 from . import transformations
-from . import data_utils
\ No newline at end of file
+from . import data_utils
diff --git a/third_party/SGMNet/utils/data_utils.py b/third_party/SGMNet/utils/data_utils.py
index 0dc51419e667e7c1f13c9ac1f0f37ab9b64325ee..7a6075b2802b1c69a7476364a973cdb5b54af616 100644
--- a/third_party/SGMNet/utils/data_utils.py
+++ b/third_party/SGMNet/utils/data_utils.py
@@ -1,151 +1,233 @@
-import  numpy as np
+import numpy as np
 
 
 def norm_kpt(K, kp):
     kp = np.concatenate([kp, np.ones([kp.shape[0], 1])], axis=1)
     kp = np.matmul(kp, np.linalg.inv(K).T)[:, :2]
     return kp
-    
-def unnorm_kp(K,kp):
+
+
+def unnorm_kp(K, kp):
     kp = np.concatenate([kp, np.ones([kp.shape[0], 1])], axis=1)
-    kp = np.matmul(kp,K.T)[:, :2]
+    kp = np.matmul(kp, K.T)[:, :2]
     return kp
 
-def interpolate_depth(pos, depth):
-        # pos:[y,x]
-        ids = np.array(range(0, pos.shape[0]))
-    
-        h, w = depth.shape
-
-        i = pos[:, 0]
-        j = pos[:, 1]
-        valid_corner=np.logical_and(np.logical_and(i>0,i<h-1),np.logical_and(j>0,j<w-1))
-        i,j=i[valid_corner],j[valid_corner]
-        ids = ids[valid_corner]
-
-        i_top_left = np.floor(i).astype(np.int32)
-        j_top_left = np.floor(j).astype(np.int32)
-
-        i_top_right = np.floor(i).astype(np.int32)
-        j_top_right = np.ceil(j).astype(np.int32)
-
-        i_bottom_left = np.ceil(i).astype(np.int32)
-        j_bottom_left = np.floor(j).astype(np.int32)
-
-        i_bottom_right = np.ceil(i).astype(np.int32)
-        j_bottom_right = np.ceil(j).astype(np.int32)
-        
-        # Valid depth
-        depth_top_left,depth_top_right,depth_down_left,depth_down_right=depth[i_top_left, j_top_left],depth[i_top_right, j_top_right],\
-                                                             depth[i_bottom_left, j_bottom_left],depth[i_bottom_right, j_bottom_right]
-        
-        valid_depth = np.logical_and(
-            np.logical_and(
-                depth_top_left > 0,
-                depth_top_right > 0
-            ),
-            np.logical_and(
-                depth_down_left > 0,
-                depth_down_left > 0
-            )
-        )
-        ids=ids[valid_depth]
-        depth_top_left,depth_top_right,depth_down_left,depth_down_right=depth_top_left[valid_depth],depth_top_right[valid_depth],\
-                                                                        depth_down_left[valid_depth],depth_down_right[valid_depth]
-
-        i,j,i_top_left,j_top_left=i[valid_depth],j[valid_depth],i_top_left[valid_depth],j_top_left[valid_depth]
-        
-        # Interpolation
-        dist_i_top_left = i - i_top_left.astype(np.float32)
-        dist_j_top_left = j - j_top_left.astype(np.float32)
-        w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
-        w_top_right = (1 - dist_i_top_left) * dist_j_top_left
-        w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
-        w_bottom_right = dist_i_top_left * dist_j_top_left
-
-        interpolated_depth = (
-            w_top_left * depth_top_left +
-            w_top_right * depth_top_right+
-            w_bottom_left * depth_down_left +
-            w_bottom_right * depth_down_right
-        )
-        return [interpolated_depth, ids]
 
-
-def reprojection(depth_map,kpt,dR,dt,K1_img2depth,K1,K2):
-    #warp kpt from img1 to img2
+def interpolate_depth(pos, depth):
+    # pos:[y,x]
+    ids = np.array(range(0, pos.shape[0]))
+
+    h, w = depth.shape
+
+    i = pos[:, 0]
+    j = pos[:, 1]
+    valid_corner = np.logical_and(
+        np.logical_and(i > 0, i < h - 1), np.logical_and(j > 0, j < w - 1)
+    )
+    i, j = i[valid_corner], j[valid_corner]
+    ids = ids[valid_corner]
+
+    i_top_left = np.floor(i).astype(np.int32)
+    j_top_left = np.floor(j).astype(np.int32)
+
+    i_top_right = np.floor(i).astype(np.int32)
+    j_top_right = np.ceil(j).astype(np.int32)
+
+    i_bottom_left = np.ceil(i).astype(np.int32)
+    j_bottom_left = np.floor(j).astype(np.int32)
+
+    i_bottom_right = np.ceil(i).astype(np.int32)
+    j_bottom_right = np.ceil(j).astype(np.int32)
+
+    # Valid depth
+    depth_top_left, depth_top_right, depth_down_left, depth_down_right = (
+        depth[i_top_left, j_top_left],
+        depth[i_top_right, j_top_right],
+        depth[i_bottom_left, j_bottom_left],
+        depth[i_bottom_right, j_bottom_right],
+    )
+
+    valid_depth = np.logical_and(
+        np.logical_and(depth_top_left > 0, depth_top_right > 0),
+        np.logical_and(depth_down_left > 0, depth_down_left > 0),
+    )
+    ids = ids[valid_depth]
+    depth_top_left, depth_top_right, depth_down_left, depth_down_right = (
+        depth_top_left[valid_depth],
+        depth_top_right[valid_depth],
+        depth_down_left[valid_depth],
+        depth_down_right[valid_depth],
+    )
+
+    i, j, i_top_left, j_top_left = (
+        i[valid_depth],
+        j[valid_depth],
+        i_top_left[valid_depth],
+        j_top_left[valid_depth],
+    )
+
+    # Interpolation
+    dist_i_top_left = i - i_top_left.astype(np.float32)
+    dist_j_top_left = j - j_top_left.astype(np.float32)
+    w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
+    w_top_right = (1 - dist_i_top_left) * dist_j_top_left
+    w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
+    w_bottom_right = dist_i_top_left * dist_j_top_left
+
+    interpolated_depth = (
+        w_top_left * depth_top_left
+        + w_top_right * depth_top_right
+        + w_bottom_left * depth_down_left
+        + w_bottom_right * depth_down_right
+    )
+    return [interpolated_depth, ids]
+
+
+def reprojection(depth_map, kpt, dR, dt, K1_img2depth, K1, K2):
+    # warp kpt from img1 to img2
     def swap_axis(data):
         return np.stack([data[:, 1], data[:, 0]], axis=-1)
 
-    kp_depth = unnorm_kp(K1_img2depth,kpt)
+    kp_depth = unnorm_kp(K1_img2depth, kpt)
     uv_depth = swap_axis(kp_depth)
-    z,valid_idx = interpolate_depth(uv_depth, depth_map)
+    z, valid_idx = interpolate_depth(uv_depth, depth_map)
 
-    norm_kp=norm_kpt(K1,kpt)
-    norm_kp_valid = np.concatenate([norm_kp[valid_idx, :], np.ones((len(valid_idx), 1))], axis=-1)
+    norm_kp = norm_kpt(K1, kpt)
+    norm_kp_valid = np.concatenate(
+        [norm_kp[valid_idx, :], np.ones((len(valid_idx), 1))], axis=-1
+    )
     xyz_valid = norm_kp_valid * z.reshape(-1, 1)
     xyz2 = np.matmul(xyz_valid, dR.T) + dt.reshape(1, 3)
     xy2 = xyz2[:, :2] / xyz2[:, 2:]
     kp2, valid = np.ones(kpt.shape) * 1e5, np.zeros(kpt.shape[0])
-    kp2[valid_idx] = unnorm_kp(K2,xy2)
+    kp2[valid_idx] = unnorm_kp(K2, xy2)
     valid[valid_idx] = 1
     return kp2, valid.astype(bool)
 
-def reprojection_2s(kp1, kp2,depth1, depth2, K1, K2, dR, dt, size1,size2):
-    #size:H*W
-    depth_size1,depth_size2 = [depth1.shape[0], depth1.shape[1]], [depth2.shape[0], depth2.shape[1]]
-    scale_1= [float(depth_size1[0]) / size1[0], float(depth_size1[1]) / size1[1], 1]
-    scale_2= [float(depth_size2[0]) / size2[0], float(depth_size2[1]) / size2[1], 1]
-    K1_img2depth, K2_img2depth = np.diag(np.asarray(scale_1)), np.diag(np.asarray(scale_2))
-    kp1_2_proj, valid1_2 = reprojection(depth1, kp1, dR, dt, K1_img2depth,K1,K2)
-    kp2_1_proj, valid2_1 = reprojection(depth2, kp2, dR.T, -np.matmul(dR.T, dt), K2_img2depth,K2,K1)
-    return [kp1_2_proj,kp2_1_proj],[valid1_2,valid2_1]
-
-def make_corr(kp1,kp2,desc1,desc2,depth1,depth2,K1,K2,dR,dt,size1,size2,corr_th,incorr_th,check_desc=False):
-    #make reprojection
-    [kp1_2,kp2_1],[valid1_2,valid2_1]=reprojection_2s(kp1,kp2,depth1,depth2,K1,K2,dR,dt,size1,size2)
+
+def reprojection_2s(kp1, kp2, depth1, depth2, K1, K2, dR, dt, size1, size2):
+    # size:H*W
+    depth_size1, depth_size2 = [depth1.shape[0], depth1.shape[1]], [
+        depth2.shape[0],
+        depth2.shape[1],
+    ]
+    scale_1 = [float(depth_size1[0]) / size1[0], float(depth_size1[1]) / size1[1], 1]
+    scale_2 = [float(depth_size2[0]) / size2[0], float(depth_size2[1]) / size2[1], 1]
+    K1_img2depth, K2_img2depth = np.diag(np.asarray(scale_1)), np.diag(
+        np.asarray(scale_2)
+    )
+    kp1_2_proj, valid1_2 = reprojection(depth1, kp1, dR, dt, K1_img2depth, K1, K2)
+    kp2_1_proj, valid2_1 = reprojection(
+        depth2, kp2, dR.T, -np.matmul(dR.T, dt), K2_img2depth, K2, K1
+    )
+    return [kp1_2_proj, kp2_1_proj], [valid1_2, valid2_1]
+
+
+def make_corr(
+    kp1,
+    kp2,
+    desc1,
+    desc2,
+    depth1,
+    depth2,
+    K1,
+    K2,
+    dR,
+    dt,
+    size1,
+    size2,
+    corr_th,
+    incorr_th,
+    check_desc=False,
+):
+    # make reprojection
+    [kp1_2, kp2_1], [valid1_2, valid2_1] = reprojection_2s(
+        kp1, kp2, depth1, depth2, K1, K2, dR, dt, size1, size2
+    )
     num_pts1, num_pts2 = kp1.shape[0], kp2.shape[0]
-    #reprojection error
-    dis_mat1=np.sqrt(abs((kp1 ** 2).sum(1,keepdims=True) + (kp2_1 ** 2).sum(1,keepdims=False)[np.newaxis] - 2 * np.matmul(kp1, kp2_1.T)))
-    dis_mat2 =np.sqrt(abs((kp2 ** 2).sum(1,keepdims=True) + (kp1_2 ** 2).sum(1,keepdims=False)[np.newaxis] - 2 * np.matmul(kp2,kp1_2.T)))
-    repro_error = np.maximum(dis_mat1,dis_mat2.T) #n1*n2
-    
+    # reprojection error
+    dis_mat1 = np.sqrt(
+        abs(
+            (kp1**2).sum(1, keepdims=True)
+            + (kp2_1**2).sum(1, keepdims=False)[np.newaxis]
+            - 2 * np.matmul(kp1, kp2_1.T)
+        )
+    )
+    dis_mat2 = np.sqrt(
+        abs(
+            (kp2**2).sum(1, keepdims=True)
+            + (kp1_2**2).sum(1, keepdims=False)[np.newaxis]
+            - 2 * np.matmul(kp2, kp1_2.T)
+        )
+    )
+    repro_error = np.maximum(dis_mat1, dis_mat2.T)  # n1*n2
+
     # find corr index
     nn_sort1 = np.argmin(repro_error, axis=1)
     nn_sort2 = np.argmin(repro_error, axis=0)
     mask_mutual = nn_sort2[nn_sort1] == np.arange(kp1.shape[0])
-    mask_inlier=np.take_along_axis(repro_error,indices=nn_sort1[:,np.newaxis],axis=-1).squeeze(1)<corr_th
-    mask = mask_mutual&mask_inlier
-    corr_index=np.stack([np.arange(num_pts1)[mask], np.arange(num_pts2)[nn_sort1[mask]]], axis=-1)
-    
+    mask_inlier = (
+        np.take_along_axis(
+            repro_error, indices=nn_sort1[:, np.newaxis], axis=-1
+        ).squeeze(1)
+        < corr_th
+    )
+    mask = mask_mutual & mask_inlier
+    corr_index = np.stack(
+        [np.arange(num_pts1)[mask], np.arange(num_pts2)[nn_sort1[mask]]], axis=-1
+    )
+
     if check_desc:
-        #filter kpt in same pos using desc distance(e.g. DoG kpt)
+        # filter kpt in same pos using desc distance(e.g. DoG kpt)
         x1_valid, x2_valid = kp1[corr_index[:, 0]], kp2[corr_index[:, 1]]
-        mask_samepos1=np.logical_and(x1_valid[:, 0,np.newaxis] == kp1[np.newaxis,:, 0],x1_valid[:, 1,np.newaxis] == kp1[np.newaxis,:, 1])
-        mask_samepos2=np.logical_and(x2_valid[:, 0,np.newaxis]== kp2[np.newaxis,:, 0],x2_valid[:, 1,np.newaxis] == kp2[np.newaxis,:, 1])
-        duplicated_mask=np.logical_or(mask_samepos1.sum(-1)>1,mask_samepos2.sum(-1)>1)
-        duplicated_index=np.nonzero(duplicated_mask)[0]
+        mask_samepos1 = np.logical_and(
+            x1_valid[:, 0, np.newaxis] == kp1[np.newaxis, :, 0],
+            x1_valid[:, 1, np.newaxis] == kp1[np.newaxis, :, 1],
+        )
+        mask_samepos2 = np.logical_and(
+            x2_valid[:, 0, np.newaxis] == kp2[np.newaxis, :, 0],
+            x2_valid[:, 1, np.newaxis] == kp2[np.newaxis, :, 1],
+        )
+        duplicated_mask = np.logical_or(
+            mask_samepos1.sum(-1) > 1, mask_samepos2.sum(-1) > 1
+        )
+        duplicated_index = np.nonzero(duplicated_mask)[0]
 
-        unique_corr_index=corr_index[~duplicated_mask]
-        clean_duplicated_corr=[]
+        unique_corr_index = corr_index[~duplicated_mask]
+        clean_duplicated_corr = []
         for index in duplicated_index:
-            cur_desc1, cur_desc2 = desc1[mask_samepos1[index]], desc2[mask_samepos2[index]]
+            cur_desc1, cur_desc2 = (
+                desc1[mask_samepos1[index]],
+                desc2[mask_samepos2[index]],
+            )
             cur_desc_mat = np.matmul(cur_desc1, cur_desc2.T)
-            cur_max_index =[np.argmax(cur_desc_mat)//cur_desc_mat.shape[1],np.argmax(cur_desc_mat)%cur_desc_mat.shape[1]]
-            clean_duplicated_corr.append(np.stack([np.arange(num_pts1)[mask_samepos1[index]][cur_max_index[0]],
-                        np.arange(num_pts2)[mask_samepos2[index]][cur_max_index[1]]]))
-        
-        clean_corr_index=unique_corr_index
-        if len(clean_duplicated_corr)!=0:
-            clean_duplicated_corr=np.stack(clean_duplicated_corr,axis=0)
-            clean_corr_index=np.concatenate([clean_corr_index,clean_duplicated_corr],axis=0)
+            cur_max_index = [
+                np.argmax(cur_desc_mat) // cur_desc_mat.shape[1],
+                np.argmax(cur_desc_mat) % cur_desc_mat.shape[1],
+            ]
+            clean_duplicated_corr.append(
+                np.stack(
+                    [
+                        np.arange(num_pts1)[mask_samepos1[index]][cur_max_index[0]],
+                        np.arange(num_pts2)[mask_samepos2[index]][cur_max_index[1]],
+                    ]
+                )
+            )
+
+        clean_corr_index = unique_corr_index
+        if len(clean_duplicated_corr) != 0:
+            clean_duplicated_corr = np.stack(clean_duplicated_corr, axis=0)
+            clean_corr_index = np.concatenate(
+                [clean_corr_index, clean_duplicated_corr], axis=0
+            )
     else:
-        clean_corr_index=corr_index
+        clean_corr_index = corr_index
     # find incorr
     mask_incorr1 = np.min(dis_mat2.T[valid1_2], axis=-1) > incorr_th
     mask_incorr2 = np.min(dis_mat1.T[valid2_1], axis=-1) > incorr_th
-    incorr_index1, incorr_index2 = np.arange(num_pts1)[valid1_2][mask_incorr1.squeeze()], \
-                                    np.arange(num_pts2)[valid2_1][mask_incorr2.squeeze()]
-
-    return clean_corr_index,incorr_index1,incorr_index2
+    incorr_index1, incorr_index2 = (
+        np.arange(num_pts1)[valid1_2][mask_incorr1.squeeze()],
+        np.arange(num_pts2)[valid2_1][mask_incorr2.squeeze()],
+    )
 
+    return clean_corr_index, incorr_index1, incorr_index2
diff --git a/third_party/SGMNet/utils/evaluation_utils.py b/third_party/SGMNet/utils/evaluation_utils.py
index 82c4715a192d3c361c849896b035cd91ee56dc42..a65a3075791857f586cc4f537dcb67eecc3ef681 100644
--- a/third_party/SGMNet/utils/evaluation_utils.py
+++ b/third_party/SGMNet/utils/evaluation_utils.py
@@ -2,57 +2,110 @@ import numpy as np
 import h5py
 import cv2
 
-def normalize_intrinsic(x,K):
-    #print(x,K)
-    return (x-K[:2,2])/np.diag(K)[:2]
 
-def normalize_size(x,size,scale=1):
-    size=size.reshape([1,2])
-    norm_fac=size.max()
-    return (x-size/2+0.5)/(norm_fac*scale)
+def normalize_intrinsic(x, K):
+    # print(x,K)
+    return (x - K[:2, 2]) / np.diag(K)[:2]
+
+
+def normalize_size(x, size, scale=1):
+    size = size.reshape([1, 2])
+    norm_fac = size.max()
+    return (x - size / 2 + 0.5) / (norm_fac * scale)
+
 
 def np_skew_symmetric(v):
     zero = np.zeros_like(v[:, 0])
-    M = np.stack([
-        zero, -v[:, 2], v[:, 1],
-        v[:, 2], zero, -v[:, 0],
-        -v[:, 1], v[:, 0], zero,
-    ], axis=1)
+    M = np.stack(
+        [
+            zero,
+            -v[:, 2],
+            v[:, 1],
+            v[:, 2],
+            zero,
+            -v[:, 0],
+            -v[:, 1],
+            v[:, 0],
+            zero,
+        ],
+        axis=1,
+    )
     return M
 
-def draw_points(img,points,color=(0,255,0),radius=3):
+
+def draw_points(img, points, color=(0, 255, 0), radius=3):
     dp = [(int(points[i, 0]), int(points[i, 1])) for i in range(points.shape[0])]
     for i in range(points.shape[0]):
-        cv2.circle(img, dp[i],radius=radius,color=color)
+        cv2.circle(img, dp[i], radius=radius, color=color)
     return img
-    
 
-def draw_match(img1, img2, corr1, corr2,inlier=[True],color=None,radius1=1,radius2=1,resize=None):
+
+def draw_match(
+    img1,
+    img2,
+    corr1,
+    corr2,
+    inlier=[True],
+    color=None,
+    radius1=1,
+    radius2=1,
+    resize=None,
+):
     if resize is not None:
-        scale1,scale2=[img1.shape[1]/resize[0],img1.shape[0]/resize[1]],[img2.shape[1]/resize[0],img2.shape[0]/resize[1]]
-        img1,img2=cv2.resize(img1, resize, interpolation=cv2.INTER_AREA),cv2.resize(img2, resize, interpolation=cv2.INTER_AREA) 
-        corr1,corr2=corr1/np.asarray(scale1)[np.newaxis],corr2/np.asarray(scale2)[np.newaxis]
-    corr1_key = [cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0])]
-    corr2_key = [cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0])]
+        scale1, scale2 = [img1.shape[1] / resize[0], img1.shape[0] / resize[1]], [
+            img2.shape[1] / resize[0],
+            img2.shape[0] / resize[1],
+        ]
+        img1, img2 = cv2.resize(img1, resize, interpolation=cv2.INTER_AREA), cv2.resize(
+            img2, resize, interpolation=cv2.INTER_AREA
+        )
+        corr1, corr2 = (
+            corr1 / np.asarray(scale1)[np.newaxis],
+            corr2 / np.asarray(scale2)[np.newaxis],
+        )
+    corr1_key = [
+        cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0])
+    ]
+    corr2_key = [
+        cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0])
+    ]
 
     assert len(corr1) == len(corr2)
 
     draw_matches = [cv2.DMatch(i, i, 0) for i in range(len(corr1))]
     if color is None:
-        color = [(0, 255, 0) if cur_inlier else (0,0,255) for cur_inlier in inlier]
-    if len(color)==1:
-        display = cv2.drawMatches(img1, corr1_key, img2, corr2_key, draw_matches, None,
-                              matchColor=color[0],
-                              singlePointColor=color[0],
-                              flags=4
-                              )
+        color = [(0, 255, 0) if cur_inlier else (0, 0, 255) for cur_inlier in inlier]
+    if len(color) == 1:
+        display = cv2.drawMatches(
+            img1,
+            corr1_key,
+            img2,
+            corr2_key,
+            draw_matches,
+            None,
+            matchColor=color[0],
+            singlePointColor=color[0],
+            flags=4,
+        )
     else:
-        height,width=max(img1.shape[0],img2.shape[0]),img1.shape[1]+img2.shape[1]
-        display=np.zeros([height,width,3],np.uint8)
-        display[:img1.shape[0],:img1.shape[1]]=img1
-        display[:img2.shape[0],img1.shape[1]:]=img2
+        height, width = max(img1.shape[0], img2.shape[0]), img1.shape[1] + img2.shape[1]
+        display = np.zeros([height, width, 3], np.uint8)
+        display[: img1.shape[0], : img1.shape[1]] = img1
+        display[: img2.shape[0], img1.shape[1] :] = img2
         for i in range(len(corr1)):
-            left_x,left_y,right_x,right_y=int(corr1[i][0]),int(corr1[i][1]),int(corr2[i][0]+img1.shape[1]),int(corr2[i][1])
-            cur_color=(int(color[i][0]),int(color[i][1]),int(color[i][2]))
-            cv2.line(display, (left_x,left_y), (right_x,right_y),cur_color,1,lineType=cv2.LINE_AA)
-    return display
\ No newline at end of file
+            left_x, left_y, right_x, right_y = (
+                int(corr1[i][0]),
+                int(corr1[i][1]),
+                int(corr2[i][0] + img1.shape[1]),
+                int(corr2[i][1]),
+            )
+            cur_color = (int(color[i][0]), int(color[i][1]), int(color[i][2]))
+            cv2.line(
+                display,
+                (left_x, left_y),
+                (right_x, right_y),
+                cur_color,
+                1,
+                lineType=cv2.LINE_AA,
+            )
+    return display
diff --git a/third_party/SGMNet/utils/fm_utils.py b/third_party/SGMNet/utils/fm_utils.py
index f9cbbeefe5d6b59c1ae1fa26cdaa42146ad22a74..900b73c42723cd9c5bcbef5c758deadcd0b309df 100644
--- a/third_party/SGMNet/utils/fm_utils.py
+++ b/third_party/SGMNet/utils/fm_utils.py
@@ -1,95 +1,100 @@
 import numpy as np
 
 
-def line_to_border(line,size):
-    #line:(a,b,c), ax+by+c=0
-    #size:(W,H) 
-    H,W=size[1],size[0]
-    a,b,c=line[0],line[1],line[2]
-    epsa=1e-8 if a>=0 else -1e-8
-    epsb=1e-8 if b>=0 else -1e-8
-    intersection_list=[]
-    
-    y_left=-c/(b+epsb)
-    y_right=(-c-a*(W-1))/(b+epsb)
-    x_top=-c/(a+epsa)
-    x_down=(-c-b*(H-1))/(a+epsa)
-
-    if y_left>=0 and y_left<=H-1:
-        intersection_list.append([0,y_left])
-    if y_right>=0 and y_right<=H-1:
-        intersection_list.append([W-1,y_right])
-    if x_top>=0 and x_top<=W-1:
-        intersection_list.append([x_top,0])
-    if x_down>=0 and x_down<=W-1:
-        intersection_list.append([x_down,H-1]) 
-    if len(intersection_list)!=2:
+def line_to_border(line, size):
+    # line:(a,b,c), ax+by+c=0
+    # size:(W,H)
+    H, W = size[1], size[0]
+    a, b, c = line[0], line[1], line[2]
+    epsa = 1e-8 if a >= 0 else -1e-8
+    epsb = 1e-8 if b >= 0 else -1e-8
+    intersection_list = []
+
+    y_left = -c / (b + epsb)
+    y_right = (-c - a * (W - 1)) / (b + epsb)
+    x_top = -c / (a + epsa)
+    x_down = (-c - b * (H - 1)) / (a + epsa)
+
+    if y_left >= 0 and y_left <= H - 1:
+        intersection_list.append([0, y_left])
+    if y_right >= 0 and y_right <= H - 1:
+        intersection_list.append([W - 1, y_right])
+    if x_top >= 0 and x_top <= W - 1:
+        intersection_list.append([x_top, 0])
+    if x_down >= 0 and x_down <= W - 1:
+        intersection_list.append([x_down, H - 1])
+    if len(intersection_list) != 2:
         return None
-    intersection_list=np.asarray(intersection_list)
+    intersection_list = np.asarray(intersection_list)
     return intersection_list
 
+
 def find_point_in_line(end_point):
-    x_span,y_span=end_point[1,0]-end_point[0,0],end_point[1,1]-end_point[0,1]
-    mv=np.random.uniform()
-    point=np.asarray([end_point[0,0]+x_span*mv,end_point[0,1]+y_span*mv])
+    x_span, y_span = (
+        end_point[1, 0] - end_point[0, 0],
+        end_point[1, 1] - end_point[0, 1],
+    )
+    mv = np.random.uniform()
+    point = np.asarray([end_point[0, 0] + x_span * mv, end_point[0, 1] + y_span * mv])
     return point
 
-def epi_line(point,F):
-    homo=np.concatenate([point,np.ones([len(point),1])],axis=-1)
-    epi=np.matmul(homo,F.T)
+
+def epi_line(point, F):
+    homo = np.concatenate([point, np.ones([len(point), 1])], axis=-1)
+    epi = np.matmul(homo, F.T)
     return epi
 
-def dis_point_to_line(line,point):
-    homo=np.concatenate([point,np.ones([len(point),1])],axis=-1)
-    dis=line*homo
-    dis=dis.sum(axis=-1)/(np.linalg.norm(line[:,:2],axis=-1)+1e-8)
+
+def dis_point_to_line(line, point):
+    homo = np.concatenate([point, np.ones([len(point), 1])], axis=-1)
+    dis = line * homo
+    dis = dis.sum(axis=-1) / (np.linalg.norm(line[:, :2], axis=-1) + 1e-8)
     return abs(dis)
 
-def SGD_oneiter(F1,F2,size1,size2):
-    H1,W1=size1[1],size1[0]
+
+def SGD_oneiter(F1, F2, size1, size2):
+    H1, W1 = size1[1], size1[0]
     factor1 = 1 / np.linalg.norm(size1)
     factor2 = 1 / np.linalg.norm(size2)
-    p0=np.asarray([(W1-1)*np.random.uniform(),(H1-1)*np.random.uniform()])
-    epi1=epi_line(p0[np.newaxis],F1)[0]
-    border_point1=line_to_border(epi1,size2)
+    p0 = np.asarray([(W1 - 1) * np.random.uniform(), (H1 - 1) * np.random.uniform()])
+    epi1 = epi_line(p0[np.newaxis], F1)[0]
+    border_point1 = line_to_border(epi1, size2)
     if border_point1 is None:
         return -1
-    
-    p1=find_point_in_line(border_point1)
-    epi2=epi_line(p0[np.newaxis],F2)
-    d1=dis_point_to_line(epi2,p1[np.newaxis])[0]*factor2
-    epi3=epi_line(p1[np.newaxis],F2.T)
-    d2=dis_point_to_line(epi3,p0[np.newaxis])[0]*factor1
-    return (d1+d2)/2
-
-def compute_SGD(F1,F2,size1,size2):
+
+    p1 = find_point_in_line(border_point1)
+    epi2 = epi_line(p0[np.newaxis], F2)
+    d1 = dis_point_to_line(epi2, p1[np.newaxis])[0] * factor2
+    epi3 = epi_line(p1[np.newaxis], F2.T)
+    d2 = dis_point_to_line(epi3, p0[np.newaxis])[0] * factor1
+    return (d1 + d2) / 2
+
+
+def compute_SGD(F1, F2, size1, size2):
     np.random.seed(1234)
-    N=1000
-    max_iter=N*10
-    count,sgd=0,0
+    N = 1000
+    max_iter = N * 10
+    count, sgd = 0, 0
     for i in range(max_iter):
-        d1=SGD_oneiter(F1,F2,size1,size2)
-        if d1<0:
+        d1 = SGD_oneiter(F1, F2, size1, size2)
+        if d1 < 0:
             continue
-        d2=SGD_oneiter(F2,F1,size1,size2)
-        if d2<0:
+        d2 = SGD_oneiter(F2, F1, size1, size2)
+        if d2 < 0:
             continue
-        count+=1
-        sgd+=(d1+d2)/2
-        if count==N:
+        count += 1
+        sgd += (d1 + d2) / 2
+        if count == N:
             break
-    if count==0:
+    if count == 0:
         return 1
     else:
-        return sgd/count
-
-def compute_inlier_rate(x1,x2,size1,size2,F_gt,th=0.003):
-    t1,t2=np.linalg.norm(size1)*th,np.linalg.norm(size2)*th
-    epi1,epi2=epi_line(x1,F_gt),epi_line(x2,F_gt.T)
-    dis1,dis2=dis_point_to_line(epi1,x2),dis_point_to_line(epi2,x1)
-    mask_inlier=np.logical_and(dis1<t2,dis2<t1)
-    return mask_inlier.mean() if len(mask_inlier)!=0 else 0
-
-
+        return sgd / count
 
 
+def compute_inlier_rate(x1, x2, size1, size2, F_gt, th=0.003):
+    t1, t2 = np.linalg.norm(size1) * th, np.linalg.norm(size2) * th
+    epi1, epi2 = epi_line(x1, F_gt), epi_line(x2, F_gt.T)
+    dis1, dis2 = dis_point_to_line(epi1, x2), dis_point_to_line(epi2, x1)
+    mask_inlier = np.logical_and(dis1 < t2, dis2 < t1)
+    return mask_inlier.mean() if len(mask_inlier) != 0 else 0
diff --git a/third_party/SGMNet/utils/metrics.py b/third_party/SGMNet/utils/metrics.py
index 060a7c09e1f1ecb54a8d9bb77c04555b7bc20857..0c4ddf4f0b9c5d045b627dea1c266b863246e1fd 100644
--- a/third_party/SGMNet/utils/metrics.py
+++ b/third_party/SGMNet/utils/metrics.py
@@ -14,12 +14,12 @@ def evaluate_R_t(R_gt, t_gt, R, t):
     q = quaternion_from_matrix(R)
     q = q / (np.linalg.norm(q) + eps)
     q_gt = q_gt / (np.linalg.norm(q_gt) + eps)
-    loss_q = np.maximum(eps, (1.0 - np.sum(q * q_gt)**2))
-    err_q = np.arccos(1 -  2*loss_q)
+    loss_q = np.maximum(eps, (1.0 - np.sum(q * q_gt) ** 2))
+    err_q = np.arccos(1 - 2 * loss_q)
 
     t = t / (np.linalg.norm(t) + eps)
     t_gt = t_gt / (np.linalg.norm(t_gt) + eps)
-    loss_t = np.maximum(eps, (1.0 - np.sum(t * t_gt)**2))
+    loss_t = np.maximum(eps, (1.0 - np.sum(t * t_gt) ** 2))
     err_t = np.arccos(np.sqrt(1 - loss_t))
     return np.rad2deg(err_q), np.rad2deg(err_t)
 
@@ -28,33 +28,36 @@ def pose_auc(errors, thresholds):
     sort_idx = np.argsort(errors)
     errors = np.array(errors.copy())[sort_idx]
     recall = (np.arange(len(errors)) + 1) / len(errors)
-    errors = np.r_[0., errors]
-    recall = np.r_[0., recall]
+    errors = np.r_[0.0, errors]
+    recall = np.r_[0.0, recall]
     aucs = []
     for t in thresholds[1:]:
         last_index = np.searchsorted(errors, t)
-        r = np.r_[recall[:last_index], recall[last_index-1]]
+        r = np.r_[recall[:last_index], recall[last_index - 1]]
         e = np.r_[errors[:last_index], t]
-        aucs.append(np.trapz(r, x=e)/t)
+        aucs.append(np.trapz(r, x=e) / t)
     return aucs
 
 
-def approx_pose_auc(errors,thresholds):
+def approx_pose_auc(errors, thresholds):
     qt_acc_hist, _ = np.histogram(errors, thresholds)
     num_pair = float(len(errors))
     qt_acc_hist = qt_acc_hist.astype(float) / num_pair
     qt_acc = np.cumsum(qt_acc_hist)
-    approx_aucs=[np.mean(qt_acc[:i]) for i in range(1, len(thresholds))]
+    approx_aucs = [np.mean(qt_acc[:i]) for i in range(1, len(thresholds))]
     return approx_aucs
 
 
-def compute_epi_inlier(x1,x2,E,inlier_th):
-    num_pts1,num_pts2=x1.shape[0],x2.shape[0]
+def compute_epi_inlier(x1, x2, E, inlier_th):
+    num_pts1, num_pts2 = x1.shape[0], x2.shape[0]
     x1_h = np.concatenate([x1, np.ones([num_pts1, 1])], -1)
     x2_h = np.concatenate([x2, np.ones([num_pts2, 1])], -1)
-    ep_line1 = x1_h@E.T
-    ep_line2=  x2_h@E
-    norm_factor=(1/np.sqrt((ep_line1[:,:2]**2).sum(1))+1/np.sqrt((ep_line2[:,:2]**2).sum(1)))/2
-    dis=abs((ep_line1*x2_h).sum(-1))*norm_factor
-    inlier_mask=dis<inlier_th
-    return inlier_mask
\ No newline at end of file
+    ep_line1 = x1_h @ E.T
+    ep_line2 = x2_h @ E
+    norm_factor = (
+        1 / np.sqrt((ep_line1[:, :2] ** 2).sum(1))
+        + 1 / np.sqrt((ep_line2[:, :2] ** 2).sum(1))
+    ) / 2
+    dis = abs((ep_line1 * x2_h).sum(-1)) * norm_factor
+    inlier_mask = dis < inlier_th
+    return inlier_mask
diff --git a/third_party/SGMNet/utils/train_utils.py b/third_party/SGMNet/utils/train_utils.py
index 3572110685bae4e2bd091fdf66d8e7515ef10c1f..f0a843e8dcc916dc1d9e87892650b050d47f1fb6 100644
--- a/third_party/SGMNet/utils/train_utils.py
+++ b/third_party/SGMNet/utils/train_utils.py
@@ -3,22 +3,24 @@ import torch.distributed as dist
 import numpy as np
 import cv2
 
+
 def parse_pair_seq(pair_num_list):
-    #generate pair_seq_list: [#pair_num]:seq
+    # generate pair_seq_list: [#pair_num]:seq
     #              accu_pair_num: dict{seq_name:accumulated_pair}
-    pair_num=int(pair_num_list[0,1])
-    pair_num_list=pair_num_list[1:]
-    pair_seq_list=[]
-    cursor=0
-    accu_pair_num={}
+    pair_num = int(pair_num_list[0, 1])
+    pair_num_list = pair_num_list[1:]
+    pair_seq_list = []
+    cursor = 0
+    accu_pair_num = {}
     for line in pair_num_list:
-       seq,seq_pair_num=line[0],int(line[1])
-       for _ in range(seq_pair_num):
-          pair_seq_list.append(seq)
-       accu_pair_num[seq]=cursor
-       cursor+=seq_pair_num
-    assert pair_num==cursor
-    return pair_seq_list,accu_pair_num
+        seq, seq_pair_num = line[0], int(line[1])
+        for _ in range(seq_pair_num):
+            pair_seq_list.append(seq)
+        accu_pair_num[seq] = cursor
+        cursor += seq_pair_num
+    assert pair_num == cursor
+    return pair_seq_list, accu_pair_num
+
 
 def tocuda(data):
     # convert tensor data in dictionary to cuda when it is a tensor
@@ -26,19 +28,23 @@ def tocuda(data):
         if type(data[key]) == torch.Tensor:
             data[key] = data[key].cuda()
     return data
-    
-def reduce_tensor(tensor,op='mean'): 
+
+
+def reduce_tensor(tensor, op="mean"):
     rt = tensor.detach()
     dist.all_reduce(rt, op=dist.ReduceOp.SUM)
-    if op=='mean':
+    if op == "mean":
         rt /= dist.get_world_size()
     return rt
 
+
 def get_rnd_homography(batch_size, pert_ratio=0.25):
     corners = np.array([[-1, 1], [1, 1], [-1, -1], [1, -1]], dtype=np.float32)
     homo_tower = []
     for _ in range(batch_size):
-        rnd_pert = np.random.uniform(-2 * pert_ratio, 2 * pert_ratio, (4, 2)).astype(np.float32)
+        rnd_pert = np.random.uniform(-2 * pert_ratio, 2 * pert_ratio, (4, 2)).astype(
+            np.float32
+        )
         pert_corners = corners + rnd_pert
         M = cv2.getPerspectiveTransform(corners, pert_corners)
         homo_tower.append(M)
diff --git a/third_party/SGMNet/utils/transformations.py b/third_party/SGMNet/utils/transformations.py
index e341f8ff976f0ea223717477b1e9bdf15bb03e47..2ed1be31e82283204b10d54376a25a1313a81244 100644
--- a/third_party/SGMNet/utils/transformations.py
+++ b/third_party/SGMNet/utils/transformations.py
@@ -199,8 +199,8 @@ import math
 
 import numpy
 
-__version__ = '2015.07.18'
-__docformat__ = 'restructuredtext en'
+__version__ = "2015.07.18"
+__docformat__ = "restructuredtext en"
 __all__ = ()
 
 
@@ -331,9 +331,13 @@ def rotation_matrix(angle, direction, point=None):
     R = numpy.diag([cosa, cosa, cosa])
     R += numpy.outer(direction, direction) * (1.0 - cosa)
     direction *= sina
-    R += numpy.array([[ 0.0,         -direction[2],  direction[1]],
-                      [ direction[2], 0.0,          -direction[0]],
-                      [-direction[1], direction[0],  0.0]])
+    R += numpy.array(
+        [
+            [0.0, -direction[2], direction[1]],
+            [direction[2], 0.0, -direction[0]],
+            [-direction[1], direction[0], 0.0],
+        ]
+    )
     M = numpy.identity(4)
     M[:3, :3] = R
     if point is not None:
@@ -374,11 +378,11 @@ def rotation_from_matrix(matrix):
     # rotation angle depending on direction
     cosa = (numpy.trace(R33) - 1.0) / 2.0
     if abs(direction[2]) > 1e-8:
-        sina = (R[1, 0] + (cosa-1.0)*direction[0]*direction[1]) / direction[2]
+        sina = (R[1, 0] + (cosa - 1.0) * direction[0] * direction[1]) / direction[2]
     elif abs(direction[1]) > 1e-8:
-        sina = (R[0, 2] + (cosa-1.0)*direction[0]*direction[2]) / direction[1]
+        sina = (R[0, 2] + (cosa - 1.0) * direction[0] * direction[2]) / direction[1]
     else:
-        sina = (R[2, 1] + (cosa-1.0)*direction[1]*direction[2]) / direction[0]
+        sina = (R[2, 1] + (cosa - 1.0) * direction[1] * direction[2]) / direction[0]
     angle = math.atan2(sina, cosa)
     return angle, direction, point
 
@@ -458,8 +462,7 @@ def scale_from_matrix(matrix):
     return factor, origin, direction
 
 
-def projection_matrix(point, normal, direction=None,
-                      perspective=None, pseudo=False):
+def projection_matrix(point, normal, direction=None, perspective=None, pseudo=False):
     """Return matrix to project onto plane defined by point and normal.
 
     Using either perspective point, projection direction, or none of both.
@@ -495,14 +498,13 @@ def projection_matrix(point, normal, direction=None,
     normal = unit_vector(normal[:3])
     if perspective is not None:
         # perspective projection
-        perspective = numpy.array(perspective[:3], dtype=numpy.float64,
-                                  copy=False)
-        M[0, 0] = M[1, 1] = M[2, 2] = numpy.dot(perspective-point, normal)
+        perspective = numpy.array(perspective[:3], dtype=numpy.float64, copy=False)
+        M[0, 0] = M[1, 1] = M[2, 2] = numpy.dot(perspective - point, normal)
         M[:3, :3] -= numpy.outer(perspective, normal)
         if pseudo:
             # preserve relative depth
             M[:3, :3] -= numpy.outer(normal, normal)
-            M[:3, 3] = numpy.dot(point, normal) * (perspective+normal)
+            M[:3, 3] = numpy.dot(point, normal) * (perspective + normal)
         else:
             M[:3, 3] = numpy.dot(point, normal) * perspective
         M[3, :3] = -normal
@@ -582,11 +584,10 @@ def projection_from_matrix(matrix, pseudo=False):
         # perspective projection
         i = numpy.where(abs(numpy.real(w)) > 1e-8)[0]
         if not len(i):
-            raise ValueError(
-                "no eigenvector not corresponding to eigenvalue 0")
+            raise ValueError("no eigenvector not corresponding to eigenvalue 0")
         point = numpy.real(V[:, i[-1]]).squeeze()
         point /= point[3]
-        normal = - M[3, :3]
+        normal = -M[3, :3]
         perspective = M[:3, 3] / numpy.dot(point[:3], normal)
         if pseudo:
             perspective -= normal
@@ -633,15 +634,19 @@ def clip_matrix(left, right, bottom, top, near, far, perspective=False):
         if near <= _EPS:
             raise ValueError("invalid frustum: near <= 0")
         t = 2.0 * near
-        M = [[t/(left-right), 0.0, (right+left)/(right-left), 0.0],
-             [0.0, t/(bottom-top), (top+bottom)/(top-bottom), 0.0],
-             [0.0, 0.0, (far+near)/(near-far), t*far/(far-near)],
-             [0.0, 0.0, -1.0, 0.0]]
+        M = [
+            [t / (left - right), 0.0, (right + left) / (right - left), 0.0],
+            [0.0, t / (bottom - top), (top + bottom) / (top - bottom), 0.0],
+            [0.0, 0.0, (far + near) / (near - far), t * far / (far - near)],
+            [0.0, 0.0, -1.0, 0.0],
+        ]
     else:
-        M = [[2.0/(right-left), 0.0, 0.0, (right+left)/(left-right)],
-             [0.0, 2.0/(top-bottom), 0.0, (top+bottom)/(bottom-top)],
-             [0.0, 0.0, 2.0/(far-near), (far+near)/(near-far)],
-             [0.0, 0.0, 0.0, 1.0]]
+        M = [
+            [2.0 / (right - left), 0.0, 0.0, (right + left) / (left - right)],
+            [0.0, 2.0 / (top - bottom), 0.0, (top + bottom) / (bottom - top)],
+            [0.0, 0.0, 2.0 / (far - near), (far + near) / (near - far)],
+            [0.0, 0.0, 0.0, 1.0],
+        ]
     return numpy.array(M)
 
 
@@ -761,7 +766,7 @@ def decompose_matrix(matrix):
     if not numpy.linalg.det(P):
         raise ValueError("matrix is singular")
 
-    scale = numpy.zeros((3, ))
+    scale = numpy.zeros((3,))
     shear = [0.0, 0.0, 0.0]
     angles = [0.0, 0.0, 0.0]
 
@@ -799,15 +804,16 @@ def decompose_matrix(matrix):
         angles[0] = math.atan2(row[1, 2], row[2, 2])
         angles[2] = math.atan2(row[0, 1], row[0, 0])
     else:
-        #angles[0] = math.atan2(row[1, 0], row[1, 1])
+        # angles[0] = math.atan2(row[1, 0], row[1, 1])
         angles[0] = math.atan2(-row[2, 1], row[1, 1])
         angles[2] = 0.0
 
     return scale, shear, angles, translate, perspective
 
 
-def compose_matrix(scale=None, shear=None, angles=None, translate=None,
-                   perspective=None):
+def compose_matrix(
+    scale=None, shear=None, angles=None, translate=None, perspective=None
+):
     """Return transformation matrix from sequence of transformations.
 
     This is the inverse of the decompose_matrix function.
@@ -841,7 +847,7 @@ def compose_matrix(scale=None, shear=None, angles=None, translate=None,
         T[:3, 3] = translate[:3]
         M = numpy.dot(M, T)
     if angles is not None:
-        R = euler_matrix(angles[0], angles[1], angles[2], 'sxyz')
+        R = euler_matrix(angles[0], angles[1], angles[2], "sxyz")
         M = numpy.dot(M, R)
     if shear is not None:
         Z = numpy.identity(4)
@@ -879,11 +885,14 @@ def orthogonalization_matrix(lengths, angles):
     sina, sinb, _ = numpy.sin(angles)
     cosa, cosb, cosg = numpy.cos(angles)
     co = (cosa * cosb - cosg) / (sina * sinb)
-    return numpy.array([
-        [ a*sinb*math.sqrt(1.0-co*co),  0.0,    0.0, 0.0],
-        [-a*sinb*co,                    b*sina, 0.0, 0.0],
-        [ a*cosb,                       b*cosa, c,   0.0],
-        [ 0.0,                          0.0,    0.0, 1.0]])
+    return numpy.array(
+        [
+            [a * sinb * math.sqrt(1.0 - co * co), 0.0, 0.0, 0.0],
+            [-a * sinb * co, b * sina, 0.0, 0.0],
+            [a * cosb, b * cosa, c, 0.0],
+            [0.0, 0.0, 0.0, 1.0],
+        ]
+    )
 
 
 def affine_matrix_from_points(v0, v1, shear=True, scale=True, usesvd=True):
@@ -936,11 +945,11 @@ def affine_matrix_from_points(v0, v1, shear=True, scale=True, usesvd=True):
 
     # move centroids to origin
     t0 = -numpy.mean(v0, axis=1)
-    M0 = numpy.identity(ndims+1)
+    M0 = numpy.identity(ndims + 1)
     M0[:ndims, ndims] = t0
     v0 += t0.reshape(ndims, 1)
     t1 = -numpy.mean(v1, axis=1)
-    M1 = numpy.identity(ndims+1)
+    M1 = numpy.identity(ndims + 1)
     M1[:ndims, ndims] = t1
     v1 += t1.reshape(ndims, 1)
 
@@ -950,10 +959,10 @@ def affine_matrix_from_points(v0, v1, shear=True, scale=True, usesvd=True):
         u, s, vh = numpy.linalg.svd(A.T)
         vh = vh[:ndims].T
         B = vh[:ndims]
-        C = vh[ndims:2*ndims]
+        C = vh[ndims : 2 * ndims]
         t = numpy.dot(C, numpy.linalg.pinv(B))
         t = numpy.concatenate((t, numpy.zeros((ndims, 1))), axis=1)
-        M = numpy.vstack((t, ((0.0,)*ndims) + (1.0,)))
+        M = numpy.vstack((t, ((0.0,) * ndims) + (1.0,)))
     elif usesvd or ndims != 3:
         # Rigid transformation via SVD of covariance matrix
         u, s, vh = numpy.linalg.svd(numpy.dot(v1, v0.T))
@@ -961,10 +970,10 @@ def affine_matrix_from_points(v0, v1, shear=True, scale=True, usesvd=True):
         R = numpy.dot(u, vh)
         if numpy.linalg.det(R) < 0.0:
             # R does not constitute right handed system
-            R -= numpy.outer(u[:, ndims-1], vh[ndims-1, :]*2.0)
+            R -= numpy.outer(u[:, ndims - 1], vh[ndims - 1, :] * 2.0)
             s[-1] *= -1.0
         # homogeneous transformation matrix
-        M = numpy.identity(ndims+1)
+        M = numpy.identity(ndims + 1)
         M[:ndims, :ndims] = R
     else:
         # Rigid transformation matrix via quaternion
@@ -972,10 +981,12 @@ def affine_matrix_from_points(v0, v1, shear=True, scale=True, usesvd=True):
         xx, yy, zz = numpy.sum(v0 * v1, axis=1)
         xy, yz, zx = numpy.sum(v0 * numpy.roll(v1, -1, axis=0), axis=1)
         xz, yx, zy = numpy.sum(v0 * numpy.roll(v1, -2, axis=0), axis=1)
-        N = [[xx+yy+zz, 0.0,      0.0,      0.0],
-             [yz-zy,    xx-yy-zz, 0.0,      0.0],
-             [zx-xz,    xy+yx,    yy-xx-zz, 0.0],
-             [xy-yx,    zx+xz,    yz+zy,    zz-xx-yy]]
+        N = [
+            [xx + yy + zz, 0.0, 0.0, 0.0],
+            [yz - zy, xx - yy - zz, 0.0, 0.0],
+            [zx - xz, xy + yx, yy - xx - zz, 0.0],
+            [xy - yx, zx + xz, yz + zy, zz - xx - yy],
+        ]
         # quaternion: eigenvector corresponding to most positive eigenvalue
         w, V = numpy.linalg.eigh(N)
         q = V[:, numpy.argmax(w)]
@@ -1042,11 +1053,10 @@ def superimposition_matrix(v0, v1, scale=False, usesvd=True):
     """
     v0 = numpy.array(v0, dtype=numpy.float64, copy=False)[:3]
     v1 = numpy.array(v1, dtype=numpy.float64, copy=False)[:3]
-    return affine_matrix_from_points(v0, v1, shear=False,
-                                     scale=scale, usesvd=usesvd)
+    return affine_matrix_from_points(v0, v1, shear=False, scale=scale, usesvd=usesvd)
 
 
-def euler_matrix(ai, aj, ak, axes='sxyz'):
+def euler_matrix(ai, aj, ak, axes="sxyz"):
     """Return homogeneous rotation matrix from Euler angles and axis sequence.
 
     ai, aj, ak : Euler's roll, pitch and yaw angles
@@ -1072,8 +1082,8 @@ def euler_matrix(ai, aj, ak, axes='sxyz'):
         firstaxis, parity, repetition, frame = axes
 
     i = firstaxis
-    j = _NEXT_AXIS[i+parity]
-    k = _NEXT_AXIS[i-parity+1]
+    j = _NEXT_AXIS[i + parity]
+    k = _NEXT_AXIS[i - parity + 1]
 
     if frame:
         ai, ak = ak, ai
@@ -1082,34 +1092,34 @@ def euler_matrix(ai, aj, ak, axes='sxyz'):
 
     si, sj, sk = math.sin(ai), math.sin(aj), math.sin(ak)
     ci, cj, ck = math.cos(ai), math.cos(aj), math.cos(ak)
-    cc, cs = ci*ck, ci*sk
-    sc, ss = si*ck, si*sk
+    cc, cs = ci * ck, ci * sk
+    sc, ss = si * ck, si * sk
 
     M = numpy.identity(4)
     if repetition:
         M[i, i] = cj
-        M[i, j] = sj*si
-        M[i, k] = sj*ci
-        M[j, i] = sj*sk
-        M[j, j] = -cj*ss+cc
-        M[j, k] = -cj*cs-sc
-        M[k, i] = -sj*ck
-        M[k, j] = cj*sc+cs
-        M[k, k] = cj*cc-ss
+        M[i, j] = sj * si
+        M[i, k] = sj * ci
+        M[j, i] = sj * sk
+        M[j, j] = -cj * ss + cc
+        M[j, k] = -cj * cs - sc
+        M[k, i] = -sj * ck
+        M[k, j] = cj * sc + cs
+        M[k, k] = cj * cc - ss
     else:
-        M[i, i] = cj*ck
-        M[i, j] = sj*sc-cs
-        M[i, k] = sj*cc+ss
-        M[j, i] = cj*sk
-        M[j, j] = sj*ss+cc
-        M[j, k] = sj*cs-sc
+        M[i, i] = cj * ck
+        M[i, j] = sj * sc - cs
+        M[i, k] = sj * cc + ss
+        M[j, i] = cj * sk
+        M[j, j] = sj * ss + cc
+        M[j, k] = sj * cs - sc
         M[k, i] = -sj
-        M[k, j] = cj*si
-        M[k, k] = cj*ci
+        M[k, j] = cj * si
+        M[k, k] = cj * ci
     return M
 
 
-def euler_from_matrix(matrix, axes='sxyz'):
+def euler_from_matrix(matrix, axes="sxyz"):
     """Return Euler angles from rotation matrix for specified axis sequence.
 
     axes : One of 24 axis sequences as string or encoded tuple
@@ -1135,29 +1145,29 @@ def euler_from_matrix(matrix, axes='sxyz'):
         firstaxis, parity, repetition, frame = axes
 
     i = firstaxis
-    j = _NEXT_AXIS[i+parity]
-    k = _NEXT_AXIS[i-parity+1]
+    j = _NEXT_AXIS[i + parity]
+    k = _NEXT_AXIS[i - parity + 1]
 
     M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:3, :3]
     if repetition:
-        sy = math.sqrt(M[i, j]*M[i, j] + M[i, k]*M[i, k])
+        sy = math.sqrt(M[i, j] * M[i, j] + M[i, k] * M[i, k])
         if sy > _EPS:
-            ax = math.atan2( M[i, j],  M[i, k])
-            ay = math.atan2( sy,       M[i, i])
-            az = math.atan2( M[j, i], -M[k, i])
+            ax = math.atan2(M[i, j], M[i, k])
+            ay = math.atan2(sy, M[i, i])
+            az = math.atan2(M[j, i], -M[k, i])
         else:
-            ax = math.atan2(-M[j, k],  M[j, j])
-            ay = math.atan2( sy,       M[i, i])
+            ax = math.atan2(-M[j, k], M[j, j])
+            ay = math.atan2(sy, M[i, i])
             az = 0.0
     else:
-        cy = math.sqrt(M[i, i]*M[i, i] + M[j, i]*M[j, i])
+        cy = math.sqrt(M[i, i] * M[i, i] + M[j, i] * M[j, i])
         if cy > _EPS:
-            ax = math.atan2( M[k, j],  M[k, k])
-            ay = math.atan2(-M[k, i],  cy)
-            az = math.atan2( M[j, i],  M[i, i])
+            ax = math.atan2(M[k, j], M[k, k])
+            ay = math.atan2(-M[k, i], cy)
+            az = math.atan2(M[j, i], M[i, i])
         else:
-            ax = math.atan2(-M[j, k],  M[j, j])
-            ay = math.atan2(-M[k, i],  cy)
+            ax = math.atan2(-M[j, k], M[j, j])
+            ay = math.atan2(-M[k, i], cy)
             az = 0.0
 
     if parity:
@@ -1167,7 +1177,7 @@ def euler_from_matrix(matrix, axes='sxyz'):
     return ax, ay, az
 
 
-def euler_from_quaternion(quaternion, axes='sxyz'):
+def euler_from_quaternion(quaternion, axes="sxyz"):
     """Return Euler angles from quaternion for specified axis sequence.
 
     >>> angles = euler_from_quaternion([0.99810947, 0.06146124, 0, 0])
@@ -1178,7 +1188,7 @@ def euler_from_quaternion(quaternion, axes='sxyz'):
     return euler_from_matrix(quaternion_matrix(quaternion), axes)
 
 
-def quaternion_from_euler(ai, aj, ak, axes='sxyz'):
+def quaternion_from_euler(ai, aj, ak, axes="sxyz"):
     """Return quaternion from Euler angles and axis sequence.
 
     ai, aj, ak : Euler's roll, pitch and yaw angles
@@ -1196,8 +1206,8 @@ def quaternion_from_euler(ai, aj, ak, axes='sxyz'):
         firstaxis, parity, repetition, frame = axes
 
     i = firstaxis + 1
-    j = _NEXT_AXIS[i+parity-1] + 1
-    k = _NEXT_AXIS[i-parity] + 1
+    j = _NEXT_AXIS[i + parity - 1] + 1
+    k = _NEXT_AXIS[i - parity] + 1
 
     if frame:
         ai, ak = ak, ai
@@ -1213,22 +1223,22 @@ def quaternion_from_euler(ai, aj, ak, axes='sxyz'):
     sj = math.sin(aj)
     ck = math.cos(ak)
     sk = math.sin(ak)
-    cc = ci*ck
-    cs = ci*sk
-    sc = si*ck
-    ss = si*sk
+    cc = ci * ck
+    cs = ci * sk
+    sc = si * ck
+    ss = si * sk
 
-    q = numpy.empty((4, ))
+    q = numpy.empty((4,))
     if repetition:
-        q[0] = cj*(cc - ss)
-        q[i] = cj*(cs + sc)
-        q[j] = sj*(cc + ss)
-        q[k] = sj*(cs - sc)
+        q[0] = cj * (cc - ss)
+        q[i] = cj * (cs + sc)
+        q[j] = sj * (cc + ss)
+        q[k] = sj * (cs - sc)
     else:
-        q[0] = cj*cc + sj*ss
-        q[i] = cj*sc - sj*cs
-        q[j] = cj*ss + sj*cc
-        q[k] = cj*cs - sj*sc
+        q[0] = cj * cc + sj * ss
+        q[i] = cj * sc - sj * cs
+        q[j] = cj * ss + sj * cc
+        q[k] = cj * cs - sj * sc
     if parity:
         q[j] *= -1.0
 
@@ -1246,8 +1256,8 @@ def quaternion_about_axis(angle, axis):
     q = numpy.array([0.0, axis[0], axis[1], axis[2]])
     qlen = vector_norm(q)
     if qlen > _EPS:
-        q *= math.sin(angle/2.0) / qlen
-    q[0] = math.cos(angle/2.0)
+        q *= math.sin(angle / 2.0) / qlen
+    q[0] = math.cos(angle / 2.0)
     return q
 
 
@@ -1271,11 +1281,14 @@ def quaternion_matrix(quaternion):
         return numpy.identity(4)
     q *= math.sqrt(2.0 / n)
     q = numpy.outer(q, q)
-    return numpy.array([
-        [1.0-q[2, 2]-q[3, 3],     q[1, 2]-q[3, 0],     q[1, 3]+q[2, 0], 0.0],
-        [    q[1, 2]+q[3, 0], 1.0-q[1, 1]-q[3, 3],     q[2, 3]-q[1, 0], 0.0],
-        [    q[1, 3]-q[2, 0],     q[2, 3]+q[1, 0], 1.0-q[1, 1]-q[2, 2], 0.0],
-        [                0.0,                 0.0,                 0.0, 1.0]])
+    return numpy.array(
+        [
+            [1.0 - q[2, 2] - q[3, 3], q[1, 2] - q[3, 0], q[1, 3] + q[2, 0], 0.0],
+            [q[1, 2] + q[3, 0], 1.0 - q[1, 1] - q[3, 3], q[2, 3] - q[1, 0], 0.0],
+            [q[1, 3] - q[2, 0], q[2, 3] + q[1, 0], 1.0 - q[1, 1] - q[2, 2], 0.0],
+            [0.0, 0.0, 0.0, 1.0],
+        ]
+    )
 
 
 def quaternion_from_matrix(matrix, isprecise=False):
@@ -1316,7 +1329,7 @@ def quaternion_from_matrix(matrix, isprecise=False):
     """
     M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:4, :4]
     if isprecise:
-        q = numpy.empty((4, ))
+        q = numpy.empty((4,))
         t = numpy.trace(M)
         if t > M[3, 3]:
             q[0] = t
@@ -1346,10 +1359,14 @@ def quaternion_from_matrix(matrix, isprecise=False):
         m21 = M[2, 1]
         m22 = M[2, 2]
         # symmetric matrix K
-        K = numpy.array([[m00-m11-m22, 0.0,         0.0,         0.0],
-                         [m01+m10,     m11-m00-m22, 0.0,         0.0],
-                         [m02+m20,     m12+m21,     m22-m00-m11, 0.0],
-                         [m21-m12,     m02-m20,     m10-m01,     m00+m11+m22]])
+        K = numpy.array(
+            [
+                [m00 - m11 - m22, 0.0, 0.0, 0.0],
+                [m01 + m10, m11 - m00 - m22, 0.0, 0.0],
+                [m02 + m20, m12 + m21, m22 - m00 - m11, 0.0],
+                [m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22],
+            ]
+        )
         K /= 3.0
         # quaternion is eigenvector of K that corresponds to largest eigenvalue
         w, V = numpy.linalg.eigh(K)
@@ -1369,10 +1386,15 @@ def quaternion_multiply(quaternion1, quaternion0):
     """
     w0, x0, y0, z0 = quaternion0
     w1, x1, y1, z1 = quaternion1
-    return numpy.array([-x1*x0 - y1*y0 - z1*z0 + w1*w0,
-                         x1*w0 + y1*z0 - z1*y0 + w1*x0,
-                        -x1*z0 + y1*w0 + z1*x0 + w1*y0,
-                         x1*y0 - y1*x0 + z1*w0 + w1*z0], dtype=numpy.float64)
+    return numpy.array(
+        [
+            -x1 * x0 - y1 * y0 - z1 * z0 + w1 * w0,
+            x1 * w0 + y1 * z0 - z1 * y0 + w1 * x0,
+            -x1 * z0 + y1 * w0 + z1 * x0 + w1 * y0,
+            x1 * y0 - y1 * x0 + z1 * w0 + w1 * z0,
+        ],
+        dtype=numpy.float64,
+    )
 
 
 def quaternion_conjugate(quaternion):
@@ -1488,8 +1510,9 @@ def random_quaternion(rand=None):
     pi2 = math.pi * 2.0
     t1 = pi2 * rand[1]
     t2 = pi2 * rand[2]
-    return numpy.array([numpy.cos(t2)*r2, numpy.sin(t1)*r1,
-                        numpy.cos(t1)*r1, numpy.sin(t2)*r2])
+    return numpy.array(
+        [numpy.cos(t2) * r2, numpy.sin(t1) * r1, numpy.cos(t1) * r1, numpy.sin(t2) * r2]
+    )
 
 
 def random_rotation_matrix(rand=None):
@@ -1530,6 +1553,7 @@ class Arcball(object):
     >>> ball.next()
 
     """
+
     def __init__(self, initial=None):
         """Initialize virtual trackball control.
 
@@ -1548,7 +1572,7 @@ class Arcball(object):
             initial = numpy.array(initial, dtype=numpy.float64)
             if initial.shape == (4, 4):
                 self._qdown = quaternion_from_matrix(initial)
-            elif initial.shape == (4, ):
+            elif initial.shape == (4,):
                 initial /= vector_norm(initial)
                 self._qdown = initial
             else:
@@ -1610,7 +1634,7 @@ class Arcball(object):
 
     def next(self, acceleration=0.0):
         """Continue rotation in direction of last drag."""
-        q = quaternion_slerp(self._qpre, self._qnow, 2.0+acceleration, False)
+        q = quaternion_slerp(self._qpre, self._qnow, 2.0 + acceleration, False)
         self._qpre, self._qnow = self._qnow, q
 
     def matrix(self):
@@ -1622,11 +1646,11 @@ def arcball_map_to_sphere(point, center, radius):
     """Return unit sphere coordinates from window coordinates."""
     v0 = (point[0] - center[0]) / radius
     v1 = (center[1] - point[1]) / radius
-    n = v0*v0 + v1*v1
+    n = v0 * v0 + v1 * v1
     if n > 1.0:
         # position outside of sphere
         n = math.sqrt(n)
-        return numpy.array([v0/n, v1/n, 0.0])
+        return numpy.array([v0 / n, v1 / n, 0.0])
     else:
         return numpy.array([v0, v1, math.sqrt(1.0 - n)])
 
@@ -1668,14 +1692,31 @@ _NEXT_AXIS = [1, 2, 0, 1]
 
 # map axes strings to/from tuples of inner axis, parity, repetition, frame
 _AXES2TUPLE = {
-    'sxyz': (0, 0, 0, 0), 'sxyx': (0, 0, 1, 0), 'sxzy': (0, 1, 0, 0),
-    'sxzx': (0, 1, 1, 0), 'syzx': (1, 0, 0, 0), 'syzy': (1, 0, 1, 0),
-    'syxz': (1, 1, 0, 0), 'syxy': (1, 1, 1, 0), 'szxy': (2, 0, 0, 0),
-    'szxz': (2, 0, 1, 0), 'szyx': (2, 1, 0, 0), 'szyz': (2, 1, 1, 0),
-    'rzyx': (0, 0, 0, 1), 'rxyx': (0, 0, 1, 1), 'ryzx': (0, 1, 0, 1),
-    'rxzx': (0, 1, 1, 1), 'rxzy': (1, 0, 0, 1), 'ryzy': (1, 0, 1, 1),
-    'rzxy': (1, 1, 0, 1), 'ryxy': (1, 1, 1, 1), 'ryxz': (2, 0, 0, 1),
-    'rzxz': (2, 0, 1, 1), 'rxyz': (2, 1, 0, 1), 'rzyz': (2, 1, 1, 1)}
+    "sxyz": (0, 0, 0, 0),
+    "sxyx": (0, 0, 1, 0),
+    "sxzy": (0, 1, 0, 0),
+    "sxzx": (0, 1, 1, 0),
+    "syzx": (1, 0, 0, 0),
+    "syzy": (1, 0, 1, 0),
+    "syxz": (1, 1, 0, 0),
+    "syxy": (1, 1, 1, 0),
+    "szxy": (2, 0, 0, 0),
+    "szxz": (2, 0, 1, 0),
+    "szyx": (2, 1, 0, 0),
+    "szyz": (2, 1, 1, 0),
+    "rzyx": (0, 0, 0, 1),
+    "rxyx": (0, 0, 1, 1),
+    "ryzx": (0, 1, 0, 1),
+    "rxzx": (0, 1, 1, 1),
+    "rxzy": (1, 0, 0, 1),
+    "ryzy": (1, 0, 1, 1),
+    "rzxy": (1, 1, 0, 1),
+    "ryxy": (1, 1, 1, 1),
+    "ryxz": (2, 0, 0, 1),
+    "rzxz": (2, 0, 1, 1),
+    "rxyz": (2, 1, 0, 1),
+    "rzyz": (2, 1, 1, 1),
+}
 
 _TUPLE2AXES = dict((v, k) for k, v in _AXES2TUPLE.items())
 
@@ -1754,7 +1795,7 @@ def unit_vector(data, axis=None, out=None):
         if out is not data:
             out[:] = numpy.array(data, copy=False)
         data = out
-    length = numpy.atleast_1d(numpy.sum(data*data, axis))
+    length = numpy.atleast_1d(numpy.sum(data * data, axis))
     numpy.sqrt(length, length)
     if axis is not None:
         length = numpy.expand_dims(length, axis)
@@ -1878,7 +1919,7 @@ def is_same_transform(matrix0, matrix1):
     return numpy.allclose(matrix0, matrix1)
 
 
-def _import_module(name, package=None, warn=True, prefix='_py_', ignore='_'):
+def _import_module(name, package=None, warn=True, prefix="_py_", ignore="_"):
     """Try import all public attributes from module into global namespace.
 
     Existing attributes with name clashes are renamed with prefix.
@@ -1889,14 +1930,15 @@ def _import_module(name, package=None, warn=True, prefix='_py_', ignore='_'):
     """
     import warnings
     from importlib import import_module
+
     try:
         if not package:
             module = import_module(name)
         else:
-            module = import_module('.' + name, package=package)
+            module = import_module("." + name, package=package)
     except ImportError:
         if warn:
-            #warnings.warn("failed to import module %s" % name)
+            # warnings.warn("failed to import module %s" % name)
             pass
     else:
         for attr in dir(module):
@@ -1911,11 +1953,11 @@ def _import_module(name, package=None, warn=True, prefix='_py_', ignore='_'):
         return True
 
 
-_import_module('_transformations')
+_import_module("_transformations")
 
 if __name__ == "__main__":
     import doctest
     import random  # used in doctests
+
     numpy.set_printoptions(suppress=True, precision=5)
     doctest.testmod()
-
diff --git a/third_party/SOLD2/setup.py b/third_party/SOLD2/setup.py
index 69f72fecdc54cf9b43a7fc55144470e83c5a862d..e6c9cdcb47bdd73758cbd2d5b125dcb91306705f 100644
--- a/third_party/SOLD2/setup.py
+++ b/third_party/SOLD2/setup.py
@@ -1,4 +1,4 @@
 from setuptools import setup
 
 
-setup(name='sold2', version="0.0", packages=['sold2'])
+setup(name="sold2", version="0.0", packages=["sold2"])
diff --git a/third_party/SOLD2/sold2/config/project_config.py b/third_party/SOLD2/sold2/config/project_config.py
index 42ed00d1c1900e71568d1b06ff4f9d19a295232d..6846b4451e038b1c517043ea6db08f3029b79852 100644
--- a/third_party/SOLD2/sold2/config/project_config.py
+++ b/third_party/SOLD2/sold2/config/project_config.py
@@ -5,26 +5,29 @@ import os
 
 
 class Config(object):
-    """ Datasets and experiments folders for the whole project. """
+    """Datasets and experiments folders for the whole project."""
+
     #####################
     ## Dataset setting ##
     #####################
-    DATASET_ROOT = os.getenv("DATASET_ROOT", "./datasets/")  # TODO: path to your datasets folder
+    DATASET_ROOT = os.getenv(
+        "DATASET_ROOT", "./datasets/"
+    )  # TODO: path to your datasets folder
     if not os.path.exists(DATASET_ROOT):
         os.makedirs(DATASET_ROOT)
-    
+
     # Synthetic shape dataset
     synthetic_dataroot = os.path.join(DATASET_ROOT, "synthetic_shapes")
     synthetic_cache_path = os.path.join(DATASET_ROOT, "synthetic_shapes")
     if not os.path.exists(synthetic_dataroot):
         os.makedirs(synthetic_dataroot)
-    
+
     # Exported predictions dataset
     export_dataroot = os.path.join(DATASET_ROOT, "export_datasets")
     export_cache_path = os.path.join(DATASET_ROOT, "export_datasets")
     if not os.path.exists(export_dataroot):
         os.makedirs(export_dataroot)
-    
+
     # Wireframe dataset
     wireframe_dataroot = os.path.join(DATASET_ROOT, "wireframe")
     wireframe_cache_path = os.path.join(DATASET_ROOT, "wireframe")
@@ -32,10 +35,12 @@ class Config(object):
     # Holicity dataset
     holicity_dataroot = os.path.join(DATASET_ROOT, "Holicity")
     holicity_cache_path = os.path.join(DATASET_ROOT, "Holicity")
-    
+
     ########################
     ## Experiment Setting ##
     ########################
-    EXP_PATH = os.getenv("EXP_PATH", "./experiments/")  # TODO: path to your experiments folder
+    EXP_PATH = os.getenv(
+        "EXP_PATH", "./experiments/"
+    )  # TODO: path to your experiments folder
     if not os.path.exists(EXP_PATH):
         os.makedirs(EXP_PATH)
diff --git a/third_party/SOLD2/sold2/dataset/dataset_util.py b/third_party/SOLD2/sold2/dataset/dataset_util.py
index 50439ef3e2958d82719da0f6d10f4a7d98322f9a..67271bc915e6975cad005e9001d2bb430a8baa14 100644
--- a/third_party/SOLD2/sold2/dataset/dataset_util.py
+++ b/third_party/SOLD2/sold2/dataset/dataset_util.py
@@ -8,53 +8,50 @@ from .merge_dataset import MergeDataset
 
 
 def get_dataset(mode="train", dataset_cfg=None):
-    """ Initialize different dataset based on a configuration. """
+    """Initialize different dataset based on a configuration."""
     # Check dataset config is given
     if dataset_cfg is None:
         raise ValueError("[Error] The dataset config is required!")
 
     # Synthetic dataset
     if dataset_cfg["dataset_name"] == "synthetic_shape":
-        dataset = SyntheticShapes(
-            mode, dataset_cfg
-        )
+        dataset = SyntheticShapes(mode, dataset_cfg)
 
         # Get the collate_fn
         from .synthetic_dataset import synthetic_collate_fn
+
         collate_fn = synthetic_collate_fn
 
     # Wireframe dataset
     elif dataset_cfg["dataset_name"] == "wireframe":
-        dataset = WireframeDataset(
-            mode, dataset_cfg
-        )
+        dataset = WireframeDataset(mode, dataset_cfg)
 
         # Get the collate_fn
         from .wireframe_dataset import wireframe_collate_fn
+
         collate_fn = wireframe_collate_fn
-    
+
     # Holicity dataset
     elif dataset_cfg["dataset_name"] == "holicity":
-        dataset = HolicityDataset(
-            mode, dataset_cfg
-        )
+        dataset = HolicityDataset(mode, dataset_cfg)
 
         # Get the collate_fn
         from .holicity_dataset import holicity_collate_fn
+
         collate_fn = holicity_collate_fn
-    
+
     # Dataset merging several datasets in one
     elif dataset_cfg["dataset_name"] == "merge":
-        dataset = MergeDataset(
-            mode, dataset_cfg
-        )
+        dataset = MergeDataset(mode, dataset_cfg)
 
         # Get the collate_fn
         from .holicity_dataset import holicity_collate_fn
+
         collate_fn = holicity_collate_fn
 
     else:
         raise ValueError(
-    "[Error] The dataset '%s' is not supported" % dataset_cfg["dataset_name"])
+            "[Error] The dataset '%s' is not supported" % dataset_cfg["dataset_name"]
+        )
 
     return dataset, collate_fn
diff --git a/third_party/SOLD2/sold2/dataset/holicity_dataset.py b/third_party/SOLD2/sold2/dataset/holicity_dataset.py
index e4437f37bda366983052de902a41467ca01412bd..af182c5ef46d68d595da4c3dda76c1f631d56fcc 100644
--- a/third_party/SOLD2/sold2/dataset/holicity_dataset.py
+++ b/third_party/SOLD2/sold2/dataset/holicity_dataset.py
@@ -26,12 +26,19 @@ from ..misc.train_utils import parse_h5_data
 
 
 def holicity_collate_fn(batch):
-    """ Customized collate_fn. """
-    batch_keys = ["image", "junction_map", "valid_mask", "heatmap",
-                  "heatmap_pos", "heatmap_neg", "homography",
-                  "line_points", "line_indices"]
-    list_keys = ["junctions", "line_map", "line_map_pos",
-                 "line_map_neg", "file_key"]
+    """Customized collate_fn."""
+    batch_keys = [
+        "image",
+        "junction_map",
+        "valid_mask",
+        "heatmap",
+        "heatmap_pos",
+        "heatmap_neg",
+        "homography",
+        "line_points",
+        "line_indices",
+    ]
+    list_keys = ["junctions", "line_map", "line_map_pos", "line_map_neg", "file_key"]
 
     outputs = {}
     for data_key in batch[0].keys():
@@ -40,14 +47,16 @@ def holicity_collate_fn(batch):
         # print(batch_match, list_match)
         if batch_match > 0 and list_match == 0:
             outputs[data_key] = torch_loader.default_collate(
-                [b[data_key] for b in batch])
+                [b[data_key] for b in batch]
+            )
         elif batch_match == 0 and list_match > 0:
             outputs[data_key] = [b[data_key] for b in batch]
         elif batch_match == 0 and list_match == 0:
             continue
         else:
             raise ValueError(
-        "[Error] A key matches batch keys and list keys simultaneously.")
+                "[Error] A key matches batch keys and list keys simultaneously."
+            )
 
     return outputs
 
@@ -57,7 +66,8 @@ class HolicityDataset(Dataset):
         super(HolicityDataset, self).__init__()
         if not mode in ["train", "test"]:
             raise ValueError(
-        "[Error] Unknown mode for Holicity dataset. Only 'train' and 'test'.")
+                "[Error] Unknown mode for Holicity dataset. Only 'train' and 'test'."
+            )
         self.mode = mode
 
         if config is None:
@@ -71,17 +81,18 @@ class HolicityDataset(Dataset):
         self.dataset_name = self.get_dataset_name()
         self.cache_name = self.get_cache_name()
         self.cache_path = cfg.holicity_cache_path
-        
+
         # Get the ground truth source if it exists
         self.gt_source = None
-        if "gt_source_%s"%(self.mode) in self.config:
-            self.gt_source = self.config.get("gt_source_%s"%(self.mode))
+        if "gt_source_%s" % (self.mode) in self.config:
+            self.gt_source = self.config.get("gt_source_%s" % (self.mode))
             self.gt_source = os.path.join(cfg.export_dataroot, self.gt_source)
             # Check the full path exists
             if not os.path.exists(self.gt_source):
                 raise ValueError(
-            "[Error] The specified ground truth source does not exist.")
-        
+                    "[Error] The specified ground truth source does not exist."
+                )
+
         # Get the filename dataset
         print("[Info] Initializing Holicity dataset...")
         self.filename_dataset, self.datapoints = self.construct_dataset()
@@ -92,22 +103,22 @@ class HolicityDataset(Dataset):
         # Print some info
         print("[Info] Successfully initialized dataset")
         print("\t Name: Holicity")
-        print("\t Mode: %s" %(self.mode))
-        print("\t Gt: %s" %(self.config.get("gt_source_%s"%(self.mode),
-                                            "None")))
-        print("\t Counts: %d" %(self.dataset_length))
+        print("\t Mode: %s" % (self.mode))
+        print("\t Gt: %s" % (self.config.get("gt_source_%s" % (self.mode), "None")))
+        print("\t Counts: %d" % (self.dataset_length))
         print("----------------------------------------")
 
     #######################################
     ## Dataset construction related APIs ##
     #######################################
     def construct_dataset(self):
-        """ Construct the dataset (from scratch or from cache). """
+        """Construct the dataset (from scratch or from cache)."""
         # Check if the filename cache exists
         # If cache exists, load from cache
         if self.check_dataset_cache():
-            print("\t Found filename cache %s at %s"%(self.cache_name,
-                                                      self.cache_path))
+            print(
+                "\t Found filename cache %s at %s" % (self.cache_name, self.cache_path)
+            )
             print("\t Load filename cache...")
             filename_dataset, datapoints = self.get_filename_dataset_from_cache()
         # If not, initialize dataset from scratch
@@ -117,56 +128,56 @@ class HolicityDataset(Dataset):
             filename_dataset, datapoints = self.get_filename_dataset()
             print("\t Create filename dataset cache...")
             self.create_filename_dataset_cache(filename_dataset, datapoints)
-        
+
         return filename_dataset, datapoints
-    
+
     def create_filename_dataset_cache(self, filename_dataset, datapoints):
-        """ Create filename dataset cache for faster initialization. """
+        """Create filename dataset cache for faster initialization."""
         # Check cache path exists
         if not os.path.exists(self.cache_path):
             os.makedirs(self.cache_path)
 
         cache_file_path = os.path.join(self.cache_path, self.cache_name)
-        data = {
-            "filename_dataset": filename_dataset,
-            "datapoints": datapoints
-        }
+        data = {"filename_dataset": filename_dataset, "datapoints": datapoints}
         with open(cache_file_path, "wb") as f:
             pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
-    
+
     def get_filename_dataset_from_cache(self):
-        """ Get filename dataset from cache. """
+        """Get filename dataset from cache."""
         # Load from pkl cache
         cache_file_path = os.path.join(self.cache_path, self.cache_name)
         with open(cache_file_path, "rb") as f:
             data = pickle.load(f)
-        
+
         return data["filename_dataset"], data["datapoints"]
 
     def get_filename_dataset(self):
-        """ Get the path to the dataset. """
+        """Get the path to the dataset."""
         if self.mode == "train":
             # Contains 5720 or 11872 images
-            dataset_path = [os.path.join(cfg.holicity_dataroot, p)
-                            for p in self.config["train_splits"]]
+            dataset_path = [
+                os.path.join(cfg.holicity_dataroot, p)
+                for p in self.config["train_splits"]
+            ]
         else:
             # Test mode - Contains 520 images
             dataset_path = [os.path.join(cfg.holicity_dataroot, "2018-03")]
-        
+
         # Get paths to all image files
         image_paths = []
         for folder in dataset_path:
-            image_paths += [os.path.join(folder, img)
-                            for img in os.listdir(folder)
-                            if os.path.splitext(img)[-1] == ".jpg"]
+            image_paths += [
+                os.path.join(folder, img)
+                for img in os.listdir(folder)
+                if os.path.splitext(img)[-1] == ".jpg"
+            ]
         image_paths = sorted(image_paths)
 
         # Verify all the images exist
         for idx in range(len(image_paths)):
             image_path = image_paths[idx]
             if not (os.path.exists(image_path)):
-                raise ValueError(
-            "[Error] The image does not exist. %s"%(image_path))
+                raise ValueError("[Error] The image does not exist. %s" % (image_path))
 
         # Construct the filename dataset
         num_pad = int(math.ceil(math.log10(len(image_paths))) + 1)
@@ -176,82 +187,77 @@ class HolicityDataset(Dataset):
             key = self.get_padded_filename(num_pad, idx)
 
             filename_dataset[key] = {"image": image_paths[idx]}
-        
+
         # Get the datapoints
         datapoints = list(sorted(filename_dataset.keys()))
 
         return filename_dataset, datapoints
-    
+
     def get_dataset_name(self):
-        """ Get dataset name from dataset config / default config. """
-        dataset_name = self.config.get("dataset_name",
-                                       self.default_config["dataset_name"])
+        """Get dataset name from dataset config / default config."""
+        dataset_name = self.config.get(
+            "dataset_name", self.default_config["dataset_name"]
+        )
         dataset_name = dataset_name + "_%s" % self.mode
         return dataset_name
-    
+
     def get_cache_name(self):
-        """ Get cache name from dataset config / default config. """
-        dataset_name = self.config.get("dataset_name",
-                                       self.default_config["dataset_name"])
+        """Get cache name from dataset config / default config."""
+        dataset_name = self.config.get(
+            "dataset_name", self.default_config["dataset_name"]
+        )
         dataset_name = dataset_name + "_%s" % self.mode
         # Compose cache name
         cache_name = dataset_name + "_cache.pkl"
         return cache_name
 
     def check_dataset_cache(self):
-        """ Check if dataset cache exists. """
+        """Check if dataset cache exists."""
         cache_file_path = os.path.join(self.cache_path, self.cache_name)
         if os.path.exists(cache_file_path):
             return True
         else:
             return False
-    
+
     @staticmethod
     def get_padded_filename(num_pad, idx):
-        """ Get the padded filename using adaptive padding. """
+        """Get the padded filename using adaptive padding."""
         file_len = len("%d" % (idx))
         filename = "0" * (num_pad - file_len) + "%d" % (idx)
         return filename
 
     def get_default_config(self):
-        """ Get the default configuration. """
+        """Get the default configuration."""
         return {
             "dataset_name": "holicity",
             "train_split": "2018-01",
             "add_augmentation_to_all_splits": False,
-            "preprocessing": {
-                "resize": [512, 512],
-                "blur_size": 11
-            },
-            "augmentation":{
-                "photometric":{
-                    "enable": False
-                },
-                "homographic":{
-                    "enable": False
-                },
+            "preprocessing": {"resize": [512, 512], "blur_size": 11},
+            "augmentation": {
+                "photometric": {"enable": False},
+                "homographic": {"enable": False},
             },
         }
-        
+
     ############################################
     ## Pytorch and preprocessing related APIs ##
     ############################################
     @staticmethod
     def get_data_from_path(data_path):
-        """ Get data from the information from filename dataset. """
+        """Get data from the information from filename dataset."""
         output = {}
 
         # Get image data
         image_path = data_path["image"]
         image = imread(image_path)
         output["image"] = image
-        
+
         return output
-    
+
     @staticmethod
     def convert_line_map(lcnn_line_map, num_junctions):
-        """ Convert the line_pos or line_neg
-            (represented by two junction indexes) to our line map. """
+        """Convert the line_pos or line_neg
+        (represented by two junction indexes) to our line map."""
         # Initialize empty line map
         line_map = np.zeros([num_junctions, num_junctions])
 
@@ -262,59 +268,60 @@ class HolicityDataset(Dataset):
 
             line_map[index1, index2] = 1
             line_map[index2, index1] = 1
-        
+
         return line_map
 
     @staticmethod
     def junc_to_junc_map(junctions, image_size):
-        """ Convert junction points to junction maps. """
+        """Convert junction points to junction maps."""
         junctions = np.round(junctions).astype(np.int)
         # Clip the boundary by image size
-        junctions[:, 0] = np.clip(junctions[:, 0], 0., image_size[0]-1)
-        junctions[:, 1] = np.clip(junctions[:, 1], 0., image_size[1]-1)
+        junctions[:, 0] = np.clip(junctions[:, 0], 0.0, image_size[0] - 1)
+        junctions[:, 1] = np.clip(junctions[:, 1], 0.0, image_size[1] - 1)
 
         # Create junction map
         junc_map = np.zeros([image_size[0], image_size[1]])
         junc_map[junctions[:, 0], junctions[:, 1]] = 1
 
         return junc_map[..., None].astype(np.int)
-    
+
     def parse_transforms(self, names, all_transforms):
-        """ Parse the transform. """
-        trans = all_transforms if (names == 'all') \
+        """Parse the transform."""
+        trans = (
+            all_transforms
+            if (names == "all")
             else (names if isinstance(names, list) else [names])
+        )
         assert set(trans) <= set(all_transforms)
         return trans
 
     def get_photo_transform(self):
-        """ Get list of photometric transforms (according to the config). """
+        """Get list of photometric transforms (according to the config)."""
         # Get the photometric transform config
         photo_config = self.config["augmentation"]["photometric"]
         if not photo_config["enable"]:
-            raise ValueError(
-        "[Error] Photometric augmentation is not enabled.")
-        
+            raise ValueError("[Error] Photometric augmentation is not enabled.")
+
         # Parse photometric transforms
-        trans_lst = self.parse_transforms(photo_config["primitives"],
-                                          photoaug.available_augmentations)
-        trans_config_lst = [photo_config["params"].get(p, {})
-                            for p in trans_lst]
+        trans_lst = self.parse_transforms(
+            photo_config["primitives"], photoaug.available_augmentations
+        )
+        trans_config_lst = [photo_config["params"].get(p, {}) for p in trans_lst]
 
         # List of photometric augmentation
         photometric_trans_lst = [
-            getattr(photoaug, trans)(**conf) \
+            getattr(photoaug, trans)(**conf)
             for (trans, conf) in zip(trans_lst, trans_config_lst)
         ]
 
         return photometric_trans_lst
 
     def get_homo_transform(self):
-        """ Get homographic transforms (according to the config). """
+        """Get homographic transforms (according to the config)."""
         # Get homographic transforms for image
         homo_config = self.config["augmentation"]["homographic"]["params"]
         if not self.config["augmentation"]["homographic"]["enable"]:
-            raise ValueError(
-        "[Error] Homographic augmentation is not enabled")
+            raise ValueError("[Error] Homographic augmentation is not enabled")
 
         # Parse the homographic transforms
         image_shape = self.config["preprocessing"]["resize"]
@@ -324,30 +331,33 @@ class HolicityDataset(Dataset):
             min_label_tmp = self.config["generation"]["min_label_len"]
         except:
             min_label_tmp = None
-        
+
         # float label len => fraction
-        if isinstance(min_label_tmp, float): # Skip if not provided
+        if isinstance(min_label_tmp, float):  # Skip if not provided
             min_label_len = min_label_tmp * min(image_shape)
         # int label len => length in pixel
         elif isinstance(min_label_tmp, int):
-            scale_ratio = (self.config["preprocessing"]["resize"]
-                           / self.config["generation"]["image_size"][0])
-            min_label_len = (self.config["generation"]["min_label_len"]
-                             * scale_ratio)
+            scale_ratio = (
+                self.config["preprocessing"]["resize"]
+                / self.config["generation"]["image_size"][0]
+            )
+            min_label_len = self.config["generation"]["min_label_len"] * scale_ratio
         # if none => no restriction
         else:
             min_label_len = 0
-        
+
         # Initialize the transform
         homographic_trans = homoaug.homography_transform(
-            image_shape, homo_config, 0, min_label_len)
+            image_shape, homo_config, 0, min_label_len
+        )
 
         return homographic_trans
 
-    def get_line_points(self, junctions, line_map, H1=None, H2=None,
-                        img_size=None, warp=False):
-        """ Sample evenly points along each line segments
-            and keep track of line idx. """
+    def get_line_points(
+        self, junctions, line_map, H1=None, H2=None, img_size=None, warp=False
+    ):
+        """Sample evenly points along each line segments
+        and keep track of line idx."""
         if np.sum(line_map) == 0:
             # No segment detected in the image
             line_indices = np.zeros(self.config["max_pts"], dtype=int)
@@ -356,35 +366,38 @@ class HolicityDataset(Dataset):
 
         # Extract all pairs of connected junctions
         junc_indices = np.array(
-            [[i, j] for (i, j) in zip(*np.where(line_map)) if j > i])
-        line_segments = np.stack([junctions[junc_indices[:, 0]],
-                                  junctions[junc_indices[:, 1]]], axis=1)
+            [[i, j] for (i, j) in zip(*np.where(line_map)) if j > i]
+        )
+        line_segments = np.stack(
+            [junctions[junc_indices[:, 0]], junctions[junc_indices[:, 1]]], axis=1
+        )
         # line_segments is (num_lines, 2, 2)
-        line_lengths = np.linalg.norm(
-            line_segments[:, 0] - line_segments[:, 1], axis=1)
+        line_lengths = np.linalg.norm(line_segments[:, 0] - line_segments[:, 1], axis=1)
 
         # Sample the points separated by at least min_dist_pts along each line
         # The number of samples depends on the length of the line
-        num_samples = np.minimum(line_lengths // self.config["min_dist_pts"],
-                                 self.config["max_num_samples"])
+        num_samples = np.minimum(
+            line_lengths // self.config["min_dist_pts"], self.config["max_num_samples"]
+        )
         line_points = []
         line_indices = []
         cur_line_idx = 1
         for n in np.arange(2, self.config["max_num_samples"] + 1):
             # Consider all lines where we can fit up to n points
             cur_line_seg = line_segments[num_samples == n]
-            line_points_x = np.linspace(cur_line_seg[:, 0, 0],
-                                        cur_line_seg[:, 1, 0],
-                                        n, axis=-1).flatten()
-            line_points_y = np.linspace(cur_line_seg[:, 0, 1],
-                                        cur_line_seg[:, 1, 1],
-                                        n, axis=-1).flatten()
+            line_points_x = np.linspace(
+                cur_line_seg[:, 0, 0], cur_line_seg[:, 1, 0], n, axis=-1
+            ).flatten()
+            line_points_y = np.linspace(
+                cur_line_seg[:, 0, 1], cur_line_seg[:, 1, 1], n, axis=-1
+            ).flatten()
             jitter = self.config.get("jittering", 0)
             if jitter:
                 # Add a small random jittering of all points along the line
                 angles = np.arctan2(
                     cur_line_seg[:, 1, 0] - cur_line_seg[:, 0, 0],
-                    cur_line_seg[:, 1, 1] - cur_line_seg[:, 0, 1]).repeat(n)
+                    cur_line_seg[:, 1, 1] - cur_line_seg[:, 0, 1],
+                ).repeat(n)
                 jitter_hyp = (np.random.rand(len(angles)) * 2 - 1) * jitter
                 line_points_x += jitter_hyp * np.sin(angles)
                 line_points_y += jitter_hyp * np.cos(angles)
@@ -394,10 +407,8 @@ class HolicityDataset(Dataset):
             line_idx = np.arange(cur_line_idx, cur_line_idx + num_cur_lines)
             line_indices.append(line_idx.repeat(n))
             cur_line_idx += num_cur_lines
-        line_points = np.concatenate(line_points,
-                                     axis=0)[:self.config["max_pts"]]
-        line_indices = np.concatenate(line_indices,
-                                      axis=0)[:self.config["max_pts"]]
+        line_points = np.concatenate(line_points, axis=0)[: self.config["max_pts"]]
+        line_indices = np.concatenate(line_indices, axis=0)[: self.config["max_pts"]]
 
         # Warp the points if need be, and filter unvalid ones
         # If the other view is also warped
@@ -419,37 +430,43 @@ class HolicityDataset(Dataset):
             mask = mask_points(warped_points, img_size)
         line_points = line_points[mask]
         line_indices = line_indices[mask]
-        
+
         # Pad the line points to a fixed length
         # Index of 0 means padded line
-        line_indices = np.concatenate([line_indices, np.zeros(
-            self.config["max_pts"] - len(line_indices))], axis=0)
+        line_indices = np.concatenate(
+            [line_indices, np.zeros(self.config["max_pts"] - len(line_indices))], axis=0
+        )
         line_points = np.concatenate(
-            [line_points,
-             np.zeros((self.config["max_pts"] - len(line_points), 2),
-                      dtype=float)], axis=0)
-        
+            [
+                line_points,
+                np.zeros((self.config["max_pts"] - len(line_points), 2), dtype=float),
+            ],
+            axis=0,
+        )
+
         return line_points, line_indices
 
     def export_preprocessing(self, data, numpy=False):
-        """ Preprocess the exported data. """
+        """Preprocess the exported data."""
         # Fetch the corresponding entries
         image = data["image"]
         image_size = image.shape[:2]
 
         # Resize the image before photometric and homographical augmentations
-        if not(list(image_size) == self.config["preprocessing"]["resize"]):
+        if not (list(image_size) == self.config["preprocessing"]["resize"]):
             # Resize the image and the point location.
-            size_old = list(image.shape)[:2] # Only H and W dimensions
+            size_old = list(image.shape)[:2]  # Only H and W dimensions
 
             image = cv2.resize(
-                image, tuple(self.config['preprocessing']['resize'][::-1]),
-                interpolation=cv2.INTER_LINEAR)
+                image,
+                tuple(self.config["preprocessing"]["resize"][::-1]),
+                interpolation=cv2.INTER_LINEAR,
+            )
             image = np.array(image, dtype=np.uint8)
-        
+
         # Optionally convert the image to grayscale
         if self.config["gray_scale"]:
-            image = (color.rgb2gray(image) * 255.).astype(np.uint8)
+            image = (color.rgb2gray(image) * 255.0).astype(np.uint8)
 
         image = photoaug.normalize_image()(image)
 
@@ -459,11 +476,21 @@ class HolicityDataset(Dataset):
             return {"image": to_tensor(image)}
         else:
             return {"image": image}
-    
+
     def train_preprocessing_exported(
-        self, data, numpy=False, disable_homoaug=False, desc_training=False,
-        H1=None, H1_scale=None, H2=None, scale=1., h_crop=None, w_crop=None):
-        """ Train preprocessing for the exported labels. """
+        self,
+        data,
+        numpy=False,
+        disable_homoaug=False,
+        desc_training=False,
+        H1=None,
+        H1_scale=None,
+        H2=None,
+        scale=1.0,
+        h_crop=None,
+        w_crop=None,
+    ):
+        """Train preprocessing for the exported labels."""
         data = copy.deepcopy(data)
         # Fetch the corresponding entries
         image = data["image"]
@@ -483,13 +510,15 @@ class HolicityDataset(Dataset):
                     w_crop = np.random.randint(W_scale - W)
 
         # Resize the image before photometric and homographical augmentations
-        if not(list(image_size) == self.config["preprocessing"]["resize"]):
+        if not (list(image_size) == self.config["preprocessing"]["resize"]):
             # Resize the image and the point location.
-            size_old = list(image.shape)[:2] # Only H and W dimensions
+            size_old = list(image.shape)[:2]  # Only H and W dimensions
 
             image = cv2.resize(
-                image, tuple(self.config['preprocessing']['resize'][::-1]),
-                interpolation=cv2.INTER_LINEAR)
+                image,
+                tuple(self.config["preprocessing"]["resize"][::-1]),
+                interpolation=cv2.INTER_LINEAR,
+            )
             image = np.array(image, dtype=np.uint8)
 
             # # In HW format
@@ -504,7 +533,7 @@ class HolicityDataset(Dataset):
 
         # Optionally convert the image to grayscale
         if self.config["gray_scale"]:
-            image = (color.rgb2gray(image) * 255.).astype(np.uint8)
+            image = (color.rgb2gray(image) * 255.0).astype(np.uint8)
 
         # Check if we need to apply augmentations
         # In training mode => yes.
@@ -514,16 +543,17 @@ class HolicityDataset(Dataset):
             ### Image transform ###
             np.random.shuffle(photo_trans_lst)
             image_transform = transforms.Compose(
-                photo_trans_lst + [photoaug.normalize_image()])
+                photo_trans_lst + [photoaug.normalize_image()]
+            )
         else:
             image_transform = photoaug.normalize_image()
         image = image_transform(image)
 
         # Perform the random scaling
-        if scale != 1.:
+        if scale != 1.0:
             image, junctions, line_map, valid_mask = random_scaling(
-                 image, junctions, line_map, scale,
-                 h_crop=h_crop, w_crop=w_crop)
+                image, junctions, line_map, scale, h_crop=h_crop, w_crop=w_crop
+            )
         else:
             # Declare default valid mask (all ones)
             valid_mask = np.ones(image_size)
@@ -534,20 +564,28 @@ class HolicityDataset(Dataset):
         to_tensor = transforms.ToTensor()
 
         # Check homographic augmentation
-        warp = (self.config["augmentation"]["homographic"]["enable"]
-                and disable_homoaug == False)
+        warp = (
+            self.config["augmentation"]["homographic"]["enable"]
+            and disable_homoaug == False
+        )
         if warp:
             homo_trans = self.get_homo_transform()
             # Perform homographic transform
             if H1 is None:
-                homo_outputs = homo_trans(image, junctions, line_map,
-                                            valid_mask=valid_mask)
+                homo_outputs = homo_trans(
+                    image, junctions, line_map, valid_mask=valid_mask
+                )
             else:
                 homo_outputs = homo_trans(
-                    image, junctions, line_map, homo=H1, scale=H1_scale,
-                    valid_mask=valid_mask)
+                    image,
+                    junctions,
+                    line_map,
+                    homo=H1,
+                    scale=H1_scale,
+                    valid_mask=valid_mask,
+                )
             homography_mat = homo_outputs["homo"]
-            
+
             # Give the warp of the other view
             if H1 is None:
                 H1 = homo_outputs["homo"]
@@ -555,8 +593,8 @@ class HolicityDataset(Dataset):
         # Sample points along each line segments for the descriptor
         if desc_training:
             line_points, line_indices = self.get_line_points(
-                junctions, line_map, H1=H1, H2=H2,
-                img_size=image_size, warp=warp)
+                junctions, line_map, H1=H1, H2=H2, img_size=image_size, warp=warp
+            )
 
         # Record the warped results
         if warp:
@@ -565,52 +603,59 @@ class HolicityDataset(Dataset):
             line_map = homo_outputs["line_map"]
             valid_mask = homo_outputs["valid_mask"]  # Same for pos and neg
             heatmap = homo_outputs["warped_heatmap"]
-            
+
             # Optionally put warping information first.
             if not numpy:
-                outputs["homography_mat"] = to_tensor(
-                    homography_mat).to(torch.float32)[0, ...]
+                outputs["homography_mat"] = to_tensor(homography_mat).to(torch.float32)[
+                    0, ...
+                ]
             else:
                 outputs["homography_mat"] = homography_mat.astype(np.float32)
 
         junction_map = self.junc_to_junc_map(junctions, image_size)
-        
+
         if not numpy:
-            outputs.update({
-                "image": to_tensor(image),
-                "junctions": to_tensor(junctions).to(torch.float32)[0, ...],
-                "junction_map": to_tensor(junction_map).to(torch.int),
-                "line_map": to_tensor(line_map).to(torch.int32)[0, ...],
-                "heatmap": to_tensor(heatmap).to(torch.int32),
-                "valid_mask": to_tensor(valid_mask).to(torch.int32)
-            })
+            outputs.update(
+                {
+                    "image": to_tensor(image),
+                    "junctions": to_tensor(junctions).to(torch.float32)[0, ...],
+                    "junction_map": to_tensor(junction_map).to(torch.int),
+                    "line_map": to_tensor(line_map).to(torch.int32)[0, ...],
+                    "heatmap": to_tensor(heatmap).to(torch.int32),
+                    "valid_mask": to_tensor(valid_mask).to(torch.int32),
+                }
+            )
             if desc_training:
-                outputs.update({
-                    "line_points": to_tensor(
-                        line_points).to(torch.float32)[0],
-                    "line_indices": torch.tensor(line_indices,
-                                                 dtype=torch.int)
-                })
+                outputs.update(
+                    {
+                        "line_points": to_tensor(line_points).to(torch.float32)[0],
+                        "line_indices": torch.tensor(line_indices, dtype=torch.int),
+                    }
+                )
         else:
-            outputs.update({
-                "image": image,
-                "junctions": junctions.astype(np.float32),
-                "junction_map": junction_map.astype(np.int32),
-                "line_map": line_map.astype(np.int32),
-                "heatmap": heatmap.astype(np.int32),
-                "valid_mask": valid_mask.astype(np.int32)
-            })
+            outputs.update(
+                {
+                    "image": image,
+                    "junctions": junctions.astype(np.float32),
+                    "junction_map": junction_map.astype(np.int32),
+                    "line_map": line_map.astype(np.int32),
+                    "heatmap": heatmap.astype(np.int32),
+                    "valid_mask": valid_mask.astype(np.int32),
+                }
+            )
             if desc_training:
-                outputs.update({
-                    "line_points": line_points.astype(np.float32),
-                    "line_indices": line_indices.astype(int)
-                })
-        
+                outputs.update(
+                    {
+                        "line_points": line_points.astype(np.float32),
+                        "line_indices": line_indices.astype(int),
+                    }
+                )
+
         return outputs
-    
-    def preprocessing_exported_paired_desc(self, data, numpy=False, scale=1.):
-        """ Train preprocessing for paired data for the exported labels
-            for descriptor training. """
+
+    def preprocessing_exported_paired_desc(self, data, numpy=False, scale=1.0):
+        """Train preprocessing for paired data for the exported labels
+        for descriptor training."""
         outputs = {}
 
         # Define the random crop for scaling if necessary
@@ -622,51 +667,66 @@ class HolicityDataset(Dataset):
                 h_crop = np.random.randint(H_scale - H)
             if W_scale > W:
                 w_crop = np.random.randint(W_scale - W)
-        
+
         # Sample ref homography first
         homo_config = self.config["augmentation"]["homographic"]["params"]
         image_shape = self.config["preprocessing"]["resize"]
-        ref_H, ref_scale = homoaug.sample_homography(image_shape,
-                                                     **homo_config)
+        ref_H, ref_scale = homoaug.sample_homography(image_shape, **homo_config)
 
         # Data for target view (All augmentation)
         target_data = self.train_preprocessing_exported(
-            data, numpy=numpy, desc_training=True, H1=None, H2=ref_H,
-            scale=scale, h_crop=h_crop, w_crop=w_crop)
+            data,
+            numpy=numpy,
+            desc_training=True,
+            H1=None,
+            H2=ref_H,
+            scale=scale,
+            h_crop=h_crop,
+            w_crop=w_crop,
+        )
 
         # Data for reference view (No homographical augmentation)
         ref_data = self.train_preprocessing_exported(
-            data, numpy=numpy, desc_training=True, H1=ref_H,
-            H1_scale=ref_scale, H2=target_data['homography_mat'].numpy(),
-            scale=scale, h_crop=h_crop, w_crop=w_crop)
+            data,
+            numpy=numpy,
+            desc_training=True,
+            H1=ref_H,
+            H1_scale=ref_scale,
+            H2=target_data["homography_mat"].numpy(),
+            scale=scale,
+            h_crop=h_crop,
+            w_crop=w_crop,
+        )
 
         # Spread ref data
         for key, val in ref_data.items():
             outputs["ref_" + key] = val
-        
+
         # Spread target data
         for key, val in target_data.items():
             outputs["target_" + key] = val
-        
+
         return outputs
 
     def test_preprocessing_exported(self, data, numpy=False):
-        """ Test preprocessing for the exported labels. """
+        """Test preprocessing for the exported labels."""
         data = copy.deepcopy(data)
         # Fetch the corresponding entries
         image = data["image"]
         junctions = data["junctions"]
-        line_map = data["line_map"]      
+        line_map = data["line_map"]
         image_size = image.shape[:2]
 
         # Resize the image before photometric and homographical augmentations
-        if not(list(image_size) == self.config["preprocessing"]["resize"]):
+        if not (list(image_size) == self.config["preprocessing"]["resize"]):
             # Resize the image and the point location.
-            size_old = list(image.shape)[:2] # Only H and W dimensions
+            size_old = list(image.shape)[:2]  # Only H and W dimensions
 
             image = cv2.resize(
-                image, tuple(self.config['preprocessing']['resize'][::-1]),
-                interpolation=cv2.INTER_LINEAR)
+                image,
+                tuple(self.config["preprocessing"]["resize"][::-1]),
+                interpolation=cv2.INTER_LINEAR,
+            )
             image = np.array(image, dtype=np.uint8)
 
             # # In HW format
@@ -676,7 +736,7 @@ class HolicityDataset(Dataset):
 
         # Optionally convert the image to grayscale
         if self.config["gray_scale"]:
-            image = (color.rgb2gray(image) * 255.).astype(np.uint8)
+            image = (color.rgb2gray(image) * 255.0).astype(np.uint8)
 
         # Still need to normalize image
         image_transform = photoaug.normalize_image()
@@ -686,7 +746,7 @@ class HolicityDataset(Dataset):
         junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1)
         image_size = image.shape[:2]
         heatmap = get_line_heatmap(junctions_xy, line_map, image_size)
-        
+
         # Declare default valid mask (all ones)
         valid_mask = np.ones(image_size)
 
@@ -701,7 +761,7 @@ class HolicityDataset(Dataset):
                 "junction_map": to_tensor(junction_map).to(torch.int),
                 "line_map": to_tensor(line_map).to(torch.int32)[0, ...],
                 "heatmap": to_tensor(heatmap).to(torch.int32),
-                "valid_mask": to_tensor(valid_mask).to(torch.int32)
+                "valid_mask": to_tensor(valid_mask).to(torch.int32),
             }
         else:
             outputs = {
@@ -710,38 +770,36 @@ class HolicityDataset(Dataset):
                 "junction_map": junction_map.astype(np.int32),
                 "line_map": line_map.astype(np.int32),
                 "heatmap": heatmap.astype(np.int32),
-                "valid_mask": valid_mask.astype(np.int32)
+                "valid_mask": valid_mask.astype(np.int32),
             }
-        
+
         return outputs
 
     def __len__(self):
         return self.dataset_length
-    
+
     def get_data_from_key(self, file_key):
-        """ Get data from file_key. """
+        """Get data from file_key."""
         # Check key exists
         if not file_key in self.filename_dataset.keys():
-            raise ValueError(
-        "[Error] the specified key is not in the dataset.")
-        
+            raise ValueError("[Error] the specified key is not in the dataset.")
+
         # Get the data paths
         data_path = self.filename_dataset[file_key]
         # Read in the image and npz labels
         data = self.get_data_from_path(data_path)
 
         # Perform transform and augmentation
-        if (self.mode == "train"
-            or self.config["add_augmentation_to_all_splits"]):
+        if self.mode == "train" or self.config["add_augmentation_to_all_splits"]:
             data = self.train_preprocessing(data, numpy=True)
         else:
             data = self.test_preprocessing(data, numpy=True)
-        
+
         # Add file key to the output
         data["file_key"] = file_key
-        
+
         return data
-    
+
     def __getitem__(self, idx):
         """Return data
         file_key: str, keys used to retrieve data from the filename dataset.
@@ -761,27 +819,25 @@ class HolicityDataset(Dataset):
         if self.gt_source:
             with h5py.File(self.gt_source, "r") as f:
                 exported_label = parse_h5_data(f[file_key])
-            
+
             data["junctions"] = exported_label["junctions"]
             data["line_map"] = exported_label["line_map"]
-        
+
         # Perform transform and augmentation
         return_type = self.config.get("return_type", "single")
         if self.gt_source is None:
             # For export only
             data = self.export_preprocessing(data)
-        elif (self.mode == "train"
-              or self.config["add_augmentation_to_all_splits"]):
+        elif self.mode == "train" or self.config["add_augmentation_to_all_splits"]:
             # Perform random scaling first
             if self.config["augmentation"]["random_scaling"]["enable"]:
                 scale_range = self.config["augmentation"]["random_scaling"]["range"]
                 # Decide the scaling
                 scale = np.random.uniform(min(scale_range), max(scale_range))
             else:
-                scale = 1.
+                scale = 1.0
             if self.mode == "train" and return_type == "paired_desc":
-                data = self.preprocessing_exported_paired_desc(data,
-                                                               scale=scale)
+                data = self.preprocessing_exported_paired_desc(data, scale=scale)
             else:
                 data = self.train_preprocessing_exported(data, scale=scale)
         else:
@@ -789,9 +845,8 @@ class HolicityDataset(Dataset):
                 data = self.preprocessing_exported_paired_desc(data)
             else:
                 data = self.test_preprocessing_exported(data)
-        
+
         # Add file key to the output
         data["file_key"] = file_key
-        
-        return data
 
+        return data
diff --git a/third_party/SOLD2/sold2/dataset/merge_dataset.py b/third_party/SOLD2/sold2/dataset/merge_dataset.py
index 178d3822d56639a49a99f68e392330e388fa8fc3..1f6395873dcfdea0c35898eefbf4c74a8cfac7a1 100644
--- a/third_party/SOLD2/sold2/dataset/merge_dataset.py
+++ b/third_party/SOLD2/sold2/dataset/merge_dataset.py
@@ -14,23 +14,24 @@ class MergeDataset(Dataset):
         # Initialize the datasets
         self._datasets = []
         spec_config = deepcopy(config)
-        for i, d in enumerate(config['datasets']):
-            spec_config['dataset_name'] = d
-            spec_config['gt_source_train'] = config['gt_source_train'][i]
-            spec_config['gt_source_test'] = config['gt_source_test'][i]
+        for i, d in enumerate(config["datasets"]):
+            spec_config["dataset_name"] = d
+            spec_config["gt_source_train"] = config["gt_source_train"][i]
+            spec_config["gt_source_test"] = config["gt_source_test"][i]
             if d == "wireframe":
                 self._datasets.append(WireframeDataset(mode, spec_config))
             elif d == "holicity":
-                spec_config['train_split'] = config['train_splits'][i]
+                spec_config["train_split"] = config["train_splits"][i]
                 self._datasets.append(HolicityDataset(mode, spec_config))
             else:
-                raise ValueError("Unknown dataset: " + d)            
+                raise ValueError("Unknown dataset: " + d)
+
+        self._weights = config["weights"]
 
-        self._weights = config['weights']
-    
     def __getitem__(self, item):
-        dataset = self._datasets[np.random.choice(
-            range(len(self._datasets)), p=self._weights)]
+        dataset = self._datasets[
+            np.random.choice(range(len(self._datasets)), p=self._weights)
+        ]
         return dataset[np.random.randint(len(dataset))]
 
     def __len__(self):
diff --git a/third_party/SOLD2/sold2/dataset/synthetic_dataset.py b/third_party/SOLD2/sold2/dataset/synthetic_dataset.py
index cf5f11e5407e65887f4995291156f7cc361843d1..4a1dab47bd81ec831554ba42a635a350ef7a73dc 100644
--- a/third_party/SOLD2/sold2/dataset/synthetic_dataset.py
+++ b/third_party/SOLD2/sold2/dataset/synthetic_dataset.py
@@ -25,9 +25,8 @@ from ..misc.train_utils import parse_h5_data
 
 
 def synthetic_collate_fn(batch):
-    """ Customized collate_fn. """
-    batch_keys = ["image", "junction_map", "heatmap",
-                  "valid_mask", "homography"]
+    """Customized collate_fn."""
+    batch_keys = ["image", "junction_map", "heatmap", "valid_mask", "homography"]
     list_keys = ["junctions", "line_map", "file_key"]
 
     outputs = {}
@@ -36,27 +35,31 @@ def synthetic_collate_fn(batch):
         list_match = sum([_ in data_key for _ in list_keys])
         # print(batch_match, list_match)
         if batch_match > 0 and list_match == 0:
-            outputs[data_key] = torch_loader.default_collate([b[data_key]
-                                                             for b in batch])
+            outputs[data_key] = torch_loader.default_collate(
+                [b[data_key] for b in batch]
+            )
         elif batch_match == 0 and list_match > 0:
             outputs[data_key] = [b[data_key] for b in batch]
         elif batch_match == 0 and list_match == 0:
             continue
         else:
             raise ValueError(
-        "[Error] A key matches batch keys and list keys simultaneously.")
+                "[Error] A key matches batch keys and list keys simultaneously."
+            )
 
     return outputs
 
 
 class SyntheticShapes(Dataset):
-    """ Dataset of synthetic shapes. """
+    """Dataset of synthetic shapes."""
+
     # Initialize the dataset
     def __init__(self, mode="train", config=None):
         super(SyntheticShapes, self).__init__()
         if not mode in ["train", "val", "test"]:
             raise ValueError(
-        "[Error] Supported dataset modes are 'train', 'val', and 'test'.")
+                "[Error] Supported dataset modes are 'train', 'val', and 'test'."
+            )
         self.mode = mode
 
         # Get configuration
@@ -67,14 +70,14 @@ class SyntheticShapes(Dataset):
 
         # Set all available primitives
         self.available_primitives = [
-            'draw_lines',
-            'draw_polygon',
-            'draw_multiple_polygons',
-            'draw_star',
-            'draw_checkerboard_multiseg',
-            'draw_stripes_multiseg',
-            'draw_cube',
-            'gaussian_noise'
+            "draw_lines",
+            "draw_polygon",
+            "draw_multiple_polygons",
+            "draw_star",
+            "draw_checkerboard_multiseg",
+            "draw_stripes_multiseg",
+            "draw_cube",
+            "gaussian_noise",
         ]
 
         # Some cache setting
@@ -88,11 +91,14 @@ class SyntheticShapes(Dataset):
         self.print_dataset_info()
 
         # Initialize h5 file handle
-        self.dataset_path = os.path.join(cfg.synthetic_dataroot, self.dataset_name + ".h5")
-        
+        self.dataset_path = os.path.join(
+            cfg.synthetic_dataroot, self.dataset_name + ".h5"
+        )
+
         # Fix the random seed for torch and numpy in testing mode
-        if ((self.mode == "val" or self.mode == "test")
-            and self.config["add_augmentation_to_all_splits"]):
+        if (self.mode == "val" or self.mode == "test") and self.config[
+            "add_augmentation_to_all_splits"
+        ]:
             seed = self.config.get("test_augmentation_seed", 200)
             np.random.seed(seed)
             torch.manual_seed(seed)
@@ -104,7 +110,7 @@ class SyntheticShapes(Dataset):
     ## Dataset construction related methods ##
     ##########################################
     def construct_dataset(self):
-        """ Dataset constructor. """
+        """Dataset constructor."""
         # Check if the filename cache exists
         # If cache exists, load from cache
         if self._check_dataset_cache():
@@ -117,13 +123,14 @@ class SyntheticShapes(Dataset):
                 print("\t All files exist!")
             # If not, need to re-export the synthetic dataset
             else:
-                print("\t Some files are missing. Re-export the synthetic shape dataset.")
+                print(
+                    "\t Some files are missing. Re-export the synthetic shape dataset."
+                )
                 self.export_synthetic_shapes()
                 print("\t Initialize filename dataset")
                 filename_dataset, datapoints = self.get_filename_dataset()
                 print("\t Create filename dataset cache...")
-                self.create_filename_dataset_cache(filename_dataset,
-                                                   datapoints)
+                self.create_filename_dataset_cache(filename_dataset, datapoints)
 
         # If not, initialize dataset from scratch
         else:
@@ -135,7 +142,9 @@ class SyntheticShapes(Dataset):
 
             # If export dataset does not exist, export from scratch
             else:
-                print("\t Synthetic dataset does not exist. Export the synthetic dataset.")
+                print(
+                    "\t Synthetic dataset does not exist. Export the synthetic dataset."
+                )
                 self.export_synthetic_shapes()
                 print("\t Initialize filename dataset")
 
@@ -146,7 +155,7 @@ class SyntheticShapes(Dataset):
         return filename_dataset, datapoints
 
     def get_cache_name(self):
-        """ Get cache name from dataset config / default config. """
+        """Get cache name from dataset config / default config."""
         if self.config["dataset_name"] is None:
             dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode
         else:
@@ -157,7 +166,7 @@ class SyntheticShapes(Dataset):
         return cache_name
 
     def get_dataset_name(self):
-        """Get dataset name from dataset config / default config. """
+        """Get dataset name from dataset config / default config."""
         if self.config["dataset_name"] is None:
             dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode
         else:
@@ -166,7 +175,7 @@ class SyntheticShapes(Dataset):
         return dataset_name
 
     def get_filename_dataset_from_cache(self):
-        """ Get filename dataset from cache. """
+        """Get filename dataset from cache."""
         # Load from the pkl cache
         cache_file_path = os.path.join(self.cache_path, self.cache_name)
         with open(cache_file_path, "rb") as f:
@@ -175,10 +184,9 @@ class SyntheticShapes(Dataset):
         return data["filename_dataset"], data["datapoints"]
 
     def get_filename_dataset(self):
-        """ Get filename dataset from scratch. """
+        """Get filename dataset from scratch."""
         # Path to the exported dataset
-        dataset_path = os.path.join(cfg.synthetic_dataroot,
-                                    self.dataset_name + ".h5")
+        dataset_path = os.path.join(cfg.synthetic_dataroot, self.dataset_name + ".h5")
 
         filename_dataset = {}
         datapoints = []
@@ -187,8 +195,7 @@ class SyntheticShapes(Dataset):
             # Iterate through all the primitives
             for prim_name in f.keys():
                 filenames = sorted(f[prim_name].keys())
-                filenames_full = [os.path.join(prim_name, _)
-                                  for _ in filenames]
+                filenames_full = [os.path.join(prim_name, _) for _ in filenames]
 
                 filename_dataset[prim_name] = filenames_full
                 datapoints += filenames_full
@@ -196,34 +203,30 @@ class SyntheticShapes(Dataset):
         return filename_dataset, datapoints
 
     def create_filename_dataset_cache(self, filename_dataset, datapoints):
-        """ Create filename dataset cache for faster initialization. """
+        """Create filename dataset cache for faster initialization."""
         # Check cache path exists
         if not os.path.exists(self.cache_path):
             os.makedirs(self.cache_path)
 
         cache_file_path = os.path.join(self.cache_path, self.cache_name)
-        data = {
-            "filename_dataset": filename_dataset,
-            "datapoints": datapoints
-        }
+        data = {"filename_dataset": filename_dataset, "datapoints": datapoints}
         with open(cache_file_path, "wb") as f:
             pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
 
     def export_synthetic_shapes(self):
-        """ Export synthetic shapes to disk. """
+        """Export synthetic shapes to disk."""
         # Set the global random state for data generation
-        synthetic_util.set_random_state(np.random.RandomState(
-            self.config["generation"]["random_seed"]))
+        synthetic_util.set_random_state(
+            np.random.RandomState(self.config["generation"]["random_seed"])
+        )
 
         # Define the export path
-        dataset_path = os.path.join(cfg.synthetic_dataroot,
-                                    self.dataset_name + ".h5")
+        dataset_path = os.path.join(cfg.synthetic_dataroot, self.dataset_name + ".h5")
 
         # Open h5py file
         with h5py.File(dataset_path, "w", libver="latest") as f:
             # Iterate through all types of shape
-            primitives = self.parse_drawing_primitives(
-                self.config["primitives"])
+            primitives = self.parse_drawing_primitives(self.config["primitives"])
             split_size = self.config["generation"]["split_sizes"][self.mode]
             for prim in primitives:
                 # Create h5 group
@@ -234,22 +237,23 @@ class SyntheticShapes(Dataset):
             f.swmr_mode = True
 
     def export_single_primitive(self, primitive, split_size, group):
-        """ Export single primitive. """
+        """Export single primitive."""
         # Check if the primitive is valid or not
         if primitive not in self.available_primitives:
-            raise ValueError(
-        "[Error]: %s is not a supported primitive" % primitive)
+            raise ValueError("[Error]: %s is not a supported primitive" % primitive)
         # Set the random seed
-        synthetic_util.set_random_state(np.random.RandomState(
-            self.config["generation"]["random_seed"]))
+        synthetic_util.set_random_state(
+            np.random.RandomState(self.config["generation"]["random_seed"])
+        )
 
         # Generate shapes
         print("\t Generating %s ..." % primitive)
         for idx in tqdm(range(split_size), ascii=True):
             # Generate background image
             image = synthetic_util.generate_background(
-                self.config['generation']['image_size'],
-                **self.config['generation']['params']['generate_background'])
+                self.config["generation"]["image_size"],
+                **self.config["generation"]["params"]["generate_background"]
+            )
 
             # Generate points
             drawing_func = getattr(synthetic_util, primitive)
@@ -260,14 +264,21 @@ class SyntheticShapes(Dataset):
             min_label_len = self.config["generation"]["min_label_len"]
 
             # Some only take min_label_len, and gaussian noises take nothing
-            if primitive in ["draw_lines", "draw_polygon",
-                             "draw_multiple_polygons", "draw_star"]:
-                data = drawing_func(image, min_len=min_len,
-                                    min_label_len=min_label_len, **kwarg)
-            elif primitive in ["draw_checkerboard_multiseg",
-                               "draw_stripes_multiseg", "draw_cube"]:
-                data = drawing_func(image, min_label_len=min_label_len,
-                                    **kwarg)
+            if primitive in [
+                "draw_lines",
+                "draw_polygon",
+                "draw_multiple_polygons",
+                "draw_star",
+            ]:
+                data = drawing_func(
+                    image, min_len=min_len, min_label_len=min_label_len, **kwarg
+                )
+            elif primitive in [
+                "draw_checkerboard_multiseg",
+                "draw_stripes_multiseg",
+                "draw_cube",
+            ]:
+                data = drawing_func(image, min_label_len=min_label_len, **kwarg)
             else:
                 data = drawing_func(image, **kwarg)
 
@@ -284,21 +295,24 @@ class SyntheticShapes(Dataset):
             image = cv2.GaussianBlur(image, (blur_size, blur_size), 0)
 
             # Resize the image and the point location.
-            points = (points
-                      * np.array(self.config['preprocessing']['resize'],
-                                 np.float)
-                      / np.array(self.config['generation']['image_size'],
-                                 np.float))
+            points = (
+                points
+                * np.array(self.config["preprocessing"]["resize"], np.float)
+                / np.array(self.config["generation"]["image_size"], np.float)
+            )
             image = cv2.resize(
-                image, tuple(self.config['preprocessing']['resize'][::-1]),
-                interpolation=cv2.INTER_LINEAR)
+                image,
+                tuple(self.config["preprocessing"]["resize"][::-1]),
+                interpolation=cv2.INTER_LINEAR,
+            )
             image = np.array(image, dtype=np.uint8)
 
             # Generate the line heatmap after post-processing
             junctions = np.flip(np.round(points).astype(np.int32), axis=1)
-            heatmap = (synthetic_util.get_line_heatmap(
-                junctions, line_map,
-                size=image.shape) * 255.).astype(np.uint8)
+            heatmap = (
+                synthetic_util.get_line_heatmap(junctions, line_map, size=image.shape)
+                * 255.0
+            ).astype(np.uint8)
 
             # Record the data in group
             num_pad = math.ceil(math.log10(split_size)) + 1
@@ -306,17 +320,13 @@ class SyntheticShapes(Dataset):
             file_group = group.create_group(file_key_name)
 
             # Store data
-            file_group.create_dataset("points", data=points,
-                                      compression="gzip")
-            file_group.create_dataset("image", data=image,
-                                      compression="gzip")
-            file_group.create_dataset("line_map", data=line_map,
-                                      compression="gzip")
-            file_group.create_dataset("heatmap", data=heatmap,
-                                      compression="gzip")
+            file_group.create_dataset("points", data=points, compression="gzip")
+            file_group.create_dataset("image", data=image, compression="gzip")
+            file_group.create_dataset("line_map", data=line_map, compression="gzip")
+            file_group.create_dataset("heatmap", data=heatmap, compression="gzip")
 
     def get_default_config(self):
-        """ Get default configuration of the dataset. """
+        """Get default configuration of the dataset."""
         # Initialize the default configuration
         self.default_config = {
             "dataset_name": "synthetic_shape",
@@ -324,43 +334,43 @@ class SyntheticShapes(Dataset):
             "add_augmentation_to_all_splits": False,
             # Shape generation configuration
             "generation": {
-                "split_sizes": {'train': 10000, 'val': 400, 'test': 500},
+                "split_sizes": {"train": 10000, "val": 400, "test": 500},
                 "random_seed": 10,
                 "image_size": [960, 1280],
                 "min_len": 0.09,
                 "min_label_len": 0.1,
-                'params': {
-                    'generate_background': {
-                        'min_kernel_size': 150, 'max_kernel_size': 500,
-                        'min_rad_ratio': 0.02, 'max_rad_ratio': 0.031},
-                    'draw_stripes': {'transform_params': (0.1, 0.1)},
-                    'draw_multiple_polygons': {'kernel_boundaries': (50, 100)}
+                "params": {
+                    "generate_background": {
+                        "min_kernel_size": 150,
+                        "max_kernel_size": 500,
+                        "min_rad_ratio": 0.02,
+                        "max_rad_ratio": 0.031,
+                    },
+                    "draw_stripes": {"transform_params": (0.1, 0.1)},
+                    "draw_multiple_polygons": {"kernel_boundaries": (50, 100)},
                 },
             },
             # Date preprocessing configuration.
-            "preprocessing": {
-                "resize": [240, 320],
-                "blur_size": 11
-            },
-            'augmentation': {
-                'photometric': {
-                    'enable': False,
-                    'primitives': 'all',
-                    'params': {},
-                    'random_order': True,
+            "preprocessing": {"resize": [240, 320], "blur_size": 11},
+            "augmentation": {
+                "photometric": {
+                    "enable": False,
+                    "primitives": "all",
+                    "params": {},
+                    "random_order": True,
                 },
-                'homographic': {
-                    'enable': False,
-                    'params': {},
-                    'valid_border_margin': 0,
+                "homographic": {
+                    "enable": False,
+                    "params": {},
+                    "valid_border_margin": 0,
                 },
-            }
+            },
         }
 
         return self.default_config
 
     def parse_drawing_primitives(self, names):
-        """ Parse the primitives in config to list of primitive names. """
+        """Parse the primitives in config to list of primitive names."""
         if names == "all":
             p = self.available_primitives
         else:
@@ -375,42 +385,42 @@ class SyntheticShapes(Dataset):
 
     @staticmethod
     def get_padded_filename(num_pad, idx):
-        """ Get the padded filename using adaptive padding. """
+        """Get the padded filename using adaptive padding."""
         file_len = len("%d" % (idx))
         filename = "0" * (num_pad - file_len) + "%d" % (idx)
 
         return filename
 
     def print_dataset_info(self):
-        """ Print dataset info. """
+        """Print dataset info."""
         print("\t ---------Summary------------------")
         print("\t Dataset mode: \t\t %s" % self.mode)
         print("\t Number of primitive: \t %d" % len(self.filename_dataset.keys()))
         print("\t Number of data: \t %d" % len(self.datapoints))
         print("\t ----------------------------------")
-    
+
     #########################
     ## Pytorch related API ##
     #########################
     def get_data_from_datapoint(self, datapoint, reader=None):
-        """ Get data given the datapoint
-            (keyname of the h5 dataset e.g. "draw_lines/0000.h5"). """
+        """Get data given the datapoint
+        (keyname of the h5 dataset e.g. "draw_lines/0000.h5")."""
         # Check if the datapoint is valid
         if not datapoint in self.datapoints:
             raise ValueError(
-        "[Error] The specified datapoint is not in available datapoints.")
+                "[Error] The specified datapoint is not in available datapoints."
+            )
 
         # Get data from h5 dataset
         if reader is None:
-            raise ValueError(
-        "[Error] The reader must be provided in __getitem__.")
+            raise ValueError("[Error] The reader must be provided in __getitem__.")
         else:
             data = reader[datapoint]
 
         return parse_h5_data(data)
 
     def get_data_from_signature(self, primitive_name, index):
-        """ Get data given the primitive name and index ("draw_lines", 10) """
+        """Get data given the primitive name and index ("draw_lines", 10)"""
         # Check the primitive name and index
         self._check_primitive_and_index(primitive_name, index)
 
@@ -420,40 +430,41 @@ class SyntheticShapes(Dataset):
         return self.get_data_from_datapoint(datapoint)
 
     def parse_transforms(self, names, all_transforms):
-        trans = all_transforms if (names == 'all') \
+        trans = (
+            all_transforms
+            if (names == "all")
             else (names if isinstance(names, list) else [names])
+        )
         assert set(trans) <= set(all_transforms)
         return trans
 
     def get_photo_transform(self):
-        """ Get list of photometric transforms (according to the config). """
+        """Get list of photometric transforms (according to the config)."""
         # Get the photometric transform config
         photo_config = self.config["augmentation"]["photometric"]
         if not photo_config["enable"]:
-            raise ValueError(
-        "[Error] Photometric augmentation is not enabled.")
-        
+            raise ValueError("[Error] Photometric augmentation is not enabled.")
+
         # Parse photometric transforms
-        trans_lst = self.parse_transforms(photo_config["primitives"],
-                                          photoaug.available_augmentations)
-        trans_config_lst = [photo_config["params"].get(p, {})
-                            for p in trans_lst]
+        trans_lst = self.parse_transforms(
+            photo_config["primitives"], photoaug.available_augmentations
+        )
+        trans_config_lst = [photo_config["params"].get(p, {}) for p in trans_lst]
 
         # List of photometric augmentation
         photometric_trans_lst = [
-            getattr(photoaug, trans)(**conf) \
+            getattr(photoaug, trans)(**conf)
             for (trans, conf) in zip(trans_lst, trans_config_lst)
         ]
 
         return photometric_trans_lst
-    
+
     def get_homo_transform(self):
-        """ Get homographic transforms (according to the config). """
+        """Get homographic transforms (according to the config)."""
         # Get homographic transforms for image
         homo_config = self.config["augmentation"]["homographic"]["params"]
         if not self.config["augmentation"]["homographic"]["enable"]:
-            raise ValueError(
-        "[Error] Homographic augmentation is not enabled")
+            raise ValueError("[Error] Homographic augmentation is not enabled")
 
         # Parse the homographic transforms
         # ToDo: use the shape from the config
@@ -464,33 +475,35 @@ class SyntheticShapes(Dataset):
             min_label_tmp = self.config["generation"]["min_label_len"]
         except:
             min_label_tmp = None
-        
+
         # float label len => fraction
-        if isinstance(min_label_tmp, float): # Skip if not provided
+        if isinstance(min_label_tmp, float):  # Skip if not provided
             min_label_len = min_label_tmp * min(image_shape)
         # int label len => length in pixel
         elif isinstance(min_label_tmp, int):
-            scale_ratio = (self.config["preprocessing"]["resize"]
-                           / self.config["generation"]["image_size"][0])
-            min_label_len = (self.config["generation"]["min_label_len"]
-                             * scale_ratio)
+            scale_ratio = (
+                self.config["preprocessing"]["resize"]
+                / self.config["generation"]["image_size"][0]
+            )
+            min_label_len = self.config["generation"]["min_label_len"] * scale_ratio
         # if none => no restriction
         else:
             min_label_len = 0
-        
+
         # Initialize the transform
         homographic_trans = homoaug.homography_transform(
-            image_shape, homo_config, 0, min_label_len)
+            image_shape, homo_config, 0, min_label_len
+        )
 
         return homographic_trans
 
     @staticmethod
     def junc_to_junc_map(junctions, image_size):
-        """ Convert junction points to junction maps. """
+        """Convert junction points to junction maps."""
         junctions = np.round(junctions).astype(np.int)
         # Clip the boundary by image size
-        junctions[:, 0] = np.clip(junctions[:, 0], 0., image_size[0]-1)
-        junctions[:, 1] = np.clip(junctions[:, 1], 0., image_size[1]-1)
+        junctions[:, 0] = np.clip(junctions[:, 0], 0.0, image_size[0] - 1)
+        junctions[:, 1] = np.clip(junctions[:, 1], 0.0, image_size[1] - 1)
 
         # Create junction map
         junc_map = np.zeros([image_size[0], image_size[1]])
@@ -499,7 +512,7 @@ class SyntheticShapes(Dataset):
         return junc_map[..., None].astype(np.int)
 
     def train_preprocessing(self, data, disable_homoaug=False):
-        """ Training preprocessing. """
+        """Training preprocessing."""
         # Fetch corresponding entries
         image = data["image"]
         junctions = data["points"]
@@ -509,29 +522,32 @@ class SyntheticShapes(Dataset):
 
         # Resize the image before the photometric and homographic transforms
         # Check if we need to do the resizing
-        if not(list(image.shape) == self.config["preprocessing"]["resize"]):
+        if not (list(image.shape) == self.config["preprocessing"]["resize"]):
             # Resize the image and the point location.
             size_old = list(image.shape)
             image = cv2.resize(
-                image, tuple(self.config['preprocessing']['resize'][::-1]),
-                interpolation=cv2.INTER_LINEAR)
+                image,
+                tuple(self.config["preprocessing"]["resize"][::-1]),
+                interpolation=cv2.INTER_LINEAR,
+            )
             image = np.array(image, dtype=np.uint8)
 
             junctions = (
                 junctions
-                * np.array(self.config['preprocessing']['resize'], np.float)
-                / np.array(size_old, np.float))
+                * np.array(self.config["preprocessing"]["resize"], np.float)
+                / np.array(size_old, np.float)
+            )
 
             # Generate the line heatmap after post-processing
-            junctions_xy = np.flip(np.round(junctions).astype(np.int32),
-                                   axis=1)
-            heatmap = synthetic_util.get_line_heatmap(junctions_xy, line_map,
-                                                      size=image.shape)
-            heatmap = (heatmap * 255.).astype(np.uint8)
+            junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1)
+            heatmap = synthetic_util.get_line_heatmap(
+                junctions_xy, line_map, size=image.shape
+            )
+            heatmap = (heatmap * 255.0).astype(np.uint8)
 
             # Update image size
             image_size = image.shape[:2]
-        
+
         # Declare default valid mask (all ones)
         valid_mask = np.ones(image_size)
 
@@ -544,7 +560,8 @@ class SyntheticShapes(Dataset):
             ### Image transform ###
             np.random.shuffle(photo_trans_lst)
             image_transform = transforms.Compose(
-                photo_trans_lst + [photoaug.normalize_image()])
+                photo_trans_lst + [photoaug.normalize_image()]
+            )
         else:
             image_transform = photoaug.normalize_image()
         image = image_transform(image)
@@ -554,40 +571,46 @@ class SyntheticShapes(Dataset):
         # Convert to tensor and return the results
         to_tensor = transforms.ToTensor()
         # Check homographic augmentation
-        if (self.config["augmentation"]["homographic"]["enable"]
-            and disable_homoaug == False):
+        if (
+            self.config["augmentation"]["homographic"]["enable"]
+            and disable_homoaug == False
+        ):
             homo_trans = self.get_homo_transform()
             # Perform homographic transform
             homo_outputs = homo_trans(image, junctions, line_map)
 
             # Record the warped results
-            junctions = homo_outputs["junctions"]    # Should be HW format
+            junctions = homo_outputs["junctions"]  # Should be HW format
             image = homo_outputs["warped_image"]
             line_map = homo_outputs["line_map"]
             heatmap = homo_outputs["warped_heatmap"]
             valid_mask = homo_outputs["valid_mask"]  # Same for pos and neg
             homography_mat = homo_outputs["homo"]
-            
+
             # Optionally put warpping information first.
-            outputs["homography_mat"] = to_tensor(
-                homography_mat).to(torch.float32)[0, ...]
+            outputs["homography_mat"] = to_tensor(homography_mat).to(torch.float32)[
+                0, ...
+            ]
 
         junction_map = self.junc_to_junc_map(junctions, image_size)
 
-        outputs.update({
-            "image": to_tensor(image),
-            "junctions": to_tensor(np.ascontiguousarray(
-                junctions).copy()).to(torch.float32)[0, ...],
-            "junction_map": to_tensor(junction_map).to(torch.int),
-            "line_map": to_tensor(line_map).to(torch.int32)[0, ...],
-            "heatmap": to_tensor(heatmap).to(torch.int32),
-            "valid_mask": to_tensor(valid_mask).to(torch.int32),
-        })
+        outputs.update(
+            {
+                "image": to_tensor(image),
+                "junctions": to_tensor(np.ascontiguousarray(junctions).copy()).to(
+                    torch.float32
+                )[0, ...],
+                "junction_map": to_tensor(junction_map).to(torch.int),
+                "line_map": to_tensor(line_map).to(torch.int32)[0, ...],
+                "heatmap": to_tensor(heatmap).to(torch.int32),
+                "valid_mask": to_tensor(valid_mask).to(torch.int32),
+            }
+        )
 
         return outputs
 
     def test_preprocessing(self, data):
-        """ Test preprocessing. """
+        """Test preprocessing."""
         # Fetch corresponding entries
         image = data["image"]
         points = data["points"]
@@ -600,20 +623,24 @@ class SyntheticShapes(Dataset):
             # Resize the image and the point location.
             size_old = list(image.shape)
             image = cv2.resize(
-                image, tuple(self.config['preprocessing']['resize'][::-1]),
-                interpolation=cv2.INTER_LINEAR)
+                image,
+                tuple(self.config["preprocessing"]["resize"][::-1]),
+                interpolation=cv2.INTER_LINEAR,
+            )
             image = np.array(image, dtype=np.uint8)
 
-            points = (points
-                      * np.array(self.config['preprocessing']['resize'],
-                                 np.float)
-                      / np.array(size_old, np.float))
+            points = (
+                points
+                * np.array(self.config["preprocessing"]["resize"], np.float)
+                / np.array(size_old, np.float)
+            )
 
             # Generate the line heatmap after post-processing
             junctions = np.flip(np.round(points).astype(np.int32), axis=1)
-            heatmap = synthetic_util.get_line_heatmap(junctions, line_map,
-                                                      size=image.shape)
-            heatmap = (heatmap * 255.).astype(np.uint8)
+            heatmap = synthetic_util.get_line_heatmap(
+                junctions, line_map, size=image.shape
+            )
+            heatmap = (heatmap * 255.0).astype(np.uint8)
 
             # Update image size
             image_size = image.shape[:2]
@@ -638,7 +665,7 @@ class SyntheticShapes(Dataset):
             "junction_map": junction_map,
             "line_map": line_map,
             "heatmap": heatmap,
-            "valid_mask": valid_mask
+            "valid_mask": valid_mask,
         }
 
     def __getitem__(self, index):
@@ -649,8 +676,7 @@ class SyntheticShapes(Dataset):
             data = self.get_data_from_datapoint(datapoint, reader)
 
         # Apply different transforms in different mod.
-        if (self.mode == "train"
-            or self.config["add_augmentation_to_all_splits"]):
+        if self.mode == "train" or self.config["add_augmentation_to_all_splits"]:
             return_type = self.config.get("return_type", "single")
             data = self.train_preprocessing(data)
         else:
@@ -665,7 +691,7 @@ class SyntheticShapes(Dataset):
     ## Some other methods ##
     ########################
     def _check_dataset_cache(self):
-        """ Check if dataset cache exists. """
+        """Check if dataset cache exists."""
         cache_file_path = os.path.join(self.cache_path, self.cache_name)
         if os.path.exists(cache_file_path):
             return True
@@ -673,7 +699,7 @@ class SyntheticShapes(Dataset):
             return False
 
     def _check_export_dataset(self):
-        """ Check if exported dataset exists. """
+        """Check if exported dataset exists."""
         dataset_path = os.path.join(cfg.synthetic_dataroot, self.dataset_name)
         if os.path.exists(dataset_path) and len(os.listdir(dataset_path)) > 0:
             return True
@@ -681,32 +707,30 @@ class SyntheticShapes(Dataset):
             return False
 
     def _check_file_existence(self, filename_dataset):
-        """ Check if all exported file exists. """
+        """Check if all exported file exists."""
         # Path to the exported dataset
-        dataset_path = os.path.join(cfg.synthetic_dataroot, 
-                                    self.dataset_name + ".h5")
+        dataset_path = os.path.join(cfg.synthetic_dataroot, self.dataset_name + ".h5")
 
         flag = True
         # Open the h5 dataset
         with h5py.File(dataset_path, "r") as f:
             # Iterate through all the primitives
             for prim_name in f.keys():
-                if (len(filename_dataset[prim_name])
-                    != len(f[prim_name].keys())):
+                if len(filename_dataset[prim_name]) != len(f[prim_name].keys()):
                     flag = False
 
         return flag
 
     def _check_primitive_and_index(self, primitive, index):
-        """ Check if the primitve and index are valid. """
+        """Check if the primitve and index are valid."""
         # Check primitives
         if not primitive in self.available_primitives:
-            raise ValueError(
-                "[Error] The primitive is not in available primitives.")
+            raise ValueError("[Error] The primitive is not in available primitives.")
 
         prim_len = len(self.filename_dataset[primitive])
         # Check the index
         if not index < prim_len:
             raise ValueError(
                 "[Error] The index exceeds the total file counts %d for %s"
-                % (prim_len, primitive))
+                % (prim_len, primitive)
+            )
diff --git a/third_party/SOLD2/sold2/dataset/synthetic_util.py b/third_party/SOLD2/sold2/dataset/synthetic_util.py
index af009e0ce7e91391e31d7069064ae6121aa84cc0..63e41c5bbcadd4a1a633a2b33392dc6d4fd088ff 100644
--- a/third_party/SOLD2/sold2/dataset/synthetic_util.py
+++ b/third_party/SOLD2/sold2/dataset/synthetic_util.py
@@ -17,8 +17,8 @@ def set_random_state(state):
 
 
 def get_random_color(background_color):
-    """ Output a random scalar in grayscale with a least a small contrast
-        with the background color. """
+    """Output a random scalar in grayscale with a least a small contrast
+    with the background color."""
     color = random_state.randint(256)
     if abs(color - background_color) < 30:  # not enough contrast
         color = (color + 128) % 256
@@ -26,7 +26,7 @@ def get_random_color(background_color):
 
 
 def get_different_color(previous_colors, min_dist=50, max_count=20):
-    """ Output a color that contrasts with the previous colors.
+    """Output a color that contrasts with the previous colors.
     Parameters:
       previous_colors: np.array of the previous colors
       min_dist: the difference between the new color and
@@ -42,7 +42,7 @@ def get_different_color(previous_colors, min_dist=50, max_count=20):
 
 
 def add_salt_and_pepper(img):
-    """ Add salt and pepper noise to an image. """
+    """Add salt and pepper noise to an image."""
     noise = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
     cv.randu(noise, 0, 255)
     black = noise < 30
@@ -53,10 +53,15 @@ def add_salt_and_pepper(img):
     return np.empty((0, 2), dtype=np.int)
 
 
-def generate_background(size=(960, 1280), nb_blobs=100, min_rad_ratio=0.01,
-                        max_rad_ratio=0.05, min_kernel_size=50,
-                        max_kernel_size=300):
-    """ Generate a customized background image.
+def generate_background(
+    size=(960, 1280),
+    nb_blobs=100,
+    min_rad_ratio=0.01,
+    max_rad_ratio=0.05,
+    min_kernel_size=50,
+    max_kernel_size=300,
+):
+    """Generate a customized background image.
     Parameters:
       size: size of the image
       nb_blobs: number of circles to draw
@@ -71,22 +76,30 @@ def generate_background(size=(960, 1280), nb_blobs=100, min_rad_ratio=0.01,
     cv.threshold(img, random_state.randint(256), 255, cv.THRESH_BINARY, img)
     background_color = int(np.mean(img))
     blobs = np.concatenate(
-        [random_state.randint(0, size[1], size=(nb_blobs, 1)),
-         random_state.randint(0, size[0], size=(nb_blobs, 1))], axis=1)
+        [
+            random_state.randint(0, size[1], size=(nb_blobs, 1)),
+            random_state.randint(0, size[0], size=(nb_blobs, 1)),
+        ],
+        axis=1,
+    )
     for i in range(nb_blobs):
         col = get_random_color(background_color)
-        cv.circle(img, (blobs[i][0], blobs[i][1]),
-                  np.random.randint(int(dim * min_rad_ratio),
-                                    int(dim * max_rad_ratio)),
-                  col, -1)
+        cv.circle(
+            img,
+            (blobs[i][0], blobs[i][1]),
+            np.random.randint(int(dim * min_rad_ratio), int(dim * max_rad_ratio)),
+            col,
+            -1,
+        )
     kernel_size = random_state.randint(min_kernel_size, max_kernel_size)
     cv.blur(img, (kernel_size, kernel_size), img)
     return img
 
 
-def generate_custom_background(size, background_color, nb_blobs=3000,
-                               kernel_boundaries=(50, 100)):
-    """ Generate a customized background to fill the shapes.
+def generate_custom_background(
+    size, background_color, nb_blobs=3000, kernel_boundaries=(50, 100)
+):
+    """Generate a customized background to fill the shapes.
     Parameters:
       background_color: average color of the background image
       nb_blobs: number of circles to draw
@@ -95,20 +108,22 @@ def generate_custom_background(size, background_color, nb_blobs=3000,
     img = np.zeros(size, dtype=np.uint8)
     img = img + get_random_color(background_color)
     blobs = np.concatenate(
-        [np.random.randint(0, size[1], size=(nb_blobs, 1)),
-         np.random.randint(0, size[0], size=(nb_blobs, 1))], axis=1)
+        [
+            np.random.randint(0, size[1], size=(nb_blobs, 1)),
+            np.random.randint(0, size[0], size=(nb_blobs, 1)),
+        ],
+        axis=1,
+    )
     for i in range(nb_blobs):
         col = get_random_color(background_color)
-        cv.circle(img, (blobs[i][0], blobs[i][1]),
-                  np.random.randint(20), col, -1)
-    kernel_size = np.random.randint(kernel_boundaries[0],
-                                    kernel_boundaries[1])
+        cv.circle(img, (blobs[i][0], blobs[i][1]), np.random.randint(20), col, -1)
+    kernel_size = np.random.randint(kernel_boundaries[0], kernel_boundaries[1])
     cv.blur(img, (kernel_size, kernel_size), img)
     return img
 
 
 def final_blur(img, kernel_size=(5, 5)):
-    """ Gaussian blur applied to an image.
+    """Gaussian blur applied to an image.
     Parameters:
       kernel_size: size of the kernel
     """
@@ -116,33 +131,39 @@ def final_blur(img, kernel_size=(5, 5)):
 
 
 def ccw(A, B, C, dim):
-    """ Check if the points are listed in counter-clockwise order. """
+    """Check if the points are listed in counter-clockwise order."""
     if dim == 2:  # only 2 dimensions
-        return((C[:, 1] - A[:, 1]) * (B[:, 0] - A[:, 0])
-               > (B[:, 1] - A[:, 1]) * (C[:, 0] - A[:, 0]))
+        return (C[:, 1] - A[:, 1]) * (B[:, 0] - A[:, 0]) > (B[:, 1] - A[:, 1]) * (
+            C[:, 0] - A[:, 0]
+        )
     else:  # dim should be equal to 3
-        return((C[:, 1, :] - A[:, 1, :])
-               * (B[:, 0, :] - A[:, 0, :])
-               > (B[:, 1, :] - A[:, 1, :])
-               * (C[:, 0, :] - A[:, 0, :]))
+        return (C[:, 1, :] - A[:, 1, :]) * (B[:, 0, :] - A[:, 0, :]) > (
+            B[:, 1, :] - A[:, 1, :]
+        ) * (C[:, 0, :] - A[:, 0, :])
 
 
 def intersect(A, B, C, D, dim):
-    """ Return true if line segments AB and CD intersect """
-    return np.any((ccw(A, C, D, dim) != ccw(B, C, D, dim)) &
-                  (ccw(A, B, C, dim) != ccw(A, B, D, dim)))
+    """Return true if line segments AB and CD intersect"""
+    return np.any(
+        (ccw(A, C, D, dim) != ccw(B, C, D, dim))
+        & (ccw(A, B, C, dim) != ccw(A, B, D, dim))
+    )
 
 
 def keep_points_inside(points, size):
-    """ Keep only the points whose coordinates are inside the dimensions of
-    the image of size 'size' """
-    mask = (points[:, 0] >= 0) & (points[:, 0] < size[1]) &\
-           (points[:, 1] >= 0) & (points[:, 1] < size[0])
+    """Keep only the points whose coordinates are inside the dimensions of
+    the image of size 'size'"""
+    mask = (
+        (points[:, 0] >= 0)
+        & (points[:, 0] < size[1])
+        & (points[:, 1] >= 0)
+        & (points[:, 1] < size[0])
+    )
     return points[mask, :]
 
 
 def get_unique_junctions(segments, min_label_len):
-    """ Get unique junction points from line segments. """
+    """Get unique junction points from line segments."""
     # Get all junctions from segments
     junctions_all = np.concatenate((segments[:, :2], segments[:, 2:]), axis=0)
     if junctions_all.shape[0] == 0:
@@ -159,7 +180,7 @@ def get_unique_junctions(segments, min_label_len):
 
 
 def get_line_map(points: np.ndarray, segments: np.ndarray) -> np.ndarray:
-    """ Get line map given the points and segment sets. """
+    """Get line map given the points and segment sets."""
     # create empty line map
     num_point = points.shape[0]
     line_map = np.zeros([num_point, num_point])
@@ -183,7 +204,7 @@ def get_line_map(points: np.ndarray, segments: np.ndarray) -> np.ndarray:
 
 
 def get_line_heatmap(junctions, line_map, size=[480, 640], thickness=1):
-    """ Get line heat map from junctions and line map. """
+    """Get line heat map from junctions and line map."""
     # Make sure that the thickness is 1
     if not isinstance(thickness, int):
         thickness = int(thickness)
@@ -195,7 +216,7 @@ def get_line_heatmap(junctions, line_map, size=[480, 640], thickness=1):
     # Initialize empty map
     heat_map = np.zeros(size)
 
-    if junctions.shape[0] > 0: # If empty, just return zero map
+    if junctions.shape[0] > 0:  # If empty, just return zero map
         # Iterate through all the junctions
         for idx in range(junctions.shape[0]):
             # if no connectivity, just skip it
@@ -209,13 +230,13 @@ def get_line_heatmap(junctions, line_map, size=[480, 640], thickness=1):
                     point2 = junctions[idx2, :]
 
                     # Draw line
-                    cv.line(heat_map, tuple(point1), tuple(point2), 1., thickness)
+                    cv.line(heat_map, tuple(point1), tuple(point2), 1.0, thickness)
 
     return heat_map
 
 
 def draw_lines(img, nb_lines=10, min_len=32, min_label_len=32):
-    """ Draw random lines and output the positions of the pair of junctions
+    """Draw random lines and output the positions of the pair of junctions
         and line associativities.
     Parameters:
       nb_lines: maximal number of lines
@@ -228,9 +249,9 @@ def draw_lines(img, nb_lines=10, min_len=32, min_label_len=32):
     min_dim = min(img.shape)
 
     # Convert length constrain to pixel if given float number
-    if isinstance(min_len, float) and min_len <= 1.:
+    if isinstance(min_len, float) and min_len <= 1.0:
         min_len = int(min_dim * min_len)
-    if isinstance(min_label_len, float) and min_label_len <= 1.:
+    if isinstance(min_label_len, float) and min_label_len <= 1.0:
         min_label_len = int(min_dim * min_label_len)
 
     # Generate lines one by one
@@ -258,10 +279,8 @@ def draw_lines(img, nb_lines=10, min_len=32, min_label_len=32):
         # Only record the segments longer than min_label_len
         seg_len = math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
         if seg_len >= min_label_len:
-            segments = np.concatenate([segments,
-                                       np.array([[x1, y1, x2, y2]])], axis=0)
-            points = np.concatenate([points,
-                                     np.array([[x1, y1], [x2, y2]])], axis=0)
+            segments = np.concatenate([segments, np.array([[x1, y1, x2, y2]])], axis=0)
+            points = np.concatenate([points, np.array([[x1, y1], [x2, y2]])], axis=0)
 
     # If no line is drawn, recursively call the function
     if points.shape[0] == 0:
@@ -270,19 +289,16 @@ def draw_lines(img, nb_lines=10, min_len=32, min_label_len=32):
     # Get the line associativity map
     line_map = get_line_map(points, segments)
 
-    return {
-        "points": points,
-        "line_map": line_map
-    }
+    return {"points": points, "line_map": line_map}
 
 
 def check_segment_len(segments, min_len=32):
-    """ Check if one of the segments is too short (True means too short). """
+    """Check if one of the segments is too short (True means too short)."""
     point1_vec = segments[:, :2]
     point2_vec = segments[:, 2:]
     diff = point1_vec - point2_vec
 
-    dist = np.sqrt(np.sum(diff ** 2, axis=1))
+    dist = np.sqrt(np.sum(diff**2, axis=1))
     if np.any(dist < min_len):
         return True
     else:
@@ -290,7 +306,7 @@ def check_segment_len(segments, min_len=32):
 
 
 def draw_polygon(img, max_sides=8, min_len=32, min_label_len=64):
-    """ Draw a polygon with a random number of corners and return the position
+    """Draw a polygon with a random number of corners and return the position
         of the junctions + line map.
     Parameters:
       max_sides: maximal number of sides + 1
@@ -303,31 +319,42 @@ def draw_polygon(img, max_sides=8, min_len=32, min_label_len=64):
     y = random_state.randint(rad, img.shape[0] - rad)
 
     # Convert length constrain to pixel if given float number
-    if isinstance(min_len, float) and min_len <= 1.:
+    if isinstance(min_len, float) and min_len <= 1.0:
         min_len = int(min_dim * min_len)
-    if isinstance(min_label_len, float) and min_label_len <= 1.:
+    if isinstance(min_label_len, float) and min_label_len <= 1.0:
         min_label_len = int(min_dim * min_label_len)
 
     # Sample num_corners points inside the circle
     slices = np.linspace(0, 2 * math.pi, num_corners + 1)
-    angles = [slices[i] + random_state.rand() * (slices[i+1] - slices[i])
-              for i in range(num_corners)]
+    angles = [
+        slices[i] + random_state.rand() * (slices[i + 1] - slices[i])
+        for i in range(num_corners)
+    ]
     points = np.array(
-        [[int(x + max(random_state.rand(), 0.4) * rad * math.cos(a)),
-          int(y + max(random_state.rand(), 0.4) * rad * math.sin(a))]
-         for a in angles])
+        [
+            [
+                int(x + max(random_state.rand(), 0.4) * rad * math.cos(a)),
+                int(y + max(random_state.rand(), 0.4) * rad * math.sin(a)),
+            ]
+            for a in angles
+        ]
+    )
 
     # Filter the points that are too close or that have an angle too flat
-    norms = [np.linalg.norm(points[(i-1) % num_corners, :]
-                            - points[i, :]) for i in range(num_corners)]
+    norms = [
+        np.linalg.norm(points[(i - 1) % num_corners, :] - points[i, :])
+        for i in range(num_corners)
+    ]
     mask = np.array(norms) > 0.01
     points = points[mask, :]
     num_corners = points.shape[0]
-    corner_angles = [angle_between_vectors(points[(i-1) % num_corners, :] -
-                                           points[i, :],
-                                           points[(i+1) % num_corners, :] -
-                                           points[i, :])
-                     for i in range(num_corners)]
+    corner_angles = [
+        angle_between_vectors(
+            points[(i - 1) % num_corners, :] - points[i, :],
+            points[(i + 1) % num_corners, :] - points[i, :],
+        )
+        for i in range(num_corners)
+    ]
     mask = np.array(corner_angles) < (2 * math.pi / 3)
     points = points[mask, :]
     num_corners = points.shape[0]
@@ -349,8 +376,7 @@ def draw_polygon(img, max_sides=8, min_len=32, min_label_len=64):
         seg_len = np.sqrt(np.sum((p1 - p2) ** 2))
         if seg_len >= min_label_len:
             segments = np.concatenate((segments, segment[None, ...]), axis=0)
-        segments_raw = np.concatenate((segments_raw, segment[None, ...]),
-                                      axis=0)
+        segments_raw = np.concatenate((segments_raw, segment[None, ...]), axis=0)
 
     # If not enough corner, just regenerate one
     if (num_corners < 3) or check_segment_len(segments_raw, min_len):
@@ -372,15 +398,12 @@ def draw_polygon(img, max_sides=8, min_len=32, min_label_len=64):
     col = get_random_color(int(np.mean(img)))
     cv.fillPoly(img, [corners], col)
 
-    return {
-        "points": junc_points,
-        "line_map": line_map
-    }
+    return {"points": junc_points, "line_map": line_map}
 
 
 def overlap(center, rad, centers, rads):
-    """ Check that the circle with (center, rad)
-        doesn't overlap with the other circles. """
+    """Check that the circle with (center, rad)
+    doesn't overlap with the other circles."""
     flag = False
     for i in range(len(rads)):
         if np.linalg.norm(center - centers[i]) < rad + rads[i]:
@@ -390,15 +413,22 @@ def overlap(center, rad, centers, rads):
 
 
 def angle_between_vectors(v1, v2):
-    """ Compute the angle (in rad) between the two vectors v1 and v2. """
+    """Compute the angle (in rad) between the two vectors v1 and v2."""
     v1_u = v1 / np.linalg.norm(v1)
     v2_u = v2 / np.linalg.norm(v2)
     return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))
 
 
-def draw_multiple_polygons(img, max_sides=8, nb_polygons=30, min_len=32,
-                           min_label_len=64, safe_margin=5, **extra):
-    """ Draw multiple polygons with a random number of corners
+def draw_multiple_polygons(
+    img,
+    max_sides=8,
+    nb_polygons=30,
+    min_len=32,
+    min_label_len=64,
+    safe_margin=5,
+    **extra
+):
+    """Draw multiple polygons with a random number of corners
         and return the junction points + line map.
     Parameters:
       max_sides: maximal number of sides + 1
@@ -413,11 +443,11 @@ def draw_multiple_polygons(img, max_sides=8, nb_polygons=30, min_len=32,
 
     min_dim = min(img.shape[0], img.shape[1])
     # Convert length constrain to pixel if given float number
-    if isinstance(min_len, float) and min_len <= 1.:
+    if isinstance(min_len, float) and min_len <= 1.0:
         min_len = int(min_dim * min_len)
-    if isinstance(min_label_len, float) and min_label_len <= 1.:
+    if isinstance(min_label_len, float) and min_label_len <= 1.0:
         min_label_len = int(min_dim * min_label_len)
-    if isinstance(safe_margin, float) and safe_margin <= 1.:
+    if isinstance(safe_margin, float) and safe_margin <= 1.0:
         safe_margin = int(min_dim * safe_margin)
 
     # Sequentially generate polygons
@@ -435,8 +465,10 @@ def draw_multiple_polygons(img, max_sides=8, nb_polygons=30, min_len=32,
 
         # Sample num_corners points inside the circle
         slices = np.linspace(0, 2 * math.pi, num_corners + 1)
-        angles = [slices[i] + random_state.rand() * (slices[i+1] - slices[i])
-                  for i in range(num_corners)]
+        angles = [
+            slices[i] + random_state.rand() * (slices[i + 1] - slices[i])
+            for i in range(num_corners)
+        ]
 
         # Sample outer points and inner points
         new_points = []
@@ -444,29 +476,38 @@ def draw_multiple_polygons(img, max_sides=8, nb_polygons=30, min_len=32,
         for a in angles:
             x_offset = max(random_state.rand(), 0.4)
             y_offset = max(random_state.rand(), 0.4)
-            new_points.append([int(x + x_offset * rad * math.cos(a)),
-                               int(y + y_offset * rad * math.sin(a))])
+            new_points.append(
+                [
+                    int(x + x_offset * rad * math.cos(a)),
+                    int(y + y_offset * rad * math.sin(a)),
+                ]
+            )
             new_points_real.append(
-                [int(x + x_offset * rad_real * math.cos(a)),
-                 int(y + y_offset * rad_real * math.sin(a))])
+                [
+                    int(x + x_offset * rad_real * math.cos(a)),
+                    int(y + y_offset * rad_real * math.sin(a)),
+                ]
+            )
         new_points = np.array(new_points)
         new_points_real = np.array(new_points_real)
 
         # Filter the points that are too close or that have an angle too flat
-        norms = [np.linalg.norm(new_points[(i-1) % num_corners, :]
-                                - new_points[i, :])
-                 for i in range(num_corners)]
+        norms = [
+            np.linalg.norm(new_points[(i - 1) % num_corners, :] - new_points[i, :])
+            for i in range(num_corners)
+        ]
         mask = np.array(norms) > 0.01
         new_points = new_points[mask, :]
         new_points_real = new_points_real[mask, :]
 
         num_corners = new_points.shape[0]
         corner_angles = [
-            angle_between_vectors(new_points[(i-1) % num_corners, :] -
-                                  new_points[i, :],
-                                  new_points[(i+1) % num_corners, :] -
-                                  new_points[i, :])
-            for i in range(num_corners)]
+            angle_between_vectors(
+                new_points[(i - 1) % num_corners, :] - new_points[i, :],
+                new_points[(i + 1) % num_corners, :] - new_points[i, :],
+            )
+            for i in range(num_corners)
+        ]
         mask = np.array(corner_angles) < (2 * math.pi / 3)
         new_points = new_points[mask, :]
         new_points_real = new_points_real[mask, :]
@@ -480,28 +521,32 @@ def draw_multiple_polygons(img, max_sides=8, nb_polygons=30, min_len=32,
         new_segments = np.zeros((1, 4, num_corners))
         new_segments[:, 0, :] = [new_points[i][0] for i in range(num_corners)]
         new_segments[:, 1, :] = [new_points[i][1] for i in range(num_corners)]
-        new_segments[:, 2, :] = [new_points[(i+1) % num_corners][0]
-                                 for i in range(num_corners)]
-        new_segments[:, 3, :] = [new_points[(i+1) % num_corners][1]
-                                 for i in range(num_corners)]
+        new_segments[:, 2, :] = [
+            new_points[(i + 1) % num_corners][0] for i in range(num_corners)
+        ]
+        new_segments[:, 3, :] = [
+            new_points[(i + 1) % num_corners][1] for i in range(num_corners)
+        ]
 
         # Segments to record (inner circle)
         new_segments_real = np.zeros((1, 4, num_corners))
-        new_segments_real[:, 0, :] = [new_points_real[i][0]
-                                      for i in range(num_corners)]
-        new_segments_real[:, 1, :] = [new_points_real[i][1]
-                                      for i in range(num_corners)]
+        new_segments_real[:, 0, :] = [new_points_real[i][0] for i in range(num_corners)]
+        new_segments_real[:, 1, :] = [new_points_real[i][1] for i in range(num_corners)]
         new_segments_real[:, 2, :] = [
-            new_points_real[(i + 1) % num_corners][0]
-            for i in range(num_corners)]
+            new_points_real[(i + 1) % num_corners][0] for i in range(num_corners)
+        ]
         new_segments_real[:, 3, :] = [
-            new_points_real[(i + 1) % num_corners][1]
-            for i in range(num_corners)]
+            new_points_real[(i + 1) % num_corners][1] for i in range(num_corners)
+        ]
 
         # Check that the polygon will not overlap with pre-existing shapes
-        if intersect(segments[:, 0:2, None], segments[:, 2:4, None],
-                     new_segments[:, 0:2, :], new_segments[:, 2:4, :],
-                     3) or overlap(np.array([x, y]), rad, centers, rads):
+        if intersect(
+            segments[:, 0:2, None],
+            segments[:, 2:4, None],
+            new_segments[:, 0:2, :],
+            new_segments[:, 2:4, :],
+            3,
+        ) or overlap(np.array([x, y]), rad, centers, rads):
             continue
 
         # Check that the the edges of the polygon is not too short
@@ -515,20 +560,19 @@ def draw_multiple_polygons(img, max_sides=8, nb_polygons=30, min_len=32,
         segments = np.concatenate([segments, new_segments], axis=0)
 
         # Only record the segments longer than min_label_len
-        new_segments_real = np.reshape(np.swapaxes(new_segments_real, 0, 2),
-                                       (-1, 4))
+        new_segments_real = np.reshape(np.swapaxes(new_segments_real, 0, 2), (-1, 4))
         points1 = new_segments_real[:, :2]
         points2 = new_segments_real[:, 2:]
         seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1))
         new_label_segment = new_segments_real[seg_len >= min_label_len, :]
-        label_segments = np.concatenate([label_segments, new_label_segment],
-                                        axis=0)
+        label_segments = np.concatenate([label_segments, new_label_segment], axis=0)
 
         # Color the polygon with a custom background
         corners = new_points_real.reshape((-1, 1, 2))
         mask = np.zeros(img.shape, np.uint8)
         custom_background = generate_custom_background(
-            img.shape, background_color, **extra)
+            img.shape, background_color, **extra
+        )
 
         cv.fillPoly(mask, [corners], 255)
         locs = np.where(mask != 0)
@@ -537,7 +581,8 @@ def draw_multiple_polygons(img, max_sides=8, nb_polygons=30, min_len=32,
 
     # Get all junctions from label segments
     junctions_all = np.concatenate(
-        (label_segments[:, :2], label_segments[:, 2:]), axis=0)
+        (label_segments[:, :2], label_segments[:, 2:]), axis=0
+    )
     if junctions_all.shape[0] == 0:
         junc_points = None
         line_map = None
@@ -548,14 +593,11 @@ def draw_multiple_polygons(img, max_sides=8, nb_polygons=30, min_len=32,
         # Generate line map from points and segments
         line_map = get_line_map(junc_points, label_segments)
 
-    return {
-        "points": junc_points,
-        "line_map": line_map
-    }
+    return {"points": junc_points, "line_map": line_map}
 
 
 def draw_ellipses(img, nb_ellipses=20):
-    """ Draw several ellipses.
+    """Draw several ellipses.
     Parameters:
       nb_ellipses: maximal number of ellipses
     """
@@ -585,16 +627,16 @@ def draw_ellipses(img, nb_ellipses=20):
 
 
 def draw_star(img, nb_branches=6, min_len=32, min_label_len=64):
-    """ Draw a star and return the junction points + line map.
+    """Draw a star and return the junction points + line map.
     Parameters:
       nb_branches: number of branches of the star
     """
     num_branches = random_state.randint(3, nb_branches)
     min_dim = min(img.shape[0], img.shape[1])
     # Convert length constrain to pixel if given float number
-    if isinstance(min_len, float) and min_len <= 1.:
+    if isinstance(min_len, float) and min_len <= 1.0:
         min_len = int(min_dim * min_len)
-    if isinstance(min_label_len, float) and min_label_len <= 1.:
+    if isinstance(min_label_len, float) and min_label_len <= 1.0:
         min_label_len = int(min_dim * min_label_len)
 
     thickness = random_state.randint(min_dim * 0.01, min_dim * 0.025)
@@ -603,12 +645,19 @@ def draw_star(img, nb_branches=6, min_len=32, min_label_len=64):
     y = random_state.randint(rad, img.shape[0] - rad)
     # Sample num_branches points inside the circle
     slices = np.linspace(0, 2 * math.pi, num_branches + 1)
-    angles = [slices[i] + random_state.rand() * (slices[i+1] - slices[i])
-              for i in range(num_branches)]
+    angles = [
+        slices[i] + random_state.rand() * (slices[i + 1] - slices[i])
+        for i in range(num_branches)
+    ]
     points = np.array(
-        [[int(x + max(random_state.rand(), 0.3) * rad * math.cos(a)),
-          int(y + max(random_state.rand(), 0.3) * rad * math.sin(a))]
-         for a in angles])
+        [
+            [
+                int(x + max(random_state.rand(), 0.3) * rad * math.cos(a)),
+                int(y + max(random_state.rand(), 0.3) * rad * math.sin(a)),
+            ]
+            for a in angles
+        ]
+    )
     points = np.concatenate(([[x, y]], points), axis=0)
 
     # Generate segments and check the length
@@ -624,7 +673,8 @@ def draw_star(img, nb_branches=6, min_len=32, min_label_len=64):
 
     # Get all junctions from label segments
     junctions_all = np.concatenate(
-        (label_segments[:, :2], label_segments[:, 2:]), axis=0)
+        (label_segments[:, :2], label_segments[:, 2:]), axis=0
+    )
     if junctions_all.shape[0] == 0:
         junc_points = None
         line_map = None
@@ -638,19 +688,25 @@ def draw_star(img, nb_branches=6, min_len=32, min_label_len=64):
     background_color = int(np.mean(img))
     for i in range(1, num_branches + 1):
         col = get_random_color(background_color)
-        cv.line(img, (points[0][0], points[0][1]),
-                (points[i][0], points[i][1]),
-                col, thickness)
-    return {
-        "points": junc_points,
-        "line_map": line_map
-    }
-
-
-def draw_checkerboard_multiseg(img, max_rows=7, max_cols=7,
-                               transform_params=(0.05, 0.15),
-                               min_label_len=64, seed=None):
-    """ Draw a checkerboard and output the junctions + line segments
+        cv.line(
+            img,
+            (points[0][0], points[0][1]),
+            (points[i][0], points[i][1]),
+            col,
+            thickness,
+        )
+    return {"points": junc_points, "line_map": line_map}
+
+
+def draw_checkerboard_multiseg(
+    img,
+    max_rows=7,
+    max_cols=7,
+    transform_params=(0.05, 0.15),
+    min_label_len=64,
+    seed=None,
+):
+    """Draw a checkerboard and output the junctions + line segments
     Parameters:
       max_rows: maximal number of rows + 1
       max_cols: maximal number of cols + 1
@@ -664,57 +720,63 @@ def draw_checkerboard_multiseg(img, max_rows=7, max_cols=7,
     background_color = int(np.mean(img))
 
     min_dim = min(img.shape)
-    if isinstance(min_label_len, float) and min_label_len <= 1.:
+    if isinstance(min_label_len, float) and min_label_len <= 1.0:
         min_label_len = int(min_dim * min_label_len)
     # Create the grid
     rows = random_state.randint(3, max_rows)  # number of rows
     cols = random_state.randint(3, max_cols)  # number of cols
     s = min((img.shape[1] - 1) // cols, (img.shape[0] - 1) // rows)
-    x_coord = np.tile(range(cols + 1),
-                      rows + 1).reshape(((rows + 1) * (cols + 1), 1))
-    y_coord = np.repeat(range(rows + 1),
-                        cols + 1).reshape(((rows + 1) * (cols + 1), 1))
+    x_coord = np.tile(range(cols + 1), rows + 1).reshape(((rows + 1) * (cols + 1), 1))
+    y_coord = np.repeat(range(rows + 1), cols + 1).reshape(((rows + 1) * (cols + 1), 1))
     # points are the grid coordinates
     points = s * np.concatenate([x_coord, y_coord], axis=1)
 
     # Warp the grid using an affine transformation and an homography
     alpha_affine = np.max(img.shape) * (
-        transform_params[0] + random_state.rand() * transform_params[1])
+        transform_params[0] + random_state.rand() * transform_params[1]
+    )
     center_square = np.float32(img.shape) // 2
     min_dim = min(img.shape)
     square_size = min_dim // 3
-    pts1 = np.float32([center_square + square_size,
-                       [center_square[0] + square_size,
-                        center_square[1] - square_size],
-                       center_square - square_size,
-                       [center_square[0] - square_size,
-                        center_square[1] + square_size]])
-    pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine,
-                                       size=pts1.shape).astype(np.float32)
+    pts1 = np.float32(
+        [
+            center_square + square_size,
+            [center_square[0] + square_size, center_square[1] - square_size],
+            center_square - square_size,
+            [center_square[0] - square_size, center_square[1] + square_size],
+        ]
+    )
+    pts2 = pts1 + random_state.uniform(
+        -alpha_affine, alpha_affine, size=pts1.shape
+    ).astype(np.float32)
     affine_transform = cv.getAffineTransform(pts1[:3], pts2[:3])
-    pts2 = pts1 + random_state.uniform(-alpha_affine / 2, alpha_affine / 2,
-                                       size=pts1.shape).astype(np.float32)
+    pts2 = pts1 + random_state.uniform(
+        -alpha_affine / 2, alpha_affine / 2, size=pts1.shape
+    ).astype(np.float32)
     perspective_transform = cv.getPerspectiveTransform(pts1, pts2)
 
     # Apply the affine transformation
-    points = np.transpose(np.concatenate(
-        (points, np.ones(((rows + 1) * (cols + 1), 1))), axis=1))
+    points = np.transpose(
+        np.concatenate((points, np.ones(((rows + 1) * (cols + 1), 1))), axis=1)
+    )
     warped_points = np.transpose(np.dot(affine_transform, points))
 
     # Apply the homography
-    warped_col0 = np.add(np.sum(np.multiply(
-        warped_points, perspective_transform[0, :2]), axis=1),
-        perspective_transform[0, 2])
-    warped_col1 = np.add(np.sum(np.multiply(
-        warped_points, perspective_transform[1, :2]), axis=1),
-        perspective_transform[1, 2])
-    warped_col2 = np.add(np.sum(np.multiply(
-        warped_points, perspective_transform[2, :2]), axis=1),
-        perspective_transform[2, 2])
+    warped_col0 = np.add(
+        np.sum(np.multiply(warped_points, perspective_transform[0, :2]), axis=1),
+        perspective_transform[0, 2],
+    )
+    warped_col1 = np.add(
+        np.sum(np.multiply(warped_points, perspective_transform[1, :2]), axis=1),
+        perspective_transform[1, 2],
+    )
+    warped_col2 = np.add(
+        np.sum(np.multiply(warped_points, perspective_transform[2, :2]), axis=1),
+        perspective_transform[2, 2],
+    )
     warped_col0 = np.divide(warped_col0, warped_col2)
     warped_col1 = np.divide(warped_col1, warped_col2)
-    warped_points = np.concatenate(
-        [warped_col0[:, None], warped_col1[:, None]], axis=1)
+    warped_points = np.concatenate([warped_col0[:, None], warped_col1[:, None]], axis=1)
     warped_points_float = warped_points.copy()
     warped_points = warped_points.astype(int)
 
@@ -735,15 +797,30 @@ def draw_checkerboard_multiseg(img, max_rows=7, max_cols=7,
             colors[i * cols + j] = col
 
             # Fill the cell
-            cv.fillConvexPoly(img, np.array(
-                [(warped_points[i * (cols + 1) + j, 0],
-                  warped_points[i * (cols + 1) + j, 1]),
-                 (warped_points[i * (cols + 1) + j + 1, 0],
-                  warped_points[i * (cols + 1) + j + 1, 1]),
-                 (warped_points[(i + 1) * (cols + 1) + j + 1, 0],
-                  warped_points[(i + 1) * (cols + 1) + j + 1, 1]),
-                 (warped_points[(i + 1) * (cols + 1) + j, 0],
-                  warped_points[(i + 1) * (cols + 1) + j, 1])]), col)
+            cv.fillConvexPoly(
+                img,
+                np.array(
+                    [
+                        (
+                            warped_points[i * (cols + 1) + j, 0],
+                            warped_points[i * (cols + 1) + j, 1],
+                        ),
+                        (
+                            warped_points[i * (cols + 1) + j + 1, 0],
+                            warped_points[i * (cols + 1) + j + 1, 1],
+                        ),
+                        (
+                            warped_points[(i + 1) * (cols + 1) + j + 1, 0],
+                            warped_points[(i + 1) * (cols + 1) + j + 1, 1],
+                        ),
+                        (
+                            warped_points[(i + 1) * (cols + 1) + j, 0],
+                            warped_points[(i + 1) * (cols + 1) + j, 1],
+                        ),
+                    ]
+                ),
+                col,
+            )
 
     label_segments = np.empty([0, 4], dtype=np.int)
     # Iterate through rows
@@ -751,12 +828,18 @@ def draw_checkerboard_multiseg(img, max_rows=7, max_cols=7,
         # Include all the combination of the junctions
         # Iterate through all the combination of junction index in that row
         multi_seg_lst = [
-            np.array([warped_points_float[id1, 0],
-                      warped_points_float[id1, 1],
-                      warped_points_float[id2, 0],
-                      warped_points_float[id2, 1]])[None, ...]
-            for (id1, id2) in combinations(range(
-                row_idx * (cols + 1), (row_idx + 1) * (cols + 1), 1), 2)]
+            np.array(
+                [
+                    warped_points_float[id1, 0],
+                    warped_points_float[id1, 1],
+                    warped_points_float[id2, 0],
+                    warped_points_float[id2, 1],
+                ]
+            )[None, ...]
+            for (id1, id2) in combinations(
+                range(row_idx * (cols + 1), (row_idx + 1) * (cols + 1), 1), 2
+            )
+        ]
         multi_seg = np.concatenate(multi_seg_lst, axis=0)
         label_segments = np.concatenate((label_segments, multi_seg), axis=0)
 
@@ -765,20 +848,31 @@ def draw_checkerboard_multiseg(img, max_rows=7, max_cols=7,
         # Include all the combination of the junctions
         # Iterate throuhg all the combination of junction index in that column
         multi_seg_lst = [
-            np.array([warped_points_float[id1, 0],
-                      warped_points_float[id1, 1],
-                      warped_points_float[id2, 0],
-                      warped_points_float[id2, 1]])[None, ...]
-            for (id1, id2) in combinations(range(
-                col_idx, col_idx + ((rows + 1) * (cols + 1)), cols + 1), 2)]
+            np.array(
+                [
+                    warped_points_float[id1, 0],
+                    warped_points_float[id1, 1],
+                    warped_points_float[id2, 0],
+                    warped_points_float[id2, 1],
+                ]
+            )[None, ...]
+            for (id1, id2) in combinations(
+                range(col_idx, col_idx + ((rows + 1) * (cols + 1)), cols + 1), 2
+            )
+        ]
         multi_seg = np.concatenate(multi_seg_lst, axis=0)
         label_segments = np.concatenate((label_segments, multi_seg), axis=0)
 
     label_segments_filtered = np.zeros([0, 4])
     # Define image boundary polygon (in x y manner)
     image_poly = shapely.geometry.Polygon(
-        [[0, 0], [img.shape[1] - 1, 0], [img.shape[1] - 1, img.shape[0] - 1],
-         [0, img.shape[0] - 1]])
+        [
+            [0, 0],
+            [img.shape[1] - 1, 0],
+            [img.shape[1] - 1, img.shape[0] - 1],
+            [0, img.shape[0] - 1],
+        ]
+    )
     for idx in range(label_segments.shape[0]):
         # Get the line segment
         seg_raw = label_segments[idx, :]
@@ -787,20 +881,21 @@ def draw_checkerboard_multiseg(img, max_rows=7, max_cols=7,
         # The line segment is just inside the image.
         if seg.intersection(image_poly) == seg:
             label_segments_filtered = np.concatenate(
-                (label_segments_filtered, seg_raw[None, ...]), axis=0)
+                (label_segments_filtered, seg_raw[None, ...]), axis=0
+            )
 
         # Intersect with the image.
         elif seg.intersects(image_poly):
             # Check intersection
             try:
-                p = np.array(seg.intersection(
-                    image_poly).coords).reshape([-1, 4])
+                p = np.array(seg.intersection(image_poly).coords).reshape([-1, 4])
             # If intersect with eact one point
             except:
                 continue
             segment = p
             label_segments_filtered = np.concatenate(
-                (label_segments_filtered, segment), axis=0)
+                (label_segments_filtered, segment), axis=0
+            )
 
         else:
             continue
@@ -814,8 +909,7 @@ def draw_checkerboard_multiseg(img, max_rows=7, max_cols=7,
     label_segments = label_segments[seg_len >= min_label_len, :]
 
     # Get all junctions from label segments
-    junc_points, line_map = get_unique_junctions(label_segments,
-                                                 min_label_len)
+    junc_points, line_map = get_unique_junctions(label_segments, min_label_len)
 
     # Draw lines on the boundaries of the board at random
     nb_rows = random_state.randint(2, rows + 2)
@@ -826,33 +920,52 @@ def draw_checkerboard_multiseg(img, max_rows=7, max_cols=7,
         col_idx1 = random_state.randint(cols + 1)
         col_idx2 = random_state.randint(cols + 1)
         col = get_random_color(background_color)
-        cv.line(img, (warped_points[row_idx * (cols + 1) + col_idx1, 0],
-                      warped_points[row_idx * (cols + 1) + col_idx1, 1]),
-                (warped_points[row_idx * (cols + 1) + col_idx2, 0],
-                 warped_points[row_idx * (cols + 1) + col_idx2, 1]),
-                col, thickness)
+        cv.line(
+            img,
+            (
+                warped_points[row_idx * (cols + 1) + col_idx1, 0],
+                warped_points[row_idx * (cols + 1) + col_idx1, 1],
+            ),
+            (
+                warped_points[row_idx * (cols + 1) + col_idx2, 0],
+                warped_points[row_idx * (cols + 1) + col_idx2, 1],
+            ),
+            col,
+            thickness,
+        )
     for _ in range(nb_cols):
         col_idx = random_state.randint(cols + 1)
         row_idx1 = random_state.randint(rows + 1)
         row_idx2 = random_state.randint(rows + 1)
         col = get_random_color(background_color)
-        cv.line(img, (warped_points[row_idx1 * (cols + 1) + col_idx, 0],
-                      warped_points[row_idx1 * (cols + 1) + col_idx, 1]),
-                (warped_points[row_idx2 * (cols + 1) + col_idx, 0],
-                 warped_points[row_idx2 * (cols + 1) + col_idx, 1]),
-                col, thickness)
+        cv.line(
+            img,
+            (
+                warped_points[row_idx1 * (cols + 1) + col_idx, 0],
+                warped_points[row_idx1 * (cols + 1) + col_idx, 1],
+            ),
+            (
+                warped_points[row_idx2 * (cols + 1) + col_idx, 0],
+                warped_points[row_idx2 * (cols + 1) + col_idx, 1],
+            ),
+            col,
+            thickness,
+        )
 
     # Keep only the points inside the image
     points = keep_points_inside(warped_points, img.shape[:2])
-    return {
-        "points": junc_points,
-        "line_map": line_map
-    }
-
-
-def draw_stripes_multiseg(img, max_nb_cols=13, min_len=0.04, min_label_len=64,
-                          transform_params=(0.05, 0.15), seed=None):
-    """ Draw stripes in a distorted rectangle
+    return {"points": junc_points, "line_map": line_map}
+
+
+def draw_stripes_multiseg(
+    img,
+    max_nb_cols=13,
+    min_len=0.04,
+    min_label_len=64,
+    transform_params=(0.05, 0.15),
+    seed=None,
+):
+    """Draw stripes in a distorted rectangle
         and output the junctions points + line map.
     Parameters:
       max_nb_cols: maximal number of stripes to be drawn
@@ -868,73 +981,84 @@ def draw_stripes_multiseg(img, max_nb_cols=13, min_len=0.04, min_label_len=64,
 
     background_color = int(np.mean(img))
     # Create the grid
-    board_size = (int(img.shape[0] * (1 + random_state.rand())),
-                  int(img.shape[1] * (1 + random_state.rand())))
+    board_size = (
+        int(img.shape[0] * (1 + random_state.rand())),
+        int(img.shape[1] * (1 + random_state.rand())),
+    )
 
     # Number of cols
     col = random_state.randint(5, max_nb_cols)
-    cols = np.concatenate([board_size[1] * random_state.rand(col - 1),
-                           np.array([0, board_size[1] - 1])], axis=0)
+    cols = np.concatenate(
+        [board_size[1] * random_state.rand(col - 1), np.array([0, board_size[1] - 1])],
+        axis=0,
+    )
     cols = np.unique(cols.astype(int))
 
     # Remove the indices that are too close
     min_dim = min(img.shape)
 
     # Convert length constrain to pixel if given float number
-    if isinstance(min_len, float) and min_len <= 1.:
+    if isinstance(min_len, float) and min_len <= 1.0:
         min_len = int(min_dim * min_len)
-    if isinstance(min_label_len, float) and min_label_len <= 1.:
+    if isinstance(min_label_len, float) and min_label_len <= 1.0:
         min_label_len = int(min_dim * min_label_len)
 
-    cols = cols[(np.concatenate([cols[1:],
-                                 np.array([board_size[1] + min_len])],
-                                axis=0) - cols) >= min_len]
+    cols = cols[
+        (np.concatenate([cols[1:], np.array([board_size[1] + min_len])], axis=0) - cols)
+        >= min_len
+    ]
     # Update the number of cols
     col = cols.shape[0] - 1
     cols = np.reshape(cols, (col + 1, 1))
     cols1 = np.concatenate([cols, np.zeros((col + 1, 1), np.int32)], axis=1)
     cols2 = np.concatenate(
-        [cols, (board_size[0] - 1) * np.ones((col + 1, 1), np.int32)], axis=1)
+        [cols, (board_size[0] - 1) * np.ones((col + 1, 1), np.int32)], axis=1
+    )
     points = np.concatenate([cols1, cols2], axis=0)
 
     # Warp the grid using an affine transformation and a homography
     alpha_affine = np.max(img.shape) * (
-        transform_params[0] + random_state.rand() * transform_params[1])
+        transform_params[0] + random_state.rand() * transform_params[1]
+    )
     center_square = np.float32(img.shape) // 2
     square_size = min(img.shape) // 3
-    pts1 = np.float32([center_square + square_size,
-                       [center_square[0]+square_size,
-                        center_square[1]-square_size],
-                       center_square - square_size,
-                       [center_square[0]-square_size,
-                        center_square[1]+square_size]])
-    pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine,
-                                       size=pts1.shape).astype(np.float32)
+    pts1 = np.float32(
+        [
+            center_square + square_size,
+            [center_square[0] + square_size, center_square[1] - square_size],
+            center_square - square_size,
+            [center_square[0] - square_size, center_square[1] + square_size],
+        ]
+    )
+    pts2 = pts1 + random_state.uniform(
+        -alpha_affine, alpha_affine, size=pts1.shape
+    ).astype(np.float32)
     affine_transform = cv.getAffineTransform(pts1[:3], pts2[:3])
-    pts2 = pts1 + random_state.uniform(-alpha_affine / 2, alpha_affine / 2,
-                                       size=pts1.shape).astype(np.float32)
+    pts2 = pts1 + random_state.uniform(
+        -alpha_affine / 2, alpha_affine / 2, size=pts1.shape
+    ).astype(np.float32)
     perspective_transform = cv.getPerspectiveTransform(pts1, pts2)
 
     # Apply the affine transformation
-    points = np.transpose(np.concatenate((points,
-                                          np.ones((2 * (col + 1), 1))),
-                                         axis=1))
+    points = np.transpose(np.concatenate((points, np.ones((2 * (col + 1), 1))), axis=1))
     warped_points = np.transpose(np.dot(affine_transform, points))
 
     # Apply the homography
-    warped_col0 = np.add(np.sum(np.multiply(
-        warped_points, perspective_transform[0, :2]), axis=1),
-        perspective_transform[0, 2])
-    warped_col1 = np.add(np.sum(np.multiply(
-        warped_points, perspective_transform[1, :2]), axis=1),
-        perspective_transform[1, 2])
-    warped_col2 = np.add(np.sum(np.multiply(
-        warped_points, perspective_transform[2, :2]), axis=1),
-        perspective_transform[2, 2])
+    warped_col0 = np.add(
+        np.sum(np.multiply(warped_points, perspective_transform[0, :2]), axis=1),
+        perspective_transform[0, 2],
+    )
+    warped_col1 = np.add(
+        np.sum(np.multiply(warped_points, perspective_transform[1, :2]), axis=1),
+        perspective_transform[1, 2],
+    )
+    warped_col2 = np.add(
+        np.sum(np.multiply(warped_points, perspective_transform[2, :2]), axis=1),
+        perspective_transform[2, 2],
+    )
     warped_col0 = np.divide(warped_col0, warped_col2)
     warped_col1 = np.divide(warped_col1, warped_col2)
-    warped_points = np.concatenate(
-        [warped_col0[:, None], warped_col1[:, None]], axis=1)
+    warped_points = np.concatenate([warped_col0[:, None], warped_col1[:, None]], axis=1)
     warped_points_float = warped_points.copy()
     warped_points = warped_points.astype(int)
 
@@ -944,15 +1068,18 @@ def draw_stripes_multiseg(img, max_nb_cols=13, min_len=0.04, min_label_len=64,
     for i in range(col):
         # Fill the color
         color = (color + 128 + random_state.randint(-30, 30)) % 256
-        cv.fillConvexPoly(img, np.array([(warped_points[i, 0],
-                                          warped_points[i, 1]),
-                                         (warped_points[i+1, 0],
-                                          warped_points[i+1, 1]),
-                                         (warped_points[i+col+2, 0],
-                                          warped_points[i+col+2, 1]),
-                                         (warped_points[i+col+1, 0],
-                                          warped_points[i+col+1, 1])]),
-                          color)
+        cv.fillConvexPoly(
+            img,
+            np.array(
+                [
+                    (warped_points[i, 0], warped_points[i, 1]),
+                    (warped_points[i + 1, 0], warped_points[i + 1, 1]),
+                    (warped_points[i + col + 2, 0], warped_points[i + col + 2, 1]),
+                    (warped_points[i + col + 1, 0], warped_points[i + col + 1, 1]),
+                ]
+            ),
+            color,
+        )
 
     segments = np.zeros([0, 4])
     row = 1  # in stripes case
@@ -960,27 +1087,39 @@ def draw_stripes_multiseg(img, max_nb_cols=13, min_len=0.04, min_label_len=64,
     for row_idx in range(row + 1):
         # Include all the combination of the junctions
         # Iterate through all the combination of junction index in that row
-        multi_seg_lst = [np.array(
-            [warped_points_float[id1, 0],
-             warped_points_float[id1, 1],
-             warped_points_float[id2, 0],
-             warped_points_float[id2, 1]])[None, ...]
-             for (id1, id2) in combinations(range(
-                 row_idx * (col + 1), (row_idx + 1) * (col + 1), 1), 2)]
+        multi_seg_lst = [
+            np.array(
+                [
+                    warped_points_float[id1, 0],
+                    warped_points_float[id1, 1],
+                    warped_points_float[id2, 0],
+                    warped_points_float[id2, 1],
+                ]
+            )[None, ...]
+            for (id1, id2) in combinations(
+                range(row_idx * (col + 1), (row_idx + 1) * (col + 1), 1), 2
+            )
+        ]
         multi_seg = np.concatenate(multi_seg_lst, axis=0)
         segments = np.concatenate((segments, multi_seg), axis=0)
 
     # Iterate through columns
-    for col_idx in range(col + 1): # for 5 columns, we will have 5 + 1 edges.
+    for col_idx in range(col + 1):  # for 5 columns, we will have 5 + 1 edges.
         # Include all the combination of the junctions
         # Iterate throuhg all the combination of junction index in that column
-        multi_seg_lst = [np.array(
-            [warped_points_float[id1, 0],
-             warped_points_float[id1, 1],
-             warped_points_float[id2, 0],
-             warped_points_float[id2, 1]])[None, ...]
-             for (id1, id2) in combinations(range(
-                 col_idx, col_idx + (row * col) + 2, col + 1), 2)]
+        multi_seg_lst = [
+            np.array(
+                [
+                    warped_points_float[id1, 0],
+                    warped_points_float[id1, 1],
+                    warped_points_float[id2, 0],
+                    warped_points_float[id2, 1],
+                ]
+            )[None, ...]
+            for (id1, id2) in combinations(
+                range(col_idx, col_idx + (row * col) + 2, col + 1), 2
+            )
+        ]
         multi_seg = np.concatenate(multi_seg_lst, axis=0)
         segments = np.concatenate((segments, multi_seg), axis=0)
 
@@ -988,8 +1127,13 @@ def draw_stripes_multiseg(img, max_nb_cols=13, min_len=0.04, min_label_len=64,
     segments_new = np.zeros([0, 4])
     # Define image boundary polygon (in x y manner)
     image_poly = shapely.geometry.Polygon(
-        [[0, 0], [img.shape[1]-1, 0], [img.shape[1]-1, img.shape[0]-1],
-         [0, img.shape[0]-1]])
+        [
+            [0, 0],
+            [img.shape[1] - 1, 0],
+            [img.shape[1] - 1, img.shape[0] - 1],
+            [0, img.shape[0] - 1],
+        ]
+    )
     for idx in range(segments.shape[0]):
         # Get the line segment
         seg_raw = segments[idx, :]
@@ -997,15 +1141,13 @@ def draw_stripes_multiseg(img, max_nb_cols=13, min_len=0.04, min_label_len=64,
 
         # The line segment is just inside the image.
         if seg.intersection(image_poly) == seg:
-            segments_new = np.concatenate(
-                (segments_new, seg_raw[None, ...]), axis=0)
+            segments_new = np.concatenate((segments_new, seg_raw[None, ...]), axis=0)
 
         # Intersect with the image.
         elif seg.intersects(image_poly):
             # Check intersection
             try:
-                p = np.array(
-                    seg.intersection(image_poly).coords).reshape([-1, 4])
+                p = np.array(seg.intersection(image_poly).coords).reshape([-1, 4])
             # If intersect at exact one point, just continue.
             except:
                 continue
@@ -1025,7 +1167,8 @@ def draw_stripes_multiseg(img, max_nb_cols=13, min_len=0.04, min_label_len=64,
 
     # Get all junctions from label segments
     junctions_all = np.concatenate(
-        (label_segments[:, :2], label_segments[:, 2:]), axis=0)
+        (label_segments[:, :2], label_segments[:, 2:]), axis=0
+    )
     if junctions_all.shape[0] == 0:
         junc_points = None
         line_map = None
@@ -1045,32 +1188,44 @@ def draw_stripes_multiseg(img, max_nb_cols=13, min_len=0.04, min_label_len=64,
         col_idx1 = random_state.randint(col + 1)
         col_idx2 = random_state.randint(col + 1)
         color = get_random_color(background_color)
-        cv.line(img, (warped_points[row_idx + col_idx1, 0],
-                      warped_points[row_idx + col_idx1, 1]),
-                (warped_points[row_idx + col_idx2, 0],
-                 warped_points[row_idx + col_idx2, 1]),
-                color, thickness)
+        cv.line(
+            img,
+            (
+                warped_points[row_idx + col_idx1, 0],
+                warped_points[row_idx + col_idx1, 1],
+            ),
+            (
+                warped_points[row_idx + col_idx2, 0],
+                warped_points[row_idx + col_idx2, 1],
+            ),
+            color,
+            thickness,
+        )
 
     for _ in range(nb_cols):
         col_idx = random_state.randint(col + 1)
         color = get_random_color(background_color)
-        cv.line(img, (warped_points[col_idx, 0],
-                      warped_points[col_idx, 1]),
-                (warped_points[col_idx + col + 1, 0],
-                 warped_points[col_idx + col + 1, 1]),
-                color, thickness)
+        cv.line(
+            img,
+            (warped_points[col_idx, 0], warped_points[col_idx, 1]),
+            (warped_points[col_idx + col + 1, 0], warped_points[col_idx + col + 1, 1]),
+            color,
+            thickness,
+        )
 
     # Keep only the points inside the image
     # points = keep_points_inside(warped_points, img.shape[:2])
-    return {
-        "points": junc_points,
-        "line_map": line_map
-    }
+    return {"points": junc_points, "line_map": line_map}
 
 
-def draw_cube(img, min_size_ratio=0.2, min_label_len=64,
-              scale_interval=(0.4, 0.6), trans_interval=(0.5, 0.2)):
-    """ Draw a 2D projection of a cube and output the visible juntions.
+def draw_cube(
+    img,
+    min_size_ratio=0.2,
+    min_label_len=64,
+    scale_interval=(0.4, 0.6),
+    trans_interval=(0.5, 0.2),
+):
+    """Draw a 2D projection of a cube and output the visible juntions.
     Parameters:
       min_size_ratio: min(img.shape) * min_size_ratio is the smallest
                       achievable cube side size
@@ -1088,46 +1243,68 @@ def draw_cube(img, min_size_ratio=0.2, min_label_len=64,
     lx = min_side + random_state.rand() * 2 * min_dim / 3  # dims of the cube
     ly = min_side + random_state.rand() * 2 * min_dim / 3
     lz = min_side + random_state.rand() * 2 * min_dim / 3
-    cube = np.array([[0, 0, 0],
-                     [lx, 0, 0],
-                     [0, ly, 0],
-                     [lx, ly, 0],
-                     [0, 0, lz],
-                     [lx, 0, lz],
-                     [0, ly, lz],
-                     [lx, ly, lz]])
-    rot_angles = random_state.rand(3) * 3 * math.pi / 10. + math.pi / 10.
-    rotation_1 = np.array([[math.cos(rot_angles[0]),
-                            -math.sin(rot_angles[0]), 0],
-                           [math.sin(rot_angles[0]),
-                            math.cos(rot_angles[0]), 0],
-                           [0, 0, 1]])
-    rotation_2 = np.array([[1, 0, 0],
-                           [0, math.cos(rot_angles[1]),
-                            -math.sin(rot_angles[1])],
-                           [0, math.sin(rot_angles[1]),
-                            math.cos(rot_angles[1])]])
-    rotation_3 = np.array([[math.cos(rot_angles[2]), 0,
-                            -math.sin(rot_angles[2])],
-                           [0, 1, 0],
-                           [math.sin(rot_angles[2]), 0,
-                            math.cos(rot_angles[2])]])
-    scaling = np.array([[scale_interval[0] +
-                         random_state.rand() * scale_interval[1], 0, 0],
-                        [0, scale_interval[0] +
-                         random_state.rand() * scale_interval[1], 0],
-                        [0, 0, scale_interval[0] +
-                         random_state.rand() * scale_interval[1]]])
-    trans = np.array([img.shape[1] * trans_interval[0] +
-                      random_state.randint(-img.shape[1] * trans_interval[1],
-                                           img.shape[1] * trans_interval[1]),
-                      img.shape[0] * trans_interval[0] +
-                      random_state.randint(-img.shape[0] * trans_interval[1],
-                                           img.shape[0] * trans_interval[1]),
-                      0])
+    cube = np.array(
+        [
+            [0, 0, 0],
+            [lx, 0, 0],
+            [0, ly, 0],
+            [lx, ly, 0],
+            [0, 0, lz],
+            [lx, 0, lz],
+            [0, ly, lz],
+            [lx, ly, lz],
+        ]
+    )
+    rot_angles = random_state.rand(3) * 3 * math.pi / 10.0 + math.pi / 10.0
+    rotation_1 = np.array(
+        [
+            [math.cos(rot_angles[0]), -math.sin(rot_angles[0]), 0],
+            [math.sin(rot_angles[0]), math.cos(rot_angles[0]), 0],
+            [0, 0, 1],
+        ]
+    )
+    rotation_2 = np.array(
+        [
+            [1, 0, 0],
+            [0, math.cos(rot_angles[1]), -math.sin(rot_angles[1])],
+            [0, math.sin(rot_angles[1]), math.cos(rot_angles[1])],
+        ]
+    )
+    rotation_3 = np.array(
+        [
+            [math.cos(rot_angles[2]), 0, -math.sin(rot_angles[2])],
+            [0, 1, 0],
+            [math.sin(rot_angles[2]), 0, math.cos(rot_angles[2])],
+        ]
+    )
+    scaling = np.array(
+        [
+            [scale_interval[0] + random_state.rand() * scale_interval[1], 0, 0],
+            [0, scale_interval[0] + random_state.rand() * scale_interval[1], 0],
+            [0, 0, scale_interval[0] + random_state.rand() * scale_interval[1]],
+        ]
+    )
+    trans = np.array(
+        [
+            img.shape[1] * trans_interval[0]
+            + random_state.randint(
+                -img.shape[1] * trans_interval[1], img.shape[1] * trans_interval[1]
+            ),
+            img.shape[0] * trans_interval[0]
+            + random_state.randint(
+                -img.shape[0] * trans_interval[1], img.shape[0] * trans_interval[1]
+            ),
+            0,
+        ]
+    )
     cube = trans + np.transpose(
-        np.dot(scaling, np.dot(rotation_1,
-        np.dot(rotation_2, np.dot(rotation_3, np.transpose(cube))))))
+        np.dot(
+            scaling,
+            np.dot(
+                rotation_1, np.dot(rotation_2, np.dot(rotation_3, np.transpose(cube)))
+            ),
+        )
+    )
 
     # The hidden corner is 0 by construction
     # The front one is 7
@@ -1145,18 +1322,26 @@ def draw_cube(img, min_size_ratio=0.2, min_label_len=64,
         face = faces[face_idx, :]
         # Brute-forcely expand all the segments
         segment = np.array(
-            [np.concatenate((cube[face[0]], cube[face[1]]), axis=0),
-             np.concatenate((cube[face[1]], cube[face[2]]), axis=0),
-             np.concatenate((cube[face[2]], cube[face[3]]), axis=0),
-             np.concatenate((cube[face[3]], cube[face[0]]), axis=0)])
+            [
+                np.concatenate((cube[face[0]], cube[face[1]]), axis=0),
+                np.concatenate((cube[face[1]], cube[face[2]]), axis=0),
+                np.concatenate((cube[face[2]], cube[face[3]]), axis=0),
+                np.concatenate((cube[face[3]], cube[face[0]]), axis=0),
+            ]
+        )
         segments = np.concatenate((segments, segment), axis=0)
 
     # Select and refine the segments
     segments_new = np.zeros([0, 4])
     # Define image boundary polygon (in x y manner)
     image_poly = shapely.geometry.Polygon(
-        [[0, 0], [img.shape[1] - 1, 0], [img.shape[1] - 1, img.shape[0] - 1],
-         [0, img.shape[0] - 1]])
+        [
+            [0, 0],
+            [img.shape[1] - 1, 0],
+            [img.shape[1] - 1, img.shape[0] - 1],
+            [0, img.shape[0] - 1],
+        ]
+    )
     for idx in range(segments.shape[0]):
         # Get the line segment
         seg_raw = segments[idx, :]
@@ -1164,14 +1349,12 @@ def draw_cube(img, min_size_ratio=0.2, min_label_len=64,
 
         # The line segment is just inside the image.
         if seg.intersection(image_poly) == seg:
-            segments_new = np.concatenate(
-                (segments_new, seg_raw[None, ...]), axis=0)
+            segments_new = np.concatenate((segments_new, seg_raw[None, ...]), axis=0)
 
         # Intersect with the image.
         elif seg.intersects(image_poly):
             try:
-                p = np.array(
-                    seg.intersection(image_poly).coords).reshape([-1, 4])
+                p = np.array(seg.intersection(image_poly).coords).reshape([-1, 4])
             except:
                 continue
             segment = p
@@ -1190,7 +1373,8 @@ def draw_cube(img, min_size_ratio=0.2, min_label_len=64,
 
     # Get all junctions from label segments
     junctions_all = np.concatenate(
-        (label_segments[:, :2], label_segments[:, 2:]), axis=0)
+        (label_segments[:, :2], label_segments[:, 2:]), axis=0
+    )
     if junctions_all.shape[0] == 0:
         junc_points = None
         line_map = None
@@ -1204,29 +1388,25 @@ def draw_cube(img, min_size_ratio=0.2, min_label_len=64,
     # Fill the faces and draw the contours
     col_face = get_random_color(background_color)
     for i in [0, 1, 2]:
-        cv.fillPoly(img, [cube[faces[i]].reshape((-1, 1, 2))],
-                    col_face)
+        cv.fillPoly(img, [cube[faces[i]].reshape((-1, 1, 2))], col_face)
     thickness = random_state.randint(min_dim * 0.003, min_dim * 0.015)
     for i in [0, 1, 2]:
         for j in [0, 1, 2, 3]:
-            col_edge = (col_face + 128
-                        + random_state.randint(-64, 64))\
-                        % 256  # color that constrats with the face color
-            cv.line(img, (cube[faces[i][j], 0], cube[faces[i][j], 1]),
-                    (cube[faces[i][(j + 1) % 4], 0],
-                     cube[faces[i][(j + 1) % 4], 1]),
-                    col_edge, thickness)
+            col_edge = (
+                col_face + 128 + random_state.randint(-64, 64)
+            ) % 256  # color that constrats with the face color
+            cv.line(
+                img,
+                (cube[faces[i][j], 0], cube[faces[i][j], 1]),
+                (cube[faces[i][(j + 1) % 4], 0], cube[faces[i][(j + 1) % 4], 1]),
+                col_edge,
+                thickness,
+            )
 
-    return {
-        "points": junc_points,
-        "line_map": line_map
-    }
+    return {"points": junc_points, "line_map": line_map}
 
 
 def gaussian_noise(img):
-    """ Apply random noise to the image. """
+    """Apply random noise to the image."""
     cv.randu(img, 0, 255)
-    return {
-        "points": None,
-        "line_map": None
-    }
+    return {"points": None, "line_map": None}
diff --git a/third_party/SOLD2/sold2/dataset/transforms/homographic_transforms.py b/third_party/SOLD2/sold2/dataset/transforms/homographic_transforms.py
index d9338abb169f7a86f3c6e702a031e1c0de86c339..b9c63613b57f9064333bf80bd59fa6553f3ccb8e 100644
--- a/third_party/SOLD2/sold2/dataset/transforms/homographic_transforms.py
+++ b/third_party/SOLD2/sold2/dataset/transforms/homographic_transforms.py
@@ -12,11 +12,21 @@ import shapely.geometry
 
 
 def sample_homography(
-        shape, perspective=True, scaling=True, rotation=True,
-        translation=True, n_scales=5, n_angles=25, scaling_amplitude=0.1,
-        perspective_amplitude_x=0.1, perspective_amplitude_y=0.1,
-        patch_ratio=0.5, max_angle=pi/2, allow_artifacts=False,
-        translation_overflow=0.):
+    shape,
+    perspective=True,
+    scaling=True,
+    rotation=True,
+    translation=True,
+    n_scales=5,
+    n_angles=25,
+    scaling_amplitude=0.1,
+    perspective_amplitude_x=0.1,
+    perspective_amplitude_y=0.1,
+    patch_ratio=0.5,
+    max_angle=pi / 2,
+    allow_artifacts=False,
+    translation_overflow=0.0,
+):
     """
     Computes the homography transformation between a random patch in the
     original image and a warped projection with the same image size.
@@ -51,11 +61,12 @@ def sample_homography(
         shape = np.array(shape)
 
     # Corners of the output image
-    pts1 = np.array([[0., 0.], [0., 1.], [1., 1.], [1., 0.]])
+    pts1 = np.array([[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0]])
     # Corners of the input patch
     margin = (1 - patch_ratio) / 2
-    pts2 = margin + np.array([[0, 0], [0, patch_ratio],
-                             [patch_ratio, patch_ratio], [patch_ratio, 0]])
+    pts2 = margin + np.array(
+        [[0, 0], [0, patch_ratio], [patch_ratio, patch_ratio], [patch_ratio, 0]]
+    )
 
     # Random perspective and affine perturbations
     if perspective:
@@ -65,25 +76,25 @@ def sample_homography(
 
         # normal distribution with mean=0, std=perspective_amplitude_y/2
         perspective_displacement = np.random.normal(
-            0., perspective_amplitude_y/2, [1])
-        h_displacement_left = np.random.normal(
-            0., perspective_amplitude_x/2, [1])
-        h_displacement_right = np.random.normal(
-            0., perspective_amplitude_x/2, [1])
-        pts2 += np.stack([np.concatenate([h_displacement_left,
-                                          perspective_displacement], 0),
-                          np.concatenate([h_displacement_left,
-                                          -perspective_displacement], 0),
-                          np.concatenate([h_displacement_right,
-                                          perspective_displacement], 0),
-                          np.concatenate([h_displacement_right,
-                                          -perspective_displacement], 0)])
+            0.0, perspective_amplitude_y / 2, [1]
+        )
+        h_displacement_left = np.random.normal(0.0, perspective_amplitude_x / 2, [1])
+        h_displacement_right = np.random.normal(0.0, perspective_amplitude_x / 2, [1])
+        pts2 += np.stack(
+            [
+                np.concatenate([h_displacement_left, perspective_displacement], 0),
+                np.concatenate([h_displacement_left, -perspective_displacement], 0),
+                np.concatenate([h_displacement_right, perspective_displacement], 0),
+                np.concatenate([h_displacement_right, -perspective_displacement], 0),
+            ]
+        )
 
     # Random scaling: sample several scales, check collision with borders,
     # randomly pick a valid one
     if scaling:
         scales = np.concatenate(
-            [[1.], np.random.normal(1, scaling_amplitude/2, [n_scales])], 0)
+            [[1.0], np.random.normal(1, scaling_amplitude / 2, [n_scales])], 0
+        )
         center = np.mean(pts2, axis=0, keepdims=True)
         scaled = (pts2 - center)[None, ...] * scales[..., None, None] + center
         # all scales are valid except scale=1
@@ -91,17 +102,27 @@ def sample_homography(
             valid = np.array(range(n_scales))
         # Chech the valid scale
         else:
-            valid = np.where(np.all((scaled >= 0.)
-                             & (scaled < 1.), (1, 2)))[0]
+            valid = np.where(np.all((scaled >= 0.0) & (scaled < 1.0), (1, 2)))[0]
         # No valid scale found => recursively call
         if valid.shape[0] == 0:
             return sample_homography(
-                shape, perspective, scaling, rotation, translation,
-                n_scales, n_angles, scaling_amplitude, 
-                perspective_amplitude_x, perspective_amplitude_y,
-                patch_ratio, max_angle, allow_artifacts, translation_overflow)
-
-        idx = valid[np.random.uniform(0., valid.shape[0], ()).astype(np.int32)]
+                shape,
+                perspective,
+                scaling,
+                rotation,
+                translation,
+                n_scales,
+                n_angles,
+                scaling_amplitude,
+                perspective_amplitude_x,
+                perspective_amplitude_y,
+                patch_ratio,
+                max_angle,
+                allow_artifacts,
+                translation_overflow,
+            )
+
+        idx = valid[np.random.uniform(0.0, valid.shape[0], ()).astype(np.int32)]
         pts2 = scaled[idx]
 
         # Additionally save and return the selected scale.
@@ -113,39 +134,60 @@ def sample_homography(
         if allow_artifacts:
             t_min += translation_overflow
             t_max += translation_overflow
-        pts2 += (np.stack([np.random.uniform(-t_min[0], t_max[0], ()),
-                           np.random.uniform(-t_min[1],
-                                             t_max[1], ())]))[None, ...]
+        pts2 += (
+            np.stack(
+                [
+                    np.random.uniform(-t_min[0], t_max[0], ()),
+                    np.random.uniform(-t_min[1], t_max[1], ()),
+                ]
+            )
+        )[None, ...]
 
     # Random rotation: sample several rotations, check collision with borders,
     # randomly pick a valid one
     if rotation:
         angles = np.linspace(-max_angle, max_angle, n_angles)
         # in case no rotation is valid
-        angles = np.concatenate([[0.], angles], axis=0)
+        angles = np.concatenate([[0.0], angles], axis=0)
         center = np.mean(pts2, axis=0, keepdims=True)
-        rot_mat = np.reshape(np.stack(
-            [np.cos(angles), -np.sin(angles),
-             np.sin(angles), np.cos(angles)], axis=1), [-1, 2, 2])
-        rotated = np.matmul(
-                np.tile((pts2 - center)[None, ...], [n_angles+1, 1, 1]),
-                rot_mat) + center
+        rot_mat = np.reshape(
+            np.stack(
+                [np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)],
+                axis=1,
+            ),
+            [-1, 2, 2],
+        )
+        rotated = (
+            np.matmul(
+                np.tile((pts2 - center)[None, ...], [n_angles + 1, 1, 1]), rot_mat
+            )
+            + center
+        )
         if allow_artifacts:
             # All angles are valid, except angle=0
             valid = np.array(range(n_angles))
         else:
-            valid = np.where(np.all((rotated >= 0.)
-                             & (rotated < 1.), axis=(1, 2)))[0]
-        
+            valid = np.where(np.all((rotated >= 0.0) & (rotated < 1.0), axis=(1, 2)))[0]
+
         if valid.shape[0] == 0:
             return sample_homography(
-                shape, perspective, scaling, rotation, translation,
-                n_scales, n_angles, scaling_amplitude, 
-                perspective_amplitude_x, perspective_amplitude_y,
-                patch_ratio, max_angle, allow_artifacts, translation_overflow)
-
-        idx = valid[np.random.uniform(0., valid.shape[0],
-                                      ()).astype(np.int32)]
+                shape,
+                perspective,
+                scaling,
+                rotation,
+                translation,
+                n_scales,
+                n_angles,
+                scaling_amplitude,
+                perspective_amplitude_x,
+                perspective_amplitude_y,
+                patch_ratio,
+                max_angle,
+                allow_artifacts,
+                translation_overflow,
+            )
+
+        idx = valid[np.random.uniform(0.0, valid.shape[0], ()).astype(np.int32)]
         pts2 = rotated[idx]
 
     # Rescale to actual size
@@ -153,27 +195,33 @@ def sample_homography(
     pts1 *= shape[None, ...]
     pts2 *= shape[None, ...]
 
-    def ax(p, q): return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]]
+    def ax(p, q):
+        return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]]
 
-    def ay(p, q): return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]]
+    def ay(p, q):
+        return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]]
 
-    a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4)
-                      for f in (ax, ay)], axis=0)
-    p_mat = np.transpose(np.stack([[pts2[i][j] for i in range(4)
-                                    for j in range(2)]], axis=0))
+    a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) for f in (ax, ay)], axis=0)
+    p_mat = np.transpose(
+        np.stack([[pts2[i][j] for i in range(4) for j in range(2)]], axis=0)
+    )
     homo_vec, _, _, _ = np.linalg.lstsq(a_mat, p_mat, rcond=None)
 
     # Compose the homography vector back to matrix
-    homo_mat = np.concatenate([
-        homo_vec[0:3, 0][None, ...], homo_vec[3:6, 0][None, ...],
-        np.concatenate((homo_vec[6], homo_vec[7], [1]),
-                       axis=0)[None, ...]], axis=0)
+    homo_mat = np.concatenate(
+        [
+            homo_vec[0:3, 0][None, ...],
+            homo_vec[3:6, 0][None, ...],
+            np.concatenate((homo_vec[6], homo_vec[7], [1]), axis=0)[None, ...],
+        ],
+        axis=0,
+    )
 
     return homo_mat, selected_scale
 
 
 def convert_to_line_segments(junctions, line_map):
-    """ Convert junctions and line map to line segments. """
+    """Convert junctions and line map to line segments."""
     # Copy the line map
     line_map_tmp = copy.copy(line_map)
 
@@ -188,9 +236,9 @@ def convert_to_line_segments(junctions, line_map):
                 p1 = junctions[idx, :]
                 p2 = junctions[idx2, :]
                 line_segments = np.concatenate(
-                    (line_segments,
-                     np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]),
-                    axis=0)
+                    (line_segments, np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]),
+                    axis=0,
+                )
                 # Update line_map
                 line_map_tmp[idx, idx2] = 0
                 line_map_tmp[idx2, idx] = 0
@@ -198,46 +246,50 @@ def convert_to_line_segments(junctions, line_map):
     return line_segments
 
 
-def compute_valid_mask(image_size, homography,
-                       border_margin, valid_mask=None):
+def compute_valid_mask(image_size, homography, border_margin, valid_mask=None):
     # Warp the mask
     if valid_mask is None:
         initial_mask = np.ones(image_size)
     else:
         initial_mask = valid_mask
     mask = cv2.warpPerspective(
-        initial_mask, homography, (image_size[1], image_size[0]),
-        flags=cv2.INTER_NEAREST)
+        initial_mask,
+        homography,
+        (image_size[1], image_size[0]),
+        flags=cv2.INTER_NEAREST,
+    )
 
     # Optionally perform erosion
     if border_margin > 0:
-        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
-                                           (border_margin*2, )*2)
+        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (border_margin * 2,) * 2)
         mask = cv2.erode(mask, kernel)
-    
+
     # Perform dilation if border_margin is negative
     if border_margin < 0:
-        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
-                                           (abs(int(border_margin))*2, )*2)
+        kernel = cv2.getStructuringElement(
+            cv2.MORPH_ELLIPSE, (abs(int(border_margin)) * 2,) * 2
+        )
         mask = cv2.dilate(mask, kernel)
 
     return mask
 
 
 def warp_line_segment(line_segments, homography, image_size):
-    """ Warp the line segments using a homography. """
+    """Warp the line segments using a homography."""
     # Separate the line segements into 2N points to apply matrix operation
     num_segments = line_segments.shape[0]
 
     junctions = np.concatenate(
-        (line_segments[:, :2], # The first junction of each segment.
-        line_segments[:, 2:]), # The second junction of each segment.
-        axis=0)
+        (
+            line_segments[:, :2],  # The first junction of each segment.
+            line_segments[:, 2:],
+        ),  # The second junction of each segment.
+        axis=0,
+    )
     # Convert to homogeneous coordinates
     # Flip the junctions before converting to homogeneous (xy format)
     junctions = np.flip(junctions, axis=1)
-    junctions = np.concatenate((junctions, np.ones([2*num_segments, 1])),
-                               axis=1)
+    junctions = np.concatenate((junctions, np.ones([2 * num_segments, 1])), axis=1)
     warped_junctions = np.matmul(homography, junctions.T).T
 
     # Convert back to segments
@@ -245,41 +297,43 @@ def warp_line_segment(line_segments, homography, image_size):
     # (Convert back to hw format)
     warped_junctions = np.flip(warped_junctions, axis=1)
     warped_segments = np.concatenate(
-        (warped_junctions[:num_segments, :],
-         warped_junctions[num_segments:, :]),
-        axis=1
+        (warped_junctions[:num_segments, :], warped_junctions[num_segments:, :]), axis=1
     )
 
     # Check the intersections with the boundary
     warped_segments_new = np.zeros([0, 4])
     image_poly = shapely.geometry.Polygon(
-        [[0, 0], [image_size[1]-1, 0], [image_size[1]-1, image_size[0]-1],
-        [0, image_size[0]-1]])
+        [
+            [0, 0],
+            [image_size[1] - 1, 0],
+            [image_size[1] - 1, image_size[0] - 1],
+            [0, image_size[0] - 1],
+        ]
+    )
     for idx in range(warped_segments.shape[0]):
         # Get the line segment
-        seg_raw = warped_segments[idx, :]   # in HW format.
+        seg_raw = warped_segments[idx, :]  # in HW format.
         # Convert to shapely line (flip to xy format)
-        seg = shapely.geometry.LineString([np.flip(seg_raw[:2]), 
-                                           np.flip(seg_raw[2:])])
+        seg = shapely.geometry.LineString([np.flip(seg_raw[:2]), np.flip(seg_raw[2:])])
 
         # The line segment is just inside the image.
         if seg.intersection(image_poly) == seg:
-            warped_segments_new = np.concatenate((warped_segments_new,
-                                                  seg_raw[None, ...]), axis=0)
-        
+            warped_segments_new = np.concatenate(
+                (warped_segments_new, seg_raw[None, ...]), axis=0
+            )
+
         # Intersect with the image.
         elif seg.intersects(image_poly):
             # Check intersection
             try:
-                p = np.array(
-                    seg.intersection(image_poly).coords).reshape([-1, 4])
+                p = np.array(seg.intersection(image_poly).coords).reshape([-1, 4])
             # If intersect at exact one point, just continue.
             except:
                 continue
-            segment = np.concatenate([np.flip(p[0, :2]), np.flip(p[0, 2:],
-                                     axis=0)])[None, ...]
-            warped_segments_new = np.concatenate(
-                (warped_segments_new, segment), axis=0)
+            segment = np.concatenate([np.flip(p[0, :2]), np.flip(p[0, 2:], axis=0)])[
+                None, ...
+            ]
+            warped_segments_new = np.concatenate((warped_segments_new, segment), axis=0)
 
         else:
             continue
@@ -289,9 +343,9 @@ def warp_line_segment(line_segments, homography, image_size):
 
 
 class homography_transform(object):
-    """ # Homography transformations. """
-    def __init__(self, image_size, homograpy_config,
-                 border_margin=0, min_label_len=20):
+    """# Homography transformations."""
+
+    def __init__(self, image_size, homograpy_config, border_margin=0, min_label_len=20):
         self.homo_config = homograpy_config
         self.image_size = image_size
         self.target_size = (self.image_size[1], self.image_size[0])
@@ -300,31 +354,33 @@ class homography_transform(object):
             raise ValueError("[Error] min_label_len should be in pixels.")
         self.min_label_len = min_label_len
 
-    def __call__(self, input_image, junctions, line_map,
-                 valid_mask=None, homo=None, scale=None):
+    def __call__(
+        self, input_image, junctions, line_map, valid_mask=None, homo=None, scale=None
+    ):
         # Sample one random homography or use the given one
         if homo is None or scale is None:
-            homo, scale = sample_homography(self.image_size,
-                                            **self.homo_config)
+            homo, scale = sample_homography(self.image_size, **self.homo_config)
 
         # Warp the image
         warped_image = cv2.warpPerspective(
-            input_image, homo, self.target_size, flags=cv2.INTER_LINEAR)
-        
-        valid_mask = compute_valid_mask(self.image_size, homo,
-                                        self.border_margin, valid_mask)
+            input_image, homo, self.target_size, flags=cv2.INTER_LINEAR
+        )
+
+        valid_mask = compute_valid_mask(
+            self.image_size, homo, self.border_margin, valid_mask
+        )
 
         # Convert junctions and line_map back to line segments
         line_segments = convert_to_line_segments(junctions, line_map)
 
         # Warp the segments and check the length.
         # Adjust the min_label_length
-        warped_segments = warp_line_segment(line_segments, homo,
-                                             self.image_size)
+        warped_segments = warp_line_segment(line_segments, homo, self.image_size)
 
         # Convert back to junctions and line_map
-        junctions_new = np.concatenate((warped_segments[:, :2],
-                                        warped_segments[:, 2:]), axis=0)
+        junctions_new = np.concatenate(
+            (warped_segments[:, :2], warped_segments[:, 2:]), axis=0
+        )
         if junctions_new.shape[0] == 0:
             junctions_new = np.zeros([0, 2])
             line_map = np.zeros([0, 0])
@@ -333,11 +389,11 @@ class homography_transform(object):
             junctions_new = np.unique(junctions_new, axis=0)
 
             # Generate line map from points and segments
-            line_map = get_line_map(junctions_new,
-                                    warped_segments).astype(np.int)
+            line_map = get_line_map(junctions_new, warped_segments).astype(np.int)
             # Compute the heatmap
-            warped_heatmap = get_line_heatmap(np.flip(junctions_new, axis=1),
-                                              line_map, self.image_size)
+            warped_heatmap = get_line_heatmap(
+                np.flip(junctions_new, axis=1), line_map, self.image_size
+            )
 
         return {
             "junctions": junctions_new,
@@ -346,5 +402,5 @@ class homography_transform(object):
             "line_map": line_map,
             "warped_heatmap": warped_heatmap,
             "homo": homo,
-            "scale": scale
+            "scale": scale,
         }
diff --git a/third_party/SOLD2/sold2/dataset/transforms/photometric_transforms.py b/third_party/SOLD2/sold2/dataset/transforms/photometric_transforms.py
index 8fa44bf0efa93a47e5f8012988058f1cbd49324f..5f41192cd2cba7b47939f031027e8dce6e1a406f 100644
--- a/third_party/SOLD2/sold2/dataset/transforms/photometric_transforms.py
+++ b/third_party/SOLD2/sold2/dataset/transforms/photometric_transforms.py
@@ -9,17 +9,18 @@ import cv2
 
 # List all the available augmentations
 available_augmentations = [
-    'additive_gaussian_noise',
-    'additive_speckle_noise',
-    'random_brightness',
-    'random_contrast',
-    'additive_shade',
-    'motion_blur'
+    "additive_gaussian_noise",
+    "additive_speckle_noise",
+    "random_brightness",
+    "random_contrast",
+    "additive_shade",
+    "motion_blur",
 ]
 
 
 class additive_gaussian_noise(object):
-    """ Additive gaussian noise. """
+    """Additive gaussian noise."""
+
     def __init__(self, stddev_range=None):
         # If std is not given, use the default setting
         if stddev_range is None:
@@ -30,14 +31,15 @@ class additive_gaussian_noise(object):
     def __call__(self, input_image):
         # Get the noise stddev
         stddev = np.random.uniform(self.stddev_range[0], self.stddev_range[1])
-        noise = np.random.normal(0., stddev, size=input_image.shape)
-        noisy_image = (input_image + noise).clip(0., 255.)
+        noise = np.random.normal(0.0, stddev, size=input_image.shape)
+        noisy_image = (input_image + noise).clip(0.0, 255.0)
 
         return noisy_image
 
 
 class additive_speckle_noise(object):
-    """ Additive speckle noise. """
+    """Additive speckle noise."""
+
     def __init__(self, prob_range=None):
         # If prob range is not given, use the default setting
         if prob_range is None:
@@ -48,7 +50,7 @@ class additive_speckle_noise(object):
     def __call__(self, input_image):
         # Sample
         prob = np.random.uniform(self.prob_range[0], self.prob_range[1])
-        sample = np.random.uniform(0., 1., size=input_image.shape)
+        sample = np.random.uniform(0.0, 1.0, size=input_image.shape)
 
         # Get the mask
         mask0 = sample <= prob
@@ -56,14 +58,15 @@ class additive_speckle_noise(object):
 
         # Mask the image (here we assume the image ranges from 0~255
         noisy = input_image.copy()
-        noisy[mask0] = 0.
-        noisy[mask1] = 255.
+        noisy[mask0] = 0.0
+        noisy[mask1] = 255.0
 
         return noisy
 
 
 class random_brightness(object):
-    """ Brightness change. """
+    """Brightness change."""
+
     def __init__(self, brightness=None):
         # If the brightness is not given, use the default setting
         if brightness is None:
@@ -83,7 +86,8 @@ class random_brightness(object):
 
 
 class random_contrast(object):
-    """ Additive contrast. """
+    """Additive contrast."""
+
     def __init__(self, contrast=None):
         # If the brightness is not given, use the default setting
         if contrast is None:
@@ -103,9 +107,9 @@ class random_contrast(object):
 
 
 class additive_shade(object):
-    """ Additive shade. """
-    def __init__(self, nb_ellipses=20, transparency_range=None,
-                   kernel_size_range=None):
+    """Additive shade."""
+
+    def __init__(self, nb_ellipses=20, transparency_range=None, kernel_size_range=None):
         self.nb_ellipses = nb_ellipses
         if transparency_range is None:
             self.transparency_range = [-0.5, 0.8]
@@ -136,39 +140,40 @@ class additive_shade(object):
         # kernel_size has to be odd
         if (kernel_size % 2) == 0:
             kernel_size += 1
-        mask = cv2.GaussianBlur(mask.astype(np.float32),
-                                (kernel_size, kernel_size), 0)
-        shaded = (input_image[..., None]
-                  * (1 - transparency * mask[..., np.newaxis]/255.))
+        mask = cv2.GaussianBlur(mask.astype(np.float32), (kernel_size, kernel_size), 0)
+        shaded = input_image[..., None] * (
+            1 - transparency * mask[..., np.newaxis] / 255.0
+        )
         shaded = np.clip(shaded, 0, 255)
 
         return np.reshape(shaded, input_image.shape)
 
 
 class motion_blur(object):
-    """ Motion blur. """
+    """Motion blur."""
+
     def __init__(self, max_kernel_size=10):
         self.max_kernel_size = max_kernel_size
 
     def __call__(self, input_image):
         # Either vertical, horizontal or diagonal blur
-        mode = np.random.choice(['h', 'v', 'diag_down', 'diag_up'])
-        ksize = np.random.randint(
-            0, int(round((self.max_kernel_size + 1) / 2))) * 2 + 1
+        mode = np.random.choice(["h", "v", "diag_down", "diag_up"])
+        ksize = np.random.randint(0, int(round((self.max_kernel_size + 1) / 2))) * 2 + 1
         center = int((ksize - 1) / 2)
         kernel = np.zeros((ksize, ksize))
-        if mode == 'h':
-            kernel[center, :] = 1.
-        elif mode == 'v':
-            kernel[:, center] = 1.
-        elif mode == 'diag_down':
+        if mode == "h":
+            kernel[center, :] = 1.0
+        elif mode == "v":
+            kernel[:, center] = 1.0
+        elif mode == "diag_down":
             kernel = np.eye(ksize)
-        elif mode == 'diag_up':
+        elif mode == "diag_up":
             kernel = np.flip(np.eye(ksize), 0)
-        var = ksize * ksize / 16.
+        var = ksize * ksize / 16.0
         grid = np.repeat(np.arange(ksize)[:, np.newaxis], ksize, axis=-1)
-        gaussian = np.exp(-(np.square(grid - center)
-                            + np.square(grid.T - center)) / (2. * var))
+        gaussian = np.exp(
+            -(np.square(grid - center) + np.square(grid.T - center)) / (2.0 * var)
+        )
         kernel *= gaussian
         kernel /= np.sum(kernel)
         blurred = cv2.filter2D(input_image, -1, kernel)
@@ -177,7 +182,8 @@ class motion_blur(object):
 
 
 class normalize_image(object):
-    """ Image normalization to the range [0, 1]. """
+    """Image normalization to the range [0, 1]."""
+
     def __init__(self):
         self.normalize_value = 255
 
diff --git a/third_party/SOLD2/sold2/dataset/transforms/utils.py b/third_party/SOLD2/sold2/dataset/transforms/utils.py
index 5f1ed09e5b32e2ae2f3577e0e8e5491495e7b05b..4e2d9b4234400b16c59773ebcf15ecc557df6cac 100644
--- a/third_party/SOLD2/sold2/dataset/transforms/utils.py
+++ b/third_party/SOLD2/sold2/dataset/transforms/utils.py
@@ -9,7 +9,7 @@ from ..synthetic_util import get_line_map
 from . import homographic_transforms as homoaug
 
 
-def random_scaling(image, junctions, line_map, scale=1., h_crop=0, w_crop=0):
+def random_scaling(image, junctions, line_map, scale=1.0, h_crop=0, w_crop=0):
     H, W = image.shape[:2]
     H_scale, W_scale = round(H * scale), round(W * scale)
 
@@ -18,42 +18,46 @@ def random_scaling(image, junctions, line_map, scale=1., h_crop=0, w_crop=0):
         return (image, junctions, line_map, np.ones([H, W], dtype=np.int))
 
     # Zoom-in => resize and random crop
-    if scale >= 1.:
-        image_big = cv2.resize(image, (W_scale, H_scale),
-                               interpolation=cv2.INTER_LINEAR)
+    if scale >= 1.0:
+        image_big = cv2.resize(
+            image, (W_scale, H_scale), interpolation=cv2.INTER_LINEAR
+        )
         # Crop the image
-        image = image_big[h_crop:h_crop+H, w_crop:w_crop+W, ...]
+        image = image_big[h_crop : h_crop + H, w_crop : w_crop + W, ...]
         valid_mask = np.ones([H, W], dtype=np.int)
 
         # Process junctions
         junctions, line_map = process_junctions_and_line_map(
-            h_crop, w_crop, H, W, H_scale, W_scale,
-            junctions, line_map, "zoom-in")
+            h_crop, w_crop, H, W, H_scale, W_scale, junctions, line_map, "zoom-in"
+        )
     # Zoom-out => resize and pad
     else:
         image_shape_raw = image.shape
-        image_small = cv2.resize(image, (W_scale, H_scale),
-                                 interpolation=cv2.INTER_AREA)
+        image_small = cv2.resize(
+            image, (W_scale, H_scale), interpolation=cv2.INTER_AREA
+        )
         # Decide the pasting location
         h_start = round((H - H_scale) / 2)
         w_start = round((W - W_scale) / 2)
         # Paste the image to the middle
         image = np.zeros(image_shape_raw, dtype=np.float)
-        image[h_start:h_start+H_scale,
-              w_start:w_start+W_scale, ...] = image_small
+        image[
+            h_start : h_start + H_scale, w_start : w_start + W_scale, ...
+        ] = image_small
         valid_mask = np.zeros([H, W], dtype=np.int)
-        valid_mask[h_start:h_start+H_scale, w_start:w_start+W_scale] = 1
+        valid_mask[h_start : h_start + H_scale, w_start : w_start + W_scale] = 1
 
         # Process the junctions
         junctions, line_map = process_junctions_and_line_map(
-            h_start, w_start, H, W, H_scale, W_scale,
-            junctions, line_map, "zoom-out")
+            h_start, w_start, H, W, H_scale, W_scale, junctions, line_map, "zoom-out"
+        )
 
     return image, junctions, line_map, valid_mask
 
 
-def process_junctions_and_line_map(h_start, w_start, H, W, H_scale, W_scale,
-                                   junctions, line_map, mode="zoom-in"):
+def process_junctions_and_line_map(
+    h_start, w_start, H, W, H_scale, W_scale, junctions, line_map, mode="zoom-in"
+):
     if mode == "zoom-in":
         junctions[:, 0] = junctions[:, 0] * H_scale / H
         junctions[:, 1] = junctions[:, 1] * W_scale / W
@@ -61,53 +65,55 @@ def process_junctions_and_line_map(h_start, w_start, H, W, H_scale, W_scale,
         # Crop segments to the new boundaries
         line_segments_new = np.zeros([0, 4])
         image_poly = sg.Polygon(
-            [[w_start, h_start],
-            [w_start+W, h_start],
-            [w_start+W, h_start+H],
-            [w_start, h_start+H]
-            ])
+            [
+                [w_start, h_start],
+                [w_start + W, h_start],
+                [w_start + W, h_start + H],
+                [w_start, h_start + H],
+            ]
+        )
         for idx in range(line_segments.shape[0]):
             # Get the line segment
-            seg_raw = line_segments[idx, :]   # in HW format.
+            seg_raw = line_segments[idx, :]  # in HW format.
             # Convert to shapely line (flip to xy format)
-            seg = sg.LineString([np.flip(seg_raw[:2]), 
-                                np.flip(seg_raw[2:])])
+            seg = sg.LineString([np.flip(seg_raw[:2]), np.flip(seg_raw[2:])])
             # The line segment is just inside the image.
             if seg.intersection(image_poly) == seg:
                 line_segments_new = np.concatenate(
-                    (line_segments_new, seg_raw[None, ...]), axis=0)
+                    (line_segments_new, seg_raw[None, ...]), axis=0
+                )
             # Intersect with the image.
             elif seg.intersects(image_poly):
                 # Check intersection
                 try:
-                    p = np.array(
-                        seg.intersection(image_poly).coords).reshape([-1, 4])
+                    p = np.array(seg.intersection(image_poly).coords).reshape([-1, 4])
                 # If intersect at exact one point, just continue.
                 except:
                     continue
-                segment = np.concatenate([np.flip(p[0, :2]), np.flip(p[0, 2:],
-                                         axis=0)])[None, ...]
-                line_segments_new = np.concatenate(
-                    (line_segments_new, segment), axis=0)
+                segment = np.concatenate(
+                    [np.flip(p[0, :2]), np.flip(p[0, 2:], axis=0)]
+                )[None, ...]
+                line_segments_new = np.concatenate((line_segments_new, segment), axis=0)
             else:
                 continue
         line_segments_new = (np.round(line_segments_new)).astype(np.int)
         # Filter segments with 0 length
         segment_lens = np.linalg.norm(
-            line_segments_new[:, :2] - line_segments_new[:, 2:], axis=-1)
+            line_segments_new[:, :2] - line_segments_new[:, 2:], axis=-1
+        )
         seg_mask = segment_lens != 0
         line_segments_new = line_segments_new[seg_mask, :]
         # Convert back to junctions and line_map
         junctions_new = np.concatenate(
-            (line_segments_new[:, :2], line_segments_new[:, 2:]), axis=0)
+            (line_segments_new[:, :2], line_segments_new[:, 2:]), axis=0
+        )
         if junctions_new.shape[0] == 0:
             junctions_new = np.zeros([0, 2])
             line_map = np.zeros([0, 0])
         else:
             junctions_new = np.unique(junctions_new, axis=0)
             # Generate line map from points and segments
-            line_map = get_line_map(junctions_new,
-                                    line_segments_new).astype(np.int)
+            line_map = get_line_map(junctions_new, line_segments_new).astype(np.int)
         junctions_new[:, 0] -= h_start
         junctions_new[:, 1] -= w_start
         junctions = junctions_new
diff --git a/third_party/SOLD2/sold2/dataset/wireframe_dataset.py b/third_party/SOLD2/sold2/dataset/wireframe_dataset.py
index ed5bb910bed1b89934ddaaec3bcddf111ea0faef..44341d7394303188db3ba69123bb4b4212700466 100644
--- a/third_party/SOLD2/sold2/dataset/wireframe_dataset.py
+++ b/third_party/SOLD2/sold2/dataset/wireframe_dataset.py
@@ -27,12 +27,19 @@ from ..misc.geometry_utils import warp_points, mask_points
 
 
 def wireframe_collate_fn(batch):
-    """ Customized collate_fn for wireframe dataset. """
-    batch_keys = ["image", "junction_map", "valid_mask", "heatmap",
-                  "heatmap_pos", "heatmap_neg", "homography",
-                  "line_points", "line_indices"]
-    list_keys = ["junctions", "line_map", "line_map_pos",
-                 "line_map_neg", "file_key"]
+    """Customized collate_fn for wireframe dataset."""
+    batch_keys = [
+        "image",
+        "junction_map",
+        "valid_mask",
+        "heatmap",
+        "heatmap_pos",
+        "heatmap_neg",
+        "homography",
+        "line_points",
+        "line_indices",
+    ]
+    list_keys = ["junctions", "line_map", "line_map_pos", "line_map_neg", "file_key"]
 
     outputs = {}
     for data_key in batch[0].keys():
@@ -41,14 +48,16 @@ def wireframe_collate_fn(batch):
         # print(batch_match, list_match)
         if batch_match > 0 and list_match == 0:
             outputs[data_key] = torch_loader.default_collate(
-                [b[data_key] for b in batch])
+                [b[data_key] for b in batch]
+            )
         elif batch_match == 0 and list_match > 0:
             outputs[data_key] = [b[data_key] for b in batch]
         elif batch_match == 0 and list_match == 0:
             continue
         else:
             raise ValueError(
-        "[Error] A key matches batch keys and list keys simultaneously.")
+                "[Error] A key matches batch keys and list keys simultaneously."
+            )
 
     return outputs
 
@@ -58,7 +67,8 @@ class WireframeDataset(Dataset):
         super(WireframeDataset, self).__init__()
         if not mode in ["train", "test"]:
             raise ValueError(
-        "[Error] Unknown mode for Wireframe dataset. Only 'train' and 'test'.")
+                "[Error] Unknown mode for Wireframe dataset. Only 'train' and 'test'."
+            )
         self.mode = mode
 
         if config is None:
@@ -72,18 +82,17 @@ class WireframeDataset(Dataset):
         self.dataset_name = self.get_dataset_name()
         self.cache_name = self.get_cache_name()
         self.cache_path = cfg.wireframe_cache_path
-        
+
         # Get the ground truth source
-        self.gt_source = self.config.get("gt_source_%s"%(self.mode),
-                                         "official")
+        self.gt_source = self.config.get("gt_source_%s" % (self.mode), "official")
         if not self.gt_source == "official":
             # Convert gt_source to full path
             self.gt_source = os.path.join(cfg.export_dataroot, self.gt_source)
             # Check the full path exists
             if not os.path.exists(self.gt_source):
                 raise ValueError(
-            "[Error] The specified ground truth source does not exist.")
-        
+                    "[Error] The specified ground truth source does not exist."
+                )
 
         # Get the filename dataset
         print("[Info] Initializing wireframe dataset...")
@@ -95,22 +104,22 @@ class WireframeDataset(Dataset):
         # Print some info
         print("[Info] Successfully initialized dataset")
         print("\t Name: wireframe")
-        print("\t Mode: %s" %(self.mode))
-        print("\t Gt: %s" %(self.config.get("gt_source_%s"%(self.mode),
-                                            "official")))
-        print("\t Counts: %d" %(self.dataset_length))
+        print("\t Mode: %s" % (self.mode))
+        print("\t Gt: %s" % (self.config.get("gt_source_%s" % (self.mode), "official")))
+        print("\t Counts: %d" % (self.dataset_length))
         print("----------------------------------------")
 
     #######################################
     ## Dataset construction related APIs ##
     #######################################
     def construct_dataset(self):
-        """ Construct the dataset (from scratch or from cache). """
+        """Construct the dataset (from scratch or from cache)."""
         # Check if the filename cache exists
         # If cache exists, load from cache
         if self._check_dataset_cache():
-            print("\t Found filename cache %s at %s"%(self.cache_name,
-                                                      self.cache_path))
+            print(
+                "\t Found filename cache %s at %s" % (self.cache_name, self.cache_path)
+            )
             print("\t Load filename cache...")
             filename_dataset, datapoints = self.get_filename_dataset_from_cache()
         # If not, initialize dataset from scratch
@@ -120,30 +129,27 @@ class WireframeDataset(Dataset):
             filename_dataset, datapoints = self.get_filename_dataset()
             print("\t Create filename dataset cache...")
             self.create_filename_dataset_cache(filename_dataset, datapoints)
-        
+
         return filename_dataset, datapoints
-    
+
     def create_filename_dataset_cache(self, filename_dataset, datapoints):
-        """ Create filename dataset cache for faster initialization. """
+        """Create filename dataset cache for faster initialization."""
         # Check cache path exists
         if not os.path.exists(self.cache_path):
             os.makedirs(self.cache_path)
 
         cache_file_path = os.path.join(self.cache_path, self.cache_name)
-        data = {
-            "filename_dataset": filename_dataset,
-            "datapoints": datapoints
-        }
+        data = {"filename_dataset": filename_dataset, "datapoints": datapoints}
         with open(cache_file_path, "wb") as f:
             pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
-    
+
     def get_filename_dataset_from_cache(self):
-        """ Get filename dataset from cache. """
+        """Get filename dataset from cache."""
         # Load from pkl cache
         cache_file_path = os.path.join(self.cache_path, self.cache_name)
         with open(cache_file_path, "rb") as f:
             data = pickle.load(f)
-        
+
         return data["filename_dataset"], data["datapoints"]
 
     def get_filename_dataset(self):
@@ -152,14 +158,18 @@ class WireframeDataset(Dataset):
             dataset_path = os.path.join(cfg.wireframe_dataroot, "train")
         elif self.mode == "test":
             dataset_path = os.path.join(cfg.wireframe_dataroot, "valid")
-        
+
         # Get paths to all image files
-        image_paths = sorted([os.path.join(dataset_path, _)
-                              for _ in os.listdir(dataset_path)\
-                              if os.path.splitext(_)[-1] == ".png"])
+        image_paths = sorted(
+            [
+                os.path.join(dataset_path, _)
+                for _ in os.listdir(dataset_path)
+                if os.path.splitext(_)[-1] == ".png"
+            ]
+        )
         # Get the shared prefix
         prefix_paths = [_.split(".png")[0] for _ in image_paths]
-        
+
         # Get the label paths (different procedure for different split)
         if self.mode == "train":
             label_paths = [_ + "_label.npz" for _ in prefix_paths]
@@ -171,17 +181,18 @@ class WireframeDataset(Dataset):
         for idx in range(len(image_paths)):
             image_path = image_paths[idx]
             label_path = label_paths[idx]
-            if (not (os.path.exists(image_path)
-                and os.path.exists(label_path))):
+            if not (os.path.exists(image_path) and os.path.exists(label_path)):
                 raise ValueError(
-            "[Error] The image and label do not exist. %s"%(image_path))
+                    "[Error] The image and label do not exist. %s" % (image_path)
+                )
             # Further verify mat paths for test split
             if self.mode == "test":
                 mat_path = mat_paths[idx]
                 if not os.path.exists(mat_path):
                     raise ValueError(
-                "[Error] The mat file does not exist. %s"%(mat_path))
-        
+                        "[Error] The mat file does not exist. %s" % (mat_path)
+                    )
+
         # Construct the filename dataset
         num_pad = int(math.ceil(math.log10(len(image_paths))) + 1)
         filename_dataset = {}
@@ -191,25 +202,25 @@ class WireframeDataset(Dataset):
 
             filename_dataset[key] = {
                 "image": image_paths[idx],
-                "label": label_paths[idx]
+                "label": label_paths[idx],
             }
 
         # Get the datapoints
         datapoints = list(sorted(filename_dataset.keys()))
 
         return filename_dataset, datapoints
-    
+
     def get_dataset_name(self):
-        """ Get dataset name from dataset config / default config. """
+        """Get dataset name from dataset config / default config."""
         if self.config["dataset_name"] is None:
             dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode
         else:
             dataset_name = self.config["dataset_name"] + "_%s" % self.mode
 
         return dataset_name
-    
+
     def get_cache_name(self):
-        """ Get cache name from dataset config / default config. """
+        """Get cache name from dataset config / default config."""
         if self.config["dataset_name"] is None:
             dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode
         else:
@@ -218,35 +229,27 @@ class WireframeDataset(Dataset):
         cache_name = dataset_name + "_cache.pkl"
 
         return cache_name
-    
+
     @staticmethod
     def get_padded_filename(num_pad, idx):
-        """ Get the padded filename using adaptive padding. """
+        """Get the padded filename using adaptive padding."""
         file_len = len("%d" % (idx))
         filename = "0" * (num_pad - file_len) + "%d" % (idx)
 
         return filename
 
     def get_default_config(self):
-        """ Get the default configuration. """
+        """Get the default configuration."""
         return {
             "dataset_name": "wireframe",
             "add_augmentation_to_all_splits": False,
-            "preprocessing": {
-                "resize": [240, 320],
-                "blur_size": 11
-            },
-            "augmentation":{
-                "photometric":{
-                    "enable": False
-                },
-                "homographic":{
-                    "enable": False
-                },
+            "preprocessing": {"resize": [240, 320], "blur_size": 11},
+            "augmentation": {
+                "photometric": {"enable": False},
+                "homographic": {"enable": False},
             },
         }
 
-        
     ############################################
     ## Pytorch and preprocessing related APIs ##
     ############################################
@@ -280,13 +283,13 @@ class WireframeDataset(Dataset):
         # TODO: How to process mat data
         if data_path.get("line_mat") is not None:
             raise NotImplementedError
-        
+
         return output
-    
+
     @staticmethod
     def convert_line_map(lcnn_line_map, num_junctions):
-        """ Convert the line_pos or line_neg
-            (represented by two junction indexes) to our line map. """
+        """Convert the line_pos or line_neg
+        (represented by two junction indexes) to our line map."""
         # Initialize empty line map
         line_map = np.zeros([num_junctions, num_junctions])
 
@@ -297,59 +300,60 @@ class WireframeDataset(Dataset):
 
             line_map[index1, index2] = 1
             line_map[index2, index1] = 1
-        
+
         return line_map
-    
+
     @staticmethod
     def junc_to_junc_map(junctions, image_size):
-        """ Convert junction points to junction maps. """
+        """Convert junction points to junction maps."""
         junctions = np.round(junctions).astype(np.int)
         # Clip the boundary by image size
-        junctions[:, 0] = np.clip(junctions[:, 0], 0., image_size[0]-1)
-        junctions[:, 1] = np.clip(junctions[:, 1], 0., image_size[1]-1)
+        junctions[:, 0] = np.clip(junctions[:, 0], 0.0, image_size[0] - 1)
+        junctions[:, 1] = np.clip(junctions[:, 1], 0.0, image_size[1] - 1)
 
         # Create junction map
         junc_map = np.zeros([image_size[0], image_size[1]])
         junc_map[junctions[:, 0], junctions[:, 1]] = 1
 
         return junc_map[..., None].astype(np.int)
-    
+
     def parse_transforms(self, names, all_transforms):
-        """ Parse the transform. """
-        trans = all_transforms if (names == 'all') \
+        """Parse the transform."""
+        trans = (
+            all_transforms
+            if (names == "all")
             else (names if isinstance(names, list) else [names])
+        )
         assert set(trans) <= set(all_transforms)
         return trans
 
     def get_photo_transform(self):
-        """ Get list of photometric transforms (according to the config). """
+        """Get list of photometric transforms (according to the config)."""
         # Get the photometric transform config
         photo_config = self.config["augmentation"]["photometric"]
         if not photo_config["enable"]:
-            raise ValueError(
-        "[Error] Photometric augmentation is not enabled.")
-        
+            raise ValueError("[Error] Photometric augmentation is not enabled.")
+
         # Parse photometric transforms
-        trans_lst = self.parse_transforms(photo_config["primitives"],
-                                          photoaug.available_augmentations)
-        trans_config_lst = [photo_config["params"].get(p, {})
-                            for p in trans_lst]
+        trans_lst = self.parse_transforms(
+            photo_config["primitives"], photoaug.available_augmentations
+        )
+        trans_config_lst = [photo_config["params"].get(p, {}) for p in trans_lst]
 
         # List of photometric augmentation
         photometric_trans_lst = [
-            getattr(photoaug, trans)(**conf) \
+            getattr(photoaug, trans)(**conf)
             for (trans, conf) in zip(trans_lst, trans_config_lst)
         ]
 
         return photometric_trans_lst
 
     def get_homo_transform(self):
-        """ Get homographic transforms (according to the config). """
+        """Get homographic transforms (according to the config)."""
         # Get homographic transforms for image
         homo_config = self.config["augmentation"]["homographic"]["params"]
         if not self.config["augmentation"]["homographic"]["enable"]:
-            raise ValueError(
-        "[Error] Homographic augmentation is not enabled.")
+            raise ValueError("[Error] Homographic augmentation is not enabled.")
 
         # Parse the homographic transforms
         image_shape = self.config["preprocessing"]["resize"]
@@ -359,67 +363,73 @@ class WireframeDataset(Dataset):
             min_label_tmp = self.config["generation"]["min_label_len"]
         except:
             min_label_tmp = None
-        
+
         # float label len => fraction
-        if isinstance(min_label_tmp, float): # Skip if not provided
+        if isinstance(min_label_tmp, float):  # Skip if not provided
             min_label_len = min_label_tmp * min(image_shape)
         # int label len => length in pixel
         elif isinstance(min_label_tmp, int):
-            scale_ratio = (self.config["preprocessing"]["resize"]
-                           / self.config["generation"]["image_size"][0])
-            min_label_len = (self.config["generation"]["min_label_len"]
-                             * scale_ratio)
+            scale_ratio = (
+                self.config["preprocessing"]["resize"]
+                / self.config["generation"]["image_size"][0]
+            )
+            min_label_len = self.config["generation"]["min_label_len"] * scale_ratio
         # if none => no restriction
         else:
             min_label_len = 0
-        
+
         # Initialize the transform
         homographic_trans = homoaug.homography_transform(
-            image_shape, homo_config, 0, min_label_len)
+            image_shape, homo_config, 0, min_label_len
+        )
 
         return homographic_trans
 
-    def get_line_points(self, junctions, line_map, H1=None, H2=None,
-                        img_size=None, warp=False):
-        """ Sample evenly points along each line segments
-            and keep track of line idx. """
+    def get_line_points(
+        self, junctions, line_map, H1=None, H2=None, img_size=None, warp=False
+    ):
+        """Sample evenly points along each line segments
+        and keep track of line idx."""
         if np.sum(line_map) == 0:
             # No segment detected in the image
             line_indices = np.zeros(self.config["max_pts"], dtype=int)
             line_points = np.zeros((self.config["max_pts"], 2), dtype=float)
             return line_points, line_indices
-            
+
         # Extract all pairs of connected junctions
         junc_indices = np.array(
-            [[i, j] for (i, j) in zip(*np.where(line_map)) if j > i])
-        line_segments = np.stack([junctions[junc_indices[:, 0]],
-                                  junctions[junc_indices[:, 1]]], axis=1)
+            [[i, j] for (i, j) in zip(*np.where(line_map)) if j > i]
+        )
+        line_segments = np.stack(
+            [junctions[junc_indices[:, 0]], junctions[junc_indices[:, 1]]], axis=1
+        )
         # line_segments is (num_lines, 2, 2)
-        line_lengths = np.linalg.norm(
-            line_segments[:, 0] - line_segments[:, 1], axis=1)
+        line_lengths = np.linalg.norm(line_segments[:, 0] - line_segments[:, 1], axis=1)
 
         # Sample the points separated by at least min_dist_pts along each line
         # The number of samples depends on the length of the line
-        num_samples = np.minimum(line_lengths // self.config["min_dist_pts"],
-                                 self.config["max_num_samples"])
+        num_samples = np.minimum(
+            line_lengths // self.config["min_dist_pts"], self.config["max_num_samples"]
+        )
         line_points = []
         line_indices = []
         cur_line_idx = 1
         for n in np.arange(2, self.config["max_num_samples"] + 1):
             # Consider all lines where we can fit up to n points
             cur_line_seg = line_segments[num_samples == n]
-            line_points_x = np.linspace(cur_line_seg[:, 0, 0],
-                                        cur_line_seg[:, 1, 0],
-                                        n, axis=-1).flatten()
-            line_points_y = np.linspace(cur_line_seg[:, 0, 1],
-                                        cur_line_seg[:, 1, 1],
-                                        n, axis=-1).flatten()
+            line_points_x = np.linspace(
+                cur_line_seg[:, 0, 0], cur_line_seg[:, 1, 0], n, axis=-1
+            ).flatten()
+            line_points_y = np.linspace(
+                cur_line_seg[:, 0, 1], cur_line_seg[:, 1, 1], n, axis=-1
+            ).flatten()
             jitter = self.config.get("jittering", 0)
             if jitter:
                 # Add a small random jittering of all points along the line
                 angles = np.arctan2(
                     cur_line_seg[:, 1, 0] - cur_line_seg[:, 0, 0],
-                    cur_line_seg[:, 1, 1] - cur_line_seg[:, 0, 1]).repeat(n)
+                    cur_line_seg[:, 1, 1] - cur_line_seg[:, 0, 1],
+                ).repeat(n)
                 jitter_hyp = (np.random.rand(len(angles)) * 2 - 1) * jitter
                 line_points_x += jitter_hyp * np.sin(angles)
                 line_points_y += jitter_hyp * np.cos(angles)
@@ -429,10 +439,8 @@ class WireframeDataset(Dataset):
             line_idx = np.arange(cur_line_idx, cur_line_idx + num_cur_lines)
             line_indices.append(line_idx.repeat(n))
             cur_line_idx += num_cur_lines
-        line_points = np.concatenate(line_points,
-                                     axis=0)[:self.config["max_pts"]]
-        line_indices = np.concatenate(line_indices,
-                                      axis=0)[:self.config["max_pts"]]
+        line_points = np.concatenate(line_points, axis=0)[: self.config["max_pts"]]
+        line_indices = np.concatenate(line_indices, axis=0)[: self.config["max_pts"]]
 
         # Warp the points if need be, and filter unvalid ones
         # If the other view is also warped
@@ -454,20 +462,24 @@ class WireframeDataset(Dataset):
             mask = mask_points(warped_points, img_size)
         line_points = line_points[mask]
         line_indices = line_indices[mask]
-        
+
         # Pad the line points to a fixed length
         # Index of 0 means padded line
-        line_indices = np.concatenate([line_indices, np.zeros(
-            self.config["max_pts"] - len(line_indices))], axis=0)
+        line_indices = np.concatenate(
+            [line_indices, np.zeros(self.config["max_pts"] - len(line_indices))], axis=0
+        )
         line_points = np.concatenate(
-            [line_points,
-             np.zeros((self.config["max_pts"] - len(line_points), 2),
-                      dtype=float)], axis=0)
-        
+            [
+                line_points,
+                np.zeros((self.config["max_pts"] - len(line_points), 2), dtype=float),
+            ],
+            axis=0,
+        )
+
         return line_points, line_indices
 
     def train_preprocessing(self, data, numpy=False):
-        """ Train preprocessing for GT data. """
+        """Train preprocessing for GT data."""
         # Fetch the corresponding entries
         image = data["image"]
         junctions = data["junc"][:, :2]
@@ -476,23 +488,27 @@ class WireframeDataset(Dataset):
         image_size = image.shape[:2]
         # Convert junctions to pixel coordinates (from 128x128)
         junctions[:, 0] *= image_size[0] / 128
-        junctions[:, 1] *= image_size[1] / 128 
+        junctions[:, 1] *= image_size[1] / 128
 
         # Resize the image before photometric and homographical augmentations
-        if not(list(image_size) == self.config["preprocessing"]["resize"]):
+        if not (list(image_size) == self.config["preprocessing"]["resize"]):
             # Resize the image and the point location.
-            size_old = list(image.shape)[:2] # Only H and W dimensions
+            size_old = list(image.shape)[:2]  # Only H and W dimensions
 
             image = cv2.resize(
-                image, tuple(self.config['preprocessing']['resize'][::-1]),
-                interpolation=cv2.INTER_LINEAR)
+                image,
+                tuple(self.config["preprocessing"]["resize"][::-1]),
+                interpolation=cv2.INTER_LINEAR,
+            )
             image = np.array(image, dtype=np.uint8)
 
             # In HW format
-            junctions = (junctions * np.array(
-                self.config['preprocessing']['resize'], np.float)
-                         / np.array(size_old, np.float))
-        
+            junctions = (
+                junctions
+                * np.array(self.config["preprocessing"]["resize"], np.float)
+                / np.array(size_old, np.float)
+            )
+
         # Convert to positive line map and negative line map (our format)
         num_junctions = junctions.shape[0]
         line_map_pos = self.convert_line_map(line_pos, num_junctions)
@@ -509,7 +525,7 @@ class WireframeDataset(Dataset):
 
         # Optionally convert the image to grayscale
         if self.config["gray_scale"]:
-            image = (color.rgb2gray(image) * 255.).astype(np.uint8)
+            image = (color.rgb2gray(image) * 255.0).astype(np.uint8)
 
         # Check if we need to apply augmentations
         # In training mode => yes.
@@ -519,7 +535,8 @@ class WireframeDataset(Dataset):
             ### Image transform ###
             np.random.shuffle(photo_trans_lst)
             image_transform = transforms.Compose(
-                photo_trans_lst + [photoaug.normalize_image()])
+                photo_trans_lst + [photoaug.normalize_image()]
+            )
         else:
             image_transform = photoaug.normalize_image()
         image = image_transform(image)
@@ -549,13 +566,11 @@ class WireframeDataset(Dataset):
                 "image": to_tensor(image),
                 "junctions": to_tensor(junctions).to(torch.float32)[0, ...],
                 "junction_map": to_tensor(junction_map).to(torch.int),
-                "line_map_pos": to_tensor(
-                    line_map_pos).to(torch.int32)[0, ...],
-                "line_map_neg": to_tensor(
-                    line_map_neg).to(torch.int32)[0, ...],
+                "line_map_pos": to_tensor(line_map_pos).to(torch.int32)[0, ...],
+                "line_map_neg": to_tensor(line_map_neg).to(torch.int32)[0, ...],
                 "heatmap_pos": to_tensor(heatmap_pos).to(torch.int32),
                 "heatmap_neg": to_tensor(heatmap_neg).to(torch.int32),
-                "valid_mask": to_tensor(valid_mask).to(torch.int32)
+                "valid_mask": to_tensor(valid_mask).to(torch.int32),
             }
         else:
             return {
@@ -566,14 +581,23 @@ class WireframeDataset(Dataset):
                 "line_map_neg": line_map_neg.astype(np.int32),
                 "heatmap_pos": heatmap_pos.astype(np.int32),
                 "heatmap_neg": heatmap_neg.astype(np.int32),
-                "valid_mask": valid_mask.astype(np.int32)
+                "valid_mask": valid_mask.astype(np.int32),
             }
-    
+
     def train_preprocessing_exported(
-        self, data, numpy=False, disable_homoaug=False,
-        desc_training=False, H1=None, H1_scale=None, H2=None, scale=1.,
-        h_crop=None, w_crop=None):
-        """ Train preprocessing for the exported labels. """
+        self,
+        data,
+        numpy=False,
+        disable_homoaug=False,
+        desc_training=False,
+        H1=None,
+        H1_scale=None,
+        H2=None,
+        scale=1.0,
+        h_crop=None,
+        w_crop=None,
+    ):
+        """Train preprocessing for the exported labels."""
         data = copy.deepcopy(data)
         # Fetch the corresponding entries
         image = data["image"]
@@ -593,13 +617,15 @@ class WireframeDataset(Dataset):
                     w_crop = np.random.randint(W_scale - W)
 
         # Resize the image before photometric and homographical augmentations
-        if not(list(image_size) == self.config["preprocessing"]["resize"]):
+        if not (list(image_size) == self.config["preprocessing"]["resize"]):
             # Resize the image and the point location.
-            size_old = list(image.shape)[:2] # Only H and W dimensions
+            size_old = list(image.shape)[:2]  # Only H and W dimensions
 
             image = cv2.resize(
-                image, tuple(self.config['preprocessing']['resize'][::-1]),
-                interpolation=cv2.INTER_LINEAR)
+                image,
+                tuple(self.config["preprocessing"]["resize"][::-1]),
+                interpolation=cv2.INTER_LINEAR,
+            )
             image = np.array(image, dtype=np.uint8)
 
             # # In HW format
@@ -614,7 +640,7 @@ class WireframeDataset(Dataset):
 
         # Optionally convert the image to grayscale
         if self.config["gray_scale"]:
-            image = (color.rgb2gray(image) * 255.).astype(np.uint8)
+            image = (color.rgb2gray(image) * 255.0).astype(np.uint8)
 
         # Check if we need to apply augmentations
         # In training mode => yes.
@@ -624,40 +650,49 @@ class WireframeDataset(Dataset):
             ### Image transform ###
             np.random.shuffle(photo_trans_lst)
             image_transform = transforms.Compose(
-                photo_trans_lst + [photoaug.normalize_image()])
+                photo_trans_lst + [photoaug.normalize_image()]
+            )
         else:
             image_transform = photoaug.normalize_image()
         image = image_transform(image)
-        
+
         # Perform the random scaling
-        if scale != 1.:
+        if scale != 1.0:
             image, junctions, line_map, valid_mask = random_scaling(
-                 image, junctions, line_map, scale,
-                 h_crop=h_crop, w_crop=w_crop)
+                image, junctions, line_map, scale, h_crop=h_crop, w_crop=w_crop
+            )
         else:
             # Declare default valid mask (all ones)
             valid_mask = np.ones(image_size)
-            
+
         # Initialize the empty output dict
         outputs = {}
         # Convert to tensor and return the results
         to_tensor = transforms.ToTensor()
 
         # Check homographic augmentation
-        warp = (self.config["augmentation"]["homographic"]["enable"]
-                and disable_homoaug == False)
+        warp = (
+            self.config["augmentation"]["homographic"]["enable"]
+            and disable_homoaug == False
+        )
         if warp:
             homo_trans = self.get_homo_transform()
             # Perform homographic transform
             if H1 is None:
                 homo_outputs = homo_trans(
-                    image, junctions, line_map, valid_mask=valid_mask)
+                    image, junctions, line_map, valid_mask=valid_mask
+                )
             else:
                 homo_outputs = homo_trans(
-                    image, junctions, line_map, homo=H1, scale=H1_scale,
-                    valid_mask=valid_mask)
+                    image,
+                    junctions,
+                    line_map,
+                    homo=H1,
+                    scale=H1_scale,
+                    valid_mask=valid_mask,
+                )
             homography_mat = homo_outputs["homo"]
-            
+
             # Give the warp of the other view
             if H1 is None:
                 H1 = homo_outputs["homo"]
@@ -665,8 +700,8 @@ class WireframeDataset(Dataset):
         # Sample points along each line segments for the descriptor
         if desc_training:
             line_points, line_indices = self.get_line_points(
-                junctions, line_map, H1=H1, H2=H2,
-                img_size=image_size, warp=warp)
+                junctions, line_map, H1=H1, H2=H2, img_size=image_size, warp=warp
+            )
 
         # Record the warped results
         if warp:
@@ -675,52 +710,59 @@ class WireframeDataset(Dataset):
             line_map = homo_outputs["line_map"]
             valid_mask = homo_outputs["valid_mask"]  # Same for pos and neg
             heatmap = homo_outputs["warped_heatmap"]
-            
+
             # Optionally put warping information first.
             if not numpy:
-                outputs["homography_mat"] = to_tensor(
-                    homography_mat).to(torch.float32)[0, ...]
+                outputs["homography_mat"] = to_tensor(homography_mat).to(torch.float32)[
+                    0, ...
+                ]
             else:
                 outputs["homography_mat"] = homography_mat.astype(np.float32)
 
         junction_map = self.junc_to_junc_map(junctions, image_size)
-        
+
         if not numpy:
-            outputs.update({
-                "image": to_tensor(image).to(torch.float32),
-                "junctions": to_tensor(junctions).to(torch.float32)[0, ...],
-                "junction_map": to_tensor(junction_map).to(torch.int),
-                "line_map": to_tensor(line_map).to(torch.int32)[0, ...],
-                "heatmap": to_tensor(heatmap).to(torch.int32),
-                "valid_mask": to_tensor(valid_mask).to(torch.int32)
-            })
+            outputs.update(
+                {
+                    "image": to_tensor(image).to(torch.float32),
+                    "junctions": to_tensor(junctions).to(torch.float32)[0, ...],
+                    "junction_map": to_tensor(junction_map).to(torch.int),
+                    "line_map": to_tensor(line_map).to(torch.int32)[0, ...],
+                    "heatmap": to_tensor(heatmap).to(torch.int32),
+                    "valid_mask": to_tensor(valid_mask).to(torch.int32),
+                }
+            )
             if desc_training:
-                outputs.update({
-                    "line_points": to_tensor(
-                        line_points).to(torch.float32)[0],
-                    "line_indices": torch.tensor(line_indices,
-                                                 dtype=torch.int)
-                })
+                outputs.update(
+                    {
+                        "line_points": to_tensor(line_points).to(torch.float32)[0],
+                        "line_indices": torch.tensor(line_indices, dtype=torch.int),
+                    }
+                )
         else:
-            outputs.update({
-                "image": image,
-                "junctions": junctions.astype(np.float32),
-                "junction_map": junction_map.astype(np.int32),
-                "line_map": line_map.astype(np.int32),
-                "heatmap": heatmap.astype(np.int32),
-                "valid_mask": valid_mask.astype(np.int32)
-            })
+            outputs.update(
+                {
+                    "image": image,
+                    "junctions": junctions.astype(np.float32),
+                    "junction_map": junction_map.astype(np.int32),
+                    "line_map": line_map.astype(np.int32),
+                    "heatmap": heatmap.astype(np.int32),
+                    "valid_mask": valid_mask.astype(np.int32),
+                }
+            )
             if desc_training:
-                outputs.update({
-                    "line_points": line_points.astype(np.float32),
-                    "line_indices": line_indices.astype(int)
-                })
-        
+                outputs.update(
+                    {
+                        "line_points": line_points.astype(np.float32),
+                        "line_indices": line_indices.astype(int),
+                    }
+                )
+
         return outputs
-    
-    def preprocessing_exported_paired_desc(self, data, numpy=False, scale=1.):
-        """ Train preprocessing for paired data for the exported labels
-            for descriptor training. """
+
+    def preprocessing_exported_paired_desc(self, data, numpy=False, scale=1.0):
+        """Train preprocessing for paired data for the exported labels
+        for descriptor training."""
         outputs = {}
 
         # Define the random crop for scaling if necessary
@@ -732,36 +774,49 @@ class WireframeDataset(Dataset):
                 h_crop = np.random.randint(H_scale - H)
             if W_scale > W:
                 w_crop = np.random.randint(W_scale - W)
-        
+
         # Sample ref homography first
         homo_config = self.config["augmentation"]["homographic"]["params"]
         image_shape = self.config["preprocessing"]["resize"]
-        ref_H, ref_scale = homoaug.sample_homography(image_shape,
-                                                     **homo_config)
+        ref_H, ref_scale = homoaug.sample_homography(image_shape, **homo_config)
 
         # Data for target view (All augmentation)
         target_data = self.train_preprocessing_exported(
-            data, numpy=numpy, desc_training=True, H1=None, H2=ref_H,
-            scale=scale, h_crop=h_crop, w_crop=w_crop)
+            data,
+            numpy=numpy,
+            desc_training=True,
+            H1=None,
+            H2=ref_H,
+            scale=scale,
+            h_crop=h_crop,
+            w_crop=w_crop,
+        )
 
         # Data for reference view (No homographical augmentation)
         ref_data = self.train_preprocessing_exported(
-            data, numpy=numpy, desc_training=True, H1=ref_H,
-            H1_scale=ref_scale, H2=target_data["homography_mat"].numpy(),
-            scale=scale, h_crop=h_crop, w_crop=w_crop)
+            data,
+            numpy=numpy,
+            desc_training=True,
+            H1=ref_H,
+            H1_scale=ref_scale,
+            H2=target_data["homography_mat"].numpy(),
+            scale=scale,
+            h_crop=h_crop,
+            w_crop=w_crop,
+        )
 
         # Spread ref data
         for key, val in ref_data.items():
             outputs["ref_" + key] = val
-        
+
         # Spread target data
         for key, val in target_data.items():
             outputs["target_" + key] = val
-        
+
         return outputs
 
     def test_preprocessing(self, data, numpy=False):
-        """ Test preprocessing for GT data. """
+        """Test preprocessing for GT data."""
         data = copy.deepcopy(data)
         # Fetch the corresponding entries
         image = data["image"]
@@ -771,31 +826,35 @@ class WireframeDataset(Dataset):
         image_size = image.shape[:2]
         # Convert junctions to pixel coordinates (from 128x128)
         junctions[:, 0] *= image_size[0] / 128
-        junctions[:, 1] *= image_size[1] / 128 
+        junctions[:, 1] *= image_size[1] / 128
 
         # Resize the image before photometric and homographical augmentations
-        if not(list(image_size) == self.config["preprocessing"]["resize"]):
+        if not (list(image_size) == self.config["preprocessing"]["resize"]):
             # Resize the image and the point location.
-            size_old = list(image.shape)[:2] # Only H and W dimensions
+            size_old = list(image.shape)[:2]  # Only H and W dimensions
 
             image = cv2.resize(
-                image, tuple(self.config['preprocessing']['resize'][::-1]),
-                interpolation=cv2.INTER_LINEAR)
+                image,
+                tuple(self.config["preprocessing"]["resize"][::-1]),
+                interpolation=cv2.INTER_LINEAR,
+            )
             image = np.array(image, dtype=np.uint8)
 
             # In HW format
-            junctions = (junctions * np.array(
-                self.config['preprocessing']['resize'], np.float)
-                         / np.array(size_old, np.float))
-        
+            junctions = (
+                junctions
+                * np.array(self.config["preprocessing"]["resize"], np.float)
+                / np.array(size_old, np.float)
+            )
+
         # Optionally convert the image to grayscale
         if self.config["gray_scale"]:
-            image = (color.rgb2gray(image) * 255.).astype(np.uint8)
+            image = (color.rgb2gray(image) * 255.0).astype(np.uint8)
 
         # Still need to normalize image
         image_transform = photoaug.normalize_image()
         image = image_transform(image)
-        
+
         # Convert to positive line map and negative line map (our format)
         num_junctions = junctions.shape[0]
         line_map_pos = self.convert_line_map(line_pos, num_junctions)
@@ -819,13 +878,11 @@ class WireframeDataset(Dataset):
                 "image": to_tensor(image),
                 "junctions": to_tensor(junctions).to(torch.float32)[0, ...],
                 "junction_map": to_tensor(junction_map).to(torch.int),
-                "line_map_pos": to_tensor(
-                    line_map_pos).to(torch.int32)[0, ...],
-                "line_map_neg": to_tensor(
-                    line_map_neg).to(torch.int32)[0, ...],
+                "line_map_pos": to_tensor(line_map_pos).to(torch.int32)[0, ...],
+                "line_map_neg": to_tensor(line_map_neg).to(torch.int32)[0, ...],
                 "heatmap_pos": to_tensor(heatmap_pos).to(torch.int32),
                 "heatmap_neg": to_tensor(heatmap_neg).to(torch.int32),
-                "valid_mask": to_tensor(valid_mask).to(torch.int32)
+                "valid_mask": to_tensor(valid_mask).to(torch.int32),
             }
         else:
             return {
@@ -836,26 +893,28 @@ class WireframeDataset(Dataset):
                 "line_map_neg": line_map_neg.astype(np.int32),
                 "heatmap_pos": heatmap_pos.astype(np.int32),
                 "heatmap_neg": heatmap_neg.astype(np.int32),
-                "valid_mask": valid_mask.astype(np.int32)
+                "valid_mask": valid_mask.astype(np.int32),
             }
-    
-    def test_preprocessing_exported(self, data, numpy=False, scale=1.):
-        """ Test preprocessing for the exported labels. """
+
+    def test_preprocessing_exported(self, data, numpy=False, scale=1.0):
+        """Test preprocessing for the exported labels."""
         data = copy.deepcopy(data)
         # Fetch the corresponding entries
         image = data["image"]
         junctions = data["junctions"]
-        line_map = data["line_map"]      
+        line_map = data["line_map"]
         image_size = image.shape[:2]
 
         # Resize the image before photometric and homographical augmentations
-        if not(list(image_size) == self.config["preprocessing"]["resize"]):
+        if not (list(image_size) == self.config["preprocessing"]["resize"]):
             # Resize the image and the point location.
-            size_old = list(image.shape)[:2] # Only H and W dimensions
+            size_old = list(image.shape)[:2]  # Only H and W dimensions
 
             image = cv2.resize(
-                image, tuple(self.config['preprocessing']['resize'][::-1]),
-                interpolation=cv2.INTER_LINEAR)
+                image,
+                tuple(self.config["preprocessing"]["resize"][::-1]),
+                interpolation=cv2.INTER_LINEAR,
+            )
             image = np.array(image, dtype=np.uint8)
 
             # # In HW format
@@ -865,7 +924,7 @@ class WireframeDataset(Dataset):
 
         # Optionally convert the image to grayscale
         if self.config["gray_scale"]:
-            image = (color.rgb2gray(image) * 255.).astype(np.uint8)
+            image = (color.rgb2gray(image) * 255.0).astype(np.uint8)
 
         # Still need to normalize image
         image_transform = photoaug.normalize_image()
@@ -875,7 +934,7 @@ class WireframeDataset(Dataset):
         junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1)
         image_size = image.shape[:2]
         heatmap = get_line_heatmap(junctions_xy, line_map, image_size)
-        
+
         # Declare default valid mask (all ones)
         valid_mask = np.ones(image_size)
 
@@ -890,7 +949,7 @@ class WireframeDataset(Dataset):
                 "junction_map": to_tensor(junction_map).to(torch.int),
                 "line_map": to_tensor(line_map).to(torch.int32)[0, ...],
                 "heatmap": to_tensor(heatmap).to(torch.int32),
-                "valid_mask": to_tensor(valid_mask).to(torch.int32)
+                "valid_mask": to_tensor(valid_mask).to(torch.int32),
             }
         else:
             outputs = {
@@ -899,20 +958,20 @@ class WireframeDataset(Dataset):
                 "junction_map": junction_map.astype(np.int32),
                 "line_map": line_map.astype(np.int32),
                 "heatmap": heatmap.astype(np.int32),
-                "valid_mask": valid_mask.astype(np.int32)
+                "valid_mask": valid_mask.astype(np.int32),
             }
-        
+
         return outputs
 
     def __len__(self):
         return self.dataset_length
 
     def get_data_from_key(self, file_key):
-        """ Get data from file_key. """
+        """Get data from file_key."""
         # Check key exists
         if not file_key in self.filename_dataset.keys():
             raise ValueError("[Error] the specified key is not in the dataset.")
-        
+
         # Get the data paths
         data_path = self.filename_dataset[file_key]
         # Read in the image and npz labels (but haven't applied any transform)
@@ -923,12 +982,12 @@ class WireframeDataset(Dataset):
             data = self.train_preprocessing(data, numpy=True)
         else:
             data = self.test_preprocessing(data, numpy=True)
-        
+
         # Add file key to the output
         data["file_key"] = file_key
-        
+
         return data
-    
+
     def __getitem__(self, idx):
         """Return data
         file_key: str, keys used to retrieve data from the filename dataset.
@@ -951,30 +1010,27 @@ class WireframeDataset(Dataset):
         if not self.gt_source == "official":
             with h5py.File(self.gt_source, "r") as f:
                 exported_label = parse_h5_data(f[file_key])
-            
+
             data["junctions"] = exported_label["junctions"]
             data["line_map"] = exported_label["line_map"]
-        
+
         # Perform transform and augmentation
         return_type = self.config.get("return_type", "single")
-        if (self.mode == "train"
-            or self.config["add_augmentation_to_all_splits"]):
+        if self.mode == "train" or self.config["add_augmentation_to_all_splits"]:
             # Perform random scaling first
             if self.config["augmentation"]["random_scaling"]["enable"]:
                 scale_range = self.config["augmentation"]["random_scaling"]["range"]
                 # Decide the scaling
                 scale = np.random.uniform(min(scale_range), max(scale_range))
             else:
-                scale = 1.
+                scale = 1.0
             if self.gt_source == "official":
                 data = self.train_preprocessing(data)
             else:
                 if return_type == "paired_desc":
-                    data = self.preprocessing_exported_paired_desc(
-                        data, scale=scale)
+                    data = self.preprocessing_exported_paired_desc(data, scale=scale)
                 else:
-                    data = self.train_preprocessing_exported(data,
-                                                             scale=scale)
+                    data = self.train_preprocessing_exported(data, scale=scale)
         else:
             if self.gt_source == "official":
                 data = self.test_preprocessing(data)
@@ -982,17 +1038,17 @@ class WireframeDataset(Dataset):
                 data = self.preprocessing_exported_paired_desc(data)
             else:
                 data = self.test_preprocessing_exported(data)
-        
+
         # Add file key to the output
         data["file_key"] = file_key
-        
+
         return data
-    
+
     ########################
     ## Some other methods ##
     ########################
     def _check_dataset_cache(self):
-        """ Check if dataset cache exists. """
+        """Check if dataset cache exists."""
         cache_file_path = os.path.join(self.cache_path, self.cache_name)
         if os.path.exists(cache_file_path):
             return True
diff --git a/third_party/SOLD2/sold2/experiment.py b/third_party/SOLD2/sold2/experiment.py
index 3bf4db1c9f148b9e33c6d7d0ba973375cd770a14..0a2d5c0dc359cec13304813ac7732c5968d70a80 100644
--- a/third_party/SOLD2/sold2/experiment.py
+++ b/third_party/SOLD2/sold2/experiment.py
@@ -19,7 +19,7 @@ torch.backends.cudnn.benchmark = True
 
 
 def load_config(config_path):
-    """ Load configurations from a given yaml file. """
+    """Load configurations from a given yaml file."""
     # Check file exists
     if not os.path.exists(config_path):
         raise ValueError("[Error] The provided config path is not valid.")
@@ -32,7 +32,7 @@ def load_config(config_path):
 
 
 def update_config(path, model_cfg=None, dataset_cfg=None):
-    """ Update configuration file from the resume path. """
+    """Update configuration file from the resume path."""
     # Check we need to update or completely override.
     model_cfg = {} if model_cfg is None else model_cfg
     dataset_cfg = {} if dataset_cfg is None else dataset_cfg
@@ -57,23 +57,23 @@ def update_config(path, model_cfg=None, dataset_cfg=None):
 
 
 def record_config(model_cfg, dataset_cfg, output_path):
-    """ Record dataset config to the log path. """
+    """Record dataset config to the log path."""
     # Record model config
     with open(os.path.join(output_path, "model_cfg.yaml"), "w") as f:
-            yaml.safe_dump(model_cfg, f)
-    
+        yaml.safe_dump(model_cfg, f)
+
     # Record dataset config
     with open(os.path.join(output_path, "dataset_cfg.yaml"), "w") as f:
-            yaml.safe_dump(dataset_cfg, f)
-    
+        yaml.safe_dump(dataset_cfg, f)
+
 
 def train(args, dataset_cfg, model_cfg, output_path):
-    """ Training function. """
+    """Training function."""
     # Update model config from the resume path (only in resume mode)
     if args.resume:
         if os.path.realpath(output_path) != os.path.realpath(args.resume_path):
             record_config(model_cfg, dataset_cfg, output_path)
-        
+
     # First time, then write the config file to the output path
     else:
         record_config(model_cfg, dataset_cfg, output_path)
@@ -82,23 +82,32 @@ def train(args, dataset_cfg, model_cfg, output_path):
     train_net(args, dataset_cfg, model_cfg, output_path)
 
 
-def export(args, dataset_cfg, model_cfg, output_path,
-           export_dataset_mode=None, device=torch.device("cuda")):
-    """ Export function. """
+def export(
+    args,
+    dataset_cfg,
+    model_cfg,
+    output_path,
+    export_dataset_mode=None,
+    device=torch.device("cuda"),
+):
+    """Export function."""
     # Choose between normal predictions export or homography adaptation
     if dataset_cfg.get("homography_adaptation") is not None:
         print("[Info] Export predictions with homography adaptation.")
-        export_homograpy_adaptation(args, dataset_cfg, model_cfg, output_path,
-                                    export_dataset_mode, device)
+        export_homograpy_adaptation(
+            args, dataset_cfg, model_cfg, output_path, export_dataset_mode, device
+        )
     else:
         print("[Info] Export predictions normally.")
-        export_predictions(args, dataset_cfg, model_cfg, output_path,
-                           export_dataset_mode)
+        export_predictions(
+            args, dataset_cfg, model_cfg, output_path, export_dataset_mode
+        )
 
 
-def main(args, dataset_cfg, model_cfg, export_dataset_mode=None,
-         device=torch.device("cuda")):
-    """ Main function. """
+def main(
+    args, dataset_cfg, model_cfg, export_dataset_mode=None, device=torch.device("cuda")
+):
+    """Main function."""
     # Make the output path
     output_path = os.path.join(cfg.EXP_PATH, args.exp_name)
 
@@ -113,7 +122,14 @@ def main(args, dataset_cfg, model_cfg, export_dataset_mode=None,
         output_path = os.path.join(cfg.export_dataroot, args.exp_name)
         print("[Info] Export mode")
         print("\t Output path: %s" % output_path)
-        export(args, dataset_cfg, model_cfg, output_path, export_dataset_mode, device=device)
+        export(
+            args,
+            dataset_cfg,
+            model_cfg,
+            output_path,
+            export_dataset_mode,
+            device=device,
+        )
     else:
         raise ValueError("[Error]: Unknown mode: " + args.mode)
 
@@ -126,28 +142,43 @@ def set_random_seed(seed):
 if __name__ == "__main__":
     # Parse input arguments
     parser = argparse.ArgumentParser()
-    parser.add_argument("--mode", type=str, default="train",
-                        help="'train' or 'export'.")
-    parser.add_argument("--dataset_config", type=str, default=None,
-                        help="Path to the dataset config.")
-    parser.add_argument("--model_config", type=str, default=None,
-                        help="Path to the model config.")
-    parser.add_argument("--exp_name", type=str, default="exp",
-                        help="Experiment name.")
-    parser.add_argument("--resume", action="store_true", default=False,
-                        help="Load a previously trained model.")
-    parser.add_argument("--pretrained", action="store_true", default=False,
-                        help="Start training from a pre-trained model.")
-    parser.add_argument("--resume_path", default=None,
-                        help="Path from which to resume training.")
-    parser.add_argument("--pretrained_path", default=None,
-                        help="Path to the pre-trained model.")
-    parser.add_argument("--checkpoint_name", default=None,
-                        help="Name of the checkpoint to use.")
-    parser.add_argument("--export_dataset_mode", default=None,
-                        help="'train' or 'test'.")
-    parser.add_argument("--export_batch_size", default=4, type=int,
-                        help="Export batch size.")
+    parser.add_argument(
+        "--mode", type=str, default="train", help="'train' or 'export'."
+    )
+    parser.add_argument(
+        "--dataset_config", type=str, default=None, help="Path to the dataset config."
+    )
+    parser.add_argument(
+        "--model_config", type=str, default=None, help="Path to the model config."
+    )
+    parser.add_argument("--exp_name", type=str, default="exp", help="Experiment name.")
+    parser.add_argument(
+        "--resume",
+        action="store_true",
+        default=False,
+        help="Load a previously trained model.",
+    )
+    parser.add_argument(
+        "--pretrained",
+        action="store_true",
+        default=False,
+        help="Start training from a pre-trained model.",
+    )
+    parser.add_argument(
+        "--resume_path", default=None, help="Path from which to resume training."
+    )
+    parser.add_argument(
+        "--pretrained_path", default=None, help="Path to the pre-trained model."
+    )
+    parser.add_argument(
+        "--checkpoint_name", default=None, help="Name of the checkpoint to use."
+    )
+    parser.add_argument(
+        "--export_dataset_mode", default=None, help="'train' or 'test'."
+    )
+    parser.add_argument(
+        "--export_batch_size", default=4, type=int, help="Export batch size."
+    )
 
     args = parser.parse_args()
 
@@ -159,28 +190,29 @@ if __name__ == "__main__":
         device = torch.device("cpu")
 
     # Check if dataset config and model config is given.
-    if (((args.dataset_config is None) or (args.model_config is None))
-        and (not args.resume) and (args.mode == "train")):
+    if (
+        ((args.dataset_config is None) or (args.model_config is None))
+        and (not args.resume)
+        and (args.mode == "train")
+    ):
         raise ValueError(
-            "[Error] The dataset config and model config should be given in non-resume mode")
+            "[Error] The dataset config and model config should be given in non-resume mode"
+        )
 
     # If resume, check if the resume path has been given
     if args.resume and (args.resume_path is None):
-        raise ValueError(
-            "[Error] Missing resume path.")
+        raise ValueError("[Error] Missing resume path.")
 
     # [Training] Load the config file.
     if args.mode == "train" and (not args.resume):
         # Check the pretrained checkpoint_path exists
         if args.pretrained:
             checkpoint_folder = args.resume_path
-            checkpoint_path = os.path.join(args.pretrained_path,
-                                           args.checkpoint_name)
+            checkpoint_path = os.path.join(args.pretrained_path, args.checkpoint_name)
             if not os.path.exists(checkpoint_path):
-                raise ValueError("[Error] Missing checkpoint: "
-                                 + checkpoint_path)
+                raise ValueError("[Error] Missing checkpoint: " + checkpoint_path)
         dataset_cfg = load_config(args.dataset_config)
-        model_cfg = load_config(args.model_config)       
+        model_cfg = load_config(args.model_config)
 
     # [resume Training, Test, Export] Load the config file.
     elif (args.mode == "train" and args.resume) or (args.mode == "export"):
@@ -195,33 +227,35 @@ if __name__ == "__main__":
             print("[Info] No model config provided. Loading from checkpoint folder.")
             model_cfg_path = os.path.join(checkpoint_folder, "model_cfg.yaml")
             if not os.path.exists(model_cfg_path):
-                raise ValueError(
-                    "[Error] Missing model config in checkpoint path.")
+                raise ValueError("[Error] Missing model config in checkpoint path.")
             model_cfg = load_config(model_cfg_path)
         else:
             model_cfg = load_config(args.model_config)
-        
+
         # Load dataset_cfg from checkpoint folder if not provided
         if args.dataset_config is None:
             print("[Info] No dataset config provided. Loading from checkpoint folder.")
-            dataset_cfg_path = os.path.join(checkpoint_folder,
-                                            "dataset_cfg.yaml")
+            dataset_cfg_path = os.path.join(checkpoint_folder, "dataset_cfg.yaml")
             if not os.path.exists(dataset_cfg_path):
-                raise ValueError(
-                    "[Error] Missing dataset config in checkpoint path.")
+                raise ValueError("[Error] Missing dataset config in checkpoint path.")
             dataset_cfg = load_config(dataset_cfg_path)
         else:
             dataset_cfg = load_config(args.dataset_config)
-        
+
         # Check the --export_dataset_mode flag
         if (args.mode == "export") and (args.export_dataset_mode is None):
             raise ValueError("[Error] Empty --export_dataset_mode flag.")
     else:
         raise ValueError("[Error] Unknown mode: " + args.mode)
-    
+
     # Set the random seed
     seed = dataset_cfg.get("random_seed", 0)
     set_random_seed(seed)
 
-    main(args, dataset_cfg, model_cfg,
-         export_dataset_mode=args.export_dataset_mode, device=device)
+    main(
+        args,
+        dataset_cfg,
+        model_cfg,
+        export_dataset_mode=args.export_dataset_mode,
+        device=device,
+    )
diff --git a/third_party/SOLD2/sold2/export.py b/third_party/SOLD2/sold2/export.py
index 19683d982c6d7fd429b27868b620fd20562d1aa7..ec5bf2dcb1c51999c80b6d1ff170c238883e34a0 100644
--- a/third_party/SOLD2/sold2/export.py
+++ b/third_party/SOLD2/sold2/export.py
@@ -17,7 +17,7 @@ from .dataset.transforms.homographic_transforms import sample_homography
 
 
 def restore_weights(model, state_dict):
-    """ Restore weights in compatible mode. """
+    """Restore weights in compatible mode."""
     # Try to directly load state dict
     try:
         model.load_state_dict(state_dict)
@@ -38,15 +38,14 @@ def restore_weights(model, state_dict):
 
 
 def get_padded_filename(num_pad, idx):
-    """ Get the filename padded with 0. """
+    """Get the filename padded with 0."""
     file_len = len("%d" % (idx))
     filename = "0" * (num_pad - file_len) + "%d" % (idx)
     return filename
 
 
-def export_predictions(args, dataset_cfg, model_cfg, output_path,
-                       export_dataset_mode):
-    """ Export predictions. """
+def export_predictions(args, dataset_cfg, model_cfg, output_path, export_dataset_mode):
+    """Export predictions."""
     # Get the test configuration
     test_cfg = model_cfg["test"]
 
@@ -54,10 +53,14 @@ def export_predictions(args, dataset_cfg, model_cfg, output_path,
     print("\t Initializing dataset and dataloader")
     batch_size = 4
     export_dataset, collate_fn = get_dataset(export_dataset_mode, dataset_cfg)
-    export_loader = DataLoader(export_dataset, batch_size=batch_size,
-                               num_workers=test_cfg.get("num_workers", 4),
-                               shuffle=False, pin_memory=False,
-                               collate_fn=collate_fn)
+    export_loader = DataLoader(
+        export_dataset,
+        batch_size=batch_size,
+        num_workers=test_cfg.get("num_workers", 4),
+        shuffle=False,
+        pin_memory=False,
+        collate_fn=collate_fn,
+    )
     print("\t Successfully intialized dataset and dataloader.")
 
     # Initialize model and load the checkpoint
@@ -87,11 +90,18 @@ def export_predictions(args, dataset_cfg, model_cfg, output_path,
 
             # Convert predictions
             junc_np = convert_junc_predictions(
-                outputs["junctions"], model_cfg["grid_size"],
-                model_cfg["detection_thresh"], 300)
+                outputs["junctions"],
+                model_cfg["grid_size"],
+                model_cfg["detection_thresh"],
+                300,
+            )
             junc_map_np = junc_map.numpy().transpose(0, 2, 3, 1)
-            heatmap_np = softmax(outputs["heatmap"].detach(),
-                                 dim=1).cpu().numpy().transpose(0, 2, 3, 1)
+            heatmap_np = (
+                softmax(outputs["heatmap"].detach(), dim=1)
+                .cpu()
+                .numpy()
+                .transpose(0, 2, 3, 1)
+            )
             heatmap_gt_np = heatmap.numpy().transpose(0, 2, 3, 1)
             valid_mask_np = valid_mask.numpy().transpose(0, 2, 3, 1)
 
@@ -99,15 +109,22 @@ def export_predictions(args, dataset_cfg, model_cfg, output_path,
             current_batch_size = input_images.shape[0]
             for batch_idx in range(current_batch_size):
                 output_data = {
-                    "image": input_images.cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
+                    "image": input_images.cpu()
+                    .numpy()
+                    .transpose(0, 2, 3, 1)[batch_idx],
                     "junc_gt": junc_map_np[batch_idx],
                     "junc_pred": junc_np["junc_pred"][batch_idx],
-                    "junc_pred_nms": junc_np["junc_pred_nms"][batch_idx].astype(np.float32),
+                    "junc_pred_nms": junc_np["junc_pred_nms"][batch_idx].astype(
+                        np.float32
+                    ),
                     "heatmap_gt": heatmap_gt_np[batch_idx],
                     "heatmap_pred": heatmap_np[batch_idx],
                     "valid_mask": valid_mask_np[batch_idx],
-                    "junc_points": data["junctions"][batch_idx].numpy()[0].round().astype(np.int32),
-                    "line_map": data["line_map"][batch_idx].numpy()[0].astype(np.int32)
+                    "junc_points": data["junctions"][batch_idx]
+                    .numpy()[0]
+                    .round()
+                    .astype(np.int32),
+                    "line_map": data["line_map"][batch_idx].numpy()[0].astype(np.int32),
                 }
 
                 # Save data to h5 dataset
@@ -117,19 +134,18 @@ def export_predictions(args, dataset_cfg, model_cfg, output_path,
 
                 # Store data
                 for key, output_data in output_data.items():
-                    f_group.create_dataset(key, data=output_data,
-                                           compression="gzip")
+                    f_group.create_dataset(key, data=output_data, compression="gzip")
                 filename_idx += 1
 
 
-def export_homograpy_adaptation(args, dataset_cfg, model_cfg, output_path,
-                                export_dataset_mode, device):
-    """ Export homography adaptation results. """
+def export_homograpy_adaptation(
+    args, dataset_cfg, model_cfg, output_path, export_dataset_mode, device
+):
+    """Export homography adaptation results."""
     # Check if the export_dataset_mode is supported
     supported_modes = ["train", "test"]
     if not export_dataset_mode in supported_modes:
-        raise ValueError(
-            "[Error] The specified export_dataset_mode is not supported.")
+        raise ValueError("[Error] The specified export_dataset_mode is not supported.")
 
     # Get the test configuration
     test_cfg = model_cfg["test"]
@@ -137,66 +153,87 @@ def export_homograpy_adaptation(args, dataset_cfg, model_cfg, output_path,
     # Get the homography adaptation configurations
     homography_cfg = dataset_cfg.get("homography_adaptation", None)
     if homography_cfg is None:
-        raise ValueError(
-            "[Error] Empty homography_adaptation entry in config.")
+        raise ValueError("[Error] Empty homography_adaptation entry in config.")
 
     # Create the dataset and dataloader based on the export_dataset_mode
     print("\t Initializing dataset and dataloader")
     batch_size = args.export_batch_size
 
     export_dataset, collate_fn = get_dataset(export_dataset_mode, dataset_cfg)
-    export_loader = DataLoader(export_dataset, batch_size=batch_size,
-                               num_workers=test_cfg.get("num_workers", 4),
-                               shuffle=False, pin_memory=False,
-                               collate_fn=collate_fn)
+    export_loader = DataLoader(
+        export_dataset,
+        batch_size=batch_size,
+        num_workers=test_cfg.get("num_workers", 4),
+        shuffle=False,
+        pin_memory=False,
+        collate_fn=collate_fn,
+    )
     print("\t Successfully intialized dataset and dataloader.")
 
     # Initialize model and load the checkpoint
     model = get_model(model_cfg, mode="test")
-    checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name,
-                                       device)
+    checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name, device)
     model = restore_weights(model, checkpoint["model_state_dict"])
     model = model.to(device).eval()
     print("\t Successfully initialized model")
 
     # Start the export process
-    print("[Info] Start exporting predictions")    
+    print("[Info] Start exporting predictions")
     output_dataset_path = output_path + ".h5"
     with h5py.File(output_dataset_path, "w", libver="latest") as f:
-        f.swmr_mode=True
+        f.swmr_mode = True
         for _, data in enumerate(tqdm(export_loader, ascii=True)):
             input_images = data["image"].to(device)
             file_keys = data["file_key"]
             batch_size = input_images.shape[0]
-            
+
             # Run the homograpy adaptation
-            outputs = homography_adaptation(input_images, model,
-                                            model_cfg["grid_size"],
-                                            homography_cfg)
+            outputs = homography_adaptation(
+                input_images, model, model_cfg["grid_size"], homography_cfg
+            )
 
             # Save the entries
             for batch_idx in range(batch_size):
                 # Get the save key
                 save_key = file_keys[batch_idx]
                 output_data = {
-                    "image": input_images.cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
-                    "junc_prob_mean": outputs["junc_probs_mean"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
-                    "junc_prob_max": outputs["junc_probs_max"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
-                    "junc_count": outputs["junc_counts"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
-                    "heatmap_prob_mean": outputs["heatmap_probs_mean"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
-                    "heatmap_prob_max": outputs["heatmap_probs_max"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
-                    "heatmap_cout": outputs["heatmap_counts"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx]
+                    "image": input_images.cpu()
+                    .numpy()
+                    .transpose(0, 2, 3, 1)[batch_idx],
+                    "junc_prob_mean": outputs["junc_probs_mean"]
+                    .cpu()
+                    .numpy()
+                    .transpose(0, 2, 3, 1)[batch_idx],
+                    "junc_prob_max": outputs["junc_probs_max"]
+                    .cpu()
+                    .numpy()
+                    .transpose(0, 2, 3, 1)[batch_idx],
+                    "junc_count": outputs["junc_counts"]
+                    .cpu()
+                    .numpy()
+                    .transpose(0, 2, 3, 1)[batch_idx],
+                    "heatmap_prob_mean": outputs["heatmap_probs_mean"]
+                    .cpu()
+                    .numpy()
+                    .transpose(0, 2, 3, 1)[batch_idx],
+                    "heatmap_prob_max": outputs["heatmap_probs_max"]
+                    .cpu()
+                    .numpy()
+                    .transpose(0, 2, 3, 1)[batch_idx],
+                    "heatmap_cout": outputs["heatmap_counts"]
+                    .cpu()
+                    .numpy()
+                    .transpose(0, 2, 3, 1)[batch_idx],
                 }
 
                 # Create group and write data
                 f_group = f.create_group(save_key)
                 for key, output_data in output_data.items():
-                    f_group.create_dataset(key, data=output_data,
-                                           compression="gzip")
+                    f_group.create_dataset(key, data=output_data, compression="gzip")
 
 
 def homography_adaptation(input_images, model, grid_size, homography_cfg):
-    """ The homography adaptation process.
+    """The homography adaptation process.
     Arguments:
         input_images: The images to be evaluated.
         model: The pytorch model in evaluation mode.
@@ -222,121 +259,140 @@ def homography_adaptation(input_images, model, grid_size, homography_cfg):
     for idx in range(num_iter):
         if idx <= num_iter // 5:
             # Ensure that 20% of the homographies have no artifact
-            H_mat_lst = [sample_homography(
-                [H,W], **homography_cfg_no_artifacts)[0][None]
-                         for _ in range(batch_size)]
+            H_mat_lst = [
+                sample_homography([H, W], **homography_cfg_no_artifacts)[0][None]
+                for _ in range(batch_size)
+            ]
         else:
-            H_mat_lst = [sample_homography(
-                [H,W], **homography_cfg["homographies"])[0][None]
-                         for _ in range(batch_size)]
+            H_mat_lst = [
+                sample_homography([H, W], **homography_cfg["homographies"])[0][None]
+                for _ in range(batch_size)
+            ]
 
         H_mats = np.concatenate(H_mat_lst, axis=0)
         H_tensor = torch.tensor(H_mats, dtype=torch.float, device=device)
         H_inv_tensor = torch.inverse(H_tensor)
 
         # Perform the homography warp
-        images_warped = warp_perspective(input_images, H_tensor, (H, W),
-                                         flags="bilinear")
-        
+        images_warped = warp_perspective(
+            input_images, H_tensor, (H, W), flags="bilinear"
+        )
+
         # Warp the mask
         masks_junc_warped = warp_perspective(
             torch.ones([batch_size, 1, H, W], device=device),
-            H_tensor, (H, W), flags="nearest")
+            H_tensor,
+            (H, W),
+            flags="nearest",
+        )
         masks_heatmap_warped = warp_perspective(
             torch.ones([batch_size, 1, H, W], device=device),
-            H_tensor, (H, W), flags="nearest")
+            H_tensor,
+            (H, W),
+            flags="nearest",
+        )
 
         # Run the network forward pass
         with torch.no_grad():
             outputs = model(images_warped)
-        
+
         # Unwarp and mask the junction prediction
-        junc_prob_warped = pixel_shuffle(softmax(
-            outputs["junctions"], dim=1)[:, :-1, :, :], grid_size)
-        junc_prob = warp_perspective(junc_prob_warped, H_inv_tensor,
-                                     (H, W), flags="bilinear")
+        junc_prob_warped = pixel_shuffle(
+            softmax(outputs["junctions"], dim=1)[:, :-1, :, :], grid_size
+        )
+        junc_prob = warp_perspective(
+            junc_prob_warped, H_inv_tensor, (H, W), flags="bilinear"
+        )
 
         # Create the out of boundary mask
         out_boundary_mask = warp_perspective(
             torch.ones([batch_size, 1, H, W], device=device),
-            H_inv_tensor, (H, W), flags="nearest")
+            H_inv_tensor,
+            (H, W),
+            flags="nearest",
+        )
         out_boundary_mask = adjust_border(out_boundary_mask, device, margin)
 
         junc_prob = junc_prob * out_boundary_mask
-        junc_count = warp_perspective(masks_junc_warped * out_boundary_mask,
-                                      H_inv_tensor, (H, W), flags="nearest")
+        junc_count = warp_perspective(
+            masks_junc_warped * out_boundary_mask, H_inv_tensor, (H, W), flags="nearest"
+        )
 
         # Unwarp the mask and heatmap prediction
         # Always fetch only one channel
         if outputs["heatmap"].shape[1] == 2:
             # Convert to single channel directly from here
-            heatmap_prob_warped = softmax(outputs["heatmap"],
-                                          dim=1)[:, 1:, :, :]
+            heatmap_prob_warped = softmax(outputs["heatmap"], dim=1)[:, 1:, :, :]
         else:
             heatmap_prob_warped = torch.sigmoid(outputs["heatmap"])
-        
+
         heatmap_prob_warped = heatmap_prob_warped * masks_heatmap_warped
-        heatmap_prob = warp_perspective(heatmap_prob_warped, H_inv_tensor,
-                                        (H, W), flags="bilinear")
-        heatmap_count = warp_perspective(masks_heatmap_warped, H_inv_tensor,
-                                         (H, W), flags="nearest")
+        heatmap_prob = warp_perspective(
+            heatmap_prob_warped, H_inv_tensor, (H, W), flags="bilinear"
+        )
+        heatmap_count = warp_perspective(
+            masks_heatmap_warped, H_inv_tensor, (H, W), flags="nearest"
+        )
 
         # Record the results
-        junc_probs[:, idx:idx+1, :, :] = junc_prob
-        heatmap_probs[:, idx:idx+1, :, :] = heatmap_prob
+        junc_probs[:, idx : idx + 1, :, :] = junc_prob
+        heatmap_probs[:, idx : idx + 1, :, :] = heatmap_prob
         junc_counts += junc_count
         heatmap_counts += heatmap_count
 
     # Perform the accumulation operation
     if homography_cfg["min_counts"] > 0:
         min_counts = homography_cfg["min_counts"]
-        junc_count_mask = (junc_counts < min_counts)
-        heatmap_count_mask = (heatmap_counts < min_counts)
+        junc_count_mask = junc_counts < min_counts
+        heatmap_count_mask = heatmap_counts < min_counts
         junc_counts[junc_count_mask] = 0
         heatmap_counts[heatmap_count_mask] = 0
     else:
         junc_count_mask = np.zeros_like(junc_counts, dtype=bool)
         heatmap_count_mask = np.zeros_like(heatmap_counts, dtype=bool)
-    
+
     # Compute the mean accumulation
     junc_probs_mean = torch.sum(junc_probs, dim=1, keepdim=True) / junc_counts
-    junc_probs_mean[junc_count_mask] = 0.
-    heatmap_probs_mean = (torch.sum(heatmap_probs, dim=1, keepdim=True)
-                          / heatmap_counts)
-    heatmap_probs_mean[heatmap_count_mask] = 0.
+    junc_probs_mean[junc_count_mask] = 0.0
+    heatmap_probs_mean = torch.sum(heatmap_probs, dim=1, keepdim=True) / heatmap_counts
+    heatmap_probs_mean[heatmap_count_mask] = 0.0
 
     # Compute the max accumulation
     junc_probs_max = torch.max(junc_probs, dim=1, keepdim=True)[0]
-    junc_probs_max[junc_count_mask] = 0.
+    junc_probs_max[junc_count_mask] = 0.0
     heatmap_probs_max = torch.max(heatmap_probs, dim=1, keepdim=True)[0]
-    heatmap_probs_max[heatmap_count_mask] = 0.
+    heatmap_probs_max[heatmap_count_mask] = 0.0
 
-    return {"junc_probs_mean": junc_probs_mean,
-            "junc_probs_max": junc_probs_max,
-            "junc_counts": junc_counts,
-            "heatmap_probs_mean": heatmap_probs_mean,
-            "heatmap_probs_max": heatmap_probs_max,
-            "heatmap_counts": heatmap_counts}
+    return {
+        "junc_probs_mean": junc_probs_mean,
+        "junc_probs_max": junc_probs_max,
+        "junc_counts": junc_counts,
+        "heatmap_probs_mean": heatmap_probs_mean,
+        "heatmap_probs_max": heatmap_probs_max,
+        "heatmap_counts": heatmap_counts,
+    }
 
 
 def adjust_border(input_masks, device, margin=3):
-    """ Adjust the border of the counts and valid_mask. """
+    """Adjust the border of the counts and valid_mask."""
     # Convert the mask to numpy array
     dtype = input_masks.dtype
     input_masks = np.squeeze(input_masks.cpu().numpy(), axis=1)
 
-    erosion_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
-                                               (margin*2, margin*2))
+    erosion_kernel = cv2.getStructuringElement(
+        cv2.MORPH_ELLIPSE, (margin * 2, margin * 2)
+    )
     batch_size = input_masks.shape[0]
-    
+
     output_mask_lst = []
     # Erode all the masks
     for i in range(batch_size):
         output_mask = cv2.erode(input_masks[i, ...], erosion_kernel)
 
         output_mask_lst.append(
-            torch.tensor(output_mask, dtype=dtype, device=device)[None])
-    
+            torch.tensor(output_mask, dtype=dtype, device=device)[None]
+        )
+
     # Concat back along the batch dimension.
     output_masks = torch.cat(output_mask_lst, dim=0)
     return output_masks.unsqueeze(dim=1)
diff --git a/third_party/SOLD2/sold2/export_line_features.py b/third_party/SOLD2/sold2/export_line_features.py
index 4cbde860a446d758dff254ea5320ca13bb79e6b7..6df203c6ad62a559a1617744b200df283b9bb9a7 100644
--- a/third_party/SOLD2/sold2/export_line_features.py
+++ b/third_party/SOLD2/sold2/export_line_features.py
@@ -12,24 +12,29 @@ from .experiment import load_config
 from .model.line_matcher import LineMatcher
 
 
-def export_descriptors(images_list, ckpt_path, config, device, extension,
-                       output_folder, multiscale=False):
+def export_descriptors(
+    images_list, ckpt_path, config, device, extension, output_folder, multiscale=False
+):
     # Extract the image paths
-    with open(images_list, 'r') as f:
+    with open(images_list, "r") as f:
         image_files = f.readlines()
-    image_files = [path.strip('\n') for path in image_files]
+    image_files = [path.strip("\n") for path in image_files]
 
     # Initialize the line matcher
     line_matcher = LineMatcher(
-        config["model_cfg"], ckpt_path, device, config["line_detector_cfg"],
-        config["line_matcher_cfg"], multiscale)
+        config["model_cfg"],
+        ckpt_path,
+        device,
+        config["line_detector_cfg"],
+        config["line_matcher_cfg"],
+        multiscale,
+    )
     print("\t Successfully initialized model")
 
     # Run the inference on each image and write the output on disk
     for img_path in tqdm(image_files):
         img = cv2.imread(img_path, 0)
-        img = torch.tensor(img[None, None] / 255., dtype=torch.float,
-                           device=device)
+        img = torch.tensor(img[None, None] / 255.0, dtype=torch.float, device=device)
 
         # Run the line detection and description
         ref_detection = line_matcher.line_detection(img)
@@ -39,21 +44,29 @@ def export_descriptors(images_list, ckpt_path, config, device, extension,
         # Write the output on disk
         img_name = os.path.splitext(os.path.basename(img_path))[0]
         output_file = os.path.join(output_folder, img_name + extension)
-        np.savez_compressed(output_file, line_seg=ref_line_seg,
-                            descriptors=ref_descriptors)
+        np.savez_compressed(
+            output_file, line_seg=ref_line_seg, descriptors=ref_descriptors
+        )
 
 
 if __name__ == "__main__":
     # Parse input arguments
     parser = argparse.ArgumentParser()
-    parser.add_argument("--img_list", type=str, required=True,
-                        help="List of input images in a text file.")
-    parser.add_argument("--output_folder", type=str, required=True,
-                        help="Path to the output folder.")
-    parser.add_argument("--config", type=str,
-                        default="config/export_line_features.yaml")
-    parser.add_argument("--checkpoint_path", type=str,
-                        default="pretrained_models/sold2_wireframe.tar")
+    parser.add_argument(
+        "--img_list",
+        type=str,
+        required=True,
+        help="List of input images in a text file.",
+    )
+    parser.add_argument(
+        "--output_folder", type=str, required=True, help="Path to the output folder."
+    )
+    parser.add_argument(
+        "--config", type=str, default="config/export_line_features.yaml"
+    )
+    parser.add_argument(
+        "--checkpoint_path", type=str, default="pretrained_models/sold2_wireframe.tar"
+    )
     parser.add_argument("--multiscale", action="store_true", default=False)
     parser.add_argument("--extension", type=str, default=None)
     args = parser.parse_args()
@@ -67,8 +80,15 @@ if __name__ == "__main__":
     # Get the model config, extension and checkpoint path
     config = load_config(args.config)
     ckpt_path = os.path.abspath(args.checkpoint_path)
-    extension = 'sold2' if args.extension is None else args.extension
+    extension = "sold2" if args.extension is None else args.extension
     extension = "." + extension
 
-    export_descriptors(args.img_list, ckpt_path, config, device, extension,
-                       args.output_folder, args.multiscale)
+    export_descriptors(
+        args.img_list,
+        ckpt_path,
+        config,
+        device,
+        extension,
+        args.output_folder,
+        args.multiscale,
+    )
diff --git a/third_party/SOLD2/sold2/misc/geometry_utils.py b/third_party/SOLD2/sold2/misc/geometry_utils.py
index 50f0478062cd19ebac812bff62b6c3a3d5f124c2..024430a07b9b094d2eca6e4e9e14edd5105ad1c5 100644
--- a/third_party/SOLD2/sold2/misc/geometry_utils.py
+++ b/third_party/SOLD2/sold2/misc/geometry_utils.py
@@ -7,8 +7,9 @@ import torch
 # Warp a list of points using a homography
 def warp_points(points, homography):
     # Convert to homogeneous and in xy format
-    new_points = np.concatenate([points[..., [1, 0]],
-                                 np.ones_like(points[..., :1])], axis=-1)
+    new_points = np.concatenate(
+        [points[..., [1, 0]], np.ones_like(points[..., :1])], axis=-1
+    )
     # Warp
     new_points = (homography @ new_points.T).T
     # Convert back to inhomogeneous and hw format
@@ -18,10 +19,12 @@ def warp_points(points, homography):
 
 # Mask out the points that are outside of img_size
 def mask_points(points, img_size):
-    mask = ((points[..., 0] >= 0)
-            & (points[..., 0] < img_size[0])
-            & (points[..., 1] >= 0)
-            & (points[..., 1] < img_size[1]))
+    mask = (
+        (points[..., 0] >= 0)
+        & (points[..., 0] < img_size[0])
+        & (points[..., 1] >= 0)
+        & (points[..., 1] < img_size[1])
+    )
     return mask
 
 
@@ -30,8 +33,12 @@ def mask_points(points, img_size):
 def keypoints_to_grid(keypoints, img_size):
     n_points = keypoints.size()[-2]
     device = keypoints.device
-    grid_points = keypoints.float() * 2. / torch.tensor(
-        img_size, dtype=torch.float, device=device) - 1.
+    grid_points = (
+        keypoints.float()
+        * 2.0
+        / torch.tensor(img_size, dtype=torch.float, device=device)
+        - 1.0
+    )
     grid_points = grid_points[..., [1, 0]].view(-1, n_points, 1, 2)
     return grid_points
 
@@ -44,8 +51,9 @@ def get_dist_mask(kp0, kp1, valid_mask, dist_thresh):
     dist_mask1 = torch.norm(kp1.unsqueeze(2) - kp1.unsqueeze(1), dim=-1)
     dist_mask = torch.min(dist_mask0, dist_mask1)
     dist_mask = dist_mask <= dist_thresh
-    dist_mask = dist_mask.repeat(1, 1, b_size).reshape(b_size * n_points,
-                                                       b_size * n_points)
+    dist_mask = dist_mask.repeat(1, 1, b_size).reshape(
+        b_size * n_points, b_size * n_points
+    )
     dist_mask = dist_mask[valid_mask, :][:, valid_mask]
     return dist_mask
 
@@ -75,7 +83,8 @@ def mask_lines(lines, valid_mask):
 def get_common_line_mask(line_indices, valid_mask):
     b_size, n_points = line_indices.shape
     common_mask = line_indices[:, :, None] == line_indices[:, None, :]
-    common_mask = common_mask.repeat(1, 1, b_size).reshape(b_size * n_points,
-                                                           b_size * n_points)
+    common_mask = common_mask.repeat(1, 1, b_size).reshape(
+        b_size * n_points, b_size * n_points
+    )
     common_mask = common_mask[valid_mask, :][:, valid_mask]
     return common_mask
diff --git a/third_party/SOLD2/sold2/misc/train_utils.py b/third_party/SOLD2/sold2/misc/train_utils.py
index d5ada35eea660df1f78b9f20d9bf7ed726eaee2c..99113247351ceef152f308e793234a952df78166 100644
--- a/third_party/SOLD2/sold2/misc/train_utils.py
+++ b/third_party/SOLD2/sold2/misc/train_utils.py
@@ -10,7 +10,7 @@ import torch
 ## image utils ##
 #################
 def convert_image(input_tensor, axis):
-    """ Convert single channel images to 3-channel images. """
+    """Convert single channel images to 3-channel images."""
     image_lst = [input_tensor for _ in range(3)]
     outputs = np.concatenate(image_lst, axis)
     return outputs
@@ -19,29 +19,32 @@ def convert_image(input_tensor, axis):
 ######################
 ## checkpoint utils ##
 ######################
-def get_latest_checkpoint(checkpoint_root, checkpoint_name,
-                          device=torch.device("cuda")):
-    """ Get the latest checkpoint or by filename. """
+def get_latest_checkpoint(
+    checkpoint_root, checkpoint_name, device=torch.device("cuda")
+):
+    """Get the latest checkpoint or by filename."""
     # Load specific checkpoint
     if checkpoint_name is not None:
         checkpoint = torch.load(
-            os.path.join(checkpoint_root, checkpoint_name),
-            map_location=device)
+            os.path.join(checkpoint_root, checkpoint_name), map_location=device
+        )
     # Load the latest checkpoint
     else:
-        lastest_checkpoint = sorted(os.listdir(os.path.join(
-            checkpoint_root, "*.tar")))[-1]
-        checkpoint = torch.load(os.path.join(
-            checkpoint_root, lastest_checkpoint), map_location=device)
+        lastest_checkpoint = sorted(os.listdir(os.path.join(checkpoint_root, "*.tar")))[
+            -1
+        ]
+        checkpoint = torch.load(
+            os.path.join(checkpoint_root, lastest_checkpoint), map_location=device
+        )
     return checkpoint
 
 
 def remove_old_checkpoints(checkpoint_root, max_ckpt=15):
-    """ Remove the outdated checkpoints. """
+    """Remove the outdated checkpoints."""
     # Get sorted list of checkpoints
     checkpoint_list = sorted(
-        [_ for _ in os.listdir(os.path.join(checkpoint_root))
-         if _.endswith(".tar")])
+        [_ for _ in os.listdir(os.path.join(checkpoint_root)) if _.endswith(".tar")]
+    )
 
     # Get the checkpoints to be removed
     if len(checkpoint_list) > max_ckpt:
@@ -55,7 +58,7 @@ def remove_old_checkpoints(checkpoint_root, max_ckpt=15):
 def adapt_checkpoint(state_dict):
     new_state_dict = {}
     for k, v in state_dict.items():
-        if k.startswith('module.'):
+        if k.startswith("module."):
             new_state_dict[k[7:]] = v
         else:
             new_state_dict[k] = v
@@ -66,9 +69,9 @@ def adapt_checkpoint(state_dict):
 ## HDF5 utils ##
 ################
 def parse_h5_data(h5_data):
-    """ Parse h5 dataset. """
+    """Parse h5 dataset."""
     output_data = {}
     for key in h5_data.keys():
         output_data[key] = np.array(h5_data[key])
-        
+
     return output_data
diff --git a/third_party/SOLD2/sold2/misc/visualize_util.py b/third_party/SOLD2/sold2/misc/visualize_util.py
index 4aa46877f79724221b7caa423de6916acdc021f8..2d1aa38bb992302fe504bc166a3fa113e5365337 100644
--- a/third_party/SOLD2/sold2/misc/visualize_util.py
+++ b/third_party/SOLD2/sold2/misc/visualize_util.py
@@ -20,15 +20,17 @@ def plot_junctions(input_image, junctions, junc_size=3, color=None):
     if image.dtype == np.uint8:
         pass
     # A float type image ranging from 0~1
-    elif image.dtype in [np.float32, np.float64, np.float]  and image.max() <= 2.:
-        image = (image * 255.).astype(np.uint8)
+    elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0:
+        image = (image * 255.0).astype(np.uint8)
     # A float type image ranging from 0.~255.
-    elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.:
+    elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0:
         image = image.astype(np.uint8)
     else:
-        raise ValueError("[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8.")
+        raise ValueError(
+            "[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8."
+        )
 
-    # Check whether the image is single channel 
+    # Check whether the image is single channel
     if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
         # Squeeze to H*W first
         image = image.squeeze()
@@ -46,30 +48,38 @@ def plot_junctions(input_image, junctions, junc_size=3, color=None):
             junctions = junctions.T
         else:
             raise ValueError("[Error] At least one of the two dims should be 2.")
-    
+
     # Round and convert junctions to int (and check the boundary)
     H, W = image.shape[:2]
     junctions = (np.round(junctions)).astype(np.int)
-    junctions[junctions < 0] = 0 
-    junctions[junctions[:, 0] >= H, 0] = H-1  # (first dim) max bounded by H-1
-    junctions[junctions[:, 1] >= W, 1] = W-1  # (second dim) max bounded by W-1
+    junctions[junctions < 0] = 0
+    junctions[junctions[:, 0] >= H, 0] = H - 1  # (first dim) max bounded by H-1
+    junctions[junctions[:, 1] >= W, 1] = W - 1  # (second dim) max bounded by W-1
 
     # Iterate through all the junctions
     num_junc = junctions.shape[0]
     if color is None:
-        color = (0, 255., 0)
+        color = (0, 255.0, 0)
     for idx in range(num_junc):
         # Fetch one junction
         junc = junctions[idx, :]
-        cv2.circle(image, tuple(np.flip(junc)), radius=junc_size, 
-                    color=color, thickness=3)
-    
+        cv2.circle(
+            image, tuple(np.flip(junc)), radius=junc_size, color=color, thickness=3
+        )
+
     return image
 
 
 # Plot line segements given junctions and line adjecent map
-def plot_line_segments(input_image, junctions, line_map, junc_size=3, 
-                       color=(0, 255., 0), line_width=1, plot_survived_junc=True):
+def plot_line_segments(
+    input_image,
+    junctions,
+    line_map,
+    junc_size=3,
+    color=(0, 255.0, 0),
+    line_width=1,
+    plot_survived_junc=True,
+):
     """
     input_image: can be 0~1 float or 0~255 uint8.
     junctions: Nx2 or 2xN np array.
@@ -85,15 +95,17 @@ def plot_line_segments(input_image, junctions, line_map, junc_size=3,
     if image.dtype == np.uint8:
         pass
     # A float type image ranging from 0~1
-    elif image.dtype in [np.float32, np.float64, np.float]  and image.max() <= 2.:
-        image = (image * 255.).astype(np.uint8)
+    elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0:
+        image = (image * 255.0).astype(np.uint8)
     # A float type image ranging from 0.~255.
-    elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.:
+    elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0:
         image = image.astype(np.uint8)
     else:
-        raise ValueError("[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8.")
+        raise ValueError(
+            "[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8."
+        )
 
-    # Check whether the image is single channel 
+    # Check whether the image is single channel
     if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
         # Squeeze to H*W first
         image = image.squeeze()
@@ -111,7 +123,7 @@ def plot_line_segments(input_image, junctions, line_map, junc_size=3,
             junctions = junctions.T
         else:
             raise ValueError("[Error] At least one of the two dims should be 2.")
-    
+
     # line_map dimension should be 2
     if not len(line_map.shape) == 2:
         raise ValueError("[Error] line_map should be 2-dim array.")
@@ -122,8 +134,10 @@ def plot_line_segments(input_image, junctions, line_map, junc_size=3,
             raise ValueError("[Error] color should have type list or tuple.")
         else:
             if len(color) != 3:
-                raise ValueError("[Error] color should be a list or tuple with length 3.")
-    
+                raise ValueError(
+                    "[Error] color should be a list or tuple with length 3."
+                )
+
     # Make a copy of the line_map
     line_map_tmp = copy.copy(line_map)
 
@@ -136,14 +150,17 @@ def plot_line_segments(input_image, junctions, line_map, junc_size=3,
         # record the line segment
         else:
             for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]:
-                p1 = np.flip(junctions[idx, :])     # Convert to xy format
-                p2 = np.flip(junctions[idx2, :])    # Convert to xy format
-                segments = np.concatenate((segments, np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]), axis=0)
-                
+                p1 = np.flip(junctions[idx, :])  # Convert to xy format
+                p2 = np.flip(junctions[idx2, :])  # Convert to xy format
+                segments = np.concatenate(
+                    (segments, np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]),
+                    axis=0,
+                )
+
                 # Update line_map
                 line_map_tmp[idx, idx2] = 0
                 line_map_tmp[idx2, idx] = 0
-    
+
     # Draw segment pairs
     for idx in range(segments.shape[0]):
         seg = np.round(segments[idx, :]).astype(np.int)
@@ -151,8 +168,14 @@ def plot_line_segments(input_image, junctions, line_map, junc_size=3,
         if color != "random":
             color = tuple(color)
         else:
-            color = tuple(np.random.rand(3,))
-        cv2.line(image, tuple(seg[:2]), tuple(seg[2:]), color=color, thickness=line_width)
+            color = tuple(
+                np.random.rand(
+                    3,
+                )
+            )
+        cv2.line(
+            image, tuple(seg[:2]), tuple(seg[2:]), color=color, thickness=line_width
+        )
 
     # Also draw the junctions
     if not plot_survived_junc:
@@ -160,45 +183,63 @@ def plot_line_segments(input_image, junctions, line_map, junc_size=3,
         for idx in range(num_junc):
             # Fetch one junction
             junc = junctions[idx, :]
-            cv2.circle(image, tuple(np.flip(junc)), radius=junc_size, 
-                    color=(0, 255., 0), thickness=3) 
+            cv2.circle(
+                image,
+                tuple(np.flip(junc)),
+                radius=junc_size,
+                color=(0, 255.0, 0),
+                thickness=3,
+            )
     # Only plot the junctions which are part of a line segment
     else:
         for idx in range(segments.shape[0]):
-            seg = np.round(segments[idx, :]).astype(np.int) # Already in HW format.
-            cv2.circle(image, tuple(seg[:2]), radius=junc_size, 
-                    color=(0, 255., 0), thickness=3)
-            cv2.circle(image, tuple(seg[2:]), radius=junc_size, 
-                    color=(0, 255., 0), thickness=3)
-      
+            seg = np.round(segments[idx, :]).astype(np.int)  # Already in HW format.
+            cv2.circle(
+                image,
+                tuple(seg[:2]),
+                radius=junc_size,
+                color=(0, 255.0, 0),
+                thickness=3,
+            )
+            cv2.circle(
+                image,
+                tuple(seg[2:]),
+                radius=junc_size,
+                color=(0, 255.0, 0),
+                thickness=3,
+            )
+
     return image
 
 
 # Plot line segments given Nx4 or Nx2x2 line segments
-def plot_line_segments_from_segments(input_image, line_segments, junc_size=3, 
-                                     color=(0, 255., 0), line_width=1):
+def plot_line_segments_from_segments(
+    input_image, line_segments, junc_size=3, color=(0, 255.0, 0), line_width=1
+):
     # Create image copy
     image = copy.copy(input_image)
     # Make sure the image is converted to 255 uint8
     if image.dtype == np.uint8:
         pass
     # A float type image ranging from 0~1
-    elif image.dtype in [np.float32, np.float64, np.float]  and image.max() <= 2.:
-        image = (image * 255.).astype(np.uint8)
+    elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0:
+        image = (image * 255.0).astype(np.uint8)
     # A float type image ranging from 0.~255.
-    elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.:
+    elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0:
         image = image.astype(np.uint8)
     else:
-        raise ValueError("[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8.")
+        raise ValueError(
+            "[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8."
+        )
 
-    # Check whether the image is single channel 
+    # Check whether the image is single channel
     if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
         # Squeeze to H*W first
         image = image.squeeze()
 
         # Stack to channle 3
         image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
-    
+
     # Check the if line_segments are in (1) Nx4, or (2) Nx2x2.
     H, W, _ = image.shape
     # (1) Nx4 format
@@ -207,18 +248,20 @@ def plot_line_segments_from_segments(input_image, line_segments, junc_size=3,
         line_segments = line_segments.astype(np.int32)
 
         # Clip H dimension
-        line_segments[:, 0] = np.clip(line_segments[:, 0], a_min=0, a_max=H-1)
-        line_segments[:, 2] = np.clip(line_segments[:, 2], a_min=0, a_max=H-1)
+        line_segments[:, 0] = np.clip(line_segments[:, 0], a_min=0, a_max=H - 1)
+        line_segments[:, 2] = np.clip(line_segments[:, 2], a_min=0, a_max=H - 1)
 
         # Clip W dimension
-        line_segments[:, 1] = np.clip(line_segments[:, 1], a_min=0, a_max=W-1)
-        line_segments[:, 3] = np.clip(line_segments[:, 3], a_min=0, a_max=W-1)
+        line_segments[:, 1] = np.clip(line_segments[:, 1], a_min=0, a_max=W - 1)
+        line_segments[:, 3] = np.clip(line_segments[:, 3], a_min=0, a_max=W - 1)
 
         # Convert to Nx2x2 format
         line_segments = np.concatenate(
-            [np.expand_dims(line_segments[:, :2], axis=1),       
-            np.expand_dims(line_segments[:, 2:], axis=1)],
-            axis=1
+            [
+                np.expand_dims(line_segments[:, :2], axis=1),
+                np.expand_dims(line_segments[:, 2:], axis=1),
+            ],
+            axis=1,
         )
 
     # (2) Nx2x2 format
@@ -227,11 +270,13 @@ def plot_line_segments_from_segments(input_image, line_segments, junc_size=3,
         line_segments = line_segments.astype(np.int32)
 
         # Clip H dimension
-        line_segments[:, :, 0] = np.clip(line_segments[:, :, 0], a_min=0, a_max=H-1)
-        line_segments[:, :, 1] = np.clip(line_segments[:, :, 1], a_min=0, a_max=W-1)
+        line_segments[:, :, 0] = np.clip(line_segments[:, :, 0], a_min=0, a_max=H - 1)
+        line_segments[:, :, 1] = np.clip(line_segments[:, :, 1], a_min=0, a_max=W - 1)
 
     else:
-        raise ValueError("[Error] line_segments should be either Nx4 or Nx2x2 in HW format.")
+        raise ValueError(
+            "[Error] line_segments should be either Nx4 or Nx2x2 in HW format."
+        )
 
     # Draw segment pairs (all segments should be in HW format)
     image = image.copy()
@@ -241,21 +286,41 @@ def plot_line_segments_from_segments(input_image, line_segments, junc_size=3,
         if color != "random":
             color = tuple(color)
         else:
-            color = tuple(np.random.rand(3,))
-        cv2.line(image, tuple(np.flip(seg[0, :])), 
-                        tuple(np.flip(seg[1, :])), 
-                        color=color, thickness=line_width)
+            color = tuple(
+                np.random.rand(
+                    3,
+                )
+            )
+        cv2.line(
+            image,
+            tuple(np.flip(seg[0, :])),
+            tuple(np.flip(seg[1, :])),
+            color=color,
+            thickness=line_width,
+        )
 
         # Also draw the junctions
-        cv2.circle(image, tuple(np.flip(seg[0, :])), radius=junc_size, color=(0, 255., 0), thickness=3)
-        cv2.circle(image, tuple(np.flip(seg[1, :])), radius=junc_size, color=(0, 255., 0), thickness=3)
-    
+        cv2.circle(
+            image,
+            tuple(np.flip(seg[0, :])),
+            radius=junc_size,
+            color=(0, 255.0, 0),
+            thickness=3,
+        )
+        cv2.circle(
+            image,
+            tuple(np.flip(seg[1, :])),
+            radius=junc_size,
+            color=(0, 255.0, 0),
+            thickness=3,
+        )
+
     return image
 
 
 # Additional functions to visualize multiple images at the same time,
 # e.g. for line matching
-def plot_images(imgs, titles=None, cmaps='gray', dpi=100, size=6, pad=.5):
+def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=6, pad=0.5):
     """Plot a set of images horizontally.
     Args:
         imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
@@ -265,7 +330,7 @@ def plot_images(imgs, titles=None, cmaps='gray', dpi=100, size=6, pad=.5):
     n = len(imgs)
     if not isinstance(cmaps, (list, tuple)):
         cmaps = [cmaps] * n
-    figsize = (size*n, size*3/4) if size is not None else None
+    figsize = (size * n, size * 3 / 4) if size is not None else None
     fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
     if n == 1:
         ax = [ax]
@@ -281,7 +346,7 @@ def plot_images(imgs, titles=None, cmaps='gray', dpi=100, size=6, pad=.5):
     fig.tight_layout(pad=pad)
 
 
-def plot_keypoints(kpts, colors='lime', ps=4):
+def plot_keypoints(kpts, colors="lime", ps=4):
     """Plot keypoints for existing images.
     Args:
         kpts: list of ndarrays of size (N, 2).
@@ -295,7 +360,7 @@ def plot_keypoints(kpts, colors='lime', ps=4):
         a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0)
 
 
-def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.):
+def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0):
     """Plot matches for a pair of existing images.
     Args:
         kpts0, kpts1: corresponding keypoints of size (N, 2).
@@ -322,11 +387,18 @@ def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.):
         transFigure = fig.transFigure.inverted()
         fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
         fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
-        fig.lines += [matplotlib.lines.Line2D(
-            (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]),
-            zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw,
-            alpha=a)
-            for i in range(len(kpts0))]
+        fig.lines += [
+            matplotlib.lines.Line2D(
+                (fkpts0[i, 0], fkpts1[i, 0]),
+                (fkpts0[i, 1], fkpts1[i, 1]),
+                zorder=1,
+                transform=fig.transFigure,
+                c=color[i],
+                linewidth=lw,
+                alpha=a,
+            )
+            for i in range(len(kpts0))
+        ]
 
     # freeze the axes to prevent the transform to change
     ax0.autoscale(enable=False)
@@ -337,8 +409,9 @@ def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.):
         ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps, zorder=2)
 
 
-def plot_lines(lines, line_colors='orange', point_colors='cyan',
-               ps=4, lw=2, indices=(0, 1)):
+def plot_lines(
+    lines, line_colors="orange", point_colors="cyan", ps=4, lw=2, indices=(0, 1)
+):
     """Plot lines and endpoints for existing images.
     Args:
         lines: list of ndarrays of size (N, 2, 2).
@@ -361,16 +434,19 @@ def plot_lines(lines, line_colors='orange', point_colors='cyan',
     # Plot the lines and junctions
     for a, l, lc, pc in zip(axes, lines, line_colors, point_colors):
         for i in range(len(l)):
-            line = matplotlib.lines.Line2D((l[i, 0, 0], l[i, 1, 0]),
-                                           (l[i, 0, 1], l[i, 1, 1]),
-                                           zorder=1, c=lc, linewidth=lw)
+            line = matplotlib.lines.Line2D(
+                (l[i, 0, 0], l[i, 1, 0]),
+                (l[i, 0, 1], l[i, 1, 1]),
+                zorder=1,
+                c=lc,
+                linewidth=lw,
+            )
             a.add_line(line)
         pts = l.reshape(-1, 2)
-        a.scatter(pts[:, 0], pts[:, 1],
-                  c=pc, s=ps, linewidths=0, zorder=2)
+        a.scatter(pts[:, 0], pts[:, 1], c=pc, s=ps, linewidths=0, zorder=2)
 
 
-def plot_line_matches(kpts0, kpts1, color=None, lw=1.5, indices=(0, 1), a=1.):
+def plot_line_matches(kpts0, kpts1, color=None, lw=1.5, indices=(0, 1), a=1.0):
     """Plot matches for a pair of existing images, parametrized by their middle point.
     Args:
         kpts0, kpts1: corresponding middle points of the lines of size (N, 2).
@@ -396,19 +472,25 @@ def plot_line_matches(kpts0, kpts1, color=None, lw=1.5, indices=(0, 1), a=1.):
         transFigure = fig.transFigure.inverted()
         fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
         fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
-        fig.lines += [matplotlib.lines.Line2D(
-            (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]),
-            zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw,
-            alpha=a)
-            for i in range(len(kpts0))]
+        fig.lines += [
+            matplotlib.lines.Line2D(
+                (fkpts0[i, 0], fkpts1[i, 0]),
+                (fkpts0[i, 1], fkpts1[i, 1]),
+                zorder=1,
+                transform=fig.transFigure,
+                c=color[i],
+                linewidth=lw,
+                alpha=a,
+            )
+            for i in range(len(kpts0))
+        ]
 
     # freeze the axes to prevent the transform to change
     ax0.autoscale(enable=False)
     ax1.autoscale(enable=False)
 
 
-def plot_color_line_matches(lines, correct_matches=None,
-                            lw=2, indices=(0, 1)):
+def plot_color_line_matches(lines, correct_matches=None, lw=2, indices=(0, 1)):
     """Plot line matches for existing images with multiple colors.
     Args:
         lines: list of ndarrays of size (N, 2, 2).
@@ -417,7 +499,7 @@ def plot_color_line_matches(lines, correct_matches=None,
         indices: indices of the images to draw the matches on.
     """
     n_lines = len(lines[0])
-    colors = sns.color_palette('husl', n_colors=n_lines)
+    colors = sns.color_palette("husl", n_colors=n_lines)
     np.random.shuffle(colors)
     alphas = np.ones(n_lines)
     # If correct_matches is not None, display wrong matches with a low alpha
@@ -436,15 +518,21 @@ def plot_color_line_matches(lines, correct_matches=None,
         transFigure = fig.transFigure.inverted()
         endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
         endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
-        fig.lines += [matplotlib.lines.Line2D(
-            (endpoint0[i, 0], endpoint1[i, 0]),
-            (endpoint0[i, 1], endpoint1[i, 1]),
-            zorder=1, transform=fig.transFigure, c=colors[i],
-            alpha=alphas[i], linewidth=lw) for i in range(n_lines)]
-
-
-def plot_color_lines(lines, correct_matches, wrong_matches,
-                     lw=2, indices=(0, 1)):
+        fig.lines += [
+            matplotlib.lines.Line2D(
+                (endpoint0[i, 0], endpoint1[i, 0]),
+                (endpoint0[i, 1], endpoint1[i, 1]),
+                zorder=1,
+                transform=fig.transFigure,
+                c=colors[i],
+                alpha=alphas[i],
+                linewidth=lw,
+            )
+            for i in range(n_lines)
+        ]
+
+
+def plot_color_lines(lines, correct_matches, wrong_matches, lw=2, indices=(0, 1)):
     """Plot line matches for existing images with multiple colors:
     green for correct matches, red for wrong ones, and blue for the rest.
     Args:
@@ -476,15 +564,21 @@ def plot_color_lines(lines, correct_matches, wrong_matches,
         transFigure = fig.transFigure.inverted()
         endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
         endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
-        fig.lines += [matplotlib.lines.Line2D(
-            (endpoint0[i, 0], endpoint1[i, 0]),
-            (endpoint0[i, 1], endpoint1[i, 1]),
-            zorder=1, transform=fig.transFigure, c=c[i],
-            linewidth=lw) for i in range(len(l))]
+        fig.lines += [
+            matplotlib.lines.Line2D(
+                (endpoint0[i, 0], endpoint1[i, 0]),
+                (endpoint0[i, 1], endpoint1[i, 1]),
+                zorder=1,
+                transform=fig.transFigure,
+                c=c[i],
+                linewidth=lw,
+            )
+            for i in range(len(l))
+        ]
 
 
 def plot_subsegment_matches(lines, subsegments, lw=2, indices=(0, 1)):
-    """ Plot line matches for existing images with multiple colors and
+    """Plot line matches for existing images with multiple colors and
         highlight the actually matched subsegments.
     Args:
         lines: list of ndarrays of size (N, 2, 2).
@@ -493,8 +587,9 @@ def plot_subsegment_matches(lines, subsegments, lw=2, indices=(0, 1)):
         indices: indices of the images to draw the matches on.
     """
     n_lines = len(lines[0])
-    colors = sns.cubehelix_palette(start=2, rot=-0.2, dark=0.3, light=.7,
-                                   gamma=1.3, hue=1, n_colors=n_lines)
+    colors = sns.cubehelix_palette(
+        start=2, rot=-0.2, dark=0.3, light=0.7, gamma=1.3, hue=1, n_colors=n_lines
+    )
 
     fig = plt.gcf()
     ax = fig.axes
@@ -510,17 +605,31 @@ def plot_subsegment_matches(lines, subsegments, lw=2, indices=(0, 1)):
         # Draw full line
         endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
         endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
-        fig.lines += [matplotlib.lines.Line2D(
-            (endpoint0[i, 0], endpoint1[i, 0]),
-            (endpoint0[i, 1], endpoint1[i, 1]),
-            zorder=1, transform=fig.transFigure, c='red',
-            alpha=0.7, linewidth=lw) for i in range(n_lines)]
+        fig.lines += [
+            matplotlib.lines.Line2D(
+                (endpoint0[i, 0], endpoint1[i, 0]),
+                (endpoint0[i, 1], endpoint1[i, 1]),
+                zorder=1,
+                transform=fig.transFigure,
+                c="red",
+                alpha=0.7,
+                linewidth=lw,
+            )
+            for i in range(n_lines)
+        ]
 
         # Draw matched subsegment
         endpoint0 = transFigure.transform(a.transData.transform(ss[:, 0]))
         endpoint1 = transFigure.transform(a.transData.transform(ss[:, 1]))
-        fig.lines += [matplotlib.lines.Line2D(
-            (endpoint0[i, 0], endpoint1[i, 0]),
-            (endpoint0[i, 1], endpoint1[i, 1]),
-            zorder=1, transform=fig.transFigure, c=colors[i],
-            alpha=1, linewidth=lw) for i in range(n_lines)]
\ No newline at end of file
+        fig.lines += [
+            matplotlib.lines.Line2D(
+                (endpoint0[i, 0], endpoint1[i, 0]),
+                (endpoint0[i, 1], endpoint1[i, 1]),
+                zorder=1,
+                transform=fig.transFigure,
+                c=colors[i],
+                alpha=1,
+                linewidth=lw,
+            )
+            for i in range(n_lines)
+        ]
diff --git a/third_party/SOLD2/sold2/model/line_detection.py b/third_party/SOLD2/sold2/model/line_detection.py
index 0c186337b0ce2072ddd5246408c538dac2cf325f..8ff379a8de3ff5d54dc807b397f947ea8f361ef9 100644
--- a/third_party/SOLD2/sold2/model/line_detection.py
+++ b/third_party/SOLD2/sold2/model/line_detection.py
@@ -7,14 +7,25 @@ import torch
 
 
 class LineSegmentDetectionModule(object):
-    """ Module extracting line segments from junctions and line heatmaps. """
+    """Module extracting line segments from junctions and line heatmaps."""
+
     def __init__(
-        self, detect_thresh, num_samples=64, sampling_method="local_max",
-        inlier_thresh=0., heatmap_low_thresh=0.15, heatmap_high_thresh=0.2,
-        max_local_patch_radius=3, lambda_radius=2.,
-        use_candidate_suppression=False, nms_dist_tolerance=3., 
-        use_heatmap_refinement=False, heatmap_refine_cfg=None,
-        use_junction_refinement=False, junction_refine_cfg=None):
+        self,
+        detect_thresh,
+        num_samples=64,
+        sampling_method="local_max",
+        inlier_thresh=0.0,
+        heatmap_low_thresh=0.15,
+        heatmap_high_thresh=0.2,
+        max_local_patch_radius=3,
+        lambda_radius=2.0,
+        use_candidate_suppression=False,
+        nms_dist_tolerance=3.0,
+        use_heatmap_refinement=False,
+        heatmap_refine_cfg=None,
+        use_junction_refinement=False,
+        junction_refine_cfg=None,
+    ):
         """
         Parameters:
             detect_thresh: The probability threshold for mean activation (0. ~ 1.)
@@ -41,7 +52,7 @@ class LineSegmentDetectionModule(object):
         self.inlier_thresh = inlier_thresh
         self.local_patch_radius = max_local_patch_radius
         self.lambda_radius = lambda_radius
-        
+
         # Detecting junctions on the boundary parameters
         self.low_thresh = heatmap_low_thresh
         self.high_thresh = heatmap_high_thresh
@@ -65,56 +76,61 @@ class LineSegmentDetectionModule(object):
         self.junction_refine_cfg = junction_refine_cfg
         if self.use_junction_refinement and self.junction_refine_cfg is None:
             raise ValueError("[Error] Missing junction refinement config.")
-        
+
     def convert_inputs(self, inputs, device):
-        """ Convert inputs to desired torch tensor. """
+        """Convert inputs to desired torch tensor."""
         if isinstance(inputs, np.ndarray):
             outputs = torch.tensor(inputs, dtype=torch.float32, device=device)
         elif isinstance(inputs, torch.Tensor):
             outputs = inputs.to(torch.float32).to(device)
         else:
             raise ValueError(
-        "[Error] Inputs must either be torch tensor or numpy ndarray.")
-        
+                "[Error] Inputs must either be torch tensor or numpy ndarray."
+            )
+
         return outputs
-        
+
     def detect(self, junctions, heatmap, device=torch.device("cpu")):
-        """ Main function performing line segment detection. """
+        """Main function performing line segment detection."""
         # Convert inputs to torch tensor
         junctions = self.convert_inputs(junctions, device=device)
         heatmap = self.convert_inputs(heatmap, device=device)
-        
+
         # Perform the heatmap refinement
         if self.use_heatmap_refinement:
             if self.heatmap_refine_cfg["mode"] == "global":
                 heatmap = self.refine_heatmap(
-                    heatmap, 
+                    heatmap,
                     self.heatmap_refine_cfg["ratio"],
-                    self.heatmap_refine_cfg["valid_thresh"]
+                    self.heatmap_refine_cfg["valid_thresh"],
                 )
             elif self.heatmap_refine_cfg["mode"] == "local":
                 heatmap = self.refine_heatmap_local(
-                    heatmap, 
+                    heatmap,
                     self.heatmap_refine_cfg["num_blocks"],
                     self.heatmap_refine_cfg["overlap_ratio"],
                     self.heatmap_refine_cfg["ratio"],
-                    self.heatmap_refine_cfg["valid_thresh"]
+                    self.heatmap_refine_cfg["valid_thresh"],
                 )
-        
+
         # Initialize empty line map
         num_junctions = junctions.shape[0]
-        line_map_pred = torch.zeros([num_junctions, num_junctions],
-                                    device=device, dtype=torch.int32)
-        
+        line_map_pred = torch.zeros(
+            [num_junctions, num_junctions], device=device, dtype=torch.int32
+        )
+
         # Stop if there are not enough junctions
         if num_junctions < 2:
             return line_map_pred, junctions, heatmap
 
         # Generate the candidate map
-        candidate_map = torch.triu(torch.ones(
-            [num_junctions, num_junctions], device=device, dtype=torch.int32),
-                                   diagonal=1)
-        
+        candidate_map = torch.triu(
+            torch.ones(
+                [num_junctions, num_junctions], device=device, dtype=torch.int32
+            ),
+            diagonal=1,
+        )
+
         # Fetch the image boundary
         if len(heatmap.shape) > 2:
             H, W, _ = heatmap.shape
@@ -123,39 +139,47 @@ class LineSegmentDetectionModule(object):
 
         # Optionally perform candidate filtering
         if self.use_candidate_suppression:
-            candidate_map = self.candidate_suppression(junctions,
-                                                       candidate_map)
+            candidate_map = self.candidate_suppression(junctions, candidate_map)
 
         # Fetch the candidates
         candidate_index_map = torch.where(candidate_map)
-        candidate_index_map = torch.cat([candidate_index_map[0][..., None],
-                                         candidate_index_map[1][..., None]],
-                                        dim=-1)
-        
+        candidate_index_map = torch.cat(
+            [candidate_index_map[0][..., None], candidate_index_map[1][..., None]],
+            dim=-1,
+        )
+
         # Get the corresponding start and end junctions
         candidate_junc_start = junctions[candidate_index_map[:, 0], :]
         candidate_junc_end = junctions[candidate_index_map[:, 1], :]
 
         # Get the sampling locations (N x 64)
         sampler = self.torch_sampler.to(device)[None, ...]
-        cand_samples_h = candidate_junc_start[:, 0:1] * sampler + \
-                         candidate_junc_end[:, 0:1] * (1 - sampler)
-        cand_samples_w = candidate_junc_start[:, 1:2] * sampler + \
-                         candidate_junc_end[:, 1:2] * (1 - sampler)
-        
+        cand_samples_h = candidate_junc_start[:, 0:1] * sampler + candidate_junc_end[
+            :, 0:1
+        ] * (1 - sampler)
+        cand_samples_w = candidate_junc_start[:, 1:2] * sampler + candidate_junc_end[
+            :, 1:2
+        ] * (1 - sampler)
+
         # Clip to image boundary
-        cand_h = torch.clamp(cand_samples_h, min=0, max=H-1)
-        cand_w = torch.clamp(cand_samples_w, min=0, max=W-1)
-        
+        cand_h = torch.clamp(cand_samples_h, min=0, max=H - 1)
+        cand_w = torch.clamp(cand_samples_w, min=0, max=W - 1)
+
         # Local maximum search
         if self.sampling_method == "local_max":
             # Compute normalized segment lengths
-            segments_length = torch.sqrt(torch.sum(
-                (candidate_junc_start.to(torch.float32) -
-                 candidate_junc_end.to(torch.float32)) ** 2, dim=-1))
-            normalized_seg_length = (segments_length
-                                     / (((H ** 2) + (W ** 2)) ** 0.5))
-            
+            segments_length = torch.sqrt(
+                torch.sum(
+                    (
+                        candidate_junc_start.to(torch.float32)
+                        - candidate_junc_end.to(torch.float32)
+                    )
+                    ** 2,
+                    dim=-1,
+                )
+            )
+            normalized_seg_length = segments_length / (((H**2) + (W**2)) ** 0.5)
+
             # Perform local max search
             num_cand = cand_h.shape[0]
             group_size = 10000
@@ -163,85 +187,88 @@ class LineSegmentDetectionModule(object):
                 num_iter = math.ceil(num_cand / group_size)
                 sampled_feat_lst = []
                 for iter_idx in range(num_iter):
-                    if not iter_idx == num_iter-1:
-                        cand_h_ = cand_h[iter_idx * group_size:
-                                         (iter_idx+1) * group_size, :]
-                        cand_w_ = cand_w[iter_idx * group_size:
-                                         (iter_idx+1) * group_size, :]
+                    if not iter_idx == num_iter - 1:
+                        cand_h_ = cand_h[
+                            iter_idx * group_size : (iter_idx + 1) * group_size, :
+                        ]
+                        cand_w_ = cand_w[
+                            iter_idx * group_size : (iter_idx + 1) * group_size, :
+                        ]
                         normalized_seg_length_ = normalized_seg_length[
-                            iter_idx * group_size: (iter_idx+1) * group_size]
+                            iter_idx * group_size : (iter_idx + 1) * group_size
+                        ]
                     else:
-                        cand_h_ = cand_h[iter_idx * group_size:, :]
-                        cand_w_ = cand_w[iter_idx * group_size:, :]
+                        cand_h_ = cand_h[iter_idx * group_size :, :]
+                        cand_w_ = cand_w[iter_idx * group_size :, :]
                         normalized_seg_length_ = normalized_seg_length[
-                            iter_idx * group_size:]
+                            iter_idx * group_size :
+                        ]
                     sampled_feat_ = self.detect_local_max(
-                        heatmap, cand_h_, cand_w_, H, W,
-                        normalized_seg_length_, device)
+                        heatmap, cand_h_, cand_w_, H, W, normalized_seg_length_, device
+                    )
                     sampled_feat_lst.append(sampled_feat_)
                 sampled_feat = torch.cat(sampled_feat_lst, dim=0)
             else:
                 sampled_feat = self.detect_local_max(
-                    heatmap, cand_h, cand_w, H, W, 
-                    normalized_seg_length, device)
+                    heatmap, cand_h, cand_w, H, W, normalized_seg_length, device
+                )
         # Bilinear sampling
         elif self.sampling_method == "bilinear":
             # Perform bilinear sampling
-            sampled_feat = self.detect_bilinear(
-                heatmap, cand_h, cand_w, H, W, device)
+            sampled_feat = self.detect_bilinear(heatmap, cand_h, cand_w, H, W, device)
         else:
             raise ValueError("[Error] Unknown sampling method.")
-     
+
         # [Simple threshold detection]
         # detection_results is a mask over all candidates
-        detection_results = (torch.mean(sampled_feat, dim=-1)
-                             > self.detect_thresh)
-        
+        detection_results = torch.mean(sampled_feat, dim=-1) > self.detect_thresh
+
         # [Inlier threshold detection]
-        if self.inlier_thresh > 0.:
-            inlier_ratio = torch.sum(
-                sampled_feat > self.detect_thresh,
-                dim=-1).to(torch.float32) / self.num_samples
+        if self.inlier_thresh > 0.0:
+            inlier_ratio = (
+                torch.sum(sampled_feat > self.detect_thresh, dim=-1).to(torch.float32)
+                / self.num_samples
+            )
             detection_results_inlier = inlier_ratio >= self.inlier_thresh
             detection_results = detection_results * detection_results_inlier
 
         # Convert detection results back to line_map_pred
         detected_junc_indexes = candidate_index_map[detection_results, :]
-        line_map_pred[detected_junc_indexes[:, 0],
-                      detected_junc_indexes[:, 1]] = 1
-        line_map_pred[detected_junc_indexes[:, 1],
-                      detected_junc_indexes[:, 0]] = 1
-        
+        line_map_pred[detected_junc_indexes[:, 0], detected_junc_indexes[:, 1]] = 1
+        line_map_pred[detected_junc_indexes[:, 1], detected_junc_indexes[:, 0]] = 1
+
         # Perform junction refinement
         if self.use_junction_refinement and len(detected_junc_indexes) > 0:
             junctions, line_map_pred = self.refine_junction_perturb(
-                junctions, line_map_pred, heatmap, H, W, device)
+                junctions, line_map_pred, heatmap, H, W, device
+            )
 
         return line_map_pred, junctions, heatmap
-    
+
     def refine_heatmap(self, heatmap, ratio=0.2, valid_thresh=1e-2):
-        """ Global heatmap refinement method. """
+        """Global heatmap refinement method."""
         # Grab the top 10% values
         heatmap_values = heatmap[heatmap > valid_thresh]
         sorted_values = torch.sort(heatmap_values, descending=True)[0]
         top10_len = math.ceil(sorted_values.shape[0] * ratio)
         max20 = torch.mean(sorted_values[:top10_len])
-        heatmap = torch.clamp(heatmap / max20, min=0., max=1.)
+        heatmap = torch.clamp(heatmap / max20, min=0.0, max=1.0)
         return heatmap
-    
-    def refine_heatmap_local(self, heatmap, num_blocks=5, overlap_ratio=0.5,
-                             ratio=0.2, valid_thresh=2e-3):
-        """ Local heatmap refinement method. """
+
+    def refine_heatmap_local(
+        self, heatmap, num_blocks=5, overlap_ratio=0.5, ratio=0.2, valid_thresh=2e-3
+    ):
+        """Local heatmap refinement method."""
         # Get the shape of the heatmap
         H, W = heatmap.shape
         increase_ratio = 1 - overlap_ratio
         h_block = round(H / (1 + (num_blocks - 1) * increase_ratio))
         w_block = round(W / (1 + (num_blocks - 1) * increase_ratio))
 
-        count_map = torch.zeros(heatmap.shape, dtype=torch.int,
-                                device=heatmap.device)
-        heatmap_output = torch.zeros(heatmap.shape, dtype=torch.float,
-                                     device=heatmap.device)
+        count_map = torch.zeros(heatmap.shape, dtype=torch.int, device=heatmap.device)
+        heatmap_output = torch.zeros(
+            heatmap.shape, dtype=torch.float, device=heatmap.device
+        )
         # Iterate through each block
         for h_idx in range(num_blocks):
             for w_idx in range(num_blocks):
@@ -254,25 +281,29 @@ class LineSegmentDetectionModule(object):
                 subheatmap = heatmap[h_start:h_end, w_start:w_end]
                 if subheatmap.max() > valid_thresh:
                     subheatmap = self.refine_heatmap(
-                        subheatmap, ratio, valid_thresh=valid_thresh)
-                
+                        subheatmap, ratio, valid_thresh=valid_thresh
+                    )
+
                 # Aggregate it to the final heatmap
                 heatmap_output[h_start:h_end, w_start:w_end] += subheatmap
                 count_map[h_start:h_end, w_start:w_end] += 1
-        heatmap_output = torch.clamp(heatmap_output / count_map,
-                                     max=1., min=0.)
+        heatmap_output = torch.clamp(heatmap_output / count_map, max=1.0, min=0.0)
 
         return heatmap_output
 
     def candidate_suppression(self, junctions, candidate_map):
-        """ Suppress overlapping long lines in the candidate segments. """
+        """Suppress overlapping long lines in the candidate segments."""
         # Define the distance tolerance
         dist_tolerance = self.nms_dist_tolerance
 
         # Compute distance between junction pairs
         # (num_junc x 1 x 2) - (1 x num_junc x 2) => num_junc x num_junc map
-        line_dist_map = torch.sum((torch.unsqueeze(junctions, dim=1)
-                                  - junctions[None, ...]) ** 2, dim=-1) ** 0.5
+        line_dist_map = (
+            torch.sum(
+                (torch.unsqueeze(junctions, dim=1) - junctions[None, ...]) ** 2, dim=-1
+            )
+            ** 0.5
+        )
 
         # Fetch all the "detected lines"
         seg_indexes = torch.where(torch.triu(candidate_map, diagonal=1))
@@ -285,20 +316,23 @@ class LineSegmentDetectionModule(object):
         line_dists = line_dist_map[start_point_idxs, end_point_idxs]
 
         # Check whether they are on the line
-        dir_vecs = ((end_points - start_points)
-                    / torch.norm(end_points - start_points,
-                                 dim=-1)[..., None])
+        dir_vecs = (end_points - start_points) / torch.norm(
+            end_points - start_points, dim=-1
+        )[..., None]
         # Get the orthogonal distance
         cand_vecs = junctions[None, ...] - start_points.unsqueeze(dim=1)
         cand_vecs_norm = torch.norm(cand_vecs, dim=-1)
         # Check whether they are projected directly onto the segment
-        proj = (torch.einsum('bij,bjk->bik', cand_vecs, dir_vecs[..., None])
-                / line_dists[..., None, None])
+        proj = (
+            torch.einsum("bij,bjk->bik", cand_vecs, dir_vecs[..., None])
+            / line_dists[..., None, None]
+        )
         # proj is num_segs x num_junction x 1
-        proj_mask = (proj >=0) * (proj <= 1)
+        proj_mask = (proj >= 0) * (proj <= 1)
         cand_angles = torch.acos(
-            torch.einsum('bij,bjk->bik', cand_vecs, dir_vecs[..., None])
-            / cand_vecs_norm[..., None])
+            torch.einsum("bij,bjk->bik", cand_vecs, dir_vecs[..., None])
+            / cand_vecs_norm[..., None]
+        )
         cand_dists = cand_vecs_norm[..., None] * torch.sin(cand_angles)
         junc_dist_mask = cand_dists <= dist_tolerance
         junc_mask = junc_dist_mask * proj_mask
@@ -306,21 +340,21 @@ class LineSegmentDetectionModule(object):
         # Minus starting points
         num_segs = start_point_idxs.shape[0]
         junc_counts = torch.sum(junc_mask, dim=[1, 2])
-        junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs),
-                                         start_point_idxs].to(torch.int)
-        junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs),
-                                         end_point_idxs].to(torch.int)
-        
+        junc_counts -= junc_mask[..., 0][
+            torch.arange(0, num_segs), start_point_idxs
+        ].to(torch.int)
+        junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs), end_point_idxs].to(
+            torch.int
+        )
+
         # Get the invalid candidate mask
         final_mask = junc_counts > 0
-        candidate_map[start_point_idxs[final_mask],
-                      end_point_idxs[final_mask]] = 0
-            
+        candidate_map[start_point_idxs[final_mask], end_point_idxs[final_mask]] = 0
+
         return candidate_map
-    
-    def refine_junction_perturb(self, junctions, line_map_pred,
-                                heatmap, H, W, device):
-        """ Refine the line endpoints in a similar way as in LSD. """
+
+    def refine_junction_perturb(self, junctions, line_map_pred, heatmap, H, W, device):
+        """Refine the line endpoints in a similar way as in LSD."""
         # Get the config
         junction_refine_cfg = self.junction_refine_cfg
 
@@ -330,14 +364,23 @@ class LineSegmentDetectionModule(object):
         side_perturbs = (num_perturbs - 1) // 2
         # Fetch the 2D perturb mat
         perturb_vec = torch.arange(
-            start=-perturb_interval*side_perturbs,
-            end=perturb_interval*(side_perturbs+1),
-            step=perturb_interval, device=device)
+            start=-perturb_interval * side_perturbs,
+            end=perturb_interval * (side_perturbs + 1),
+            step=perturb_interval,
+            device=device,
+        )
         w1_grid, h1_grid, w2_grid, h2_grid = torch.meshgrid(
-            perturb_vec, perturb_vec, perturb_vec, perturb_vec)
-        perturb_tensor = torch.cat([
-            w1_grid[..., None], h1_grid[..., None], 
-            w2_grid[..., None], h2_grid[..., None]], dim=-1)
+            perturb_vec, perturb_vec, perturb_vec, perturb_vec
+        )
+        perturb_tensor = torch.cat(
+            [
+                w1_grid[..., None],
+                h1_grid[..., None],
+                w2_grid[..., None],
+                h2_grid[..., None],
+            ],
+            dim=-1,
+        )
         perturb_tensor_flat = perturb_tensor.view(-1, 2, 2)
 
         # Fetch the junctions and line_map
@@ -351,16 +394,20 @@ class LineSegmentDetectionModule(object):
         start_points = junctions[start_point_idxs, :]
         end_points = junctions[end_point_idxs, :]
 
-        line_segments = torch.cat([start_points.unsqueeze(dim=1),
-                                   end_points.unsqueeze(dim=1)], dim=1)
+        line_segments = torch.cat(
+            [start_points.unsqueeze(dim=1), end_points.unsqueeze(dim=1)], dim=1
+        )
 
-        line_segment_candidates = (line_segments.unsqueeze(dim=1)
-                                   + perturb_tensor_flat[None, ...])
+        line_segment_candidates = (
+            line_segments.unsqueeze(dim=1) + perturb_tensor_flat[None, ...]
+        )
         # Clip the boundaries
         line_segment_candidates[..., 0] = torch.clamp(
-            line_segment_candidates[..., 0], min=0, max=H - 1)
+            line_segment_candidates[..., 0], min=0, max=H - 1
+        )
         line_segment_candidates[..., 1] = torch.clamp(
-            line_segment_candidates[..., 1], min=0, max=W - 1)
+            line_segment_candidates[..., 1], min=0, max=W - 1
+        )
 
         # Iterate through all the segments
         refined_segment_lst = []
@@ -373,36 +420,37 @@ class LineSegmentDetectionModule(object):
 
             # Get the sampling locations (N x 64)
             sampler = self.torch_sampler.to(device)[None, ...]
-            cand_samples_h = (candidate_junc_start[:, 0:1] * sampler +
-                              candidate_junc_end[:, 0:1] * (1 - sampler))
-            cand_samples_w = (candidate_junc_start[:, 1:2] * sampler +
-                              candidate_junc_end[:, 1:2] * (1 - sampler))
-            
+            cand_samples_h = candidate_junc_start[
+                :, 0:1
+            ] * sampler + candidate_junc_end[:, 0:1] * (1 - sampler)
+            cand_samples_w = candidate_junc_start[
+                :, 1:2
+            ] * sampler + candidate_junc_end[:, 1:2] * (1 - sampler)
+
             # Clip to image boundary
             cand_h = torch.clamp(cand_samples_h, min=0, max=H - 1)
             cand_w = torch.clamp(cand_samples_w, min=0, max=W - 1)
 
             # Perform bilinear sampling
-            segment_feat = self.detect_bilinear(
-                heatmap, cand_h, cand_w, H, W, device)
+            segment_feat = self.detect_bilinear(heatmap, cand_h, cand_w, H, W, device)
             segment_results = torch.mean(segment_feat, dim=-1)
             max_idx = torch.argmax(segment_results)
             refined_segment_lst.append(segment[max_idx, ...][None, ...])
-        
+
         # Concatenate back to segments
         refined_segments = torch.cat(refined_segment_lst, dim=0)
 
         # Convert back to junctions and line_map
         junctions_new = torch.cat(
-            [refined_segments[:, 0, :], refined_segments[:, 1, :]], dim=0)
+            [refined_segments[:, 0, :], refined_segments[:, 1, :]], dim=0
+        )
         junctions_new = torch.unique(junctions_new, dim=0)
-        line_map_new = self.segments_to_line_map(junctions_new,
-                                                 refined_segments)
+        line_map_new = self.segments_to_line_map(junctions_new, refined_segments)
 
         return junctions_new, line_map_new
-    
+
     def segments_to_line_map(self, junctions, segments):
-        """ Convert the list of segments to line map. """
+        """Convert the list of segments to line map."""
         # Create empty line map
         device = junctions.device
         num_junctions = junctions.shape[0]
@@ -416,10 +464,8 @@ class LineSegmentDetectionModule(object):
             junction2 = seg[1, :]
 
             # Get index
-            idx_junction1 = torch.where(
-                (junctions == junction1).sum(axis=1) == 2)[0]
-            idx_junction2 = torch.where(
-                (junctions == junction2).sum(axis=1) == 2)[0]
+            idx_junction1 = torch.where((junctions == junction1).sum(axis=1) == 2)[0]
+            idx_junction2 = torch.where((junctions == junction2).sum(axis=1) == 2)[0]
 
             # label the corresponding entries
             line_map[idx_junction1, idx_junction2] = 1
@@ -428,7 +474,7 @@ class LineSegmentDetectionModule(object):
         return line_map
 
     def detect_bilinear(self, heatmap, cand_h, cand_w, H, W, device):
-        """ Detection by bilinear sampling. """
+        """Detection by bilinear sampling."""
         # Get the floor and ceiling locations
         cand_h_floor = torch.floor(cand_h).to(torch.long)
         cand_h_ceil = torch.ceil(cand_h).to(torch.long)
@@ -437,63 +483,83 @@ class LineSegmentDetectionModule(object):
 
         # Perform the bilinear sampling
         cand_samples_feat = (
-            heatmap[cand_h_floor, cand_w_floor] * (cand_h_ceil - cand_h)
-            * (cand_w_ceil - cand_w) + heatmap[cand_h_floor, cand_w_ceil]
-            * (cand_h_ceil - cand_h) * (cand_w - cand_w_floor) +
-            heatmap[cand_h_ceil, cand_w_floor] * (cand_h - cand_h_floor)
-            * (cand_w_ceil - cand_w) + heatmap[cand_h_ceil, cand_w_ceil]
-            * (cand_h - cand_h_floor) * (cand_w - cand_w_floor))
-        
+            heatmap[cand_h_floor, cand_w_floor]
+            * (cand_h_ceil - cand_h)
+            * (cand_w_ceil - cand_w)
+            + heatmap[cand_h_floor, cand_w_ceil]
+            * (cand_h_ceil - cand_h)
+            * (cand_w - cand_w_floor)
+            + heatmap[cand_h_ceil, cand_w_floor]
+            * (cand_h - cand_h_floor)
+            * (cand_w_ceil - cand_w)
+            + heatmap[cand_h_ceil, cand_w_ceil]
+            * (cand_h - cand_h_floor)
+            * (cand_w - cand_w_floor)
+        )
+
         return cand_samples_feat
-    
-    def detect_local_max(self, heatmap, cand_h, cand_w, H, W,
-                         normalized_seg_length, device):
-        """ Detection by local maximum search. """
+
+    def detect_local_max(
+        self, heatmap, cand_h, cand_w, H, W, normalized_seg_length, device
+    ):
+        """Detection by local maximum search."""
         # Compute the distance threshold
-        dist_thresh = (0.5 * (2 ** 0.5)
-                       + self.lambda_radius * normalized_seg_length)
+        dist_thresh = 0.5 * (2**0.5) + self.lambda_radius * normalized_seg_length
         # Make it N x 64
-        dist_thresh = torch.repeat_interleave(dist_thresh[..., None],
-                                              self.num_samples, dim=-1)
-        
+        dist_thresh = torch.repeat_interleave(
+            dist_thresh[..., None], self.num_samples, dim=-1
+        )
+
         # Compute the candidate points
-        cand_points = torch.cat([cand_h[..., None], cand_w[..., None]],
-                                dim=-1)
-        cand_points_round = torch.round(cand_points) # N x 64 x 2
-        
+        cand_points = torch.cat([cand_h[..., None], cand_w[..., None]], dim=-1)
+        cand_points_round = torch.round(cand_points)  # N x 64 x 2
+
         # Construct local patches 9x9 = 81
-        patch_mask = torch.zeros([int(2 * self.local_patch_radius + 1), 
-                                  int(2 * self.local_patch_radius + 1)],
-                                 device=device)
+        patch_mask = torch.zeros(
+            [
+                int(2 * self.local_patch_radius + 1),
+                int(2 * self.local_patch_radius + 1),
+            ],
+            device=device,
+        )
         patch_center = torch.tensor(
-            [[self.local_patch_radius, self.local_patch_radius]], 
-            device=device, dtype=torch.float32)
+            [[self.local_patch_radius, self.local_patch_radius]],
+            device=device,
+            dtype=torch.float32,
+        )
         H_patch_points, W_patch_points = torch.where(patch_mask >= 0)
-        patch_points = torch.cat([H_patch_points[..., None],
-                                  W_patch_points[..., None]], dim=-1)
+        patch_points = torch.cat(
+            [H_patch_points[..., None], W_patch_points[..., None]], dim=-1
+        )
         # Fetch the circle region
-        patch_center_dist = torch.sqrt(torch.sum(
-            (patch_points - patch_center) ** 2, dim=-1))
-        patch_points = (patch_points[patch_center_dist
-                        <= self.local_patch_radius, :])
+        patch_center_dist = torch.sqrt(
+            torch.sum((patch_points - patch_center) ** 2, dim=-1)
+        )
+        patch_points = patch_points[patch_center_dist <= self.local_patch_radius, :]
         # Shift [0, 0] to the center
         patch_points = patch_points - self.local_patch_radius
-        
+
         # Construct local patch mask
-        patch_points_shifted = (torch.unsqueeze(cand_points_round, dim=2)
-                                + patch_points[None, None, ...])
-        patch_dist = torch.sqrt(torch.sum((torch.unsqueeze(cand_points, dim=2)
-                                          - patch_points_shifted) ** 2,
-                                          dim=-1))
+        patch_points_shifted = (
+            torch.unsqueeze(cand_points_round, dim=2) + patch_points[None, None, ...]
+        )
+        patch_dist = torch.sqrt(
+            torch.sum(
+                (torch.unsqueeze(cand_points, dim=2) - patch_points_shifted) ** 2,
+                dim=-1,
+            )
+        )
         patch_dist_mask = patch_dist < dist_thresh[..., None]
-        
+
         # Get all points => num_points_center x num_patch_points x 2
-        points_H = torch.clamp(patch_points_shifted[:, :, :, 0], min=0,
-                               max=H - 1).to(torch.long)
-        points_W = torch.clamp(patch_points_shifted[:, :, :, 1], min=0,
-                               max=W - 1).to(torch.long)
+        points_H = torch.clamp(patch_points_shifted[:, :, :, 0], min=0, max=H - 1).to(
+            torch.long
+        )
+        points_W = torch.clamp(patch_points_shifted[:, :, :, 1], min=0, max=W - 1).to(
+            torch.long
+        )
         points = torch.cat([points_H[..., None], points_W[..., None]], dim=-1)
-        
+
         # Sample the feature (N x 64 x 81)
         sampled_feat = heatmap[points[:, :, :, 0], points[:, :, :, 1]]
         # Filtering using the valid mask
@@ -502,5 +568,5 @@ class LineSegmentDetectionModule(object):
             sampled_feat_lmax = torch.empty(0, 64)
         else:
             sampled_feat_lmax, _ = torch.max(sampled_feat, dim=-1)
-        
+
         return sampled_feat_lmax
diff --git a/third_party/SOLD2/sold2/model/line_detector.py b/third_party/SOLD2/sold2/model/line_detector.py
index 2f3d059e130178c482e8e569171ef9e0370424c7..33429f8bc48d21d223efaf83ab6a8f1375b359ec 100644
--- a/third_party/SOLD2/sold2/model/line_detector.py
+++ b/third_party/SOLD2/sold2/model/line_detector.py
@@ -14,7 +14,7 @@ from ..misc.train_utils import adapt_checkpoint
 
 
 def line_map_to_segments(junctions, line_map):
-    """ Convert a line map to a Nx2x2 list of segments. """ 
+    """Convert a line map to a Nx2x2 list of segments."""
     line_map_tmp = line_map.copy()
 
     output_segments = np.zeros([0, 2, 2])
@@ -27,22 +27,23 @@ def line_map_to_segments(junctions, line_map):
             for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]:
                 p1 = junctions[idx, :]  # HW format
                 p2 = junctions[idx2, :]
-                single_seg = np.concatenate([p1[None, ...], p2[None, ...]],
-                                            axis=0)
+                single_seg = np.concatenate([p1[None, ...], p2[None, ...]], axis=0)
                 output_segments = np.concatenate(
-                    (output_segments, single_seg[None, ...]), axis=0)
-                
+                    (output_segments, single_seg[None, ...]), axis=0
+                )
+
                 # Update line_map
                 line_map_tmp[idx, idx2] = 0
                 line_map_tmp[idx2, idx] = 0
-    
+
     return output_segments
 
 
 class LineDetector(object):
-    def __init__(self, model_cfg, ckpt_path, device, line_detector_cfg,
-                 junc_detect_thresh=None):
-        """ SOLD² line detector taking raw images as input.
+    def __init__(
+        self, model_cfg, ckpt_path, device, line_detector_cfg, junc_detect_thresh=None
+    ):
+        """SOLD² line detector taking raw images as input.
         Parameters:
             model_cfg: config for CNN model
             ckpt_path: path to the weights
@@ -51,7 +52,7 @@ class LineDetector(object):
         # Get loss weights if dynamic weighting
         _, loss_weights = get_loss_and_weights(model_cfg, device)
         self.device = device
-        
+
         # Initialize the cnn backbone
         self.model = get_model(model_cfg, loss_weights)
         checkpoint = torch.load(ckpt_path, map_location=self.device)
@@ -65,20 +66,21 @@ class LineDetector(object):
         if junc_detect_thresh is not None:
             self.junc_detect_thresh = junc_detect_thresh
         else:
-            self.junc_detect_thresh = model_cfg.get("detection_thresh", 1/65)
+            self.junc_detect_thresh = model_cfg.get("detection_thresh", 1 / 65)
         self.max_num_junctions = model_cfg.get("max_num_junctions", 300)
 
         # Initialize the line detector
         self.line_detector_cfg = line_detector_cfg
         self.line_detector = LineSegmentDetectionModule(**line_detector_cfg)
-    
-    def __call__(self, input_image, valid_mask=None,
-                 return_heatmap=False, profile=False):
+
+    def __call__(
+        self, input_image, valid_mask=None, return_heatmap=False, profile=False
+    ):
         # Now we restrict input_image to 4D torch tensor
-        if ((not len(input_image.shape) == 4)
-            or (not isinstance(input_image, torch.Tensor))):
-            raise ValueError(
-        "[Error] the input image should be a 4D torch tensor.")
+        if (not len(input_image.shape) == 4) or (
+            not isinstance(input_image, torch.Tensor)
+        ):
+            raise ValueError("[Error] the input image should be a 4D torch tensor.")
 
         # Move the input to corresponding device
         input_image = input_image.to(self.device)
@@ -89,15 +91,18 @@ class LineDetector(object):
             net_outputs = self.model(input_image)
 
         junc_np = convert_junc_predictions(
-            net_outputs["junctions"], self.grid_size,
-            self.junc_detect_thresh, self.max_num_junctions)
+            net_outputs["junctions"],
+            self.grid_size,
+            self.junc_detect_thresh,
+            self.max_num_junctions,
+        )
         if valid_mask is None:
             junctions = np.where(junc_np["junc_pred_nms"].squeeze())
         else:
-            junctions = np.where(junc_np["junc_pred_nms"].squeeze()
-                                 * valid_mask)
+            junctions = np.where(junc_np["junc_pred_nms"].squeeze() * valid_mask)
         junctions = np.concatenate(
-            [junctions[0][..., None], junctions[1][..., None]], axis=-1)
+            [junctions[0][..., None], junctions[1][..., None]], axis=-1
+        )
 
         if net_outputs["heatmap"].shape[1] == 2:
             # Convert to single channel directly from here
@@ -108,7 +113,8 @@ class LineDetector(object):
 
         # Run the line detector.
         line_map, junctions, heatmap = self.line_detector.detect(
-            junctions, heatmap, device=self.device)
+            junctions, heatmap, device=self.device
+        )
         heatmap = heatmap.cpu().numpy()
         if isinstance(line_map, torch.Tensor):
             line_map = line_map.cpu().numpy()
@@ -123,5 +129,5 @@ class LineDetector(object):
             outputs["heatmap"] = heatmap
         if profile:
             outputs["time"] = end_time - start_time
-        
+
         return outputs
diff --git a/third_party/SOLD2/sold2/model/line_matcher.py b/third_party/SOLD2/sold2/model/line_matcher.py
index bc5a003573c91313e2295c75871edcb1c113662a..458a5e3141c0ad27c0ba665dbd72d5ce0c1c9a86 100644
--- a/third_party/SOLD2/sold2/model/line_matcher.py
+++ b/third_party/SOLD2/sold2/model/line_matcher.py
@@ -19,14 +19,23 @@ from .line_detector import line_map_to_segments
 
 
 class LineMatcher(object):
-    """ Full line matcher including line detection and matching
-        with the Needleman-Wunsch algorithm. """
-    def __init__(self, model_cfg, ckpt_path, device, line_detector_cfg,
-                 line_matcher_cfg, multiscale=False, scales=[1., 2.]):
+    """Full line matcher including line detection and matching
+    with the Needleman-Wunsch algorithm."""
+
+    def __init__(
+        self,
+        model_cfg,
+        ckpt_path,
+        device,
+        line_detector_cfg,
+        line_matcher_cfg,
+        multiscale=False,
+        scales=[1.0, 2.0],
+    ):
         # Get loss weights if dynamic weighting
         _, loss_weights = get_loss_and_weights(model_cfg, device)
         self.device = device
-        
+
         # Initialize the cnn backbone
         self.model = get_model(model_cfg, loss_weights)
         checkpoint = torch.load(ckpt_path, map_location=self.device)
@@ -46,23 +55,22 @@ class LineMatcher(object):
 
         # Initialize the line matcher
         self.line_matcher = WunschLineMatcher(**line_matcher_cfg)
-        
+
         # Print some debug messages
         for key, val in line_detector_cfg.items():
             print(f"[Debug] {key}: {val}")
         # print("[Debug] detect_thresh: %f" % (line_detector_cfg["detect_thresh"]))
         # print("[Debug] num_samples: %d" % (line_detector_cfg["num_samples"]))
-        
-
 
     # Perform line detection and descriptor inference on a single image
-    def line_detection(self, input_image, valid_mask=None,
-                       desc_only=False, profile=False):
+    def line_detection(
+        self, input_image, valid_mask=None, desc_only=False, profile=False
+    ):
         # Restrict input_image to 4D torch tensor
-        if ((not len(input_image.shape) == 4)
-            or (not isinstance(input_image, torch.Tensor))):
-            raise ValueError(
-                "[Error] the input image should be a 4D torch tensor")
+        if (not len(input_image.shape) == 4) or (
+            not isinstance(input_image, torch.Tensor)
+        ):
+            raise ValueError("[Error] the input image should be a 4D torch tensor")
 
         # Move the input to corresponding device
         input_image = input_image.to(self.device)
@@ -76,29 +84,40 @@ class LineMatcher(object):
 
         if not desc_only:
             junc_np = convert_junc_predictions(
-                net_outputs["junctions"], self.grid_size,
-                self.junc_detect_thresh, self.max_num_junctions)
+                net_outputs["junctions"],
+                self.grid_size,
+                self.junc_detect_thresh,
+                self.max_num_junctions,
+            )
             if valid_mask is None:
                 junctions = np.where(junc_np["junc_pred_nms"].squeeze())
             else:
-                junctions = np.where(
-                    junc_np["junc_pred_nms"].squeeze() * valid_mask)
-            junctions = np.concatenate([junctions[0][..., None],
-                                        junctions[1][..., None]], axis=-1)
+                junctions = np.where(junc_np["junc_pred_nms"].squeeze() * valid_mask)
+            junctions = np.concatenate(
+                [junctions[0][..., None], junctions[1][..., None]], axis=-1
+            )
 
             if net_outputs["heatmap"].shape[1] == 2:
                 # Convert to single channel directly from here
-                heatmap = softmax(
-                    net_outputs["heatmap"],
-                    dim=1)[:, 1:, :, :].cpu().numpy().transpose(0, 2, 3, 1)
+                heatmap = (
+                    softmax(net_outputs["heatmap"], dim=1)[:, 1:, :, :]
+                    .cpu()
+                    .numpy()
+                    .transpose(0, 2, 3, 1)
+                )
             else:
-                heatmap = torch.sigmoid(
-                    net_outputs["heatmap"]).cpu().numpy().transpose(0, 2, 3, 1)
+                heatmap = (
+                    torch.sigmoid(net_outputs["heatmap"])
+                    .cpu()
+                    .numpy()
+                    .transpose(0, 2, 3, 1)
+                )
             heatmap = heatmap[0, :, :, 0]
 
             # Run the line detector.
             line_map, junctions, heatmap = self.line_detector.detect(
-                junctions, heatmap, device=self.device)
+                junctions, heatmap, device=self.device
+            )
             if isinstance(line_map, torch.Tensor):
                 line_map = line_map.cpu().numpy()
             if isinstance(junctions, torch.Tensor):
@@ -115,7 +134,9 @@ class LineMatcher(object):
                     line_segments_inlier = []
                     for inlier_idx in range(num_inlier_thresh):
                         line_map_tmp = line_map[detect_idx, inlier_idx, :, :]
-                        line_segments_tmp = line_map_to_segments(junctions, line_map_tmp)
+                        line_segments_tmp = line_map_to_segments(
+                            junctions, line_map_tmp
+                        )
                         line_segments_inlier.append(line_segments_tmp)
                     line_segments.append(line_segments_inlier)
             else:
@@ -127,18 +148,24 @@ class LineMatcher(object):
 
         if profile:
             outputs["time"] = end_time - start_time
-        
+
         return outputs
 
     # Perform line detection and descriptor inference at multiple scales
-    def multiscale_line_detection(self, input_image, valid_mask=None,
-                                  desc_only=False, profile=False,
-                                  scales=[1., 2.], aggregation='mean'):
+    def multiscale_line_detection(
+        self,
+        input_image,
+        valid_mask=None,
+        desc_only=False,
+        profile=False,
+        scales=[1.0, 2.0],
+        aggregation="mean",
+    ):
         # Restrict input_image to 4D torch tensor
-        if ((not len(input_image.shape) == 4)
-            or (not isinstance(input_image, torch.Tensor))):
-            raise ValueError(
-                "[Error] the input image should be a 4D torch tensor")
+        if (not len(input_image.shape) == 4) or (
+            not isinstance(input_image, torch.Tensor)
+        ):
+            raise ValueError("[Error] the input image should be a 4D torch tensor")
 
         # Move the input to corresponding device
         input_image = input_image.to(self.device)
@@ -150,34 +177,39 @@ class LineMatcher(object):
         junctions, heatmaps, descriptors = [], [], []
         for s in scales:
             # Resize the image
-            resized_img = F.interpolate(input_image, scale_factor=s,
-                                        mode='bilinear')
+            resized_img = F.interpolate(input_image, scale_factor=s, mode="bilinear")
 
             # Forward of the CNN backbone
             with torch.no_grad():
                 net_outputs = self.model(resized_img)
 
-            descriptors.append(F.interpolate(
-                net_outputs["descriptors"], size=desc_size, mode="bilinear"))
+            descriptors.append(
+                F.interpolate(
+                    net_outputs["descriptors"], size=desc_size, mode="bilinear"
+                )
+            )
 
             if not desc_only:
                 junc_prob = convert_junc_predictions(
-                    net_outputs["junctions"], self.grid_size)["junc_pred"]
-                junctions.append(cv2.resize(junc_prob.squeeze(),
-                                 (img_size[1], img_size[0]),
-                                 interpolation=cv2.INTER_LINEAR))
+                    net_outputs["junctions"], self.grid_size
+                )["junc_pred"]
+                junctions.append(
+                    cv2.resize(
+                        junc_prob.squeeze(),
+                        (img_size[1], img_size[0]),
+                        interpolation=cv2.INTER_LINEAR,
+                    )
+                )
 
                 if net_outputs["heatmap"].shape[1] == 2:
                     # Convert to single channel directly from here
-                    heatmap = softmax(net_outputs["heatmap"],
-                                      dim=1)[:, 1:, :, :]
+                    heatmap = softmax(net_outputs["heatmap"], dim=1)[:, 1:, :, :]
                 else:
                     heatmap = torch.sigmoid(net_outputs["heatmap"])
-                heatmaps.append(F.interpolate(heatmap, size=img_size,
-                                              mode="bilinear"))
+                heatmaps.append(F.interpolate(heatmap, size=img_size, mode="bilinear"))
 
         # Aggregate the results
-        if aggregation == 'mean':
+        if aggregation == "mean":
             # Aggregation through the mean activation
             descriptors = torch.stack(descriptors, dim=0).mean(0)
         else:
@@ -186,7 +218,7 @@ class LineMatcher(object):
         outputs = {"descriptor": descriptors}
 
         if not desc_only:
-            if aggregation == 'mean':
+            if aggregation == "mean":
                 junctions = np.stack(junctions, axis=0).mean(0)[None]
                 heatmap = torch.stack(heatmaps, dim=0).mean(0)[0, 0, :, :]
                 heatmap = heatmap.cpu().numpy()
@@ -197,18 +229,23 @@ class LineMatcher(object):
 
             # Extract junctions
             junc_pred_nms = super_nms(
-                junctions[..., None], self.grid_size,
-                self.junc_detect_thresh, self.max_num_junctions)
+                junctions[..., None],
+                self.grid_size,
+                self.junc_detect_thresh,
+                self.max_num_junctions,
+            )
             if valid_mask is None:
                 junctions = np.where(junc_pred_nms.squeeze())
             else:
                 junctions = np.where(junc_pred_nms.squeeze() * valid_mask)
-            junctions = np.concatenate([junctions[0][..., None],
-                                        junctions[1][..., None]], axis=-1)
+            junctions = np.concatenate(
+                [junctions[0][..., None], junctions[1][..., None]], axis=-1
+            )
 
             # Run the line detector.
             line_map, junctions, heatmap = self.line_detector.detect(
-                junctions, heatmap, device=self.device)
+                junctions, heatmap, device=self.device
+            )
             if isinstance(line_map, torch.Tensor):
                 line_map = line_map.cpu().numpy()
             if isinstance(junctions, torch.Tensor):
@@ -226,7 +263,8 @@ class LineMatcher(object):
                     for inlier_idx in range(num_inlier_thresh):
                         line_map_tmp = line_map[detect_idx, inlier_idx, :, :]
                         line_segments_tmp = line_map_to_segments(
-                            junctions, line_map_tmp)
+                            junctions, line_map_tmp
+                        )
                         line_segments_inlier.append(line_segments_tmp)
                     line_segments.append(line_segments_inlier)
             else:
@@ -238,25 +276,25 @@ class LineMatcher(object):
 
         if profile:
             outputs["time"] = end_time - start_time
-        
+
         return outputs
-    
+
     def __call__(self, images, valid_masks=[None, None], profile=False):
         # Line detection and descriptor inference on both images
         if self.multiscale:
             forward_outputs = [
                 self.multiscale_line_detection(
-                    images[0], valid_masks[0], profile=profile,
-                    scales=self.scales),
+                    images[0], valid_masks[0], profile=profile, scales=self.scales
+                ),
                 self.multiscale_line_detection(
-                    images[1], valid_masks[1], profile=profile,
-                    scales=self.scales)]
+                    images[1], valid_masks[1], profile=profile, scales=self.scales
+                ),
+            ]
         else:
             forward_outputs = [
-                self.line_detection(images[0], valid_masks[0],
-                                    profile=profile),
-                self.line_detection(images[1], valid_masks[1],
-                                    profile=profile)]
+                self.line_detection(images[0], valid_masks[0], profile=profile),
+                self.line_detection(images[1], valid_masks[1], profile=profile),
+            ]
         line_seg1 = forward_outputs[0]["line_segments"]
         line_seg2 = forward_outputs[1]["line_segments"]
         desc1 = forward_outputs[0]["descriptor"]
@@ -264,16 +302,15 @@ class LineMatcher(object):
 
         # Match the lines in both images
         start_time = time.time()
-        matches = self.line_matcher.forward(line_seg1, line_seg2,
-                                            desc1, desc2)
+        matches = self.line_matcher.forward(line_seg1, line_seg2, desc1, desc2)
         end_time = time.time()
 
-        outputs = {"line_segments": [line_seg1, line_seg2],
-                   "matches": matches}
+        outputs = {"line_segments": [line_seg1, line_seg2], "matches": matches}
 
         if profile:
-            outputs["line_detection_time"] = (forward_outputs[0]["time"]
-                                              + forward_outputs[1]["time"])
+            outputs["line_detection_time"] = (
+                forward_outputs[0]["time"] + forward_outputs[1]["time"]
+            )
             outputs["line_matching_time"] = end_time - start_time
-        
+
         return outputs
diff --git a/third_party/SOLD2/sold2/model/line_matching.py b/third_party/SOLD2/sold2/model/line_matching.py
index 89b71879e3104f9a8b52c1cf5e534cd124fe83b2..bfceb5a161732c3f7f4cf97e988d5e369a4c25fa 100644
--- a/third_party/SOLD2/sold2/model/line_matching.py
+++ b/third_party/SOLD2/sold2/model/line_matching.py
@@ -10,11 +10,19 @@ from ..misc.geometry_utils import keypoints_to_grid
 
 
 class WunschLineMatcher(object):
-    """ Class matching two sets of line segments
-        with the Needleman-Wunsch algorithm. """
-    def __init__(self, cross_check=True, num_samples=10, min_dist_pts=8,
-                 top_k_candidates=10, grid_size=8, sampling="regular",
-                 line_score=False):
+    """Class matching two sets of line segments
+    with the Needleman-Wunsch algorithm."""
+
+    def __init__(
+        self,
+        cross_check=True,
+        num_samples=10,
+        min_dist_pts=8,
+        top_k_candidates=10,
+        grid_size=8,
+        sampling="regular",
+        line_score=False,
+    ):
         self.cross_check = cross_check
         self.num_samples = num_samples
         self.min_dist_pts = min_dist_pts
@@ -27,13 +35,11 @@ class WunschLineMatcher(object):
 
     def forward(self, line_seg1, line_seg2, desc1, desc2):
         """
-            Find the best matches between two sets of line segments
-            and their corresponding descriptors.
+        Find the best matches between two sets of line segments
+        and their corresponding descriptors.
         """
-        img_size1 = (desc1.shape[2] * self.grid_size,
-                     desc1.shape[3] * self.grid_size)
-        img_size2 = (desc2.shape[2] * self.grid_size,
-                     desc2.shape[3] * self.grid_size)
+        img_size1 = (desc1.shape[2] * self.grid_size, desc1.shape[3] * self.grid_size)
+        img_size2 = (desc2.shape[2] * self.grid_size, desc2.shape[3] * self.grid_size)
         device = desc1.device
 
         # Default case when an image has no lines
@@ -48,13 +54,17 @@ class WunschLineMatcher(object):
             line_points2, valid_points2 = self.sample_line_points(line_seg2)
         else:
             line_points1, valid_points1 = self.sample_salient_points(
-                line_seg1, desc1, img_size1, self.sampling_mode)
+                line_seg1, desc1, img_size1, self.sampling_mode
+            )
             line_points2, valid_points2 = self.sample_salient_points(
-                line_seg2, desc2, img_size2, self.sampling_mode)
-        line_points1 = torch.tensor(line_points1.reshape(-1, 2),
-                                    dtype=torch.float, device=device)
-        line_points2 = torch.tensor(line_points2.reshape(-1, 2),
-                                    dtype=torch.float, device=device)
+                line_seg2, desc2, img_size2, self.sampling_mode
+            )
+        line_points1 = torch.tensor(
+            line_points1.reshape(-1, 2), dtype=torch.float, device=device
+        )
+        line_points2 = torch.tensor(
+            line_points2.reshape(-1, 2), dtype=torch.float, device=device
+        )
 
         # Extract the descriptors for each point
         grid1 = keypoints_to_grid(line_points1, img_size1)
@@ -67,8 +77,9 @@ class WunschLineMatcher(object):
         scores = desc1.t() @ desc2
         scores[~valid_points1.flatten()] = -1
         scores[:, ~valid_points2.flatten()] = -1
-        scores = scores.reshape(len(line_seg1), self.num_samples,
-                                len(line_seg2), self.num_samples)
+        scores = scores.reshape(
+            len(line_seg1), self.num_samples, len(line_seg2), self.num_samples
+        )
         scores = scores.permute(0, 2, 1, 3)
         # scores.shape = (n_lines1, n_lines2, num_samples, num_samples)
 
@@ -77,16 +88,15 @@ class WunschLineMatcher(object):
 
         # [Optionally] filter matches with mutual nearest neighbor filtering
         if self.cross_check:
-            matches2 = self.filter_and_match_lines(
-                scores.permute(1, 0, 3, 2))
+            matches2 = self.filter_and_match_lines(scores.permute(1, 0, 3, 2))
             mutual = matches2[matches] == np.arange(len(line_seg1))
             matches[~mutual] = -1
 
         return matches
 
     def d2_net_saliency_score(self, desc):
-        """ Compute the D2-Net saliency score
-            on a 3D or 4D descriptor. """
+        """Compute the D2-Net saliency score
+        on a 3D or 4D descriptor."""
         is_3d = len(desc.shape) == 3
         b_size = len(desc)
         feat = F.relu(desc)
@@ -94,11 +104,9 @@ class WunschLineMatcher(object):
         # Compute the soft local max
         exp = torch.exp(feat)
         if is_3d:
-            sum_exp = 3 * F.avg_pool1d(exp, kernel_size=3, stride=1,
-                                       padding=1)
+            sum_exp = 3 * F.avg_pool1d(exp, kernel_size=3, stride=1, padding=1)
         else:
-            sum_exp = 9 * F.avg_pool2d(exp, kernel_size=3, stride=1,
-                                       padding=1)
+            sum_exp = 9 * F.avg_pool2d(exp, kernel_size=3, stride=1, padding=1)
         soft_local_max = exp / sum_exp
 
         # Compute the depth-wise maximum
@@ -116,7 +124,7 @@ class WunschLineMatcher(object):
         return score
 
     def asl_feat_saliency_score(self, desc):
-        """ Compute the ASLFeat saliency score on a 3D or 4D descriptor. """
+        """Compute the ASLFeat saliency score on a 3D or 4D descriptor."""
         is_3d = len(desc.shape) == 3
         b_size = len(desc)
 
@@ -141,8 +149,7 @@ class WunschLineMatcher(object):
         score = score / normalization
         return score
 
-    def sample_salient_points(self, line_seg, desc, img_size,
-                              saliency_type='d2_net'):
+    def sample_salient_points(self, line_seg, desc, img_size, saliency_type="d2_net"):
         """
         Sample the most salient points along each line segments, with a
         minimal distance between each point. Pad the remaining points.
@@ -167,8 +174,9 @@ class WunschLineMatcher(object):
         line_lengths = np.linalg.norm(line_seg[:, 0] - line_seg[:, 1], axis=1)
 
         # The number of samples depends on the length of the line
-        num_samples_lst = np.clip(line_lengths // self.min_dist_pts,
-                                  2, self.num_samples)
+        num_samples_lst = np.clip(
+            line_lengths // self.min_dist_pts, 2, self.num_samples
+        )
         line_points = np.empty((num_lines, self.num_samples, 2), dtype=float)
         valid_points = np.empty((num_lines, self.num_samples), dtype=bool)
 
@@ -182,17 +190,19 @@ class WunschLineMatcher(object):
             cur_num_lines = len(cur_line_seg)
             if cur_num_lines == 0:
                 continue
-            line_points_x = np.linspace(cur_line_seg[:, 0, 0],
-                                        cur_line_seg[:, 1, 0],
-                                        sample_rate, axis=-1)
-            line_points_y = np.linspace(cur_line_seg[:, 0, 1],
-                                        cur_line_seg[:, 1, 1],
-                                        sample_rate, axis=-1)
-            cur_line_points = np.stack([line_points_x, line_points_y],
-                                       axis=-1).reshape(-1, 2)
+            line_points_x = np.linspace(
+                cur_line_seg[:, 0, 0], cur_line_seg[:, 1, 0], sample_rate, axis=-1
+            )
+            line_points_y = np.linspace(
+                cur_line_seg[:, 0, 1], cur_line_seg[:, 1, 1], sample_rate, axis=-1
+            )
+            cur_line_points = np.stack([line_points_x, line_points_y], axis=-1).reshape(
+                -1, 2
+            )
             # cur_line_points is of shape (n_cur_lines * sample_rate, 2)
-            cur_line_points = torch.tensor(cur_line_points, dtype=torch.float,
-                                           device=device)
+            cur_line_points = torch.tensor(
+                cur_line_points, dtype=torch.float, device=device
+            )
             grid_points = keypoints_to_grid(cur_line_points, img_size)
 
             if self.line_score:
@@ -206,25 +216,26 @@ class WunschLineMatcher(object):
                 else:
                     scores = self.asl_feat_saliency_score(line_desc)
             else:
-                scores = F.grid_sample(score.unsqueeze(1),
-                                       grid_points).squeeze()
+                scores = F.grid_sample(score.unsqueeze(1), grid_points).squeeze()
 
             # Take the most salient point in n distinct regions
             scores = scores.reshape(-1, n, n_samples_per_region)
             best = torch.max(scores, dim=2, keepdim=True)[1].cpu().numpy()
-            cur_line_points = cur_line_points.reshape(-1, n,
-                                                      n_samples_per_region, 2)
+            cur_line_points = cur_line_points.reshape(-1, n, n_samples_per_region, 2)
             cur_line_points = np.take_along_axis(
-                cur_line_points, best[..., None], axis=2)[:, :, 0]
+                cur_line_points, best[..., None], axis=2
+            )[:, :, 0]
 
             # Pad
-            cur_valid_points = np.ones((cur_num_lines, self.num_samples),
-                                       dtype=bool)
+            cur_valid_points = np.ones((cur_num_lines, self.num_samples), dtype=bool)
             cur_valid_points[:, n:] = False
-            cur_line_points = np.concatenate([
-                cur_line_points,
-                np.zeros((cur_num_lines, self.num_samples - n, 2), dtype=float)],
-                axis=1)
+            cur_line_points = np.concatenate(
+                [
+                    cur_line_points,
+                    np.zeros((cur_num_lines, self.num_samples - n, 2), dtype=float),
+                ],
+                axis=1,
+            )
 
             line_points[cur_mask] = cur_line_points
             valid_points[cur_mask] = cur_valid_points
@@ -246,31 +257,34 @@ class WunschLineMatcher(object):
 
         # Sample the points separated by at least min_dist_pts along each line
         # The number of samples depends on the length of the line
-        num_samples_lst = np.clip(line_lengths // self.min_dist_pts,
-                                  2, self.num_samples)
+        num_samples_lst = np.clip(
+            line_lengths // self.min_dist_pts, 2, self.num_samples
+        )
         line_points = np.empty((num_lines, self.num_samples, 2), dtype=float)
         valid_points = np.empty((num_lines, self.num_samples), dtype=bool)
         for n in np.arange(2, self.num_samples + 1):
             # Consider all lines where we can fit up to n points
             cur_mask = num_samples_lst == n
             cur_line_seg = line_seg[cur_mask]
-            line_points_x = np.linspace(cur_line_seg[:, 0, 0],
-                                        cur_line_seg[:, 1, 0],
-                                        n, axis=-1)
-            line_points_y = np.linspace(cur_line_seg[:, 0, 1],
-                                        cur_line_seg[:, 1, 1],
-                                        n, axis=-1)
+            line_points_x = np.linspace(
+                cur_line_seg[:, 0, 0], cur_line_seg[:, 1, 0], n, axis=-1
+            )
+            line_points_y = np.linspace(
+                cur_line_seg[:, 0, 1], cur_line_seg[:, 1, 1], n, axis=-1
+            )
             cur_line_points = np.stack([line_points_x, line_points_y], axis=-1)
 
             # Pad
             cur_num_lines = len(cur_line_seg)
-            cur_valid_points = np.ones((cur_num_lines, self.num_samples),
-                                       dtype=bool)
+            cur_valid_points = np.ones((cur_num_lines, self.num_samples), dtype=bool)
             cur_valid_points[:, n:] = False
-            cur_line_points = np.concatenate([
-                cur_line_points,
-                np.zeros((cur_num_lines, self.num_samples - n, 2), dtype=float)],
-                axis=1)
+            cur_line_points = np.concatenate(
+                [
+                    cur_line_points,
+                    np.zeros((cur_num_lines, self.num_samples - n, 2), dtype=float),
+                ],
+                axis=1,
+            )
 
             line_points[cur_mask] = cur_line_points
             valid_points[cur_mask] = cur_valid_points
@@ -290,23 +304,18 @@ class WunschLineMatcher(object):
         # Pre-filter the pairs and keep the top k best candidate lines
         line_scores1 = scores.max(3)[0]
         valid_scores1 = line_scores1 != -1
-        line_scores1 = ((line_scores1 * valid_scores1).sum(2)
-                        / valid_scores1.sum(2))
+        line_scores1 = (line_scores1 * valid_scores1).sum(2) / valid_scores1.sum(2)
         line_scores2 = scores.max(2)[0]
         valid_scores2 = line_scores2 != -1
-        line_scores2 = ((line_scores2 * valid_scores2).sum(2)
-                        / valid_scores2.sum(2))
+        line_scores2 = (line_scores2 * valid_scores2).sum(2) / valid_scores2.sum(2)
         line_scores = (line_scores1 + line_scores2) / 2
-        topk_lines = torch.argsort(line_scores,
-                                dim=1)[:, -self.top_k_candidates:]
+        topk_lines = torch.argsort(line_scores, dim=1)[:, -self.top_k_candidates :]
         scores, topk_lines = scores.cpu().numpy(), topk_lines.cpu().numpy()
         # topk_lines.shape = (n_lines1, top_k_candidates)
-        top_scores = np.take_along_axis(scores, topk_lines[:, :, None, None],
-                                        axis=1)
+        top_scores = np.take_along_axis(scores, topk_lines[:, :, None, None], axis=1)
 
         # Consider the reversed line segments as well
-        top_scores = np.concatenate([top_scores, top_scores[..., ::-1]],
-                                    axis=1)
+        top_scores = np.concatenate([top_scores, top_scores[..., ::-1]], axis=1)
 
         # Compute the line distance matrix with Needleman-Wunsch algo and
         # retrieve the closest line neighbor
@@ -339,30 +348,33 @@ class WunschLineMatcher(object):
             for j in range(m):
                 nw_grid[:, i + 1, j + 1] = np.maximum(
                     np.maximum(nw_grid[:, i + 1, j], nw_grid[:, i, j + 1]),
-                    nw_grid[:, i, j] + nw_scores[:, i, j])
+                    nw_grid[:, i, j] + nw_scores[:, i, j],
+                )
 
         return nw_grid[:, -1, -1]
 
     def get_pairwise_distance(self, line_seg1, line_seg2, desc1, desc2):
         """
-            Compute the OPPOSITE of the NW score for pairs of line segments
-            and their corresponding descriptors.
+        Compute the OPPOSITE of the NW score for pairs of line segments
+        and their corresponding descriptors.
         """
         num_lines = len(line_seg1)
-        assert num_lines == len(line_seg2), "The same number of lines is required in pairwise score."
-        img_size1 = (desc1.shape[2] * self.grid_size,
-                     desc1.shape[3] * self.grid_size)
-        img_size2 = (desc2.shape[2] * self.grid_size,
-                     desc2.shape[3] * self.grid_size)
+        assert num_lines == len(
+            line_seg2
+        ), "The same number of lines is required in pairwise score."
+        img_size1 = (desc1.shape[2] * self.grid_size, desc1.shape[3] * self.grid_size)
+        img_size2 = (desc2.shape[2] * self.grid_size, desc2.shape[3] * self.grid_size)
         device = desc1.device
 
         # Sample points regularly along each line
         line_points1, valid_points1 = self.sample_line_points(line_seg1)
         line_points2, valid_points2 = self.sample_line_points(line_seg2)
-        line_points1 = torch.tensor(line_points1.reshape(-1, 2),
-                                    dtype=torch.float, device=device)
-        line_points2 = torch.tensor(line_points2.reshape(-1, 2),
-                                    dtype=torch.float, device=device)
+        line_points1 = torch.tensor(
+            line_points1.reshape(-1, 2), dtype=torch.float, device=device
+        )
+        line_points2 = torch.tensor(
+            line_points2.reshape(-1, 2), dtype=torch.float, device=device
+        )
 
         # Extract the descriptors for each point
         grid1 = keypoints_to_grid(line_points1, img_size1)
@@ -374,9 +386,8 @@ class WunschLineMatcher(object):
 
         # Compute the distance between line points for every pair of lines
         # Assign a score of -1 for unvalid points
-        scores = torch.einsum('dns,dnt->nst', desc1, desc2).cpu().numpy()
-        scores = scores.reshape(num_lines * self.num_samples,
-                                self.num_samples)
+        scores = torch.einsum("dns,dnt->nst", desc1, desc2).cpu().numpy()
+        scores = scores.reshape(num_lines * self.num_samples, self.num_samples)
         scores[~valid_points1.flatten()] = -1
         scores = scores.reshape(num_lines, self.num_samples, self.num_samples)
         scores = scores.transpose(1, 0, 2).reshape(self.num_samples, -1)
diff --git a/third_party/SOLD2/sold2/model/loss.py b/third_party/SOLD2/sold2/model/loss.py
index aaad3c67f3fd59db308869901f8a56623901e318..c1d2bfd232958fc19a4a775fe561dd5089079bff 100644
--- a/third_party/SOLD2/sold2/model/loss.py
+++ b/third_party/SOLD2/sold2/model/loss.py
@@ -7,17 +7,16 @@ import torch.nn as nn
 import torch.nn.functional as F
 from kornia.geometry import warp_perspective
 
-from ..misc.geometry_utils import (keypoints_to_grid, get_dist_mask,
-                                   get_common_line_mask)
+from ..misc.geometry_utils import keypoints_to_grid, get_dist_mask, get_common_line_mask
 
 
 def get_loss_and_weights(model_cfg, device=torch.device("cuda")):
-    """ Get loss functions and either static or dynamic weighting. """
+    """Get loss functions and either static or dynamic weighting."""
     # Get the global weighting policy
     w_policy = model_cfg.get("weighting_policy", "static")
     if not w_policy in ["static", "dynamic"]:
         raise ValueError("[Error] Not supported weighting policy.")
-    
+
     loss_func = {}
     loss_weight = {}
     # Get junction loss function and weight
@@ -27,14 +26,16 @@ def get_loss_and_weights(model_cfg, device=torch.device("cuda")):
 
     # Get heatmap loss function and weight
     w_heatmap, heatmap_loss_func = get_heatmap_loss_and_weight(
-        model_cfg, w_policy, device)
+        model_cfg, w_policy, device
+    )
     loss_func["heatmap_loss"] = heatmap_loss_func.to(device)
     loss_weight["w_heatmap"] = w_heatmap
 
     # [Optionally] get descriptor loss function and weight
     if model_cfg.get("descriptor_loss_func", None) is not None:
         w_descriptor, descriptor_loss_func = get_descriptor_loss_and_weight(
-            model_cfg, w_policy)
+            model_cfg, w_policy
+        )
         loss_func["descriptor_loss"] = descriptor_loss_func.to(device)
         loss_weight["w_desc"] = w_descriptor
 
@@ -42,26 +43,26 @@ def get_loss_and_weights(model_cfg, device=torch.device("cuda")):
 
 
 def get_junction_loss_and_weight(model_cfg, global_w_policy):
-    """ Get the junction loss function and weight. """
+    """Get the junction loss function and weight."""
     junction_loss_cfg = model_cfg.get("junction_loss_cfg", {})
-    
+
     # Get the junction loss weight
     w_policy = junction_loss_cfg.get("policy", global_w_policy)
     if w_policy == "static":
         w_junc = torch.tensor(model_cfg["w_junc"], dtype=torch.float32)
     elif w_policy == "dynamic":
         w_junc = nn.Parameter(
-            torch.tensor(model_cfg["w_junc"], dtype=torch.float32),
-            requires_grad=True)
+            torch.tensor(model_cfg["w_junc"], dtype=torch.float32), requires_grad=True
+        )
     else:
-        raise ValueError(
-    "[Error] Unknown weighting policy for junction loss weight.")
+        raise ValueError("[Error] Unknown weighting policy for junction loss weight.")
 
     # Get the junction loss function
     junc_loss_name = model_cfg.get("junction_loss_func", "superpoint")
     if junc_loss_name == "superpoint":
-        junc_loss_func = JunctionDetectionLoss(model_cfg["grid_size"],
-                                               model_cfg["keep_border_valid"])
+        junc_loss_func = JunctionDetectionLoss(
+            model_cfg["grid_size"], model_cfg["keep_border_valid"]
+        )
     else:
         raise ValueError("[Error] Not supported junction loss function.")
 
@@ -69,7 +70,7 @@ def get_junction_loss_and_weight(model_cfg, global_w_policy):
 
 
 def get_heatmap_loss_and_weight(model_cfg, global_w_policy, device):
-    """ Get the heatmap loss function and weight. """
+    """Get the heatmap loss function and weight."""
     heatmap_loss_cfg = model_cfg.get("heatmap_loss_cfg", {})
 
     # Get the heatmap loss weight
@@ -78,19 +79,20 @@ def get_heatmap_loss_and_weight(model_cfg, global_w_policy, device):
         w_heatmap = torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32)
     elif w_policy == "dynamic":
         w_heatmap = nn.Parameter(
-            torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32), 
-            requires_grad=True)
+            torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32),
+            requires_grad=True,
+        )
     else:
-        raise ValueError(
-    "[Error] Unknown weighting policy for junction loss weight.")
+        raise ValueError("[Error] Unknown weighting policy for junction loss weight.")
 
     # Get the corresponding heatmap loss based on the config
     heatmap_loss_name = model_cfg.get("heatmap_loss_func", "cross_entropy")
     if heatmap_loss_name == "cross_entropy":
         # Get the heatmap class weight (always static)
-        heatmap_class_w = model_cfg.get("w_heatmap_class", 1.)
-        class_weight = torch.tensor(
-            np.array([1., heatmap_class_w])).to(torch.float).to(device)
+        heatmap_class_w = model_cfg.get("w_heatmap_class", 1.0)
+        class_weight = (
+            torch.tensor(np.array([1.0, heatmap_class_w])).to(torch.float).to(device)
+        )
         heatmap_loss_func = HeatmapLoss(class_weight=class_weight)
     else:
         raise ValueError("[Error] Not supported heatmap loss function.")
@@ -99,28 +101,28 @@ def get_heatmap_loss_and_weight(model_cfg, global_w_policy, device):
 
 
 def get_descriptor_loss_and_weight(model_cfg, global_w_policy):
-    """ Get the descriptor loss function and weight. """
+    """Get the descriptor loss function and weight."""
     descriptor_loss_cfg = model_cfg.get("descriptor_loss_cfg", {})
-    
+
     # Get the descriptor loss weight
     w_policy = descriptor_loss_cfg.get("policy", global_w_policy)
     if w_policy == "static":
         w_descriptor = torch.tensor(model_cfg["w_desc"], dtype=torch.float32)
     elif w_policy == "dynamic":
-        w_descriptor = nn.Parameter(torch.tensor(model_cfg["w_desc"],
-                                    dtype=torch.float32), requires_grad=True)
+        w_descriptor = nn.Parameter(
+            torch.tensor(model_cfg["w_desc"], dtype=torch.float32), requires_grad=True
+        )
     else:
-        raise ValueError(
-    "[Error] Unknown weighting policy for descriptor loss weight.")
+        raise ValueError("[Error] Unknown weighting policy for descriptor loss weight.")
 
     # Get the descriptor loss function
-    descriptor_loss_name = model_cfg.get("descriptor_loss_func",
-                                         "regular_sampling")
+    descriptor_loss_name = model_cfg.get("descriptor_loss_func", "regular_sampling")
     if descriptor_loss_name == "regular_sampling":
         descriptor_loss_func = TripletDescriptorLoss(
             descriptor_loss_cfg["grid_size"],
             descriptor_loss_cfg["dist_threshold"],
-            descriptor_loss_cfg["margin"])
+            descriptor_loss_cfg["margin"],
+        )
     else:
         raise ValueError("[Error] Not supported descriptor loss function.")
 
@@ -128,79 +130,88 @@ def get_descriptor_loss_and_weight(model_cfg, global_w_policy):
 
 
 def space_to_depth(input_tensor, grid_size):
-    """ PixelUnshuffle for pytorch. """
+    """PixelUnshuffle for pytorch."""
     N, C, H, W = input_tensor.size()
     # (N, C, H//bs, bs, W//bs, bs)
     x = input_tensor.view(N, C, H // grid_size, grid_size, W // grid_size, grid_size)
     # (N, bs, bs, C, H//bs, W//bs)
     x = x.permute(0, 3, 5, 1, 2, 4).contiguous()
     # (N, C*bs^2, H//bs, W//bs)
-    x = x.view(N, C * (grid_size ** 2), H // grid_size, W // grid_size)
+    x = x.view(N, C * (grid_size**2), H // grid_size, W // grid_size)
     return x
 
 
-def junction_detection_loss(junction_map, junc_predictions, valid_mask=None,
-                            grid_size=8, keep_border=True):
-    """ Junction detection loss. """
+def junction_detection_loss(
+    junction_map, junc_predictions, valid_mask=None, grid_size=8, keep_border=True
+):
+    """Junction detection loss."""
     # Convert junc_map to channel tensor
     junc_map = space_to_depth(junction_map, grid_size)
     map_shape = junc_map.shape[-2:]
     batch_size = junc_map.shape[0]
-    dust_bin_label = torch.ones(
-        [batch_size, 1, map_shape[0],
-         map_shape[1]]).to(junc_map.device).to(torch.int)
-    junc_map = torch.cat([junc_map*2, dust_bin_label], dim=1)
+    dust_bin_label = (
+        torch.ones([batch_size, 1, map_shape[0], map_shape[1]])
+        .to(junc_map.device)
+        .to(torch.int)
+    )
+    junc_map = torch.cat([junc_map * 2, dust_bin_label], dim=1)
     labels = torch.argmax(
-        junc_map.to(torch.float) +
-        torch.distributions.Uniform(0, 0.1).sample(junc_map.shape).to(junc_map.device),
-        dim=1)
+        junc_map.to(torch.float)
+        + torch.distributions.Uniform(0, 0.1)
+        .sample(junc_map.shape)
+        .to(junc_map.device),
+        dim=1,
+    )
 
     # Also convert the valid mask to channel tensor
-    valid_mask = (torch.ones(junction_map.shape) if valid_mask is None
-                  else valid_mask)
+    valid_mask = torch.ones(junction_map.shape) if valid_mask is None else valid_mask
     valid_mask = space_to_depth(valid_mask, grid_size)
-    
+
     # Compute junction loss on the border patch or not
     if keep_border:
-        valid_mask = torch.sum(valid_mask.to(torch.bool).to(torch.int),
-                               dim=1, keepdim=True) > 0
+        valid_mask = (
+            torch.sum(valid_mask.to(torch.bool).to(torch.int), dim=1, keepdim=True) > 0
+        )
     else:
-        valid_mask = torch.sum(valid_mask.to(torch.bool).to(torch.int),
-                               dim=1, keepdim=True) >= grid_size * grid_size
+        valid_mask = (
+            torch.sum(valid_mask.to(torch.bool).to(torch.int), dim=1, keepdim=True)
+            >= grid_size * grid_size
+        )
 
     # Compute the classification loss
     loss_func = nn.CrossEntropyLoss(reduction="none")
     # The loss still need NCHW format
-    loss = loss_func(input=junc_predictions,
-                     target=labels.to(torch.long))
-    
+    loss = loss_func(input=junc_predictions, target=labels.to(torch.long))
+
     # Weighted sum by the valid mask
-    loss_ = torch.sum(loss * torch.squeeze(valid_mask.to(torch.float),
-                                           dim=1), dim=[0, 1, 2])
-    loss_final = loss_ / torch.sum(torch.squeeze(valid_mask.to(torch.float),
-                                                 dim=1))
+    loss_ = torch.sum(
+        loss * torch.squeeze(valid_mask.to(torch.float), dim=1), dim=[0, 1, 2]
+    )
+    loss_final = loss_ / torch.sum(torch.squeeze(valid_mask.to(torch.float), dim=1))
 
     return loss_final
 
 
-def heatmap_loss(heatmap_gt, heatmap_pred, valid_mask=None,
-                 class_weight=None):
-    """ Heatmap prediction loss. """
+def heatmap_loss(heatmap_gt, heatmap_pred, valid_mask=None, class_weight=None):
+    """Heatmap prediction loss."""
     # Compute the classification loss on each pixel
     if class_weight is None:
         loss_func = nn.CrossEntropyLoss(reduction="none")
     else:
         loss_func = nn.CrossEntropyLoss(class_weight, reduction="none")
 
-    loss = loss_func(input=heatmap_pred,
-                     target=torch.squeeze(heatmap_gt.to(torch.long), dim=1))
+    loss = loss_func(
+        input=heatmap_pred, target=torch.squeeze(heatmap_gt.to(torch.long), dim=1)
+    )
 
     # Weighted sum by the valid mask
     # Sum over H and W
-    loss_spatial_sum = torch.sum(loss * torch.squeeze(
-        valid_mask.to(torch.float), dim=1), dim=[1, 2])
-    valid_spatial_sum = torch.sum(torch.squeeze(valid_mask.to(torch.float32),
-                                                dim=1), dim=[1, 2])
+    loss_spatial_sum = torch.sum(
+        loss * torch.squeeze(valid_mask.to(torch.float), dim=1), dim=[1, 2]
+    )
+    valid_spatial_sum = torch.sum(
+        torch.squeeze(valid_mask.to(torch.float32), dim=1), dim=[1, 2]
+    )
     # Mean to single scalar over batch dimension
     loss = torch.sum(loss_spatial_sum) / torch.sum(valid_spatial_sum)
 
@@ -208,19 +219,22 @@ def heatmap_loss(heatmap_gt, heatmap_pred, valid_mask=None,
 
 
 class JunctionDetectionLoss(nn.Module):
-    """ Junction detection loss. """
+    """Junction detection loss."""
+
     def __init__(self, grid_size, keep_border):
         super(JunctionDetectionLoss, self).__init__()
         self.grid_size = grid_size
         self.keep_border = keep_border
 
     def forward(self, prediction, target, valid_mask=None):
-        return junction_detection_loss(target, prediction, valid_mask,
-                                       self.grid_size, self.keep_border)
+        return junction_detection_loss(
+            target, prediction, valid_mask, self.grid_size, self.keep_border
+        )
 
 
 class HeatmapLoss(nn.Module):
-    """ Heatmap prediction loss. """
+    """Heatmap prediction loss."""
+
     def __init__(self, class_weight):
         super(HeatmapLoss, self).__init__()
         self.class_weight = class_weight
@@ -230,7 +244,8 @@ class HeatmapLoss(nn.Module):
 
 
 class RegularizationLoss(nn.Module):
-    """ Module for regularization loss. """
+    """Module for regularization loss."""
+
     def __init__(self):
         super(RegularizationLoss, self).__init__()
         self.name = "regularization_loss"
@@ -242,14 +257,23 @@ class RegularizationLoss(nn.Module):
         for _, val in loss_weights.items():
             if isinstance(val, nn.Parameter):
                 loss += val
-        
+
         return loss
 
 
-def triplet_loss(desc_pred1, desc_pred2, points1, points2, line_indices,
-                 epoch, grid_size=8, dist_threshold=8,
-                 init_dist_threshold=64, margin=1):
-    """ Regular triplet loss for descriptor learning. """
+def triplet_loss(
+    desc_pred1,
+    desc_pred2,
+    points1,
+    points2,
+    line_indices,
+    epoch,
+    grid_size=8,
+    dist_threshold=8,
+    init_dist_threshold=64,
+    margin=1,
+):
+    """Regular triplet loss for descriptor learning."""
     b_size, _, Hc, Wc = desc_pred1.size()
     img_size = (Hc * grid_size, Wc * grid_size)
     device = desc_pred1.device
@@ -259,12 +283,11 @@ def triplet_loss(desc_pred1, desc_pred2, points1, points2, line_indices,
     valid_points = line_indices.bool().flatten()
     n_correct_points = torch.sum(valid_points).item()
     if n_correct_points == 0:
-        return torch.tensor(0., dtype=torch.float, device=device)
+        return torch.tensor(0.0, dtype=torch.float, device=device)
 
     # Check which keypoints are too close to be matched
     # dist_threshold is decreased at each epoch for easier training
-    dist_threshold = max(dist_threshold,
-                         2 * init_dist_threshold // (epoch + 1))
+    dist_threshold = max(dist_threshold, 2 * init_dist_threshold // (epoch + 1))
     dist_mask = get_dist_mask(points1, points2, valid_points, dist_threshold)
 
     # Additionally ban negative mining along the same line
@@ -276,11 +299,17 @@ def triplet_loss(desc_pred1, desc_pred2, points1, points2, line_indices,
     grid2 = keypoints_to_grid(points2, img_size)
 
     # Extract the descriptors
-    desc1 = F.grid_sample(desc_pred1, grid1).permute(
-        0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points]
+    desc1 = (
+        F.grid_sample(desc_pred1, grid1)
+        .permute(0, 2, 3, 1)
+        .reshape(b_size * n_points, -1)[valid_points]
+    )
     desc1 = F.normalize(desc1, dim=1)
-    desc2 = F.grid_sample(desc_pred2, grid2).permute(
-        0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points]
+    desc2 = (
+        F.grid_sample(desc_pred2, grid2)
+        .permute(0, 2, 3, 1)
+        .reshape(b_size * n_points, -1)[valid_points]
+    )
     desc2 = F.normalize(desc2, dim=1)
     desc_dists = 2 - 2 * (desc1 @ desc2.t())
 
@@ -288,20 +317,23 @@ def triplet_loss(desc_pred1, desc_pred2, points1, points2, line_indices,
     pos_dist = torch.diag(desc_dists)
 
     # Negative distance loss
-    max_dist = torch.tensor(4., dtype=torch.float, device=device)
+    max_dist = torch.tensor(4.0, dtype=torch.float, device=device)
     desc_dists[
         torch.arange(n_correct_points, dtype=torch.long),
-        torch.arange(n_correct_points, dtype=torch.long)] = max_dist
+        torch.arange(n_correct_points, dtype=torch.long),
+    ] = max_dist
     desc_dists[dist_mask] = max_dist
-    neg_dist = torch.min(torch.min(desc_dists, dim=1)[0],
-                         torch.min(desc_dists, dim=0)[0])
+    neg_dist = torch.min(
+        torch.min(desc_dists, dim=1)[0], torch.min(desc_dists, dim=0)[0]
+    )
 
     triplet_loss = F.relu(margin + pos_dist - neg_dist)
     return triplet_loss, grid1, grid2, valid_points
 
 
 class TripletDescriptorLoss(nn.Module):
-    """ Triplet descriptor loss. """
+    """Triplet descriptor loss."""
+
     def __init__(self, grid_size, dist_threshold, margin):
         super(TripletDescriptorLoss, self).__init__()
         self.grid_size = grid_size
@@ -309,23 +341,35 @@ class TripletDescriptorLoss(nn.Module):
         self.dist_threshold = dist_threshold
         self.margin = margin
 
-    def forward(self, desc_pred1, desc_pred2, points1,
-                points2, line_indices, epoch):
-        return self.descriptor_loss(desc_pred1, desc_pred2, points1,
-                                    points2, line_indices, epoch)
+    def forward(self, desc_pred1, desc_pred2, points1, points2, line_indices, epoch):
+        return self.descriptor_loss(
+            desc_pred1, desc_pred2, points1, points2, line_indices, epoch
+        )
 
     # The descriptor loss based on regularly sampled points along the lines
-    def descriptor_loss(self, desc_pred1, desc_pred2, points1,
-                        points2, line_indices, epoch):
-        return torch.mean(triplet_loss(
-            desc_pred1, desc_pred2, points1, points2, line_indices, epoch,
-            self.grid_size, self.dist_threshold, self.init_dist_threshold,
-            self.margin)[0])
+    def descriptor_loss(
+        self, desc_pred1, desc_pred2, points1, points2, line_indices, epoch
+    ):
+        return torch.mean(
+            triplet_loss(
+                desc_pred1,
+                desc_pred2,
+                points1,
+                points2,
+                line_indices,
+                epoch,
+                self.grid_size,
+                self.dist_threshold,
+                self.init_dist_threshold,
+                self.margin,
+            )[0]
+        )
 
 
 class TotalLoss(nn.Module):
-    """ Total loss summing junction, heatma, descriptor
-        and regularization losses. """
+    """Total loss summing junction, heatma, descriptor
+    and regularization losses."""
+
     def __init__(self, loss_funcs, loss_weights, weighting_policy):
         super(TotalLoss, self).__init__()
         # Whether we need to compute the descriptor loss
@@ -338,23 +382,26 @@ class TotalLoss(nn.Module):
         # Always add regularization loss (it will return zero if not used)
         self.loss_funcs["reg_loss"] = RegularizationLoss().cuda()
 
-    def forward(self, junc_pred, junc_target, heatmap_pred,
-                heatmap_target, valid_mask=None):
-        """ Detection only loss. """
+    def forward(
+        self, junc_pred, junc_target, heatmap_pred, heatmap_target, valid_mask=None
+    ):
+        """Detection only loss."""
         # Compute the junction loss
-        junc_loss = self.loss_funcs["junc_loss"](junc_pred, junc_target,
-                                                 valid_mask)
+        junc_loss = self.loss_funcs["junc_loss"](junc_pred, junc_target, valid_mask)
         # Compute the heatmap loss
         heatmap_loss = self.loss_funcs["heatmap_loss"](
-            heatmap_pred, heatmap_target, valid_mask)
+            heatmap_pred, heatmap_target, valid_mask
+        )
 
         # Compute the total loss.
         if self.weighting_policy == "dynamic":
             reg_loss = self.loss_funcs["reg_loss"](self.loss_weights)
-            total_loss = junc_loss * torch.exp(-self.loss_weights["w_junc"]) + \
-                         heatmap_loss * torch.exp(-self.loss_weights["w_heatmap"]) + \
-                         reg_loss
-            
+            total_loss = (
+                junc_loss * torch.exp(-self.loss_weights["w_junc"])
+                + heatmap_loss * torch.exp(-self.loss_weights["w_heatmap"])
+                + reg_loss
+            )
+
             return {
                 "total_loss": total_loss,
                 "junc_loss": junc_loss,
@@ -363,32 +410,47 @@ class TotalLoss(nn.Module):
                 "w_junc": torch.exp(-self.loss_weights["w_junc"]).item(),
                 "w_heatmap": torch.exp(-self.loss_weights["w_heatmap"]).item(),
             }
-        
+
         elif self.weighting_policy == "static":
-            total_loss = junc_loss * self.loss_weights["w_junc"] + \
-                         heatmap_loss * self.loss_weights["w_heatmap"]
-            
+            total_loss = (
+                junc_loss * self.loss_weights["w_junc"]
+                + heatmap_loss * self.loss_weights["w_heatmap"]
+            )
+
             return {
                 "total_loss": total_loss,
                 "junc_loss": junc_loss,
-                "heatmap_loss": heatmap_loss
+                "heatmap_loss": heatmap_loss,
             }
 
         else:
             raise ValueError("[Error] Unknown weighting policy.")
-    
-    def forward_descriptors(self, 
-            junc_map_pred1, junc_map_pred2, junc_map_target1,
-            junc_map_target2, heatmap_pred1, heatmap_pred2, heatmap_target1,
-            heatmap_target2, line_points1, line_points2, line_indices,
-            desc_pred1, desc_pred2, epoch, valid_mask1=None,
-            valid_mask2=None):
-        """ Loss for detection + description. """
+
+    def forward_descriptors(
+        self,
+        junc_map_pred1,
+        junc_map_pred2,
+        junc_map_target1,
+        junc_map_target2,
+        heatmap_pred1,
+        heatmap_pred2,
+        heatmap_target1,
+        heatmap_target2,
+        line_points1,
+        line_points2,
+        line_indices,
+        desc_pred1,
+        desc_pred2,
+        epoch,
+        valid_mask1=None,
+        valid_mask2=None,
+    ):
+        """Loss for detection + description."""
         # Compute junction loss
         junc_loss = self.loss_funcs["junc_loss"](
-            torch.cat([junc_map_pred1, junc_map_pred2], dim=0), 
+            torch.cat([junc_map_pred1, junc_map_pred2], dim=0),
             torch.cat([junc_map_target1, junc_map_target2], dim=0),
-            torch.cat([valid_mask1, valid_mask2], dim=0)
+            torch.cat([valid_mask1, valid_mask2], dim=0),
         )
         # Get junction loss weight (dynamic or not)
         if isinstance(self.loss_weights["w_junc"], nn.Parameter):
@@ -398,9 +460,9 @@ class TotalLoss(nn.Module):
 
         # Compute heatmap loss
         heatmap_loss = self.loss_funcs["heatmap_loss"](
-            torch.cat([heatmap_pred1, heatmap_pred2], dim=0), 
+            torch.cat([heatmap_pred1, heatmap_pred2], dim=0),
             torch.cat([heatmap_target1, heatmap_target2], dim=0),
-            torch.cat([valid_mask1, valid_mask2], dim=0)
+            torch.cat([valid_mask1, valid_mask2], dim=0),
         )
         # Get heatmap loss weight (dynamic or not)
         if isinstance(self.loss_weights["w_heatmap"], nn.Parameter):
@@ -410,8 +472,8 @@ class TotalLoss(nn.Module):
 
         # Compute the descriptor loss
         descriptor_loss = self.loss_funcs["descriptor_loss"](
-            desc_pred1, desc_pred2, line_points1,
-            line_points2, line_indices, epoch)
+            desc_pred1, desc_pred2, line_points1, line_points2, line_indices, epoch
+        )
         # Get descriptor loss weight (dynamic or not)
         if isinstance(self.loss_weights["w_desc"], nn.Parameter):
             w_descriptor = torch.exp(-self.loss_weights["w_desc"])
@@ -419,27 +481,27 @@ class TotalLoss(nn.Module):
             w_descriptor = self.loss_weights["w_desc"]
 
         # Update the total loss
-        total_loss = (junc_loss * w_junc
-                      + heatmap_loss * w_heatmap
-                      + descriptor_loss * w_descriptor)
+        total_loss = (
+            junc_loss * w_junc
+            + heatmap_loss * w_heatmap
+            + descriptor_loss * w_descriptor
+        )
         outputs = {
             "junc_loss": junc_loss,
             "heatmap_loss": heatmap_loss,
-            "w_junc": w_junc.item() \
-                if isinstance(w_junc, nn.Parameter) else w_junc,
-            "w_heatmap": w_heatmap.item() \
-                if isinstance(w_heatmap, nn.Parameter) else w_heatmap,
+            "w_junc": w_junc.item() if isinstance(w_junc, nn.Parameter) else w_junc,
+            "w_heatmap": w_heatmap.item()
+            if isinstance(w_heatmap, nn.Parameter)
+            else w_heatmap,
             "descriptor_loss": descriptor_loss,
-            "w_desc": w_descriptor.item() \
-                if isinstance(w_descriptor, nn.Parameter) else w_descriptor
+            "w_desc": w_descriptor.item()
+            if isinstance(w_descriptor, nn.Parameter)
+            else w_descriptor,
         }
-        
+
         # Compute the regularization loss
         reg_loss = self.loss_funcs["reg_loss"](self.loss_weights)
         total_loss += reg_loss
-        outputs.update({
-            "reg_loss": reg_loss,
-            "total_loss": total_loss
-        })
+        outputs.update({"reg_loss": reg_loss, "total_loss": total_loss})
 
         return outputs
diff --git a/third_party/SOLD2/sold2/model/lr_scheduler.py b/third_party/SOLD2/sold2/model/lr_scheduler.py
index 3faa4f68a67564719008a932b40c16c5e908949f..fa3f5903c92a61f01eaa8aed95fb2261212f3762 100644
--- a/third_party/SOLD2/sold2/model/lr_scheduler.py
+++ b/third_party/SOLD2/sold2/model/lr_scheduler.py
@@ -5,18 +5,17 @@ import torch
 
 
 def get_lr_scheduler(lr_decay, lr_decay_cfg, optimizer):
-    """ Get the learning rate scheduler according to the config. """
+    """Get the learning rate scheduler according to the config."""
     # If no lr_decay is specified => return None
     if (lr_decay == False) or (lr_decay_cfg is None):
         schduler = None
     # Exponential decay
     elif (lr_decay == True) and (lr_decay_cfg["policy"] == "exp"):
         schduler = torch.optim.lr_scheduler.ExponentialLR(
-            optimizer, 
-            gamma=lr_decay_cfg["gamma"]
+            optimizer, gamma=lr_decay_cfg["gamma"]
         )
     # Unknown policy
     else:
         raise ValueError("[Error] Unknow learning rate decay policy!")
 
-    return schduler
\ No newline at end of file
+    return schduler
diff --git a/third_party/SOLD2/sold2/model/metrics.py b/third_party/SOLD2/sold2/model/metrics.py
index 0894a7207ee4afa344cb332c605c715b14db73a4..668daaf99acb9bbb80d7ca2746926f9d79d55cf0 100644
--- a/third_party/SOLD2/sold2/model/metrics.py
+++ b/third_party/SOLD2/sold2/model/metrics.py
@@ -10,15 +10,26 @@ from ..misc.geometry_utils import keypoints_to_grid
 
 
 class Metrics(object):
-    """ Metric evaluation calculator. """
-    def __init__(self, detection_thresh, prob_thresh, grid_size,
-                 junc_metric_lst=None, heatmap_metric_lst=None,
-                 pr_metric_lst=None, desc_metric_lst=None):
+    """Metric evaluation calculator."""
+
+    def __init__(
+        self,
+        detection_thresh,
+        prob_thresh,
+        grid_size,
+        junc_metric_lst=None,
+        heatmap_metric_lst=None,
+        pr_metric_lst=None,
+        desc_metric_lst=None,
+    ):
         # List supported metrics
-        self.supported_junc_metrics = ["junc_precision", "junc_precision_nms",
-                                       "junc_recall", "junc_recall_nms"]
-        self.supported_heatmap_metrics = ["heatmap_precision",
-                                          "heatmap_recall"]
+        self.supported_junc_metrics = [
+            "junc_precision",
+            "junc_precision_nms",
+            "junc_recall",
+            "junc_recall_nms",
+        ]
+        self.supported_heatmap_metrics = ["heatmap_precision", "heatmap_recall"]
         self.supported_pr_metrics = ["junc_pr", "junc_nms_pr"]
         self.supported_desc_metrics = ["matching_score"]
 
@@ -38,14 +49,13 @@ class Metrics(object):
         # For the descriptors, the default None assumes no desc metric at all
         if desc_metric_lst is None:
             self.desc_metric_lst = []
-        elif desc_metric_lst == 'all':
+        elif desc_metric_lst == "all":
             self.desc_metric_lst = self.supported_desc_metrics
         else:
             self.desc_metric_lst = desc_metric_lst
 
         if not self._check_metrics():
-            raise ValueError(
-                "[Error] Some elements in the metric_lst are invalid.")
+            raise ValueError("[Error] Some elements in the metric_lst are invalid.")
 
         # Metric mapping table
         self.metric_table = {
@@ -57,18 +67,29 @@ class Metrics(object):
             "heatmap_recall": heatmap_recall(prob_thresh),
             "junc_pr": junction_pr(),
             "junc_nms_pr": junction_pr(),
-            "matching_score": matching_score(grid_size)
+            "matching_score": matching_score(grid_size),
         }
 
         # Initialize the results
         self.metric_results = {}
         for key in self.metric_table.keys():
-            self.metric_results[key] = 0.
-
-    def evaluate(self, junc_pred, junc_pred_nms, junc_gt, heatmap_pred,
-                 heatmap_gt, valid_mask, line_points1=None, line_points2=None,
-                 desc_pred1=None, desc_pred2=None, valid_points=None):
-        """ Perform evaluation. """
+            self.metric_results[key] = 0.0
+
+    def evaluate(
+        self,
+        junc_pred,
+        junc_pred_nms,
+        junc_gt,
+        heatmap_pred,
+        heatmap_gt,
+        valid_mask,
+        line_points1=None,
+        line_points2=None,
+        desc_pred1=None,
+        desc_pred2=None,
+        valid_points=None,
+    ):
+        """Perform evaluation."""
         for metric in self.junc_metric_lst:
             # If nms metrics then use nms to compute it.
             if "nms" in metric:
@@ -77,27 +98,31 @@ class Metrics(object):
             else:
                 junc_pred_input = junc_pred
             self.metric_results[metric] = self.metric_table[metric](
-                junc_pred_input, junc_gt, valid_mask)
+                junc_pred_input, junc_gt, valid_mask
+            )
 
         for metric in self.heatmap_metric_lst:
             self.metric_results[metric] = self.metric_table[metric](
-                heatmap_pred, heatmap_gt, valid_mask)
+                heatmap_pred, heatmap_gt, valid_mask
+            )
 
         for metric in self.pr_metric_lst:
             if "nms" in metric:
                 self.metric_results[metric] = self.metric_table[metric](
-                    junc_pred_nms, junc_gt, valid_mask)
+                    junc_pred_nms, junc_gt, valid_mask
+                )
             else:
                 self.metric_results[metric] = self.metric_table[metric](
-                    junc_pred, junc_gt, valid_mask)
+                    junc_pred, junc_gt, valid_mask
+                )
 
         for metric in self.desc_metric_lst:
             self.metric_results[metric] = self.metric_table[metric](
-                line_points1, line_points2, desc_pred1,
-                desc_pred2, valid_points)
+                line_points1, line_points2, desc_pred1, desc_pred2, valid_points
+            )
 
     def _check_metrics(self):
-        """ Check if all input metrics are valid. """
+        """Check if all input metrics are valid."""
         flag = True
         for metric in self.junc_metric_lst:
             if not metric in self.supported_junc_metrics:
@@ -116,19 +141,31 @@ class Metrics(object):
 
 
 class AverageMeter(object):
-    def __init__(self, junc_metric_lst=None, heatmap_metric_lst=None,
-                 is_training=True, desc_metric_lst=None):
+    def __init__(
+        self,
+        junc_metric_lst=None,
+        heatmap_metric_lst=None,
+        is_training=True,
+        desc_metric_lst=None,
+    ):
         # List supported metrics
-        self.supported_junc_metrics = ["junc_precision", "junc_precision_nms",
-                                       "junc_recall", "junc_recall_nms"]
-        self.supported_heatmap_metrics = ["heatmap_precision",
-                                          "heatmap_recall"]
+        self.supported_junc_metrics = [
+            "junc_precision",
+            "junc_precision_nms",
+            "junc_recall",
+            "junc_recall_nms",
+        ]
+        self.supported_heatmap_metrics = ["heatmap_precision", "heatmap_recall"]
         self.supported_pr_metrics = ["junc_pr", "junc_nms_pr"]
         self.supported_desc_metrics = ["matching_score"]
         # Record loss in training mode
         # if is_training:
         self.supported_loss = [
-            "junc_loss", "heatmap_loss", "descriptor_loss", "total_loss"]
+            "junc_loss",
+            "heatmap_loss",
+            "descriptor_loss",
+            "total_loss",
+        ]
 
         self.is_training = is_training
 
@@ -144,21 +181,23 @@ class AverageMeter(object):
         # For the descriptors, the default None assumes no desc metric at all
         if desc_metric_lst is None:
             self.desc_metric_lst = []
-        elif desc_metric_lst == 'all':
+        elif desc_metric_lst == "all":
             self.desc_metric_lst = self.supported_desc_metrics
         else:
             self.desc_metric_lst = desc_metric_lst
 
         if not self._check_metrics():
-            raise ValueError(
-                "[Error] Some elements in the metric_lst are invalid.")
+            raise ValueError("[Error] Some elements in the metric_lst are invalid.")
 
         # Initialize the results
         self.metric_results = {}
-        for key in (self.supported_junc_metrics
-                    + self.supported_heatmap_metrics
-                    + self.supported_loss + self.supported_desc_metrics):
-            self.metric_results[key] = 0.
+        for key in (
+            self.supported_junc_metrics
+            + self.supported_heatmap_metrics
+            + self.supported_loss
+            + self.supported_desc_metrics
+        ):
+            self.metric_results[key] = 0.0
         for key in self.supported_pr_metrics:
             zero_lst = [0 for _ in range(50)]
             self.metric_results[key] = {
@@ -167,7 +206,7 @@ class AverageMeter(object):
                 "fp": zero_lst,
                 "fn": zero_lst,
                 "precision": zero_lst,
-                "recall": zero_lst
+                "recall": zero_lst,
             }
 
         # Initialize total count
@@ -176,18 +215,18 @@ class AverageMeter(object):
     def update(self, metrics, loss_dict=None, num_samples=1):
         # loss should be given in the training mode
         if self.is_training and (loss_dict is None):
-            raise ValueError(
-                "[Error] loss info should be given in the training mode.")
+            raise ValueError("[Error] loss info should be given in the training mode.")
 
         # update total counts
         self.count += num_samples
 
         # update all the metrics
-        for met in (self.supported_junc_metrics
-                    + self.supported_heatmap_metrics
-                    + self.supported_desc_metrics):
-            self.metric_results[met] += (num_samples
-                                         * metrics.metric_results[met])
+        for met in (
+            self.supported_junc_metrics
+            + self.supported_heatmap_metrics
+            + self.supported_desc_metrics
+        ):
+            self.metric_results[met] += num_samples * metrics.metric_results[met]
 
         # Update all the losses
         for loss in loss_dict.keys():
@@ -200,8 +239,8 @@ class AverageMeter(object):
                 # Update each interval
                 for idx in range(len(self.metric_results[pr_met][key])):
                     self.metric_results[pr_met][key][idx] += (
-                        num_samples
-                        * metrics.metric_results[pr_met][key][idx])
+                        num_samples * metrics.metric_results[pr_met][key][idx]
+                    )
 
     def average(self):
         results = {}
@@ -217,21 +256,22 @@ class AverageMeter(object):
                     "fp": self.metric_results[met]["fp"],
                     "fn": self.metric_results[met]["fn"],
                     "precision": [],
-                    "recall": []
+                    "recall": [],
                 }
                 for idx in range(len(self.metric_results[met]["precision"])):
                     met_results["precision"].append(
-                        self.metric_results[met]["precision"][idx]
-                        / self.count)
+                        self.metric_results[met]["precision"][idx] / self.count
+                    )
                     met_results["recall"].append(
-                        self.metric_results[met]["recall"][idx] / self.count)
+                        self.metric_results[met]["recall"][idx] / self.count
+                    )
 
                 results[met] = met_results
 
         return results
 
     def _check_metrics(self):
-        """ Check if all input metrics are valid. """
+        """Check if all input metrics are valid."""
         flag = True
         for metric in self.junc_metric_lst:
             if not metric in self.supported_junc_metrics:
@@ -250,7 +290,8 @@ class AverageMeter(object):
 
 
 class junction_precision(object):
-    """ Junction precision. """
+    """Junction precision."""
+
     def __init__(self, detection_thresh):
         self.detection_thresh = detection_thresh
 
@@ -262,8 +303,7 @@ class junction_precision(object):
 
         # Deal with the corner case of the prediction
         if np.sum(junc_pred) > 0:
-            precision = (np.sum(junc_pred * junc_gt.squeeze())
-                         / np.sum(junc_pred))
+            precision = np.sum(junc_pred * junc_gt.squeeze()) / np.sum(junc_pred)
         else:
             precision = 0
 
@@ -271,7 +311,8 @@ class junction_precision(object):
 
 
 class junction_recall(object):
-    """ Junction recall. """
+    """Junction recall."""
+
     def __init__(self, detection_thresh):
         self.detection_thresh = detection_thresh
 
@@ -291,7 +332,8 @@ class junction_recall(object):
 
 
 class junction_pr(object):
-    """ Junction precision-recall info. """
+    """Junction precision-recall info."""
+
     def __init__(self, num_threshold=50):
         self.max = 0.4
         step = self.max / num_threshold
@@ -316,12 +358,21 @@ class junction_pr(object):
             # Compute tp, fp, tn, fn
             junc_gt = junc_gt.squeeze()
             tp = np.sum(junc_pred * junc_gt)
-            tn = np.sum((junc_pred == 0).astype(np.float)
-                        * (junc_gt == 0).astype(np.float) * valid_mask)
-            fp = np.sum((junc_pred == 1).astype(np.float)
-                        * (junc_gt == 0).astype(np.float) * valid_mask)
-            fn = np.sum((junc_pred == 0).astype(np.float)
-                        * (junc_gt == 1).astype(np.float) * valid_mask)
+            tn = np.sum(
+                (junc_pred == 0).astype(np.float)
+                * (junc_gt == 0).astype(np.float)
+                * valid_mask
+            )
+            fp = np.sum(
+                (junc_pred == 1).astype(np.float)
+                * (junc_gt == 0).astype(np.float)
+                * valid_mask
+            )
+            fn = np.sum(
+                (junc_pred == 0).astype(np.float)
+                * (junc_gt == 1).astype(np.float)
+                * valid_mask
+            )
 
             tp_lst.append(tp)
             tn_lst.append(tn)
@@ -336,12 +387,13 @@ class junction_pr(object):
             "fp": np.array(fp_lst),
             "fn": np.array(fn_lst),
             "precision": np.array(precision_lst),
-            "recall": np.array(recall_lst)
+            "recall": np.array(recall_lst),
         }
 
 
 class heatmap_precision(object):
-    """ Heatmap precision. """
+    """Heatmap precision."""
+
     def __init__(self, prob_thresh):
         self.prob_thresh = prob_thresh
 
@@ -352,16 +404,18 @@ class heatmap_precision(object):
 
         # Deal with the corner case of the prediction
         if np.sum(heatmap_pred) > 0:
-            precision = (np.sum(heatmap_pred * heatmap_gt.squeeze())
-                         / np.sum(heatmap_pred))
+            precision = np.sum(heatmap_pred * heatmap_gt.squeeze()) / np.sum(
+                heatmap_pred
+            )
         else:
-            precision = 0.
+            precision = 0.0
 
         return precision
 
 
 class heatmap_recall(object):
-    """ Heatmap recall. """
+    """Heatmap recall."""
+
     def __init__(self, prob_thresh):
         self.prob_thresh = prob_thresh
 
@@ -372,21 +426,20 @@ class heatmap_recall(object):
 
         # Deal with the corner case of the ground truth
         if np.sum(heatmap_gt) > 0:
-            recall = (np.sum(heatmap_pred * heatmap_gt.squeeze())
-                      / np.sum(heatmap_gt))
+            recall = np.sum(heatmap_pred * heatmap_gt.squeeze()) / np.sum(heatmap_gt)
         else:
-            recall = 0.
+            recall = 0.0
 
         return recall
 
 
 class matching_score(object):
-    """ Descriptors matching score. """
+    """Descriptors matching score."""
+
     def __init__(self, grid_size):
         self.grid_size = grid_size
 
-    def __call__(self, points1, points2, desc_pred1,
-                 desc_pred2, line_indices):
+    def __call__(self, points1, points2, desc_pred1, desc_pred2, line_indices):
         b_size, _, Hc, Wc = desc_pred1.size()
         img_size = (Hc * self.grid_size, Wc * self.grid_size)
         device = desc_pred1.device
@@ -396,32 +449,37 @@ class matching_score(object):
         valid_points = line_indices.bool().flatten()
         n_correct_points = torch.sum(valid_points).item()
         if n_correct_points == 0:
-            return torch.tensor(0., dtype=torch.float, device=device)
+            return torch.tensor(0.0, dtype=torch.float, device=device)
 
         # Convert the keypoints to a grid suitable for interpolation
         grid1 = keypoints_to_grid(points1, img_size)
         grid2 = keypoints_to_grid(points2, img_size)
 
         # Extract the descriptors
-        desc1 = F.grid_sample(desc_pred1, grid1).permute(
-            0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points]
+        desc1 = (
+            F.grid_sample(desc_pred1, grid1)
+            .permute(0, 2, 3, 1)
+            .reshape(b_size * n_points, -1)[valid_points]
+        )
         desc1 = F.normalize(desc1, dim=1)
-        desc2 = F.grid_sample(desc_pred2, grid2).permute(
-            0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points]
+        desc2 = (
+            F.grid_sample(desc_pred2, grid2)
+            .permute(0, 2, 3, 1)
+            .reshape(b_size * n_points, -1)[valid_points]
+        )
         desc2 = F.normalize(desc2, dim=1)
         desc_dists = 2 - 2 * (desc1 @ desc2.t())
 
         # Compute percentage of correct matches
         matches0 = torch.min(desc_dists, dim=1)[1]
         matches1 = torch.min(desc_dists, dim=0)[1]
-        matching_score = (matches1[matches0]
-                          == torch.arange(len(matches0)).to(device))
+        matching_score = matches1[matches0] == torch.arange(len(matches0)).to(device)
         matching_score = matching_score.float().mean()
         return matching_score
 
 
 def super_nms(prob_predictions, dist_thresh, prob_thresh=0.01, top_k=0):
-    """ Non-maximum suppression adapted from SuperPoint. """
+    """Non-maximum suppression adapted from SuperPoint."""
     # Iterate through batch dimension
     im_h = prob_predictions.shape[1]
     im_w = prob_predictions.shape[2]
@@ -430,17 +488,19 @@ def super_nms(prob_predictions, dist_thresh, prob_thresh=0.01, top_k=0):
         # print(i)
         prob_pred = prob_predictions[i, ...]
         # Filter the points using prob_thresh
-        coord = np.where(prob_pred >= prob_thresh) # HW format
-        points = np.concatenate((coord[0][..., None], coord[1][..., None]),
-                                axis=1) # HW format
+        coord = np.where(prob_pred >= prob_thresh)  # HW format
+        points = np.concatenate(
+            (coord[0][..., None], coord[1][..., None]), axis=1
+        )  # HW format
 
         # Get the probability score
         prob_score = prob_pred[points[:, 0], points[:, 1]]
 
         # Perform super nms
         # Modify the in_points to xy format (instead of HW format)
-        in_points = np.concatenate((coord[1][..., None], coord[0][..., None],
-                                    prob_score), axis=1).T
+        in_points = np.concatenate(
+            (coord[1][..., None], coord[0][..., None], prob_score), axis=1
+        ).T
         keep_points_, keep_inds = nms_fast(in_points, im_h, im_w, dist_thresh)
         # Remember to flip outputs back to HW format
         keep_points = np.round(np.flip(keep_points_[:2, :], axis=0).T)
@@ -454,8 +514,9 @@ def super_nms(prob_predictions, dist_thresh, prob_thresh=0.01, top_k=0):
 
         # Re-compose the probability map
         output_map = np.zeros([im_h, im_w])
-        output_map[keep_points[:, 0].astype(np.int),
-                   keep_points[:, 1].astype(np.int)] = keep_score.squeeze()
+        output_map[
+            keep_points[:, 0].astype(np.int), keep_points[:, 1].astype(np.int)
+        ] = keep_score.squeeze()
 
         output_lst.append(output_map[None, ...])
 
@@ -506,14 +567,14 @@ def nms_fast(in_corners, H, W, dist_thresh):
         inds[rcorners[1, i], rcorners[0, i]] = i
     # Pad the border of the grid, so that we can NMS points near the border.
     pad = dist_thresh
-    grid = np.pad(grid, ((pad, pad), (pad, pad)), mode='constant')
+    grid = np.pad(grid, ((pad, pad), (pad, pad)), mode="constant")
     # Iterate through points, highest to lowest conf, suppress neighborhood.
     count = 0
     for i, rc in enumerate(rcorners.T):
         # Account for top and left padding.
         pt = (rc[0] + pad, rc[1] + pad)
         if grid[pt[1], pt[0]] == 1:  # If not yet suppressed.
-            grid[pt[1] - pad:pt[1] + pad + 1, pt[0] - pad:pt[0] + pad + 1] = 0
+            grid[pt[1] - pad : pt[1] + pad + 1, pt[0] - pad : pt[0] + pad + 1] = 0
             grid[pt[1], pt[0]] = -1
             count += 1
     # Get all surviving -1's and return sorted array of remaining corners.
diff --git a/third_party/SOLD2/sold2/model/model_util.py b/third_party/SOLD2/sold2/model/model_util.py
index f70d80da40a72c207edfcfc1509e820846f0b731..037239e45d50123c7d679e36df5c6b0de314fa8b 100644
--- a/third_party/SOLD2/sold2/model/model_util.py
+++ b/third_party/SOLD2/sold2/model/model_util.py
@@ -9,7 +9,7 @@ from .nets.descriptor_decoder import SuperpointDescriptor
 
 
 def get_model(model_cfg=None, loss_weights=None, mode="train"):
-    """ Get model based on the model configuration. """
+    """Get model based on the model configuration."""
     # Check dataset config is given
     if model_cfg is None:
         raise ValueError("[Error] The model config is required!")
@@ -18,26 +18,27 @@ def get_model(model_cfg=None, loss_weights=None, mode="train"):
     print("\n\n\t--------Initializing model----------")
     supported_arch = ["simple"]
     if not model_cfg["model_architecture"] in supported_arch:
-        raise ValueError(
-            "[Error] The model architecture is not in supported arch!")
+        raise ValueError("[Error] The model architecture is not in supported arch!")
 
     if model_cfg["model_architecture"] == "simple":
         model = SOLD2Net(model_cfg)
     else:
-        raise ValueError(
-            "[Error] The model architecture is not in supported arch!")
+        raise ValueError("[Error] The model architecture is not in supported arch!")
 
     # Optionally register loss weights to the model
     if mode == "train":
         if loss_weights is not None:
             for param_name, param in loss_weights.items():
                 if isinstance(param, nn.Parameter):
-                    print("\t [Debug] Adding %s with value %f to model"
-                          % (param_name, param.item()))
+                    print(
+                        "\t [Debug] Adding %s with value %f to model"
+                        % (param_name, param.item())
+                    )
                     model.register_parameter(param_name, param)
         else:
             raise ValueError(
-                "[Error] the loss weights can not be None in dynamic weighting mode during training.")
+                "[Error] the loss weights can not be None in dynamic weighting mode during training."
+            )
 
     # Display some summary info.
     print("\tModel architecture: %s" % model_cfg["model_architecture"])
@@ -50,7 +51,8 @@ def get_model(model_cfg=None, loss_weights=None, mode="train"):
 
 
 class SOLD2Net(nn.Module):
-    """ Full network for SOLD². """
+    """Full network for SOLD²."""
+
     def __init__(self, model_cfg):
         super(SOLD2Net, self).__init__()
         self.name = model_cfg["model_name"]
@@ -65,8 +67,7 @@ class SOLD2Net(nn.Module):
         self.junction_decoder = self.get_junction_decoder()
 
         # List supported heatmap decoder options
-        self.supported_heatmap_decoder = ["pixel_shuffle",
-                                          "pixel_shuffle_single"]
+        self.supported_heatmap_decoder = ["pixel_shuffle", "pixel_shuffle_single"]
         self.heatmap_decoder = self.get_heatmap_decoder()
 
         # List supported descriptor decoder options
@@ -96,10 +97,9 @@ class SOLD2Net(nn.Module):
         return outputs
 
     def get_backbone(self):
-        """ Retrieve the backbone encoder network. """
+        """Retrieve the backbone encoder network."""
         if not self.cfg["backbone"] in self.supported_backbone:
-            raise ValueError(
-                "[Error] The backbone selection is not supported.")
+            raise ValueError("[Error] The backbone selection is not supported.")
 
         # lcnn backbone (stacked hourglass)
         if self.cfg["backbone"] == "lcnn":
@@ -113,79 +113,73 @@ class SOLD2Net(nn.Module):
             feat_channel = 128
 
         else:
-            raise ValueError(
-                "[Error] The backbone selection is not supported.")
+            raise ValueError("[Error] The backbone selection is not supported.")
 
         return backbone, feat_channel
 
     def get_junction_decoder(self):
-        """ Get the junction decoder. """
-        if (not self.cfg["junction_decoder"]
-            in self.supported_junction_decoder):
-            raise ValueError(
-                "[Error] The junction decoder selection is not supported.")
+        """Get the junction decoder."""
+        if not self.cfg["junction_decoder"] in self.supported_junction_decoder:
+            raise ValueError("[Error] The junction decoder selection is not supported.")
 
         # superpoint decoder
         if self.cfg["junction_decoder"] == "superpoint_decoder":
-            decoder = SuperpointDecoder(self.feat_channel,
-                                        self.cfg["backbone"])
+            decoder = SuperpointDecoder(self.feat_channel, self.cfg["backbone"])
         else:
-            raise ValueError(
-                "[Error] The junction decoder selection is not supported.")
+            raise ValueError("[Error] The junction decoder selection is not supported.")
 
         return decoder
 
     def get_heatmap_decoder(self):
-        """ Get the heatmap decoder. """
+        """Get the heatmap decoder."""
         if not self.cfg["heatmap_decoder"] in self.supported_heatmap_decoder:
-            raise ValueError(
-                "[Error] The heatmap decoder selection is not supported.")
+            raise ValueError("[Error] The heatmap decoder selection is not supported.")
 
         # Pixel_shuffle decoder
         if self.cfg["heatmap_decoder"] == "pixel_shuffle":
             if self.cfg["backbone"] == "lcnn":
-                decoder = PixelShuffleDecoder(self.feat_channel,
-                                              num_upsample=2)
+                decoder = PixelShuffleDecoder(self.feat_channel, num_upsample=2)
             elif self.cfg["backbone"] == "superpoint":
-                decoder = PixelShuffleDecoder(self.feat_channel,
-                                              num_upsample=3)
+                decoder = PixelShuffleDecoder(self.feat_channel, num_upsample=3)
             else:
                 raise ValueError("[Error] Unknown backbone option.")
         # Pixel_shuffle decoder with single channel output
         elif self.cfg["heatmap_decoder"] == "pixel_shuffle_single":
             if self.cfg["backbone"] == "lcnn":
                 decoder = PixelShuffleDecoder(
-                    self.feat_channel, num_upsample=2, output_channel=1)
+                    self.feat_channel, num_upsample=2, output_channel=1
+                )
             elif self.cfg["backbone"] == "superpoint":
                 decoder = PixelShuffleDecoder(
-                    self.feat_channel, num_upsample=3, output_channel=1)
+                    self.feat_channel, num_upsample=3, output_channel=1
+                )
             else:
                 raise ValueError("[Error] Unknown backbone option.")
         else:
-            raise ValueError(
-                "[Error] The heatmap decoder selection is not supported.")
+            raise ValueError("[Error] The heatmap decoder selection is not supported.")
 
         return decoder
 
     def get_descriptor_decoder(self):
-        """ Get the descriptor decoder. """
-        if (not self.cfg["descriptor_decoder"]
-            in self.supported_descriptor_decoder):
+        """Get the descriptor decoder."""
+        if not self.cfg["descriptor_decoder"] in self.supported_descriptor_decoder:
             raise ValueError(
-                "[Error] The descriptor decoder selection is not supported.")
+                "[Error] The descriptor decoder selection is not supported."
+            )
 
         # SuperPoint descriptor
         if self.cfg["descriptor_decoder"] == "superpoint_descriptor":
             decoder = SuperpointDescriptor(self.feat_channel)
         else:
             raise ValueError(
-                "[Error] The descriptor decoder selection is not supported.")
+                "[Error] The descriptor decoder selection is not supported."
+            )
 
         return decoder
 
 
 def weight_init(m):
-    """ Weight initialization function. """
+    """Weight initialization function."""
     # Conv2D
     if isinstance(m, nn.Conv2d):
         init.xavier_normal_(m.weight.data)
diff --git a/third_party/SOLD2/sold2/model/nets/backbone.py b/third_party/SOLD2/sold2/model/nets/backbone.py
index 71f260aef108c77d54319cab7bc082c3c51112e7..26b5a1366223b9148bc110ec28917cc1f81b5cbf 100644
--- a/third_party/SOLD2/sold2/model/nets/backbone.py
+++ b/third_party/SOLD2/sold2/model/nets/backbone.py
@@ -5,49 +5,46 @@ from .lcnn_hourglass import MultitaskHead, hg
 
 
 class HourglassBackbone(nn.Module):
-    """ Hourglass backbone. """
-    def __init__(self, input_channel=1, depth=4, num_stacks=2,
-                 num_blocks=1, num_classes=5):
+    """Hourglass backbone."""
+
+    def __init__(
+        self, input_channel=1, depth=4, num_stacks=2, num_blocks=1, num_classes=5
+    ):
         super(HourglassBackbone, self).__init__()
         self.head = MultitaskHead
-        self.net = hg(**{
-            "head": self.head,
-            "depth": depth,
-            "num_stacks": num_stacks,
-            "num_blocks": num_blocks,
-            "num_classes": num_classes,
-            "input_channels": input_channel
-        })
+        self.net = hg(
+            **{
+                "head": self.head,
+                "depth": depth,
+                "num_stacks": num_stacks,
+                "num_blocks": num_blocks,
+                "num_classes": num_classes,
+                "input_channels": input_channel,
+            }
+        )
 
     def forward(self, input_images):
         return self.net(input_images)[1]
 
 
 class SuperpointBackbone(nn.Module):
-    """ SuperPoint backbone. """
+    """SuperPoint backbone."""
+
     def __init__(self):
         super(SuperpointBackbone, self).__init__()
         self.relu = torch.nn.ReLU(inplace=True)
         self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
         c1, c2, c3, c4 = 64, 64, 128, 128
         # Shared Encoder.
-        self.conv1a = torch.nn.Conv2d(1, c1, kernel_size=3,
-                                      stride=1, padding=1)
-        self.conv1b = torch.nn.Conv2d(c1, c1, kernel_size=3,
-                                      stride=1, padding=1)
-        self.conv2a = torch.nn.Conv2d(c1, c2, kernel_size=3,
-                                      stride=1, padding=1)
-        self.conv2b = torch.nn.Conv2d(c2, c2, kernel_size=3,
-                                      stride=1, padding=1)
-        self.conv3a = torch.nn.Conv2d(c2, c3, kernel_size=3,
-                                      stride=1, padding=1)
-        self.conv3b = torch.nn.Conv2d(c3, c3, kernel_size=3,
-                                      stride=1, padding=1)
-        self.conv4a = torch.nn.Conv2d(c3, c4, kernel_size=3,
-                                      stride=1, padding=1)
-        self.conv4b = torch.nn.Conv2d(c4, c4, kernel_size=3,
-                                      stride=1, padding=1)
-    
+        self.conv1a = torch.nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
+        self.conv1b = torch.nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
+        self.conv2a = torch.nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
+        self.conv2b = torch.nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
+        self.conv3a = torch.nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
+        self.conv3b = torch.nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
+        self.conv4a = torch.nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
+        self.conv4b = torch.nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
+
     def forward(self, input_images):
         # Shared Encoder.
         x = self.relu(self.conv1a(input_images))
diff --git a/third_party/SOLD2/sold2/model/nets/descriptor_decoder.py b/third_party/SOLD2/sold2/model/nets/descriptor_decoder.py
index 6ed4306fad764efab2c22ede9cae253c9b17d6c2..449bac37e6b0e6ff7802c0dbcea92f4829786578 100644
--- a/third_party/SOLD2/sold2/model/nets/descriptor_decoder.py
+++ b/third_party/SOLD2/sold2/model/nets/descriptor_decoder.py
@@ -3,17 +3,18 @@ import torch.nn as nn
 
 
 class SuperpointDescriptor(nn.Module):
-    """ Descriptor decoder based on the SuperPoint arcihtecture. """
+    """Descriptor decoder based on the SuperPoint arcihtecture."""
+
     def __init__(self, input_feat_dim=128):
         super(SuperpointDescriptor, self).__init__()
         self.relu = torch.nn.ReLU(inplace=True)
-        self.convPa = torch.nn.Conv2d(input_feat_dim, 256, kernel_size=3,
-                                      stride=1, padding=1)        
-        self.convPb = torch.nn.Conv2d(256, 128, kernel_size=1,
-                                      stride=1, padding=0)
+        self.convPa = torch.nn.Conv2d(
+            input_feat_dim, 256, kernel_size=3, stride=1, padding=1
+        )
+        self.convPb = torch.nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0)
 
     def forward(self, input_features):
         feat = self.relu(self.convPa(input_features))
         semi = self.convPb(feat)
 
-        return semi
\ No newline at end of file
+        return semi
diff --git a/third_party/SOLD2/sold2/model/nets/heatmap_decoder.py b/third_party/SOLD2/sold2/model/nets/heatmap_decoder.py
index bd5157ca740c8c7e25f2183b2a3c1fefa813deca..11828426a2852fb3e9ee3e6a3310ca89cbcd4d78 100644
--- a/third_party/SOLD2/sold2/model/nets/heatmap_decoder.py
+++ b/third_party/SOLD2/sold2/model/nets/heatmap_decoder.py
@@ -2,7 +2,8 @@ import torch.nn as nn
 
 
 class PixelShuffleDecoder(nn.Module):
-    """ Pixel shuffle decoder. """
+    """Pixel shuffle decoder."""
+
     def __init__(self, input_feat_dim=128, num_upsample=2, output_channel=2):
         super(PixelShuffleDecoder, self).__init__()
         # Get channel parameters
@@ -10,35 +11,46 @@ class PixelShuffleDecoder(nn.Module):
 
         # Define the pixel shuffle
         self.pixshuffle = nn.PixelShuffle(2)
-        
+
         # Process the feature
         self.conv_block_lst = []
         # The input block
         self.conv_block_lst.append(
             nn.Sequential(
-                nn.Conv2d(input_feat_dim, self.channel_conf[0],
-                          kernel_size=3, stride=1, padding=1),
+                nn.Conv2d(
+                    input_feat_dim,
+                    self.channel_conf[0],
+                    kernel_size=3,
+                    stride=1,
+                    padding=1,
+                ),
                 nn.BatchNorm2d(self.channel_conf[0]),
-                nn.ReLU(inplace=True)
-        ))
+                nn.ReLU(inplace=True),
+            )
+        )
 
         # Intermediate block
         for channel in self.channel_conf[1:-1]:
             self.conv_block_lst.append(
                 nn.Sequential(
-                    nn.Conv2d(channel, channel, kernel_size=3,
-                              stride=1, padding=1),
+                    nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1),
                     nn.BatchNorm2d(channel),
-                    nn.ReLU(inplace=True)
-            ))
-        
+                    nn.ReLU(inplace=True),
+                )
+            )
+
         # Output block
         self.conv_block_lst.append(
-            nn.Conv2d(self.channel_conf[-1], output_channel,
-                      kernel_size=1, stride=1, padding=0)
+            nn.Conv2d(
+                self.channel_conf[-1],
+                output_channel,
+                kernel_size=1,
+                stride=1,
+                padding=0,
+            )
         )
         self.conv_block_lst = nn.ModuleList(self.conv_block_lst)
-    
+
     # Get num of channels based on number of upsampling.
     def get_channel_conf(self, num_upsample):
         if num_upsample == 2:
@@ -52,7 +64,7 @@ class PixelShuffleDecoder(nn.Module):
         for block in self.conv_block_lst[:-1]:
             out = block(out)
             out = self.pixshuffle(out)
-        
+
         # Output layer
         out = self.conv_block_lst[-1](out)
 
diff --git a/third_party/SOLD2/sold2/model/nets/junction_decoder.py b/third_party/SOLD2/sold2/model/nets/junction_decoder.py
index d2bb649518896501c784940028a772d688c2b3a7..ea90a6b6821d994461dee83f85a6d2851d78e055 100644
--- a/third_party/SOLD2/sold2/model/nets/junction_decoder.py
+++ b/third_party/SOLD2/sold2/model/nets/junction_decoder.py
@@ -3,25 +3,27 @@ import torch.nn as nn
 
 
 class SuperpointDecoder(nn.Module):
-    """ Junction decoder based on the SuperPoint architecture. """
+    """Junction decoder based on the SuperPoint architecture."""
+
     def __init__(self, input_feat_dim=128, backbone_name="lcnn"):
         super(SuperpointDecoder, self).__init__()
         self.relu = torch.nn.ReLU(inplace=True)
         # Perform strided convolution when using lcnn backbone.
         if backbone_name == "lcnn":
-            self.convPa = torch.nn.Conv2d(input_feat_dim, 256, kernel_size=3,
-                                          stride=2, padding=1)
+            self.convPa = torch.nn.Conv2d(
+                input_feat_dim, 256, kernel_size=3, stride=2, padding=1
+            )
         elif backbone_name == "superpoint":
-            self.convPa = torch.nn.Conv2d(input_feat_dim, 256, kernel_size=3,
-                                          stride=1, padding=1)
+            self.convPa = torch.nn.Conv2d(
+                input_feat_dim, 256, kernel_size=3, stride=1, padding=1
+            )
         else:
             raise ValueError("[Error] Unknown backbone option.")
-        
-        self.convPb = torch.nn.Conv2d(256, 65, kernel_size=1,
-                                      stride=1, padding=0)
+
+        self.convPb = torch.nn.Conv2d(256, 65, kernel_size=1, stride=1, padding=0)
 
     def forward(self, input_features):
         feat = self.relu(self.convPa(input_features))
         semi = self.convPb(feat)
 
-        return semi
\ No newline at end of file
+        return semi
diff --git a/third_party/SOLD2/sold2/model/nets/lcnn_hourglass.py b/third_party/SOLD2/sold2/model/nets/lcnn_hourglass.py
index a9dc78eef34e7ee146166b1b66c10070799d63f3..c25594d9dda28624337546fd8fec27e1c59b452f 100644
--- a/third_party/SOLD2/sold2/model/nets/lcnn_hourglass.py
+++ b/third_party/SOLD2/sold2/model/nets/lcnn_hourglass.py
@@ -39,8 +39,7 @@ class Bottleneck2D(nn.Module):
         self.bn1 = nn.BatchNorm2d(inplanes)
         self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1)
         self.bn2 = nn.BatchNorm2d(planes)
-        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
-                               stride=stride, padding=1)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1)
         self.bn3 = nn.BatchNorm2d(planes)
         self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1)
         self.relu = nn.ReLU(inplace=True)
@@ -116,15 +115,17 @@ class Hourglass(nn.Module):
 class HourglassNet(nn.Module):
     """Hourglass model from Newell et al ECCV 2016"""
 
-    def __init__(self, block, head, depth, num_stacks, num_blocks,
-                 num_classes, input_channels):
+    def __init__(
+        self, block, head, depth, num_stacks, num_blocks, num_classes, input_channels
+    ):
         super(HourglassNet, self).__init__()
 
         self.inplanes = 64
         self.num_feats = 128
         self.num_stacks = num_stacks
-        self.conv1 = nn.Conv2d(input_channels, self.inplanes, kernel_size=7,
-                               stride=2, padding=3)
+        self.conv1 = nn.Conv2d(
+            input_channels, self.inplanes, kernel_size=7, stride=2, padding=3
+        )
         self.bn1 = nn.BatchNorm2d(self.inplanes)
         self.relu = nn.ReLU(inplace=True)
         self.layer1 = self._make_residual(block, self.inplanes, 1)
@@ -215,12 +216,11 @@ class HourglassNet(nn.Module):
 def hg(**kwargs):
     model = HourglassNet(
         Bottleneck2D,
-        head=kwargs.get("head",
-                        lambda c_in, c_out: nn.Conv2D(c_in, c_out, 1)),
+        head=kwargs.get("head", lambda c_in, c_out: nn.Conv2D(c_in, c_out, 1)),
         depth=kwargs["depth"],
         num_stacks=kwargs["num_stacks"],
         num_blocks=kwargs["num_blocks"],
         num_classes=kwargs["num_classes"],
-        input_channels=kwargs["input_channels"]
+        input_channels=kwargs["input_channels"],
     )
     return model
diff --git a/third_party/SOLD2/sold2/postprocess/convert_homography_results.py b/third_party/SOLD2/sold2/postprocess/convert_homography_results.py
index 352eebbde00f6d8a9c20517dccd7024fd0758ffd..61045777bde0190e872c1c3983f1172ef36d8f1c 100644
--- a/third_party/SOLD2/sold2/postprocess/convert_homography_results.py
+++ b/third_party/SOLD2/sold2/postprocess/convert_homography_results.py
@@ -2,6 +2,7 @@
 Convert the aggregation results from the homography adaptation to GT labels.
 """
 import sys
+
 sys.path.append("../")
 import os
 import yaml
@@ -17,9 +18,10 @@ from model.metrics import super_nms
 from misc.train_utils import parse_h5_data
 
 
-def convert_raw_exported_predictions(input_data, grid_size=8,
-                                     detect_thresh=1/65, topk=300):
-    """ Convert the exported junctions and heatmaps predictions
+def convert_raw_exported_predictions(
+    input_data, grid_size=8, detect_thresh=1 / 65, topk=300
+):
+    """Convert the exported junctions and heatmaps predictions
         to a standard format.
     Arguments:
         input_data: the raw data (dict) decoded from the hdf5 dataset
@@ -31,28 +33,29 @@ def convert_raw_exported_predictions(input_data, grid_size=8,
     # Check the input_data is from (1) single prediction,
     # or (2) homography adaptation.
     # Homography adaptation raw predictions
-    if (("junc_prob_mean" in input_data.keys())
-        and ("heatmap_prob_mean" in input_data.keys())):
+    if ("junc_prob_mean" in input_data.keys()) and (
+        "heatmap_prob_mean" in input_data.keys()
+    ):
         # Get the junction predictions and convert if to Nx2 format
         junc_prob = input_data["junc_prob_mean"]
         junc_pred_np = junc_prob[None, ...]
-        junc_pred_np_nms = super_nms(junc_pred_np, grid_size,
-                                     detect_thresh, topk)
+        junc_pred_np_nms = super_nms(junc_pred_np, grid_size, detect_thresh, topk)
         junctions = np.where(junc_pred_np_nms.squeeze())
-        junc_points_pred = np.concatenate([junctions[0][..., None],
-                                           junctions[1][..., None]], axis=-1)
+        junc_points_pred = np.concatenate(
+            [junctions[0][..., None], junctions[1][..., None]], axis=-1
+        )
 
         # Get the heatmap predictions
         heatmap_pred = input_data["heatmap_prob_mean"].squeeze()
         valid_mask = np.ones(heatmap_pred.shape, dtype=np.int32)
-        
+
     # Single predictions
     else:
         # Get the junction point predictions and convert to Nx2 format
         junc_points_pred = np.where(input_data["junc_pred_nms"])
         junc_points_pred = np.concatenate(
-            [junc_points_pred[0][..., None],
-             junc_points_pred[1][..., None]], axis=-1)
+            [junc_points_pred[0][..., None], junc_points_pred[1][..., None]], axis=-1
+        )
 
         # Get the heatmap predictions
         heatmap_pred = input_data["heatmap_pred"]
@@ -61,34 +64,29 @@ def convert_raw_exported_predictions(input_data, grid_size=8,
     return {
         "junctions_pred": junc_points_pred,
         "heatmap_pred": heatmap_pred,
-        "valid_mask": valid_mask
+        "valid_mask": valid_mask,
     }
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument("input_dataset", type=str,
-                        help="Name of the exported dataset.")
-    parser.add_argument("output_dataset", type=str,
-                        help="Name of the output dataset.")
-    parser.add_argument("config", type=str,
-                        help="Path to the model config.")
-    args = parser.parse_args()    
-    
+    parser.add_argument("input_dataset", type=str, help="Name of the exported dataset.")
+    parser.add_argument("output_dataset", type=str, help="Name of the output dataset.")
+    parser.add_argument("config", type=str, help="Path to the model config.")
+    args = parser.parse_args()
+
     # Define the path to the input exported dataset
-    exported_dataset_path = os.path.join(cfg.export_dataroot,
-                                         args.input_dataset)
+    exported_dataset_path = os.path.join(cfg.export_dataroot, args.input_dataset)
     if not os.path.exists(exported_dataset_path):
         raise ValueError("Missing input dataset: " + exported_dataset_path)
     exported_dataset = h5py.File(exported_dataset_path, "r")
 
     # Define the output path for the results
-    output_dataset_path = os.path.join(cfg.export_dataroot,
-                                       args.output_dataset)
+    output_dataset_path = os.path.join(cfg.export_dataroot, args.output_dataset)
 
     device = torch.device("cuda")
     nms_device = torch.device("cuda")
-    
+
     # Read the config file
     if not os.path.exists(args.config):
         raise ValueError("Missing config file: " + args.config)
@@ -96,41 +94,43 @@ if __name__ == "__main__":
         config = yaml.safe_load(f)
     model_cfg = config["model_cfg"]
     line_detector_cfg = config["line_detector_cfg"]
-    
+
     # Initialize the line detection module
     line_detector = LineSegmentDetectionModule(**line_detector_cfg)
 
     # Iterate through all the dataset keys
     with h5py.File(output_dataset_path, "w") as output_dataset:
-        for idx, output_key in enumerate(tqdm(list(exported_dataset.keys()),
-                                              ascii=True)):
+        for idx, output_key in enumerate(
+            tqdm(list(exported_dataset.keys()), ascii=True)
+        ):
             # Get the data
             data = parse_h5_data(exported_dataset[output_key])
 
             # Preprocess the data
             converted_data = convert_raw_exported_predictions(
-                data, grid_size=model_cfg["grid_size"],
-                detect_thresh=model_cfg["detection_thresh"])
+                data,
+                grid_size=model_cfg["grid_size"],
+                detect_thresh=model_cfg["detection_thresh"],
+            )
             junctions_pred_raw = converted_data["junctions_pred"]
             heatmap_pred = converted_data["heatmap_pred"]
             valid_mask = converted_data["valid_mask"]
 
             line_map_pred, junctions_pred, heatmap_pred = line_detector.detect(
-                junctions_pred_raw, heatmap_pred, device=device)
+                junctions_pred_raw, heatmap_pred, device=device
+            )
             if isinstance(line_map_pred, torch.Tensor):
                 line_map_pred = line_map_pred.cpu().numpy()
             if isinstance(junctions_pred, torch.Tensor):
                 junctions_pred = junctions_pred.cpu().numpy()
             if isinstance(heatmap_pred, torch.Tensor):
                 heatmap_pred = heatmap_pred.cpu().numpy()
-            
-            output_data = {"junctions": junctions_pred,
-                           "line_map": line_map_pred}
+
+            output_data = {"junctions": junctions_pred, "line_map": line_map_pred}
 
             # Record it to the h5 dataset
             f_group = output_dataset.create_group(output_key)
 
             # Store data
             for key, output_data in output_data.items():
-                f_group.create_dataset(key, data=output_data,
-                                       compression="gzip")
+                f_group.create_dataset(key, data=output_data, compression="gzip")
diff --git a/third_party/SOLD2/sold2/train.py b/third_party/SOLD2/sold2/train.py
index 2064e00e6d192f9202f011c3626d6f53c4fe6270..148c9b23464d975f1efc03ea459c82d4a0759b05 100644
--- a/third_party/SOLD2/sold2/train.py
+++ b/third_party/SOLD2/sold2/train.py
@@ -15,12 +15,15 @@ from .model.model_util import get_model
 from .model.loss import TotalLoss, get_loss_and_weights
 from .model.metrics import AverageMeter, Metrics, super_nms
 from .model.lr_scheduler import get_lr_scheduler
-from .misc.train_utils import (convert_image, get_latest_checkpoint,
-                               remove_old_checkpoints)
+from .misc.train_utils import (
+    convert_image,
+    get_latest_checkpoint,
+    remove_old_checkpoints,
+)
 
 
 def customized_collate_fn(batch):
-    """ Customized collate_fn. """
+    """Customized collate_fn."""
     batch_keys = ["image", "junction_map", "heatmap", "valid_mask"]
     list_keys = ["junctions", "line_map"]
 
@@ -34,14 +37,14 @@ def customized_collate_fn(batch):
 
 
 def restore_weights(model, state_dict, strict=True):
-    """ Restore weights in compatible mode. """
+    """Restore weights in compatible mode."""
     # Try to directly load state dict
     try:
         model.load_state_dict(state_dict, strict=strict)
     # Deal with some version compatibility issue (catch version incompatible)
     except:
         err = model.load_state_dict(state_dict, strict=False)
-        
+
         # missing keys are those in model but not in state_dict
         missing_keys = err.missing_keys
         # Unexpected keys are those in state_dict but not in model
@@ -53,12 +56,12 @@ def restore_weights(model, state_dict, strict=True):
             dict_keys = [_ for _ in unexpected_keys if not "tracked" in _]
             model_dict[key] = state_dict[dict_keys[idx]]
         model.load_state_dict(model_dict)
-    
+
     return model
 
 
 def train_net(args, dataset_cfg, model_cfg, output_path):
-    """ Main training function. """
+    """Main training function."""
     # Add some version compatibility check
     if model_cfg.get("weighting_policy") is None:
         # Default to static
@@ -74,44 +77,50 @@ def train_net(args, dataset_cfg, model_cfg, output_path):
     test_dataset, test_collate_fn = get_dataset("test", dataset_cfg)
 
     # Create the dataloader
-    train_loader = DataLoader(train_dataset,
-                              batch_size=train_cfg["batch_size"],
-                              num_workers=8,
-                              shuffle=True, pin_memory=True,
-                              collate_fn=train_collate_fn)
-    test_loader = DataLoader(test_dataset,
-                             batch_size=test_cfg.get("batch_size", 1),
-                             num_workers=test_cfg.get("num_workers", 1),
-                             shuffle=False, pin_memory=False,
-                             collate_fn=test_collate_fn)
+    train_loader = DataLoader(
+        train_dataset,
+        batch_size=train_cfg["batch_size"],
+        num_workers=8,
+        shuffle=True,
+        pin_memory=True,
+        collate_fn=train_collate_fn,
+    )
+    test_loader = DataLoader(
+        test_dataset,
+        batch_size=test_cfg.get("batch_size", 1),
+        num_workers=test_cfg.get("num_workers", 1),
+        shuffle=False,
+        pin_memory=False,
+        collate_fn=test_collate_fn,
+    )
     print("\t Successfully intialized dataloaders.")
 
-
     # Get the loss function and weight first
     loss_funcs, loss_weights = get_loss_and_weights(model_cfg)
 
     # If resume.
     if args.resume:
         # Create model and load the state dict
-        checkpoint = get_latest_checkpoint(args.resume_path,
-                                           args.checkpoint_name)
+        checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name)
         model = get_model(model_cfg, loss_weights)
         model = restore_weights(model, checkpoint["model_state_dict"])
         model = model.cuda()
         optimizer = torch.optim.Adam(
-            [{"params": model.parameters(),
-              "initial_lr": model_cfg["learning_rate"]}], 
-            model_cfg["learning_rate"], 
-            amsgrad=True)
+            [{"params": model.parameters(), "initial_lr": model_cfg["learning_rate"]}],
+            model_cfg["learning_rate"],
+            amsgrad=True,
+        )
         optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
         # Optionally get the learning rate scheduler
         scheduler = get_lr_scheduler(
             lr_decay=model_cfg.get("lr_decay", False),
             lr_decay_cfg=model_cfg.get("lr_decay_cfg", None),
-            optimizer=optimizer)
+            optimizer=optimizer,
+        )
         # If we start to use learning rate scheduler from the middle
-        if ((scheduler is not None)
-            and (checkpoint.get("scheduler_state_dict", None) is not None)):
+        if (scheduler is not None) and (
+            checkpoint.get("scheduler_state_dict", None) is not None
+        ):
             scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
         start_epoch = checkpoint["epoch"] + 1
     # Initialize all the components.
@@ -121,40 +130,45 @@ def train_net(args, dataset_cfg, model_cfg, output_path):
         # Optionally get the pretrained wieghts
         if args.pretrained:
             print("\t [Debug] Loading pretrained weights...")
-            checkpoint = get_latest_checkpoint(args.pretrained_path,
-                                               args.checkpoint_name)
+            checkpoint = get_latest_checkpoint(
+                args.pretrained_path, args.checkpoint_name
+            )
             # If auto weighting restore from non-auto weighting
-            model = restore_weights(model, checkpoint["model_state_dict"],
-                                    strict=False)
+            model = restore_weights(model, checkpoint["model_state_dict"], strict=False)
             print("\t [Debug] Finished loading pretrained weights!")
-        
+
         model = model.cuda()
         optimizer = torch.optim.Adam(
-            [{"params": model.parameters(),
-              "initial_lr": model_cfg["learning_rate"]}], 
-            model_cfg["learning_rate"], 
-            amsgrad=True)
+            [{"params": model.parameters(), "initial_lr": model_cfg["learning_rate"]}],
+            model_cfg["learning_rate"],
+            amsgrad=True,
+        )
         # Optionally get the learning rate scheduler
         scheduler = get_lr_scheduler(
             lr_decay=model_cfg.get("lr_decay", False),
             lr_decay_cfg=model_cfg.get("lr_decay_cfg", None),
-            optimizer=optimizer)
+            optimizer=optimizer,
+        )
         start_epoch = 0
-    
+
     print("\t Successfully initialized model")
 
     # Define the total loss
     policy = model_cfg.get("weighting_policy", "static")
     loss_func = TotalLoss(loss_funcs, loss_weights, policy).cuda()
     if "descriptor_decoder" in model_cfg:
-        metric_func = Metrics(model_cfg["detection_thresh"],
-                              model_cfg["prob_thresh"],
-                              model_cfg["descriptor_loss_cfg"]["grid_size"],
-                              desc_metric_lst='all')
+        metric_func = Metrics(
+            model_cfg["detection_thresh"],
+            model_cfg["prob_thresh"],
+            model_cfg["descriptor_loss_cfg"]["grid_size"],
+            desc_metric_lst="all",
+        )
     else:
-        metric_func = Metrics(model_cfg["detection_thresh"],
-                              model_cfg["prob_thresh"],
-                              model_cfg["grid_size"])
+        metric_func = Metrics(
+            model_cfg["detection_thresh"],
+            model_cfg["prob_thresh"],
+            model_cfg["grid_size"],
+        )
 
     # Define the summary writer
     logdir = os.path.join(output_path, "log")
@@ -176,7 +190,8 @@ def train_net(args, dataset_cfg, model_cfg, output_path):
             metric_func=metric_func,
             train_loader=train_loader,
             writer=writer,
-            epoch=epoch)
+            epoch=epoch,
+        )
 
         # Do the validation
         print("\n\n================== Validation ==================")
@@ -187,21 +202,22 @@ def train_net(args, dataset_cfg, model_cfg, output_path):
             metric_func=metric_func,
             val_loader=test_loader,
             writer=writer,
-            epoch=epoch)
+            epoch=epoch,
+        )
 
         # Update the scheduler
         if scheduler is not None:
             scheduler.step()
 
         # Save checkpoints
-        file_name = os.path.join(output_path,
-                                 "checkpoint-epoch%03d-end.tar"%(epoch))
+        file_name = os.path.join(output_path, "checkpoint-epoch%03d-end.tar" % (epoch))
         print("[Info] Saving checkpoint %s ..." % file_name)
         save_dict = {
             "epoch": epoch,
             "model_state_dict": model.state_dict(),
             "optimizer_state_dict": optimizer.state_dict(),
-            "model_cfg": model_cfg}
+            "model_cfg": model_cfg,
+        }
         if scheduler is not None:
             save_dict.update({"scheduler_state_dict": scheduler.state_dict()})
         torch.save(save_dict, file_name)
@@ -210,16 +226,17 @@ def train_net(args, dataset_cfg, model_cfg, output_path):
         remove_old_checkpoints(output_path, model_cfg.get("max_ckpt", 15))
 
 
-def train_single_epoch(model, model_cfg, optimizer, loss_func, metric_func,
-                       train_loader, writer, epoch):
-    """ Train for one epoch. """
+def train_single_epoch(
+    model, model_cfg, optimizer, loss_func, metric_func, train_loader, writer, epoch
+):
+    """Train for one epoch."""
     # Switch the model to training mode
     model.train()
 
     # Initialize the average meter
     compute_descriptors = loss_func.compute_descriptors
     if compute_descriptors:
-        average_meter = AverageMeter(is_training=True, desc_metric_lst='all')
+        average_meter = AverageMeter(is_training=True, desc_metric_lst="all")
     else:
         average_meter = AverageMeter(is_training=True)
 
@@ -244,11 +261,23 @@ def train_single_epoch(model, model_cfg, optimizer, loss_func, metric_func,
 
             # Compute losses
             losses = loss_func.forward_descriptors(
-                outputs["junctions"], outputs2["junctions"],
-                junc_map, junc_map2, outputs["heatmap"], outputs2["heatmap"],
-                heatmap, heatmap2, line_points, line_points2,
-                line_indices, outputs['descriptors'], outputs2['descriptors'],
-                epoch, valid_mask, valid_mask2)
+                outputs["junctions"],
+                outputs2["junctions"],
+                junc_map,
+                junc_map2,
+                outputs["heatmap"],
+                outputs2["heatmap"],
+                heatmap,
+                heatmap2,
+                line_points,
+                line_points2,
+                line_indices,
+                outputs["descriptors"],
+                outputs2["descriptors"],
+                epoch,
+                valid_mask,
+                valid_mask2,
+            )
         else:
             junc_map = data["junction_map"].cuda()
             heatmap = data["heatmap"].cuda()
@@ -260,58 +289,74 @@ def train_single_epoch(model, model_cfg, optimizer, loss_func, metric_func,
 
             # Compute losses
             losses = loss_func(
-                outputs["junctions"], junc_map,
-                outputs["heatmap"], heatmap,
-                valid_mask)
-        
+                outputs["junctions"], junc_map, outputs["heatmap"], heatmap, valid_mask
+            )
+
         total_loss = losses["total_loss"]
 
         # Update the model
         optimizer.zero_grad()
-        total_loss.backward()                     
+        total_loss.backward()
         optimizer.step()
 
         # Compute the global step
         global_step = epoch * len(train_loader) + idx
         ############## Measure the metric error #########################
         # Only do this when needed
-        if (((idx % model_cfg["disp_freq"]) == 0)
-            or ((idx % model_cfg["summary_freq"]) == 0)):
+        if ((idx % model_cfg["disp_freq"]) == 0) or (
+            (idx % model_cfg["summary_freq"]) == 0
+        ):
             junc_np = convert_junc_predictions(
-                outputs["junctions"], model_cfg["grid_size"],
-                model_cfg["detection_thresh"], 300)
+                outputs["junctions"],
+                model_cfg["grid_size"],
+                model_cfg["detection_thresh"],
+                300,
+            )
             junc_map_np = junc_map.cpu().numpy().transpose(0, 2, 3, 1)
 
             # Always fetch only one channel (compatible with L1, L2, and CE)
             if outputs["heatmap"].shape[1] == 2:
-                heatmap_np = softmax(outputs["heatmap"].detach(),
-                                     dim=1).cpu().numpy()
+                heatmap_np = softmax(outputs["heatmap"].detach(), dim=1).cpu().numpy()
                 heatmap_np = heatmap_np.transpose(0, 2, 3, 1)[:, :, :, 1:]
             else:
                 heatmap_np = torch.sigmoid(outputs["heatmap"].detach())
                 heatmap_np = heatmap_np.cpu().numpy().transpose(0, 2, 3, 1)
-            
+
             heatmap_gt_np = heatmap.cpu().numpy().transpose(0, 2, 3, 1)
             valid_mask_np = valid_mask.cpu().numpy().transpose(0, 2, 3, 1)
 
             # Evaluate metric results
             if compute_descriptors:
                 metric_func.evaluate(
-                    junc_np["junc_pred"], junc_np["junc_pred_nms"],
-                    junc_map_np, heatmap_np, heatmap_gt_np, valid_mask_np,
-                    line_points, line_points2, outputs["descriptors"],
-                    outputs2["descriptors"], line_indices)
+                    junc_np["junc_pred"],
+                    junc_np["junc_pred_nms"],
+                    junc_map_np,
+                    heatmap_np,
+                    heatmap_gt_np,
+                    valid_mask_np,
+                    line_points,
+                    line_points2,
+                    outputs["descriptors"],
+                    outputs2["descriptors"],
+                    line_indices,
+                )
             else:
                 metric_func.evaluate(
-                    junc_np["junc_pred"], junc_np["junc_pred_nms"],
-                    junc_map_np, heatmap_np, heatmap_gt_np, valid_mask_np)
+                    junc_np["junc_pred"],
+                    junc_np["junc_pred_nms"],
+                    junc_map_np,
+                    heatmap_np,
+                    heatmap_gt_np,
+                    valid_mask_np,
+                )
             # Update average meter
             junc_loss = losses["junc_loss"].item()
             heatmap_loss = losses["heatmap_loss"].item()
             loss_dict = {
                 "junc_loss": junc_loss,
                 "heatmap_loss": heatmap_loss,
-                "total_loss": total_loss.item()}
+                "total_loss": total_loss.item(),
+            }
             if compute_descriptors:
                 descriptor_loss = losses["descriptor_loss"].item()
                 loss_dict["descriptor_loss"] = losses["descriptor_loss"].item()
@@ -323,34 +368,75 @@ def train_single_epoch(model, model_cfg, optimizer, loss_func, metric_func,
             results = metric_func.metric_results
             average = average_meter.average()
             # Get gpu memory usage in GB
-            gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024 ** 3)
+            gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024**3)
             if compute_descriptors:
-                print("Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f), gpu_mem=%.4fGB"
-                      % (epoch, model_cfg["epochs"], idx, len(train_loader),
-                         total_loss.item(), average["total_loss"], junc_loss,
-                         average["junc_loss"], heatmap_loss,
-                         average["heatmap_loss"], descriptor_loss,
-                         average["descriptor_loss"], gpu_mem_usage))
+                print(
+                    "Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f), gpu_mem=%.4fGB"
+                    % (
+                        epoch,
+                        model_cfg["epochs"],
+                        idx,
+                        len(train_loader),
+                        total_loss.item(),
+                        average["total_loss"],
+                        junc_loss,
+                        average["junc_loss"],
+                        heatmap_loss,
+                        average["heatmap_loss"],
+                        descriptor_loss,
+                        average["descriptor_loss"],
+                        gpu_mem_usage,
+                    )
+                )
             else:
-                print("Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), gpu_mem=%.4fGB"
-                      % (epoch, model_cfg["epochs"], idx, len(train_loader),
-                         total_loss.item(), average["total_loss"],
-                         junc_loss, average["junc_loss"], heatmap_loss,
-                         average["heatmap_loss"], gpu_mem_usage))
-            print("\t Junction     precision=%.4f (%.4f) / recall=%.4f (%.4f)"
-                  % (results["junc_precision"], average["junc_precision"],
-                     results["junc_recall"], average["junc_recall"]))
-            print("\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)"
-                  % (results["junc_precision_nms"],
-                     average["junc_precision_nms"],
-                     results["junc_recall_nms"], average["junc_recall_nms"]))
-            print("\t Heatmap      precision=%.4f (%.4f) / recall=%.4f (%.4f)"
-                  %(results["heatmap_precision"],
+                print(
+                    "Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), gpu_mem=%.4fGB"
+                    % (
+                        epoch,
+                        model_cfg["epochs"],
+                        idx,
+                        len(train_loader),
+                        total_loss.item(),
+                        average["total_loss"],
+                        junc_loss,
+                        average["junc_loss"],
+                        heatmap_loss,
+                        average["heatmap_loss"],
+                        gpu_mem_usage,
+                    )
+                )
+            print(
+                "\t Junction     precision=%.4f (%.4f) / recall=%.4f (%.4f)"
+                % (
+                    results["junc_precision"],
+                    average["junc_precision"],
+                    results["junc_recall"],
+                    average["junc_recall"],
+                )
+            )
+            print(
+                "\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)"
+                % (
+                    results["junc_precision_nms"],
+                    average["junc_precision_nms"],
+                    results["junc_recall_nms"],
+                    average["junc_recall_nms"],
+                )
+            )
+            print(
+                "\t Heatmap      precision=%.4f (%.4f) / recall=%.4f (%.4f)"
+                % (
+                    results["heatmap_precision"],
                     average["heatmap_precision"],
-                    results["heatmap_recall"], average["heatmap_recall"]))
+                    results["heatmap_recall"],
+                    average["heatmap_recall"],
+                )
+            )
             if compute_descriptors:
-                print("\t Descriptors  matching score=%.4f (%.4f)"
-                      %(results["matching_score"], average["matching_score"]))
+                print(
+                    "\t Descriptors  matching score=%.4f (%.4f)"
+                    % (results["matching_score"], average["matching_score"])
+                )
 
         # Record summaries
         if (idx % model_cfg["summary_freq"]) == 0:
@@ -362,7 +448,8 @@ def train_single_epoch(model, model_cfg, optimizer, loss_func, metric_func,
                 "heatmap_loss": heatmap_loss,
                 "total_loss": total_loss.detach().cpu().numpy(),
                 "metrics": results,
-                "average": average}
+                "average": average,
+            }
             # Add descriptor terms
             if compute_descriptors:
                 scalar_summaries["descriptor_loss"] = descriptor_loss
@@ -374,10 +461,13 @@ def train_single_epoch(model, model_cfg, optimizer, loss_func, metric_func,
             scalar_summaries["reg_loss"] = losses["reg_loss"].item()
 
             num_images = 3
-            junc_pred_binary = (junc_np["junc_pred"][:num_images, ...]
-                                > model_cfg["detection_thresh"])
-            junc_pred_nms_binary = (junc_np["junc_pred_nms"][:num_images, ...]
-                                    > model_cfg["detection_thresh"])
+            junc_pred_binary = (
+                junc_np["junc_pred"][:num_images, ...] > model_cfg["detection_thresh"]
+            )
+            junc_pred_nms_binary = (
+                junc_np["junc_pred_nms"][:num_images, ...]
+                > model_cfg["detection_thresh"]
+            )
             image_summaries = {
                 "image": input_images.cpu().numpy()[:num_images, ...],
                 "valid_mask": valid_mask_np[:num_images, ...],
@@ -386,22 +476,23 @@ def train_single_epoch(model, model_cfg, optimizer, loss_func, metric_func,
                 "junc_map_gt": junc_map_np[:num_images, ...],
                 "junc_prob_map": junc_np["junc_prob"][:num_images, ...],
                 "heatmap_pred": heatmap_np[:num_images, ...],
-                "heatmap_gt": heatmap_gt_np[:num_images, ...]}
+                "heatmap_gt": heatmap_gt_np[:num_images, ...],
+            }
             # Record the training summary
             record_train_summaries(
-                writer, global_step, scalars=scalar_summaries,
-                images=image_summaries)
+                writer, global_step, scalars=scalar_summaries, images=image_summaries
+            )
 
 
 def validate(model, model_cfg, loss_func, metric_func, val_loader, writer, epoch):
-    """ Validation. """
+    """Validation."""
     # Switch the model to eval mode
     model.eval()
 
     # Initialize the average meter
     compute_descriptors = loss_func.compute_descriptors
     if compute_descriptors:
-        average_meter = AverageMeter(is_training=True, desc_metric_lst='all')
+        average_meter = AverageMeter(is_training=True, desc_metric_lst="all")
     else:
         average_meter = AverageMeter(is_training=True)
 
@@ -427,11 +518,23 @@ def validate(model, model_cfg, loss_func, metric_func, val_loader, writer, epoch
 
                 # Compute losses
                 losses = loss_func.forward_descriptors(
-                    outputs["junctions"], outputs2["junctions"],
-                    junc_map, junc_map2, outputs["heatmap"],
-                    outputs2["heatmap"], heatmap, heatmap2, line_points,
-                    line_points2, line_indices, outputs['descriptors'],
-                    outputs2['descriptors'], epoch, valid_mask, valid_mask2)
+                    outputs["junctions"],
+                    outputs2["junctions"],
+                    junc_map,
+                    junc_map2,
+                    outputs["heatmap"],
+                    outputs2["heatmap"],
+                    heatmap,
+                    heatmap2,
+                    line_points,
+                    line_points2,
+                    line_indices,
+                    outputs["descriptors"],
+                    outputs2["descriptors"],
+                    epoch,
+                    valid_mask,
+                    valid_mask2,
+                )
         else:
             junc_map = data["junction_map"].cuda()
             heatmap = data["heatmap"].cuda()
@@ -444,47 +547,70 @@ def validate(model, model_cfg, loss_func, metric_func, val_loader, writer, epoch
 
                 # Compute losses
                 losses = loss_func(
-                    outputs["junctions"], junc_map,
-                    outputs["heatmap"], heatmap,
-                    valid_mask)
+                    outputs["junctions"],
+                    junc_map,
+                    outputs["heatmap"],
+                    heatmap,
+                    valid_mask,
+                )
         total_loss = losses["total_loss"]
 
         ############## Measure the metric error #########################
         junc_np = convert_junc_predictions(
-            outputs["junctions"], model_cfg["grid_size"],
-            model_cfg["detection_thresh"], 300)
+            outputs["junctions"],
+            model_cfg["grid_size"],
+            model_cfg["detection_thresh"],
+            300,
+        )
         junc_map_np = junc_map.cpu().numpy().transpose(0, 2, 3, 1)
         # Always fetch only one channel (compatible with L1, L2, and CE)
         if outputs["heatmap"].shape[1] == 2:
-            heatmap_np = softmax(outputs["heatmap"].detach(),
-                                 dim=1).cpu().numpy().transpose(0, 2, 3, 1)
+            heatmap_np = (
+                softmax(outputs["heatmap"].detach(), dim=1)
+                .cpu()
+                .numpy()
+                .transpose(0, 2, 3, 1)
+            )
             heatmap_np = heatmap_np[:, :, :, 1:]
         else:
             heatmap_np = torch.sigmoid(outputs["heatmap"].detach())
             heatmap_np = heatmap_np.cpu().numpy().transpose(0, 2, 3, 1)
 
-
         heatmap_gt_np = heatmap.cpu().numpy().transpose(0, 2, 3, 1)
         valid_mask_np = valid_mask.cpu().numpy().transpose(0, 2, 3, 1)
 
         # Evaluate metric results
         if compute_descriptors:
             metric_func.evaluate(
-                junc_np["junc_pred"], junc_np["junc_pred_nms"],
-                junc_map_np, heatmap_np, heatmap_gt_np, valid_mask_np,
-                line_points, line_points2, outputs["descriptors"],
-                outputs2["descriptors"], line_indices)
+                junc_np["junc_pred"],
+                junc_np["junc_pred_nms"],
+                junc_map_np,
+                heatmap_np,
+                heatmap_gt_np,
+                valid_mask_np,
+                line_points,
+                line_points2,
+                outputs["descriptors"],
+                outputs2["descriptors"],
+                line_indices,
+            )
         else:
             metric_func.evaluate(
-                junc_np["junc_pred"], junc_np["junc_pred_nms"], junc_map_np,
-                heatmap_np, heatmap_gt_np, valid_mask_np)
+                junc_np["junc_pred"],
+                junc_np["junc_pred_nms"],
+                junc_map_np,
+                heatmap_np,
+                heatmap_gt_np,
+                valid_mask_np,
+            )
         # Update average meter
         junc_loss = losses["junc_loss"].item()
         heatmap_loss = losses["heatmap_loss"].item()
         loss_dict = {
             "junc_loss": junc_loss,
             "heatmap_loss": heatmap_loss,
-            "total_loss": total_loss.item()}
+            "total_loss": total_loss.item(),
+        }
         if compute_descriptors:
             descriptor_loss = losses["descriptor_loss"].item()
             loss_dict["descriptor_loss"] = losses["descriptor_loss"].item()
@@ -495,32 +621,67 @@ def validate(model, model_cfg, loss_func, metric_func, val_loader, writer, epoch
             results = metric_func.metric_results
             average = average_meter.average()
             if compute_descriptors:
-                print("Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f)"
-                      % (idx, len(val_loader),
-                         total_loss.item(), average["total_loss"],
-                         junc_loss, average["junc_loss"],
-                         heatmap_loss, average["heatmap_loss"],
-                         descriptor_loss, average["descriptor_loss"]))
+                print(
+                    "Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f)"
+                    % (
+                        idx,
+                        len(val_loader),
+                        total_loss.item(),
+                        average["total_loss"],
+                        junc_loss,
+                        average["junc_loss"],
+                        heatmap_loss,
+                        average["heatmap_loss"],
+                        descriptor_loss,
+                        average["descriptor_loss"],
+                    )
+                )
             else:
-                print("Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f)"
-                      % (idx, len(val_loader),
-                         total_loss.item(), average["total_loss"],
-                         junc_loss, average["junc_loss"],
-                         heatmap_loss, average["heatmap_loss"]))
-            print("\t Junction     precision=%.4f (%.4f) / recall=%.4f (%.4f)"
-                  % (results["junc_precision"], average["junc_precision"],
-                     results["junc_recall"], average["junc_recall"]))
-            print("\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)"
-                  % (results["junc_precision_nms"],
-                     average["junc_precision_nms"],
-                     results["junc_recall_nms"], average["junc_recall_nms"]))
-            print("\t Heatmap      precision=%.4f (%.4f) / recall=%.4f (%.4f)"
-                  % (results["heatmap_precision"],
-                     average["heatmap_precision"],
-                     results["heatmap_recall"], average["heatmap_recall"]))
+                print(
+                    "Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f)"
+                    % (
+                        idx,
+                        len(val_loader),
+                        total_loss.item(),
+                        average["total_loss"],
+                        junc_loss,
+                        average["junc_loss"],
+                        heatmap_loss,
+                        average["heatmap_loss"],
+                    )
+                )
+            print(
+                "\t Junction     precision=%.4f (%.4f) / recall=%.4f (%.4f)"
+                % (
+                    results["junc_precision"],
+                    average["junc_precision"],
+                    results["junc_recall"],
+                    average["junc_recall"],
+                )
+            )
+            print(
+                "\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)"
+                % (
+                    results["junc_precision_nms"],
+                    average["junc_precision_nms"],
+                    results["junc_recall_nms"],
+                    average["junc_recall_nms"],
+                )
+            )
+            print(
+                "\t Heatmap      precision=%.4f (%.4f) / recall=%.4f (%.4f)"
+                % (
+                    results["heatmap_precision"],
+                    average["heatmap_precision"],
+                    results["heatmap_recall"],
+                    average["heatmap_recall"],
+                )
+            )
             if compute_descriptors:
-                print("\t Descriptors  matching score=%.4f (%.4f)"
-                      %(results["matching_score"], average["matching_score"]))
+                print(
+                    "\t Descriptors  matching score=%.4f (%.4f)"
+                    % (results["matching_score"], average["matching_score"])
+                )
 
     # Record summaries
     average = average_meter.average()
@@ -529,143 +690,182 @@ def validate(model, model_cfg, loss_func, metric_func, val_loader, writer, epoch
     record_test_summaries(writer, epoch, scalar_summaries)
 
 
-def convert_junc_predictions(predictions, grid_size,
-                             detect_thresh=1/65, topk=300):
-    """ Convert torch predictions to numpy arrays for evaluation. """
+def convert_junc_predictions(predictions, grid_size, detect_thresh=1 / 65, topk=300):
+    """Convert torch predictions to numpy arrays for evaluation."""
     # Convert to probability outputs first
     junc_prob = softmax(predictions.detach(), dim=1).cpu()
     junc_pred = junc_prob[:, :-1, :, :]
 
     junc_prob_np = junc_prob.numpy().transpose(0, 2, 3, 1)[:, :, :, :-1]
     junc_prob_np = np.sum(junc_prob_np, axis=-1)
-    junc_pred_np = pixel_shuffle(
-        junc_pred, grid_size).cpu().numpy().transpose(0, 2, 3, 1)
+    junc_pred_np = (
+        pixel_shuffle(junc_pred, grid_size).cpu().numpy().transpose(0, 2, 3, 1)
+    )
     junc_pred_np_nms = super_nms(junc_pred_np, grid_size, detect_thresh, topk)
     junc_pred_np = junc_pred_np.squeeze(-1)
 
-    return {"junc_pred": junc_pred_np, "junc_pred_nms": junc_pred_np_nms,
-            "junc_prob": junc_prob_np}
+    return {
+        "junc_pred": junc_pred_np,
+        "junc_pred_nms": junc_pred_np_nms,
+        "junc_prob": junc_prob_np,
+    }
 
 
 def record_train_summaries(writer, global_step, scalars, images):
-    """ Record training summaries. """
+    """Record training summaries."""
     # Record the scalar summaries
     results = scalars["metrics"]
     average = scalars["average"]
 
     # GPU memory part
     # Get gpu memory usage in GB
-    gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024 ** 3)
+    gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024**3)
     writer.add_scalar("GPU/GPU_memory_usage", gpu_mem_usage, global_step)
 
     # Loss part
-    writer.add_scalar("Train_loss/junc_loss", scalars["junc_loss"],
-                      global_step)
-    writer.add_scalar("Train_loss/heatmap_loss", scalars["heatmap_loss"],
-                      global_step)
-    writer.add_scalar("Train_loss/total_loss", scalars["total_loss"],
-                      global_step)
+    writer.add_scalar("Train_loss/junc_loss", scalars["junc_loss"], global_step)
+    writer.add_scalar("Train_loss/heatmap_loss", scalars["heatmap_loss"], global_step)
+    writer.add_scalar("Train_loss/total_loss", scalars["total_loss"], global_step)
     # Add regularization loss
     if "reg_loss" in scalars.keys():
-        writer.add_scalar("Train_loss/reg_loss", scalars["reg_loss"],
-                          global_step)
+        writer.add_scalar("Train_loss/reg_loss", scalars["reg_loss"], global_step)
     # Add descriptor loss
     if "descriptor_loss" in scalars.keys():
         key = "descriptor_loss"
-        writer.add_scalar("Train_loss/%s"%(key), scalars[key], global_step)
-        writer.add_scalar("Train_loss_average/%s"%(key), average[key],
-                          global_step)
-    
+        writer.add_scalar("Train_loss/%s" % (key), scalars[key], global_step)
+        writer.add_scalar("Train_loss_average/%s" % (key), average[key], global_step)
+
     # Record weighting
     for key in scalars.keys():
         if "w_" in key:
-            writer.add_scalar("Train_weight/%s"%(key), scalars[key],
-                              global_step)
-    
+            writer.add_scalar("Train_weight/%s" % (key), scalars[key], global_step)
+
     # Smoothed loss
-    writer.add_scalar("Train_loss_average/junc_loss", average["junc_loss"],
-                      global_step)
-    writer.add_scalar("Train_loss_average/heatmap_loss",
-                      average["heatmap_loss"], global_step)
-    writer.add_scalar("Train_loss_average/total_loss", average["total_loss"],
-                      global_step)
+    writer.add_scalar("Train_loss_average/junc_loss", average["junc_loss"], global_step)
+    writer.add_scalar(
+        "Train_loss_average/heatmap_loss", average["heatmap_loss"], global_step
+    )
+    writer.add_scalar(
+        "Train_loss_average/total_loss", average["total_loss"], global_step
+    )
     # Add smoothed descriptor loss
     if "descriptor_loss" in average.keys():
-        writer.add_scalar("Train_loss_average/descriptor_loss",
-                          average["descriptor_loss"], global_step)
+        writer.add_scalar(
+            "Train_loss_average/descriptor_loss",
+            average["descriptor_loss"],
+            global_step,
+        )
 
     # Metrics part
-    writer.add_scalar("Train_metrics/junc_precision",
-                      results["junc_precision"], global_step)
-    writer.add_scalar("Train_metrics/junc_precision_nms",
-                      results["junc_precision_nms"], global_step)
-    writer.add_scalar("Train_metrics/junc_recall",
-                      results["junc_recall"], global_step)
-    writer.add_scalar("Train_metrics/junc_recall_nms",
-                      results["junc_recall_nms"], global_step)
-    writer.add_scalar("Train_metrics/heatmap_precision",
-                      results["heatmap_precision"], global_step)
-    writer.add_scalar("Train_metrics/heatmap_recall",
-                      results["heatmap_recall"], global_step)
+    writer.add_scalar(
+        "Train_metrics/junc_precision", results["junc_precision"], global_step
+    )
+    writer.add_scalar(
+        "Train_metrics/junc_precision_nms", results["junc_precision_nms"], global_step
+    )
+    writer.add_scalar("Train_metrics/junc_recall", results["junc_recall"], global_step)
+    writer.add_scalar(
+        "Train_metrics/junc_recall_nms", results["junc_recall_nms"], global_step
+    )
+    writer.add_scalar(
+        "Train_metrics/heatmap_precision", results["heatmap_precision"], global_step
+    )
+    writer.add_scalar(
+        "Train_metrics/heatmap_recall", results["heatmap_recall"], global_step
+    )
     # Add descriptor metric
     if "matching_score" in results.keys():
-        writer.add_scalar("Train_metrics/matching_score",
-                          results["matching_score"], global_step)
+        writer.add_scalar(
+            "Train_metrics/matching_score", results["matching_score"], global_step
+        )
 
     # Average part
-    writer.add_scalar("Train_metrics_average/junc_precision",
-                      average["junc_precision"], global_step)
-    writer.add_scalar("Train_metrics_average/junc_precision_nms",
-                      average["junc_precision_nms"], global_step)
-    writer.add_scalar("Train_metrics_average/junc_recall",
-                      average["junc_recall"], global_step)
-    writer.add_scalar("Train_metrics_average/junc_recall_nms",
-                      average["junc_recall_nms"], global_step)
-    writer.add_scalar("Train_metrics_average/heatmap_precision",
-                      average["heatmap_precision"], global_step)
-    writer.add_scalar("Train_metrics_average/heatmap_recall",
-                      average["heatmap_recall"], global_step)
+    writer.add_scalar(
+        "Train_metrics_average/junc_precision", average["junc_precision"], global_step
+    )
+    writer.add_scalar(
+        "Train_metrics_average/junc_precision_nms",
+        average["junc_precision_nms"],
+        global_step,
+    )
+    writer.add_scalar(
+        "Train_metrics_average/junc_recall", average["junc_recall"], global_step
+    )
+    writer.add_scalar(
+        "Train_metrics_average/junc_recall_nms", average["junc_recall_nms"], global_step
+    )
+    writer.add_scalar(
+        "Train_metrics_average/heatmap_precision",
+        average["heatmap_precision"],
+        global_step,
+    )
+    writer.add_scalar(
+        "Train_metrics_average/heatmap_recall", average["heatmap_recall"], global_step
+    )
     # Add smoothed descriptor metric
     if "matching_score" in average.keys():
-        writer.add_scalar("Train_metrics_average/matching_score",
-                          average["matching_score"], global_step)
+        writer.add_scalar(
+            "Train_metrics_average/matching_score",
+            average["matching_score"],
+            global_step,
+        )
 
     # Record the image summary
     # Image part
     image_tensor = convert_image(images["image"], 1)
     valid_masks = convert_image(images["valid_mask"], -1)
-    writer.add_images("Train/images", image_tensor, global_step,
-                      dataformats="NCHW")
-    writer.add_images("Train/valid_map", valid_masks, global_step,
-                      dataformats="NHWC")
+    writer.add_images("Train/images", image_tensor, global_step, dataformats="NCHW")
+    writer.add_images("Train/valid_map", valid_masks, global_step, dataformats="NHWC")
 
     # Heatmap part
-    writer.add_images("Train/heatmap_gt",
-                      convert_image(images["heatmap_gt"], -1), global_step,
-                      dataformats="NHWC")
-    writer.add_images("Train/heatmap_pred",
-                      convert_image(images["heatmap_pred"], -1), global_step,
-                      dataformats="NHWC")
+    writer.add_images(
+        "Train/heatmap_gt",
+        convert_image(images["heatmap_gt"], -1),
+        global_step,
+        dataformats="NHWC",
+    )
+    writer.add_images(
+        "Train/heatmap_pred",
+        convert_image(images["heatmap_pred"], -1),
+        global_step,
+        dataformats="NHWC",
+    )
 
     # Junction prediction part
     junc_plots = plot_junction_detection(
-        image_tensor, images["junc_map_pred"],
-        images["junc_map_pred_nms"], images["junc_map_gt"])
-    writer.add_images("Train/junc_gt", junc_plots["junc_gt_plot"] / 255.,
-                      global_step, dataformats="NHWC")
-    writer.add_images("Train/junc_pred", junc_plots["junc_pred_plot"] / 255.,
-                      global_step, dataformats="NHWC")
-    writer.add_images("Train/junc_pred_nms",
-                      junc_plots["junc_pred_nms_plot"] / 255., global_step,
-                      dataformats="NHWC")
+        image_tensor,
+        images["junc_map_pred"],
+        images["junc_map_pred_nms"],
+        images["junc_map_gt"],
+    )
+    writer.add_images(
+        "Train/junc_gt",
+        junc_plots["junc_gt_plot"] / 255.0,
+        global_step,
+        dataformats="NHWC",
+    )
+    writer.add_images(
+        "Train/junc_pred",
+        junc_plots["junc_pred_plot"] / 255.0,
+        global_step,
+        dataformats="NHWC",
+    )
+    writer.add_images(
+        "Train/junc_pred_nms",
+        junc_plots["junc_pred_nms_plot"] / 255.0,
+        global_step,
+        dataformats="NHWC",
+    )
     writer.add_images(
         "Train/junc_prob_map",
         convert_image(images["junc_prob_map"][..., None], axis=-1),
-        global_step, dataformats="NHWC")
+        global_step,
+        dataformats="NHWC",
+    )
 
 
 def record_test_summaries(writer, epoch, scalars):
-    """ Record testing summaries. """
+    """Record testing summaries."""
     average = scalars["average"]
 
     # Average loss
@@ -675,30 +875,30 @@ def record_test_summaries(writer, epoch, scalars):
     # Add descriptor loss
     if "descriptor_loss" in average.keys():
         key = "descriptor_loss"
-        writer.add_scalar("Val_loss/%s"%(key), average[key], epoch)
+        writer.add_scalar("Val_loss/%s" % (key), average[key], epoch)
 
     # Average metrics
-    writer.add_scalar("Val_metrics/junc_precision", average["junc_precision"],
-                      epoch)
-    writer.add_scalar("Val_metrics/junc_precision_nms",
-                      average["junc_precision_nms"], epoch)
-    writer.add_scalar("Val_metrics/junc_recall",
-                      average["junc_recall"], epoch)
-    writer.add_scalar("Val_metrics/junc_recall_nms",
-                      average["junc_recall_nms"], epoch)
-    writer.add_scalar("Val_metrics/heatmap_precision",
-                      average["heatmap_precision"], epoch)
-    writer.add_scalar("Val_metrics/heatmap_recall",
-                      average["heatmap_recall"], epoch)
+    writer.add_scalar("Val_metrics/junc_precision", average["junc_precision"], epoch)
+    writer.add_scalar(
+        "Val_metrics/junc_precision_nms", average["junc_precision_nms"], epoch
+    )
+    writer.add_scalar("Val_metrics/junc_recall", average["junc_recall"], epoch)
+    writer.add_scalar("Val_metrics/junc_recall_nms", average["junc_recall_nms"], epoch)
+    writer.add_scalar(
+        "Val_metrics/heatmap_precision", average["heatmap_precision"], epoch
+    )
+    writer.add_scalar("Val_metrics/heatmap_recall", average["heatmap_recall"], epoch)
     # Add descriptor metric
     if "matching_score" in average.keys():
-        writer.add_scalar("Val_metrics/matching_score",
-                          average["matching_score"], epoch)
+        writer.add_scalar(
+            "Val_metrics/matching_score", average["matching_score"], epoch
+        )
 
 
-def plot_junction_detection(image_tensor, junc_pred_tensor,
-                            junc_pred_nms_tensor, junc_gt_tensor):
-    """ Plot the junction points on images. """
+def plot_junction_detection(
+    image_tensor, junc_pred_tensor, junc_pred_nms_tensor, junc_gt_tensor
+):
+    """Plot the junction points on images."""
     # Get the batch_size
     batch_size = image_tensor.shape[0]
 
@@ -708,45 +908,61 @@ def plot_junction_detection(image_tensor, junc_pred_tensor,
     junc_gt_lst = []
     for i in range(batch_size):
         # Convert image to 255 uint8
-        image = (image_tensor[i, :, :, :]
-                 * 255.).astype(np.uint8).transpose(1,2,0)
+        image = (image_tensor[i, :, :, :] * 255.0).astype(np.uint8).transpose(1, 2, 0)
 
         # Plot groundtruth onto image
         junc_gt = junc_gt_tensor[i, ...]
         coord_gt = np.where(junc_gt.squeeze() > 0)
-        points_gt = np.concatenate((coord_gt[0][..., None],
-                                    coord_gt[1][..., None]),
-                                    axis=1)
+        points_gt = np.concatenate(
+            (coord_gt[0][..., None], coord_gt[1][..., None]), axis=1
+        )
         plot_gt = image.copy()
         for id in range(points_gt.shape[0]):
-            cv2.circle(plot_gt, tuple(np.flip(points_gt[id, :])), 3,
-                       color=(255, 0, 0), thickness=2)
+            cv2.circle(
+                plot_gt,
+                tuple(np.flip(points_gt[id, :])),
+                3,
+                color=(255, 0, 0),
+                thickness=2,
+            )
         junc_gt_lst.append(plot_gt[None, ...])
 
         # Plot junc_pred
         junc_pred = junc_pred_tensor[i, ...]
         coord_pred = np.where(junc_pred > 0)
-        points_pred = np.concatenate((coord_pred[0][..., None],
-                                      coord_pred[1][..., None]),
-                                      axis=1)
+        points_pred = np.concatenate(
+            (coord_pred[0][..., None], coord_pred[1][..., None]), axis=1
+        )
         plot_pred = image.copy()
         for id in range(points_pred.shape[0]):
-            cv2.circle(plot_pred, tuple(np.flip(points_pred[id, :])), 3,
-                       color=(0, 255, 0), thickness=2)
+            cv2.circle(
+                plot_pred,
+                tuple(np.flip(points_pred[id, :])),
+                3,
+                color=(0, 255, 0),
+                thickness=2,
+            )
         junc_pred_lst.append(plot_pred[None, ...])
 
         # Plot junc_pred_nms
         junc_pred_nms = junc_pred_nms_tensor[i, ...]
         coord_pred_nms = np.where(junc_pred_nms > 0)
-        points_pred_nms = np.concatenate((coord_pred_nms[0][..., None],
-                                          coord_pred_nms[1][..., None]),
-                                          axis=1)
+        points_pred_nms = np.concatenate(
+            (coord_pred_nms[0][..., None], coord_pred_nms[1][..., None]), axis=1
+        )
         plot_pred_nms = image.copy()
         for id in range(points_pred_nms.shape[0]):
-            cv2.circle(plot_pred_nms, tuple(np.flip(points_pred_nms[id, :])),
-                       3, color=(0, 255, 0), thickness=2)
+            cv2.circle(
+                plot_pred_nms,
+                tuple(np.flip(points_pred_nms[id, :])),
+                3,
+                color=(0, 255, 0),
+                thickness=2,
+            )
         junc_pred_nms_lst.append(plot_pred_nms[None, ...])
 
-    return {"junc_gt_plot": np.concatenate(junc_gt_lst, axis=0),
-            "junc_pred_plot": np.concatenate(junc_pred_lst, axis=0),
-            "junc_pred_nms_plot": np.concatenate(junc_pred_nms_lst, axis=0)}
+    return {
+        "junc_gt_plot": np.concatenate(junc_gt_lst, axis=0),
+        "junc_pred_plot": np.concatenate(junc_pred_lst, axis=0),
+        "junc_pred_nms_plot": np.concatenate(junc_pred_nms_lst, axis=0),
+    }
diff --git a/third_party/SuperGluePretrainedNetwork/demo_superglue.py b/third_party/SuperGluePretrainedNetwork/demo_superglue.py
index 32d4ad3c7df1b7da141c4c6aa51f871a7d756aaf..c639efd7481052b842c640d4aa23aaf18e0eb449 100644
--- a/third_party/SuperGluePretrainedNetwork/demo_superglue.py
+++ b/third_party/SuperGluePretrainedNetwork/demo_superglue.py
@@ -51,69 +51,110 @@ import matplotlib.cm as cm
 import torch
 
 from models.matching import Matching
-from models.utils import (AverageTimer, VideoStreamer,
-                          make_matching_plot_fast, frame2tensor)
+from models.utils import (
+    AverageTimer,
+    VideoStreamer,
+    make_matching_plot_fast,
+    frame2tensor,
+)
 
 torch.set_grad_enabled(False)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     parser = argparse.ArgumentParser(
-        description='SuperGlue demo',
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+        description="SuperGlue demo",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
     parser.add_argument(
-        '--input', type=str, default='0',
-        help='ID of a USB webcam, URL of an IP camera, '
-             'or path to an image directory or movie file')
+        "--input",
+        type=str,
+        default="0",
+        help="ID of a USB webcam, URL of an IP camera, "
+        "or path to an image directory or movie file",
+    )
     parser.add_argument(
-        '--output_dir', type=str, default=None,
-        help='Directory where to write output frames (If None, no output)')
+        "--output_dir",
+        type=str,
+        default=None,
+        help="Directory where to write output frames (If None, no output)",
+    )
 
     parser.add_argument(
-        '--image_glob', type=str, nargs='+', default=['*.png', '*.jpg', '*.jpeg'],
-        help='Glob if a directory of images is specified')
+        "--image_glob",
+        type=str,
+        nargs="+",
+        default=["*.png", "*.jpg", "*.jpeg"],
+        help="Glob if a directory of images is specified",
+    )
     parser.add_argument(
-        '--skip', type=int, default=1,
-        help='Images to skip if input is a movie or directory')
+        "--skip",
+        type=int,
+        default=1,
+        help="Images to skip if input is a movie or directory",
+    )
     parser.add_argument(
-        '--max_length', type=int, default=1000000,
-        help='Maximum length if input is a movie or directory')
+        "--max_length",
+        type=int,
+        default=1000000,
+        help="Maximum length if input is a movie or directory",
+    )
     parser.add_argument(
-        '--resize', type=int, nargs='+', default=[640, 480],
-        help='Resize the input image before running inference. If two numbers, '
-             'resize to the exact dimensions, if one number, resize the max '
-             'dimension, if -1, do not resize')
+        "--resize",
+        type=int,
+        nargs="+",
+        default=[640, 480],
+        help="Resize the input image before running inference. If two numbers, "
+        "resize to the exact dimensions, if one number, resize the max "
+        "dimension, if -1, do not resize",
+    )
 
     parser.add_argument(
-        '--superglue', choices={'indoor', 'outdoor'}, default='indoor',
-        help='SuperGlue weights')
+        "--superglue",
+        choices={"indoor", "outdoor"},
+        default="indoor",
+        help="SuperGlue weights",
+    )
     parser.add_argument(
-        '--max_keypoints', type=int, default=-1,
-        help='Maximum number of keypoints detected by Superpoint'
-             ' (\'-1\' keeps all keypoints)')
+        "--max_keypoints",
+        type=int,
+        default=-1,
+        help="Maximum number of keypoints detected by Superpoint"
+        " ('-1' keeps all keypoints)",
+    )
     parser.add_argument(
-        '--keypoint_threshold', type=float, default=0.005,
-        help='SuperPoint keypoint detector confidence threshold')
+        "--keypoint_threshold",
+        type=float,
+        default=0.005,
+        help="SuperPoint keypoint detector confidence threshold",
+    )
     parser.add_argument(
-        '--nms_radius', type=int, default=4,
-        help='SuperPoint Non Maximum Suppression (NMS) radius'
-        ' (Must be positive)')
+        "--nms_radius",
+        type=int,
+        default=4,
+        help="SuperPoint Non Maximum Suppression (NMS) radius" " (Must be positive)",
+    )
     parser.add_argument(
-        '--sinkhorn_iterations', type=int, default=20,
-        help='Number of Sinkhorn iterations performed by SuperGlue')
+        "--sinkhorn_iterations",
+        type=int,
+        default=20,
+        help="Number of Sinkhorn iterations performed by SuperGlue",
+    )
     parser.add_argument(
-        '--match_threshold', type=float, default=0.2,
-        help='SuperGlue match threshold')
+        "--match_threshold", type=float, default=0.2, help="SuperGlue match threshold"
+    )
 
     parser.add_argument(
-        '--show_keypoints', action='store_true',
-        help='Show the detected keypoints')
+        "--show_keypoints", action="store_true", help="Show the detected keypoints"
+    )
     parser.add_argument(
-        '--no_display', action='store_true',
-        help='Do not display images to screen. Useful if running remotely')
+        "--no_display",
+        action="store_true",
+        help="Do not display images to screen. Useful if running remotely",
+    )
     parser.add_argument(
-        '--force_cpu', action='store_true',
-        help='Force pytorch to run in CPU mode.')
+        "--force_cpu", action="store_true", help="Force pytorch to run in CPU mode."
+    )
 
     opt = parser.parse_args()
     print(opt)
@@ -121,138 +162,160 @@ if __name__ == '__main__':
     if len(opt.resize) == 2 and opt.resize[1] == -1:
         opt.resize = opt.resize[0:1]
     if len(opt.resize) == 2:
-        print('Will resize to {}x{} (WxH)'.format(
-            opt.resize[0], opt.resize[1]))
+        print("Will resize to {}x{} (WxH)".format(opt.resize[0], opt.resize[1]))
     elif len(opt.resize) == 1 and opt.resize[0] > 0:
-        print('Will resize max dimension to {}'.format(opt.resize[0]))
+        print("Will resize max dimension to {}".format(opt.resize[0]))
     elif len(opt.resize) == 1:
-        print('Will not resize images')
+        print("Will not resize images")
     else:
-        raise ValueError('Cannot specify more than two integers for --resize')
+        raise ValueError("Cannot specify more than two integers for --resize")
 
-    device = 'cuda' if torch.cuda.is_available() and not opt.force_cpu else 'cpu'
-    print('Running inference on device \"{}\"'.format(device))
+    device = "cuda" if torch.cuda.is_available() and not opt.force_cpu else "cpu"
+    print('Running inference on device "{}"'.format(device))
     config = {
-        'superpoint': {
-            'nms_radius': opt.nms_radius,
-            'keypoint_threshold': opt.keypoint_threshold,
-            'max_keypoints': opt.max_keypoints
+        "superpoint": {
+            "nms_radius": opt.nms_radius,
+            "keypoint_threshold": opt.keypoint_threshold,
+            "max_keypoints": opt.max_keypoints,
+        },
+        "superglue": {
+            "weights": opt.superglue,
+            "sinkhorn_iterations": opt.sinkhorn_iterations,
+            "match_threshold": opt.match_threshold,
         },
-        'superglue': {
-            'weights': opt.superglue,
-            'sinkhorn_iterations': opt.sinkhorn_iterations,
-            'match_threshold': opt.match_threshold,
-        }
     }
     matching = Matching(config).eval().to(device)
-    keys = ['keypoints', 'scores', 'descriptors']
+    keys = ["keypoints", "scores", "descriptors"]
 
-    vs = VideoStreamer(opt.input, opt.resize, opt.skip,
-                       opt.image_glob, opt.max_length)
+    vs = VideoStreamer(opt.input, opt.resize, opt.skip, opt.image_glob, opt.max_length)
     frame, ret = vs.next_frame()
-    assert ret, 'Error when reading the first frame (try different --input?)'
+    assert ret, "Error when reading the first frame (try different --input?)"
 
     frame_tensor = frame2tensor(frame, device)
-    last_data = matching.superpoint({'image': frame_tensor})
-    last_data = {k+'0': last_data[k] for k in keys}
-    last_data['image0'] = frame_tensor
+    last_data = matching.superpoint({"image": frame_tensor})
+    last_data = {k + "0": last_data[k] for k in keys}
+    last_data["image0"] = frame_tensor
     last_frame = frame
     last_image_id = 0
 
     if opt.output_dir is not None:
-        print('==> Will write outputs to {}'.format(opt.output_dir))
+        print("==> Will write outputs to {}".format(opt.output_dir))
         Path(opt.output_dir).mkdir(exist_ok=True)
 
     # Create a window to display the demo.
     if not opt.no_display:
-        cv2.namedWindow('SuperGlue matches', cv2.WINDOW_NORMAL)
-        cv2.resizeWindow('SuperGlue matches', 640*2, 480)
+        cv2.namedWindow("SuperGlue matches", cv2.WINDOW_NORMAL)
+        cv2.resizeWindow("SuperGlue matches", 640 * 2, 480)
     else:
-        print('Skipping visualization, will not show a GUI.')
+        print("Skipping visualization, will not show a GUI.")
 
     # Print the keyboard help menu.
-    print('==> Keyboard control:\n'
-          '\tn: select the current frame as the anchor\n'
-          '\te/r: increase/decrease the keypoint confidence threshold\n'
-          '\td/f: increase/decrease the match filtering threshold\n'
-          '\tk: toggle the visualization of keypoints\n'
-          '\tq: quit')
+    print(
+        "==> Keyboard control:\n"
+        "\tn: select the current frame as the anchor\n"
+        "\te/r: increase/decrease the keypoint confidence threshold\n"
+        "\td/f: increase/decrease the match filtering threshold\n"
+        "\tk: toggle the visualization of keypoints\n"
+        "\tq: quit"
+    )
 
     timer = AverageTimer()
 
     while True:
         frame, ret = vs.next_frame()
         if not ret:
-            print('Finished demo_superglue.py')
+            print("Finished demo_superglue.py")
             break
-        timer.update('data')
+        timer.update("data")
         stem0, stem1 = last_image_id, vs.i - 1
 
         frame_tensor = frame2tensor(frame, device)
-        pred = matching({**last_data, 'image1': frame_tensor})
-        kpts0 = last_data['keypoints0'][0].cpu().numpy()
-        kpts1 = pred['keypoints1'][0].cpu().numpy()
-        matches = pred['matches0'][0].cpu().numpy()
-        confidence = pred['matching_scores0'][0].cpu().numpy()
-        timer.update('forward')
+        pred = matching({**last_data, "image1": frame_tensor})
+        kpts0 = last_data["keypoints0"][0].cpu().numpy()
+        kpts1 = pred["keypoints1"][0].cpu().numpy()
+        matches = pred["matches0"][0].cpu().numpy()
+        confidence = pred["matching_scores0"][0].cpu().numpy()
+        timer.update("forward")
 
         valid = matches > -1
         mkpts0 = kpts0[valid]
         mkpts1 = kpts1[matches[valid]]
         color = cm.jet(confidence[valid])
         text = [
-            'SuperGlue',
-            'Keypoints: {}:{}'.format(len(kpts0), len(kpts1)),
-            'Matches: {}'.format(len(mkpts0))
+            "SuperGlue",
+            "Keypoints: {}:{}".format(len(kpts0), len(kpts1)),
+            "Matches: {}".format(len(mkpts0)),
         ]
-        k_thresh = matching.superpoint.config['keypoint_threshold']
-        m_thresh = matching.superglue.config['match_threshold']
+        k_thresh = matching.superpoint.config["keypoint_threshold"]
+        m_thresh = matching.superglue.config["match_threshold"]
         small_text = [
-            'Keypoint Threshold: {:.4f}'.format(k_thresh),
-            'Match Threshold: {:.2f}'.format(m_thresh),
-            'Image Pair: {:06}:{:06}'.format(stem0, stem1),
+            "Keypoint Threshold: {:.4f}".format(k_thresh),
+            "Match Threshold: {:.2f}".format(m_thresh),
+            "Image Pair: {:06}:{:06}".format(stem0, stem1),
         ]
         out = make_matching_plot_fast(
-            last_frame, frame, kpts0, kpts1, mkpts0, mkpts1, color, text,
-            path=None, show_keypoints=opt.show_keypoints, small_text=small_text)
+            last_frame,
+            frame,
+            kpts0,
+            kpts1,
+            mkpts0,
+            mkpts1,
+            color,
+            text,
+            path=None,
+            show_keypoints=opt.show_keypoints,
+            small_text=small_text,
+        )
 
         if not opt.no_display:
-            cv2.imshow('SuperGlue matches', out)
+            cv2.imshow("SuperGlue matches", out)
             key = chr(cv2.waitKey(1) & 0xFF)
-            if key == 'q':
+            if key == "q":
                 vs.cleanup()
-                print('Exiting (via q) demo_superglue.py')
+                print("Exiting (via q) demo_superglue.py")
                 break
-            elif key == 'n':  # set the current frame as anchor
-                last_data = {k+'0': pred[k+'1'] for k in keys}
-                last_data['image0'] = frame_tensor
+            elif key == "n":  # set the current frame as anchor
+                last_data = {k + "0": pred[k + "1"] for k in keys}
+                last_data["image0"] = frame_tensor
                 last_frame = frame
-                last_image_id = (vs.i - 1)
-            elif key in ['e', 'r']:
+                last_image_id = vs.i - 1
+            elif key in ["e", "r"]:
                 # Increase/decrease keypoint threshold by 10% each keypress.
-                d = 0.1 * (-1 if key == 'e' else 1)
-                matching.superpoint.config['keypoint_threshold'] = min(max(
-                    0.0001, matching.superpoint.config['keypoint_threshold']*(1+d)), 1)
-                print('\nChanged the keypoint threshold to {:.4f}'.format(
-                    matching.superpoint.config['keypoint_threshold']))
-            elif key in ['d', 'f']:
+                d = 0.1 * (-1 if key == "e" else 1)
+                matching.superpoint.config["keypoint_threshold"] = min(
+                    max(
+                        0.0001,
+                        matching.superpoint.config["keypoint_threshold"] * (1 + d),
+                    ),
+                    1,
+                )
+                print(
+                    "\nChanged the keypoint threshold to {:.4f}".format(
+                        matching.superpoint.config["keypoint_threshold"]
+                    )
+                )
+            elif key in ["d", "f"]:
                 # Increase/decrease match threshold by 0.05 each keypress.
-                d = 0.05 * (-1 if key == 'd' else 1)
-                matching.superglue.config['match_threshold'] = min(max(
-                    0.05, matching.superglue.config['match_threshold']+d), .95)
-                print('\nChanged the match threshold to {:.2f}'.format(
-                    matching.superglue.config['match_threshold']))
-            elif key == 'k':
+                d = 0.05 * (-1 if key == "d" else 1)
+                matching.superglue.config["match_threshold"] = min(
+                    max(0.05, matching.superglue.config["match_threshold"] + d), 0.95
+                )
+                print(
+                    "\nChanged the match threshold to {:.2f}".format(
+                        matching.superglue.config["match_threshold"]
+                    )
+                )
+            elif key == "k":
                 opt.show_keypoints = not opt.show_keypoints
 
-        timer.update('viz')
+        timer.update("viz")
         timer.print()
 
         if opt.output_dir is not None:
-            #stem = 'matches_{:06}_{:06}'.format(last_image_id, vs.i-1)
-            stem = 'matches_{:06}_{:06}'.format(stem0, stem1)
-            out_file = str(Path(opt.output_dir, stem + '.png'))
-            print('\nWriting image to {}'.format(out_file))
+            # stem = 'matches_{:06}_{:06}'.format(last_image_id, vs.i-1)
+            stem = "matches_{:06}_{:06}".format(stem0, stem1)
+            out_file = str(Path(opt.output_dir, stem + ".png"))
+            print("\nWriting image to {}".format(out_file))
             cv2.imwrite(out_file, out)
 
     cv2.destroyAllWindows()
diff --git a/third_party/SuperGluePretrainedNetwork/match_pairs.py b/third_party/SuperGluePretrainedNetwork/match_pairs.py
index 7079687cf69fd71d810ec80442548ad2a7b869e0..9dcbcadd3ca8efc053cf4ea33c825ff75728bef1 100644
--- a/third_party/SuperGluePretrainedNetwork/match_pairs.py
+++ b/third_party/SuperGluePretrainedNetwork/match_pairs.py
@@ -53,118 +53,176 @@ import torch
 
 
 from models.matching import Matching
-from models.utils import (compute_pose_error, compute_epipolar_error,
-                          estimate_pose, make_matching_plot,
-                          error_colormap, AverageTimer, pose_auc, read_image,
-                          rotate_intrinsics, rotate_pose_inplane,
-                          scale_intrinsics)
+from models.utils import (
+    compute_pose_error,
+    compute_epipolar_error,
+    estimate_pose,
+    make_matching_plot,
+    error_colormap,
+    AverageTimer,
+    pose_auc,
+    read_image,
+    rotate_intrinsics,
+    rotate_pose_inplane,
+    scale_intrinsics,
+)
 
 torch.set_grad_enabled(False)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     parser = argparse.ArgumentParser(
-        description='Image pair matching and pose evaluation with SuperGlue',
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+        description="Image pair matching and pose evaluation with SuperGlue",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
 
     parser.add_argument(
-        '--input_pairs', type=str, default='assets/scannet_sample_pairs_with_gt.txt',
-        help='Path to the list of image pairs')
+        "--input_pairs",
+        type=str,
+        default="assets/scannet_sample_pairs_with_gt.txt",
+        help="Path to the list of image pairs",
+    )
     parser.add_argument(
-        '--input_dir', type=str, default='assets/scannet_sample_images/',
-        help='Path to the directory that contains the images')
+        "--input_dir",
+        type=str,
+        default="assets/scannet_sample_images/",
+        help="Path to the directory that contains the images",
+    )
     parser.add_argument(
-        '--output_dir', type=str, default='dump_match_pairs/',
-        help='Path to the directory in which the .npz results and optionally,'
-             'the visualization images are written')
+        "--output_dir",
+        type=str,
+        default="dump_match_pairs/",
+        help="Path to the directory in which the .npz results and optionally,"
+        "the visualization images are written",
+    )
 
     parser.add_argument(
-        '--max_length', type=int, default=-1,
-        help='Maximum number of pairs to evaluate')
+        "--max_length", type=int, default=-1, help="Maximum number of pairs to evaluate"
+    )
     parser.add_argument(
-        '--resize', type=int, nargs='+', default=[640, 480],
-        help='Resize the input image before running inference. If two numbers, '
-             'resize to the exact dimensions, if one number, resize the max '
-             'dimension, if -1, do not resize')
+        "--resize",
+        type=int,
+        nargs="+",
+        default=[640, 480],
+        help="Resize the input image before running inference. If two numbers, "
+        "resize to the exact dimensions, if one number, resize the max "
+        "dimension, if -1, do not resize",
+    )
     parser.add_argument(
-        '--resize_float', action='store_true',
-        help='Resize the image after casting uint8 to float')
+        "--resize_float",
+        action="store_true",
+        help="Resize the image after casting uint8 to float",
+    )
 
     parser.add_argument(
-        '--superglue', choices={'indoor', 'outdoor'}, default='indoor',
-        help='SuperGlue weights')
+        "--superglue",
+        choices={"indoor", "outdoor"},
+        default="indoor",
+        help="SuperGlue weights",
+    )
     parser.add_argument(
-        '--max_keypoints', type=int, default=1024,
-        help='Maximum number of keypoints detected by Superpoint'
-             ' (\'-1\' keeps all keypoints)')
+        "--max_keypoints",
+        type=int,
+        default=1024,
+        help="Maximum number of keypoints detected by Superpoint"
+        " ('-1' keeps all keypoints)",
+    )
     parser.add_argument(
-        '--keypoint_threshold', type=float, default=0.005,
-        help='SuperPoint keypoint detector confidence threshold')
+        "--keypoint_threshold",
+        type=float,
+        default=0.005,
+        help="SuperPoint keypoint detector confidence threshold",
+    )
     parser.add_argument(
-        '--nms_radius', type=int, default=4,
-        help='SuperPoint Non Maximum Suppression (NMS) radius'
-        ' (Must be positive)')
+        "--nms_radius",
+        type=int,
+        default=4,
+        help="SuperPoint Non Maximum Suppression (NMS) radius" " (Must be positive)",
+    )
     parser.add_argument(
-        '--sinkhorn_iterations', type=int, default=20,
-        help='Number of Sinkhorn iterations performed by SuperGlue')
+        "--sinkhorn_iterations",
+        type=int,
+        default=20,
+        help="Number of Sinkhorn iterations performed by SuperGlue",
+    )
     parser.add_argument(
-        '--match_threshold', type=float, default=0.2,
-        help='SuperGlue match threshold')
+        "--match_threshold", type=float, default=0.2, help="SuperGlue match threshold"
+    )
 
     parser.add_argument(
-        '--viz', action='store_true',
-        help='Visualize the matches and dump the plots')
+        "--viz", action="store_true", help="Visualize the matches and dump the plots"
+    )
     parser.add_argument(
-        '--eval', action='store_true',
-        help='Perform the evaluation'
-             ' (requires ground truth pose and intrinsics)')
+        "--eval",
+        action="store_true",
+        help="Perform the evaluation" " (requires ground truth pose and intrinsics)",
+    )
     parser.add_argument(
-        '--fast_viz', action='store_true',
-        help='Use faster image visualization with OpenCV instead of Matplotlib')
+        "--fast_viz",
+        action="store_true",
+        help="Use faster image visualization with OpenCV instead of Matplotlib",
+    )
     parser.add_argument(
-        '--cache', action='store_true',
-        help='Skip the pair if output .npz files are already found')
+        "--cache",
+        action="store_true",
+        help="Skip the pair if output .npz files are already found",
+    )
     parser.add_argument(
-        '--show_keypoints', action='store_true',
-        help='Plot the keypoints in addition to the matches')
+        "--show_keypoints",
+        action="store_true",
+        help="Plot the keypoints in addition to the matches",
+    )
     parser.add_argument(
-        '--viz_extension', type=str, default='png', choices=['png', 'pdf'],
-        help='Visualization file extension. Use pdf for highest-quality.')
+        "--viz_extension",
+        type=str,
+        default="png",
+        choices=["png", "pdf"],
+        help="Visualization file extension. Use pdf for highest-quality.",
+    )
     parser.add_argument(
-        '--opencv_display', action='store_true',
-        help='Visualize via OpenCV before saving output images')
+        "--opencv_display",
+        action="store_true",
+        help="Visualize via OpenCV before saving output images",
+    )
     parser.add_argument(
-        '--shuffle', action='store_true',
-        help='Shuffle ordering of pairs before processing')
+        "--shuffle",
+        action="store_true",
+        help="Shuffle ordering of pairs before processing",
+    )
     parser.add_argument(
-        '--force_cpu', action='store_true',
-        help='Force pytorch to run in CPU mode.')
+        "--force_cpu", action="store_true", help="Force pytorch to run in CPU mode."
+    )
 
     opt = parser.parse_args()
     print(opt)
 
-    assert not (opt.opencv_display and not opt.viz), 'Must use --viz with --opencv_display'
-    assert not (opt.opencv_display and not opt.fast_viz), 'Cannot use --opencv_display without --fast_viz'
-    assert not (opt.fast_viz and not opt.viz), 'Must use --viz with --fast_viz'
-    assert not (opt.fast_viz and opt.viz_extension == 'pdf'), 'Cannot use pdf extension with --fast_viz'
+    assert not (
+        opt.opencv_display and not opt.viz
+    ), "Must use --viz with --opencv_display"
+    assert not (
+        opt.opencv_display and not opt.fast_viz
+    ), "Cannot use --opencv_display without --fast_viz"
+    assert not (opt.fast_viz and not opt.viz), "Must use --viz with --fast_viz"
+    assert not (
+        opt.fast_viz and opt.viz_extension == "pdf"
+    ), "Cannot use pdf extension with --fast_viz"
 
     if len(opt.resize) == 2 and opt.resize[1] == -1:
         opt.resize = opt.resize[0:1]
     if len(opt.resize) == 2:
-        print('Will resize to {}x{} (WxH)'.format(
-            opt.resize[0], opt.resize[1]))
+        print("Will resize to {}x{} (WxH)".format(opt.resize[0], opt.resize[1]))
     elif len(opt.resize) == 1 and opt.resize[0] > 0:
-        print('Will resize max dimension to {}'.format(opt.resize[0]))
+        print("Will resize max dimension to {}".format(opt.resize[0]))
     elif len(opt.resize) == 1:
-        print('Will not resize images')
+        print("Will not resize images")
     else:
-        raise ValueError('Cannot specify more than two integers for --resize')
+        raise ValueError("Cannot specify more than two integers for --resize")
 
-    with open(opt.input_pairs, 'r') as f:
+    with open(opt.input_pairs, "r") as f:
         pairs = [l.split() for l in f.readlines()]
 
     if opt.max_length > -1:
-        pairs = pairs[0:np.min([len(pairs), opt.max_length])]
+        pairs = pairs[0 : np.min([len(pairs), opt.max_length])]
 
     if opt.shuffle:
         random.Random(0).shuffle(pairs)
@@ -172,48 +230,50 @@ if __name__ == '__main__':
     if opt.eval:
         if not all([len(p) == 38 for p in pairs]):
             raise ValueError(
-                'All pairs should have ground truth info for evaluation.'
-                'File \"{}\" needs 38 valid entries per row'.format(opt.input_pairs))
+                "All pairs should have ground truth info for evaluation."
+                'File "{}" needs 38 valid entries per row'.format(opt.input_pairs)
+            )
 
     # Load the SuperPoint and SuperGlue models.
-    device = 'cuda' if torch.cuda.is_available() and not opt.force_cpu else 'cpu'
-    print('Running inference on device \"{}\"'.format(device))
+    device = "cuda" if torch.cuda.is_available() and not opt.force_cpu else "cpu"
+    print('Running inference on device "{}"'.format(device))
     config = {
-        'superpoint': {
-            'nms_radius': opt.nms_radius,
-            'keypoint_threshold': opt.keypoint_threshold,
-            'max_keypoints': opt.max_keypoints
+        "superpoint": {
+            "nms_radius": opt.nms_radius,
+            "keypoint_threshold": opt.keypoint_threshold,
+            "max_keypoints": opt.max_keypoints,
+        },
+        "superglue": {
+            "weights": opt.superglue,
+            "sinkhorn_iterations": opt.sinkhorn_iterations,
+            "match_threshold": opt.match_threshold,
         },
-        'superglue': {
-            'weights': opt.superglue,
-            'sinkhorn_iterations': opt.sinkhorn_iterations,
-            'match_threshold': opt.match_threshold,
-        }
     }
     matching = Matching(config).eval().to(device)
 
     # Create the output directories if they do not exist already.
     input_dir = Path(opt.input_dir)
-    print('Looking for data in directory \"{}\"'.format(input_dir))
+    print('Looking for data in directory "{}"'.format(input_dir))
     output_dir = Path(opt.output_dir)
     output_dir.mkdir(exist_ok=True, parents=True)
-    print('Will write matches to directory \"{}\"'.format(output_dir))
+    print('Will write matches to directory "{}"'.format(output_dir))
     if opt.eval:
-        print('Will write evaluation results',
-              'to directory \"{}\"'.format(output_dir))
+        print("Will write evaluation results", 'to directory "{}"'.format(output_dir))
     if opt.viz:
-        print('Will write visualization images to',
-              'directory \"{}\"'.format(output_dir))
+        print("Will write visualization images to", 'directory "{}"'.format(output_dir))
 
     timer = AverageTimer(newline=True)
     for i, pair in enumerate(pairs):
         name0, name1 = pair[:2]
         stem0, stem1 = Path(name0).stem, Path(name1).stem
-        matches_path = output_dir / '{}_{}_matches.npz'.format(stem0, stem1)
-        eval_path = output_dir / '{}_{}_evaluation.npz'.format(stem0, stem1)
-        viz_path = output_dir / '{}_{}_matches.{}'.format(stem0, stem1, opt.viz_extension)
-        viz_eval_path = output_dir / \
-            '{}_{}_evaluation.{}'.format(stem0, stem1, opt.viz_extension)
+        matches_path = output_dir / "{}_{}_matches.npz".format(stem0, stem1)
+        eval_path = output_dir / "{}_{}_evaluation.npz".format(stem0, stem1)
+        viz_path = output_dir / "{}_{}_matches.{}".format(
+            stem0, stem1, opt.viz_extension
+        )
+        viz_eval_path = output_dir / "{}_{}_evaluation.{}".format(
+            stem0, stem1, opt.viz_extension
+        )
 
         # Handle --cache logic.
         do_match = True
@@ -225,31 +285,30 @@ if __name__ == '__main__':
                 try:
                     results = np.load(matches_path)
                 except:
-                    raise IOError('Cannot load matches .npz file: %s' %
-                                  matches_path)
+                    raise IOError("Cannot load matches .npz file: %s" % matches_path)
 
-                kpts0, kpts1 = results['keypoints0'], results['keypoints1']
-                matches, conf = results['matches'], results['match_confidence']
+                kpts0, kpts1 = results["keypoints0"], results["keypoints1"]
+                matches, conf = results["matches"], results["match_confidence"]
                 do_match = False
             if opt.eval and eval_path.exists():
                 try:
                     results = np.load(eval_path)
                 except:
-                    raise IOError('Cannot load eval .npz file: %s' % eval_path)
-                err_R, err_t = results['error_R'], results['error_t']
-                precision = results['precision']
-                matching_score = results['matching_score']
-                num_correct = results['num_correct']
-                epi_errs = results['epipolar_errors']
+                    raise IOError("Cannot load eval .npz file: %s" % eval_path)
+                err_R, err_t = results["error_R"], results["error_t"]
+                precision = results["precision"]
+                matching_score = results["matching_score"]
+                num_correct = results["num_correct"]
+                epi_errs = results["epipolar_errors"]
                 do_eval = False
             if opt.viz and viz_path.exists():
                 do_viz = False
             if opt.viz and opt.eval and viz_eval_path.exists():
                 do_viz_eval = False
-            timer.update('load_cache')
+            timer.update("load_cache")
 
         if not (do_match or do_eval or do_viz or do_viz_eval):
-            timer.print('Finished pair {:5} of {:5}'.format(i, len(pairs)))
+            timer.print("Finished pair {:5} of {:5}".format(i, len(pairs)))
             continue
 
         # If a rotation integer is provided (e.g. from EXIF data), use it:
@@ -260,26 +319,35 @@ if __name__ == '__main__':
 
         # Load the image pair.
         image0, inp0, scales0 = read_image(
-            input_dir / name0, device, opt.resize, rot0, opt.resize_float)
+            input_dir / name0, device, opt.resize, rot0, opt.resize_float
+        )
         image1, inp1, scales1 = read_image(
-            input_dir / name1, device, opt.resize, rot1, opt.resize_float)
+            input_dir / name1, device, opt.resize, rot1, opt.resize_float
+        )
         if image0 is None or image1 is None:
-            print('Problem reading image pair: {} {}'.format(
-                input_dir/name0, input_dir/name1))
+            print(
+                "Problem reading image pair: {} {}".format(
+                    input_dir / name0, input_dir / name1
+                )
+            )
             exit(1)
-        timer.update('load_image')
+        timer.update("load_image")
 
         if do_match:
             # Perform the matching.
-            pred = matching({'image0': inp0, 'image1': inp1})
+            pred = matching({"image0": inp0, "image1": inp1})
             pred = {k: v[0].cpu().numpy() for k, v in pred.items()}
-            kpts0, kpts1 = pred['keypoints0'], pred['keypoints1']
-            matches, conf = pred['matches0'], pred['matching_scores0']
-            timer.update('matcher')
+            kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"]
+            matches, conf = pred["matches0"], pred["matching_scores0"]
+            timer.update("matcher")
 
             # Write the matches to disk.
-            out_matches = {'keypoints0': kpts0, 'keypoints1': kpts1,
-                           'matches': matches, 'match_confidence': conf}
+            out_matches = {
+                "keypoints0": kpts0,
+                "keypoints1": kpts1,
+                "matches": matches,
+                "match_confidence": conf,
+            }
             np.savez(str(matches_path), **out_matches)
 
         # Keep the matching keypoints.
@@ -290,7 +358,7 @@ if __name__ == '__main__':
 
         if do_eval:
             # Estimate the pose and compute the pose error.
-            assert len(pair) == 38, 'Pair does not have ground truth info'
+            assert len(pair) == 38, "Pair does not have ground truth info"
             K0 = np.array(pair[4:13]).astype(float).reshape(3, 3)
             K1 = np.array(pair[13:22]).astype(float).reshape(3, 3)
             T_0to1 = np.array(pair[22:]).astype(float).reshape(4, 4)
@@ -318,7 +386,7 @@ if __name__ == '__main__':
             precision = np.mean(correct) if len(correct) > 0 else 0
             matching_score = num_correct / len(kpts0) if len(kpts0) > 0 else 0
 
-            thresh = 1.  # In pixels relative to resized image size.
+            thresh = 1.0  # In pixels relative to resized image size.
             ret = estimate_pose(mkpts0, mkpts1, K0, K1, thresh)
             if ret is None:
                 err_t, err_R = np.inf, np.inf
@@ -327,77 +395,103 @@ if __name__ == '__main__':
                 err_t, err_R = compute_pose_error(T_0to1, R, t)
 
             # Write the evaluation results to disk.
-            out_eval = {'error_t': err_t,
-                        'error_R': err_R,
-                        'precision': precision,
-                        'matching_score': matching_score,
-                        'num_correct': num_correct,
-                        'epipolar_errors': epi_errs}
+            out_eval = {
+                "error_t": err_t,
+                "error_R": err_R,
+                "precision": precision,
+                "matching_score": matching_score,
+                "num_correct": num_correct,
+                "epipolar_errors": epi_errs,
+            }
             np.savez(str(eval_path), **out_eval)
-            timer.update('eval')
+            timer.update("eval")
 
         if do_viz:
             # Visualize the matches.
             color = cm.jet(mconf)
             text = [
-                'SuperGlue',
-                'Keypoints: {}:{}'.format(len(kpts0), len(kpts1)),
-                'Matches: {}'.format(len(mkpts0)),
+                "SuperGlue",
+                "Keypoints: {}:{}".format(len(kpts0), len(kpts1)),
+                "Matches: {}".format(len(mkpts0)),
             ]
             if rot0 != 0 or rot1 != 0:
-                text.append('Rotation: {}:{}'.format(rot0, rot1))
+                text.append("Rotation: {}:{}".format(rot0, rot1))
 
             # Display extra parameter info.
-            k_thresh = matching.superpoint.config['keypoint_threshold']
-            m_thresh = matching.superglue.config['match_threshold']
+            k_thresh = matching.superpoint.config["keypoint_threshold"]
+            m_thresh = matching.superglue.config["match_threshold"]
             small_text = [
-                'Keypoint Threshold: {:.4f}'.format(k_thresh),
-                'Match Threshold: {:.2f}'.format(m_thresh),
-                'Image Pair: {}:{}'.format(stem0, stem1),
+                "Keypoint Threshold: {:.4f}".format(k_thresh),
+                "Match Threshold: {:.2f}".format(m_thresh),
+                "Image Pair: {}:{}".format(stem0, stem1),
             ]
 
             make_matching_plot(
-                image0, image1, kpts0, kpts1, mkpts0, mkpts1, color,
-                text, viz_path, opt.show_keypoints,
-                opt.fast_viz, opt.opencv_display, 'Matches', small_text)
-
-            timer.update('viz_match')
+                image0,
+                image1,
+                kpts0,
+                kpts1,
+                mkpts0,
+                mkpts1,
+                color,
+                text,
+                viz_path,
+                opt.show_keypoints,
+                opt.fast_viz,
+                opt.opencv_display,
+                "Matches",
+                small_text,
+            )
+
+            timer.update("viz_match")
 
         if do_viz_eval:
             # Visualize the evaluation results for the image pair.
             color = np.clip((epi_errs - 0) / (1e-3 - 0), 0, 1)
             color = error_colormap(1 - color)
-            deg, delta = ' deg', 'Delta '
+            deg, delta = " deg", "Delta "
             if not opt.fast_viz:
-                deg, delta = '°', '$\\Delta$'
-            e_t = 'FAIL' if np.isinf(err_t) else '{:.1f}{}'.format(err_t, deg)
-            e_R = 'FAIL' if np.isinf(err_R) else '{:.1f}{}'.format(err_R, deg)
+                deg, delta = "°", "$\\Delta$"
+            e_t = "FAIL" if np.isinf(err_t) else "{:.1f}{}".format(err_t, deg)
+            e_R = "FAIL" if np.isinf(err_R) else "{:.1f}{}".format(err_R, deg)
             text = [
-                'SuperGlue',
-                '{}R: {}'.format(delta, e_R), '{}t: {}'.format(delta, e_t),
-                'inliers: {}/{}'.format(num_correct, (matches > -1).sum()),
+                "SuperGlue",
+                "{}R: {}".format(delta, e_R),
+                "{}t: {}".format(delta, e_t),
+                "inliers: {}/{}".format(num_correct, (matches > -1).sum()),
             ]
             if rot0 != 0 or rot1 != 0:
-                text.append('Rotation: {}:{}'.format(rot0, rot1))
+                text.append("Rotation: {}:{}".format(rot0, rot1))
 
             # Display extra parameter info (only works with --fast_viz).
-            k_thresh = matching.superpoint.config['keypoint_threshold']
-            m_thresh = matching.superglue.config['match_threshold']
+            k_thresh = matching.superpoint.config["keypoint_threshold"]
+            m_thresh = matching.superglue.config["match_threshold"]
             small_text = [
-                'Keypoint Threshold: {:.4f}'.format(k_thresh),
-                'Match Threshold: {:.2f}'.format(m_thresh),
-                'Image Pair: {}:{}'.format(stem0, stem1),
+                "Keypoint Threshold: {:.4f}".format(k_thresh),
+                "Match Threshold: {:.2f}".format(m_thresh),
+                "Image Pair: {}:{}".format(stem0, stem1),
             ]
 
             make_matching_plot(
-                image0, image1, kpts0, kpts1, mkpts0,
-                mkpts1, color, text, viz_eval_path,
-                opt.show_keypoints, opt.fast_viz,
-                opt.opencv_display, 'Relative Pose', small_text)
-
-            timer.update('viz_eval')
-
-        timer.print('Finished pair {:5} of {:5}'.format(i, len(pairs)))
+                image0,
+                image1,
+                kpts0,
+                kpts1,
+                mkpts0,
+                mkpts1,
+                color,
+                text,
+                viz_eval_path,
+                opt.show_keypoints,
+                opt.fast_viz,
+                opt.opencv_display,
+                "Relative Pose",
+                small_text,
+            )
+
+            timer.update("viz_eval")
+
+        timer.print("Finished pair {:5} of {:5}".format(i, len(pairs)))
 
     if opt.eval:
         # Collate the results into a final table and print to terminal.
@@ -407,19 +501,21 @@ if __name__ == '__main__':
         for pair in pairs:
             name0, name1 = pair[:2]
             stem0, stem1 = Path(name0).stem, Path(name1).stem
-            eval_path = output_dir / \
-                '{}_{}_evaluation.npz'.format(stem0, stem1)
+            eval_path = output_dir / "{}_{}_evaluation.npz".format(stem0, stem1)
             results = np.load(eval_path)
-            pose_error = np.maximum(results['error_t'], results['error_R'])
+            pose_error = np.maximum(results["error_t"], results["error_R"])
             pose_errors.append(pose_error)
-            precisions.append(results['precision'])
-            matching_scores.append(results['matching_score'])
+            precisions.append(results["precision"])
+            matching_scores.append(results["matching_score"])
         thresholds = [5, 10, 20]
         aucs = pose_auc(pose_errors, thresholds)
-        aucs = [100.*yy for yy in aucs]
-        prec = 100.*np.mean(precisions)
-        ms = 100.*np.mean(matching_scores)
-        print('Evaluation Results (mean over {} pairs):'.format(len(pairs)))
-        print('AUC@5\t AUC@10\t AUC@20\t Prec\t MScore\t')
-        print('{:.2f}\t {:.2f}\t {:.2f}\t {:.2f}\t {:.2f}\t'.format(
-            aucs[0], aucs[1], aucs[2], prec, ms))
+        aucs = [100.0 * yy for yy in aucs]
+        prec = 100.0 * np.mean(precisions)
+        ms = 100.0 * np.mean(matching_scores)
+        print("Evaluation Results (mean over {} pairs):".format(len(pairs)))
+        print("AUC@5\t AUC@10\t AUC@20\t Prec\t MScore\t")
+        print(
+            "{:.2f}\t {:.2f}\t {:.2f}\t {:.2f}\t {:.2f}\t".format(
+                aucs[0], aucs[1], aucs[2], prec, ms
+            )
+        )
diff --git a/third_party/SuperGluePretrainedNetwork/models/matching.py b/third_party/SuperGluePretrainedNetwork/models/matching.py
index 5d174208d146373230a8a68dd1420fc59c180633..c5c0eda3337d021464eb6283e57b7412c08afb03 100644
--- a/third_party/SuperGluePretrainedNetwork/models/matching.py
+++ b/third_party/SuperGluePretrainedNetwork/models/matching.py
@@ -47,14 +47,15 @@ from .superglue import SuperGlue
 
 
 class Matching(torch.nn.Module):
-    """ Image Matching Frontend (SuperPoint + SuperGlue) """
+    """Image Matching Frontend (SuperPoint + SuperGlue)"""
+
     def __init__(self, config={}):
         super().__init__()
-        self.superpoint = SuperPoint(config.get('superpoint', {}))
-        self.superglue = SuperGlue(config.get('superglue', {}))
+        self.superpoint = SuperPoint(config.get("superpoint", {}))
+        self.superglue = SuperGlue(config.get("superglue", {}))
 
     def forward(self, data):
-        """ Run SuperPoint (optionally) and SuperGlue
+        """Run SuperPoint (optionally) and SuperGlue
         SuperPoint is skipped if ['keypoints0', 'keypoints1'] exist in input
         Args:
           data: dictionary with minimal keys: ['image0', 'image1']
@@ -62,12 +63,12 @@ class Matching(torch.nn.Module):
         pred = {}
 
         # Extract SuperPoint (keypoints, scores, descriptors) if not provided
-        if 'keypoints0' not in data:
-            pred0 = self.superpoint({'image': data['image0']})
-            pred = {**pred, **{k+'0': v for k, v in pred0.items()}}
-        if 'keypoints1' not in data:
-            pred1 = self.superpoint({'image': data['image1']})
-            pred = {**pred, **{k+'1': v for k, v in pred1.items()}}
+        if "keypoints0" not in data:
+            pred0 = self.superpoint({"image": data["image0"]})
+            pred = {**pred, **{k + "0": v for k, v in pred0.items()}}
+        if "keypoints1" not in data:
+            pred1 = self.superpoint({"image": data["image1"]})
+            pred = {**pred, **{k + "1": v for k, v in pred1.items()}}
 
         # Batch all features
         # We should either have i) one image per batch, or
diff --git a/third_party/SuperGluePretrainedNetwork/models/superglue.py b/third_party/SuperGluePretrainedNetwork/models/superglue.py
index 5316234dee9be9cdc083e3b4bebe97a6e51e587d..70156e07b83614b1dfb36207ea96b4b79a6ddbb9 100644
--- a/third_party/SuperGluePretrainedNetwork/models/superglue.py
+++ b/third_party/SuperGluePretrainedNetwork/models/superglue.py
@@ -49,13 +49,12 @@ from torch import nn
 
 
 def MLP(channels: List[int], do_bn: bool = True) -> nn.Module:
-    """ Multi-layer perceptron """
+    """Multi-layer perceptron"""
     n = len(channels)
     layers = []
     for i in range(1, n):
-        layers.append(
-            nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
-        if i < (n-1):
+        layers.append(nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
+        if i < (n - 1):
             if do_bn:
                 layers.append(nn.BatchNorm1d(channels[i]))
             layers.append(nn.ReLU())
@@ -63,17 +62,18 @@ def MLP(channels: List[int], do_bn: bool = True) -> nn.Module:
 
 
 def normalize_keypoints(kpts, image_shape):
-    """ Normalize keypoints locations based on image image_shape"""
+    """Normalize keypoints locations based on image image_shape"""
     _, _, height, width = image_shape
     one = kpts.new_tensor(1)
-    size = torch.stack([one*width, one*height])[None]
+    size = torch.stack([one * width, one * height])[None]
     center = size / 2
     scaling = size.max(1, keepdim=True).values * 0.7
     return (kpts - center[:, None, :]) / scaling[:, None, :]
 
 
 class KeypointEncoder(nn.Module):
-    """ Joint encoding of visual appearance and location using MLPs"""
+    """Joint encoding of visual appearance and location using MLPs"""
+
     def __init__(self, feature_dim: int, layers: List[int]) -> None:
         super().__init__()
         self.encoder = MLP([3] + layers + [feature_dim])
@@ -84,15 +84,18 @@ class KeypointEncoder(nn.Module):
         return self.encoder(torch.cat(inputs, dim=1))
 
 
-def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor,torch.Tensor]:
+def attention(
+    query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+) -> Tuple[torch.Tensor, torch.Tensor]:
     dim = query.shape[1]
-    scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5
+    scores = torch.einsum("bdhn,bdhm->bhnm", query, key) / dim**0.5
     prob = torch.nn.functional.softmax(scores, dim=-1)
-    return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob
+    return torch.einsum("bhnm,bdhm->bdhn", prob, value), prob
 
 
 class MultiHeadedAttention(nn.Module):
-    """ Multi-head attention to increase model expressivitiy """
+    """Multi-head attention to increase model expressivitiy"""
+
     def __init__(self, num_heads: int, d_model: int):
         super().__init__()
         assert d_model % num_heads == 0
@@ -101,19 +104,23 @@ class MultiHeadedAttention(nn.Module):
         self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
         self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
 
-    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
+    def forward(
+        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+    ) -> torch.Tensor:
         batch_dim = query.size(0)
-        query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)
-                             for l, x in zip(self.proj, (query, key, value))]
+        query, key, value = [
+            l(x).view(batch_dim, self.dim, self.num_heads, -1)
+            for l, x in zip(self.proj, (query, key, value))
+        ]
         x, _ = attention(query, key, value)
-        return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1))
+        return self.merge(x.contiguous().view(batch_dim, self.dim * self.num_heads, -1))
 
 
 class AttentionalPropagation(nn.Module):
     def __init__(self, feature_dim: int, num_heads: int):
         super().__init__()
         self.attn = MultiHeadedAttention(num_heads, feature_dim)
-        self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim])
+        self.mlp = MLP([feature_dim * 2, feature_dim * 2, feature_dim])
         nn.init.constant_(self.mlp[-1].bias, 0.0)
 
     def forward(self, x: torch.Tensor, source: torch.Tensor) -> torch.Tensor:
@@ -124,14 +131,16 @@ class AttentionalPropagation(nn.Module):
 class AttentionalGNN(nn.Module):
     def __init__(self, feature_dim: int, layer_names: List[str]) -> None:
         super().__init__()
-        self.layers = nn.ModuleList([
-            AttentionalPropagation(feature_dim, 4)
-            for _ in range(len(layer_names))])
+        self.layers = nn.ModuleList(
+            [AttentionalPropagation(feature_dim, 4) for _ in range(len(layer_names))]
+        )
         self.names = layer_names
 
-    def forward(self, desc0: torch.Tensor, desc1: torch.Tensor) -> Tuple[torch.Tensor,torch.Tensor]:
+    def forward(
+        self, desc0: torch.Tensor, desc1: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
         for layer, name in zip(self.layers, self.names):
-            if name == 'cross':
+            if name == "cross":
                 src0, src1 = desc1, desc0
             else:  # if name == 'self':
                 src0, src1 = desc0, desc1
@@ -140,8 +149,10 @@ class AttentionalGNN(nn.Module):
         return desc0, desc1
 
 
-def log_sinkhorn_iterations(Z: torch.Tensor, log_mu: torch.Tensor, log_nu: torch.Tensor, iters: int) -> torch.Tensor:
-    """ Perform Sinkhorn Normalization in Log-space for stability"""
+def log_sinkhorn_iterations(
+    Z: torch.Tensor, log_mu: torch.Tensor, log_nu: torch.Tensor, iters: int
+) -> torch.Tensor:
+    """Perform Sinkhorn Normalization in Log-space for stability"""
     u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
     for _ in range(iters):
         u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2)
@@ -149,20 +160,23 @@ def log_sinkhorn_iterations(Z: torch.Tensor, log_mu: torch.Tensor, log_nu: torch
     return Z + u.unsqueeze(2) + v.unsqueeze(1)
 
 
-def log_optimal_transport(scores: torch.Tensor, alpha: torch.Tensor, iters: int) -> torch.Tensor:
-    """ Perform Differentiable Optimal Transport in Log-space for stability"""
+def log_optimal_transport(
+    scores: torch.Tensor, alpha: torch.Tensor, iters: int
+) -> torch.Tensor:
+    """Perform Differentiable Optimal Transport in Log-space for stability"""
     b, m, n = scores.shape
     one = scores.new_tensor(1)
-    ms, ns = (m*one).to(scores), (n*one).to(scores)
+    ms, ns = (m * one).to(scores), (n * one).to(scores)
 
     bins0 = alpha.expand(b, m, 1)
     bins1 = alpha.expand(b, 1, n)
     alpha = alpha.expand(b, 1, 1)
 
-    couplings = torch.cat([torch.cat([scores, bins0], -1),
-                           torch.cat([bins1, alpha], -1)], 1)
+    couplings = torch.cat(
+        [torch.cat([scores, bins0], -1), torch.cat([bins1, alpha], -1)], 1
+    )
 
-    norm = - (ms + ns).log()
+    norm = -(ms + ns).log()
     log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm])
     log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])
     log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)
@@ -194,13 +208,14 @@ class SuperGlue(nn.Module):
     Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763
 
     """
+
     default_config = {
-        'descriptor_dim': 256,
-        'weights': 'indoor',
-        'keypoint_encoder': [32, 64, 128, 256],
-        'GNN_layers': ['self', 'cross'] * 9,
-        'sinkhorn_iterations': 100,
-        'match_threshold': 0.2,
+        "descriptor_dim": 256,
+        "weights": "indoor",
+        "keypoint_encoder": [32, 64, 128, 256],
+        "GNN_layers": ["self", "cross"] * 9,
+        "sinkhorn_iterations": 100,
+        "match_threshold": 0.2,
     }
 
     def __init__(self, config):
@@ -208,46 +223,51 @@ class SuperGlue(nn.Module):
         self.config = {**self.default_config, **config}
 
         self.kenc = KeypointEncoder(
-            self.config['descriptor_dim'], self.config['keypoint_encoder'])
+            self.config["descriptor_dim"], self.config["keypoint_encoder"]
+        )
 
         self.gnn = AttentionalGNN(
-            feature_dim=self.config['descriptor_dim'], layer_names=self.config['GNN_layers'])
+            feature_dim=self.config["descriptor_dim"],
+            layer_names=self.config["GNN_layers"],
+        )
 
         self.final_proj = nn.Conv1d(
-            self.config['descriptor_dim'], self.config['descriptor_dim'],
-            kernel_size=1, bias=True)
+            self.config["descriptor_dim"],
+            self.config["descriptor_dim"],
+            kernel_size=1,
+            bias=True,
+        )
 
-        bin_score = torch.nn.Parameter(torch.tensor(1.))
-        self.register_parameter('bin_score', bin_score)
+        bin_score = torch.nn.Parameter(torch.tensor(1.0))
+        self.register_parameter("bin_score", bin_score)
 
-        assert self.config['weights'] in ['indoor', 'outdoor']
+        assert self.config["weights"] in ["indoor", "outdoor"]
         path = Path(__file__).parent
-        path = path / 'weights/superglue_{}.pth'.format(self.config['weights'])
+        path = path / "weights/superglue_{}.pth".format(self.config["weights"])
         self.load_state_dict(torch.load(str(path)))
-        print('Loaded SuperGlue model (\"{}\" weights)'.format(
-            self.config['weights']))
+        print('Loaded SuperGlue model ("{}" weights)'.format(self.config["weights"]))
 
     def forward(self, data):
         """Run SuperGlue on a pair of keypoints and descriptors"""
-        desc0, desc1 = data['descriptors0'], data['descriptors1']
-        kpts0, kpts1 = data['keypoints0'], data['keypoints1']
+        desc0, desc1 = data["descriptors0"], data["descriptors1"]
+        kpts0, kpts1 = data["keypoints0"], data["keypoints1"]
 
         if kpts0.shape[1] == 0 or kpts1.shape[1] == 0:  # no keypoints
             shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1]
             return {
-                'matches0': kpts0.new_full(shape0, -1, dtype=torch.int),
-                'matches1': kpts1.new_full(shape1, -1, dtype=torch.int),
-                'matching_scores0': kpts0.new_zeros(shape0),
-                'matching_scores1': kpts1.new_zeros(shape1),
+                "matches0": kpts0.new_full(shape0, -1, dtype=torch.int),
+                "matches1": kpts1.new_full(shape1, -1, dtype=torch.int),
+                "matching_scores0": kpts0.new_zeros(shape0),
+                "matching_scores1": kpts1.new_zeros(shape1),
             }
 
         # Keypoint normalization.
-        kpts0 = normalize_keypoints(kpts0, data['image0'].shape)
-        kpts1 = normalize_keypoints(kpts1, data['image1'].shape)
+        kpts0 = normalize_keypoints(kpts0, data["image0"].shape)
+        kpts1 = normalize_keypoints(kpts1, data["image1"].shape)
 
         # Keypoint MLP encoder.
-        desc0 = desc0 + self.kenc(kpts0, data['scores0'])
-        desc1 = desc1 + self.kenc(kpts1, data['scores1'])
+        desc0 = desc0 + self.kenc(kpts0, data["scores0"])
+        desc1 = desc1 + self.kenc(kpts1, data["scores1"])
 
         # Multi-layer Transformer network.
         desc0, desc1 = self.gnn(desc0, desc1)
@@ -256,13 +276,13 @@ class SuperGlue(nn.Module):
         mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
 
         # Compute matching descriptor distance.
-        scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1)
-        scores = scores / self.config['descriptor_dim']**.5
+        scores = torch.einsum("bdn,bdm->bnm", mdesc0, mdesc1)
+        scores = scores / self.config["descriptor_dim"] ** 0.5
 
         # Run the optimal transport.
         scores = log_optimal_transport(
-            scores, self.bin_score,
-            iters=self.config['sinkhorn_iterations'])
+            scores, self.bin_score, iters=self.config["sinkhorn_iterations"]
+        )
 
         # Get the matches with score above "match_threshold".
         max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
@@ -272,13 +292,13 @@ class SuperGlue(nn.Module):
         zero = scores.new_tensor(0)
         mscores0 = torch.where(mutual0, max0.values.exp(), zero)
         mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
-        valid0 = mutual0 & (mscores0 > self.config['match_threshold'])
+        valid0 = mutual0 & (mscores0 > self.config["match_threshold"])
         valid1 = mutual1 & valid0.gather(1, indices1)
         indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
         indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
         return {
-            'matches0': indices0, # use -1 for invalid match
-            'matches1': indices1, # use -1 for invalid match
-            'matching_scores0': mscores0,
-            'matching_scores1': mscores1,
+            "matches0": indices0,  # use -1 for invalid match
+            "matches1": indices1,  # use -1 for invalid match
+            "matching_scores0": mscores0,
+            "matching_scores1": mscores1,
         }
diff --git a/third_party/SuperGluePretrainedNetwork/models/superpoint.py b/third_party/SuperGluePretrainedNetwork/models/superpoint.py
index b837d938f755850180ddc168e957742e874adacd..ab9712eed30ea30f1578cabb97c0c8f2fbed8c7c 100644
--- a/third_party/SuperGluePretrainedNetwork/models/superpoint.py
+++ b/third_party/SuperGluePretrainedNetwork/models/superpoint.py
@@ -44,13 +44,15 @@ from pathlib import Path
 import torch
 from torch import nn
 
+
 def simple_nms(scores, nms_radius: int):
-    """ Fast Non-maximum suppression to remove nearby points """
-    assert(nms_radius >= 0)
+    """Fast Non-maximum suppression to remove nearby points"""
+    assert nms_radius >= 0
 
     def max_pool(x):
         return torch.nn.functional.max_pool2d(
-            x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius)
+            x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
+        )
 
     zeros = torch.zeros_like(scores)
     max_mask = scores == max_pool(scores)
@@ -63,7 +65,7 @@ def simple_nms(scores, nms_radius: int):
 
 
 def remove_borders(keypoints, scores, border: int, height: int, width: int):
-    """ Removes keypoints too close to the border """
+    """Removes keypoints too close to the border"""
     mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border))
     mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border))
     mask = mask_h & mask_w
@@ -78,17 +80,20 @@ def top_k_keypoints(keypoints, scores, k: int):
 
 
 def sample_descriptors(keypoints, descriptors, s: int = 8):
-    """ Interpolate descriptors at keypoint locations """
+    """Interpolate descriptors at keypoint locations"""
     b, c, h, w = descriptors.shape
     keypoints = keypoints - s / 2 + 0.5
-    keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)],
-                              ).to(keypoints)[None]
-    keypoints = keypoints*2 - 1  # normalize to (-1, 1)
-    args = {'align_corners': True} if torch.__version__ >= '1.3' else {}
+    keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],).to(
+        keypoints
+    )[None]
+    keypoints = keypoints * 2 - 1  # normalize to (-1, 1)
+    args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
     descriptors = torch.nn.functional.grid_sample(
-        descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args)
+        descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args
+    )
     descriptors = torch.nn.functional.normalize(
-        descriptors.reshape(b, c, -1), p=2, dim=1)
+        descriptors.reshape(b, c, -1), p=2, dim=1
+    )
     return descriptors
 
 
@@ -100,12 +105,13 @@ class SuperPoint(nn.Module):
     Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629
 
     """
+
     default_config = {
-        'descriptor_dim': 256,
-        'nms_radius': 4,
-        'keypoint_threshold': 0.005,
-        'max_keypoints': -1,
-        'remove_borders': 4,
+        "descriptor_dim": 256,
+        "nms_radius": 4,
+        "keypoint_threshold": 0.005,
+        "max_keypoints": -1,
+        "remove_borders": 4,
     }
 
     def __init__(self, config):
@@ -130,22 +136,22 @@ class SuperPoint(nn.Module):
 
         self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
         self.convDb = nn.Conv2d(
-            c5, self.config['descriptor_dim'],
-            kernel_size=1, stride=1, padding=0)
+            c5, self.config["descriptor_dim"], kernel_size=1, stride=1, padding=0
+        )
 
-        path = Path(__file__).parent / 'weights/superpoint_v1.pth'
+        path = Path(__file__).parent / "weights/superpoint_v1.pth"
         self.load_state_dict(torch.load(str(path)))
 
-        mk = self.config['max_keypoints']
+        mk = self.config["max_keypoints"]
         if mk == 0 or mk < -1:
-            raise ValueError('\"max_keypoints\" must be positive or \"-1\"')
+            raise ValueError('"max_keypoints" must be positive or "-1"')
 
-        print('Loaded SuperPoint model')
+        print("Loaded SuperPoint model")
 
     def forward(self, data):
-        """ Compute keypoints, scores, descriptors for image """
+        """Compute keypoints, scores, descriptors for image"""
         # Shared Encoder
-        x = self.relu(self.conv1a(data['image']))
+        x = self.relu(self.conv1a(data["image"]))
         x = self.relu(self.conv1b(x))
         x = self.pool(x)
         x = self.relu(self.conv2a(x))
@@ -163,25 +169,35 @@ class SuperPoint(nn.Module):
         scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
         b, _, h, w = scores.shape
         scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
-        scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8)
-        scores = simple_nms(scores, self.config['nms_radius'])
+        scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
+        scores = simple_nms(scores, self.config["nms_radius"])
 
         # Extract keypoints
         keypoints = [
-            torch.nonzero(s > self.config['keypoint_threshold'])
-            for s in scores]
+            torch.nonzero(s > self.config["keypoint_threshold"]) for s in scores
+        ]
         scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
 
         # Discard keypoints near the image borders
-        keypoints, scores = list(zip(*[
-            remove_borders(k, s, self.config['remove_borders'], h*8, w*8)
-            for k, s in zip(keypoints, scores)]))
+        keypoints, scores = list(
+            zip(
+                *[
+                    remove_borders(k, s, self.config["remove_borders"], h * 8, w * 8)
+                    for k, s in zip(keypoints, scores)
+                ]
+            )
+        )
 
         # Keep the k keypoints with highest score
-        if self.config['max_keypoints'] >= 0:
-            keypoints, scores = list(zip(*[
-                top_k_keypoints(k, s, self.config['max_keypoints'])
-                for k, s in zip(keypoints, scores)]))
+        if self.config["max_keypoints"] >= 0:
+            keypoints, scores = list(
+                zip(
+                    *[
+                        top_k_keypoints(k, s, self.config["max_keypoints"])
+                        for k, s in zip(keypoints, scores)
+                    ]
+                )
+            )
 
         # Convert (h, w) to (x, y)
         keypoints = [torch.flip(k, [1]).float() for k in keypoints]
@@ -192,11 +208,13 @@ class SuperPoint(nn.Module):
         descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
 
         # Extract descriptors
-        descriptors = [sample_descriptors(k[None], d[None], 8)[0]
-                       for k, d in zip(keypoints, descriptors)]
+        descriptors = [
+            sample_descriptors(k[None], d[None], 8)[0]
+            for k, d in zip(keypoints, descriptors)
+        ]
 
         return {
-            'keypoints': keypoints,
-            'scores': scores,
-            'descriptors': descriptors,
+            "keypoints": keypoints,
+            "scores": scores,
+            "descriptors": descriptors,
         }
diff --git a/third_party/SuperGluePretrainedNetwork/models/utils.py b/third_party/SuperGluePretrainedNetwork/models/utils.py
index 1206244aa2a004d9f653782de798bfef9e5e726b..d302ff84cf316f3dad016f1f23bbb54518566d2e 100644
--- a/third_party/SuperGluePretrainedNetwork/models/utils.py
+++ b/third_party/SuperGluePretrainedNetwork/models/utils.py
@@ -51,11 +51,12 @@ import cv2
 import torch
 import matplotlib.pyplot as plt
 import matplotlib
-matplotlib.use('Agg')
+
+matplotlib.use("Agg")
 
 
 class AverageTimer:
-    """ Class to help manage printing simple timing of code execution. """
+    """Class to help manage printing simple timing of code execution."""
 
     def __init__(self, smoothing=0.3, newline=False):
         self.smoothing = smoothing
@@ -71,7 +72,7 @@ class AverageTimer:
         for name in self.will_print:
             self.will_print[name] = False
 
-    def update(self, name='default'):
+    def update(self, name="default"):
         now = time.time()
         dt = now - self.last_time
         if name in self.times:
@@ -80,29 +81,30 @@ class AverageTimer:
         self.will_print[name] = True
         self.last_time = now
 
-    def print(self, text='Timer'):
-        total = 0.
-        print('[{}]'.format(text), end=' ')
+    def print(self, text="Timer"):
+        total = 0.0
+        print("[{}]".format(text), end=" ")
         for key in self.times:
             val = self.times[key]
             if self.will_print[key]:
-                print('%s=%.3f' % (key, val), end=' ')
+                print("%s=%.3f" % (key, val), end=" ")
                 total += val
-        print('total=%.3f sec {%.1f FPS}' % (total, 1./total), end=' ')
+        print("total=%.3f sec {%.1f FPS}" % (total, 1.0 / total), end=" ")
         if self.newline:
             print(flush=True)
         else:
-            print(end='\r', flush=True)
+            print(end="\r", flush=True)
         self.reset()
 
 
 class VideoStreamer:
-    """ Class to help process image streams. Four types of possible inputs:"
-        1.) USB Webcam.
-        2.) An IP camera
-        3.) A directory of images (files in directory matching 'image_glob').
-        4.) A video file, such as an .mp4 or .avi file.
+    """Class to help process image streams. Four types of possible inputs:"
+    1.) USB Webcam.
+    2.) An IP camera
+    3.) A directory of images (files in directory matching 'image_glob').
+    4.) A video file, such as an .mp4 or .avi file.
     """
+
     def __init__(self, basedir, resize, skip, image_glob, max_length=1000000):
         self._ip_grabbed = False
         self._ip_running = False
@@ -119,45 +121,45 @@ class VideoStreamer:
         self.skip = skip
         self.max_length = max_length
         if isinstance(basedir, int) or basedir.isdigit():
-            print('==> Processing USB webcam input: {}'.format(basedir))
+            print("==> Processing USB webcam input: {}".format(basedir))
             self.cap = cv2.VideoCapture(int(basedir))
             self.listing = range(0, self.max_length)
-        elif basedir.startswith(('http', 'rtsp')):
-            print('==> Processing IP camera input: {}'.format(basedir))
+        elif basedir.startswith(("http", "rtsp")):
+            print("==> Processing IP camera input: {}".format(basedir))
             self.cap = cv2.VideoCapture(basedir)
             self.start_ip_camera_thread()
             self._ip_camera = True
             self.listing = range(0, self.max_length)
         elif Path(basedir).is_dir():
-            print('==> Processing image directory input: {}'.format(basedir))
+            print("==> Processing image directory input: {}".format(basedir))
             self.listing = list(Path(basedir).glob(image_glob[0]))
             for j in range(1, len(image_glob)):
                 image_path = list(Path(basedir).glob(image_glob[j]))
                 self.listing = self.listing + image_path
             self.listing.sort()
-            self.listing = self.listing[::self.skip]
+            self.listing = self.listing[:: self.skip]
             self.max_length = np.min([self.max_length, len(self.listing)])
             if self.max_length == 0:
-                raise IOError('No images found (maybe bad \'image_glob\' ?)')
-            self.listing = self.listing[:self.max_length]
+                raise IOError("No images found (maybe bad 'image_glob' ?)")
+            self.listing = self.listing[: self.max_length]
             self.camera = False
         elif Path(basedir).exists():
-            print('==> Processing video input: {}'.format(basedir))
+            print("==> Processing video input: {}".format(basedir))
             self.cap = cv2.VideoCapture(basedir)
             self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
             num_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
             self.listing = range(0, num_frames)
-            self.listing = self.listing[::self.skip]
+            self.listing = self.listing[:: self.skip]
             self.video_file = True
             self.max_length = np.min([self.max_length, len(self.listing)])
-            self.listing = self.listing[:self.max_length]
+            self.listing = self.listing[: self.max_length]
         else:
-            raise ValueError('VideoStreamer input \"{}\" not recognized.'.format(basedir))
+            raise ValueError('VideoStreamer input "{}" not recognized.'.format(basedir))
         if self.camera and not self.cap.isOpened():
-            raise IOError('Could not read camera')
+            raise IOError("Could not read camera")
 
     def load_image(self, impath):
-        """ Read image as grayscale and resize to img_size.
+        """Read image as grayscale and resize to img_size.
         Inputs
             impath: Path to input image.
         Returns
@@ -165,15 +167,14 @@ class VideoStreamer:
         """
         grayim = cv2.imread(impath, 0)
         if grayim is None:
-            raise Exception('Error reading image %s' % impath)
+            raise Exception("Error reading image %s" % impath)
         w, h = grayim.shape[1], grayim.shape[0]
         w_new, h_new = process_resize(w, h, self.resize)
-        grayim = cv2.resize(
-            grayim, (w_new, h_new), interpolation=self.interp)
+        grayim = cv2.resize(grayim, (w_new, h_new), interpolation=self.interp)
         return grayim
 
     def next_frame(self):
-        """ Return the next frame, and increment internal counter.
+        """Return the next frame, and increment internal counter.
         Returns
              image: Next H x W image.
              status: True or False depending whether image was loaded.
@@ -184,9 +185,9 @@ class VideoStreamer:
         if self.camera:
 
             if self._ip_camera:
-                #Wait for first image, making sure we haven't exited
+                # Wait for first image, making sure we haven't exited
                 while self._ip_grabbed is False and self._ip_exited is False:
-                    time.sleep(.001)
+                    time.sleep(0.001)
 
                 ret, image = self._ip_grabbed, self._ip_image.copy()
                 if ret is False:
@@ -194,15 +195,14 @@ class VideoStreamer:
             else:
                 ret, image = self.cap.read()
             if ret is False:
-                print('VideoStreamer: Cannot get image from camera')
+                print("VideoStreamer: Cannot get image from camera")
                 return (None, False)
             w, h = image.shape[1], image.shape[0]
             if self.video_file:
                 self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.listing[self.i])
 
             w_new, h_new = process_resize(w, h, self.resize)
-            image = cv2.resize(image, (w_new, h_new),
-                               interpolation=self.interp)
+            image = cv2.resize(image, (w_new, h_new), interpolation=self.interp)
             image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
         else:
             image_file = str(self.listing[self.i])
@@ -229,19 +229,20 @@ class VideoStreamer:
             self._ip_image = img
             self._ip_grabbed = ret
             self._ip_index += 1
-            #print('IPCAMERA THREAD got frame {}'.format(self._ip_index))
-
+            # print('IPCAMERA THREAD got frame {}'.format(self._ip_index))
 
     def cleanup(self):
         self._ip_running = False
 
+
 # --- PREPROCESSING ---
 
+
 def process_resize(w, h, resize):
-    assert(len(resize) > 0 and len(resize) <= 2)
+    assert len(resize) > 0 and len(resize) <= 2
     if len(resize) == 1 and resize[0] > -1:
         scale = resize[0] / max(h, w)
-        w_new, h_new = int(round(w*scale)), int(round(h*scale))
+        w_new, h_new = int(round(w * scale)), int(round(h * scale))
     elif len(resize) == 1 and resize[0] == -1:
         w_new, h_new = w, h
     else:  # len(resize) == 2:
@@ -249,15 +250,15 @@ def process_resize(w, h, resize):
 
     # Issue warning if resolution is too small or too large.
     if max(w_new, h_new) < 160:
-        print('Warning: input resolution is very small, results may vary')
+        print("Warning: input resolution is very small, results may vary")
     elif max(w_new, h_new) > 2000:
-        print('Warning: input resolution is very large, results may vary')
+        print("Warning: input resolution is very large, results may vary")
 
     return w_new, h_new
 
 
 def frame2tensor(frame, device):
-    return torch.from_numpy(frame/255.).float()[None, None].to(device)
+    return torch.from_numpy(frame / 255.0).float()[None, None].to(device)
 
 
 def read_image(path, device, resize, rotation, resize_float):
@@ -269,9 +270,9 @@ def read_image(path, device, resize, rotation, resize_float):
     scales = (float(w) / float(w_new), float(h) / float(h_new))
 
     if resize_float:
-        image = cv2.resize(image.astype('float32'), (w_new, h_new))
+        image = cv2.resize(image.astype("float32"), (w_new, h_new))
     else:
-        image = cv2.resize(image, (w_new, h_new)).astype('float32')
+        image = cv2.resize(image, (w_new, h_new)).astype("float32")
 
     if rotation != 0:
         image = np.rot90(image, k=rotation)
@@ -296,16 +297,15 @@ def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
     kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
 
     E, mask = cv2.findEssentialMat(
-        kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf,
-        method=cv2.RANSAC)
+        kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf, method=cv2.RANSAC
+    )
 
     assert E is not None
 
     best_num_inliers = 0
     ret = None
     for _E in np.split(E, len(E) / 3):
-        n, R, t, _ = cv2.recoverPose(
-            _E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
+        n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
         if n > best_num_inliers:
             best_num_inliers = n
             ret = (R, t[:, 0], mask.ravel() > 0)
@@ -315,36 +315,42 @@ def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
 def rotate_intrinsics(K, image_shape, rot):
     """image_shape is the shape of the image after rotation"""
     assert rot <= 3
-    h, w = image_shape[:2][::-1 if (rot % 2) else 1]
+    h, w = image_shape[:2][:: -1 if (rot % 2) else 1]
     fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
     rot = rot % 4
     if rot == 1:
-        return np.array([[fy, 0., cy],
-                         [0., fx, w-1-cx],
-                         [0., 0., 1.]], dtype=K.dtype)
+        return np.array(
+            [[fy, 0.0, cy], [0.0, fx, w - 1 - cx], [0.0, 0.0, 1.0]], dtype=K.dtype
+        )
     elif rot == 2:
-        return np.array([[fx, 0., w-1-cx],
-                         [0., fy, h-1-cy],
-                         [0., 0., 1.]], dtype=K.dtype)
+        return np.array(
+            [[fx, 0.0, w - 1 - cx], [0.0, fy, h - 1 - cy], [0.0, 0.0, 1.0]],
+            dtype=K.dtype,
+        )
     else:  # if rot == 3:
-        return np.array([[fy, 0., h-1-cy],
-                         [0., fx, cx],
-                         [0., 0., 1.]], dtype=K.dtype)
+        return np.array(
+            [[fy, 0.0, h - 1 - cy], [0.0, fx, cx], [0.0, 0.0, 1.0]], dtype=K.dtype
+        )
 
 
 def rotate_pose_inplane(i_T_w, rot):
     rotation_matrices = [
-        np.array([[np.cos(r), -np.sin(r), 0., 0.],
-                  [np.sin(r), np.cos(r), 0., 0.],
-                  [0., 0., 1., 0.],
-                  [0., 0., 0., 1.]], dtype=np.float32)
+        np.array(
+            [
+                [np.cos(r), -np.sin(r), 0.0, 0.0],
+                [np.sin(r), np.cos(r), 0.0, 0.0],
+                [0.0, 0.0, 1.0, 0.0],
+                [0.0, 0.0, 0.0, 1.0],
+            ],
+            dtype=np.float32,
+        )
         for r in [np.deg2rad(d) for d in (0, 270, 180, 90)]
     ]
     return np.dot(rotation_matrices[rot], i_T_w)
 
 
 def scale_intrinsics(K, scales):
-    scales = np.diag([1./scales[0], 1./scales[1], 1.])
+    scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0])
     return np.dot(scales, K)
 
 
@@ -359,24 +365,22 @@ def compute_epipolar_error(kpts0, kpts1, T_0to1, K0, K1):
     kpts1 = to_homogeneous(kpts1)
 
     t0, t1, t2 = T_0to1[:3, 3]
-    t_skew = np.array([
-        [0, -t2, t1],
-        [t2, 0, -t0],
-        [-t1, t0, 0]
-    ])
+    t_skew = np.array([[0, -t2, t1], [t2, 0, -t0], [-t1, t0, 0]])
     E = t_skew @ T_0to1[:3, :3]
 
     Ep0 = kpts0 @ E.T  # N x 3
     p1Ep0 = np.sum(kpts1 * Ep0, -1)  # N
     Etp1 = kpts1 @ E  # N x 3
-    d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2)
-                    + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2))
+    d = p1Ep0**2 * (
+        1.0 / (Ep0[:, 0] ** 2 + Ep0[:, 1] ** 2)
+        + 1.0 / (Etp1[:, 0] ** 2 + Etp1[:, 1] ** 2)
+    )
     return d
 
 
 def angle_error_mat(R1, R2):
     cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
-    cos = np.clip(cos, -1., 1.)  # numercial errors can make it out of bounds
+    cos = np.clip(cos, -1.0, 1.0)  # numercial errors can make it out of bounds
     return np.rad2deg(np.abs(np.arccos(cos)))
 
 
@@ -398,27 +402,27 @@ def pose_auc(errors, thresholds):
     sort_idx = np.argsort(errors)
     errors = np.array(errors.copy())[sort_idx]
     recall = (np.arange(len(errors)) + 1) / len(errors)
-    errors = np.r_[0., errors]
-    recall = np.r_[0., recall]
+    errors = np.r_[0.0, errors]
+    recall = np.r_[0.0, recall]
     aucs = []
     for t in thresholds:
         last_index = np.searchsorted(errors, t)
-        r = np.r_[recall[:last_index], recall[last_index-1]]
+        r = np.r_[recall[:last_index], recall[last_index - 1]]
         e = np.r_[errors[:last_index], t]
-        aucs.append(np.trapz(r, x=e)/t)
+        aucs.append(np.trapz(r, x=e) / t)
     return aucs
 
 
 # --- VISUALIZATION ---
 
 
-def plot_image_pair(imgs, dpi=100, size=6, pad=.5):
+def plot_image_pair(imgs, dpi=100, size=6, pad=0.5):
     n = len(imgs)
-    assert n == 2, 'number of images must be two'
-    figsize = (size*n, size*3/4) if size is not None else None
+    assert n == 2, "number of images must be two"
+    figsize = (size * n, size * 3 / 4) if size is not None else None
     _, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
     for i in range(n):
-        ax[i].imshow(imgs[i], cmap=plt.get_cmap('gray'), vmin=0, vmax=255)
+        ax[i].imshow(imgs[i], cmap=plt.get_cmap("gray"), vmin=0, vmax=255)
         ax[i].get_yaxis().set_ticks([])
         ax[i].get_xaxis().set_ticks([])
         for spine in ax[i].spines.values():  # remove frame
@@ -426,7 +430,7 @@ def plot_image_pair(imgs, dpi=100, size=6, pad=.5):
     plt.tight_layout(pad=pad)
 
 
-def plot_keypoints(kpts0, kpts1, color='w', ps=2):
+def plot_keypoints(kpts0, kpts1, color="w", ps=2):
     ax = plt.gcf().axes
     ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
     ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
@@ -441,59 +445,116 @@ def plot_matches(kpts0, kpts1, color, lw=1.5, ps=4):
     fkpts0 = transFigure.transform(ax[0].transData.transform(kpts0))
     fkpts1 = transFigure.transform(ax[1].transData.transform(kpts1))
 
-    fig.lines = [matplotlib.lines.Line2D(
-        (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), zorder=1,
-        transform=fig.transFigure, c=color[i], linewidth=lw)
-                 for i in range(len(kpts0))]
+    fig.lines = [
+        matplotlib.lines.Line2D(
+            (fkpts0[i, 0], fkpts1[i, 0]),
+            (fkpts0[i, 1], fkpts1[i, 1]),
+            zorder=1,
+            transform=fig.transFigure,
+            c=color[i],
+            linewidth=lw,
+        )
+        for i in range(len(kpts0))
+    ]
     ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
     ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
 
 
-def make_matching_plot(image0, image1, kpts0, kpts1, mkpts0, mkpts1,
-                       color, text, path, show_keypoints=False,
-                       fast_viz=False, opencv_display=False,
-                       opencv_title='matches', small_text=[]):
+def make_matching_plot(
+    image0,
+    image1,
+    kpts0,
+    kpts1,
+    mkpts0,
+    mkpts1,
+    color,
+    text,
+    path,
+    show_keypoints=False,
+    fast_viz=False,
+    opencv_display=False,
+    opencv_title="matches",
+    small_text=[],
+):
 
     if fast_viz:
-        make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0, mkpts1,
-                                color, text, path, show_keypoints, 10,
-                                opencv_display, opencv_title, small_text)
+        make_matching_plot_fast(
+            image0,
+            image1,
+            kpts0,
+            kpts1,
+            mkpts0,
+            mkpts1,
+            color,
+            text,
+            path,
+            show_keypoints,
+            10,
+            opencv_display,
+            opencv_title,
+            small_text,
+        )
         return
 
     plot_image_pair([image0, image1])
     if show_keypoints:
-        plot_keypoints(kpts0, kpts1, color='k', ps=4)
-        plot_keypoints(kpts0, kpts1, color='w', ps=2)
+        plot_keypoints(kpts0, kpts1, color="k", ps=4)
+        plot_keypoints(kpts0, kpts1, color="w", ps=2)
     plot_matches(mkpts0, mkpts1, color)
 
     fig = plt.gcf()
-    txt_color = 'k' if image0[:100, :150].mean() > 200 else 'w'
+    txt_color = "k" if image0[:100, :150].mean() > 200 else "w"
     fig.text(
-        0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
-        fontsize=15, va='top', ha='left', color=txt_color)
-
-    txt_color = 'k' if image0[-100:, :150].mean() > 200 else 'w'
+        0.01,
+        0.99,
+        "\n".join(text),
+        transform=fig.axes[0].transAxes,
+        fontsize=15,
+        va="top",
+        ha="left",
+        color=txt_color,
+    )
+
+    txt_color = "k" if image0[-100:, :150].mean() > 200 else "w"
     fig.text(
-        0.01, 0.01, '\n'.join(small_text), transform=fig.axes[0].transAxes,
-        fontsize=5, va='bottom', ha='left', color=txt_color)
-
-    plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
+        0.01,
+        0.01,
+        "\n".join(small_text),
+        transform=fig.axes[0].transAxes,
+        fontsize=5,
+        va="bottom",
+        ha="left",
+        color=txt_color,
+    )
+
+    plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
     plt.close()
 
 
-def make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0,
-                            mkpts1, color, text, path=None,
-                            show_keypoints=False, margin=10,
-                            opencv_display=False, opencv_title='',
-                            small_text=[]):
+def make_matching_plot_fast(
+    image0,
+    image1,
+    kpts0,
+    kpts1,
+    mkpts0,
+    mkpts1,
+    color,
+    text,
+    path=None,
+    show_keypoints=False,
+    margin=10,
+    opencv_display=False,
+    opencv_title="",
+    small_text=[],
+):
     H0, W0 = image0.shape
     H1, W1 = image1.shape
     H, W = max(H0, H1), W0 + W1 + margin
 
-    out = 255*np.ones((H, W), np.uint8)
+    out = 255 * np.ones((H, W), np.uint8)
     out[:H0, :W0] = image0
-    out[:H1, W0+margin:] = image1
-    out = np.stack([out]*3, -1)
+    out[:H1, W0 + margin :] = image1
+    out = np.stack([out] * 3, -1)
 
     if show_keypoints:
         kpts0, kpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int)
@@ -503,42 +564,77 @@ def make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0,
             cv2.circle(out, (x, y), 2, black, -1, lineType=cv2.LINE_AA)
             cv2.circle(out, (x, y), 1, white, -1, lineType=cv2.LINE_AA)
         for x, y in kpts1:
-            cv2.circle(out, (x + margin + W0, y), 2, black, -1,
-                       lineType=cv2.LINE_AA)
-            cv2.circle(out, (x + margin + W0, y), 1, white, -1,
-                       lineType=cv2.LINE_AA)
+            cv2.circle(out, (x + margin + W0, y), 2, black, -1, lineType=cv2.LINE_AA)
+            cv2.circle(out, (x + margin + W0, y), 1, white, -1, lineType=cv2.LINE_AA)
 
     mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int)
-    color = (np.array(color[:, :3])*255).astype(int)[:, ::-1]
+    color = (np.array(color[:, :3]) * 255).astype(int)[:, ::-1]
     for (x0, y0), (x1, y1), c in zip(mkpts0, mkpts1, color):
         c = c.tolist()
-        cv2.line(out, (x0, y0), (x1 + margin + W0, y1),
-                 color=c, thickness=1, lineType=cv2.LINE_AA)
+        cv2.line(
+            out,
+            (x0, y0),
+            (x1 + margin + W0, y1),
+            color=c,
+            thickness=1,
+            lineType=cv2.LINE_AA,
+        )
         # display line end-points as circles
         cv2.circle(out, (x0, y0), 2, c, -1, lineType=cv2.LINE_AA)
-        cv2.circle(out, (x1 + margin + W0, y1), 2, c, -1,
-                   lineType=cv2.LINE_AA)
+        cv2.circle(out, (x1 + margin + W0, y1), 2, c, -1, lineType=cv2.LINE_AA)
 
     # Scale factor for consistent visualization across scales.
-    sc = min(H / 640., 2.0)
+    sc = min(H / 640.0, 2.0)
 
     # Big text.
     Ht = int(30 * sc)  # text height
     txt_color_fg = (255, 255, 255)
     txt_color_bg = (0, 0, 0)
     for i, t in enumerate(text):
-        cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX,
-                    1.0*sc, txt_color_bg, 2, cv2.LINE_AA)
-        cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX,
-                    1.0*sc, txt_color_fg, 1, cv2.LINE_AA)
+        cv2.putText(
+            out,
+            t,
+            (int(8 * sc), Ht * (i + 1)),
+            cv2.FONT_HERSHEY_DUPLEX,
+            1.0 * sc,
+            txt_color_bg,
+            2,
+            cv2.LINE_AA,
+        )
+        cv2.putText(
+            out,
+            t,
+            (int(8 * sc), Ht * (i + 1)),
+            cv2.FONT_HERSHEY_DUPLEX,
+            1.0 * sc,
+            txt_color_fg,
+            1,
+            cv2.LINE_AA,
+        )
 
     # Small text.
     Ht = int(18 * sc)  # text height
     for i, t in enumerate(reversed(small_text)):
-        cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX,
-                    0.5*sc, txt_color_bg, 2, cv2.LINE_AA)
-        cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX,
-                    0.5*sc, txt_color_fg, 1, cv2.LINE_AA)
+        cv2.putText(
+            out,
+            t,
+            (int(8 * sc), int(H - Ht * (i + 0.6))),
+            cv2.FONT_HERSHEY_DUPLEX,
+            0.5 * sc,
+            txt_color_bg,
+            2,
+            cv2.LINE_AA,
+        )
+        cv2.putText(
+            out,
+            t,
+            (int(8 * sc), int(H - Ht * (i + 0.6))),
+            cv2.FONT_HERSHEY_DUPLEX,
+            0.5 * sc,
+            txt_color_fg,
+            1,
+            cv2.LINE_AA,
+        )
 
     if path is not None:
         cv2.imwrite(str(path), out)
@@ -552,4 +648,5 @@ def make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0,
 
 def error_colormap(x):
     return np.clip(
-        np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)], -1), 0, 1)
+        np.stack([2 - x * 2, x * 2, np.zeros_like(x), np.ones_like(x)], -1), 0, 1
+    )
diff --git a/third_party/TopicFM/configs/data/base.py b/third_party/TopicFM/configs/data/base.py
index 6cab7e67019a6fee2657c1a28609c8aca5b2a1d8..1897a84393e186cc46f34fe856243756e8393a2a 100644
--- a/third_party/TopicFM/configs/data/base.py
+++ b/third_party/TopicFM/configs/data/base.py
@@ -4,6 +4,7 @@ Setups in data configs will override all existed setups!
 """
 
 from yacs.config import CfgNode as CN
+
 _CN = CN()
 _CN.DATASET = CN()
 _CN.TRAINER = CN()
diff --git a/third_party/TopicFM/configs/data/megadepth_trainval.py b/third_party/TopicFM/configs/data/megadepth_trainval.py
index 215b5c34cc41d36aa4444a58ca0cb69afbc11952..7b7b0a77e26bbf6e7b7ceb2cd54f8c2e3b709db4 100644
--- a/third_party/TopicFM/configs/data/megadepth_trainval.py
+++ b/third_party/TopicFM/configs/data/megadepth_trainval.py
@@ -11,9 +11,13 @@ cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0
 TEST_BASE_PATH = "data/megadepth/index"
 cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
 cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test"
-cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}/scene_info_val_1500"
-cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt"
-cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0   # for both test and val
+cfg.DATASET.VAL_NPZ_ROOT = (
+    cfg.DATASET.TEST_NPZ_ROOT
+) = f"{TEST_BASE_PATH}/scene_info_val_1500"
+cfg.DATASET.VAL_LIST_PATH = (
+    cfg.DATASET.TEST_LIST_PATH
+) = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt"
+cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0  # for both test and val
 
 # 368 scenes in total for MegaDepth
 # (with difficulty balanced (further split each scene to 3 sub-scenes))
diff --git a/third_party/TopicFM/configs/model/outdoor/model_ds.py b/third_party/TopicFM/configs/model/outdoor/model_ds.py
index 2c090edbfbdcd66cea225c39af6f62da8feb50b9..e0c234e8b3c932656052aa58836ed2b158344fb5 100644
--- a/third_party/TopicFM/configs/model/outdoor/model_ds.py
+++ b/third_party/TopicFM/configs/model/outdoor/model_ds.py
@@ -1,6 +1,6 @@
 from src.config.default import _CN as cfg
 
-cfg.MODEL.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
+cfg.MODEL.MATCH_COARSE.MATCH_TYPE = "dual_softmax"
 cfg.MODEL.COARSE.N_SAMPLES = 8
 
 cfg.TRAINER.CANONICAL_LR = 1e-2
diff --git a/third_party/TopicFM/flop_counter.py b/third_party/TopicFM/flop_counter.py
index ea87fa0139897434ca52b369450aa82203311181..915f703bd76146e54a3f2f9e819a7b1b85f2d700 100644
--- a/third_party/TopicFM/flop_counter.py
+++ b/third_party/TopicFM/flop_counter.py
@@ -27,7 +27,7 @@ def coarse_model_flops(coarse_model, config, inputs):
     return flops.total() / 1e9
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     path_img0 = "assets/scannet_sample_images/scene0711_00_frame-001680.jpg"
     path_img1 = "assets/scannet_sample_images/scene0711_00_frame-001995.jpg"
     img0, img1 = read_scannet_gray(path_img0), read_scannet_gray(path_img1)
@@ -35,21 +35,48 @@ if __name__ == '__main__':
 
     # LoFTR
     loftr_conf = dict(default_cfg)
-    feat_c0, loftr_featnet_flops0 = feat_net_flops(loftr_featnet, loftr_conf["resnetfpn"], img0)
-    feat_c1, loftr_featnet_flops1 = feat_net_flops(loftr_featnet, loftr_conf["resnetfpn"], img1)
-    print("FLOPs of feature extraction in LoFTR: {} GFLOPs".format((loftr_featnet_flops0 + loftr_featnet_flops1)/2))
-    feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c')
-    feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c')
-    loftr_coarse_model_flops = coarse_model_flops(LocalFeatureTransformer, loftr_conf["coarse"], (feat_c0, feat_c1))
-    print("FLOPs of coarse matching model in LoFTR: {} GFLOPs".format(loftr_coarse_model_flops))
+    feat_c0, loftr_featnet_flops0 = feat_net_flops(
+        loftr_featnet, loftr_conf["resnetfpn"], img0
+    )
+    feat_c1, loftr_featnet_flops1 = feat_net_flops(
+        loftr_featnet, loftr_conf["resnetfpn"], img1
+    )
+    print(
+        "FLOPs of feature extraction in LoFTR: {} GFLOPs".format(
+            (loftr_featnet_flops0 + loftr_featnet_flops1) / 2
+        )
+    )
+    feat_c0 = rearrange(feat_c0, "n c h w -> n (h w) c")
+    feat_c1 = rearrange(feat_c1, "n c h w -> n (h w) c")
+    loftr_coarse_model_flops = coarse_model_flops(
+        LocalFeatureTransformer, loftr_conf["coarse"], (feat_c0, feat_c1)
+    )
+    print(
+        "FLOPs of coarse matching model in LoFTR: {} GFLOPs".format(
+            loftr_coarse_model_flops
+        )
+    )
 
     # TopicFM
     topicfm_conf = get_model_cfg()
-    feat_c0, topicfm_featnet_flops0 = feat_net_flops(topicfm_featnet, topicfm_conf["fpn"], img0)
-    feat_c1, topicfm_featnet_flops1 = feat_net_flops(topicfm_featnet, topicfm_conf["fpn"], img1)
-    print("FLOPs of feature extraction in TopicFM: {} GFLOPs".format((topicfm_featnet_flops0 + topicfm_featnet_flops1) / 2))
-    feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c')
-    feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c')
-    topicfm_coarse_model_flops = coarse_model_flops(TopicFormer, topicfm_conf["coarse"], (feat_c0, feat_c1))
-    print("FLOPs of coarse matching model in TopicFM: {} GFLOPs".format(topicfm_coarse_model_flops))
-
+    feat_c0, topicfm_featnet_flops0 = feat_net_flops(
+        topicfm_featnet, topicfm_conf["fpn"], img0
+    )
+    feat_c1, topicfm_featnet_flops1 = feat_net_flops(
+        topicfm_featnet, topicfm_conf["fpn"], img1
+    )
+    print(
+        "FLOPs of feature extraction in TopicFM: {} GFLOPs".format(
+            (topicfm_featnet_flops0 + topicfm_featnet_flops1) / 2
+        )
+    )
+    feat_c0 = rearrange(feat_c0, "n c h w -> n (h w) c")
+    feat_c1 = rearrange(feat_c1, "n c h w -> n (h w) c")
+    topicfm_coarse_model_flops = coarse_model_flops(
+        TopicFormer, topicfm_conf["coarse"], (feat_c0, feat_c1)
+    )
+    print(
+        "FLOPs of coarse matching model in TopicFM: {} GFLOPs".format(
+            topicfm_coarse_model_flops
+        )
+    )
diff --git a/third_party/TopicFM/src/__init__.py b/third_party/TopicFM/src/__init__.py
index 30caef94f911f99e0c12510d8181b3c1537daf1a..aa7ba68e1b8fa7c7854ca49680c07d54d468d83e 100644
--- a/third_party/TopicFM/src/__init__.py
+++ b/third_party/TopicFM/src/__init__.py
@@ -1,11 +1,13 @@
 from yacs.config import CfgNode
 from .config.default import _CN
 
+
 def lower_config(yacs_cfg):
     if not isinstance(yacs_cfg, CfgNode):
         return yacs_cfg
     return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
 
+
 def get_model_cfg():
     cfg = lower_config(lower_config(_CN))
-    return cfg["model"]
\ No newline at end of file
+    return cfg["model"]
diff --git a/third_party/TopicFM/src/config/default.py b/third_party/TopicFM/src/config/default.py
index 591558b3f358cdce0e9e72e94acba702b2a4e896..a252b1a13952480b5c22e50d6b90432f5a328112 100644
--- a/third_party/TopicFM/src/config/default.py
+++ b/third_party/TopicFM/src/config/default.py
@@ -1,9 +1,10 @@
 from yacs.config import CfgNode as CN
+
 _CN = CN()
 
 ##############  ↓  MODEL Pipeline  ↓  ##############
 _CN.MODEL = CN()
-_CN.MODEL.BACKBONE_TYPE = 'FPN'
+_CN.MODEL.BACKBONE_TYPE = "FPN"
 _CN.MODEL.RESOLUTION = (8, 2)  # options: [(8, 2), (16, 4)]
 _CN.MODEL.FINE_WINDOW_SIZE = 5  # window_size in fine_level, must be odd
 _CN.MODEL.FINE_CONCAT_COARSE_FEAT = False
@@ -18,8 +19,8 @@ _CN.MODEL.COARSE = CN()
 _CN.MODEL.COARSE.D_MODEL = 256
 _CN.MODEL.COARSE.D_FFN = 256
 _CN.MODEL.COARSE.NHEAD = 8
-_CN.MODEL.COARSE.LAYER_NAMES = ['seed', 'seed', 'seed', 'seed', 'seed']
-_CN.MODEL.COARSE.ATTENTION = 'linear'  # options: ['linear', 'full']
+_CN.MODEL.COARSE.LAYER_NAMES = ["seed", "seed", "seed", "seed", "seed"]
+_CN.MODEL.COARSE.ATTENTION = "linear"  # options: ['linear', 'full']
 _CN.MODEL.COARSE.TEMP_BUG_FIX = True
 _CN.MODEL.COARSE.N_TOPICS = 100
 _CN.MODEL.COARSE.N_SAMPLES = 6
@@ -29,7 +30,7 @@ _CN.MODEL.COARSE.N_TOPIC_TRANSFORMERS = 1
 _CN.MODEL.MATCH_COARSE = CN()
 _CN.MODEL.MATCH_COARSE.THR = 0.2
 _CN.MODEL.MATCH_COARSE.BORDER_RM = 2
-_CN.MODEL.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
+_CN.MODEL.MATCH_COARSE.MATCH_TYPE = "dual_softmax"
 _CN.MODEL.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
 _CN.MODEL.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2  # training tricks: save GPU memory
 _CN.MODEL.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200  # training tricks: avoid DDP deadlock
@@ -40,8 +41,8 @@ _CN.MODEL.FINE = CN()
 _CN.MODEL.FINE.D_MODEL = 128
 _CN.MODEL.FINE.D_FFN = 128
 _CN.MODEL.FINE.NHEAD = 4
-_CN.MODEL.FINE.LAYER_NAMES = ['cross'] * 1
-_CN.MODEL.FINE.ATTENTION = 'linear'
+_CN.MODEL.FINE.LAYER_NAMES = ["cross"] * 1
+_CN.MODEL.FINE.ATTENTION = "linear"
 _CN.MODEL.FINE.N_TOPICS = 1
 
 # 5. MODEL Losses
@@ -57,7 +58,7 @@ _CN.MODEL.LOSS.NEG_WEIGHT = 1.0
 # use `_CN.MODEL.MATCH_COARSE.MATCH_TYPE`
 
 # -- # fine-level
-_CN.MODEL.LOSS.FINE_TYPE = 'l2_with_std'  # ['l2_with_std', 'l2']
+_CN.MODEL.LOSS.FINE_TYPE = "l2_with_std"  # ['l2_with_std', 'l2']
 _CN.MODEL.LOSS.FINE_WEIGHT = 1.0
 _CN.MODEL.LOSS.FINE_CORRECT_THR = 1.0  # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window)
 
@@ -75,25 +76,33 @@ _CN.DATASET.TRAIN_INTRINSIC_PATH = None
 _CN.DATASET.VAL_DATA_ROOT = None
 _CN.DATASET.VAL_POSE_ROOT = None  # (optional directory for poses)
 _CN.DATASET.VAL_NPZ_ROOT = None
-_CN.DATASET.VAL_LIST_PATH = None    # None if val data from all scenes are bundled into a single npz file
+_CN.DATASET.VAL_LIST_PATH = (
+    None  # None if val data from all scenes are bundled into a single npz file
+)
 _CN.DATASET.VAL_INTRINSIC_PATH = None
 # testing
 _CN.DATASET.TEST_DATA_SOURCE = None
 _CN.DATASET.TEST_DATA_ROOT = None
 _CN.DATASET.TEST_POSE_ROOT = None  # (optional directory for poses)
 _CN.DATASET.TEST_NPZ_ROOT = None
-_CN.DATASET.TEST_LIST_PATH = None   # None if test data from all scenes are bundled into a single npz file
+_CN.DATASET.TEST_LIST_PATH = (
+    None  # None if test data from all scenes are bundled into a single npz file
+)
 _CN.DATASET.TEST_INTRINSIC_PATH = None
 _CN.DATASET.TEST_IMGSIZE = None
 
 # 2. dataset config
 # general options
-_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4  # discard data with overlap_score < min_overlap_score
+_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = (
+    0.4  # discard data with overlap_score < min_overlap_score
+)
 _CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
 _CN.DATASET.AUGMENTATION_TYPE = None  # options: [None, 'dark', 'mobile']
 
 # MegaDepth options
-_CN.DATASET.MGDPT_IMG_RESIZE = 640  # resize the longer side, zero-pad bottom-right to square.
+_CN.DATASET.MGDPT_IMG_RESIZE = (
+    640  # resize the longer side, zero-pad bottom-right to square.
+)
 _CN.DATASET.MGDPT_IMG_PAD = True  # pad img to square with size = MGDPT_IMG_RESIZE
 _CN.DATASET.MGDPT_DEPTH_PAD = True  # pad depthmap to square with size = 2000
 _CN.DATASET.MGDPT_DF = 8
@@ -109,17 +118,17 @@ _CN.TRAINER.FIND_LR = False  # use learning rate finder from pytorch-lightning
 # optimizer
 _CN.TRAINER.OPTIMIZER = "adamw"  # [adam, adamw]
 _CN.TRAINER.TRUE_LR = None  # this will be calculated automatically at runtime
-_CN.TRAINER.ADAM_DECAY = 0.  # ADAM: for adam
+_CN.TRAINER.ADAM_DECAY = 0.0  # ADAM: for adam
 _CN.TRAINER.ADAMW_DECAY = 0.01
 
 # step-based warm-up
-_CN.TRAINER.WARMUP_TYPE = 'linear'  # [linear, constant]
-_CN.TRAINER.WARMUP_RATIO = 0.
+_CN.TRAINER.WARMUP_TYPE = "linear"  # [linear, constant]
+_CN.TRAINER.WARMUP_RATIO = 0.0
 _CN.TRAINER.WARMUP_STEP = 4800
 
 # learning rate scheduler
-_CN.TRAINER.SCHEDULER = 'MultiStepLR'  # [MultiStepLR, CosineAnnealing, ExponentialLR]
-_CN.TRAINER.SCHEDULER_INTERVAL = 'epoch'    # [epoch, step]
+_CN.TRAINER.SCHEDULER = "MultiStepLR"  # [MultiStepLR, CosineAnnealing, ExponentialLR]
+_CN.TRAINER.SCHEDULER_INTERVAL = "epoch"  # [epoch, step]
 _CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12]  # MSLR: MultiStepLR
 _CN.TRAINER.MSLR_GAMMA = 0.5
 _CN.TRAINER.COSA_TMAX = 30  # COSA: CosineAnnealing
@@ -127,25 +136,33 @@ _CN.TRAINER.ELR_GAMMA = 0.999992  # ELR: ExponentialLR, this value for 'step' in
 
 # plotting related
 _CN.TRAINER.ENABLE_PLOTTING = True
-_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32     # number of val/test paris for plotting
-_CN.TRAINER.PLOT_MODE = 'evaluation'  # ['evaluation', 'confidence']
-_CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic'
+_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32  # number of val/test paris for plotting
+_CN.TRAINER.PLOT_MODE = "evaluation"  # ['evaluation', 'confidence']
+_CN.TRAINER.PLOT_MATCHES_ALPHA = "dynamic"
 
 # geometric metrics and pose solver
-_CN.TRAINER.EPI_ERR_THR = 5e-4  # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
-_CN.TRAINER.POSE_GEO_MODEL = 'E'  # ['E', 'F', 'H']
-_CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC'  # [RANSAC, DEGENSAC, MAGSAC]
+_CN.TRAINER.EPI_ERR_THR = (
+    5e-4  # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
+)
+_CN.TRAINER.POSE_GEO_MODEL = "E"  # ['E', 'F', 'H']
+_CN.TRAINER.POSE_ESTIMATION_METHOD = "RANSAC"  # [RANSAC, DEGENSAC, MAGSAC]
 _CN.TRAINER.RANSAC_PIXEL_THR = 0.5
 _CN.TRAINER.RANSAC_CONF = 0.99999
 _CN.TRAINER.RANSAC_MAX_ITERS = 10000
 _CN.TRAINER.USE_MAGSACPP = False
 
 # data sampler for train_dataloader
-_CN.TRAINER.DATA_SAMPLER = 'scene_balance'  # options: ['scene_balance', 'random', 'normal']
+_CN.TRAINER.DATA_SAMPLER = (
+    "scene_balance"  # options: ['scene_balance', 'random', 'normal']
+)
 # 'scene_balance' config
 _CN.TRAINER.N_SAMPLES_PER_SUBSET = 200
-_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True  # whether sample each scene with replacement or not
-_CN.TRAINER.SB_SUBSET_SHUFFLE = True  # after sampling from scenes, whether shuffle within the epoch or not
+_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = (
+    True  # whether sample each scene with replacement or not
+)
+_CN.TRAINER.SB_SUBSET_SHUFFLE = (
+    True  # after sampling from scenes, whether shuffle within the epoch or not
+)
 _CN.TRAINER.SB_REPEAT = 1  # repeat N times for training the sampled data
 # 'random' config
 _CN.TRAINER.RDM_REPLACEMENT = True
diff --git a/third_party/TopicFM/src/datasets/aachen.py b/third_party/TopicFM/src/datasets/aachen.py
index ebfeee4dbfbd78770976ec027ceee8ef333a4574..71f2dd18855f3536a5159e7f420044d6536d960b 100644
--- a/third_party/TopicFM/src/datasets/aachen.py
+++ b/third_party/TopicFM/src/datasets/aachen.py
@@ -9,7 +9,7 @@ class AachenDataset(Dataset):
         self.img_path = img_path
         self.img_resize = img_resize
         self.down_factor = down_factor
-        with open(match_list_path, 'r') as f:
+        with open(match_list_path, "r") as f:
             self.raw_pairs = f.readlines()
         print("number of matching pairs: ", len(self.raw_pairs))
 
@@ -18,12 +18,20 @@ class AachenDataset(Dataset):
 
     def __getitem__(self, idx):
         raw_pair = self.raw_pairs[idx]
-        image_name0, image_name1 = raw_pair.strip('\n').split(' ')
+        image_name0, image_name1 = raw_pair.strip("\n").split(" ")
         path_img0 = os.path.join(self.img_path, image_name0)
         path_img1 = os.path.join(self.img_path, image_name1)
-        img0, scale0 = read_img_gray(path_img0, resize=self.img_resize, down_factor=self.down_factor)
-        img1, scale1 = read_img_gray(path_img1, resize=self.img_resize, down_factor=self.down_factor)
-        return {"image0": img0, "image1": img1,
-                "scale0": scale0, "scale1": scale1,
-                "pair_names": (image_name0, image_name1),
-                "dataset_name": "AachenDayNight"}
\ No newline at end of file
+        img0, scale0 = read_img_gray(
+            path_img0, resize=self.img_resize, down_factor=self.down_factor
+        )
+        img1, scale1 = read_img_gray(
+            path_img1, resize=self.img_resize, down_factor=self.down_factor
+        )
+        return {
+            "image0": img0,
+            "image1": img1,
+            "scale0": scale0,
+            "scale1": scale1,
+            "pair_names": (image_name0, image_name1),
+            "dataset_name": "AachenDayNight",
+        }
diff --git a/third_party/TopicFM/src/datasets/custom_dataloader.py b/third_party/TopicFM/src/datasets/custom_dataloader.py
index 46d55d4f4d56d2c96cd42b6597834f945a5eb20d..eb3bd7a083baf5d0a1e8a9a21b97a08dcc22f163 100644
--- a/third_party/TopicFM/src/datasets/custom_dataloader.py
+++ b/third_party/TopicFM/src/datasets/custom_dataloader.py
@@ -28,99 +28,124 @@ class TestDataLoader(DataLoader):
 
         # 2. dataset config
         # general options
-        self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST  # 0.4, omit data with overlap_score < min_overlap_score
+        self.min_overlap_score_test = (
+            config.DATASET.MIN_OVERLAP_SCORE_TEST
+        )  # 0.4, omit data with overlap_score < min_overlap_score
 
         # MegaDepth options
-        if dataset_name == 'megadepth':
+        if dataset_name == "megadepth":
             self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE  # 800
             self.mgdpt_img_pad = True
             self.mgdpt_depth_pad = True
             self.mgdpt_df = 8
             self.coarse_scale = 0.125
-        if dataset_name == 'scannet':
+        if dataset_name == "scannet":
             self.img_resize = config.DATASET.TEST_IMGSIZE
 
-        if (dataset_name == 'megadepth') or (dataset_name == 'scannet'):
+        if (dataset_name == "megadepth") or (dataset_name == "scannet"):
             test_dataset = self._setup_dataset(
                 self.test_data_root,
                 self.test_npz_root,
                 self.test_list_path,
                 self.test_intrinsic_path,
-                mode='test',
+                mode="test",
                 min_overlap_score=self.min_overlap_score_test,
-                pose_dir=self.test_pose_root)
-        elif dataset_name == 'aachen_v1.1':
-            test_dataset = AachenDataset(self.test_data_root, self.test_list_path,
-                                         img_resize=config.DATASET.TEST_IMGSIZE)
-        elif dataset_name == 'inloc':
-            test_dataset = InLocDataset(self.test_data_root, self.test_list_path,
-                                        img_resize=config.DATASET.TEST_IMGSIZE)
+                pose_dir=self.test_pose_root,
+            )
+        elif dataset_name == "aachen_v1.1":
+            test_dataset = AachenDataset(
+                self.test_data_root,
+                self.test_list_path,
+                img_resize=config.DATASET.TEST_IMGSIZE,
+            )
+        elif dataset_name == "inloc":
+            test_dataset = InLocDataset(
+                self.test_data_root,
+                self.test_list_path,
+                img_resize=config.DATASET.TEST_IMGSIZE,
+            )
         else:
             raise "unknown dataset"
 
         self.test_loader_params = {
-            'batch_size': 1,
-            'shuffle': False,
-            'num_workers': 4,
-            'pin_memory': True
+            "batch_size": 1,
+            "shuffle": False,
+            "num_workers": 4,
+            "pin_memory": True,
         }
 
         # sampler = Seq(self.test_dataset, shuffle=False)
         super(TestDataLoader, self).__init__(test_dataset, **self.test_loader_params)
 
-    def _setup_dataset(self,
-                       data_root,
-                       split_npz_root,
-                       scene_list_path,
-                       intri_path,
-                       mode='train',
-                       min_overlap_score=0.,
-                       pose_dir=None):
-        """ Setup train / val / test set"""
-        with open(scene_list_path, 'r') as f:
+    def _setup_dataset(
+        self,
+        data_root,
+        split_npz_root,
+        scene_list_path,
+        intri_path,
+        mode="train",
+        min_overlap_score=0.0,
+        pose_dir=None,
+    ):
+        """Setup train / val / test set"""
+        with open(scene_list_path, "r") as f:
             npz_names = [name.split()[0] for name in f.readlines()]
         local_npz_names = npz_names
 
-        return self._build_concat_dataset(data_root, local_npz_names, split_npz_root, intri_path,
-                                          mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir)
+        return self._build_concat_dataset(
+            data_root,
+            local_npz_names,
+            split_npz_root,
+            intri_path,
+            mode=mode,
+            min_overlap_score=min_overlap_score,
+            pose_dir=pose_dir,
+        )
 
     def _build_concat_dataset(
-            self,
-            data_root,
-            npz_names,
-            npz_dir,
-            intrinsic_path,
-            mode,
-            min_overlap_score=0.,
-            pose_dir=None
+        self,
+        data_root,
+        npz_names,
+        npz_dir,
+        intrinsic_path,
+        mode,
+        min_overlap_score=0.0,
+        pose_dir=None,
     ):
         datasets = []
         # augment_fn = self.augment_fn if mode == 'train' else None
         data_source = self.test_data_source
-        if str(data_source).lower() == 'megadepth':
-            npz_names = [f'{n}.npz' for n in npz_names]
+        if str(data_source).lower() == "megadepth":
+            npz_names = [f"{n}.npz" for n in npz_names]
         for npz_name in tqdm(npz_names):
             # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time.
             npz_path = osp.join(npz_dir, npz_name)
-            if data_source == 'ScanNet':
+            if data_source == "ScanNet":
                 datasets.append(
-                    ScanNetDataset(data_root,
-                                   npz_path,
-                                   intrinsic_path,
-                                   mode=mode, img_resize=self.img_resize,
-                                   min_overlap_score=min_overlap_score,
-                                   pose_dir=pose_dir))
-            elif data_source == 'MegaDepth':
+                    ScanNetDataset(
+                        data_root,
+                        npz_path,
+                        intrinsic_path,
+                        mode=mode,
+                        img_resize=self.img_resize,
+                        min_overlap_score=min_overlap_score,
+                        pose_dir=pose_dir,
+                    )
+                )
+            elif data_source == "MegaDepth":
                 datasets.append(
-                    MegaDepthDataset(data_root,
-                                     npz_path,
-                                     mode=mode,
-                                     min_overlap_score=min_overlap_score,
-                                     img_resize=self.mgdpt_img_resize,
-                                     df=self.mgdpt_df,
-                                     img_padding=self.mgdpt_img_pad,
-                                     depth_padding=self.mgdpt_depth_pad,
-                                     coarse_scale=self.coarse_scale))
+                    MegaDepthDataset(
+                        data_root,
+                        npz_path,
+                        mode=mode,
+                        min_overlap_score=min_overlap_score,
+                        img_resize=self.mgdpt_img_resize,
+                        df=self.mgdpt_df,
+                        img_padding=self.mgdpt_img_pad,
+                        depth_padding=self.mgdpt_depth_pad,
+                        coarse_scale=self.coarse_scale,
+                    )
+                )
             else:
                 raise NotImplementedError()
         return ConcatDataset(datasets)
diff --git a/third_party/TopicFM/src/datasets/inloc.py b/third_party/TopicFM/src/datasets/inloc.py
index 5421099d11b4dbbea8c09568c493d844d5c6a1b0..dc176761b7626aafd90e9674c5d85ff6e95f537c 100644
--- a/third_party/TopicFM/src/datasets/inloc.py
+++ b/third_party/TopicFM/src/datasets/inloc.py
@@ -9,7 +9,7 @@ class InLocDataset(Dataset):
         self.img_path = img_path
         self.img_resize = img_resize
         self.down_factor = down_factor
-        with open(match_list_path, 'r') as f:
+        with open(match_list_path, "r") as f:
             self.raw_pairs = f.readlines()
         print("number of matching pairs: ", len(self.raw_pairs))
 
@@ -18,12 +18,20 @@ class InLocDataset(Dataset):
 
     def __getitem__(self, idx):
         raw_pair = self.raw_pairs[idx]
-        image_name0, image_name1 = raw_pair.strip('\n').split(' ')
+        image_name0, image_name1 = raw_pair.strip("\n").split(" ")
         path_img0 = os.path.join(self.img_path, image_name0)
         path_img1 = os.path.join(self.img_path, image_name1)
-        img0, scale0 = read_img_gray(path_img0, resize=self.img_resize, down_factor=self.down_factor)
-        img1, scale1 = read_img_gray(path_img1, resize=self.img_resize, down_factor=self.down_factor)
-        return {"image0": img0, "image1": img1,
-                "scale0": scale0, "scale1": scale1,
-                "pair_names": (image_name0, image_name1),
-                "dataset_name": "InLoc"}
\ No newline at end of file
+        img0, scale0 = read_img_gray(
+            path_img0, resize=self.img_resize, down_factor=self.down_factor
+        )
+        img1, scale1 = read_img_gray(
+            path_img1, resize=self.img_resize, down_factor=self.down_factor
+        )
+        return {
+            "image0": img0,
+            "image1": img1,
+            "scale0": scale0,
+            "scale1": scale1,
+            "pair_names": (image_name0, image_name1),
+            "dataset_name": "InLoc",
+        }
diff --git a/third_party/TopicFM/src/datasets/megadepth.py b/third_party/TopicFM/src/datasets/megadepth.py
index e92768e72e373c2a8ebeaf1158f9710fb1bfb5f1..77516327ebed8ca4ea8be9692a7077d94f03ee5b 100644
--- a/third_party/TopicFM/src/datasets/megadepth.py
+++ b/third_party/TopicFM/src/datasets/megadepth.py
@@ -9,20 +9,22 @@ from src.utils.dataset import read_megadepth_gray, read_megadepth_depth
 
 
 class MegaDepthDataset(Dataset):
-    def __init__(self,
-                 root_dir,
-                 npz_path,
-                 mode='train',
-                 min_overlap_score=0.4,
-                 img_resize=None,
-                 df=None,
-                 img_padding=False,
-                 depth_padding=False,
-                 augment_fn=None,
-                 **kwargs):
+    def __init__(
+        self,
+        root_dir,
+        npz_path,
+        mode="train",
+        min_overlap_score=0.4,
+        img_resize=None,
+        df=None,
+        img_padding=False,
+        depth_padding=False,
+        augment_fn=None,
+        **kwargs
+    ):
         """
         Manage one scene(npz_path) of MegaDepth dataset.
-        
+
         Args:
             root_dir (str): megadepth root directory that has `phoenix`.
             npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
@@ -38,30 +40,38 @@ class MegaDepthDataset(Dataset):
         super().__init__()
         self.root_dir = root_dir
         self.mode = mode
-        self.scene_id = npz_path.split('.')[0]
+        self.scene_id = npz_path.split(".")[0]
 
         # prepare scene_info and pair_info
-        if mode == 'test' and min_overlap_score != 0:
-            logger.warning("You are using `min_overlap_score`!=0 in test mode. Set to 0.")
+        if mode == "test" and min_overlap_score != 0:
+            logger.warning(
+                "You are using `min_overlap_score`!=0 in test mode. Set to 0."
+            )
             min_overlap_score = 0
         self.scene_info = np.load(npz_path, allow_pickle=True)
-        self.pair_infos = self.scene_info['pair_infos'].copy()
-        del self.scene_info['pair_infos']
-        self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score]
+        self.pair_infos = self.scene_info["pair_infos"].copy()
+        del self.scene_info["pair_infos"]
+        self.pair_infos = [
+            pair_info
+            for pair_info in self.pair_infos
+            if pair_info[1] > min_overlap_score
+        ]
 
         # parameters for image resizing, padding and depthmap padding
-        if mode == 'train':
+        if mode == "train":
             assert img_resize is not None and img_padding and depth_padding
         self.img_resize = img_resize
-        if mode == 'val':
+        if mode == "val":
             self.img_resize = 864
         self.df = df
         self.img_padding = img_padding
-        self.depth_max_size = 2000 if depth_padding else None  # the upperbound of depthmaps size in megadepth.
+        self.depth_max_size = (
+            2000 if depth_padding else None
+        )  # the upperbound of depthmaps size in megadepth.
 
         # for training LoFTR
-        self.augment_fn = augment_fn if mode == 'train' else None
-        self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125)
+        self.augment_fn = augment_fn if mode == "train" else None
+        self.coarse_scale = getattr(kwargs, "coarse_scale", 0.125)
 
     def __len__(self):
         return len(self.pair_infos)
@@ -70,60 +80,77 @@ class MegaDepthDataset(Dataset):
         (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx]
 
         # read grayscale image and mask. (1, h, w) and (h, w)
-        img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0])
-        img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1])
-        
+        img_name0 = osp.join(self.root_dir, self.scene_info["image_paths"][idx0])
+        img_name1 = osp.join(self.root_dir, self.scene_info["image_paths"][idx1])
+
         # TODO: Support augmentation & handle seeds for each worker correctly.
         image0, mask0, scale0 = read_megadepth_gray(
-            img_name0, self.img_resize, self.df, self.img_padding, None)
-            # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+            img_name0, self.img_resize, self.df, self.img_padding, None
+        )
+        # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
         image1, mask1, scale1 = read_megadepth_gray(
-            img_name1, self.img_resize, self.df, self.img_padding, None)
-            # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+            img_name1, self.img_resize, self.df, self.img_padding, None
+        )
+        # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
 
         # read depth. shape: (h, w)
-        if self.mode in ['train', 'val']:
+        if self.mode in ["train", "val"]:
             depth0 = read_megadepth_depth(
-                osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size)
+                osp.join(self.root_dir, self.scene_info["depth_paths"][idx0]),
+                pad_to=self.depth_max_size,
+            )
             depth1 = read_megadepth_depth(
-                osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size)
+                osp.join(self.root_dir, self.scene_info["depth_paths"][idx1]),
+                pad_to=self.depth_max_size,
+            )
         else:
             depth0 = depth1 = torch.tensor([])
 
         # read intrinsics of original size
-        K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3)
-        K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3)
+        K_0 = torch.tensor(
+            self.scene_info["intrinsics"][idx0].copy(), dtype=torch.float
+        ).reshape(3, 3)
+        K_1 = torch.tensor(
+            self.scene_info["intrinsics"][idx1].copy(), dtype=torch.float
+        ).reshape(3, 3)
 
         # read and compute relative poses
-        T0 = self.scene_info['poses'][idx0]
-        T1 = self.scene_info['poses'][idx1]
-        T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4]  # (4, 4)
+        T0 = self.scene_info["poses"][idx0]
+        T1 = self.scene_info["poses"][idx1]
+        T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[
+            :4, :4
+        ]  # (4, 4)
         T_1to0 = T_0to1.inverse()
 
         data = {
-            'image0': image0,  # (1, h, w)
-            'depth0': depth0,  # (h, w)
-            'image1': image1,
-            'depth1': depth1,
-            'T_0to1': T_0to1,  # (4, 4)
-            'T_1to0': T_1to0,
-            'K0': K_0,  # (3, 3)
-            'K1': K_1,
-            'scale0': scale0,  # [scale_w, scale_h]
-            'scale1': scale1,
-            'dataset_name': 'MegaDepth',
-            'scene_id': self.scene_id,
-            'pair_id': idx,
-            'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]),
+            "image0": image0,  # (1, h, w)
+            "depth0": depth0,  # (h, w)
+            "image1": image1,
+            "depth1": depth1,
+            "T_0to1": T_0to1,  # (4, 4)
+            "T_1to0": T_1to0,
+            "K0": K_0,  # (3, 3)
+            "K1": K_1,
+            "scale0": scale0,  # [scale_w, scale_h]
+            "scale1": scale1,
+            "dataset_name": "MegaDepth",
+            "scene_id": self.scene_id,
+            "pair_id": idx,
+            "pair_names": (
+                self.scene_info["image_paths"][idx0],
+                self.scene_info["image_paths"][idx1],
+            ),
         }
 
         # for LoFTR training
         if mask0 is not None:  # img_padding is True
             if self.coarse_scale:
-                [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
-                                                       scale_factor=self.coarse_scale,
-                                                       mode='nearest',
-                                                       recompute_scale_factor=False)[0].bool()
-            data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1})
+                [ts_mask_0, ts_mask_1] = F.interpolate(
+                    torch.stack([mask0, mask1], dim=0)[None].float(),
+                    scale_factor=self.coarse_scale,
+                    mode="nearest",
+                    recompute_scale_factor=False,
+                )[0].bool()
+            data.update({"mask0": ts_mask_0, "mask1": ts_mask_1})
 
         return data
diff --git a/third_party/TopicFM/src/datasets/sampler.py b/third_party/TopicFM/src/datasets/sampler.py
index 81b6f435645632a013476f9a665a0861ab7fcb61..131111c4cf69cd8770058dfac2be717aa183978e 100644
--- a/third_party/TopicFM/src/datasets/sampler.py
+++ b/third_party/TopicFM/src/datasets/sampler.py
@@ -3,10 +3,10 @@ from torch.utils.data import Sampler, ConcatDataset
 
 
 class RandomConcatSampler(Sampler):
-    """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
+    """Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
     in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement.
     However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase.
-    
+
     For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not.
     Args:
         shuffle (bool): shuffle the random sampled indices across all sub-datsets.
@@ -18,16 +18,19 @@ class RandomConcatSampler(Sampler):
     TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs.
           ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373
     """
-    def __init__(self,
-                 data_source: ConcatDataset,
-                 n_samples_per_subset: int,
-                 subset_replacement: bool=True,
-                 shuffle: bool=True,
-                 repeat: int=1,
-                 seed: int=None):
+
+    def __init__(
+        self,
+        data_source: ConcatDataset,
+        n_samples_per_subset: int,
+        subset_replacement: bool = True,
+        shuffle: bool = True,
+        repeat: int = 1,
+        seed: int = None,
+    ):
         if not isinstance(data_source, ConcatDataset):
             raise TypeError("data_source should be torch.utils.data.ConcatDataset")
-        
+
         self.data_source = data_source
         self.n_subset = len(self.data_source.datasets)
         self.n_samples_per_subset = n_samples_per_subset
@@ -37,27 +40,37 @@ class RandomConcatSampler(Sampler):
         self.shuffle = shuffle
         self.generator = torch.manual_seed(seed)
         assert self.repeat >= 1
-        
+
     def __len__(self):
         return self.n_samples
-    
+
     def __iter__(self):
         indices = []
         # sample from each sub-dataset
         for d_idx in range(self.n_subset):
-            low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1]
+            low = 0 if d_idx == 0 else self.data_source.cumulative_sizes[d_idx - 1]
             high = self.data_source.cumulative_sizes[d_idx]
             if self.subset_replacement:
-                rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ),
-                                            generator=self.generator, dtype=torch.int64)
+                rand_tensor = torch.randint(
+                    low,
+                    high,
+                    (self.n_samples_per_subset,),
+                    generator=self.generator,
+                    dtype=torch.int64,
+                )
             else:  # sample without replacement
                 len_subset = len(self.data_source.datasets[d_idx])
                 rand_tensor = torch.randperm(len_subset, generator=self.generator) + low
                 if len_subset >= self.n_samples_per_subset:
-                    rand_tensor = rand_tensor[:self.n_samples_per_subset]
-                else: # padding with replacement
-                    rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ),
-                                                            generator=self.generator, dtype=torch.int64)
+                    rand_tensor = rand_tensor[: self.n_samples_per_subset]
+                else:  # padding with replacement
+                    rand_tensor_replacement = torch.randint(
+                        low,
+                        high,
+                        (self.n_samples_per_subset - len_subset,),
+                        generator=self.generator,
+                        dtype=torch.int64,
+                    )
                     rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement])
             indices.append(rand_tensor)
         indices = torch.cat(indices)
@@ -72,6 +85,6 @@ class RandomConcatSampler(Sampler):
                 _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)]
                 repeat_indices = map(_choice, repeat_indices)
             indices = torch.cat([indices, *repeat_indices], 0)
-        
+
         assert indices.shape[0] == self.n_samples
         return iter(indices.tolist())
diff --git a/third_party/TopicFM/src/datasets/scannet.py b/third_party/TopicFM/src/datasets/scannet.py
index fb5dab7b150a3c6f54eb07b0459bbf3e9ba58fbf..b955c4fa1609625be2c6c1a0ed6665109908bba0 100644
--- a/third_party/TopicFM/src/datasets/scannet.py
+++ b/third_party/TopicFM/src/datasets/scannet.py
@@ -10,20 +10,22 @@ from src.utils.dataset import (
     read_scannet_gray,
     read_scannet_depth,
     read_scannet_pose,
-    read_scannet_intrinsic
+    read_scannet_intrinsic,
 )
 
 
 class ScanNetDataset(utils.data.Dataset):
-    def __init__(self,
-                 root_dir,
-                 npz_path,
-                 intrinsic_path,
-                 mode='train',
-                 min_overlap_score=0.4,
-                 augment_fn=None,
-                 pose_dir=None,
-                 **kwargs):
+    def __init__(
+        self,
+        root_dir,
+        npz_path,
+        intrinsic_path,
+        mode="train",
+        min_overlap_score=0.4,
+        augment_fn=None,
+        pose_dir=None,
+        **kwargs,
+    ):
         """Manage one scene of ScanNet Dataset.
         Args:
             root_dir (str): ScanNet root directory that contains scene folders.
@@ -38,78 +40,88 @@ class ScanNetDataset(utils.data.Dataset):
         self.root_dir = root_dir
         self.pose_dir = pose_dir if pose_dir is not None else root_dir
         self.mode = mode
-        self.img_resize = (640, 480) if 'img_resize' not in kwargs else kwargs['img_resize']
+        self.img_resize = (
+            (640, 480) if "img_resize" not in kwargs else kwargs["img_resize"]
+        )
 
         # prepare data_names, intrinsics and extrinsics(T)
         with np.load(npz_path) as data:
-            self.data_names = data['name']
-            if 'score' in data.keys() and mode not in ['val' or 'test']:
-                kept_mask = data['score'] > min_overlap_score
+            self.data_names = data["name"]
+            if "score" in data.keys() and mode not in ["val" or "test"]:
+                kept_mask = data["score"] > min_overlap_score
                 self.data_names = self.data_names[kept_mask]
         self.intrinsics = dict(np.load(intrinsic_path))
 
         # for training LoFTR
-        self.augment_fn = augment_fn if mode == 'train' else None
+        self.augment_fn = augment_fn if mode == "train" else None
 
     def __len__(self):
         return len(self.data_names)
 
     def _read_abs_pose(self, scene_name, name):
-        pth = osp.join(self.pose_dir,
-                       scene_name,
-                       'pose', f'{name}.txt')
+        pth = osp.join(self.pose_dir, scene_name, "pose", f"{name}.txt")
         return read_scannet_pose(pth)
 
     def _compute_rel_pose(self, scene_name, name0, name1):
         pose0 = self._read_abs_pose(scene_name, name0)
         pose1 = self._read_abs_pose(scene_name, name1)
-        
+
         return np.matmul(pose1, inv(pose0))  # (4, 4)
 
     def __getitem__(self, idx):
         data_name = self.data_names[idx]
         scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name
-        scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
+        scene_name = f"scene{scene_name:04d}_{scene_sub_name:02d}"
 
         # read the grayscale image which will be resized to (1, 480, 640)
-        img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg')
-        img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg')
-        
+        img_name0 = osp.join(self.root_dir, scene_name, "color", f"{stem_name_0}.jpg")
+        img_name1 = osp.join(self.root_dir, scene_name, "color", f"{stem_name_1}.jpg")
+
         # TODO: Support augmentation & handle seeds for each worker correctly.
         image0 = read_scannet_gray(img_name0, resize=self.img_resize, augment_fn=None)
-                                #    augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+        #    augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
         image1 = read_scannet_gray(img_name1, resize=self.img_resize, augment_fn=None)
-                                #    augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+        #    augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
 
         # read the depthmap which is stored as (480, 640)
-        if self.mode in ['train', 'val']:
-            depth0 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png'))
-            depth1 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png'))
+        if self.mode in ["train", "val"]:
+            depth0 = read_scannet_depth(
+                osp.join(self.root_dir, scene_name, "depth", f"{stem_name_0}.png")
+            )
+            depth1 = read_scannet_depth(
+                osp.join(self.root_dir, scene_name, "depth", f"{stem_name_1}.png")
+            )
         else:
             depth0 = depth1 = torch.tensor([])
 
         # read the intrinsic of depthmap
-        K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3)
+        K_0 = K_1 = torch.tensor(
+            self.intrinsics[scene_name].copy(), dtype=torch.float
+        ).reshape(3, 3)
 
         # read and compute relative poses
-        T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1),
-                              dtype=torch.float32)
+        T_0to1 = torch.tensor(
+            self._compute_rel_pose(scene_name, stem_name_0, stem_name_1),
+            dtype=torch.float32,
+        )
         T_1to0 = T_0to1.inverse()
 
         data = {
-            'image0': image0,   # (1, h, w)
-            'depth0': depth0,   # (h, w)
-            'image1': image1,
-            'depth1': depth1,
-            'T_0to1': T_0to1,   # (4, 4)
-            'T_1to0': T_1to0,
-            'K0': K_0,  # (3, 3)
-            'K1': K_1,
-            'dataset_name': 'ScanNet',
-            'scene_id': scene_name,
-            'pair_id': idx,
-            'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'),
-                           osp.join(scene_name, 'color', f'{stem_name_1}.jpg'))
+            "image0": image0,  # (1, h, w)
+            "depth0": depth0,  # (h, w)
+            "image1": image1,
+            "depth1": depth1,
+            "T_0to1": T_0to1,  # (4, 4)
+            "T_1to0": T_1to0,
+            "K0": K_0,  # (3, 3)
+            "K1": K_1,
+            "dataset_name": "ScanNet",
+            "scene_id": scene_name,
+            "pair_id": idx,
+            "pair_names": (
+                osp.join(scene_name, "color", f"{stem_name_0}.jpg"),
+                osp.join(scene_name, "color", f"{stem_name_1}.jpg"),
+            ),
         }
 
         return data
diff --git a/third_party/TopicFM/src/lightning_trainer/data.py b/third_party/TopicFM/src/lightning_trainer/data.py
index 8deb713b6300e0e9e8a261e2230031174b452862..95f6a5eeecf39a993b86674242eacb7b42f8a566 100644
--- a/third_party/TopicFM/src/lightning_trainer/data.py
+++ b/third_party/TopicFM/src/lightning_trainer/data.py
@@ -16,7 +16,7 @@ from torch.utils.data import (
     ConcatDataset,
     DistributedSampler,
     RandomSampler,
-    dataloader
+    dataloader,
 )
 
 from src.utils.augment import build_augmentor
@@ -29,10 +29,11 @@ from src.datasets.sampler import RandomConcatSampler
 
 
 class MultiSceneDataModule(pl.LightningDataModule):
-    """ 
+    """
     For distributed training, each training process is assgined
     only a part of the training scenes to reduce memory overhead.
     """
+
     def __init__(self, args, config):
         super().__init__()
 
@@ -60,47 +61,51 @@ class MultiSceneDataModule(pl.LightningDataModule):
 
         # 2. dataset config
         # general options
-        self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST  # 0.4, omit data with overlap_score < min_overlap_score
+        self.min_overlap_score_test = (
+            config.DATASET.MIN_OVERLAP_SCORE_TEST
+        )  # 0.4, omit data with overlap_score < min_overlap_score
         self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN
-        self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE)  # None, options: [None, 'dark', 'mobile']
+        self.augment_fn = build_augmentor(
+            config.DATASET.AUGMENTATION_TYPE
+        )  # None, options: [None, 'dark', 'mobile']
 
         # MegaDepth options
         self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE  # 840
-        self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD   # True
-        self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD   # True
+        self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD  # True
+        self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD  # True
         self.mgdpt_df = config.DATASET.MGDPT_DF  # 8
         self.coarse_scale = 1 / config.MODEL.RESOLUTION[0]  # 0.125. for training loftr.
 
         # 3.loader parameters
         self.train_loader_params = {
-            'batch_size': args.batch_size,
-            'num_workers': args.num_workers,
-            'pin_memory': getattr(args, 'pin_memory', True)
+            "batch_size": args.batch_size,
+            "num_workers": args.num_workers,
+            "pin_memory": getattr(args, "pin_memory", True),
         }
         self.val_loader_params = {
-            'batch_size': 1,
-            'shuffle': False,
-            'num_workers': args.num_workers,
-            'pin_memory': getattr(args, 'pin_memory', True)
+            "batch_size": 1,
+            "shuffle": False,
+            "num_workers": args.num_workers,
+            "pin_memory": getattr(args, "pin_memory", True),
         }
         self.test_loader_params = {
-            'batch_size': 1,
-            'shuffle': False,
-            'num_workers': args.num_workers,
-            'pin_memory': True
+            "batch_size": 1,
+            "shuffle": False,
+            "num_workers": args.num_workers,
+            "pin_memory": True,
         }
-        
+
         # 4. sampler
         self.data_sampler = config.TRAINER.DATA_SAMPLER
         self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET
         self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT
         self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE
         self.repeat = config.TRAINER.SB_REPEAT
-        
+
         # (optional) RandomSampler for debugging
 
         # misc configurations
-        self.parallel_load_data = getattr(args, 'parallel_load_data', False)
+        self.parallel_load_data = getattr(args, "parallel_load_data", False)
         self.seed = config.TRAINER.SEED  # 66
 
     def setup(self, stage=None):
@@ -110,7 +115,7 @@ class MultiSceneDataModule(pl.LightningDataModule):
             stage (str): 'fit' in training phase, and 'test' in testing phase.
         """
 
-        assert stage in ['fit', 'test'], "stage must be either fit or test"
+        assert stage in ["fit", "test"], "stage must be either fit or test"
 
         try:
             self.world_size = dist.get_world_size()
@@ -121,73 +126,94 @@ class MultiSceneDataModule(pl.LightningDataModule):
             self.rank = 0
             logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)")
 
-        if stage == 'fit':
+        if stage == "fit":
             self.train_dataset = self._setup_dataset(
                 self.train_data_root,
                 self.train_npz_root,
                 self.train_list_path,
                 self.train_intrinsic_path,
-                mode='train',
+                mode="train",
                 min_overlap_score=self.min_overlap_score_train,
-                pose_dir=self.train_pose_root)
+                pose_dir=self.train_pose_root,
+            )
             # setup multiple (optional) validation subsets
             if isinstance(self.val_list_path, (list, tuple)):
                 self.val_dataset = []
                 if not isinstance(self.val_npz_root, (list, tuple)):
-                    self.val_npz_root = [self.val_npz_root for _ in range(len(self.val_list_path))]
+                    self.val_npz_root = [
+                        self.val_npz_root for _ in range(len(self.val_list_path))
+                    ]
                 for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root):
-                    self.val_dataset.append(self._setup_dataset(
-                        self.val_data_root,
-                        npz_root,
-                        npz_list,
-                        self.val_intrinsic_path,
-                        mode='val',
-                        min_overlap_score=self.min_overlap_score_test,
-                        pose_dir=self.val_pose_root))
+                    self.val_dataset.append(
+                        self._setup_dataset(
+                            self.val_data_root,
+                            npz_root,
+                            npz_list,
+                            self.val_intrinsic_path,
+                            mode="val",
+                            min_overlap_score=self.min_overlap_score_test,
+                            pose_dir=self.val_pose_root,
+                        )
+                    )
             else:
                 self.val_dataset = self._setup_dataset(
                     self.val_data_root,
                     self.val_npz_root,
                     self.val_list_path,
                     self.val_intrinsic_path,
-                    mode='val',
+                    mode="val",
                     min_overlap_score=self.min_overlap_score_test,
-                    pose_dir=self.val_pose_root)
-            logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!')
+                    pose_dir=self.val_pose_root,
+                )
+            logger.info(f"[rank:{self.rank}] Train & Val Dataset loaded!")
         else:  # stage == 'test
             self.test_dataset = self._setup_dataset(
                 self.test_data_root,
                 self.test_npz_root,
                 self.test_list_path,
                 self.test_intrinsic_path,
-                mode='test',
+                mode="test",
                 min_overlap_score=self.min_overlap_score_test,
-                pose_dir=self.test_pose_root)
-            logger.info(f'[rank:{self.rank}]: Test Dataset loaded!')
+                pose_dir=self.test_pose_root,
+            )
+            logger.info(f"[rank:{self.rank}]: Test Dataset loaded!")
 
-    def _setup_dataset(self,
-                       data_root,
-                       split_npz_root,
-                       scene_list_path,
-                       intri_path,
-                       mode='train',
-                       min_overlap_score=0.,
-                       pose_dir=None):
-        """ Setup train / val / test set"""
-        with open(scene_list_path, 'r') as f:
+    def _setup_dataset(
+        self,
+        data_root,
+        split_npz_root,
+        scene_list_path,
+        intri_path,
+        mode="train",
+        min_overlap_score=0.0,
+        pose_dir=None,
+    ):
+        """Setup train / val / test set"""
+        with open(scene_list_path, "r") as f:
             npz_names = [name.split()[0] for name in f.readlines()]
 
-        if mode == 'train':
-            local_npz_names = get_local_split(npz_names, self.world_size, self.rank, self.seed)
+        if mode == "train":
+            local_npz_names = get_local_split(
+                npz_names, self.world_size, self.rank, self.seed
+            )
         else:
             local_npz_names = npz_names
-        logger.info(f'[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.')
-        
-        dataset_builder = self._build_concat_dataset_parallel \
-                            if self.parallel_load_data \
-                            else self._build_concat_dataset
-        return dataset_builder(data_root, local_npz_names, split_npz_root, intri_path,
-                                mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir)
+        logger.info(f"[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.")
+
+        dataset_builder = (
+            self._build_concat_dataset_parallel
+            if self.parallel_load_data
+            else self._build_concat_dataset
+        )
+        return dataset_builder(
+            data_root,
+            local_npz_names,
+            split_npz_root,
+            intri_path,
+            mode=mode,
+            min_overlap_score=min_overlap_score,
+            pose_dir=pose_dir,
+        )
 
     def _build_concat_dataset(
         self,
@@ -196,44 +222,56 @@ class MultiSceneDataModule(pl.LightningDataModule):
         npz_dir,
         intrinsic_path,
         mode,
-        min_overlap_score=0.,
-        pose_dir=None
+        min_overlap_score=0.0,
+        pose_dir=None,
     ):
         datasets = []
-        augment_fn = self.augment_fn if mode == 'train' else None
-        data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source
-        if str(data_source).lower() == 'megadepth':
-            npz_names = [f'{n}.npz' for n in npz_names]
-        for npz_name in tqdm(npz_names,
-                             desc=f'[rank:{self.rank}] loading {mode} datasets',
-                             disable=int(self.rank) != 0):
+        augment_fn = self.augment_fn if mode == "train" else None
+        data_source = (
+            self.trainval_data_source
+            if mode in ["train", "val"]
+            else self.test_data_source
+        )
+        if str(data_source).lower() == "megadepth":
+            npz_names = [f"{n}.npz" for n in npz_names]
+        for npz_name in tqdm(
+            npz_names,
+            desc=f"[rank:{self.rank}] loading {mode} datasets",
+            disable=int(self.rank) != 0,
+        ):
             # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time.
             npz_path = osp.join(npz_dir, npz_name)
-            if data_source == 'ScanNet':
+            if data_source == "ScanNet":
                 datasets.append(
-                    ScanNetDataset(data_root,
-                                   npz_path,
-                                   intrinsic_path,
-                                   mode=mode,
-                                   min_overlap_score=min_overlap_score,
-                                   augment_fn=augment_fn,
-                                   pose_dir=pose_dir))
-            elif data_source == 'MegaDepth':
+                    ScanNetDataset(
+                        data_root,
+                        npz_path,
+                        intrinsic_path,
+                        mode=mode,
+                        min_overlap_score=min_overlap_score,
+                        augment_fn=augment_fn,
+                        pose_dir=pose_dir,
+                    )
+                )
+            elif data_source == "MegaDepth":
                 datasets.append(
-                    MegaDepthDataset(data_root,
-                                     npz_path,
-                                     mode=mode,
-                                     min_overlap_score=min_overlap_score,
-                                     img_resize=self.mgdpt_img_resize,
-                                     df=self.mgdpt_df,
-                                     img_padding=self.mgdpt_img_pad,
-                                     depth_padding=self.mgdpt_depth_pad,
-                                     augment_fn=augment_fn,
-                                     coarse_scale=self.coarse_scale))
+                    MegaDepthDataset(
+                        data_root,
+                        npz_path,
+                        mode=mode,
+                        min_overlap_score=min_overlap_score,
+                        img_resize=self.mgdpt_img_resize,
+                        df=self.mgdpt_df,
+                        img_padding=self.mgdpt_img_pad,
+                        depth_padding=self.mgdpt_depth_pad,
+                        augment_fn=augment_fn,
+                        coarse_scale=self.coarse_scale,
+                    )
+                )
             else:
                 raise NotImplementedError()
         return ConcatDataset(datasets)
-    
+
     def _build_concat_dataset_parallel(
         self,
         data_root,
@@ -241,77 +279,118 @@ class MultiSceneDataModule(pl.LightningDataModule):
         npz_dir,
         intrinsic_path,
         mode,
-        min_overlap_score=0.,
+        min_overlap_score=0.0,
         pose_dir=None,
     ):
-        augment_fn = self.augment_fn if mode == 'train' else None
-        data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source
-        if str(data_source).lower() == 'megadepth':
-            npz_names = [f'{n}.npz' for n in npz_names]
-        with tqdm_joblib(tqdm(desc=f'[rank:{self.rank}] loading {mode} datasets',
-                              total=len(npz_names), disable=int(self.rank) != 0)):
-            if data_source == 'ScanNet':
-                datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))(
-                    delayed(lambda x: _build_dataset(
-                        ScanNetDataset,
-                        data_root,
-                        osp.join(npz_dir, x),
-                        intrinsic_path,
-                        mode=mode,
-                        min_overlap_score=min_overlap_score,
-                        augment_fn=augment_fn,
-                        pose_dir=pose_dir))(name)
-                    for name in npz_names)
-            elif data_source == 'MegaDepth':
+        augment_fn = self.augment_fn if mode == "train" else None
+        data_source = (
+            self.trainval_data_source
+            if mode in ["train", "val"]
+            else self.test_data_source
+        )
+        if str(data_source).lower() == "megadepth":
+            npz_names = [f"{n}.npz" for n in npz_names]
+        with tqdm_joblib(
+            tqdm(
+                desc=f"[rank:{self.rank}] loading {mode} datasets",
+                total=len(npz_names),
+                disable=int(self.rank) != 0,
+            )
+        ):
+            if data_source == "ScanNet":
+                datasets = Parallel(
+                    n_jobs=math.floor(
+                        len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()
+                    )
+                )(
+                    delayed(
+                        lambda x: _build_dataset(
+                            ScanNetDataset,
+                            data_root,
+                            osp.join(npz_dir, x),
+                            intrinsic_path,
+                            mode=mode,
+                            min_overlap_score=min_overlap_score,
+                            augment_fn=augment_fn,
+                            pose_dir=pose_dir,
+                        )
+                    )(name)
+                    for name in npz_names
+                )
+            elif data_source == "MegaDepth":
                 # TODO: _pickle.PicklingError: Could not pickle the task to send it to the workers.
                 raise NotImplementedError()
-                datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))(
-                    delayed(lambda x: _build_dataset(
-                        MegaDepthDataset,
-                        data_root,
-                        osp.join(npz_dir, x),
-                        mode=mode,
-                        min_overlap_score=min_overlap_score,
-                        img_resize=self.mgdpt_img_resize,
-                        df=self.mgdpt_df,
-                        img_padding=self.mgdpt_img_pad,
-                        depth_padding=self.mgdpt_depth_pad,
-                        augment_fn=augment_fn,
-                        coarse_scale=self.coarse_scale))(name)
-                    for name in npz_names)
+                datasets = Parallel(
+                    n_jobs=math.floor(
+                        len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()
+                    )
+                )(
+                    delayed(
+                        lambda x: _build_dataset(
+                            MegaDepthDataset,
+                            data_root,
+                            osp.join(npz_dir, x),
+                            mode=mode,
+                            min_overlap_score=min_overlap_score,
+                            img_resize=self.mgdpt_img_resize,
+                            df=self.mgdpt_df,
+                            img_padding=self.mgdpt_img_pad,
+                            depth_padding=self.mgdpt_depth_pad,
+                            augment_fn=augment_fn,
+                            coarse_scale=self.coarse_scale,
+                        )
+                    )(name)
+                    for name in npz_names
+                )
             else:
-                raise ValueError(f'Unknown dataset: {data_source}')
+                raise ValueError(f"Unknown dataset: {data_source}")
         return ConcatDataset(datasets)
 
     def train_dataloader(self):
-        """ Build training dataloader for ScanNet / MegaDepth. """
-        assert self.data_sampler in ['scene_balance']
-        logger.info(f'[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!).')
-        if self.data_sampler == 'scene_balance':
-            sampler = RandomConcatSampler(self.train_dataset,
-                                          self.n_samples_per_subset,
-                                          self.subset_replacement,
-                                          self.shuffle, self.repeat, self.seed)
+        """Build training dataloader for ScanNet / MegaDepth."""
+        assert self.data_sampler in ["scene_balance"]
+        logger.info(
+            f"[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!)."
+        )
+        if self.data_sampler == "scene_balance":
+            sampler = RandomConcatSampler(
+                self.train_dataset,
+                self.n_samples_per_subset,
+                self.subset_replacement,
+                self.shuffle,
+                self.repeat,
+                self.seed,
+            )
         else:
             sampler = None
-        dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params)
+        dataloader = DataLoader(
+            self.train_dataset, sampler=sampler, **self.train_loader_params
+        )
         return dataloader
-    
+
     def val_dataloader(self):
-        """ Build validation dataloader for ScanNet / MegaDepth. """
-        logger.info(f'[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init.')
+        """Build validation dataloader for ScanNet / MegaDepth."""
+        logger.info(
+            f"[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init."
+        )
         if not isinstance(self.val_dataset, abc.Sequence):
             sampler = DistributedSampler(self.val_dataset, shuffle=False)
-            return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params)
+            return DataLoader(
+                self.val_dataset, sampler=sampler, **self.val_loader_params
+            )
         else:
             dataloaders = []
             for dataset in self.val_dataset:
                 sampler = DistributedSampler(dataset, shuffle=False)
-                dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params))
+                dataloaders.append(
+                    DataLoader(dataset, sampler=sampler, **self.val_loader_params)
+                )
             return dataloaders
 
     def test_dataloader(self, *args, **kwargs):
-        logger.info(f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.')
+        logger.info(
+            f"[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init."
+        )
         sampler = DistributedSampler(self.test_dataset, shuffle=False)
         return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params)
 
diff --git a/third_party/TopicFM/src/lightning_trainer/trainer.py b/third_party/TopicFM/src/lightning_trainer/trainer.py
index acf51f66130be66b7d3294ca5c081a2df3856d96..cce4839b536eba974426309eca10415547479f50 100644
--- a/third_party/TopicFM/src/lightning_trainer/trainer.py
+++ b/third_party/TopicFM/src/lightning_trainer/trainer.py
@@ -1,4 +1,3 @@
-
 from collections import defaultdict
 import pprint
 from loguru import logger
@@ -10,13 +9,16 @@ import pytorch_lightning as pl
 from matplotlib import pyplot as plt
 
 from src.models import TopicFM
-from src.models.utils.supervision import compute_supervision_coarse, compute_supervision_fine
+from src.models.utils.supervision import (
+    compute_supervision_coarse,
+    compute_supervision_fine,
+)
 from src.losses.loss import TopicFMLoss
 from src.optimizers import build_optimizer, build_scheduler
 from src.utils.metrics import (
     compute_symmetrical_epipolar_errors,
     compute_pose_errors,
-    aggregate_metrics
+    aggregate_metrics,
 )
 from src.utils.plotting import make_matching_figures
 from src.utils.comm import gather, all_gather
@@ -34,168 +36,225 @@ class PL_Trainer(pl.LightningModule):
         # Misc
         self.config = config  # full config
         _config = lower_config(self.config)
-        self.model_cfg = lower_config(_config['model'])
+        self.model_cfg = lower_config(_config["model"])
         self.profiler = profiler or PassThroughProfiler()
-        self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1)
+        self.n_vals_plot = max(
+            config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1
+        )
 
         # Matcher: TopicFM
-        self.matcher = TopicFM(config=_config['model'])
+        self.matcher = TopicFM(config=_config["model"])
         self.loss = TopicFMLoss(_config)
 
         # Pretrained weights
         if pretrained_ckpt:
-            state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict']
+            state_dict = torch.load(pretrained_ckpt, map_location="cpu")["state_dict"]
             self.matcher.load_state_dict(state_dict, strict=True)
-            logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint")
-        
+            logger.info(f"Load '{pretrained_ckpt}' as pretrained checkpoint")
+
         # Testing
         self.dump_dir = dump_dir
-        
+
     def configure_optimizers(self):
         # FIXME: The scheduler did not work properly when `--resume_from_checkpoint`
         optimizer = build_optimizer(self, self.config)
         scheduler = build_scheduler(self.config, optimizer)
         return [optimizer], [scheduler]
-    
+
     def optimizer_step(
-            self, epoch, batch_idx, optimizer, optimizer_idx,
-            optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
+        self,
+        epoch,
+        batch_idx,
+        optimizer,
+        optimizer_idx,
+        optimizer_closure,
+        on_tpu,
+        using_native_amp,
+        using_lbfgs,
+    ):
         # learning rate warm up
         warmup_step = self.config.TRAINER.WARMUP_STEP
         if self.trainer.global_step < warmup_step:
-            if self.config.TRAINER.WARMUP_TYPE == 'linear':
+            if self.config.TRAINER.WARMUP_TYPE == "linear":
                 base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
-                lr = base_lr + \
-                    (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \
-                    abs(self.config.TRAINER.TRUE_LR - base_lr)
+                lr = base_lr + (
+                    self.trainer.global_step / self.config.TRAINER.WARMUP_STEP
+                ) * abs(self.config.TRAINER.TRUE_LR - base_lr)
                 for pg in optimizer.param_groups:
-                    pg['lr'] = lr
-            elif self.config.TRAINER.WARMUP_TYPE == 'constant':
+                    pg["lr"] = lr
+            elif self.config.TRAINER.WARMUP_TYPE == "constant":
                 pass
             else:
-                raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}')
+                raise ValueError(
+                    f"Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}"
+                )
 
         # update params
         optimizer.step(closure=optimizer_closure)
         optimizer.zero_grad()
-    
+
     def _trainval_inference(self, batch):
         with self.profiler.profile("Compute coarse supervision"):
             compute_supervision_coarse(batch, self.config)
-        
+
         with self.profiler.profile("TopicFM"):
             self.matcher(batch)
-        
+
         with self.profiler.profile("Compute fine supervision"):
             compute_supervision_fine(batch, self.config)
-            
+
         with self.profiler.profile("Compute losses"):
             self.loss(batch)
-    
+
     def _compute_metrics(self, batch):
         with self.profiler.profile("Copmute metrics"):
-            compute_symmetrical_epipolar_errors(batch)  # compute epi_errs for each match
-            compute_pose_errors(batch, self.config)  # compute R_errs, t_errs, pose_errs for each pair
+            compute_symmetrical_epipolar_errors(
+                batch
+            )  # compute epi_errs for each match
+            compute_pose_errors(
+                batch, self.config
+            )  # compute R_errs, t_errs, pose_errs for each pair
 
-            rel_pair_names = list(zip(*batch['pair_names']))
-            bs = batch['image0'].size(0)
+            rel_pair_names = list(zip(*batch["pair_names"]))
+            bs = batch["image0"].size(0)
             metrics = {
                 # to filter duplicate pairs caused by DistributedSampler
-                'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
-                'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)],
-                'R_errs': batch['R_errs'],
-                't_errs': batch['t_errs'],
-                'inliers': batch['inliers']}
-            ret_dict = {'metrics': metrics}
+                "identifiers": ["#".join(rel_pair_names[b]) for b in range(bs)],
+                "epi_errs": [
+                    batch["epi_errs"][batch["m_bids"] == b].cpu().numpy()
+                    for b in range(bs)
+                ],
+                "R_errs": batch["R_errs"],
+                "t_errs": batch["t_errs"],
+                "inliers": batch["inliers"],
+            }
+            ret_dict = {"metrics": metrics}
         return ret_dict, rel_pair_names
-    
+
     def training_step(self, batch, batch_idx):
         self._trainval_inference(batch)
-        
+
         # logging
-        if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0:
+        if (
+            self.trainer.global_rank == 0
+            and self.global_step % self.trainer.log_every_n_steps == 0
+        ):
             # scalars
-            for k, v in batch['loss_scalars'].items():
-                self.logger.experiment.add_scalar(f'train/{k}', v, self.global_step)
+            for k, v in batch["loss_scalars"].items():
+                self.logger.experiment.add_scalar(f"train/{k}", v, self.global_step)
 
             # figures
             if self.config.TRAINER.ENABLE_PLOTTING:
-                compute_symmetrical_epipolar_errors(batch)  # compute epi_errs for each match
-                figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE)
+                compute_symmetrical_epipolar_errors(
+                    batch
+                )  # compute epi_errs for each match
+                figures = make_matching_figures(
+                    batch, self.config, self.config.TRAINER.PLOT_MODE
+                )
                 for k, v in figures.items():
-                    self.logger.experiment.add_figure(f'train_match/{k}', v, self.global_step)
+                    self.logger.experiment.add_figure(
+                        f"train_match/{k}", v, self.global_step
+                    )
 
-        return {'loss': batch['loss']}
+        return {"loss": batch["loss"]}
 
     def training_epoch_end(self, outputs):
-        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
+        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
         if self.trainer.global_rank == 0:
             self.logger.experiment.add_scalar(
-                'train/avg_loss_on_epoch', avg_loss,
-                global_step=self.current_epoch)
-    
+                "train/avg_loss_on_epoch", avg_loss, global_step=self.current_epoch
+            )
+
     def validation_step(self, batch, batch_idx):
         self._trainval_inference(batch)
-        
+
         ret_dict, _ = self._compute_metrics(batch)
-        
+
         val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1)
         figures = {self.config.TRAINER.PLOT_MODE: []}
         if batch_idx % val_plot_interval == 0:
-            figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE)
+            figures = make_matching_figures(
+                batch, self.config, mode=self.config.TRAINER.PLOT_MODE
+            )
 
         return {
             **ret_dict,
-            'loss_scalars': batch['loss_scalars'],
-            'figures': figures,
+            "loss_scalars": batch["loss_scalars"],
+            "figures": figures,
         }
-        
+
     def validation_epoch_end(self, outputs):
         # handle multiple validation sets
-        multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
+        multi_outputs = (
+            [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
+        )
         multi_val_metrics = defaultdict(list)
-        
+
         for valset_idx, outputs in enumerate(multi_outputs):
             # since pl performs sanity_check at the very begining of the training
             cur_epoch = self.trainer.current_epoch
-            if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check:
+            if (
+                not self.trainer.resume_from_checkpoint
+                and self.trainer.running_sanity_check
+            ):
                 cur_epoch = -1
 
             # 1. loss_scalars: dict of list, on cpu
-            _loss_scalars = [o['loss_scalars'] for o in outputs]
-            loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]}
+            _loss_scalars = [o["loss_scalars"] for o in outputs]
+            loss_scalars = {
+                k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars]))
+                for k in _loss_scalars[0]
+            }
 
             # 2. val metrics: dict of list, numpy
-            _metrics = [o['metrics'] for o in outputs]
-            metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
-            # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0 
-            val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
+            _metrics = [o["metrics"] for o in outputs]
+            metrics = {
+                k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics])))
+                for k in _metrics[0]
+            }
+            # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
+            val_metrics_4tb = aggregate_metrics(
+                metrics, self.config.TRAINER.EPI_ERR_THR
+            )
             for thr in [5, 10, 20]:
-                multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}'])
-            
+                multi_val_metrics[f"auc@{thr}"].append(val_metrics_4tb[f"auc@{thr}"])
+
             # 3. figures
-            _figures = [o['figures'] for o in outputs]
-            figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]}
+            _figures = [o["figures"] for o in outputs]
+            figures = {
+                k: flattenList(gather(flattenList([_me[k] for _me in _figures])))
+                for k in _figures[0]
+            }
 
             # tensorboard records only on rank 0
             if self.trainer.global_rank == 0:
                 for k, v in loss_scalars.items():
                     mean_v = torch.stack(v).mean()
-                    self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch)
+                    self.logger.experiment.add_scalar(
+                        f"val_{valset_idx}/avg_{k}", mean_v, global_step=cur_epoch
+                    )
 
                 for k, v in val_metrics_4tb.items():
-                    self.logger.experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch)
-                
+                    self.logger.experiment.add_scalar(
+                        f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch
+                    )
+
                 for k, v in figures.items():
                     if self.trainer.global_rank == 0:
                         for plot_idx, fig in enumerate(v):
                             self.logger.experiment.add_figure(
-                                f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True)
-            plt.close('all')
+                                f"val_match_{valset_idx}/{k}/pair-{plot_idx}",
+                                fig,
+                                cur_epoch,
+                                close=True,
+                            )
+            plt.close("all")
 
         for thr in [5, 10, 20]:
             # log on all ranks for ModelCheckpoint callback to work properly
-            self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}'])))  # ckpt monitors on this
+            self.log(
+                f"auc@{thr}", torch.tensor(np.mean(multi_val_metrics[f"auc@{thr}"]))
+            )  # ckpt monitors on this
 
     def test_step(self, batch, batch_idx):
         with self.profiler.profile("TopicFM"):
@@ -206,39 +265,46 @@ class PL_Trainer(pl.LightningModule):
         with self.profiler.profile("dump_results"):
             if self.dump_dir is not None:
                 # dump results for further analysis
-                keys_to_save = {'mkpts0_f', 'mkpts1_f', 'mconf', 'epi_errs'}
-                pair_names = list(zip(*batch['pair_names']))
-                bs = batch['image0'].shape[0]
+                keys_to_save = {"mkpts0_f", "mkpts1_f", "mconf", "epi_errs"}
+                pair_names = list(zip(*batch["pair_names"]))
+                bs = batch["image0"].shape[0]
                 dumps = []
                 for b_id in range(bs):
                     item = {}
-                    mask = batch['m_bids'] == b_id
-                    item['pair_names'] = pair_names[b_id]
-                    item['identifier'] = '#'.join(rel_pair_names[b_id])
+                    mask = batch["m_bids"] == b_id
+                    item["pair_names"] = pair_names[b_id]
+                    item["identifier"] = "#".join(rel_pair_names[b_id])
                     for key in keys_to_save:
                         item[key] = batch[key][mask].cpu().numpy()
-                    for key in ['R_errs', 't_errs', 'inliers']:
+                    for key in ["R_errs", "t_errs", "inliers"]:
                         item[key] = batch[key][b_id]
                     dumps.append(item)
-                ret_dict['dumps'] = dumps
+                ret_dict["dumps"] = dumps
 
         return ret_dict
 
     def test_epoch_end(self, outputs):
         # metrics: dict of list, numpy
-        _metrics = [o['metrics'] for o in outputs]
-        metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
+        _metrics = [o["metrics"] for o in outputs]
+        metrics = {
+            k: flattenList(gather(flattenList([_me[k] for _me in _metrics])))
+            for k in _metrics[0]
+        }
 
         # [{key: [{...}, *#bs]}, *#batch]
         if self.dump_dir is not None:
             Path(self.dump_dir).mkdir(parents=True, exist_ok=True)
-            _dumps = flattenList([o['dumps'] for o in outputs])  # [{...}, #bs*#batch]
+            _dumps = flattenList([o["dumps"] for o in outputs])  # [{...}, #bs*#batch]
             dumps = flattenList(gather(_dumps))  # [{...}, #proc*#bs*#batch]
-            logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}')
+            logger.info(
+                f"Prediction and evaluation results will be saved to: {self.dump_dir}"
+            )
 
         if self.trainer.global_rank == 0:
             print(self.profiler.summary())
-            val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
-            logger.info('\n' + pprint.pformat(val_metrics_4tb))
+            val_metrics_4tb = aggregate_metrics(
+                metrics, self.config.TRAINER.EPI_ERR_THR
+            )
+            logger.info("\n" + pprint.pformat(val_metrics_4tb))
             if self.dump_dir is not None:
-                np.save(Path(self.dump_dir) / 'TopicFM_pred_eval', dumps)
+                np.save(Path(self.dump_dir) / "TopicFM_pred_eval", dumps)
diff --git a/third_party/TopicFM/src/losses/loss.py b/third_party/TopicFM/src/losses/loss.py
index 4be58498579c9fe649ed0ce2d42f230e59cef581..e386bb557285a290962477179e9a3a36b665368f 100644
--- a/third_party/TopicFM/src/losses/loss.py
+++ b/third_party/TopicFM/src/losses/loss.py
@@ -13,10 +13,10 @@ def sample_non_matches(pos_mask, match_ids=None, sampling_ratio=10):
             return ~pos_mask
 
         neg_mask = torch.zeros_like(pos_mask)
-        probs = torch.ones((HW - 1)//3, device=pos_mask.device)
+        probs = torch.ones((HW - 1) // 3, device=pos_mask.device)
         for _ in range(sampling_ratio):
             d = torch.multinomial(probs, len(j_ids), replacement=True)
-            sampled_j_ids = (j_ids + d*3 + 1) % HW
+            sampled_j_ids = (j_ids + d * 3 + 1) % HW
             neg_mask[b_ids, i_ids, sampled_j_ids] = True
         # neg_mask = neg_matrix == 1
     else:
@@ -29,18 +29,20 @@ class TopicFMLoss(nn.Module):
     def __init__(self, config):
         super().__init__()
         self.config = config  # config under the global namespace
-        self.loss_config = config['model']['loss']
-        self.match_type = self.config['model']['match_coarse']['match_type']
-        
+        self.loss_config = config["model"]["loss"]
+        self.match_type = self.config["model"]["match_coarse"]["match_type"]
+
         # coarse-level
-        self.correct_thr = self.loss_config['fine_correct_thr']
-        self.c_pos_w = self.loss_config['pos_weight']
-        self.c_neg_w = self.loss_config['neg_weight']
+        self.correct_thr = self.loss_config["fine_correct_thr"]
+        self.c_pos_w = self.loss_config["pos_weight"]
+        self.c_neg_w = self.loss_config["neg_weight"]
         # fine-level
-        self.fine_type = self.loss_config['fine_type']
+        self.fine_type = self.loss_config["fine_type"]
 
-    def compute_coarse_loss(self, conf, topic_mat, conf_gt, match_ids=None, weight=None):
-        """ Point-wise CE / Focal Loss with 0 / 1 confidence as gt.
+    def compute_coarse_loss(
+        self, conf, topic_mat, conf_gt, match_ids=None, weight=None
+    ):
+        """Point-wise CE / Focal Loss with 0 / 1 confidence as gt.
         Args:
             conf (torch.Tensor): (N, HW0, HW1) / (N, HW0+1, HW1+1)
             conf_gt (torch.Tensor): (N, HW0, HW1)
@@ -53,30 +55,30 @@ class TopicFMLoss(nn.Module):
         if not pos_mask.any():  # assign a wrong gt
             pos_mask[0, 0, 0] = True
             if weight is not None:
-                weight[0, 0, 0] = 0.
-            c_pos_w = 0.
+                weight[0, 0, 0] = 0.0
+            c_pos_w = 0.0
         if not neg_mask.any():
             neg_mask[0, 0, 0] = True
             if weight is not None:
-                weight[0, 0, 0] = 0.
-            c_neg_w = 0.
+                weight[0, 0, 0] = 0.0
+            c_neg_w = 0.0
 
         conf = torch.clamp(conf, 1e-6, 1 - 1e-6)
-        alpha = self.loss_config['focal_alpha']
+        alpha = self.loss_config["focal_alpha"]
 
         loss = 0.0
         if isinstance(topic_mat, torch.Tensor):
             pos_topic = topic_mat[pos_mask]
-            loss_pos_topic = - alpha * (pos_topic + 1e-6).log()
+            loss_pos_topic = -alpha * (pos_topic + 1e-6).log()
             neg_topic = topic_mat[neg_mask]
-            loss_neg_topic = - alpha * (1 - neg_topic + 1e-6).log()
+            loss_neg_topic = -alpha * (1 - neg_topic + 1e-6).log()
             if weight is not None:
                 loss_pos_topic = loss_pos_topic * weight[pos_mask]
                 loss_neg_topic = loss_neg_topic * weight[neg_mask]
             loss = loss_pos_topic.mean() + loss_neg_topic.mean()
 
         pos_conf = conf[pos_mask]
-        loss_pos = - alpha * pos_conf.log()
+        loss_pos = -alpha * pos_conf.log()
         # handle loss weights
         if weight is not None:
             # Different from dense-spvs, the loss w.r.t. padded regions aren't directly zeroed out,
@@ -86,11 +88,11 @@ class TopicFMLoss(nn.Module):
         loss = loss + c_pos_w * loss_pos.mean()
 
         return loss
-        
+
     def compute_fine_loss(self, expec_f, expec_f_gt):
-        if self.fine_type == 'l2_with_std':
+        if self.fine_type == "l2_with_std":
             return self._compute_fine_loss_l2_std(expec_f, expec_f_gt)
-        elif self.fine_type == 'l2':
+        elif self.fine_type == "l2":
             return self._compute_fine_loss_l2(expec_f, expec_f_gt)
         else:
             raise NotImplementedError()
@@ -101,9 +103,13 @@ class TopicFMLoss(nn.Module):
             expec_f (torch.Tensor): [M, 2] <x, y>
             expec_f_gt (torch.Tensor): [M, 2] <x, y>
         """
-        correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr
+        correct_mask = (
+            torch.linalg.norm(expec_f_gt, ord=float("inf"), dim=1) < self.correct_thr
+        )
         if correct_mask.sum() == 0:
-            if self.training:  # this seldomly happen when training, since we pad prediction with gt
+            if (
+                self.training
+            ):  # this seldomly happen when training, since we pad prediction with gt
                 logger.warning("assign a false supervision to avoid ddp deadlock")
                 correct_mask[0] = True
             else:
@@ -118,34 +124,45 @@ class TopicFMLoss(nn.Module):
             expec_f_gt (torch.Tensor): [M, 2] <x, y>
         """
         # correct_mask tells you which pair to compute fine-loss
-        correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr
+        correct_mask = (
+            torch.linalg.norm(expec_f_gt, ord=float("inf"), dim=1) < self.correct_thr
+        )
 
         # use std as weight that measures uncertainty
         std = expec_f[:, 2]
-        inverse_std = 1. / torch.clamp(std, min=1e-10)
-        weight = (inverse_std / torch.mean(inverse_std)).detach()  # avoid minizing loss through increase std
+        inverse_std = 1.0 / torch.clamp(std, min=1e-10)
+        weight = (
+            inverse_std / torch.mean(inverse_std)
+        ).detach()  # avoid minizing loss through increase std
 
         # corner case: no correct coarse match found
         if not correct_mask.any():
-            if self.training:  # this seldomly happen during training, since we pad prediction with gt
-                               # sometimes there is not coarse-level gt at all.
+            if (
+                self.training
+            ):  # this seldomly happen during training, since we pad prediction with gt
+                # sometimes there is not coarse-level gt at all.
                 logger.warning("assign a false supervision to avoid ddp deadlock")
                 correct_mask[0] = True
-                weight[0] = 0.
+                weight[0] = 0.0
             else:
                 return None
 
         # l2 loss with std
-        offset_l2 = ((expec_f_gt[correct_mask] - expec_f[correct_mask, :2]) ** 2).sum(-1)
+        offset_l2 = ((expec_f_gt[correct_mask] - expec_f[correct_mask, :2]) ** 2).sum(
+            -1
+        )
         loss = (offset_l2 * weight[correct_mask]).mean()
 
         return loss
-    
+
     @torch.no_grad()
     def compute_c_weight(self, data):
-        """ compute element-wise weights for computing coarse-level loss. """
-        if 'mask0' in data:
-            c_weight = (data['mask0'].flatten(-2)[..., None] * data['mask1'].flatten(-2)[:, None]).float()
+        """compute element-wise weights for computing coarse-level loss."""
+        if "mask0" in data:
+            c_weight = (
+                data["mask0"].flatten(-2)[..., None]
+                * data["mask1"].flatten(-2)[:, None]
+            ).float()
         else:
             c_weight = None
         return c_weight
@@ -163,20 +180,24 @@ class TopicFMLoss(nn.Module):
         c_weight = self.compute_c_weight(data)
 
         # 1. coarse-level loss
-        loss_c = self.compute_coarse_loss(data['conf_matrix'], data['topic_matrix'],
-            data['conf_matrix_gt'], match_ids=(data['spv_b_ids'], data['spv_i_ids'], data['spv_j_ids']),
-            weight=c_weight)
-        loss = loss_c * self.loss_config['coarse_weight']
+        loss_c = self.compute_coarse_loss(
+            data["conf_matrix"],
+            data["topic_matrix"],
+            data["conf_matrix_gt"],
+            match_ids=(data["spv_b_ids"], data["spv_i_ids"], data["spv_j_ids"]),
+            weight=c_weight,
+        )
+        loss = loss_c * self.loss_config["coarse_weight"]
         loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()})
 
         # 2. fine-level loss
-        loss_f = self.compute_fine_loss(data['expec_f'], data['expec_f_gt'])
+        loss_f = self.compute_fine_loss(data["expec_f"], data["expec_f_gt"])
         if loss_f is not None:
-            loss += loss_f * self.loss_config['fine_weight']
-            loss_scalars.update({"loss_f":  loss_f.clone().detach().cpu()})
+            loss += loss_f * self.loss_config["fine_weight"]
+            loss_scalars.update({"loss_f": loss_f.clone().detach().cpu()})
         else:
             assert self.training is False
-            loss_scalars.update({'loss_f': torch.tensor(1.)})  # 1 is the upper bound
+            loss_scalars.update({"loss_f": torch.tensor(1.0)})  # 1 is the upper bound
 
-        loss_scalars.update({'loss': loss.clone().detach().cpu()})
+        loss_scalars.update({"loss": loss.clone().detach().cpu()})
         data.update({"loss": loss, "loss_scalars": loss_scalars})
diff --git a/third_party/TopicFM/src/models/backbone/__init__.py b/third_party/TopicFM/src/models/backbone/__init__.py
index 53f98db4e910b46716bed7cfc6ebbf8c8bfad399..72a80de20ba3f6bc02454f4930b25d6b18f4b34f 100644
--- a/third_party/TopicFM/src/models/backbone/__init__.py
+++ b/third_party/TopicFM/src/models/backbone/__init__.py
@@ -2,4 +2,4 @@ from .fpn import FPN
 
 
 def build_backbone(config):
-    return FPN(config['fpn'])
+    return FPN(config["fpn"])
diff --git a/third_party/TopicFM/src/models/backbone/fpn.py b/third_party/TopicFM/src/models/backbone/fpn.py
index 93cc475f57317f9dbb8132cdfe0297391972f9e2..7f38ec13f196793a00cacbaaa3eb7c0a5d8e9605 100644
--- a/third_party/TopicFM/src/models/backbone/fpn.py
+++ b/third_party/TopicFM/src/models/backbone/fpn.py
@@ -4,12 +4,16 @@ import torch.nn.functional as F
 
 def conv1x1(in_planes, out_planes, stride=1):
     """1x1 convolution without padding"""
-    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
+    return nn.Conv2d(
+        in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False
+    )
 
 
 def conv3x3(in_planes, out_planes, stride=1):
     """3x3 convolution with padding"""
-    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
+    return nn.Conv2d(
+        in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
+    )
 
 
 class ConvBlock(nn.Module):
@@ -22,7 +26,7 @@ class ConvBlock(nn.Module):
     def forward(self, x):
         y = self.conv(x)
         if self.bn:
-            y = self.bn(y) #F.layer_norm(y, y.shape[1:])
+            y = self.bn(y)  # F.layer_norm(y, y.shape[1:])
         y = self.act(y)
         return y
 
@@ -37,14 +41,16 @@ class FPN(nn.Module):
         super().__init__()
         # Config
         block = ConvBlock
-        initial_dim = config['initial_dim']
-        block_dims = config['block_dims']
+        initial_dim = config["initial_dim"]
+        block_dims = config["block_dims"]
 
         # Class Variable
         self.in_planes = initial_dim
 
         # Networks
-        self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
+        self.conv1 = nn.Conv2d(
+            1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False
+        )
         self.bn1 = nn.BatchNorm2d(initial_dim)
         self.relu = nn.ReLU(inplace=True)
 
@@ -72,7 +78,7 @@ class FPN(nn.Module):
 
         for m in self.modules():
             if isinstance(m, nn.Conv2d):
-                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
             elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                 nn.init.constant_(m.weight, 1)
                 nn.init.constant_(m.bias, 0)
@@ -94,16 +100,22 @@ class FPN(nn.Module):
         x4 = self.layer4(x3)  # 1/16
 
         # FPN
-        x4_out_2x = F.interpolate(x4, scale_factor=2., mode='bilinear', align_corners=True)
+        x4_out_2x = F.interpolate(
+            x4, scale_factor=2.0, mode="bilinear", align_corners=True
+        )
         x3_out = self.layer3_outconv(x3)
-        x3_out = self.layer3_outconv2(x3_out+x4_out_2x)
+        x3_out = self.layer3_outconv2(x3_out + x4_out_2x)
 
-        x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
+        x3_out_2x = F.interpolate(
+            x3_out, scale_factor=2.0, mode="bilinear", align_corners=True
+        )
         x2_out = self.layer2_outconv(x2)
-        x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
+        x2_out = self.layer2_outconv2(x2_out + x3_out_2x)
 
-        x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
+        x2_out_2x = F.interpolate(
+            x2_out, scale_factor=2.0, mode="bilinear", align_corners=True
+        )
         x1_out = self.layer1_outconv(x1)
-        x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
+        x1_out = self.layer1_outconv2(x1_out + x2_out_2x)
 
         return [x3_out, x1_out]
diff --git a/third_party/TopicFM/src/models/modules/fine_preprocess.py b/third_party/TopicFM/src/models/modules/fine_preprocess.py
index 4c8d264c1895be8f4e124fc3982d4e0d3b876af3..4cdce2d327ebc88371769946a292824f834729a5 100644
--- a/third_party/TopicFM/src/models/modules/fine_preprocess.py
+++ b/third_party/TopicFM/src/models/modules/fine_preprocess.py
@@ -9,15 +9,15 @@ class FinePreprocess(nn.Module):
         super().__init__()
 
         self.config = config
-        self.cat_c_feat = config['fine_concat_coarse_feat']
-        self.W = self.config['fine_window_size']
+        self.cat_c_feat = config["fine_concat_coarse_feat"]
+        self.W = self.config["fine_window_size"]
 
-        d_model_c = self.config['coarse']['d_model']
-        d_model_f = self.config['fine']['d_model']
+        d_model_c = self.config["coarse"]["d_model"]
+        d_model_f = self.config["fine"]["d_model"]
         self.d_model_f = d_model_f
         if self.cat_c_feat:
             self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
-            self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True)
+            self.merge_feat = nn.Linear(2 * d_model_f, d_model_f, bias=True)
 
         self._reset_parameters()
 
@@ -28,32 +28,48 @@ class FinePreprocess(nn.Module):
 
     def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data):
         W = self.W
-        stride = data['hw0_f'][0] // data['hw0_c'][0]
+        stride = data["hw0_f"][0] // data["hw0_c"][0]
 
-        data.update({'W': W})
-        if data['b_ids'].shape[0] == 0:
+        data.update({"W": W})
+        if data["b_ids"].shape[0] == 0:
             feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
             feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
             return feat0, feat1
 
         # 1. unfold(crop) all local windows
-        feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2)
-        feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
-        feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2)
-        feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
+        feat_f0_unfold = F.unfold(
+            feat_f0, kernel_size=(W, W), stride=stride, padding=W // 2
+        )
+        feat_f0_unfold = rearrange(feat_f0_unfold, "n (c ww) l -> n l ww c", ww=W**2)
+        feat_f1_unfold = F.unfold(
+            feat_f1, kernel_size=(W, W), stride=stride, padding=W // 2
+        )
+        feat_f1_unfold = rearrange(feat_f1_unfold, "n (c ww) l -> n l ww c", ww=W**2)
 
         # 2. select only the predicted matches
-        feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']]  # [n, ww, cf]
-        feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
+        feat_f0_unfold = feat_f0_unfold[data["b_ids"], data["i_ids"]]  # [n, ww, cf]
+        feat_f1_unfold = feat_f1_unfold[data["b_ids"], data["j_ids"]]
 
         # option: use coarse-level feature as context: concat and linear
         if self.cat_c_feat:
-            feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']],
-                                                   feat_c1[data['b_ids'], data['j_ids']]], 0))  # [2n, c]
-            feat_cf_win = self.merge_feat(torch.cat([
-                torch.cat([feat_f0_unfold, feat_f1_unfold], 0),  # [2n, ww, cf]
-                repeat(feat_c_win, 'n c -> n ww c', ww=W**2),  # [2n, ww, cf]
-            ], -1))
+            feat_c_win = self.down_proj(
+                torch.cat(
+                    [
+                        feat_c0[data["b_ids"], data["i_ids"]],
+                        feat_c1[data["b_ids"], data["j_ids"]],
+                    ],
+                    0,
+                )
+            )  # [2n, c]
+            feat_cf_win = self.merge_feat(
+                torch.cat(
+                    [
+                        torch.cat([feat_f0_unfold, feat_f1_unfold], 0),  # [2n, ww, cf]
+                        repeat(feat_c_win, "n c -> n ww c", ww=W**2),  # [2n, ww, cf]
+                    ],
+                    -1,
+                )
+            )
             feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
 
         return feat_f0_unfold, feat_f1_unfold
diff --git a/third_party/TopicFM/src/models/modules/linear_attention.py b/third_party/TopicFM/src/models/modules/linear_attention.py
index af6cd825033e98b7be15cc694ce28110ef84cc93..57b86b3ba682da62f9ff65893aa0ccd6753d32af 100644
--- a/third_party/TopicFM/src/models/modules/linear_attention.py
+++ b/third_party/TopicFM/src/models/modules/linear_attention.py
@@ -18,7 +18,7 @@ class LinearAttention(Module):
         self.eps = eps
 
     def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
-        """ Multi-Head linear attention proposed in "Transformers are RNNs"
+        """Multi-Head linear attention proposed in "Transformers are RNNs"
         Args:
             queries: [N, L, H, D]
             keys: [N, S, H, D]
@@ -54,7 +54,7 @@ class FullAttention(Module):
         self.dropout = Dropout(attention_dropout)
 
     def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
-        """ Multi-head scaled dot-product attention, a.k.a full attention.
+        """Multi-head scaled dot-product attention, a.k.a full attention.
         Args:
             queries: [N, L, H, D]
             keys: [N, S, H, D]
@@ -68,10 +68,12 @@ class FullAttention(Module):
         # Compute the unnormalized attention and apply the masks
         QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
         if kv_mask is not None:
-            QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]).bool(), -1e9)
+            QK.masked_fill_(
+                ~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]).bool(), -1e9
+            )
 
         # Compute the attention and the weighted average
-        softmax_temp = 1. / queries.size(3)**.5  # sqrt(D)
+        softmax_temp = 1.0 / queries.size(3) ** 0.5  # sqrt(D)
         A = torch.softmax(softmax_temp * QK, dim=2)
         if self.use_dropout:
             A = self.dropout(A)
diff --git a/third_party/TopicFM/src/models/modules/transformer.py b/third_party/TopicFM/src/models/modules/transformer.py
index 27ff8f6554844b1e14a7094fcbad40876f766db8..cef17ca689cd0f844c1d6bd6c0f987a3e0c3be59 100644
--- a/third_party/TopicFM/src/models/modules/transformer.py
+++ b/third_party/TopicFM/src/models/modules/transformer.py
@@ -8,10 +8,7 @@ from .linear_attention import LinearAttention, FullAttention
 
 
 class LoFTREncoderLayer(nn.Module):
-    def __init__(self,
-                 d_model,
-                 nhead,
-                 attention='linear'):
+    def __init__(self, d_model, nhead, attention="linear"):
         super(LoFTREncoderLayer, self).__init__()
 
         self.dim = d_model // nhead
@@ -21,14 +18,14 @@ class LoFTREncoderLayer(nn.Module):
         self.q_proj = nn.Linear(d_model, d_model, bias=False)
         self.k_proj = nn.Linear(d_model, d_model, bias=False)
         self.v_proj = nn.Linear(d_model, d_model, bias=False)
-        self.attention = LinearAttention() if attention == 'linear' else FullAttention()
+        self.attention = LinearAttention() if attention == "linear" else FullAttention()
         self.merge = nn.Linear(d_model, d_model, bias=False)
 
         # feed-forward network
         self.mlp = nn.Sequential(
-            nn.Linear(d_model*2, d_model*2, bias=False),
+            nn.Linear(d_model * 2, d_model * 2, bias=False),
             nn.GELU(),
-            nn.Linear(d_model*2, d_model, bias=False),
+            nn.Linear(d_model * 2, d_model, bias=False),
         )
 
         # norm and dropout
@@ -50,8 +47,10 @@ class LoFTREncoderLayer(nn.Module):
         query = self.q_proj(query).view(bs, -1, self.nhead, self.dim)  # [N, L, (H, D)]
         key = self.k_proj(key).view(bs, -1, self.nhead, self.dim)  # [N, S, (H, D)]
         value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
-        message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask)  # [N, L, (H, D)]
-        message = self.merge(message.view(bs, -1, self.nhead*self.dim))  # [N, L, C]
+        message = self.attention(
+            query, key, value, q_mask=x_mask, kv_mask=source_mask
+        )  # [N, L, (H, D)]
+        message = self.merge(message.view(bs, -1, self.nhead * self.dim))  # [N, L, C]
         message = self.norm1(message)
 
         # feed-forward network
@@ -68,18 +67,33 @@ class TopicFormer(nn.Module):
         super(TopicFormer, self).__init__()
 
         self.config = config
-        self.d_model = config['d_model']
-        self.nhead = config['nhead']
-        self.layer_names = config['layer_names']
-        encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention'])
-        self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
-
-        self.topic_transformers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(2*config['n_topic_transformers'])]) if config['n_samples'] > 0 else None #nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(2)])
-        self.n_iter_topic_transformer = config['n_topic_transformers']
+        self.d_model = config["d_model"]
+        self.nhead = config["nhead"]
+        self.layer_names = config["layer_names"]
+        encoder_layer = LoFTREncoderLayer(
+            config["d_model"], config["nhead"], config["attention"]
+        )
+        self.layers = nn.ModuleList(
+            [copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]
+        )
 
-        self.seed_tokens = nn.Parameter(torch.randn(config['n_topics'], config['d_model']))
-        self.register_parameter('seed_tokens', self.seed_tokens)
-        self.n_samples = config['n_samples']
+        self.topic_transformers = (
+            nn.ModuleList(
+                [
+                    copy.deepcopy(encoder_layer)
+                    for _ in range(2 * config["n_topic_transformers"])
+                ]
+            )
+            if config["n_samples"] > 0
+            else None
+        )  # nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(2)])
+        self.n_iter_topic_transformer = config["n_topic_transformers"]
+
+        self.seed_tokens = nn.Parameter(
+            torch.randn(config["n_topics"], config["d_model"])
+        )
+        self.register_parameter("seed_tokens", self.seed_tokens)
+        self.n_samples = config["n_samples"]
 
         self._reset_parameters()
 
@@ -94,9 +108,9 @@ class TopicFormer(nn.Module):
             topics (torch.Tensor): [N, L+S, K]
         """
         prob_topics0, prob_topics1 = prob_topics[:, :L], prob_topics[:, L:]
-        topics0, topics1  = topics[:, :L], topics[:, L:]
+        topics0, topics1 = topics[:, :L], topics[:, L:]
 
-        theta0 = F.normalize(prob_topics0.sum(dim=1), p=1, dim=-1) # [N, K]
+        theta0 = F.normalize(prob_topics0.sum(dim=1), p=1, dim=-1)  # [N, K]
         theta1 = F.normalize(prob_topics1.sum(dim=1), p=1, dim=-1)
         theta = F.normalize(theta0 * theta1, p=1, dim=-1)
         if self.n_samples == 0:
@@ -106,18 +120,28 @@ class TopicFormer(nn.Module):
             sampled_values = torch.gather(theta, dim=-1, index=sampled_inds)
         else:
             sampled_values, sampled_inds = torch.topk(theta, self.n_samples, dim=-1)
-        sampled_topics0 = torch.gather(topics0, dim=-1, index=sampled_inds.unsqueeze(1).repeat(1, topics0.shape[1], 1))
-        sampled_topics1 = torch.gather(topics1, dim=-1, index=sampled_inds.unsqueeze(1).repeat(1, topics1.shape[1], 1))
+        sampled_topics0 = torch.gather(
+            topics0,
+            dim=-1,
+            index=sampled_inds.unsqueeze(1).repeat(1, topics0.shape[1], 1),
+        )
+        sampled_topics1 = torch.gather(
+            topics1,
+            dim=-1,
+            index=sampled_inds.unsqueeze(1).repeat(1, topics1.shape[1], 1),
+        )
         return sampled_topics0, sampled_topics1
 
     def reduce_feat(self, feat, topick, N, C):
         len_topic = topick.sum(dim=-1).int()
         max_len = len_topic.max().item()
         selected_ids = topick.bool()
-        resized_feat = torch.zeros((N, max_len, C), dtype=torch.float32, device=feat.device)
+        resized_feat = torch.zeros(
+            (N, max_len, C), dtype=torch.float32, device=feat.device
+        )
         new_mask = torch.zeros_like(resized_feat[..., 0]).bool()
         for i in range(N):
-            new_mask[i, :len_topic[i]] = True
+            new_mask[i, : len_topic[i]] = True
         resized_feat[new_mask, :] = feat[selected_ids, :]
         return resized_feat, new_mask, selected_ids
 
@@ -130,8 +154,16 @@ class TopicFormer(nn.Module):
             mask1 (torch.Tensor): [N, S] (optional)
         """
 
-        assert self.d_model == feat0.shape[2], "the feature number of src and transformer must be equal"
-        N, L, S, C, K = feat0.shape[0], feat0.shape[1], feat1.shape[1], feat0.shape[2], self.config['n_topics']
+        assert (
+            self.d_model == feat0.shape[2]
+        ), "the feature number of src and transformer must be equal"
+        N, L, S, C, K = (
+            feat0.shape[0],
+            feat0.shape[1],
+            feat1.shape[1],
+            feat0.shape[2],
+            self.config["n_topics"],
+        )
 
         seeds = self.seed_tokens.unsqueeze(0).repeat(N, 1, 1)
 
@@ -142,18 +174,20 @@ class TopicFormer(nn.Module):
             mask = None
 
         for layer, name in zip(self.layers, self.layer_names):
-            if name == 'seed':
+            if name == "seed":
                 # seeds = layer(seeds, feat0, None, mask0)
                 # seeds = layer(seeds, feat1, None, mask1)
                 seeds = layer(seeds, feat, None, mask)
-            elif name == 'feat':
+            elif name == "feat":
                 feat0 = layer(feat0, seeds, mask0, None)
                 feat1 = layer(feat1, seeds, mask1, None)
 
         dmatrix = torch.einsum("nmd,nkd->nmk", feat, seeds)
         prob_topics = F.softmax(dmatrix, dim=-1)
 
-        feat_topics = torch.zeros_like(dmatrix).scatter_(-1, torch.argmax(dmatrix, dim=-1, keepdim=True), 1.0)
+        feat_topics = torch.zeros_like(dmatrix).scatter_(
+            -1, torch.argmax(dmatrix, dim=-1, keepdim=True), 1.0
+        )
 
         if mask is not None:
             feat_topics = feat_topics * mask.unsqueeze(-1)
@@ -163,35 +197,57 @@ class TopicFormer(nn.Module):
             logger.warning("topic distribution is highly sparse!")
         sampled_topics = self.sample_topic(prob_topics.detach(), feat_topics, L)
         if sampled_topics is not None:
-            updated_feat0, updated_feat1 = torch.zeros_like(feat0), torch.zeros_like(feat1)
+            updated_feat0, updated_feat1 = torch.zeros_like(feat0), torch.zeros_like(
+                feat1
+            )
             s_topics0, s_topics1 = sampled_topics
             for k in range(s_topics0.shape[-1]):
-                topick0, topick1 = s_topics0[..., k], s_topics1[..., k] # [N, L+S]
+                topick0, topick1 = s_topics0[..., k], s_topics1[..., k]  # [N, L+S]
                 if (topick0.sum() > 0) and (topick1.sum() > 0):
-                    new_feat0, new_mask0, selected_ids0 = self.reduce_feat(feat0, topick0, N, C)
-                    new_feat1, new_mask1, selected_ids1 = self.reduce_feat(feat1, topick1, N, C)
+                    new_feat0, new_mask0, selected_ids0 = self.reduce_feat(
+                        feat0, topick0, N, C
+                    )
+                    new_feat1, new_mask1, selected_ids1 = self.reduce_feat(
+                        feat1, topick1, N, C
+                    )
                     for idt in range(self.n_iter_topic_transformer):
-                        new_feat0 = self.topic_transformers[idt*2](new_feat0, new_feat0, new_mask0, new_mask0)
-                        new_feat1 = self.topic_transformers[idt*2](new_feat1, new_feat1, new_mask1, new_mask1)
-                        new_feat0 = self.topic_transformers[idt*2+1](new_feat0, new_feat1, new_mask0, new_mask1)
-                        new_feat1 = self.topic_transformers[idt*2+1](new_feat1, new_feat0, new_mask1, new_mask0)
+                        new_feat0 = self.topic_transformers[idt * 2](
+                            new_feat0, new_feat0, new_mask0, new_mask0
+                        )
+                        new_feat1 = self.topic_transformers[idt * 2](
+                            new_feat1, new_feat1, new_mask1, new_mask1
+                        )
+                        new_feat0 = self.topic_transformers[idt * 2 + 1](
+                            new_feat0, new_feat1, new_mask0, new_mask1
+                        )
+                        new_feat1 = self.topic_transformers[idt * 2 + 1](
+                            new_feat1, new_feat0, new_mask1, new_mask0
+                        )
                     updated_feat0[selected_ids0, :] = new_feat0[new_mask0, :]
                     updated_feat1[selected_ids1, :] = new_feat1[new_mask1, :]
 
             feat0 = (1 - s_topics0.sum(dim=-1, keepdim=True)) * feat0 + updated_feat0
             feat1 = (1 - s_topics1.sum(dim=-1, keepdim=True)) * feat1 + updated_feat1
 
-        conf_matrix = torch.einsum("nlc,nsc->nls", feat0, feat1) / C**.5 #(C * temperature)
+        conf_matrix = (
+            torch.einsum("nlc,nsc->nls", feat0, feat1) / C**0.5
+        )  # (C * temperature)
         if self.training:
-            topic_matrix = torch.einsum("nlk,nsk->nls", prob_topics[:, :L], prob_topics[:, L:])
-            outlier_mask = torch.einsum("nlk,nsk->nls", feat_topics[:, :L], feat_topics[:, L:])
+            topic_matrix = torch.einsum(
+                "nlk,nsk->nls", prob_topics[:, :L], prob_topics[:, L:]
+            )
+            outlier_mask = torch.einsum(
+                "nlk,nsk->nls", feat_topics[:, :L], feat_topics[:, L:]
+            )
         else:
             topic_matrix = {"img0": feat_topics[:, :L], "img1": feat_topics[:, L:]}
             outlier_mask = torch.ones_like(conf_matrix)
         if mask0 is not None:
-            outlier_mask = (outlier_mask * mask0[..., None] * mask1[:, None]) #.bool()
+            outlier_mask = outlier_mask * mask0[..., None] * mask1[:, None]  # .bool()
         conf_matrix.masked_fill_(~outlier_mask.bool(), -1e9)
-        conf_matrix = F.softmax(conf_matrix, 1) * F.softmax(conf_matrix, 2)  # * topic_matrix
+        conf_matrix = F.softmax(conf_matrix, 1) * F.softmax(
+            conf_matrix, 2
+        )  # * topic_matrix
 
         return feat0, feat1, conf_matrix, topic_matrix
 
@@ -203,11 +259,15 @@ class LocalFeatureTransformer(nn.Module):
         super(LocalFeatureTransformer, self).__init__()
 
         self.config = config
-        self.d_model = config['d_model']
-        self.nhead = config['nhead']
-        self.layer_names = config['layer_names']
-        encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention'])
-        self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(2)]) #len(self.layer_names))])
+        self.d_model = config["d_model"]
+        self.nhead = config["nhead"]
+        self.layer_names = config["layer_names"]
+        encoder_layer = LoFTREncoderLayer(
+            config["d_model"], config["nhead"], config["attention"]
+        )
+        self.layers = nn.ModuleList(
+            [copy.deepcopy(encoder_layer) for _ in range(2)]
+        )  # len(self.layer_names))])
         self._reset_parameters()
 
     def _reset_parameters(self):
@@ -224,7 +284,9 @@ class LocalFeatureTransformer(nn.Module):
             mask1 (torch.Tensor): [N, S] (optional)
         """
 
-        assert self.d_model == feat0.shape[2], "the feature number of src and transformer must be equal"
+        assert (
+            self.d_model == feat0.shape[2]
+        ), "the feature number of src and transformer must be equal"
 
         feat0 = self.layers[0](feat0, feat1, mask0, mask1)
         feat1 = self.layers[1](feat1, feat0, mask1, mask0)
diff --git a/third_party/TopicFM/src/models/topic_fm.py b/third_party/TopicFM/src/models/topic_fm.py
index 95cd22f9b66d08760382fe4cd22c4df918cc9f68..2556bdbb489574e13a5e5af60be87c546473d406 100644
--- a/third_party/TopicFM/src/models/topic_fm.py
+++ b/third_party/TopicFM/src/models/topic_fm.py
@@ -17,14 +17,14 @@ class TopicFM(nn.Module):
         # Modules
         self.backbone = build_backbone(config)
 
-        self.loftr_coarse = TopicFormer(config['coarse'])
-        self.coarse_matching = CoarseMatching(config['match_coarse'])
+        self.loftr_coarse = TopicFormer(config["coarse"])
+        self.coarse_matching = CoarseMatching(config["match_coarse"])
         self.fine_preprocess = FinePreprocess(config)
         self.loftr_fine = LocalFeatureTransformer(config["fine"])
         self.fine_matching = FineMatching()
 
     def forward(self, data):
-        """ 
+        """
         Update:
             data (dict): {
                 'image0': (torch.Tensor): (N, 1, H, W)
@@ -34,46 +34,65 @@ class TopicFM(nn.Module):
             }
         """
         # 1. Local Feature CNN
-        data.update({
-            'bs': data['image0'].size(0),
-            'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:]
-        })
+        data.update(
+            {
+                "bs": data["image0"].size(0),
+                "hw0_i": data["image0"].shape[2:],
+                "hw1_i": data["image1"].shape[2:],
+            }
+        )
 
-        if data['hw0_i'] == data['hw1_i']:  # faster & better BN convergence
-            feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0))
-            (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs'])
+        if data["hw0_i"] == data["hw1_i"]:  # faster & better BN convergence
+            feats_c, feats_f = self.backbone(
+                torch.cat([data["image0"], data["image1"]], dim=0)
+            )
+            (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(
+                data["bs"]
+            ), feats_f.split(data["bs"])
         else:  # handle different input shapes
-            (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1'])
+            (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(
+                data["image0"]
+            ), self.backbone(data["image1"])
 
-        data.update({
-            'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
-            'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:]
-        })
+        data.update(
+            {
+                "hw0_c": feat_c0.shape[2:],
+                "hw1_c": feat_c1.shape[2:],
+                "hw0_f": feat_f0.shape[2:],
+                "hw1_f": feat_f1.shape[2:],
+            }
+        )
 
         # 2. coarse-level loftr module
-        feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c')
-        feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c')
+        feat_c0 = rearrange(feat_c0, "n c h w -> n (h w) c")
+        feat_c1 = rearrange(feat_c1, "n c h w -> n (h w) c")
 
         mask_c0 = mask_c1 = None  # mask is useful in training
-        if 'mask0' in data:
-            mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2)
+        if "mask0" in data:
+            mask_c0, mask_c1 = data["mask0"].flatten(-2), data["mask1"].flatten(-2)
 
-        feat_c0, feat_c1, conf_matrix, topic_matrix = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1)
-        data.update({"conf_matrix": conf_matrix, "topic_matrix": topic_matrix}) ######
+        feat_c0, feat_c1, conf_matrix, topic_matrix = self.loftr_coarse(
+            feat_c0, feat_c1, mask_c0, mask_c1
+        )
+        data.update({"conf_matrix": conf_matrix, "topic_matrix": topic_matrix})  ######
 
         # 3. match coarse-level
         self.coarse_matching(data)
 
         # 4. fine-level refinement
-        feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0.detach(), feat_c1.detach(), data)
+        feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(
+            feat_f0, feat_f1, feat_c0.detach(), feat_c1.detach(), data
+        )
         if feat_f0_unfold.size(0) != 0:  # at least one coarse level predicted
-            feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold)
+            feat_f0_unfold, feat_f1_unfold = self.loftr_fine(
+                feat_f0_unfold, feat_f1_unfold
+            )
 
         # 5. match fine-level
         self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
 
     def load_state_dict(self, state_dict, *args, **kwargs):
         for k in list(state_dict.keys()):
-            if k.startswith('matcher.'):
-                state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k)
+            if k.startswith("matcher."):
+                state_dict[k.replace("matcher.", "", 1)] = state_dict.pop(k)
         return super().load_state_dict(state_dict, *args, **kwargs)
diff --git a/third_party/TopicFM/src/models/utils/coarse_matching.py b/third_party/TopicFM/src/models/utils/coarse_matching.py
index 75adbb5cc465220e759a044f96f86c08da2d7a50..0cd0ea3db496fe50f82bf7660696e96e26b23484 100644
--- a/third_party/TopicFM/src/models/utils/coarse_matching.py
+++ b/third_party/TopicFM/src/models/utils/coarse_matching.py
@@ -5,8 +5,9 @@ from einops.einops import rearrange
 
 INF = 1e9
 
+
 def mask_border(m, b: int, v):
-    """ Mask borders with value
+    """Mask borders with value
     Args:
         m (torch.Tensor): [N, H0, W0, H1, W1]
         b (int)
@@ -37,22 +38,21 @@ def mask_border_with_padding(m, bd, v, p_m0, p_m1):
     h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
     h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
     for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
-        m[b_idx, h0 - bd:] = v
-        m[b_idx, :, w0 - bd:] = v
-        m[b_idx, :, :, h1 - bd:] = v
-        m[b_idx, :, :, :, w1 - bd:] = v
+        m[b_idx, h0 - bd :] = v
+        m[b_idx, :, w0 - bd :] = v
+        m[b_idx, :, :, h1 - bd :] = v
+        m[b_idx, :, :, :, w1 - bd :] = v
 
 
 def compute_max_candidates(p_m0, p_m1):
     """Compute the max candidates of all pairs within a batch
-    
+
     Args:
         p_m0, p_m1 (torch.Tensor): padded masks
     """
     h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
     h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
-    max_cand = torch.sum(
-        torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
+    max_cand = torch.sum(torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
     return max_cand
 
 
@@ -61,26 +61,27 @@ class CoarseMatching(nn.Module):
         super().__init__()
         self.config = config
         # general config
-        self.thr = config['thr']
-        self.border_rm = config['border_rm']
+        self.thr = config["thr"]
+        self.border_rm = config["border_rm"]
         # -- # for trainig fine-level LoFTR
-        self.train_coarse_percent = config['train_coarse_percent']
-        self.train_pad_num_gt_min = config['train_pad_num_gt_min']
+        self.train_coarse_percent = config["train_coarse_percent"]
+        self.train_pad_num_gt_min = config["train_pad_num_gt_min"]
 
         # we provide 2 options for differentiable matching
-        self.match_type = config['match_type']
-        if self.match_type == 'dual_softmax':
-            self.temperature = config['dsmax_temperature']
-        elif self.match_type == 'sinkhorn':
+        self.match_type = config["match_type"]
+        if self.match_type == "dual_softmax":
+            self.temperature = config["dsmax_temperature"]
+        elif self.match_type == "sinkhorn":
             try:
                 from .superglue import log_optimal_transport
             except ImportError:
                 raise ImportError("download superglue.py first!")
             self.log_optimal_transport = log_optimal_transport
             self.bin_score = nn.Parameter(
-                torch.tensor(config['skh_init_bin_score'], requires_grad=True))
-            self.skh_iters = config['skh_iters']
-            self.skh_prefilter = config['skh_prefilter']
+                torch.tensor(config["skh_init_bin_score"], requires_grad=True)
+            )
+            self.skh_iters = config["skh_iters"]
+            self.skh_prefilter = config["skh_prefilter"]
         else:
             raise NotImplementedError()
 
@@ -99,7 +100,7 @@ class CoarseMatching(nn.Module):
                 'mconf' (torch.Tensor): [M]}
             NOTE: M' != M during training.
         """
-        conf_matrix = data['conf_matrix']
+        conf_matrix = data["conf_matrix"]
         # predict coarse matches from conf_matrix
         data.update(**self.get_coarse_match(conf_matrix, data))
 
@@ -121,28 +122,33 @@ class CoarseMatching(nn.Module):
                 'mconf' (torch.Tensor): [M]}
         """
         axes_lengths = {
-            'h0c': data['hw0_c'][0],
-            'w0c': data['hw0_c'][1],
-            'h1c': data['hw1_c'][0],
-            'w1c': data['hw1_c'][1]
+            "h0c": data["hw0_c"][0],
+            "w0c": data["hw0_c"][1],
+            "h1c": data["hw1_c"][0],
+            "w1c": data["hw1_c"][1],
         }
         _device = conf_matrix.device
         # 1. confidence thresholding
         mask = conf_matrix > self.thr
-        mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
-                         **axes_lengths)
-        if 'mask0' not in data:
+        mask = rearrange(
+            mask, "b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c", **axes_lengths
+        )
+        if "mask0" not in data:
             mask_border(mask, self.border_rm, False)
         else:
-            mask_border_with_padding(mask, self.border_rm, False,
-                                     data['mask0'], data['mask1'])
-        mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',
-                         **axes_lengths)
+            mask_border_with_padding(
+                mask, self.border_rm, False, data["mask0"], data["mask1"]
+            )
+        mask = rearrange(
+            mask, "b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)", **axes_lengths
+        )
 
         # 2. mutual nearest
-        mask = mask \
-            * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \
+        mask = (
+            mask
+            * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0])
             * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
+        )
 
         # 3. find all valid coarse matches
         # this only works when at most one `True` in each row
@@ -157,16 +163,17 @@ class CoarseMatching(nn.Module):
             # NOTE:
             # The sampling is performed across all pairs in a batch without manually balancing
             # #samples for fine-level increases w.r.t. batch_size
-            if 'mask0' not in data:
-                num_candidates_max = mask.size(0) * max(
-                    mask.size(1), mask.size(2))
+            if "mask0" not in data:
+                num_candidates_max = mask.size(0) * max(mask.size(1), mask.size(2))
             else:
                 num_candidates_max = compute_max_candidates(
-                    data['mask0'], data['mask1'])
-            num_matches_train = int(num_candidates_max *
-                                    self.train_coarse_percent)
+                    data["mask0"], data["mask1"]
+                )
+            num_matches_train = int(num_candidates_max * self.train_coarse_percent)
             num_matches_pred = len(b_ids)
-            assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches"
+            assert (
+                self.train_pad_num_gt_min < num_matches_train
+            ), "min-num-gt-pad should be less than num-train-matches"
 
             # pred_indices is to select from prediction
             if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
@@ -174,44 +181,55 @@ class CoarseMatching(nn.Module):
             else:
                 pred_indices = torch.randint(
                     num_matches_pred,
-                    (num_matches_train - self.train_pad_num_gt_min, ),
-                    device=_device)
+                    (num_matches_train - self.train_pad_num_gt_min,),
+                    device=_device,
+                )
 
             # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
             gt_pad_indices = torch.randint(
-                    len(data['spv_b_ids']),
-                    (max(num_matches_train - num_matches_pred,
-                        self.train_pad_num_gt_min), ),
-                    device=_device)
-            mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device)  # set conf of gt paddings to all zero
+                len(data["spv_b_ids"]),
+                (max(num_matches_train - num_matches_pred, self.train_pad_num_gt_min),),
+                device=_device,
+            )
+            mconf_gt = torch.zeros(
+                len(data["spv_b_ids"]), device=_device
+            )  # set conf of gt paddings to all zero
 
             b_ids, i_ids, j_ids, mconf = map(
-                lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
-                                       dim=0),
-                *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']],
-                     [j_ids, data['spv_j_ids']], [mconf, mconf_gt]))
+                lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], dim=0),
+                *zip(
+                    [b_ids, data["spv_b_ids"]],
+                    [i_ids, data["spv_i_ids"]],
+                    [j_ids, data["spv_j_ids"]],
+                    [mconf, mconf_gt],
+                )
+            )
 
         # These matches select patches that feed into fine-level network
-        coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}
+        coarse_matches = {"b_ids": b_ids, "i_ids": i_ids, "j_ids": j_ids}
 
         # 4. Update with matches in original image resolution
-        scale = data['hw0_i'][0] / data['hw0_c'][0]
-        scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
-        scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
-        mkpts0_c = torch.stack(
-            [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]],
-            dim=1) * scale0
-        mkpts1_c = torch.stack(
-            [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]],
-            dim=1) * scale1
+        scale = data["hw0_i"][0] / data["hw0_c"][0]
+        scale0 = scale * data["scale0"][b_ids] if "scale0" in data else scale
+        scale1 = scale * data["scale1"][b_ids] if "scale1" in data else scale
+        mkpts0_c = (
+            torch.stack([i_ids % data["hw0_c"][1], i_ids // data["hw0_c"][1]], dim=1)
+            * scale0
+        )
+        mkpts1_c = (
+            torch.stack([j_ids % data["hw1_c"][1], j_ids // data["hw1_c"][1]], dim=1)
+            * scale1
+        )
 
         # These matches is the current prediction (for visualization)
-        coarse_matches.update({
-            'gt_mask': mconf == 0,
-            'm_bids': b_ids[mconf != 0],  # mconf == 0 => gt matches
-            'mkpts0_c': mkpts0_c[mconf != 0],
-            'mkpts1_c': mkpts1_c[mconf != 0],
-            'mconf': mconf[mconf != 0]
-        })
+        coarse_matches.update(
+            {
+                "gt_mask": mconf == 0,
+                "m_bids": b_ids[mconf != 0],  # mconf == 0 => gt matches
+                "mkpts0_c": mkpts0_c[mconf != 0],
+                "mkpts1_c": mkpts1_c[mconf != 0],
+                "mconf": mconf[mconf != 0],
+            }
+        )
 
         return coarse_matches
diff --git a/third_party/TopicFM/src/models/utils/fine_matching.py b/third_party/TopicFM/src/models/utils/fine_matching.py
index 018f2fe475600b319998c263a97237ce135c3aaf..7156e3e1f22e2e341062565e5ad6baee41dd9bc6 100644
--- a/third_party/TopicFM/src/models/utils/fine_matching.py
+++ b/third_party/TopicFM/src/models/utils/fine_matching.py
@@ -27,39 +27,57 @@ class FineMatching(nn.Module):
         """
         M, WW, C = feat_f0.shape
         W = int(math.sqrt(WW))
-        scale = data['hw0_i'][0] / data['hw0_f'][0]
+        scale = data["hw0_i"][0] / data["hw0_f"][0]
         self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale
 
         # corner case: if no coarse matches found
         if M == 0:
-            assert self.training == False, "M is always >0, when training, see coarse_matching.py"
+            assert (
+                self.training == False
+            ), "M is always >0, when training, see coarse_matching.py"
             # logger.warning('No matches found in coarse-level.')
-            data.update({
-                'expec_f': torch.empty(0, 3, device=feat_f0.device),
-                'mkpts0_f': data['mkpts0_c'],
-                'mkpts1_f': data['mkpts1_c'],
-            })
+            data.update(
+                {
+                    "expec_f": torch.empty(0, 3, device=feat_f0.device),
+                    "mkpts0_f": data["mkpts0_c"],
+                    "mkpts1_f": data["mkpts1_c"],
+                }
+            )
             return
 
-        feat_f0_picked = feat_f0[:, WW//2, :]
+        feat_f0_picked = feat_f0[:, WW // 2, :]
 
-        sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1)
-        softmax_temp = 1. / C**.5
+        sim_matrix = torch.einsum("mc,mrc->mr", feat_f0_picked, feat_f1)
+        softmax_temp = 1.0 / C**0.5
         heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1)
-        feat_f1_picked = (feat_f1 * heatmap.unsqueeze(-1)).sum(dim=1) # [M, C]
+        feat_f1_picked = (feat_f1 * heatmap.unsqueeze(-1)).sum(dim=1)  # [M, C]
         heatmap = heatmap.view(-1, W, W)
 
         # compute coordinates from heatmap
-        coords1_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0]  # [M, 2]
-        grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2)  # [1, WW, 2]
+        coords1_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[
+            0
+        ]  # [M, 2]
+        grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(
+            1, -1, 2
+        )  # [1, WW, 2]
 
         # compute std over <x, y>
-        var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords1_normalized**2  # [M, 2]
-        std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1)  # [M]  clamp needed for numerical stability
-        
+        var = (
+            torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1)
+            - coords1_normalized**2
+        )  # [M, 2]
+        std = torch.sum(
+            torch.sqrt(torch.clamp(var, min=1e-10)), -1
+        )  # [M]  clamp needed for numerical stability
+
         # for fine-level supervision
-        data.update({'expec_f': torch.cat([coords1_normalized, std.unsqueeze(1)], -1),
-                     'descriptors0': feat_f0_picked.detach(), 'descriptors1': feat_f1_picked.detach()})
+        data.update(
+            {
+                "expec_f": torch.cat([coords1_normalized, std.unsqueeze(1)], -1),
+                "descriptors0": feat_f0_picked.detach(),
+                "descriptors1": feat_f1_picked.detach(),
+            }
+        )
 
         # compute absolute kpt coords
         self.get_fine_match(coords1_normalized, data)
@@ -70,11 +88,13 @@ class FineMatching(nn.Module):
 
         # mkpts0_f and mkpts1_f
         # scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else scale
-        mkpts0_f = data['mkpts0_c'] # + (coords0_normed * (W // 2) * scale0 )[:len(data['mconf'])]
-        scale1 = scale * data['scale1'][data['b_ids']] if 'scale1' in data else scale
-        mkpts1_f = data['mkpts1_c'] + (coords1_normed * (W // 2) * scale1)[:len(data['mconf'])]
+        mkpts0_f = data[
+            "mkpts0_c"
+        ]  # + (coords0_normed * (W // 2) * scale0 )[:len(data['mconf'])]
+        scale1 = scale * data["scale1"][data["b_ids"]] if "scale1" in data else scale
+        mkpts1_f = (
+            data["mkpts1_c"]
+            + (coords1_normed * (W // 2) * scale1)[: len(data["mconf"])]
+        )
 
-        data.update({
-            "mkpts0_f": mkpts0_f,
-            "mkpts1_f": mkpts1_f
-        })
+        data.update({"mkpts0_f": mkpts0_f, "mkpts1_f": mkpts1_f})
diff --git a/third_party/TopicFM/src/models/utils/geometry.py b/third_party/TopicFM/src/models/utils/geometry.py
index f95cdb65b48324c4f4ceb20231b1bed992b41116..6101f738f2b2b7ee014fcb53a4032391939ed8cd 100644
--- a/third_party/TopicFM/src/models/utils/geometry.py
+++ b/third_party/TopicFM/src/models/utils/geometry.py
@@ -3,10 +3,10 @@ import torch
 
 @torch.no_grad()
 def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
-    """ Warp kpts0 from I0 to I1 with depth, K and Rt
+    """Warp kpts0 from I0 to I1 with depth, K and Rt
     Also check covisibility and depth consistency.
     Depth is consistent if relative error < 0.2 (hard-coded).
-    
+
     Args:
         kpts0 (torch.Tensor): [N, L, 2] - <x, y>,
         depth0 (torch.Tensor): [N, H, W],
@@ -22,33 +22,52 @@ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
 
     # Sample depth, get calculable_mask on depth != 0
     kpts0_depth = torch.stack(
-        [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
+        [
+            depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]]
+            for i in range(kpts0.shape[0])
+        ],
+        dim=0,
     )  # (N, L)
     nonzero_mask = kpts0_depth != 0
 
     # Unproject
-    kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None]  # (N, L, 3)
+    kpts0_h = (
+        torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
+        * kpts0_depth[..., None]
+    )  # (N, L, 3)
     kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1)  # (N, 3, L)
 
     # Rigid Transform
-    w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]]    # (N, 3, L)
+    w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]]  # (N, 3, L)
     w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
 
     # Project
     w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1)  # (N, L, 3)
-    w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4)  # (N, L, 2), +1e-4 to avoid zero depth
+    w_kpts0 = w_kpts0_h[:, :, :2] / (
+        w_kpts0_h[:, :, [2]] + 1e-4
+    )  # (N, L, 2), +1e-4 to avoid zero depth
 
     # Covisible Check
     h, w = depth1.shape[1:3]
-    covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
-        (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
+    covisible_mask = (
+        (w_kpts0[:, :, 0] > 0)
+        * (w_kpts0[:, :, 0] < w - 1)
+        * (w_kpts0[:, :, 1] > 0)
+        * (w_kpts0[:, :, 1] < h - 1)
+    )
     w_kpts0_long = w_kpts0.long()
     w_kpts0_long[~covisible_mask, :] = 0
 
     w_kpts0_depth = torch.stack(
-        [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
+        [
+            depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]]
+            for i in range(w_kpts0_long.shape[0])
+        ],
+        dim=0,
     )  # (N, L)
-    consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
+    consistent_mask = (
+        (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
+    ).abs() < 0.2
     valid_mask = nonzero_mask * covisible_mask * consistent_mask
 
     return valid_mask, w_kpts0
diff --git a/third_party/TopicFM/src/models/utils/supervision.py b/third_party/TopicFM/src/models/utils/supervision.py
index 1f1f0478fdcbe7f8ceffbc4aff4d507cec55bbd2..86f167e95439d588c998ca32b9296c3482484215 100644
--- a/third_party/TopicFM/src/models/utils/supervision.py
+++ b/third_party/TopicFM/src/models/utils/supervision.py
@@ -13,7 +13,7 @@ from .geometry import warp_kpts
 @torch.no_grad()
 def mask_pts_at_padded_regions(grid_pt, mask):
     """For megadepth dataset, zero-padding exists in images"""
-    mask = repeat(mask, 'n h w -> n (h w) c', c=2)
+    mask = repeat(mask, "n h w -> n (h w) c", c=2)
     grid_pt[~mask.bool()] = 0
     return grid_pt
 
@@ -30,37 +30,55 @@ def spvs_coarse(data, config):
             'spv_w_pt0_i': [N, hw0, 2], in original image resolution
             'spv_pt1_i': [N, hw1, 2], in original image resolution
         }
-        
+
     NOTE:
         - for scannet dataset, there're 3 kinds of resolution {i, c, f}
         - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
     """
     # 1. misc
-    device = data['image0'].device
-    N, _, H0, W0 = data['image0'].shape
-    _, _, H1, W1 = data['image1'].shape
-    scale = config['MODEL']['RESOLUTION'][0]
-    scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
-    scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale
+    device = data["image0"].device
+    N, _, H0, W0 = data["image0"].shape
+    _, _, H1, W1 = data["image1"].shape
+    scale = config["MODEL"]["RESOLUTION"][0]
+    scale0 = scale * data["scale0"][:, None] if "scale0" in data else scale
+    scale1 = scale * data["scale1"][:, None] if "scale0" in data else scale
     h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
 
     # 2. warp grids
     # create kpts in meshgrid and resize them to image resolution
-    grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1)    # [N, hw, 2]
+    grid_pt0_c = (
+        create_meshgrid(h0, w0, False, device).reshape(1, h0 * w0, 2).repeat(N, 1, 1)
+    )  # [N, hw, 2]
     grid_pt0_i = scale0 * grid_pt0_c
-    grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1)
+    grid_pt1_c = (
+        create_meshgrid(h1, w1, False, device).reshape(1, h1 * w1, 2).repeat(N, 1, 1)
+    )
     grid_pt1_i = scale1 * grid_pt1_c
 
     # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt
-    if 'mask0' in data:
-        grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0'])
-        grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1'])
+    if "mask0" in data:
+        grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data["mask0"])
+        grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data["mask1"])
 
     # warp kpts bi-directionally and resize them to coarse-level resolution
     # (no depth consistency check, since it leads to worse results experimentally)
     # (unhandled edge case: points with 0-depth will be warped to the left-up corner)
-    _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'])
-    _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
+    _, w_pt0_i = warp_kpts(
+        grid_pt0_i,
+        data["depth0"],
+        data["depth1"],
+        data["T_0to1"],
+        data["K0"],
+        data["K1"],
+    )
+    _, w_pt1_i = warp_kpts(
+        grid_pt1_i,
+        data["depth1"],
+        data["depth0"],
+        data["T_1to0"],
+        data["K1"],
+        data["K0"],
+    )
     w_pt0_c = w_pt0_i / scale1
     w_pt1_c = w_pt1_i / scale0
 
@@ -72,21 +90,26 @@ def spvs_coarse(data, config):
 
     # corner case: out of boundary
     def out_bound_mask(pt, w, h):
-        return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
+        return (
+            (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
+        )
+
     nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0
     nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0
 
-    loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0)
-    correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1)
+    loop_back = torch.stack(
+        [nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0
+    )
+    correct_0to1 = loop_back == torch.arange(h0 * w0, device=device)[None].repeat(N, 1)
     correct_0to1[:, 0] = False  # ignore the top-left corner
 
     # 4. construct a gt conf_matrix
-    conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device)
+    conf_matrix_gt = torch.zeros(N, h0 * w0, h1 * w1, device=device)
     b_ids, i_ids = torch.where(correct_0to1 != 0)
     j_ids = nearest_index1[b_ids, i_ids]
 
     conf_matrix_gt[b_ids, i_ids, j_ids] = 1
-    data.update({'conf_matrix_gt': conf_matrix_gt})
+    data.update({"conf_matrix_gt": conf_matrix_gt})
 
     # 5. save coarse matches(gt) for training fine level
     if len(b_ids) == 0:
@@ -96,30 +119,26 @@ def spvs_coarse(data, config):
         i_ids = torch.tensor([0], device=device)
         j_ids = torch.tensor([0], device=device)
 
-    data.update({
-        'spv_b_ids': b_ids,
-        'spv_i_ids': i_ids,
-        'spv_j_ids': j_ids
-    })
+    data.update({"spv_b_ids": b_ids, "spv_i_ids": i_ids, "spv_j_ids": j_ids})
 
     # 6. save intermediate results (for fast fine-level computation)
-    data.update({
-        'spv_w_pt0_i': w_pt0_i,
-        'spv_pt1_i': grid_pt1_i
-    })
+    data.update({"spv_w_pt0_i": w_pt0_i, "spv_pt1_i": grid_pt1_i})
 
 
 def compute_supervision_coarse(data, config):
-    assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!"
-    data_source = data['dataset_name'][0]
-    if data_source.lower() in ['scannet', 'megadepth']:
+    assert (
+        len(set(data["dataset_name"])) == 1
+    ), "Do not support mixed datasets training!"
+    data_source = data["dataset_name"][0]
+    if data_source.lower() in ["scannet", "megadepth"]:
         spvs_coarse(data, config)
     else:
-        raise ValueError(f'Unknown data source: {data_source}')
+        raise ValueError(f"Unknown data source: {data_source}")
 
 
 ##############  ↓  Fine-Level supervision  ↓  ##############
 
+
 @torch.no_grad()
 def spvs_fine(data, config):
     """
@@ -129,23 +148,25 @@ def spvs_fine(data, config):
     """
     # 1. misc
     # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i')
-    w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i']
-    scale = config['MODEL']['RESOLUTION'][1]
-    radius = config['MODEL']['FINE_WINDOW_SIZE'] // 2
+    w_pt0_i, pt1_i = data["spv_w_pt0_i"], data["spv_pt1_i"]
+    scale = config["MODEL"]["RESOLUTION"][1]
+    radius = config["MODEL"]["FINE_WINDOW_SIZE"] // 2
 
     # 2. get coarse prediction
-    b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids']
+    b_ids, i_ids, j_ids = data["b_ids"], data["i_ids"], data["j_ids"]
 
     # 3. compute gt
-    scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale
+    scale = scale * data["scale1"][b_ids] if "scale0" in data else scale
     # `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later
-    expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius  # [M, 2]
+    expec_f_gt = (
+        (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius
+    )  # [M, 2]
     data.update({"expec_f_gt": expec_f_gt})
 
 
 def compute_supervision_fine(data, config):
-    data_source = data['dataset_name'][0]
-    if data_source.lower() in ['scannet', 'megadepth']:
+    data_source = data["dataset_name"][0]
+    if data_source.lower() in ["scannet", "megadepth"]:
         spvs_fine(data, config)
     else:
         raise NotImplementedError
diff --git a/third_party/TopicFM/src/optimizers/__init__.py b/third_party/TopicFM/src/optimizers/__init__.py
index e1db2285352586c250912bdd2c4ae5029620ab5f..e4e36c22e00217deccacd589f8924b2f74589456 100644
--- a/third_party/TopicFM/src/optimizers/__init__.py
+++ b/third_party/TopicFM/src/optimizers/__init__.py
@@ -7,9 +7,13 @@ def build_optimizer(model, config):
     lr = config.TRAINER.TRUE_LR
 
     if name == "adam":
-        return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY)
+        return torch.optim.Adam(
+            model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY
+        )
     elif name == "adamw":
-        return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY)
+        return torch.optim.AdamW(
+            model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY
+        )
     else:
         raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
 
@@ -24,18 +28,27 @@ def build_scheduler(config, optimizer):
             'frequency': x, (optional)
         }
     """
-    scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL}
+    scheduler = {"interval": config.TRAINER.SCHEDULER_INTERVAL}
     name = config.TRAINER.SCHEDULER
 
-    if name == 'MultiStepLR':
+    if name == "MultiStepLR":
         scheduler.update(
-            {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)})
-    elif name == 'CosineAnnealing':
+            {
+                "scheduler": MultiStepLR(
+                    optimizer,
+                    config.TRAINER.MSLR_MILESTONES,
+                    gamma=config.TRAINER.MSLR_GAMMA,
+                )
+            }
+        )
+    elif name == "CosineAnnealing":
         scheduler.update(
-            {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)})
-    elif name == 'ExponentialLR':
+            {"scheduler": CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)}
+        )
+    elif name == "ExponentialLR":
         scheduler.update(
-            {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)})
+            {"scheduler": ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)}
+        )
     else:
         raise NotImplementedError()
 
diff --git a/third_party/TopicFM/src/utils/augment.py b/third_party/TopicFM/src/utils/augment.py
index d7c5d3e11b6fe083aaeff7555bb7ce3a4bfb755d..068751c6c07091bbaed76debd43a73155f61b9bd 100644
--- a/third_party/TopicFM/src/utils/augment.py
+++ b/third_party/TopicFM/src/utils/augment.py
@@ -7,16 +7,21 @@ class DarkAug(object):
     """
 
     def __init__(self) -> None:
-        self.augmentor = A.Compose([
-            A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)),
-            A.Blur(p=0.1, blur_limit=(3, 9)),
-            A.MotionBlur(p=0.2, blur_limit=(3, 25)),
-            A.RandomGamma(p=0.1, gamma_limit=(15, 65)),
-            A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40))
-        ], p=0.75)
+        self.augmentor = A.Compose(
+            [
+                A.RandomBrightnessContrast(
+                    p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)
+                ),
+                A.Blur(p=0.1, blur_limit=(3, 9)),
+                A.MotionBlur(p=0.2, blur_limit=(3, 25)),
+                A.RandomGamma(p=0.1, gamma_limit=(15, 65)),
+                A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40)),
+            ],
+            p=0.75,
+        )
 
     def __call__(self, x):
-        return self.augmentor(image=x)['image']
+        return self.augmentor(image=x)["image"]
 
 
 class MobileAug(object):
@@ -25,31 +30,36 @@ class MobileAug(object):
     """
 
     def __init__(self):
-        self.augmentor = A.Compose([
-            A.MotionBlur(p=0.25),
-            A.ColorJitter(p=0.5),
-            A.RandomRain(p=0.1),  # random occlusion
-            A.RandomSunFlare(p=0.1),
-            A.JpegCompression(p=0.25),
-            A.ISONoise(p=0.25)
-        ], p=1.0)
+        self.augmentor = A.Compose(
+            [
+                A.MotionBlur(p=0.25),
+                A.ColorJitter(p=0.5),
+                A.RandomRain(p=0.1),  # random occlusion
+                A.RandomSunFlare(p=0.1),
+                A.JpegCompression(p=0.25),
+                A.ISONoise(p=0.25),
+            ],
+            p=1.0,
+        )
 
     def __call__(self, x):
-        return self.augmentor(image=x)['image']
+        return self.augmentor(image=x)["image"]
 
 
 def build_augmentor(method=None, **kwargs):
     if method is not None:
-        raise NotImplementedError('Using of augmentation functions are not supported yet!')
-    if method == 'dark':
+        raise NotImplementedError(
+            "Using of augmentation functions are not supported yet!"
+        )
+    if method == "dark":
         return DarkAug()
-    elif method == 'mobile':
+    elif method == "mobile":
         return MobileAug()
     elif method is None:
         return None
     else:
-        raise ValueError(f'Invalid augmentation method: {method}')
+        raise ValueError(f"Invalid augmentation method: {method}")
 
 
-if __name__ == '__main__':
-    augmentor = build_augmentor('FDA')
+if __name__ == "__main__":
+    augmentor = build_augmentor("FDA")
diff --git a/third_party/TopicFM/src/utils/comm.py b/third_party/TopicFM/src/utils/comm.py
index 26ec9517cc47e224430106d8ae9aa99a3fe49167..9f578cda8933cc358934c645fcf413c63ab4d79d 100644
--- a/third_party/TopicFM/src/utils/comm.py
+++ b/third_party/TopicFM/src/utils/comm.py
@@ -98,11 +98,11 @@ def _serialize_to_tensor(data, group):
     device = torch.device("cpu" if backend == "gloo" else "cuda")
 
     buffer = pickle.dumps(data)
-    if len(buffer) > 1024 ** 3:
+    if len(buffer) > 1024**3:
         logger = logging.getLogger(__name__)
         logger.warning(
             "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
-                get_rank(), len(buffer) / (1024 ** 3), device
+                get_rank(), len(buffer) / (1024**3), device
             )
         )
     storage = torch.ByteStorage.from_buffer(buffer)
@@ -122,7 +122,8 @@ def _pad_to_largest_tensor(tensor, group):
     ), "comm.gather/all_gather must be called from ranks within the given group!"
     local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
     size_list = [
-        torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
+        torch.zeros([1], dtype=torch.int64, device=tensor.device)
+        for _ in range(world_size)
     ]
     dist.all_gather(size_list, local_size, group=group)
 
@@ -133,7 +134,9 @@ def _pad_to_largest_tensor(tensor, group):
     # we pad the tensor because torch all_gather does not support
     # gathering tensors of different shapes
     if local_size != max_size:
-        padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
+        padding = torch.zeros(
+            (max_size - local_size,), dtype=torch.uint8, device=tensor.device
+        )
         tensor = torch.cat((tensor, padding), dim=0)
     return size_list, tensor
 
@@ -164,7 +167,8 @@ def all_gather(data, group=None):
 
     # receiving Tensor from all ranks
     tensor_list = [
-        torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+        torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
+        for _ in size_list
     ]
     dist.all_gather(tensor_list, tensor, group=group)
 
@@ -205,7 +209,8 @@ def gather(data, dst=0, group=None):
     if rank == dst:
         max_size = max(size_list)
         tensor_list = [
-            torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+            torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
+            for _ in size_list
         ]
         dist.gather(tensor, tensor_list, dst=dst, group=group)
 
@@ -228,7 +233,7 @@ def shared_random_seed():
 
     All workers must call this function, otherwise it will deadlock.
     """
-    ints = np.random.randint(2 ** 31)
+    ints = np.random.randint(2**31)
     all_ints = all_gather(ints)
     return all_ints[0]
 
diff --git a/third_party/TopicFM/src/utils/dataloader.py b/third_party/TopicFM/src/utils/dataloader.py
index 6da37b880a290c2bb3ebb028d0c8dab592acc5c1..b980dfd344714870ecdacd9e7a9742f51c3ee14d 100644
--- a/third_party/TopicFM/src/utils/dataloader.py
+++ b/third_party/TopicFM/src/utils/dataloader.py
@@ -3,21 +3,22 @@ import numpy as np
 
 # --- PL-DATAMODULE ---
 
+
 def get_local_split(items: list, world_size: int, rank: int, seed: int):
-    """ The local rank only loads a split of the dataset. """
+    """The local rank only loads a split of the dataset."""
     n_items = len(items)
     items_permute = np.random.RandomState(seed).permutation(items)
     if n_items % world_size == 0:
         padded_items = items_permute
     else:
         padding = np.random.RandomState(seed).choice(
-            items,
-            world_size - (n_items % world_size),
-            replace=True)
+            items, world_size - (n_items % world_size), replace=True
+        )
         padded_items = np.concatenate([items_permute, padding])
-        assert len(padded_items) % world_size == 0, \
-            f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}'
+        assert (
+            len(padded_items) % world_size == 0
+        ), f"len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}"
     n_per_rank = len(padded_items) // world_size
-    local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)]
+    local_items = padded_items[n_per_rank * rank : n_per_rank * (rank + 1)]
 
     return local_items
diff --git a/third_party/TopicFM/src/utils/dataset.py b/third_party/TopicFM/src/utils/dataset.py
index 647bbadd821b6c90736ed45462270670b1017b0b..f26722dddcc15516b1986182a246b0cdb52c347a 100644
--- a/third_party/TopicFM/src/utils/dataset.py
+++ b/third_party/TopicFM/src/utils/dataset.py
@@ -12,8 +12,11 @@ MEGADEPTH_CLIENT = SCANNET_CLIENT = None
 
 # --- DATA IO ---
 
+
 def load_array_from_s3(
-    path, client, cv_type,
+    path,
+    client,
+    cv_type,
     use_h5py=False,
 ):
     byte_str = client.Get(path)
@@ -23,7 +26,7 @@ def load_array_from_s3(
             data = cv2.imdecode(raw_array, cv_type)
         else:
             f = io.BytesIO(byte_str)
-            data = np.array(h5py.File(f, 'r')['/depth'])
+            data = np.array(h5py.File(f, "r")["/depth"])
     except Exception as ex:
         print(f"==> Data loading failure: {path}")
         raise ex
@@ -33,9 +36,8 @@ def load_array_from_s3(
 
 
 def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT):
-    cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \
-                else cv2.IMREAD_COLOR
-    if str(path).startswith('s3://'):
+    cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None else cv2.IMREAD_COLOR
+    if str(path).startswith("s3://"):
         image = load_array_from_s3(str(path), client, cv_type)
     else:
         image = cv2.imread(str(path), cv_type)
@@ -49,9 +51,9 @@ def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT):
 
 
 def get_resized_wh(w, h, resize=None):
-    if (resize is not None) and (max(h,w) > resize):  # resize the longer edge
+    if (resize is not None) and (max(h, w) > resize):  # resize the longer edge
         scale = resize / max(h, w)
-        w_new, h_new = int(round(w*scale)), int(round(h*scale))
+        w_new, h_new = int(round(w * scale)), int(round(h * scale))
     else:
         w_new, h_new = w, h
     return w_new, h_new
@@ -66,20 +68,22 @@ def get_divisible_wh(w, h, df=None):
 
 
 def pad_bottom_right(inp, pad_size, ret_mask=False):
-    assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
+    assert isinstance(pad_size, int) and pad_size >= max(
+        inp.shape[-2:]
+    ), f"{pad_size} < {max(inp.shape[-2:])}"
     mask = None
     if inp.ndim == 2:
         padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
-        padded[:inp.shape[0], :inp.shape[1]] = inp
+        padded[: inp.shape[0], : inp.shape[1]] = inp
         if ret_mask:
             mask = np.zeros((pad_size, pad_size), dtype=bool)
-            mask[:inp.shape[0], :inp.shape[1]] = True
+            mask[: inp.shape[0], : inp.shape[1]] = True
     elif inp.ndim == 3:
         padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
-        padded[:, :inp.shape[1], :inp.shape[2]] = inp
+        padded[:, : inp.shape[1], : inp.shape[2]] = inp
         if ret_mask:
             mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
-            mask[:, :inp.shape[1], :inp.shape[2]] = True
+            mask[:, : inp.shape[1], : inp.shape[2]] = True
     else:
         raise NotImplementedError()
     return padded, mask
@@ -87,6 +91,7 @@ def pad_bottom_right(inp, pad_size, ret_mask=False):
 
 # --- MEGADEPTH ---
 
+
 def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None):
     """
     Args:
@@ -96,7 +101,7 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No
     Returns:
         image (torch.tensor): (1, h, w)
         mask (torch.tensor): (h, w)
-        scale (torch.tensor): [w/w_new, h/h_new]        
+        scale (torch.tensor): [w/w_new, h/h_new]
     """
     # read image
     image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT)
@@ -107,25 +112,27 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No
     w_new, h_new = get_divisible_wh(w_new, h_new, df)
 
     image = cv2.resize(image, (w_new, h_new))
-    scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
+    scale = torch.tensor([w / w_new, h / h_new], dtype=torch.float)
 
     if padding:  # padding
-        pad_to = resize #max(h_new, w_new)
+        pad_to = resize  # max(h_new, w_new)
         image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
     else:
         mask = None
 
-    image = torch.from_numpy(image).float()[None] / 255  # (h, w) -> (1, h, w) and normalized
+    image = (
+        torch.from_numpy(image).float()[None] / 255
+    )  # (h, w) -> (1, h, w) and normalized
     mask = torch.from_numpy(mask) if mask is not None else None
 
     return image, mask, scale
 
 
 def read_megadepth_depth(path, pad_to=None):
-    if str(path).startswith('s3://'):
+    if str(path).startswith("s3://"):
         depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True)
     else:
-        depth = np.array(h5py.File(path, 'r')['depth'])
+        depth = np.array(h5py.File(path, "r")["depth"])
     if pad_to is not None:
         depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False)
     depth = torch.from_numpy(depth).float()  # (h, w)
@@ -134,6 +141,7 @@ def read_megadepth_depth(path, pad_to=None):
 
 # --- ScanNet ---
 
+
 def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
     """
     Args:
@@ -142,7 +150,7 @@ def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
     Returns:
         image (torch.tensor): (1, h, w)
         mask (torch.tensor): (h, w)
-        scale (torch.tensor): [w/w_new, h/h_new]        
+        scale (torch.tensor): [w/w_new, h/h_new]
     """
     # read and resize image
     image = imread_gray(path, augment_fn)
@@ -155,6 +163,7 @@ def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
 
 # ---- evaluation datasets: HLoc, Aachen, InLoc
 
+
 def read_img_gray(path, resize=None, down_factor=16):
     # read and resize image
     image = imread_gray(path, None)
@@ -174,7 +183,7 @@ def read_img_gray(path, resize=None, down_factor=16):
 
 
 def read_scannet_depth(path):
-    if str(path).startswith('s3://'):
+    if str(path).startswith("s3://"):
         depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED)
     else:
         depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
@@ -184,18 +193,17 @@ def read_scannet_depth(path):
 
 
 def read_scannet_pose(path):
-    """ Read ScanNet's Camera2World pose and transform it to World2Camera.
-    
+    """Read ScanNet's Camera2World pose and transform it to World2Camera.
+
     Returns:
         pose_w2c (np.ndarray): (4, 4)
     """
-    cam2world = np.loadtxt(path, delimiter=' ')
+    cam2world = np.loadtxt(path, delimiter=" ")
     world2cam = inv(cam2world)
     return world2cam
 
 
 def read_scannet_intrinsic(path):
-    """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
-    """
-    intrinsic = np.loadtxt(path, delimiter=' ')
+    """Read ScanNet's intrinsic matrix and return the 3x3 matrix."""
+    intrinsic = np.loadtxt(path, delimiter=" ")
     return intrinsic[:-1, :-1]
diff --git a/third_party/TopicFM/src/utils/metrics.py b/third_party/TopicFM/src/utils/metrics.py
index a93c31ed1d151cd41e2449a19be2d6abc5f9d419..6190b04f0af85680a0c951f74309c0b66c80e1e5 100644
--- a/third_party/TopicFM/src/utils/metrics.py
+++ b/third_party/TopicFM/src/utils/metrics.py
@@ -9,6 +9,7 @@ from kornia.geometry.conversions import convert_points_to_homogeneous
 
 # --- METRICS ---
 
+
 def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
     # angle error between 2 vectors
     t_gt = T_0to1[:3, 3]
@@ -21,7 +22,7 @@ def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
     # angle error between 2 rotation matrices
     R_gt = T_0to1[:3, :3]
     cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2
-    cos = np.clip(cos, -1., 1.)  # handle numercial errors
+    cos = np.clip(cos, -1.0, 1.0)  # handle numercial errors
     R_err = np.rad2deg(np.abs(np.arccos(cos)))
 
     return t_err, R_err
@@ -43,30 +44,36 @@ def symmetric_epipolar_distance(pts0, pts1, E, K0, K1):
     p1Ep0 = torch.sum(pts1 * Ep0, -1)  # [N,]
     Etp1 = pts1 @ E  # [N, 3]
 
-    d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2))  # N
+    d = p1Ep0**2 * (
+        1.0 / (Ep0[:, 0] ** 2 + Ep0[:, 1] ** 2)
+        + 1.0 / (Etp1[:, 0] ** 2 + Etp1[:, 1] ** 2)
+    )  # N
     return d
 
 
 def compute_symmetrical_epipolar_errors(data):
-    """ 
+    """
     Update:
         data (dict):{"epi_errs": [M]}
     """
-    Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3])
-    E_mat = Tx @ data['T_0to1'][:, :3, :3]
+    Tx = numeric.cross_product_matrix(data["T_0to1"][:, :3, 3])
+    E_mat = Tx @ data["T_0to1"][:, :3, :3]
 
-    m_bids = data['m_bids']
-    pts0 = data['mkpts0_f']
-    pts1 = data['mkpts1_f']
+    m_bids = data["m_bids"]
+    pts0 = data["mkpts0_f"]
+    pts1 = data["mkpts1_f"]
 
     epi_errs = []
     for bs in range(Tx.size(0)):
         mask = m_bids == bs
         epi_errs.append(
-            symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs]))
+            symmetric_epipolar_distance(
+                pts0[mask], pts1[mask], E_mat[bs], data["K0"][bs], data["K1"][bs]
+            )
+        )
     epi_errs = torch.cat(epi_errs, dim=0)
 
-    data.update({'epi_errs': epi_errs})
+    data.update({"epi_errs": epi_errs})
 
 
 def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
@@ -81,7 +88,8 @@ def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
 
     # compute pose with cv2
     E, mask = cv2.findEssentialMat(
-        kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC)
+        kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC
+    )
     if E is None:
         print("\nE is None while trying to recover pose.\n")
         return None
@@ -99,7 +107,7 @@ def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
 
 
 def compute_pose_errors(data, config=None, ransac_thr=0.5, ransac_conf=0.99999):
-    """ 
+    """
     Update:
         data (dict):{
             "R_errs" List[float]: [N]
@@ -107,35 +115,40 @@ def compute_pose_errors(data, config=None, ransac_thr=0.5, ransac_conf=0.99999):
             "inliers" List[np.ndarray]: [N]
         }
     """
-    pixel_thr = config.TRAINER.RANSAC_PIXEL_THR if config is not None else ransac_thr  # 0.5
+    pixel_thr = (
+        config.TRAINER.RANSAC_PIXEL_THR if config is not None else ransac_thr
+    )  # 0.5
     conf = config.TRAINER.RANSAC_CONF if config is not None else ransac_conf  # 0.99999
-    data.update({'R_errs': [], 't_errs': [], 'inliers': []})
+    data.update({"R_errs": [], "t_errs": [], "inliers": []})
 
-    m_bids = data['m_bids'].cpu().numpy()
-    pts0 = data['mkpts0_f'].cpu().numpy()
-    pts1 = data['mkpts1_f'].cpu().numpy()
-    K0 = data['K0'].cpu().numpy()
-    K1 = data['K1'].cpu().numpy()
-    T_0to1 = data['T_0to1'].cpu().numpy()
+    m_bids = data["m_bids"].cpu().numpy()
+    pts0 = data["mkpts0_f"].cpu().numpy()
+    pts1 = data["mkpts1_f"].cpu().numpy()
+    K0 = data["K0"].cpu().numpy()
+    K1 = data["K1"].cpu().numpy()
+    T_0to1 = data["T_0to1"].cpu().numpy()
 
     for bs in range(K0.shape[0]):
         mask = m_bids == bs
-        ret = estimate_pose(pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf)
+        ret = estimate_pose(
+            pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf
+        )
 
         if ret is None:
-            data['R_errs'].append(np.inf)
-            data['t_errs'].append(np.inf)
-            data['inliers'].append(np.array([]).astype(np.bool))
+            data["R_errs"].append(np.inf)
+            data["t_errs"].append(np.inf)
+            data["inliers"].append(np.array([]).astype(np.bool))
         else:
             R, t, inliers = ret
             t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0)
-            data['R_errs'].append(R_err)
-            data['t_errs'].append(t_err)
-            data['inliers'].append(inliers)
+            data["R_errs"].append(R_err)
+            data["t_errs"].append(t_err)
+            data["inliers"].append(inliers)
 
 
 # --- METRIC AGGREGATION ---
 
+
 def error_auc(errors, thresholds):
     """
     Args:
@@ -149,11 +162,11 @@ def error_auc(errors, thresholds):
     thresholds = [5, 10, 20]
     for thr in thresholds:
         last_index = np.searchsorted(errors, thr)
-        y = recall[:last_index] + [recall[last_index-1]]
+        y = recall[:last_index] + [recall[last_index - 1]]
         x = errors[:last_index] + [thr]
         aucs.append(np.trapz(y, x) / thr)
 
-    return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)}
+    return {f"auc@{t}": auc for t, auc in zip(thresholds, aucs)}
 
 
 def epidist_prec(errors, thresholds, ret_dict=False):
@@ -165,29 +178,33 @@ def epidist_prec(errors, thresholds, ret_dict=False):
             prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0)
         precs.append(np.mean(prec_) if len(prec_) > 0 else 0)
     if ret_dict:
-        return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)}
+        return {f"prec@{t:.0e}": prec for t, prec in zip(thresholds, precs)}
     else:
         return precs
 
 
 def aggregate_metrics(metrics, epi_err_thr=5e-4):
-    """ Aggregate metrics for the whole dataset:
+    """Aggregate metrics for the whole dataset:
     (This method should be called once per dataset)
     1. AUC of the pose error (angular) at the threshold [5, 10, 20]
     2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth)
     """
     # filter duplicates
-    unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics['identifiers']))
+    unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics["identifiers"]))
     unq_ids = list(unq_ids.values())
-    logger.info(f'Aggregating metrics over {len(unq_ids)} unique items...')
+    logger.info(f"Aggregating metrics over {len(unq_ids)} unique items...")
 
     # pose auc
     angular_thresholds = [5, 10, 20]
-    pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids]
+    pose_errors = np.max(np.stack([metrics["R_errs"], metrics["t_errs"]]), axis=0)[
+        unq_ids
+    ]
     aucs = error_auc(pose_errors, angular_thresholds)  # (auc@5, auc@10, auc@20)
 
     # matching precision
     dist_thresholds = [epi_err_thr]
-    precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True)  # (prec@err_thr)
+    precs = epidist_prec(
+        np.array(metrics["epi_errs"], dtype=object)[unq_ids], dist_thresholds, True
+    )  # (prec@err_thr)
 
     return {**aucs, **precs}
diff --git a/third_party/TopicFM/src/utils/misc.py b/third_party/TopicFM/src/utils/misc.py
index 9c8db04666519753ea2df43903ab6c47ec00a9a1..461077d77f1628c67055d841a5e70c29c7b82ade 100644
--- a/third_party/TopicFM/src/utils/misc.py
+++ b/third_party/TopicFM/src/utils/misc.py
@@ -24,7 +24,7 @@ def upper_config(dict_cfg):
 
 def log_on(condition, message, level):
     if condition:
-        assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL']
+        assert level in ["INFO", "DEBUG", "WARNING", "ERROR", "CRITICAL"]
         logger.log(level, message)
 
 
@@ -34,32 +34,35 @@ def get_rank_zero_only_logger(logger: _Logger):
     else:
         for _level in logger._core.levels.keys():
             level = _level.lower()
-            setattr(logger, level,
-                    lambda x: None)
+            setattr(logger, level, lambda x: None)
         logger._log = lambda x: None
     return logger
 
 
 def setup_gpus(gpus: Union[str, int]) -> int:
-    """ A temporary fix for pytorch-lighting 1.3.x """
+    """A temporary fix for pytorch-lighting 1.3.x"""
     gpus = str(gpus)
     gpu_ids = []
-    
-    if ',' not in gpus:
+
+    if "," not in gpus:
         n_gpus = int(gpus)
         return n_gpus if n_gpus != -1 else torch.cuda.device_count()
     else:
-        gpu_ids = [i.strip() for i in gpus.split(',') if i != '']
-    
+        gpu_ids = [i.strip() for i in gpus.split(",") if i != ""]
+
     # setup environment variables
-    visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
+    visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
     if visible_devices is None:
         os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
-        os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids)
-        visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
-        logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}')
+        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpu_ids)
+        visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
+        logger.warning(
+            f"[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}"
+        )
     else:
-        logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.')
+        logger.warning(
+            "[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process."
+        )
     return len(gpu_ids)
 
 
@@ -70,11 +73,11 @@ def flattenList(x):
 @contextlib.contextmanager
 def tqdm_joblib(tqdm_object):
     """Context manager to patch joblib to report into tqdm progress bar given as argument
-    
+
     Usage:
         with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:
             Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))
-            
+
     When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing)
         ret_vals = Parallel(n_jobs=args.world_size)(
                     delayed(lambda x: _compute_cov_score(pid, *x))(param)
@@ -83,6 +86,7 @@ def tqdm_joblib(tqdm_object):
                                           total=len(image_ids)*(len(image_ids)-1)/2))
     Src: https://stackoverflow.com/a/58936697
     """
+
     class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
         def __init__(self, *args, **kwargs):
             super().__init__(*args, **kwargs)
@@ -98,4 +102,3 @@ def tqdm_joblib(tqdm_object):
     finally:
         joblib.parallel.BatchCompletionCallBack = old_batch_callback
         tqdm_object.close()
-
diff --git a/third_party/TopicFM/src/utils/plotting.py b/third_party/TopicFM/src/utils/plotting.py
index 89b22ef27e6152225d07ab24bb3e62718d180b59..189045409c822f2e1d79610b29ea7e2825ae4bbd 100644
--- a/third_party/TopicFM/src/utils/plotting.py
+++ b/third_party/TopicFM/src/utils/plotting.py
@@ -9,37 +9,49 @@ import torch
 
 
 def _compute_conf_thresh(data):
-    dataset_name = data['dataset_name'][0].lower()
-    if dataset_name == 'scannet':
+    dataset_name = data["dataset_name"][0].lower()
+    if dataset_name == "scannet":
         thr = 5e-4
-    elif dataset_name == 'megadepth':
+    elif dataset_name == "megadepth":
         thr = 1e-4
     else:
-        raise ValueError(f'Unknown dataset: {dataset_name}')
+        raise ValueError(f"Unknown dataset: {dataset_name}")
     return thr
 
 
 # --- VISUALIZATION --- #
 
+
 def make_matching_figure(
-        img0, img1, mkpts0, mkpts1, color,
-        kpts0=None, kpts1=None, text=[], dpi=75, path=None):
+    img0,
+    img1,
+    mkpts0,
+    mkpts1,
+    color,
+    kpts0=None,
+    kpts1=None,
+    text=[],
+    dpi=75,
+    path=None,
+):
     # draw image pair
-    assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
+    assert (
+        mkpts0.shape[0] == mkpts1.shape[0]
+    ), f"mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}"
     fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
     axes[0].imshow(img0)  # , cmap='gray')
     axes[1].imshow(img1)  # , cmap='gray')
-    for i in range(2):   # clear all frames
+    for i in range(2):  # clear all frames
         axes[i].get_yaxis().set_ticks([])
         axes[i].get_xaxis().set_ticks([])
         for spine in axes[i].spines.values():
             spine.set_visible(False)
     plt.tight_layout(pad=1)
-    
+
     if kpts0 is not None:
         assert kpts1 is not None
-        axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=5)
-        axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=5)
+        axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c="w", s=5)
+        axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=5)
 
     # draw matches
     if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
@@ -47,99 +59,112 @@ def make_matching_figure(
         transFigure = fig.transFigure.inverted()
         fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
         fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
-        fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
-                                            (fkpts0[i, 1], fkpts1[i, 1]),
-                                            transform=fig.transFigure, c=color[i], linewidth=2)
-                                        for i in range(len(mkpts0))]
-        
+        fig.lines = [
+            matplotlib.lines.Line2D(
+                (fkpts0[i, 0], fkpts1[i, 0]),
+                (fkpts0[i, 1], fkpts1[i, 1]),
+                transform=fig.transFigure,
+                c=color[i],
+                linewidth=2,
+            )
+            for i in range(len(mkpts0))
+        ]
+
         axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color[..., :3], s=4)
         axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color[..., :3], s=4)
 
     # put txts
-    txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
+    txt_color = "k" if img0[:100, :200].mean() > 200 else "w"
     fig.text(
-        0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
-        fontsize=15, va='top', ha='left', color=txt_color)
+        0.01,
+        0.99,
+        "\n".join(text),
+        transform=fig.axes[0].transAxes,
+        fontsize=15,
+        va="top",
+        ha="left",
+        color=txt_color,
+    )
 
     # save or return figure
     if path:
-        plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
+        plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
         plt.close()
     else:
         return fig
 
 
-def _make_evaluation_figure(data, b_id, alpha='dynamic'):
-    b_mask = data['m_bids'] == b_id
+def _make_evaluation_figure(data, b_id, alpha="dynamic"):
+    b_mask = data["m_bids"] == b_id
     conf_thr = _compute_conf_thresh(data)
-    
-    img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
-    img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
-    kpts0 = data['mkpts0_f'][b_mask].cpu().numpy()
-    kpts1 = data['mkpts1_f'][b_mask].cpu().numpy()
-    
+
+    img0 = (data["image0"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+    img1 = (data["image1"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+    kpts0 = data["mkpts0_f"][b_mask].cpu().numpy()
+    kpts1 = data["mkpts1_f"][b_mask].cpu().numpy()
+
     # for megadepth, we visualize matches on the resized image
-    if 'scale0' in data:
-        kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]]
-        kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]]
+    if "scale0" in data:
+        kpts0 = kpts0 / data["scale0"][b_id].cpu().numpy()[[1, 0]]
+        kpts1 = kpts1 / data["scale1"][b_id].cpu().numpy()[[1, 0]]
 
-    epi_errs = data['epi_errs'][b_mask].cpu().numpy()
+    epi_errs = data["epi_errs"][b_mask].cpu().numpy()
     correct_mask = epi_errs < conf_thr
     precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
     n_correct = np.sum(correct_mask)
-    n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu())
+    n_gt_matches = int(data["conf_matrix_gt"][b_id].sum().cpu())
     recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
     # recall might be larger than 1, since the calculation of conf_matrix_gt
     # uses groundtruth depths and camera poses, but epipolar distance is used here.
 
     # matching info
-    if alpha == 'dynamic':
+    if alpha == "dynamic":
         alpha = dynamic_alpha(len(correct_mask))
     color = error_colormap(epi_errs, conf_thr, alpha=alpha)
-    
+
     text = [
-        f'#Matches {len(kpts0)}',
-        f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
-        f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}'
+        f"#Matches {len(kpts0)}",
+        f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}",
+        f"Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}",
     ]
-    
+
     # make the figure
-    figure = make_matching_figure(img0, img1, kpts0, kpts1,
-                                  color, text=text)
+    figure = make_matching_figure(img0, img1, kpts0, kpts1, color, text=text)
     return figure
 
+
 def _make_confidence_figure(data, b_id):
     # TODO: Implement confidence figure
     raise NotImplementedError()
 
 
-def make_matching_figures(data, config, mode='evaluation'):
-    """ Make matching figures for a batch.
-    
+def make_matching_figures(data, config, mode="evaluation"):
+    """Make matching figures for a batch.
+
     Args:
         data (Dict): a batch updated by PL_LoFTR.
         config (Dict): matcher config
     Returns:
         figures (Dict[str, List[plt.figure]]
     """
-    assert mode in ['evaluation', 'confidence']  # 'confidence'
+    assert mode in ["evaluation", "confidence"]  # 'confidence'
     figures = {mode: []}
-    for b_id in range(data['image0'].size(0)):
-        if mode == 'evaluation':
+    for b_id in range(data["image0"].size(0)):
+        if mode == "evaluation":
             fig = _make_evaluation_figure(
-                data, b_id,
-                alpha=config.TRAINER.PLOT_MATCHES_ALPHA)
-        elif mode == 'confidence':
+                data, b_id, alpha=config.TRAINER.PLOT_MATCHES_ALPHA
+            )
+        elif mode == "confidence":
             fig = _make_confidence_figure(data, b_id)
         else:
-            raise ValueError(f'Unknown plot mode: {mode}')
+            raise ValueError(f"Unknown plot mode: {mode}")
     figures[mode].append(fig)
     return figures
 
 
-def dynamic_alpha(n_matches,
-                  milestones=[0, 300, 1000, 2000],
-                  alphas=[1.0, 0.8, 0.4, 0.2]):
+def dynamic_alpha(
+    n_matches, milestones=[0, 300, 1000, 2000], alphas=[1.0, 0.8, 0.4, 0.2]
+):
     if n_matches == 0:
         return 1.0
     ranges = list(zip(alphas, alphas[1:] + [None]))
@@ -148,14 +173,18 @@ def dynamic_alpha(n_matches,
     if _range[1] is None:
         return _range[0]
     return _range[1] + (milestones[loc + 1] - n_matches) / (
-        milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1])
+        milestones[loc + 1] - milestones[loc]
+    ) * (_range[0] - _range[1])
 
 
 def error_colormap(err, thr, alpha=1.0):
     assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
     x = 1 - np.clip(err / (thr * 2), 0, 1)
     return np.clip(
-        np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
+        np.stack([2 - x * 2, x * 2, np.zeros_like(x), np.ones_like(x) * alpha], -1),
+        0,
+        1,
+    )
 
 
 np.random.seed(1995)
@@ -163,7 +192,9 @@ color_map = np.arange(100)
 np.random.shuffle(color_map)
 
 
-def draw_topics(data, img0, img1, saved_folder="viz_topics", show_n_topics=8, saved_name=None):
+def draw_topics(
+    data, img0, img1, saved_folder="viz_topics", show_n_topics=8, saved_name=None
+):
 
     topic0, topic1 = data["topic_matrix"]["img0"], data["topic_matrix"]["img1"]
     hw0_c, hw1_c = data["hw0_c"], data["hw1_c"]
@@ -188,27 +219,38 @@ def draw_topics(data, img0, img1, saved_folder="viz_topics", show_n_topics=8, sa
     theta1 /= theta1.sum().float()
     # top_topic0 = torch.argsort(theta0, descending=True)[:show_n_topics]
     # top_topic1 = torch.argsort(theta1, descending=True)[:show_n_topics]
-    top_topics = torch.argsort(theta0*theta1, descending=True)[:show_n_topics]
+    top_topics = torch.argsort(theta0 * theta1, descending=True)[:show_n_topics]
     # print(sum_topic0, sum_topic1)
 
-    topic0 = topic0[0].argmax(dim=-1, keepdim=True) #.float() / (n_topics - 1) #* 255 + 1 #
+    topic0 = topic0[0].argmax(
+        dim=-1, keepdim=True
+    )  # .float() / (n_topics - 1) #* 255 + 1 #
     # topic0[~mask0_nonzero] = -1
-    topic1 = topic1[0].argmax(dim=-1, keepdim=True) #.float() / (n_topics - 1) #* 255 + 1
+    topic1 = topic1[0].argmax(
+        dim=-1, keepdim=True
+    )  # .float() / (n_topics - 1) #* 255 + 1
     # topic1[~mask1_nonzero] = -1
     label_img0, label_img1 = torch.zeros_like(topic0) - 1, torch.zeros_like(topic1) - 1
     for i, k in enumerate(top_topics):
         label_img0[topic0 == k] = color_map[k]
         label_img1[topic1 == k] = color_map[k]
 
-#     print(hw0_c, scale0)
-#     print(hw1_c, scale1)
+    #     print(hw0_c, scale0)
+    #     print(hw1_c, scale1)
     # map_topic0 = F.fold(label_img0.unsqueeze(0), hw0_i, kernel_size=scale0, stride=scale0)
-    map_topic0 = label_img0.float().view(hw0_c).cpu().numpy() #map_topic0.squeeze(0).squeeze(0).cpu().numpy()
-    map_topic0 = cv2.resize(map_topic0, (int(hw0_c[1] * scale0[0]), int(hw0_c[0] * scale0[1])))
+    map_topic0 = (
+        label_img0.float().view(hw0_c).cpu().numpy()
+    )  # map_topic0.squeeze(0).squeeze(0).cpu().numpy()
+    map_topic0 = cv2.resize(
+        map_topic0, (int(hw0_c[1] * scale0[0]), int(hw0_c[0] * scale0[1]))
+    )
     # map_topic1 = F.fold(label_img1.unsqueeze(0), hw1_i, kernel_size=scale1, stride=scale1)
-    map_topic1 = label_img1.float().view(hw1_c).cpu().numpy() #map_topic1.squeeze(0).squeeze(0).cpu().numpy()
-    map_topic1 = cv2.resize(map_topic1, (int(hw1_c[1] * scale1[0]), int(hw1_c[0] * scale1[1])))
-
+    map_topic1 = (
+        label_img1.float().view(hw1_c).cpu().numpy()
+    )  # map_topic1.squeeze(0).squeeze(0).cpu().numpy()
+    map_topic1 = cv2.resize(
+        map_topic1, (int(hw1_c[1] * scale1[0]), int(hw1_c[0] * scale1[1]))
+    )
 
     # show image0
     if saved_name is None:
@@ -219,28 +261,57 @@ def draw_topics(data, img0, img1, saved_folder="viz_topics", show_n_topics=8, sa
     path_saved_img0 = os.path.join(saved_folder, "{}_0.png".format(saved_name))
     plt.imshow(img0)
     masked_map_topic0 = np.ma.masked_where(map_topic0 < 0, map_topic0)
-    plt.imshow(masked_map_topic0, cmap=plt.cm.jet, vmin=0, vmax=n_topics-1, alpha=.3, interpolation='bilinear')
+    plt.imshow(
+        masked_map_topic0,
+        cmap=plt.cm.jet,
+        vmin=0,
+        vmax=n_topics - 1,
+        alpha=0.3,
+        interpolation="bilinear",
+    )
     # plt.show()
-    plt.axis('off')
-    plt.savefig(path_saved_img0, bbox_inches='tight', pad_inches=0, dpi=250)
+    plt.axis("off")
+    plt.savefig(path_saved_img0, bbox_inches="tight", pad_inches=0, dpi=250)
     plt.close()
 
     path_saved_img1 = os.path.join(saved_folder, "{}_1.png".format(saved_name))
     plt.imshow(img1)
     masked_map_topic1 = np.ma.masked_where(map_topic1 < 0, map_topic1)
-    plt.imshow(masked_map_topic1, cmap=plt.cm.jet, vmin=0, vmax=n_topics-1, alpha=.3, interpolation='bilinear')
-    plt.axis('off')
-    plt.savefig(path_saved_img1, bbox_inches='tight', pad_inches=0, dpi=250)
+    plt.imshow(
+        masked_map_topic1,
+        cmap=plt.cm.jet,
+        vmin=0,
+        vmax=n_topics - 1,
+        alpha=0.3,
+        interpolation="bilinear",
+    )
+    plt.axis("off")
+    plt.savefig(path_saved_img1, bbox_inches="tight", pad_inches=0, dpi=250)
     plt.close()
 
 
-def draw_topicfm_demo(data, img0, img1, mkpts0, mkpts1, mcolor, text, show_n_topics=8,
-                      topic_alpha=0.3, margin=5, path=None, opencv_display=False, opencv_title=''):
+def draw_topicfm_demo(
+    data,
+    img0,
+    img1,
+    mkpts0,
+    mkpts1,
+    mcolor,
+    text,
+    show_n_topics=8,
+    topic_alpha=0.3,
+    margin=5,
+    path=None,
+    opencv_display=False,
+    opencv_title="",
+):
     topic_map0, topic_map1 = draw_topics(data, img0, img1, show_n_topics=show_n_topics)
 
-    mask_tm0, mask_tm1 = np.expand_dims(topic_map0 >= 0, axis=-1), np.expand_dims(topic_map1 >= 0, axis=-1)
+    mask_tm0, mask_tm1 = np.expand_dims(topic_map0 >= 0, axis=-1), np.expand_dims(
+        topic_map1 >= 0, axis=-1
+    )
 
-    topic_cm0, topic_cm1 = cm.jet(topic_map0 / 99.), cm.jet(topic_map1 / 99.)
+    topic_cm0, topic_cm1 = cm.jet(topic_map0 / 99.0), cm.jet(topic_map1 / 99.0)
     topic_cm0 = cv2.cvtColor(topic_cm0[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR)
     topic_cm1 = cv2.cvtColor(topic_cm1[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR)
     overlay0 = (mask_tm0 * topic_cm0 + (1 - mask_tm0) * img0).astype(np.float32)
@@ -249,7 +320,9 @@ def draw_topicfm_demo(data, img0, img1, mkpts0, mkpts1, mcolor, text, show_n_top
     cv2.addWeighted(overlay0, topic_alpha, img0, 1 - topic_alpha, 0, overlay0)
     cv2.addWeighted(overlay1, topic_alpha, img1, 1 - topic_alpha, 0, overlay1)
 
-    overlay0, overlay1 = (overlay0 * 255).astype(np.uint8), (overlay1 * 255).astype(np.uint8)
+    overlay0, overlay1 = (overlay0 * 255).astype(np.uint8), (overlay1 * 255).astype(
+        np.uint8
+    )
 
     h0, w0 = img0.shape[:2]
     h1, w1 = img1.shape[:2]
@@ -258,19 +331,25 @@ def draw_topicfm_demo(data, img0, img1, mkpts0, mkpts1, mcolor, text, show_n_top
     out_fig[:h0, :w0] = overlay0
     if h0 >= h1:
         start = (h0 - h1) // 2
-        out_fig[start:(start+h1), (w0+margin):(w0+margin+w1)] = overlay1
+        out_fig[start : (start + h1), (w0 + margin) : (w0 + margin + w1)] = overlay1
     else:
         start = (h1 - h0) // 2
-        out_fig[:h0, (w0+margin):(w0+margin+w1)] = overlay1[start:(start+h0)]
+        out_fig[:h0, (w0 + margin) : (w0 + margin + w1)] = overlay1[
+            start : (start + h0)
+        ]
 
     step_h = h0 + margin * 2
-    out_fig[step_h:step_h+h0, :w0] = (img0 * 255).astype(np.uint8)
+    out_fig[step_h : step_h + h0, :w0] = (img0 * 255).astype(np.uint8)
     if h0 >= h1:
         start = step_h + (h0 - h1) // 2
-        out_fig[start:start+h1, (w0+margin):(w0+margin+w1)] = (img1 * 255).astype(np.uint8)
+        out_fig[start : start + h1, (w0 + margin) : (w0 + margin + w1)] = (
+            img1 * 255
+        ).astype(np.uint8)
     else:
         start = (h1 - h0) // 2
-        out_fig[step_h:step_h+h0, (w0+margin):(w0+margin+w1)] = (img1[start:start+h0] * 255).astype(np.uint8)
+        out_fig[step_h : step_h + h0, (w0 + margin) : (w0 + margin + w1)] = (
+            img1[start : start + h0] * 255
+        ).astype(np.uint8)
 
     # draw matching lines, this is inspried from https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/master/models/utils.py
     mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int)
@@ -278,24 +357,53 @@ def draw_topicfm_demo(data, img0, img1, mkpts0, mkpts1, mcolor, text, show_n_top
 
     for (x0, y0), (x1, y1), c in zip(mkpts0, mkpts1, mcolor):
         c = c.tolist()
-        cv2.line(out_fig, (x0, y0+step_h), (x1+margin+w0, y1+step_h+(h0-h1)//2),
-                 color=c, thickness=1, lineType=cv2.LINE_AA)
+        cv2.line(
+            out_fig,
+            (x0, y0 + step_h),
+            (x1 + margin + w0, y1 + step_h + (h0 - h1) // 2),
+            color=c,
+            thickness=1,
+            lineType=cv2.LINE_AA,
+        )
         # display line end-points as circles
-        cv2.circle(out_fig, (x0, y0+step_h), 2, c, -1, lineType=cv2.LINE_AA)
-        cv2.circle(out_fig, (x1+margin+w0, y1+step_h+(h0-h1)//2), 2, c, -1, lineType=cv2.LINE_AA)
+        cv2.circle(out_fig, (x0, y0 + step_h), 2, c, -1, lineType=cv2.LINE_AA)
+        cv2.circle(
+            out_fig,
+            (x1 + margin + w0, y1 + step_h + (h0 - h1) // 2),
+            2,
+            c,
+            -1,
+            lineType=cv2.LINE_AA,
+        )
 
         # Scale factor for consistent visualization across scales.
-    sc = min(h / 960., 2.0)
+    sc = min(h / 960.0, 2.0)
 
     # Big text.
     Ht = int(30 * sc)  # text height
     txt_color_fg = (255, 255, 255)
     txt_color_bg = (0, 0, 0)
     for i, t in enumerate(text):
-        cv2.putText(out_fig, t, (int(8 * sc), Ht + step_h*i), cv2.FONT_HERSHEY_DUPLEX,
-                    1.0 * sc, txt_color_bg, 2, cv2.LINE_AA)
-        cv2.putText(out_fig, t, (int(8 * sc), Ht + step_h*i), cv2.FONT_HERSHEY_DUPLEX,
-                    1.0 * sc, txt_color_fg, 1, cv2.LINE_AA)
+        cv2.putText(
+            out_fig,
+            t,
+            (int(8 * sc), Ht + step_h * i),
+            cv2.FONT_HERSHEY_DUPLEX,
+            1.0 * sc,
+            txt_color_bg,
+            2,
+            cv2.LINE_AA,
+        )
+        cv2.putText(
+            out_fig,
+            t,
+            (int(8 * sc), Ht + step_h * i),
+            cv2.FONT_HERSHEY_DUPLEX,
+            1.0 * sc,
+            txt_color_fg,
+            1,
+            cv2.LINE_AA,
+        )
 
     if path is not None:
         cv2.imwrite(str(path), out_fig)
@@ -305,9 +413,3 @@ def draw_topicfm_demo(data, img0, img1, mkpts0, mkpts1, mcolor, text, show_n_top
         cv2.waitKey(1)
 
     return out_fig
-
-
-
-
-
-
diff --git a/third_party/TopicFM/src/utils/profiler.py b/third_party/TopicFM/src/utils/profiler.py
index 6d21ed79fb506ef09c75483355402c48a195aaa9..0275ea34e3eb9cceb4ed809bebeda209749f5bc5 100644
--- a/third_party/TopicFM/src/utils/profiler.py
+++ b/third_party/TopicFM/src/utils/profiler.py
@@ -7,7 +7,7 @@ from pytorch_lightning.utilities import rank_zero_only
 class InferenceProfiler(SimpleProfiler):
     """
     This profiler records duration of actions with cuda.synchronize()
-    Use this in test time. 
+    Use this in test time.
     """
 
     def __init__(self):
@@ -28,12 +28,13 @@ class InferenceProfiler(SimpleProfiler):
 
 
 def build_profiler(name):
-    if name == 'inference':
+    if name == "inference":
         return InferenceProfiler()
-    elif name == 'pytorch':
+    elif name == "pytorch":
         from pytorch_lightning.profiler import PyTorchProfiler
+
         return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100)
     elif name is None:
         return PassThroughProfiler()
     else:
-        raise ValueError(f'Invalid profiler: {name}')
+        raise ValueError(f"Invalid profiler: {name}")
diff --git a/third_party/TopicFM/test.py b/third_party/TopicFM/test.py
index aeb451cde3674b70b0d2e02f37ff1fd391004d30..7b941ea4f6529c2206d527be85a23523dcf0e148 100644
--- a/third_party/TopicFM/test.py
+++ b/third_party/TopicFM/test.py
@@ -13,29 +13,43 @@ from src.lightning_trainer.trainer import PL_Trainer
 def parse_args():
     # init a costum parser which will be added into pl.Trainer parser
     # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
-    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+    parser.add_argument("data_cfg_path", type=str, help="data config path")
+    parser.add_argument("main_cfg_path", type=str, help="main config path")
     parser.add_argument(
-        'data_cfg_path', type=str, help='data config path')
+        "--ckpt_path",
+        type=str,
+        default="weights/indoor_ds.ckpt",
+        help="path to the checkpoint",
+    )
     parser.add_argument(
-        'main_cfg_path', type=str, help='main config path')
+        "--dump_dir",
+        type=str,
+        default=None,
+        help="if set, the matching results will be dump to dump_dir",
+    )
     parser.add_argument(
-        '--ckpt_path', type=str, default="weights/indoor_ds.ckpt", help='path to the checkpoint')
+        "--profiler_name",
+        type=str,
+        default=None,
+        help="options: [inference, pytorch], or leave it unset",
+    )
+    parser.add_argument("--batch_size", type=int, default=1, help="batch_size per gpu")
+    parser.add_argument("--num_workers", type=int, default=2)
     parser.add_argument(
-        '--dump_dir', type=str, default=None, help="if set, the matching results will be dump to dump_dir")
-    parser.add_argument(
-        '--profiler_name', type=str, default=None, help='options: [inference, pytorch], or leave it unset')
-    parser.add_argument(
-        '--batch_size', type=int, default=1, help='batch_size per gpu')
-    parser.add_argument(
-        '--num_workers', type=int, default=2)
-    parser.add_argument(
-        '--thr', type=float, default=None, help='modify the coarse-level matching threshold.')
+        "--thr",
+        type=float,
+        default=None,
+        help="modify the coarse-level matching threshold.",
+    )
 
     parser = pl.Trainer.add_argparse_args(parser)
     return parser.parse_args()
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     # parse arguments
     args = parse_args()
     pprint.pprint(vars(args))
@@ -54,7 +68,12 @@ if __name__ == '__main__':
 
     # lightning module
     profiler = build_profiler(args.profiler_name)
-    model = PL_Trainer(config, pretrained_ckpt=args.ckpt_path, profiler=profiler, dump_dir=args.dump_dir)
+    model = PL_Trainer(
+        config,
+        pretrained_ckpt=args.ckpt_path,
+        profiler=profiler,
+        dump_dir=args.dump_dir,
+    )
     loguru_logger.info(f"Model-lightning initialized!")
 
     # lightning data
@@ -62,7 +81,9 @@ if __name__ == '__main__':
     loguru_logger.info(f"DataModule initialized!")
 
     # lightning trainer
-    trainer = pl.Trainer.from_argparse_args(args, replace_sampler_ddp=False, logger=False)
+    trainer = pl.Trainer.from_argparse_args(
+        args, replace_sampler_ddp=False, logger=False
+    )
 
     loguru_logger.info(f"Start testing!")
     trainer.test(model, datamodule=data_module, verbose=False)
diff --git a/third_party/TopicFM/train.py b/third_party/TopicFM/train.py
index a552c23718b81ddcb282cedbfe3ceb45e50b3f29..9188b80a3fb407f4871b8147a2c90fa382380e25 100644
--- a/third_party/TopicFM/train.py
+++ b/third_party/TopicFM/train.py
@@ -23,32 +23,43 @@ loguru_logger = get_rank_zero_only_logger(loguru_logger)
 def parse_args():
     # init a costum parser which will be added into pl.Trainer parser
     # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
-    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+    parser.add_argument("data_cfg_path", type=str, help="data config path")
+    parser.add_argument("main_cfg_path", type=str, help="main config path")
+    parser.add_argument("--exp_name", type=str, default="default_exp_name")
+    parser.add_argument("--batch_size", type=int, default=4, help="batch_size per gpu")
+    parser.add_argument("--num_workers", type=int, default=4)
     parser.add_argument(
-        'data_cfg_path', type=str, help='data config path')
+        "--pin_memory",
+        type=lambda x: bool(strtobool(x)),
+        nargs="?",
+        default=True,
+        help="whether loading data to pinned memory or not",
+    )
     parser.add_argument(
-        'main_cfg_path', type=str, help='main config path')
+        "--ckpt_path",
+        type=str,
+        default=None,
+        help="pretrained checkpoint path, helpful for using a pre-trained coarse-only LoFTR",
+    )
     parser.add_argument(
-        '--exp_name', type=str, default='default_exp_name')
+        "--disable_ckpt",
+        action="store_true",
+        help="disable checkpoint saving (useful for debugging).",
+    )
     parser.add_argument(
-        '--batch_size', type=int, default=4, help='batch_size per gpu')
+        "--profiler_name",
+        type=str,
+        default=None,
+        help="options: [inference, pytorch], or leave it unset",
+    )
     parser.add_argument(
-        '--num_workers', type=int, default=4)
-    parser.add_argument(
-        '--pin_memory', type=lambda x: bool(strtobool(x)),
-        nargs='?', default=True, help='whether loading data to pinned memory or not')
-    parser.add_argument(
-        '--ckpt_path', type=str, default=None,
-        help='pretrained checkpoint path, helpful for using a pre-trained coarse-only LoFTR')
-    parser.add_argument(
-        '--disable_ckpt', action='store_true',
-        help='disable checkpoint saving (useful for debugging).')
-    parser.add_argument(
-        '--profiler_name', type=str, default=None,
-        help='options: [inference, pytorch], or leave it unset')
-    parser.add_argument(
-        '--parallel_load_data', action='store_true',
-        help='load datasets in with multiple processes.')
+        "--parallel_load_data",
+        action="store_true",
+        help="load datasets in with multiple processes.",
+    )
 
     parser = pl.Trainer.add_argparse_args(parser)
     return parser.parse_args()
@@ -66,7 +77,7 @@ def main():
     pl.seed_everything(config.TRAINER.SEED)  # reproducibility
     # TODO: Use different seeds for each dataloader workers
     # This is needed for data augmentation
-    
+
     # scale lr and warmup-step automatically
     args.gpus = _n_gpus = setup_gpus(args.gpus)
     config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes
@@ -75,49 +86,59 @@ def main():
     config.TRAINER.SCALING = _scaling
     config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling
     config.TRAINER.WARMUP_STEP = math.floor(config.TRAINER.WARMUP_STEP / _scaling)
-    
+
     # lightning module
     profiler = build_profiler(args.profiler_name)
     model = PL_Trainer(config, pretrained_ckpt=args.ckpt_path, profiler=profiler)
     loguru_logger.info(f"Model LightningModule initialized!")
-    
+
     # lightning data
     data_module = MultiSceneDataModule(args, config)
     loguru_logger.info(f"Model DataModule initialized!")
-    
+
     # TensorBoard Logger
-    logger = TensorBoardLogger(save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False)
-    ckpt_dir = Path(logger.log_dir) / 'checkpoints'
-    
+    logger = TensorBoardLogger(
+        save_dir="logs/tb_logs", name=args.exp_name, default_hp_metric=False
+    )
+    ckpt_dir = Path(logger.log_dir) / "checkpoints"
+
     # Callbacks
     # TODO: update ModelCheckpoint to monitor multiple metrics
-    ckpt_callback = ModelCheckpoint(monitor='auc@10', verbose=True, save_top_k=5, mode='max',
-                                    save_last=True,
-                                    dirpath=str(ckpt_dir),
-                                    filename='{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}')
-    lr_monitor = LearningRateMonitor(logging_interval='step')
+    ckpt_callback = ModelCheckpoint(
+        monitor="auc@10",
+        verbose=True,
+        save_top_k=5,
+        mode="max",
+        save_last=True,
+        dirpath=str(ckpt_dir),
+        filename="{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}",
+    )
+    lr_monitor = LearningRateMonitor(logging_interval="step")
     callbacks = [lr_monitor]
     if not args.disable_ckpt:
         callbacks.append(ckpt_callback)
-    
+
     # Lightning Trainer
     trainer = pl.Trainer.from_argparse_args(
         args,
-        plugins=DDPPlugin(find_unused_parameters=False,
-                          num_nodes=args.num_nodes,
-                          sync_batchnorm=config.TRAINER.WORLD_SIZE > 0),
+        plugins=DDPPlugin(
+            find_unused_parameters=False,
+            num_nodes=args.num_nodes,
+            sync_batchnorm=config.TRAINER.WORLD_SIZE > 0,
+        ),
         gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING,
         callbacks=callbacks,
         logger=logger,
         sync_batchnorm=config.TRAINER.WORLD_SIZE > 0,
         replace_sampler_ddp=False,  # use custom sampler
         reload_dataloaders_every_epoch=False,  # avoid repeated samples!
-        weights_summary='full',
-        profiler=profiler)
+        weights_summary="full",
+        profiler=profiler,
+    )
     loguru_logger.info(f"Trainer initialized!")
     loguru_logger.info(f"Start training!")
     trainer.fit(model, datamodule=data_module)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
diff --git a/third_party/TopicFM/visualization.py b/third_party/TopicFM/visualization.py
index 279b41cd88f61ce3414e2f3077fec642b2c8333a..73ec7dd74e21ac72204484cf8d4f3c6fd56a72a2 100644
--- a/third_party/TopicFM/visualization.py
+++ b/third_party/TopicFM/visualization.py
@@ -15,9 +15,9 @@ from configs.data.base import cfg as data_cfg
 import viz
 
 
-def get_model_config(method_name, dataset_name, root_dir='viz'):
-    config_file = f'{root_dir}/configs/{method_name}.yml'
-    with open(config_file, 'r') as f:
+def get_model_config(method_name, dataset_name, root_dir="viz"):
+    config_file = f"{root_dir}/configs/{method_name}.yml"
+    with open(config_file, "r") as f:
         model_conf = yaml.load(f, Loader=yaml.FullLoader)[dataset_name]
     return model_conf
 
@@ -30,7 +30,10 @@ class DemoDataset(Dataset):
             self.list_img_files.sort()
         else:
             with open(img_file) as f:
-                self.list_img_files = [os.path.join(dataset_dir, img_file.strip()) for img_file in f.readlines()]
+                self.list_img_files = [
+                    os.path.join(dataset_dir, img_file.strip())
+                    for img_file in f.readlines()
+                ]
         self.resize = resize
         self.down_factor = down_factor
 
@@ -38,24 +41,31 @@ class DemoDataset(Dataset):
         return len(self.list_img_files)
 
     def __getitem__(self, idx):
-        img_path = self.list_img_files[idx] #os.path.join(self.dataset_dir, self.list_img_files[idx])
-        img, scale = read_img_gray(img_path, resize=self.resize, down_factor=self.down_factor)
+        img_path = self.list_img_files[
+            idx
+        ]  # os.path.join(self.dataset_dir, self.list_img_files[idx])
+        img, scale = read_img_gray(
+            img_path, resize=self.resize, down_factor=self.down_factor
+        )
         return {"img": img, "id": idx, "img_path": img_path}
 
 
-if __name__ == '__main__':
-    parser = argparse.ArgumentParser(description='Visualize matches')
-    parser.add_argument('--gpu', '-gpu', type=str, default='0')
-    parser.add_argument('--method', type=str, default=None)
-    parser.add_argument('--dataset_dir', type=str, default='data/aachen-day-night')
-    parser.add_argument('--pair_dir', type=str, default=None)
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Visualize matches")
+    parser.add_argument("--gpu", "-gpu", type=str, default="0")
+    parser.add_argument("--method", type=str, default=None)
+    parser.add_argument("--dataset_dir", type=str, default="data/aachen-day-night")
+    parser.add_argument("--pair_dir", type=str, default=None)
     parser.add_argument(
-        '--dataset_name', type=str, choices=['megadepth', 'scannet', 'aachen_v1.1', 'inloc'], default='megadepth'
+        "--dataset_name",
+        type=str,
+        choices=["megadepth", "scannet", "aachen_v1.1", "inloc"],
+        default="megadepth",
     )
-    parser.add_argument('--measure_time', action="store_true")
-    parser.add_argument('--no_viz', action="store_true")
-    parser.add_argument('--compute_eval_metrics', action="store_true")
-    parser.add_argument('--run_demo', action="store_true")
+    parser.add_argument("--measure_time", action="store_true")
+    parser.add_argument("--no_viz", action="store_true")
+    parser.add_argument("--compute_eval_metrics", action="store_true")
+    parser.add_argument("--run_demo", action="store_true")
 
     args = parser.parse_args()
 
@@ -64,26 +74,45 @@ if __name__ == '__main__':
     model = viz.__dict__[class_name](model_cfg)
     # all_args = Namespace(**vars(args), **model_cfg)
     if not args.run_demo:
-        if args.dataset_name == 'megadepth':
+        if args.dataset_name == "megadepth":
             from configs.data.megadepth_test_1500 import cfg
 
             data_cfg.merge_from_other_cfg(cfg)
-        elif args.dataset_name == 'scannet':
+        elif args.dataset_name == "scannet":
             from configs.data.scannet_test_1500 import cfg
 
             data_cfg.merge_from_other_cfg(cfg)
-        elif args.dataset_name == 'aachen_v1.1':
-            data_cfg.merge_from_list(["DATASET.TEST_DATA_SOURCE", "aachen_v1.1",
-                                      "DATASET.TEST_DATA_ROOT", os.path.join(args.dataset_dir, "images/images_upright"),
-                                      "DATASET.TEST_LIST_PATH", args.pair_dir,
-                                      "DATASET.TEST_IMGSIZE", model_cfg["imsize"]])
-        elif args.dataset_name == 'inloc':
-            data_cfg.merge_from_list(["DATASET.TEST_DATA_SOURCE", "inloc",
-                                      "DATASET.TEST_DATA_ROOT", args.dataset_dir,
-                                      "DATASET.TEST_LIST_PATH", args.pair_dir,
-                                      "DATASET.TEST_IMGSIZE", model_cfg["imsize"]])
-
-        has_ground_truth = str(data_cfg.DATASET.TEST_DATA_SOURCE).lower() in ["megadepth", "scannet"]
+        elif args.dataset_name == "aachen_v1.1":
+            data_cfg.merge_from_list(
+                [
+                    "DATASET.TEST_DATA_SOURCE",
+                    "aachen_v1.1",
+                    "DATASET.TEST_DATA_ROOT",
+                    os.path.join(args.dataset_dir, "images/images_upright"),
+                    "DATASET.TEST_LIST_PATH",
+                    args.pair_dir,
+                    "DATASET.TEST_IMGSIZE",
+                    model_cfg["imsize"],
+                ]
+            )
+        elif args.dataset_name == "inloc":
+            data_cfg.merge_from_list(
+                [
+                    "DATASET.TEST_DATA_SOURCE",
+                    "inloc",
+                    "DATASET.TEST_DATA_ROOT",
+                    args.dataset_dir,
+                    "DATASET.TEST_LIST_PATH",
+                    args.pair_dir,
+                    "DATASET.TEST_IMGSIZE",
+                    model_cfg["imsize"],
+                ]
+            )
+
+        has_ground_truth = str(data_cfg.DATASET.TEST_DATA_SOURCE).lower() in [
+            "megadepth",
+            "scannet",
+        ]
         dataloader = TestDataLoader(data_cfg)
         with torch.no_grad():
             for data_dict in tqdm(dataloader):
@@ -91,11 +120,20 @@ if __name__ == '__main__':
                     if isinstance(v, torch.Tensor):
                         data_dict[k] = v.cuda() if torch.cuda.is_available() else v
                 img_root_dir = data_cfg.DATASET.TEST_DATA_ROOT
-                model.match_and_draw(data_dict, root_dir=img_root_dir, ground_truth=has_ground_truth,
-                                     measure_time=args.measure_time, viz_matches=(not args.no_viz))
+                model.match_and_draw(
+                    data_dict,
+                    root_dir=img_root_dir,
+                    ground_truth=has_ground_truth,
+                    measure_time=args.measure_time,
+                    viz_matches=(not args.no_viz),
+                )
 
         if args.measure_time:
-            print("Running time for each image is {} miliseconds".format(model.measure_time()))
+            print(
+                "Running time for each image is {} miliseconds".format(
+                    model.measure_time()
+                )
+            )
         if args.compute_eval_metrics and has_ground_truth:
             model.compute_eval_metrics()
     else:
@@ -103,6 +141,13 @@ if __name__ == '__main__':
         sampler = SequentialSampler(demo_dataset)
         dataloader = DataLoader(demo_dataset, batch_size=1, sampler=sampler)
 
-        writer = cv2.VideoWriter('topicfm_demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 15, (640 * 2 + 5, 480 * 2 + 10))
+        writer = cv2.VideoWriter(
+            "topicfm_demo.mp4",
+            cv2.VideoWriter_fourcc(*"mp4v"),
+            15,
+            (640 * 2 + 5, 480 * 2 + 10),
+        )
 
-        model.run_demo(iter(dataloader), writer) #, output_dir="demo", no_display=True)
+        model.run_demo(
+            iter(dataloader), writer
+        )  # , output_dir="demo", no_display=True)
diff --git a/third_party/TopicFM/viz/methods/base.py b/third_party/TopicFM/viz/methods/base.py
index 377e95134f339459bff3c5a0d30b3bfbc122d978..1dfc23efb5fb49bbf510364599489c9acf1df263 100644
--- a/third_party/TopicFM/viz/methods/base.py
+++ b/third_party/TopicFM/viz/methods/base.py
@@ -14,7 +14,9 @@ def flatten_list(x):
 class Viz(metaclass=ABCMeta):
     def __init__(self):
         super().__init__()
-        self.device = torch.device('cuda:{}'.format(0) if torch.cuda.is_available() else 'cpu')
+        self.device = torch.device(
+            "cuda:{}".format(0) if torch.cuda.is_available() else "cpu"
+        )
         torch.set_grad_enabled(False)
 
         # for evaluation metrics of MegaDepth and ScanNet
@@ -33,11 +35,15 @@ class Viz(metaclass=ABCMeta):
             f"{self.name}",
             f"#Matches: {len(mkpts0)}",
         ]
-        if 'R_errs' in kwargs:
-            text.append(f"$\\Delta$R:{kwargs['R_errs']:.2f}°,  $\\Delta$t:{kwargs['t_errs']:.2f}°",)
+        if "R_errs" in kwargs:
+            text.append(
+                f"$\\Delta$R:{kwargs['R_errs']:.2f}°,  $\\Delta$t:{kwargs['t_errs']:.2f}°",
+            )
 
         if path:
-            make_matching_figure(img0, img1, mkpts0, mkpts1, color, text=text, path=path, dpi=150)
+            make_matching_figure(
+                img0, img1, mkpts0, mkpts1, color, text=text, path=path, dpi=150
+            )
         else:
             return make_matching_figure(img0, img1, mkpts0, mkpts1, color, text=text)
 
@@ -47,11 +53,11 @@ class Viz(metaclass=ABCMeta):
 
     def compute_eval_metrics(self, epi_err_thr=5e-4):
         # metrics: dict of list, numpy
-        _metrics = [o['metrics'] for o in self.eval_stats]
+        _metrics = [o["metrics"] for o in self.eval_stats]
         metrics = {k: flatten_list([_me[k] for _me in _metrics]) for k in _metrics[0]}
 
         val_metrics_4tb = aggregate_metrics(metrics, epi_err_thr)
-        print('\n' + pprint.pformat(val_metrics_4tb))
+        print("\n" + pprint.pformat(val_metrics_4tb))
 
     def measure_time(self):
         if len(self.time_stats) == 0:
diff --git a/third_party/TopicFM/viz/methods/loftr.py b/third_party/TopicFM/viz/methods/loftr.py
index 53d0c00c1a067cee10bf1587197e4780ac8b2eda..29046a2aa95596cbfe9656c3bda6dafcb1a55058 100644
--- a/third_party/TopicFM/viz/methods/loftr.py
+++ b/third_party/TopicFM/viz/methods/loftr.py
@@ -19,20 +19,27 @@ class VizLoFTR(Viz):
 
         # Load model
         conf = dict(default_cfg)
-        conf['match_coarse']['thr'] = self.match_threshold
+        conf["match_coarse"]["thr"] = self.match_threshold
         print(conf)
         self.model = LoFTR(config=conf)
         ckpt_dict = torch.load(args.ckpt)
-        self.model.load_state_dict(ckpt_dict['state_dict'])
+        self.model.load_state_dict(ckpt_dict["state_dict"])
         self.model = self.model.eval().to(self.device)
 
         # Name the method
         # self.ckpt_name = args.ckpt.split('/')[-1].split('.')[0]
-        self.name = 'LoFTR'
+        self.name = "LoFTR"
 
-        print(f'Initialize {self.name}')
+        print(f"Initialize {self.name}")
 
-    def match_and_draw(self, data_dict, root_dir=None, ground_truth=False, measure_time=False, viz_matches=True):
+    def match_and_draw(
+        self,
+        data_dict,
+        root_dir=None,
+        ground_truth=False,
+        measure_time=False,
+        viz_matches=True,
+    ):
         if measure_time:
             torch.cuda.synchronize()
             start = torch.cuda.Event(enable_timing=True)
@@ -45,41 +52,72 @@ class VizLoFTR(Viz):
             torch.cuda.synchronize()
             self.time_stats.append(start.elapsed_time(end))
 
-        kpts0 = data_dict['mkpts0_f'].cpu().numpy()
-        kpts1 = data_dict['mkpts1_f'].cpu().numpy()
+        kpts0 = data_dict["mkpts0_f"].cpu().numpy()
+        kpts1 = data_dict["mkpts1_f"].cpu().numpy()
 
-        img_name0, img_name1 = list(zip(*data_dict['pair_names']))[0]
+        img_name0, img_name1 = list(zip(*data_dict["pair_names"]))[0]
         img0 = cv2.imread(os.path.join(root_dir, img_name0))
         img1 = cv2.imread(os.path.join(root_dir, img_name1))
-        if str(data_dict["dataset_name"][0]).lower() == 'scannet':
+        if str(data_dict["dataset_name"][0]).lower() == "scannet":
             img0 = cv2.resize(img0, (640, 480))
             img1 = cv2.resize(img1, (640, 480))
 
         if viz_matches:
-            saved_name = "_".join([img_name0.split('/')[-1].split('.')[0], img_name1.split('/')[-1].split('.')[0]])
+            saved_name = "_".join(
+                [
+                    img_name0.split("/")[-1].split(".")[0],
+                    img_name1.split("/")[-1].split(".")[0],
+                ]
+            )
             folder_matches = os.path.join(root_dir, "{}_viz_matches".format(self.name))
             if not os.path.exists(folder_matches):
                 os.makedirs(folder_matches)
-            path_to_save_matches = os.path.join(folder_matches, "{}.png".format(saved_name))
+            path_to_save_matches = os.path.join(
+                folder_matches, "{}.png".format(saved_name)
+            )
             if ground_truth:
-                compute_symmetrical_epipolar_errors(data_dict)  # compute epi_errs for each match
-                compute_pose_errors(data_dict)  # compute R_errs, t_errs, pose_errs for each pair
-                epi_errors = data_dict['epi_errs'].cpu().numpy()
-                R_errors, t_errors = data_dict['R_errs'][0], data_dict['t_errs'][0]
+                compute_symmetrical_epipolar_errors(
+                    data_dict
+                )  # compute epi_errs for each match
+                compute_pose_errors(
+                    data_dict
+                )  # compute R_errs, t_errs, pose_errs for each pair
+                epi_errors = data_dict["epi_errs"].cpu().numpy()
+                R_errors, t_errors = data_dict["R_errs"][0], data_dict["t_errs"][0]
 
-                self.draw_matches(kpts0, kpts1, img0, img1, epi_errors, path=path_to_save_matches,
-                                  R_errs=R_errors, t_errs=t_errors)
+                self.draw_matches(
+                    kpts0,
+                    kpts1,
+                    img0,
+                    img1,
+                    epi_errors,
+                    path=path_to_save_matches,
+                    R_errs=R_errors,
+                    t_errs=t_errors,
+                )
 
-                rel_pair_names = list(zip(*data_dict['pair_names']))
-                bs = data_dict['image0'].size(0)
+                rel_pair_names = list(zip(*data_dict["pair_names"]))
+                bs = data_dict["image0"].size(0)
                 metrics = {
                     # to filter duplicate pairs caused by DistributedSampler
-                    'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
-                    'epi_errs': [data_dict['epi_errs'][data_dict['m_bids'] == b].cpu().numpy() for b in range(bs)],
-                    'R_errs': data_dict['R_errs'],
-                    't_errs': data_dict['t_errs'],
-                    'inliers': data_dict['inliers']}
-                self.eval_stats.append({'metrics': metrics})
+                    "identifiers": ["#".join(rel_pair_names[b]) for b in range(bs)],
+                    "epi_errs": [
+                        data_dict["epi_errs"][data_dict["m_bids"] == b].cpu().numpy()
+                        for b in range(bs)
+                    ],
+                    "R_errs": data_dict["R_errs"],
+                    "t_errs": data_dict["t_errs"],
+                    "inliers": data_dict["inliers"],
+                }
+                self.eval_stats.append({"metrics": metrics})
             else:
                 m_conf = 1 - data_dict["mconf"].cpu().numpy()
-                self.draw_matches(kpts0, kpts1, img0, img1, m_conf, path=path_to_save_matches, conf_thr=0.4)
+                self.draw_matches(
+                    kpts0,
+                    kpts1,
+                    img0,
+                    img1,
+                    m_conf,
+                    path=path_to_save_matches,
+                    conf_thr=0.4,
+                )
diff --git a/third_party/TopicFM/viz/methods/patch2pix.py b/third_party/TopicFM/viz/methods/patch2pix.py
index 14a1d345881e2021be97dc5dde91d8bbe1cd18fa..4d2df36f35c5b06ea8d45980e0b6b91e7482c718 100644
--- a/third_party/TopicFM/viz/methods/patch2pix.py
+++ b/third_party/TopicFM/viz/methods/patch2pix.py
@@ -7,7 +7,7 @@ from pathlib import Path
 from .base import Viz
 from src.utils.metrics import compute_symmetrical_epipolar_errors, compute_pose_errors
 
-patch2pix_path = Path(__file__).parent / '../../third_party/patch2pix'
+patch2pix_path = Path(__file__).parent / "../../third_party/patch2pix"
 sys.path.append(str(patch2pix_path))
 from third_party.patch2pix.utils.eval.model_helper import load_model, estimate_matches
 
@@ -21,25 +21,39 @@ class VizPatch2Pix(Viz):
         self.imsize = args.imsize
         self.match_threshold = args.match_threshold
         self.ksize = args.ksize
-        self.model = load_model(args.ckpt, method='patch2pix')
-        self.name = 'Patch2Pix'
-        print(f'Initialize {self.name} with image size {self.imsize}')
+        self.model = load_model(args.ckpt, method="patch2pix")
+        self.name = "Patch2Pix"
+        print(f"Initialize {self.name} with image size {self.imsize}")
 
-    def match_and_draw(self, data_dict, root_dir=None, ground_truth=False, measure_time=False, viz_matches=True):
-        img_name0, img_name1 = list(zip(*data_dict['pair_names']))[0]
+    def match_and_draw(
+        self,
+        data_dict,
+        root_dir=None,
+        ground_truth=False,
+        measure_time=False,
+        viz_matches=True,
+    ):
+        img_name0, img_name1 = list(zip(*data_dict["pair_names"]))[0]
         path_img0 = os.path.join(root_dir, img_name0)
         path_img1 = os.path.join(root_dir, img_name1)
         img0, img1 = cv2.imread(path_img0), cv2.imread(path_img1)
         return_m_upscale = True
-        if str(data_dict["dataset_name"][0]).lower() == 'scannet':
+        if str(data_dict["dataset_name"][0]).lower() == "scannet":
             # self.imsize = 640
             img0 = cv2.resize(img0, tuple(self.imsize))  # (640, 480))
             img1 = cv2.resize(img1, tuple(self.imsize))  # (640, 480))
             return_m_upscale = False
-        outputs = estimate_matches(self.model, path_img0, path_img1,
-                                   ksize=self.ksize, io_thres=self.match_threshold,
-                                   eval_type='fine', imsize=self.imsize,
-                                   return_upscale=return_m_upscale, measure_time=measure_time)
+        outputs = estimate_matches(
+            self.model,
+            path_img0,
+            path_img1,
+            ksize=self.ksize,
+            io_thres=self.match_threshold,
+            eval_type="fine",
+            imsize=self.imsize,
+            return_upscale=return_m_upscale,
+            measure_time=measure_time,
+        )
         if measure_time:
             self.time_stats.append(outputs[-1])
         matches, mconf = outputs[0], outputs[1]
@@ -47,34 +61,71 @@ class VizPatch2Pix(Viz):
         kpts1 = matches[:, 2:4]
 
         if viz_matches:
-            saved_name = "_".join([img_name0.split('/')[-1].split('.')[0], img_name1.split('/')[-1].split('.')[0]])
+            saved_name = "_".join(
+                [
+                    img_name0.split("/")[-1].split(".")[0],
+                    img_name1.split("/")[-1].split(".")[0],
+                ]
+            )
             folder_matches = os.path.join(root_dir, "{}_viz_matches".format(self.name))
             if not os.path.exists(folder_matches):
                 os.makedirs(folder_matches)
-            path_to_save_matches = os.path.join(folder_matches, "{}.png".format(saved_name))
+            path_to_save_matches = os.path.join(
+                folder_matches, "{}.png".format(saved_name)
+            )
 
             if ground_truth:
-                data_dict["mkpts0_f"] = torch.from_numpy(matches[:, :2]).float().to(self.device)
-                data_dict["mkpts1_f"] = torch.from_numpy(matches[:, 2:4]).float().to(self.device)
-                data_dict["m_bids"] = torch.zeros(matches.shape[0], device=self.device, dtype=torch.float32)
-                compute_symmetrical_epipolar_errors(data_dict)  # compute epi_errs for each match
-                compute_pose_errors(data_dict)  # compute R_errs, t_errs, pose_errs for each pair
-                epi_errors = data_dict['epi_errs'].cpu().numpy()
-                R_errors, t_errors = data_dict['R_errs'][0], data_dict['t_errs'][0]
+                data_dict["mkpts0_f"] = (
+                    torch.from_numpy(matches[:, :2]).float().to(self.device)
+                )
+                data_dict["mkpts1_f"] = (
+                    torch.from_numpy(matches[:, 2:4]).float().to(self.device)
+                )
+                data_dict["m_bids"] = torch.zeros(
+                    matches.shape[0], device=self.device, dtype=torch.float32
+                )
+                compute_symmetrical_epipolar_errors(
+                    data_dict
+                )  # compute epi_errs for each match
+                compute_pose_errors(
+                    data_dict
+                )  # compute R_errs, t_errs, pose_errs for each pair
+                epi_errors = data_dict["epi_errs"].cpu().numpy()
+                R_errors, t_errors = data_dict["R_errs"][0], data_dict["t_errs"][0]
 
-                self.draw_matches(kpts0, kpts1, img0, img1, epi_errors, path=path_to_save_matches,
-                                  R_errs=R_errors, t_errs=t_errors)
+                self.draw_matches(
+                    kpts0,
+                    kpts1,
+                    img0,
+                    img1,
+                    epi_errors,
+                    path=path_to_save_matches,
+                    R_errs=R_errors,
+                    t_errs=t_errors,
+                )
 
-                rel_pair_names = list(zip(*data_dict['pair_names']))
-                bs = data_dict['image0'].size(0)
+                rel_pair_names = list(zip(*data_dict["pair_names"]))
+                bs = data_dict["image0"].size(0)
                 metrics = {
                     # to filter duplicate pairs caused by DistributedSampler
-                    'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
-                    'epi_errs': [data_dict['epi_errs'][data_dict['m_bids'] == b].cpu().numpy() for b in range(bs)],
-                    'R_errs': data_dict['R_errs'],
-                    't_errs': data_dict['t_errs'],
-                    'inliers': data_dict['inliers']}
-                self.eval_stats.append({'metrics': metrics})
+                    "identifiers": ["#".join(rel_pair_names[b]) for b in range(bs)],
+                    "epi_errs": [
+                        data_dict["epi_errs"][data_dict["m_bids"] == b].cpu().numpy()
+                        for b in range(bs)
+                    ],
+                    "R_errs": data_dict["R_errs"],
+                    "t_errs": data_dict["t_errs"],
+                    "inliers": data_dict["inliers"],
+                }
+                self.eval_stats.append({"metrics": metrics})
             else:
                 m_conf = 1 - mconf
-                self.draw_matches(kpts0, kpts1, img0, img1, m_conf, path=path_to_save_matches, conf_thr=0.4)
+                self.draw_matches(
+                    kpts0,
+                    kpts1,
+                    img0,
+                    img1,
+                    m_conf,
+                    path=path_to_save_matches,
+                    conf_thr=0.4,
+                )
diff --git a/third_party/TopicFM/viz/methods/topicfm.py b/third_party/TopicFM/viz/methods/topicfm.py
index cd8b1485d5296947a38480cc031c5d7439bf163d..e066dc4e031d47b295c4c14db774643ba0a2f25c 100644
--- a/third_party/TopicFM/viz/methods/topicfm.py
+++ b/third_party/TopicFM/viz/methods/topicfm.py
@@ -26,21 +26,28 @@ class VizTopicFM(Viz):
 
         # Load model
         conf = dict(get_model_cfg())
-        conf['match_coarse']['thr'] = self.match_threshold
-        conf['coarse']['n_samples'] = self.n_sampling_topics
+        conf["match_coarse"]["thr"] = self.match_threshold
+        conf["coarse"]["n_samples"] = self.n_sampling_topics
         print("model config: ", conf)
         self.model = TopicFM(config=conf)
         ckpt_dict = torch.load(args.ckpt)
-        self.model.load_state_dict(ckpt_dict['state_dict'])
+        self.model.load_state_dict(ckpt_dict["state_dict"])
         self.model = self.model.eval().to(self.device)
 
         # Name the method
         # self.ckpt_name = args.ckpt.split('/')[-1].split('.')[0]
-        self.name = 'TopicFM'
-
-        print(f'Initialize {self.name}')
-
-    def match_and_draw(self, data_dict, root_dir=None, ground_truth=False, measure_time=False, viz_matches=True):
+        self.name = "TopicFM"
+
+        print(f"Initialize {self.name}")
+
+    def match_and_draw(
+        self,
+        data_dict,
+        root_dir=None,
+        ground_truth=False,
+        measure_time=False,
+        viz_matches=True,
+    ):
         if measure_time:
             torch.cuda.synchronize()
             start = torch.cuda.Event(enable_timing=True)
@@ -53,86 +60,133 @@ class VizTopicFM(Viz):
             torch.cuda.synchronize()
             self.time_stats.append(start.elapsed_time(end))
 
-        kpts0 = data_dict['mkpts0_f'].cpu().numpy()
-        kpts1 = data_dict['mkpts1_f'].cpu().numpy()
+        kpts0 = data_dict["mkpts0_f"].cpu().numpy()
+        kpts1 = data_dict["mkpts1_f"].cpu().numpy()
 
-        img_name0, img_name1 = list(zip(*data_dict['pair_names']))[0]
+        img_name0, img_name1 = list(zip(*data_dict["pair_names"]))[0]
         img0 = cv2.imread(os.path.join(root_dir, img_name0))
         img1 = cv2.imread(os.path.join(root_dir, img_name1))
-        if str(data_dict["dataset_name"][0]).lower() == 'scannet':
+        if str(data_dict["dataset_name"][0]).lower() == "scannet":
             img0 = cv2.resize(img0, (640, 480))
             img1 = cv2.resize(img1, (640, 480))
 
         if viz_matches:
-            saved_name = "_".join([img_name0.split('/')[-1].split('.')[0], img_name1.split('/')[-1].split('.')[0]])
+            saved_name = "_".join(
+                [
+                    img_name0.split("/")[-1].split(".")[0],
+                    img_name1.split("/")[-1].split(".")[0],
+                ]
+            )
             folder_matches = os.path.join(root_dir, "{}_viz_matches".format(self.name))
             if not os.path.exists(folder_matches):
                 os.makedirs(folder_matches)
-            path_to_save_matches = os.path.join(folder_matches, "{}.png".format(saved_name))
+            path_to_save_matches = os.path.join(
+                folder_matches, "{}.png".format(saved_name)
+            )
 
             if ground_truth:
-                compute_symmetrical_epipolar_errors(data_dict)  # compute epi_errs for each match
-                compute_pose_errors(data_dict)  # compute R_errs, t_errs, pose_errs for each pair
-                epi_errors = data_dict['epi_errs'].cpu().numpy()
-                R_errors, t_errors = data_dict['R_errs'][0], data_dict['t_errs'][0]
-
-                self.draw_matches(kpts0, kpts1, img0, img1, epi_errors, path=path_to_save_matches,
-                                  R_errs=R_errors, t_errs=t_errors)
+                compute_symmetrical_epipolar_errors(
+                    data_dict
+                )  # compute epi_errs for each match
+                compute_pose_errors(
+                    data_dict
+                )  # compute R_errs, t_errs, pose_errs for each pair
+                epi_errors = data_dict["epi_errs"].cpu().numpy()
+                R_errors, t_errors = data_dict["R_errs"][0], data_dict["t_errs"][0]
+
+                self.draw_matches(
+                    kpts0,
+                    kpts1,
+                    img0,
+                    img1,
+                    epi_errors,
+                    path=path_to_save_matches,
+                    R_errs=R_errors,
+                    t_errs=t_errors,
+                )
 
                 # compute evaluation metrics
-                rel_pair_names = list(zip(*data_dict['pair_names']))
-                bs = data_dict['image0'].size(0)
+                rel_pair_names = list(zip(*data_dict["pair_names"]))
+                bs = data_dict["image0"].size(0)
                 metrics = {
                     # to filter duplicate pairs caused by DistributedSampler
-                    'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
-                    'epi_errs': [data_dict['epi_errs'][data_dict['m_bids'] == b].cpu().numpy() for b in range(bs)],
-                    'R_errs': data_dict['R_errs'],
-                    't_errs': data_dict['t_errs'],
-                    'inliers': data_dict['inliers']}
-                self.eval_stats.append({'metrics': metrics})
+                    "identifiers": ["#".join(rel_pair_names[b]) for b in range(bs)],
+                    "epi_errs": [
+                        data_dict["epi_errs"][data_dict["m_bids"] == b].cpu().numpy()
+                        for b in range(bs)
+                    ],
+                    "R_errs": data_dict["R_errs"],
+                    "t_errs": data_dict["t_errs"],
+                    "inliers": data_dict["inliers"],
+                }
+                self.eval_stats.append({"metrics": metrics})
             else:
                 m_conf = 1 - data_dict["mconf"].cpu().numpy()
-                self.draw_matches(kpts0, kpts1, img0, img1, m_conf, path=path_to_save_matches, conf_thr=0.4)
+                self.draw_matches(
+                    kpts0,
+                    kpts1,
+                    img0,
+                    img1,
+                    m_conf,
+                    path=path_to_save_matches,
+                    conf_thr=0.4,
+                )
             if self.show_n_topics > 0:
-                folder_topics = os.path.join(root_dir, "{}_viz_topics".format(self.name))
+                folder_topics = os.path.join(
+                    root_dir, "{}_viz_topics".format(self.name)
+                )
                 if not os.path.exists(folder_topics):
                     os.makedirs(folder_topics)
-                draw_topics(data_dict, img0, img1, saved_folder=folder_topics, show_n_topics=self.show_n_topics,
-                            saved_name=saved_name)
-
-    def run_demo(self, dataloader, writer=None, output_dir=None, no_display=False, skip_frames=1):
+                draw_topics(
+                    data_dict,
+                    img0,
+                    img1,
+                    saved_folder=folder_topics,
+                    show_n_topics=self.show_n_topics,
+                    saved_name=saved_name,
+                )
+
+    def run_demo(
+        self, dataloader, writer=None, output_dir=None, no_display=False, skip_frames=1
+    ):
         data_dict = next(dataloader)
 
         frame_id = 0
         last_image_id = 0
-        img0 = np.array(cv2.imread(str(data_dict["img_path"][0])), dtype=np.float32) / 255
+        img0 = (
+            np.array(cv2.imread(str(data_dict["img_path"][0])), dtype=np.float32) / 255
+        )
         frame_tensor = data_dict["img"].to(self.device)
-        pair_data = {'image0': frame_tensor}
-        last_frame = cv2.resize(img0, (frame_tensor.shape[-1], frame_tensor.shape[-2]), cv2.INTER_LINEAR)
+        pair_data = {"image0": frame_tensor}
+        last_frame = cv2.resize(
+            img0, (frame_tensor.shape[-1], frame_tensor.shape[-2]), cv2.INTER_LINEAR
+        )
 
         if output_dir is not None:
-            print('==> Will write outputs to {}'.format(output_dir))
+            print("==> Will write outputs to {}".format(output_dir))
             Path(output_dir).mkdir(exist_ok=True)
 
         # Create a window to display the demo.
         if not no_display:
-            window_name = 'Topic-assisted Feature Matching'
+            window_name = "Topic-assisted Feature Matching"
             cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
             cv2.resizeWindow(window_name, (640 * 2, 480 * 2))
         else:
-            print('Skipping visualization, will not show a GUI.')
+            print("Skipping visualization, will not show a GUI.")
 
         # Print the keyboard help menu.
-        print('==> Keyboard control:\n'
-              '\tn: select the current frame as the reference image (left)\n'
-              '\tq: quit')
+        print(
+            "==> Keyboard control:\n"
+            "\tn: select the current frame as the reference image (left)\n"
+            "\tq: quit"
+        )
 
         # vis_range = [kwargs["bottom_k"], kwargs["top_k"]]
 
         while True:
             frame_id += 1
             if frame_id == len(dataloader):
-                print('Finished demo_loftr.py')
+                print("Finished demo_loftr.py")
                 break
             data_dict = next(dataloader)
             if frame_id % skip_frames != 0:
@@ -140,17 +194,24 @@ class VizTopicFM(Viz):
                 continue
 
             stem0, stem1 = last_image_id, data_dict["id"][0].item() - 1
-            frame = np.array(cv2.imread(str(data_dict["img_path"][0])), dtype=np.float32) / 255
+            frame = (
+                np.array(cv2.imread(str(data_dict["img_path"][0])), dtype=np.float32)
+                / 255
+            )
 
             frame_tensor = data_dict["img"].to(self.device)
-            frame = cv2.resize(frame, (frame_tensor.shape[-1], frame_tensor.shape[-2]), interpolation=cv2.INTER_LINEAR)
-            pair_data = {**pair_data, 'image1': frame_tensor}
+            frame = cv2.resize(
+                frame,
+                (frame_tensor.shape[-1], frame_tensor.shape[-2]),
+                interpolation=cv2.INTER_LINEAR,
+            )
+            pair_data = {**pair_data, "image1": frame_tensor}
             self.model(pair_data)
 
-            total_n_matches = len(pair_data['mkpts0_f'])
-            mkpts0 = pair_data['mkpts0_f'].cpu().numpy()  # [vis_range[0]:vis_range[1]]
-            mkpts1 = pair_data['mkpts1_f'].cpu().numpy()  # [vis_range[0]:vis_range[1]]
-            mconf = pair_data['mconf'].cpu().numpy()  # [vis_range[0]:vis_range[1]]
+            total_n_matches = len(pair_data["mkpts0_f"])
+            mkpts0 = pair_data["mkpts0_f"].cpu().numpy()  # [vis_range[0]:vis_range[1]]
+            mkpts1 = pair_data["mkpts1_f"].cpu().numpy()  # [vis_range[0]:vis_range[1]]
+            mconf = pair_data["mconf"].cpu().numpy()  # [vis_range[0]:vis_range[1]]
 
             # Normalize confidence.
             if len(mconf) > 0:
@@ -161,33 +222,42 @@ class VizTopicFM(Viz):
             color = error_colormap(mconf, thr=0.4, alpha=0.1)
 
             text = [
-                f'Topics',
-                '#Matches: {}'.format(total_n_matches),
+                f"Topics",
+                "#Matches: {}".format(total_n_matches),
             ]
 
-            out = draw_topicfm_demo(pair_data, last_frame, frame, mkpts0, mkpts1, color, text,
-                                    show_n_topics=4, path=None)
+            out = draw_topicfm_demo(
+                pair_data,
+                last_frame,
+                frame,
+                mkpts0,
+                mkpts1,
+                color,
+                text,
+                show_n_topics=4,
+                path=None,
+            )
 
             if not no_display:
                 if writer is not None:
                     writer.write(out)
-                cv2.imshow('TopicFM Matches', out)
+                cv2.imshow("TopicFM Matches", out)
                 key = chr(cv2.waitKey(10) & 0xFF)
-                if key == 'q':
+                if key == "q":
                     if writer is not None:
                         writer.release()
-                    print('Exiting...')
+                    print("Exiting...")
                     break
-                elif key == 'n':
-                    pair_data['image0'] = frame_tensor
+                elif key == "n":
+                    pair_data["image0"] = frame_tensor
                     last_frame = frame
-                    last_image_id = (data_dict["id"][0].item() - 1)
+                    last_image_id = data_dict["id"][0].item() - 1
                     frame_id_left = frame_id
 
             elif output_dir is not None:
-                stem = 'matches_{:06}_{:06}'.format(stem0, stem1)
-                out_file = str(Path(output_dir, stem + '.png'))
-                print('\nWriting image to {}'.format(out_file))
+                stem = "matches_{:06}_{:06}".format(stem0, stem1)
+                out_file = str(Path(output_dir, stem + ".png"))
+                print("\nWriting image to {}".format(out_file))
                 cv2.imwrite(out_file, out)
             else:
                 raise ValueError("output_dir is required when no display is given.")
@@ -195,4 +265,3 @@ class VizTopicFM(Viz):
         cv2.destroyAllWindows()
         if writer is not None:
             writer.release()
-
diff --git a/third_party/d2net/extract_features.py b/third_party/d2net/extract_features.py
index 628463a7d042a90b5cadea8a317237cde86f5ae4..ebcac0889d084c59d86bb21ed80d1e1ed8f17d8d 100644
--- a/third_party/d2net/extract_features.py
+++ b/third_party/d2net/extract_features.py
@@ -21,49 +21,55 @@ use_cuda = torch.cuda.is_available()
 device = torch.device("cuda:0" if use_cuda else "cpu")
 
 # Argument parsing
-parser = argparse.ArgumentParser(description='Feature extraction script')
+parser = argparse.ArgumentParser(description="Feature extraction script")
 
 parser.add_argument(
-    '--image_list_file', type=str, required=True,
-    help='path to a file containing a list of images to process'
+    "--image_list_file",
+    type=str,
+    required=True,
+    help="path to a file containing a list of images to process",
 )
 
 parser.add_argument(
-    '--preprocessing', type=str, default='caffe',
-    help='image preprocessing (caffe or torch)'
+    "--preprocessing",
+    type=str,
+    default="caffe",
+    help="image preprocessing (caffe or torch)",
 )
 parser.add_argument(
-    '--model_file', type=str, default='models/d2_tf.pth',
-    help='path to the full model'
+    "--model_file", type=str, default="models/d2_tf.pth", help="path to the full model"
 )
 
 parser.add_argument(
-    '--max_edge', type=int, default=1600,
-    help='maximum image size at network input'
+    "--max_edge", type=int, default=1600, help="maximum image size at network input"
 )
 parser.add_argument(
-    '--max_sum_edges', type=int, default=2800,
-    help='maximum sum of image sizes at network input'
+    "--max_sum_edges",
+    type=int,
+    default=2800,
+    help="maximum sum of image sizes at network input",
 )
 
 parser.add_argument(
-    '--output_extension', type=str, default='.d2-net',
-    help='extension for the output'
+    "--output_extension", type=str, default=".d2-net", help="extension for the output"
 )
 parser.add_argument(
-    '--output_type', type=str, default='npz',
-    help='output file type (npz or mat)'
+    "--output_type", type=str, default="npz", help="output file type (npz or mat)"
 )
 
 parser.add_argument(
-    '--multiscale', dest='multiscale', action='store_true',
-    help='extract multiscale features'
+    "--multiscale",
+    dest="multiscale",
+    action="store_true",
+    help="extract multiscale features",
 )
 parser.set_defaults(multiscale=False)
 
 parser.add_argument(
-    '--no-relu', dest='use_relu', action='store_false',
-    help='remove ReLU after the dense feature extraction module'
+    "--no-relu",
+    dest="use_relu",
+    action="store_false",
+    help="remove ReLU after the dense feature extraction module",
 )
 parser.set_defaults(use_relu=True)
 
@@ -72,14 +78,10 @@ args = parser.parse_args()
 print(args)
 
 # Creating CNN model
-model = D2Net(
-    model_file=args.model_file,
-    use_relu=args.use_relu,
-    use_cuda=use_cuda
-)
+model = D2Net(model_file=args.model_file, use_relu=args.use_relu, use_cuda=use_cuda)
 
 # Process the file
-with open(args.image_list_file, 'r') as f:
+with open(args.image_list_file, "r") as f:
     lines = f.readlines()
 for line in tqdm(lines, total=len(lines)):
     path = line.strip()
@@ -93,39 +95,32 @@ for line in tqdm(lines, total=len(lines)):
     resized_image = image
     if max(resized_image.shape) > args.max_edge:
         resized_image = scipy.misc.imresize(
-            resized_image,
-            args.max_edge / max(resized_image.shape)
-        ).astype('float')
-    if sum(resized_image.shape[: 2]) > args.max_sum_edges:
+            resized_image, args.max_edge / max(resized_image.shape)
+        ).astype("float")
+    if sum(resized_image.shape[:2]) > args.max_sum_edges:
         resized_image = scipy.misc.imresize(
-            resized_image,
-            args.max_sum_edges / sum(resized_image.shape[: 2])
-        ).astype('float')
+            resized_image, args.max_sum_edges / sum(resized_image.shape[:2])
+        ).astype("float")
 
     fact_i = image.shape[0] / resized_image.shape[0]
     fact_j = image.shape[1] / resized_image.shape[1]
 
-    input_image = preprocess_image(
-        resized_image,
-        preprocessing=args.preprocessing
-    )
+    input_image = preprocess_image(resized_image, preprocessing=args.preprocessing)
     with torch.no_grad():
         if args.multiscale:
             keypoints, scores, descriptors = process_multiscale(
                 torch.tensor(
-                    input_image[np.newaxis, :, :, :].astype(np.float32),
-                    device=device
+                    input_image[np.newaxis, :, :, :].astype(np.float32), device=device
                 ),
-                model
+                model,
             )
         else:
             keypoints, scores, descriptors = process_multiscale(
                 torch.tensor(
-                    input_image[np.newaxis, :, :, :].astype(np.float32),
-                    device=device
+                    input_image[np.newaxis, :, :, :].astype(np.float32), device=device
                 ),
                 model,
-                scales=[1]
+                scales=[1],
             )
 
     # Input image coordinates
@@ -134,23 +129,16 @@ for line in tqdm(lines, total=len(lines)):
     # i, j -> u, v
     keypoints = keypoints[:, [1, 0, 2]]
 
-    if args.output_type == 'npz':
-        with open(path + args.output_extension, 'wb') as output_file:
+    if args.output_type == "npz":
+        with open(path + args.output_extension, "wb") as output_file:
             np.savez(
-                output_file,
-                keypoints=keypoints,
-                scores=scores,
-                descriptors=descriptors
+                output_file, keypoints=keypoints, scores=scores, descriptors=descriptors
             )
-    elif args.output_type == 'mat':
-        with open(path + args.output_extension, 'wb') as output_file:
+    elif args.output_type == "mat":
+        with open(path + args.output_extension, "wb") as output_file:
             scipy.io.savemat(
                 output_file,
-                {
-                    'keypoints': keypoints,
-                    'scores': scores,
-                    'descriptors': descriptors
-                }
+                {"keypoints": keypoints, "scores": scores, "descriptors": descriptors},
             )
     else:
-        raise ValueError('Unknown output type.')
+        raise ValueError("Unknown output type.")
diff --git a/third_party/d2net/extract_kapture.py b/third_party/d2net/extract_kapture.py
index 23198b978229c699dbe24cd3bc0400d62bcab030..bad6ad4254238b9c9425243ff80f830bc4f02198 100644
--- a/third_party/d2net/extract_kapture.py
+++ b/third_party/d2net/extract_kapture.py
@@ -13,9 +13,21 @@ from os import path
 import kapture
 from kapture.io.records import get_image_fullpath
 from kapture.io.csv import kapture_from_dir, get_all_tar_handlers
-from kapture.io.csv import get_feature_csv_fullpath, keypoints_to_file, descriptors_to_file
-from kapture.io.features import get_keypoints_fullpath, keypoints_check_dir, image_keypoints_to_file
-from kapture.io.features import get_descriptors_fullpath, descriptors_check_dir, image_descriptors_to_file
+from kapture.io.csv import (
+    get_feature_csv_fullpath,
+    keypoints_to_file,
+    descriptors_to_file,
+)
+from kapture.io.features import (
+    get_keypoints_fullpath,
+    keypoints_check_dir,
+    image_keypoints_to_file,
+)
+from kapture.io.features import (
+    get_descriptors_fullpath,
+    descriptors_check_dir,
+    image_descriptors_to_file,
+)
 
 from lib.model_test import D2Net
 from lib.utils import preprocess_image
@@ -28,68 +40,89 @@ use_cuda = torch.cuda.is_available()
 device = torch.device("cuda:0" if use_cuda else "cpu")
 
 # Argument parsing
-parser = argparse.ArgumentParser(description='Feature extraction script')
+parser = argparse.ArgumentParser(description="Feature extraction script")
 
 parser.add_argument(
-    '--kapture-root', type=str, required=True,
-    help='path to kapture root directory'
+    "--kapture-root", type=str, required=True, help="path to kapture root directory"
 )
 
 parser.add_argument(
-    '--preprocessing', type=str, default='caffe',
-    help='image preprocessing (caffe or torch)'
+    "--preprocessing",
+    type=str,
+    default="caffe",
+    help="image preprocessing (caffe or torch)",
 )
 parser.add_argument(
-    '--model_file', type=str, default='models/d2_tf.pth',
-    help='path to the full model'
+    "--model_file", type=str, default="models/d2_tf.pth", help="path to the full model"
 )
 parser.add_argument(
-    '--keypoints-type', type=str, default=None,
-    help='keypoint type_name, default is filename of model'
+    "--keypoints-type",
+    type=str,
+    default=None,
+    help="keypoint type_name, default is filename of model",
 )
 parser.add_argument(
-    '--descriptors-type', type=str, default=None,
-    help='descriptors type_name, default is filename of model'
+    "--descriptors-type",
+    type=str,
+    default=None,
+    help="descriptors type_name, default is filename of model",
 )
 
 parser.add_argument(
-    '--max_edge', type=int, default=1600,
-    help='maximum image size at network input'
+    "--max_edge", type=int, default=1600, help="maximum image size at network input"
 )
 parser.add_argument(
-    '--max_sum_edges', type=int, default=2800,
-    help='maximum sum of image sizes at network input'
+    "--max_sum_edges",
+    type=int,
+    default=2800,
+    help="maximum sum of image sizes at network input",
 )
 
 parser.add_argument(
-    '--multiscale', dest='multiscale', action='store_true',
-    help='extract multiscale features'
+    "--multiscale",
+    dest="multiscale",
+    action="store_true",
+    help="extract multiscale features",
 )
 parser.set_defaults(multiscale=False)
 
 parser.add_argument(
-    '--no-relu', dest='use_relu', action='store_false',
-    help='remove ReLU after the dense feature extraction module'
+    "--no-relu",
+    dest="use_relu",
+    action="store_false",
+    help="remove ReLU after the dense feature extraction module",
 )
 parser.set_defaults(use_relu=True)
 
-parser.add_argument("--max-keypoints", type=int, default=float("+inf"),
-                    help='max number of keypoints save to disk')
+parser.add_argument(
+    "--max-keypoints",
+    type=int,
+    default=float("+inf"),
+    help="max number of keypoints save to disk",
+)
 
 args = parser.parse_args()
 
 print(args)
-with get_all_tar_handlers(args.kapture_root,
-                          mode={kapture.Keypoints: 'a',
-                                kapture.Descriptors: 'a',
-                                kapture.GlobalFeatures: 'r',
-                                kapture.Matches: 'r'}) as tar_handlers:
-    kdata = kapture_from_dir(args.kapture_root,
-                             skip_list=[kapture.GlobalFeatures,
-                                        kapture.Matches,
-                                        kapture.Points3d,
-                                        kapture.Observations],
-                             tar_handlers=tar_handlers)
+with get_all_tar_handlers(
+    args.kapture_root,
+    mode={
+        kapture.Keypoints: "a",
+        kapture.Descriptors: "a",
+        kapture.GlobalFeatures: "r",
+        kapture.Matches: "r",
+    },
+) as tar_handlers:
+    kdata = kapture_from_dir(
+        args.kapture_root,
+        skip_list=[
+            kapture.GlobalFeatures,
+            kapture.Matches,
+            kapture.Points3d,
+            kapture.Observations,
+        ],
+        tar_handlers=tar_handlers,
+    )
     if kdata.keypoints is None:
         kdata.keypoints = {}
     if kdata.descriptors is None:
@@ -99,28 +132,29 @@ with get_all_tar_handlers(args.kapture_root,
     image_list = [filename for _, _, filename in kapture.flatten(kdata.records_camera)]
     if args.keypoints_type is None:
         args.keypoints_type = path.splitext(path.basename(args.model_file))[0]
-        print(f'keypoints_type set to {args.keypoints_type}')
+        print(f"keypoints_type set to {args.keypoints_type}")
     if args.descriptors_type is None:
         args.descriptors_type = path.splitext(path.basename(args.model_file))[0]
-        print(f'descriptors_type set to {args.descriptors_type}')
-    if args.keypoints_type in kdata.keypoints and args.descriptors_type in kdata.descriptors:
-        image_list = [name
-                      for name in image_list
-                      if name not in kdata.keypoints[args.keypoints_type] or
-                      name not in kdata.descriptors[args.descriptors_type]]
+        print(f"descriptors_type set to {args.descriptors_type}")
+    if (
+        args.keypoints_type in kdata.keypoints
+        and args.descriptors_type in kdata.descriptors
+    ):
+        image_list = [
+            name
+            for name in image_list
+            if name not in kdata.keypoints[args.keypoints_type]
+            or name not in kdata.descriptors[args.descriptors_type]
+        ]
 
     if len(image_list) == 0:
-        print('All features were already extracted')
+        print("All features were already extracted")
         exit(0)
     else:
-        print(f'Extracting d2net features for {len(image_list)} images')
+        print(f"Extracting d2net features for {len(image_list)} images")
 
     # Creating CNN model
-    model = D2Net(
-        model_file=args.model_file,
-        use_relu=args.use_relu,
-        use_cuda=use_cuda
-    )
+    model = D2Net(model_file=args.model_file, use_relu=args.use_relu, use_cuda=use_cuda)
 
     if args.keypoints_type not in kdata.keypoints:
         keypoints_dtype = None
@@ -138,7 +172,7 @@ with get_all_tar_handlers(args.kapture_root,
     # Process the files
     for image_name in tqdm(image_list, total=len(image_list)):
         img_path = get_image_fullpath(args.kapture_root, image_name)
-        image = Image.open(img_path).convert('RGB')
+        image = Image.open(img_path).convert("RGB")
 
         width, height = image.size
 
@@ -162,30 +196,27 @@ with get_all_tar_handlers(args.kapture_root,
         fact_i = width / resized_width
         fact_j = height / resized_height
 
-        resized_image = np.array(resized_image).astype('float')
+        resized_image = np.array(resized_image).astype("float")
 
-        input_image = preprocess_image(
-            resized_image,
-            preprocessing=args.preprocessing
-        )
+        input_image = preprocess_image(resized_image, preprocessing=args.preprocessing)
 
         with torch.no_grad():
             if args.multiscale:
                 keypoints, scores, descriptors = process_multiscale(
                     torch.tensor(
                         input_image[np.newaxis, :, :, :].astype(np.float32),
-                        device=device
+                        device=device,
                     ),
-                    model
+                    model,
                 )
             else:
                 keypoints, scores, descriptors = process_multiscale(
                     torch.tensor(
                         input_image[np.newaxis, :, :, :].astype(np.float32),
-                        device=device
+                        device=device,
                     ),
                     model,
-                    scales=[1]
+                    scales=[1],
                 )
 
         # Input image coordinates
@@ -196,7 +227,7 @@ with get_all_tar_handlers(args.kapture_root,
 
         if args.max_keypoints != float("+inf"):
             # keep the last (the highest) indexes
-            idx_keep = scores.argsort()[-min(len(keypoints), args.max_keypoints):]
+            idx_keep = scores.argsort()[-min(len(keypoints), args.max_keypoints) :]
             keypoints = keypoints[idx_keep]
             descriptors = descriptors[idx_keep]
 
@@ -207,42 +238,65 @@ with get_all_tar_handlers(args.kapture_root,
             keypoints_dsize = keypoints.shape[1]
             descriptors_dsize = descriptors.shape[1]
 
-            kdata.keypoints[args.keypoints_type] = kapture.Keypoints('d2net', keypoints_dtype, keypoints_dsize)
-            kdata.descriptors[args.descriptors_type] = kapture.Descriptors('d2net', descriptors_dtype,
-                                                                           descriptors_dsize,
-                                                                           args.keypoints_type, 'L2')
-
-            keypoints_config_absolute_path = get_feature_csv_fullpath(kapture.Keypoints,
-                                                                      args.keypoints_type,
-                                                                      args.kapture_root)
-            descriptors_config_absolute_path = get_feature_csv_fullpath(kapture.Descriptors,
-                                                                        args.descriptors_type,
-                                                                        args.kapture_root)
-
-            keypoints_to_file(keypoints_config_absolute_path, kdata.keypoints[args.keypoints_type])
-            descriptors_to_file(descriptors_config_absolute_path, kdata.descriptors[args.descriptors_type])
+            kdata.keypoints[args.keypoints_type] = kapture.Keypoints(
+                "d2net", keypoints_dtype, keypoints_dsize
+            )
+            kdata.descriptors[args.descriptors_type] = kapture.Descriptors(
+                "d2net", descriptors_dtype, descriptors_dsize, args.keypoints_type, "L2"
+            )
+
+            keypoints_config_absolute_path = get_feature_csv_fullpath(
+                kapture.Keypoints, args.keypoints_type, args.kapture_root
+            )
+            descriptors_config_absolute_path = get_feature_csv_fullpath(
+                kapture.Descriptors, args.descriptors_type, args.kapture_root
+            )
+
+            keypoints_to_file(
+                keypoints_config_absolute_path, kdata.keypoints[args.keypoints_type]
+            )
+            descriptors_to_file(
+                descriptors_config_absolute_path,
+                kdata.descriptors[args.descriptors_type],
+            )
         else:
             assert kdata.keypoints[args.keypoints_type].dtype == keypoints.dtype
             assert kdata.descriptors[args.descriptors_type].dtype == descriptors.dtype
             assert kdata.keypoints[args.keypoints_type].dsize == keypoints.shape[1]
-            assert kdata.descriptors[args.descriptors_type].dsize == descriptors.shape[1]
-            assert kdata.descriptors[args.descriptors_type].keypoints_type == args.keypoints_type
-            assert kdata.descriptors[args.descriptors_type].metric_type == 'L2'
-
-        keypoints_fullpath = get_keypoints_fullpath(args.keypoints_type, args.kapture_root,
-                                                    image_name, tar_handlers)
+            assert (
+                kdata.descriptors[args.descriptors_type].dsize == descriptors.shape[1]
+            )
+            assert (
+                kdata.descriptors[args.descriptors_type].keypoints_type
+                == args.keypoints_type
+            )
+            assert kdata.descriptors[args.descriptors_type].metric_type == "L2"
+
+        keypoints_fullpath = get_keypoints_fullpath(
+            args.keypoints_type, args.kapture_root, image_name, tar_handlers
+        )
         print(f"Saving {keypoints.shape[0]} keypoints to {keypoints_fullpath}")
         image_keypoints_to_file(keypoints_fullpath, keypoints)
         kdata.keypoints[args.keypoints_type].add(image_name)
 
-        descriptors_fullpath = get_descriptors_fullpath(args.descriptors_type, args.kapture_root,
-                                                        image_name, tar_handlers)
+        descriptors_fullpath = get_descriptors_fullpath(
+            args.descriptors_type, args.kapture_root, image_name, tar_handlers
+        )
         print(f"Saving {descriptors.shape[0]} descriptors to {descriptors_fullpath}")
         image_descriptors_to_file(descriptors_fullpath, descriptors)
         kdata.descriptors[args.descriptors_type].add(image_name)
 
-    if not keypoints_check_dir(kdata.keypoints[args.keypoints_type], args.keypoints_type,
-                               args.kapture_root, tar_handlers) or \
-        not descriptors_check_dir(kdata.descriptors[args.descriptors_type], args.descriptors_type,
-                                  args.kapture_root, tar_handlers):
-        print('local feature extraction ended successfully but not all files were saved')
+    if not keypoints_check_dir(
+        kdata.keypoints[args.keypoints_type],
+        args.keypoints_type,
+        args.kapture_root,
+        tar_handlers,
+    ) or not descriptors_check_dir(
+        kdata.descriptors[args.descriptors_type],
+        args.descriptors_type,
+        args.kapture_root,
+        tar_handlers,
+    ):
+        print(
+            "local feature extraction ended successfully but not all files were saved"
+        )
diff --git a/third_party/d2net/megadepth_utils/preprocess_scene.py b/third_party/d2net/megadepth_utils/preprocess_scene.py
index fc68a403795e7cddce88dfcb74b38d19ab09e133..5364058829b7e45eabd61a32a591711645fc1ded 100644
--- a/third_party/d2net/megadepth_utils/preprocess_scene.py
+++ b/third_party/d2net/megadepth_utils/preprocess_scene.py
@@ -6,78 +6,63 @@ import numpy as np
 
 import os
 
-parser = argparse.ArgumentParser(description='MegaDepth preprocessing script')
+parser = argparse.ArgumentParser(description="MegaDepth preprocessing script")
 
-parser.add_argument(
-    '--base_path', type=str, required=True,
-    help='path to MegaDepth'
-)
-parser.add_argument(
-    '--scene_id', type=str, required=True,
-    help='scene ID'
-)
+parser.add_argument("--base_path", type=str, required=True, help="path to MegaDepth")
+parser.add_argument("--scene_id", type=str, required=True, help="scene ID")
 
 parser.add_argument(
-    '--output_path', type=str, required=True,
-    help='path to the output directory'
+    "--output_path", type=str, required=True, help="path to the output directory"
 )
 
 args = parser.parse_args()
 
 base_path = args.base_path
 # Remove the trailing / if need be.
-if base_path[-1] in ['/', '\\']:
-    base_path = base_path[: - 1]
+if base_path[-1] in ["/", "\\"]:
+    base_path = base_path[:-1]
 scene_id = args.scene_id
 
-base_depth_path = os.path.join(
-    base_path, 'phoenix/S6/zl548/MegaDepth_v1'
-)
-base_undistorted_sfm_path = os.path.join(
-    base_path, 'Undistorted_SfM'
-)
+base_depth_path = os.path.join(base_path, "phoenix/S6/zl548/MegaDepth_v1")
+base_undistorted_sfm_path = os.path.join(base_path, "Undistorted_SfM")
 
 undistorted_sparse_path = os.path.join(
-    base_undistorted_sfm_path, scene_id, 'sparse-txt'
+    base_undistorted_sfm_path, scene_id, "sparse-txt"
 )
 if not os.path.exists(undistorted_sparse_path):
     exit()
 
-depths_path = os.path.join(
-    base_depth_path, scene_id, 'dense0', 'depths'
-)
+depths_path = os.path.join(base_depth_path, scene_id, "dense0", "depths")
 if not os.path.exists(depths_path):
     exit()
 
-images_path = os.path.join(
-    base_undistorted_sfm_path, scene_id, 'images'
-)
+images_path = os.path.join(base_undistorted_sfm_path, scene_id, "images")
 if not os.path.exists(images_path):
     exit()
 
 # Process cameras.txt
-with open(os.path.join(undistorted_sparse_path, 'cameras.txt'), 'r') as f:
-    raw = f.readlines()[3 :]  # skip the header
+with open(os.path.join(undistorted_sparse_path, "cameras.txt"), "r") as f:
+    raw = f.readlines()[3:]  # skip the header
 
 camera_intrinsics = {}
 for camera in raw:
-    camera = camera.split(' ')
-    camera_intrinsics[int(camera[0])] = [float(elem) for elem in camera[2 :]]
+    camera = camera.split(" ")
+    camera_intrinsics[int(camera[0])] = [float(elem) for elem in camera[2:]]
 
 # Process points3D.txt
-with open(os.path.join(undistorted_sparse_path, 'points3D.txt'), 'r') as f:
-    raw = f.readlines()[3 :]  # skip the header
+with open(os.path.join(undistorted_sparse_path, "points3D.txt"), "r") as f:
+    raw = f.readlines()[3:]  # skip the header
 
 points3D = {}
 for point3D in raw:
-    point3D = point3D.split(' ')
-    points3D[int(point3D[0])] = np.array([
-        float(point3D[1]), float(point3D[2]), float(point3D[3])
-    ])
-    
+    point3D = point3D.split(" ")
+    points3D[int(point3D[0])] = np.array(
+        [float(point3D[1]), float(point3D[2]), float(point3D[3])]
+    )
+
 # Process images.txt
-with open(os.path.join(undistorted_sparse_path, 'images.txt'), 'r') as f:
-    raw = f.readlines()[4 :]  # skip the header
+with open(os.path.join(undistorted_sparse_path, "images.txt"), "r") as f:
+    raw = f.readlines()[4:]  # skip the header
 
 image_id_to_idx = {}
 image_names = []
@@ -85,19 +70,19 @@ raw_pose = []
 camera = []
 points3D_id_to_2D = []
 n_points3D = []
-for idx, (image, points) in enumerate(zip(raw[:: 2], raw[1 :: 2])):
-    image = image.split(' ')
-    points = points.split(' ')
+for idx, (image, points) in enumerate(zip(raw[::2], raw[1::2])):
+    image = image.split(" ")
+    points = points.split(" ")
 
     image_id_to_idx[int(image[0])] = idx
 
-    image_name = image[-1].strip('\n')
+    image_name = image[-1].strip("\n")
     image_names.append(image_name)
 
-    raw_pose.append([float(elem) for elem in image[1 : -2]])
+    raw_pose.append([float(elem) for elem in image[1:-2]])
     camera.append(int(image[-2]))
     current_points3D_id_to_2D = {}
-    for x, y, point3D_id in zip(points[:: 3], points[1 :: 3], points[2 :: 3]):
+    for x, y, point3D_id in zip(points[::3], points[1::3], points[2::3]):
         if int(point3D_id) == -1:
             continue
         current_points3D_id_to_2D[int(point3D_id)] = [float(x), float(y)]
@@ -110,12 +95,10 @@ image_paths = []
 depth_paths = []
 for image_name in image_names:
     image_path = os.path.join(images_path, image_name)
-   
+
     # Path to the depth file
-    depth_path = os.path.join(
-        depths_path, '%s.h5' % os.path.splitext(image_name)[0]
-    )
-    
+    depth_path = os.path.join(depths_path, "%s.h5" % os.path.splitext(image_name)[0])
+
     if os.path.exists(depth_path):
         # Check if depth map or background / foreground mask
         file_size = os.stat(depth_path).st_size
@@ -152,32 +135,22 @@ for idx, image_name in enumerate(image_names):
     intrinsics.append(K)
 
     image_pose = raw_pose[idx]
-    qvec = image_pose[: 4]
+    qvec = image_pose[:4]
     qvec = qvec / np.linalg.norm(qvec)
     w, x, y, z = qvec
-    R = np.array([
-        [
-            1 - 2 * y * y - 2 * z * z,
-            2 * x * y - 2 * z * w,
-            2 * x * z + 2 * y * w
-        ],
+    R = np.array(
         [
-            2 * x * y + 2 * z * w,
-            1 - 2 * x * x - 2 * z * z,
-            2 * y * z - 2 * x * w
-        ],
-        [
-            2 * x * z - 2 * y * w,
-            2 * y * z + 2 * x * w,
-            1 - 2 * x * x - 2 * y * y
+            [1 - 2 * y * y - 2 * z * z, 2 * x * y - 2 * z * w, 2 * x * z + 2 * y * w],
+            [2 * x * y + 2 * z * w, 1 - 2 * x * x - 2 * z * z, 2 * y * z - 2 * x * w],
+            [2 * x * z - 2 * y * w, 2 * y * z + 2 * x * w, 1 - 2 * x * x - 2 * y * y],
         ]
-    ])
+    )
     principal_axis.append(R[2, :])
-    t = image_pose[4 : 7]
+    t = image_pose[4:7]
     # World-to-Camera pose
     current_pose = np.zeros([4, 4])
-    current_pose[: 3, : 3] = R
-    current_pose[: 3, 3] = t
+    current_pose[:3, :3] = R
+    current_pose[:3, 3] = t
     current_pose[3, 3] = 1
     # Camera-to-World pose
     # pose = np.zeros([4, 4])
@@ -185,38 +158,38 @@ for idx, image_name in enumerate(image_names):
     # pose[: 3, 3] = -np.matmul(np.transpose(R), t)
     # pose[3, 3] = 1
     poses.append(current_pose)
-    
+
     current_points3D_id_to_ndepth = {}
     for point3D_id in points3D_id_to_2D[idx].keys():
         p3d = points3D[point3D_id]
-        current_points3D_id_to_ndepth[point3D_id] = (np.dot(R[2, :], p3d) + t[2]) / (.5 * (K[0, 0] + K[1, 1])) 
+        current_points3D_id_to_ndepth[point3D_id] = (np.dot(R[2, :], p3d) + t[2]) / (
+            0.5 * (K[0, 0] + K[1, 1])
+        )
     points3D_id_to_ndepth.append(current_points3D_id_to_ndepth)
 principal_axis = np.array(principal_axis)
-angles = np.rad2deg(np.arccos(
-    np.clip(
-        np.dot(principal_axis, np.transpose(principal_axis)),
-        -1, 1
-    )
-))
+angles = np.rad2deg(
+    np.arccos(np.clip(np.dot(principal_axis, np.transpose(principal_axis)), -1, 1))
+)
 
 # Compute overlap score
-overlap_matrix = np.full([n_images, n_images], -1.)
-scale_ratio_matrix = np.full([n_images, n_images], -1.)
+overlap_matrix = np.full([n_images, n_images], -1.0)
+scale_ratio_matrix = np.full([n_images, n_images], -1.0)
 for idx1 in range(n_images):
     if image_paths[idx1] is None or depth_paths[idx1] is None:
         continue
     for idx2 in range(idx1 + 1, n_images):
         if image_paths[idx2] is None or depth_paths[idx2] is None:
             continue
-        matches = (
-            points3D_id_to_2D[idx1].keys() &
-            points3D_id_to_2D[idx2].keys()
-        )
+        matches = points3D_id_to_2D[idx1].keys() & points3D_id_to_2D[idx2].keys()
         min_num_points3D = min(
             len(points3D_id_to_2D[idx1]), len(points3D_id_to_2D[idx2])
         )
-        overlap_matrix[idx1, idx2] = len(matches) / len(points3D_id_to_2D[idx1])  # min_num_points3D
-        overlap_matrix[idx2, idx1] = len(matches) / len(points3D_id_to_2D[idx2])  # min_num_points3D
+        overlap_matrix[idx1, idx2] = len(matches) / len(
+            points3D_id_to_2D[idx1]
+        )  # min_num_points3D
+        overlap_matrix[idx2, idx1] = len(matches) / len(
+            points3D_id_to_2D[idx2]
+        )  # min_num_points3D
         if len(matches) == 0:
             continue
         points3D_id_to_ndepth1 = points3D_id_to_ndepth[idx1]
@@ -228,7 +201,7 @@ for idx1 in range(n_images):
         scale_ratio_matrix[idx2, idx1] = min_scale_ratio
 
 np.savez(
-    os.path.join(args.output_path, '%s.npz' % scene_id),
+    os.path.join(args.output_path, "%s.npz" % scene_id),
     image_paths=image_paths,
     depth_paths=depth_paths,
     intrinsics=intrinsics,
@@ -238,5 +211,5 @@ np.savez(
     angles=angles,
     n_points3D=n_points3D,
     points3D_id_to_2D=points3D_id_to_2D,
-    points3D_id_to_ndepth=points3D_id_to_ndepth
+    points3D_id_to_ndepth=points3D_id_to_ndepth,
 )
diff --git a/third_party/d2net/megadepth_utils/undistort_reconstructions.py b/third_party/d2net/megadepth_utils/undistort_reconstructions.py
index a6b99a72f81206e6fbefae9daa9aa683c8754051..822c9abd3fc75fd8fc1e8d9ada75aa76802c6798 100644
--- a/third_party/d2net/megadepth_utils/undistort_reconstructions.py
+++ b/third_party/d2net/megadepth_utils/undistort_reconstructions.py
@@ -6,28 +6,18 @@ import os
 
 import subprocess
 
-parser = argparse.ArgumentParser(description='MegaDepth Undistortion')
+parser = argparse.ArgumentParser(description="MegaDepth Undistortion")
 
 parser.add_argument(
-    '--colmap_path', type=str, required=True,
-    help='path to colmap executable'
-)
-parser.add_argument(
-    '--base_path', type=str, required=True,
-    help='path to MegaDepth'
+    "--colmap_path", type=str, required=True, help="path to colmap executable"
 )
+parser.add_argument("--base_path", type=str, required=True, help="path to MegaDepth")
 
 args = parser.parse_args()
 
-sfm_path = os.path.join(
-    args.base_path, 'MegaDepth_v1_SfM'
-)
-base_depth_path = os.path.join(
-    args.base_path, 'phoenix/S6/zl548/MegaDepth_v1'
-)
-output_path = os.path.join(
-    args.base_path, 'Undistorted_SfM'
-)
+sfm_path = os.path.join(args.base_path, "MegaDepth_v1_SfM")
+base_depth_path = os.path.join(args.base_path, "phoenix/S6/zl548/MegaDepth_v1")
+output_path = os.path.join(args.base_path, "Undistorted_SfM")
 
 os.mkdir(output_path)
 
@@ -35,35 +25,45 @@ for scene_name in os.listdir(base_depth_path):
     current_output_path = os.path.join(output_path, scene_name)
     os.mkdir(current_output_path)
 
-    image_path = os.path.join(
-        base_depth_path, scene_name, 'dense0', 'imgs'
-    )
+    image_path = os.path.join(base_depth_path, scene_name, "dense0", "imgs")
     if not os.path.exists(image_path):
         continue
-    
+
     # Find the maximum image size in scene.
     max_image_size = 0
     for image_name in os.listdir(image_path):
         max_image_size = max(
-            max_image_size,
-            max(imagesize.get(os.path.join(image_path, image_name)))
+            max_image_size, max(imagesize.get(os.path.join(image_path, image_name)))
         )
 
     # Undistort the images and update the reconstruction.
-    subprocess.call([
-        os.path.join(args.colmap_path, 'colmap'), 'image_undistorter', 
-        '--image_path', os.path.join(sfm_path, scene_name, 'images'),
-        '--input_path', os.path.join(sfm_path, scene_name, 'sparse', 'manhattan', '0'),
-        '--output_path',  current_output_path,
-        '--max_image_size', str(max_image_size)
-    ])
+    subprocess.call(
+        [
+            os.path.join(args.colmap_path, "colmap"),
+            "image_undistorter",
+            "--image_path",
+            os.path.join(sfm_path, scene_name, "images"),
+            "--input_path",
+            os.path.join(sfm_path, scene_name, "sparse", "manhattan", "0"),
+            "--output_path",
+            current_output_path,
+            "--max_image_size",
+            str(max_image_size),
+        ]
+    )
 
     # Transform the reconstruction to raw text format.
-    sparse_txt_path = os.path.join(current_output_path, 'sparse-txt')
+    sparse_txt_path = os.path.join(current_output_path, "sparse-txt")
     os.mkdir(sparse_txt_path)
-    subprocess.call([
-        os.path.join(args.colmap_path, 'colmap'), 'model_converter',
-        '--input_path', os.path.join(current_output_path, 'sparse'),
-        '--output_path', sparse_txt_path, 
-        '--output_type', 'TXT'
-    ])
\ No newline at end of file
+    subprocess.call(
+        [
+            os.path.join(args.colmap_path, "colmap"),
+            "model_converter",
+            "--input_path",
+            os.path.join(current_output_path, "sparse"),
+            "--output_path",
+            sparse_txt_path,
+            "--output_type",
+            "TXT",
+        ]
+    )
diff --git a/third_party/d2net/train.py b/third_party/d2net/train.py
index 5817f1712bda0779175fb18437d1f8c263f29f3b..5ca584e131c14930f86c3252f93b89f1aea40713 100644
--- a/third_party/d2net/train.py
+++ b/third_party/d2net/train.py
@@ -32,72 +32,64 @@ if use_cuda:
 np.random.seed(1)
 
 # Argument parsing
-parser = argparse.ArgumentParser(description='Training script')
+parser = argparse.ArgumentParser(description="Training script")
 
 parser.add_argument(
-    '--dataset_path', type=str, required=True,
-    help='path to the dataset'
+    "--dataset_path", type=str, required=True, help="path to the dataset"
 )
 parser.add_argument(
-    '--scene_info_path', type=str, required=True,
-    help='path to the processed scenes'
+    "--scene_info_path", type=str, required=True, help="path to the processed scenes"
 )
 
 parser.add_argument(
-    '--preprocessing', type=str, default='caffe',
-    help='image preprocessing (caffe or torch)'
+    "--preprocessing",
+    type=str,
+    default="caffe",
+    help="image preprocessing (caffe or torch)",
 )
 parser.add_argument(
-    '--model_file', type=str, default='models/d2_ots.pth',
-    help='path to the full model'
+    "--model_file", type=str, default="models/d2_ots.pth", help="path to the full model"
 )
 
 parser.add_argument(
-    '--num_epochs', type=int, default=10,
-    help='number of training epochs'
+    "--num_epochs", type=int, default=10, help="number of training epochs"
 )
+parser.add_argument("--lr", type=float, default=1e-3, help="initial learning rate")
+parser.add_argument("--batch_size", type=int, default=1, help="batch size")
 parser.add_argument(
-    '--lr', type=float, default=1e-3,
-    help='initial learning rate'
-)
-parser.add_argument(
-    '--batch_size', type=int, default=1,
-    help='batch size'
-)
-parser.add_argument(
-    '--num_workers', type=int, default=4,
-    help='number of workers for data loading'
+    "--num_workers", type=int, default=4, help="number of workers for data loading"
 )
 
 parser.add_argument(
-    '--use_validation', dest='use_validation', action='store_true',
-    help='use the validation split'
+    "--use_validation",
+    dest="use_validation",
+    action="store_true",
+    help="use the validation split",
 )
 parser.set_defaults(use_validation=False)
 
 parser.add_argument(
-    '--log_interval', type=int, default=250,
-    help='loss logging interval'
+    "--log_interval", type=int, default=250, help="loss logging interval"
 )
 
-parser.add_argument(
-    '--log_file', type=str, default='log.txt',
-    help='loss logging file'
-)
+parser.add_argument("--log_file", type=str, default="log.txt", help="loss logging file")
 
 parser.add_argument(
-    '--plot', dest='plot', action='store_true',
-    help='plot training pairs'
+    "--plot", dest="plot", action="store_true", help="plot training pairs"
 )
 parser.set_defaults(plot=False)
 
 parser.add_argument(
-    '--checkpoint_directory', type=str, default='checkpoints',
-    help='directory for training checkpoints'
+    "--checkpoint_directory",
+    type=str,
+    default="checkpoints",
+    help="directory for training checkpoints",
 )
 parser.add_argument(
-    '--checkpoint_prefix', type=str, default='d2',
-    help='prefix for training checkpoints'
+    "--checkpoint_prefix",
+    type=str,
+    default="d2",
+    help="prefix for training checkpoints",
 )
 
 args = parser.parse_args()
@@ -106,17 +98,14 @@ print(args)
 
 # Create the folders for plotting if need be
 if args.plot:
-    plot_path = 'train_vis'
+    plot_path = "train_vis"
     if os.path.isdir(plot_path):
-        print('[Warning] Plotting directory already exists.')
+        print("[Warning] Plotting directory already exists.")
     else:
         os.mkdir(plot_path)
 
 # Creating CNN model
-model = D2Net(
-    model_file=args.model_file,
-    use_cuda=use_cuda
-)
+model = D2Net(model_file=args.model_file, use_cuda=use_cuda)
 
 # Optimizer
 optimizer = optim.Adam(
@@ -126,37 +115,39 @@ optimizer = optim.Adam(
 # Dataset
 if args.use_validation:
     validation_dataset = MegaDepthDataset(
-        scene_list_path='megadepth_utils/valid_scenes.txt',
+        scene_list_path="megadepth_utils/valid_scenes.txt",
         scene_info_path=args.scene_info_path,
         base_path=args.dataset_path,
         train=False,
         preprocessing=args.preprocessing,
-        pairs_per_scene=25
+        pairs_per_scene=25,
     )
     validation_dataloader = DataLoader(
-        validation_dataset,
-        batch_size=args.batch_size,
-        num_workers=args.num_workers
+        validation_dataset, batch_size=args.batch_size, num_workers=args.num_workers
     )
 
 training_dataset = MegaDepthDataset(
-    scene_list_path='megadepth_utils/train_scenes.txt',
+    scene_list_path="megadepth_utils/train_scenes.txt",
     scene_info_path=args.scene_info_path,
     base_path=args.dataset_path,
-    preprocessing=args.preprocessing
+    preprocessing=args.preprocessing,
 )
 training_dataloader = DataLoader(
-    training_dataset,
-    batch_size=args.batch_size,
-    num_workers=args.num_workers
+    training_dataset, batch_size=args.batch_size, num_workers=args.num_workers
 )
 
 
 # Define epoch function
 def process_epoch(
-        epoch_idx,
-        model, loss_function, optimizer, dataloader, device,
-        log_file, args, train=True
+    epoch_idx,
+    model,
+    loss_function,
+    optimizer,
+    dataloader,
+    device,
+    log_file,
+    args,
+    train=True,
 ):
     epoch_losses = []
 
@@ -167,12 +158,12 @@ def process_epoch(
         if train:
             optimizer.zero_grad()
 
-        batch['train'] = train
-        batch['epoch_idx'] = epoch_idx
-        batch['batch_idx'] = batch_idx
-        batch['batch_size'] = args.batch_size
-        batch['preprocessing'] = args.preprocessing
-        batch['log_interval'] = args.log_interval
+        batch["train"] = train
+        batch["epoch_idx"] = epoch_idx
+        batch["batch_idx"] = batch_idx
+        batch["batch_size"] = args.batch_size
+        batch["preprocessing"] = args.preprocessing
+        batch["log_interval"] = args.log_interval
 
         try:
             loss = loss_function(model, batch, device, plot=args.plot)
@@ -182,23 +173,28 @@ def process_epoch(
         current_loss = loss.data.cpu().numpy()[0]
         epoch_losses.append(current_loss)
 
-        progress_bar.set_postfix(loss=('%.4f' % np.mean(epoch_losses)))
+        progress_bar.set_postfix(loss=("%.4f" % np.mean(epoch_losses)))
 
         if batch_idx % args.log_interval == 0:
-            log_file.write('[%s] epoch %d - batch %d / %d - avg_loss: %f\n' % (
-                'train' if train else 'valid',
-                epoch_idx, batch_idx, len(dataloader), np.mean(epoch_losses)
-            ))
+            log_file.write(
+                "[%s] epoch %d - batch %d / %d - avg_loss: %f\n"
+                % (
+                    "train" if train else "valid",
+                    epoch_idx,
+                    batch_idx,
+                    len(dataloader),
+                    np.mean(epoch_losses),
+                )
+            )
 
         if train:
             loss.backward()
             optimizer.step()
 
-    log_file.write('[%s] epoch %d - avg_loss: %f\n' % (
-        'train' if train else 'valid',
-        epoch_idx,
-        np.mean(epoch_losses)
-    ))
+    log_file.write(
+        "[%s] epoch %d - avg_loss: %f\n"
+        % ("train" if train else "valid", epoch_idx, np.mean(epoch_losses))
+    )
     log_file.flush()
 
     return np.mean(epoch_losses)
@@ -206,15 +202,15 @@ def process_epoch(
 
 # Create the checkpoint directory
 if os.path.isdir(args.checkpoint_directory):
-    print('[Warning] Checkpoint directory already exists.')
+    print("[Warning] Checkpoint directory already exists.")
 else:
     os.mkdir(args.checkpoint_directory)
-    
+
 
 # Open the log file for writing
 if os.path.exists(args.log_file):
-    print('[Warning] Log file already exists.')
-log_file = open(args.log_file, 'a+')
+    print("[Warning] Log file already exists.")
+log_file = open(args.log_file, "a+")
 
 # Initialize the history
 train_loss_history = []
@@ -223,9 +219,14 @@ if args.use_validation:
     validation_dataset.build_dataset()
     min_validation_loss = process_epoch(
         0,
-        model, loss_function, optimizer, validation_dataloader, device,
-        log_file, args,
-        train=False
+        model,
+        loss_function,
+        optimizer,
+        validation_dataloader,
+        device,
+        log_file,
+        args,
+        train=False,
     )
 
 # Start the training
@@ -235,8 +236,13 @@ for epoch_idx in range(1, args.num_epochs + 1):
     train_loss_history.append(
         process_epoch(
             epoch_idx,
-            model, loss_function, optimizer, training_dataloader, device,
-            log_file, args
+            model,
+            loss_function,
+            optimizer,
+            training_dataloader,
+            device,
+            log_file,
+            args,
         )
     )
 
@@ -244,34 +250,34 @@ for epoch_idx in range(1, args.num_epochs + 1):
         validation_loss_history.append(
             process_epoch(
                 epoch_idx,
-                model, loss_function, optimizer, validation_dataloader, device,
-                log_file, args,
-                train=False
+                model,
+                loss_function,
+                optimizer,
+                validation_dataloader,
+                device,
+                log_file,
+                args,
+                train=False,
             )
         )
 
     # Save the current checkpoint
     checkpoint_path = os.path.join(
-        args.checkpoint_directory,
-        '%s.%02d.pth' % (args.checkpoint_prefix, epoch_idx)
+        args.checkpoint_directory, "%s.%02d.pth" % (args.checkpoint_prefix, epoch_idx)
     )
     checkpoint = {
-        'args': args,
-        'epoch_idx': epoch_idx,
-        'model': model.state_dict(),
-        'optimizer': optimizer.state_dict(),
-        'train_loss_history': train_loss_history,
-        'validation_loss_history': validation_loss_history
+        "args": args,
+        "epoch_idx": epoch_idx,
+        "model": model.state_dict(),
+        "optimizer": optimizer.state_dict(),
+        "train_loss_history": train_loss_history,
+        "validation_loss_history": validation_loss_history,
     }
     torch.save(checkpoint, checkpoint_path)
-    if (
-        args.use_validation and
-        validation_loss_history[-1] < min_validation_loss
-    ):
+    if args.use_validation and validation_loss_history[-1] < min_validation_loss:
         min_validation_loss = validation_loss_history[-1]
         best_checkpoint_path = os.path.join(
-            args.checkpoint_directory,
-            '%s.best.pth' % args.checkpoint_prefix
+            args.checkpoint_directory, "%s.best.pth" % args.checkpoint_prefix
         )
         shutil.copy(checkpoint_path, best_checkpoint_path)
 
diff --git a/third_party/lanet/augmentations.py b/third_party/lanet/augmentations.py
index f4e4496c77ce8fc8cdadb230dd0d0750166152a9..c39b7bfee0b42730f81e8f614352a58c25187b59 100644
--- a/third_party/lanet/augmentations.py
+++ b/third_party/lanet/augmentations.py
@@ -54,110 +54,163 @@ def resize_sample(sample, image_shape, image_interpolation=Image.ANTIALIAS):
     """
     # image
     image_transform = transforms.Resize(image_shape, interpolation=image_interpolation)
-    sample['image'] = image_transform(sample['image'])
+    sample["image"] = image_transform(sample["image"])
     return sample
 
+
 def spatial_augment_sample(sample):
-    """ Apply spatial augmentation to an image (flipping and random affine transformation)."""
-    augment_image = transforms.Compose([
-        transforms.RandomVerticalFlip(p=0.5),
-        transforms.RandomHorizontalFlip(p=0.5),
-        transforms.RandomAffine(15, translate=(0.1, 0.1), scale=(0.9, 1.1))
-        
-    ])
-    sample['image'] = augment_image(sample['image'])
+    """Apply spatial augmentation to an image (flipping and random affine transformation)."""
+    augment_image = transforms.Compose(
+        [
+            transforms.RandomVerticalFlip(p=0.5),
+            transforms.RandomHorizontalFlip(p=0.5),
+            transforms.RandomAffine(15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
+        ]
+    )
+    sample["image"] = augment_image(sample["image"])
 
     return sample
 
+
 def unnormalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
-    """ Counterpart method of torchvision.transforms.Normalize."""
+    """Counterpart method of torchvision.transforms.Normalize."""
     for t, m, s in zip(tensor, mean, std):
         t.div_(1 / s).sub_(-m)
     return tensor
 
 
 def sample_homography(
-        shape, perspective=True, scaling=True, rotation=True, translation=True,
-        n_scales=100, n_angles=100, scaling_amplitude=0.1, perspective_amplitude=0.4,
-        patch_ratio=0.8, max_angle=pi/4):
-    """ Sample a random homography that includes perspective, scale, translation and rotation operations."""
+    shape,
+    perspective=True,
+    scaling=True,
+    rotation=True,
+    translation=True,
+    n_scales=100,
+    n_angles=100,
+    scaling_amplitude=0.1,
+    perspective_amplitude=0.4,
+    patch_ratio=0.8,
+    max_angle=pi / 4,
+):
+    """Sample a random homography that includes perspective, scale, translation and rotation operations."""
 
     width = float(shape[1])
     hw_ratio = float(shape[0]) / float(shape[1])
 
-    pts1 = np.stack([[-1., -1.], [-1., 1.], [1., -1.], [1., 1.]], axis=0)
+    pts1 = np.stack([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]], axis=0)
     pts2 = pts1.copy() * patch_ratio
-    pts2[:,1] *= hw_ratio
+    pts2[:, 1] *= hw_ratio
 
     if perspective:
 
-        perspective_amplitude_x = np.random.normal(0., perspective_amplitude/2, (2))
-        perspective_amplitude_y = np.random.normal(0., hw_ratio * perspective_amplitude/2, (2))
+        perspective_amplitude_x = np.random.normal(0.0, perspective_amplitude / 2, (2))
+        perspective_amplitude_y = np.random.normal(
+            0.0, hw_ratio * perspective_amplitude / 2, (2)
+        )
 
-        perspective_amplitude_x = np.clip(perspective_amplitude_x, -perspective_amplitude/2, perspective_amplitude/2)
-        perspective_amplitude_y = np.clip(perspective_amplitude_y, hw_ratio * -perspective_amplitude/2, hw_ratio * perspective_amplitude/2)
+        perspective_amplitude_x = np.clip(
+            perspective_amplitude_x,
+            -perspective_amplitude / 2,
+            perspective_amplitude / 2,
+        )
+        perspective_amplitude_y = np.clip(
+            perspective_amplitude_y,
+            hw_ratio * -perspective_amplitude / 2,
+            hw_ratio * perspective_amplitude / 2,
+        )
 
-        pts2[0,0] -= perspective_amplitude_x[1]
-        pts2[0,1] -= perspective_amplitude_y[1]
+        pts2[0, 0] -= perspective_amplitude_x[1]
+        pts2[0, 1] -= perspective_amplitude_y[1]
 
-        pts2[1,0] -= perspective_amplitude_x[0]
-        pts2[1,1] += perspective_amplitude_y[1]
+        pts2[1, 0] -= perspective_amplitude_x[0]
+        pts2[1, 1] += perspective_amplitude_y[1]
 
-        pts2[2,0] += perspective_amplitude_x[1]
-        pts2[2,1] -= perspective_amplitude_y[0]
+        pts2[2, 0] += perspective_amplitude_x[1]
+        pts2[2, 1] -= perspective_amplitude_y[0]
 
-        pts2[3,0] += perspective_amplitude_x[0]
-        pts2[3,1] += perspective_amplitude_y[0]
+        pts2[3, 0] += perspective_amplitude_x[0]
+        pts2[3, 1] += perspective_amplitude_y[0]
 
     if scaling:
 
-        random_scales = np.random.normal(1, scaling_amplitude/2, (n_scales))
-        random_scales = np.clip(random_scales, 1-scaling_amplitude/2, 1+scaling_amplitude/2)
+        random_scales = np.random.normal(1, scaling_amplitude / 2, (n_scales))
+        random_scales = np.clip(
+            random_scales, 1 - scaling_amplitude / 2, 1 + scaling_amplitude / 2
+        )
 
-        scales = np.concatenate([[1.], random_scales], 0)
+        scales = np.concatenate([[1.0], random_scales], 0)
         center = np.mean(pts2, axis=0, keepdims=True)
-        scaled = np.expand_dims(pts2 - center, axis=0) * np.expand_dims(
-                np.expand_dims(scales, 1), 1) + center
+        scaled = (
+            np.expand_dims(pts2 - center, axis=0)
+            * np.expand_dims(np.expand_dims(scales, 1), 1)
+            + center
+        )
         valid = np.arange(n_scales)  # all scales are valid except scale=1
         idx = valid[np.random.randint(valid.shape[0])]
         pts2 = scaled[idx]
 
     if translation:
-        t_min, t_max = np.min(pts2 - [-1., -hw_ratio], axis=0), np.min([1., hw_ratio] - pts2, axis=0)
-        pts2 += np.expand_dims(np.stack([np.random.uniform(-t_min[0], t_max[0]),
-                                         np.random.uniform(-t_min[1], t_max[1])]),
-                               axis=0)
+        t_min, t_max = np.min(pts2 - [-1.0, -hw_ratio], axis=0), np.min(
+            [1.0, hw_ratio] - pts2, axis=0
+        )
+        pts2 += np.expand_dims(
+            np.stack(
+                [
+                    np.random.uniform(-t_min[0], t_max[0]),
+                    np.random.uniform(-t_min[1], t_max[1]),
+                ]
+            ),
+            axis=0,
+        )
 
     if rotation:
         angles = np.linspace(-max_angle, max_angle, n_angles)
-        angles = np.concatenate([[0.], angles], axis=0) 
+        angles = np.concatenate([[0.0], angles], axis=0)
 
         center = np.mean(pts2, axis=0, keepdims=True)
-        rot_mat = np.reshape(np.stack([np.cos(angles), -np.sin(angles), np.sin(angles),
-                                       np.cos(angles)], axis=1), [-1, 2, 2])
-        rotated = np.matmul(
-                np.tile(np.expand_dims(pts2 - center, axis=0), [n_angles+1, 1, 1]),
-                rot_mat) + center
+        rot_mat = np.reshape(
+            np.stack(
+                [np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)],
+                axis=1,
+            ),
+            [-1, 2, 2],
+        )
+        rotated = (
+            np.matmul(
+                np.tile(np.expand_dims(pts2 - center, axis=0), [n_angles + 1, 1, 1]),
+                rot_mat,
+            )
+            + center
+        )
 
-        valid = np.where(np.all((rotated >= [-1.,-hw_ratio]) & (rotated < [1.,hw_ratio]),
-                                        axis=(1, 2)))[0]
+        valid = np.where(
+            np.all(
+                (rotated >= [-1.0, -hw_ratio]) & (rotated < [1.0, hw_ratio]),
+                axis=(1, 2),
+            )
+        )[0]
 
         idx = valid[np.random.randint(valid.shape[0])]
         pts2 = rotated[idx]
 
-    pts2[:,1] /= hw_ratio
+    pts2[:, 1] /= hw_ratio
+
+    def ax(p, q):
+        return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]]
 
-    def ax(p, q): return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]]
-    def ay(p, q): return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]]
+    def ay(p, q):
+        return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]]
 
     a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) for f in (ax, ay)], axis=0)
-    p_mat = np.transpose(np.stack(
-        [[pts2[i][j] for i in range(4) for j in range(2)]], axis=0))
+    p_mat = np.transpose(
+        np.stack([[pts2[i][j] for i in range(4) for j in range(2)]], axis=0)
+    )
 
     homography = np.matmul(np.linalg.pinv(a_mat), p_mat).squeeze()
-    homography = np.concatenate([homography, [1.]]).reshape(3,3)
+    homography = np.concatenate([homography, [1.0]]).reshape(3, 3)
     return homography
 
+
 def warp_homography(sources, homography):
     """Warp features given a homography
 
@@ -175,12 +228,15 @@ def warp_homography(sources, homography):
     """
     _, H, W, _ = sources.shape
     warped_sources = sources.clone().squeeze()
-    warped_sources = warped_sources.view(-1,2)
-    warped_sources = torch.addmm(homography[:,2], warped_sources, homography[:,:2].t())
-    warped_sources.mul_(1/warped_sources[:,2].unsqueeze(1))
-    warped_sources = warped_sources[:,:2].contiguous().view(1,H,W,2)
+    warped_sources = warped_sources.view(-1, 2)
+    warped_sources = torch.addmm(
+        homography[:, 2], warped_sources, homography[:, :2].t()
+    )
+    warped_sources.mul_(1 / warped_sources[:, 2].unsqueeze(1))
+    warped_sources = warped_sources[:, :2].contiguous().view(1, H, W, 2)
     return warped_sources
 
+
 def add_noise(img, mode="gaussian", percent=0.02):
     """Add image noise
 
@@ -259,36 +315,40 @@ def add_noise(img, mode="gaussian", percent=0.02):
     return noisy
 
 
-def non_spatial_augmentation(img_warp_ori, jitter_paramters, color_order=[0,1,2], to_gray=False):
-    """ Apply non-spatial augmentation to an image (jittering, color swap, convert to gray scale, Gaussian blur)."""
+def non_spatial_augmentation(
+    img_warp_ori, jitter_paramters, color_order=[0, 1, 2], to_gray=False
+):
+    """Apply non-spatial augmentation to an image (jittering, color swap, convert to gray scale, Gaussian blur)."""
 
     brightness, contrast, saturation, hue = jitter_paramters
     color_augmentation = transforms.ColorJitter(brightness, contrast, saturation, hue)
-    '''
+    """
     augment_image = color_augmentation.get_params(brightness=[max(0, 1 - brightness), 1 + brightness],
                                                     contrast=[max(0, 1 - contrast), 1 + contrast],
                                                     saturation=[max(0, 1 - saturation), 1 + saturation],
                                                     hue=[-hue, hue])
-    '''
+    """
 
     B = img_warp_ori.shape[0]
     img_warp = []
-    kernel_sizes = [0,1,3,5]
+    kernel_sizes = [0, 1, 3, 5]
     for b in range(B):
         img_warp_sub = img_warp_ori[b].cpu()
         img_warp_sub = torchvision.transforms.functional.to_pil_image(img_warp_sub)
 
-        img_warp_sub_np = np.array(img_warp_sub) 
-        img_warp_sub_np = img_warp_sub_np[:,:,color_order]
-        
+        img_warp_sub_np = np.array(img_warp_sub)
+        img_warp_sub_np = img_warp_sub_np[:, :, color_order]
+
         if np.random.rand() > 0.5:
             img_warp_sub_np = add_noise(img_warp_sub_np)
 
         rand_index = np.random.randint(4)
         kernel_size = kernel_sizes[rand_index]
-        if kernel_size >0:
-            img_warp_sub_np = cv2.GaussianBlur(img_warp_sub_np, (kernel_size, kernel_size), sigmaX=0)
-        
+        if kernel_size > 0:
+            img_warp_sub_np = cv2.GaussianBlur(
+                img_warp_sub_np, (kernel_size, kernel_size), sigmaX=0
+            )
+
         if to_gray:
             img_warp_sub_np = cv2.cvtColor(img_warp_sub_np, cv2.COLOR_RGB2GRAY)
             img_warp_sub_np = cv2.cvtColor(img_warp_sub_np, cv2.COLOR_GRAY2RGB)
@@ -296,35 +356,54 @@ def non_spatial_augmentation(img_warp_ori, jitter_paramters, color_order=[0,1,2]
         img_warp_sub = Image.fromarray(img_warp_sub_np)
         img_warp_sub = color_augmentation(img_warp_sub)
 
-        img_warp_sub = torchvision.transforms.functional.to_tensor(img_warp_sub).to(img_warp_ori.device)
+        img_warp_sub = torchvision.transforms.functional.to_tensor(img_warp_sub).to(
+            img_warp_ori.device
+        )
 
         img_warp.append(img_warp_sub)
 
     img_warp = torch.stack(img_warp, dim=0)
     return img_warp
 
-def ha_augment_sample(data, jitter_paramters=[0.5, 0.5, 0.2, 0.05], patch_ratio=0.7, scaling_amplitude=0.2, max_angle=pi/4):
+
+def ha_augment_sample(
+    data,
+    jitter_paramters=[0.5, 0.5, 0.2, 0.05],
+    patch_ratio=0.7,
+    scaling_amplitude=0.2,
+    max_angle=pi / 4,
+):
     """Apply Homography Adaptation image augmentation."""
-    input_img = data['image'].unsqueeze(0)
+    input_img = data["image"].unsqueeze(0)
     _, _, H, W = input_img.shape
     device = input_img.device
-    
-    homography = torch.from_numpy(
-        sample_homography([H, W], 
-        patch_ratio=patch_ratio, 
-        scaling_amplitude=scaling_amplitude, 
-        max_angle=max_angle)).float().to(device)
+
+    homography = (
+        torch.from_numpy(
+            sample_homography(
+                [H, W],
+                patch_ratio=patch_ratio,
+                scaling_amplitude=scaling_amplitude,
+                max_angle=max_angle,
+            )
+        )
+        .float()
+        .to(device)
+    )
     homography_inv = torch.inverse(homography)
 
-    source = image_grid(1, H, W,
-                    dtype=input_img.dtype,
-                    device=device,
-                    ones=False, normalized=True).clone().permute(0, 2, 3, 1)
+    source = (
+        image_grid(
+            1, H, W, dtype=input_img.dtype, device=device, ones=False, normalized=True
+        )
+        .clone()
+        .permute(0, 2, 3, 1)
+    )
 
     target_warped = warp_homography(source, homography)
     img_warp = torch.nn.functional.grid_sample(input_img, target_warped)
 
-    color_order = [0,1,2]
+    color_order = [0, 1, 2]
     if np.random.rand() > 0.5:
         random.shuffle(color_order)
 
@@ -332,11 +411,21 @@ def ha_augment_sample(data, jitter_paramters=[0.5, 0.5, 0.2, 0.05], patch_ratio=
     if np.random.rand() > 0.5:
         to_gray = True
 
-    input_img = non_spatial_augmentation(input_img, jitter_paramters=jitter_paramters, color_order=color_order, to_gray=to_gray)
-    img_warp = non_spatial_augmentation(img_warp, jitter_paramters=jitter_paramters, color_order=color_order, to_gray=to_gray)
-
-    data['image'] = input_img.squeeze()
-    data['image_aug'] = img_warp.squeeze()
-    data['homography'] = homography
-    data['homography_inv'] = homography_inv
+    input_img = non_spatial_augmentation(
+        input_img,
+        jitter_paramters=jitter_paramters,
+        color_order=color_order,
+        to_gray=to_gray,
+    )
+    img_warp = non_spatial_augmentation(
+        img_warp,
+        jitter_paramters=jitter_paramters,
+        color_order=color_order,
+        to_gray=to_gray,
+    )
+
+    data["image"] = input_img.squeeze()
+    data["image_aug"] = img_warp.squeeze()
+    data["homography"] = homography
+    data["homography_inv"] = homography_inv
     return data
diff --git a/third_party/lanet/config.py b/third_party/lanet/config.py
index baa3aedc95410b231c29ab64b31ea5a2bd3266d7..84419d0a1f7199e8bec1afc7b046e674a629d886 100644
--- a/third_party/lanet/config.py
+++ b/third_party/lanet/config.py
@@ -1,78 +1,94 @@
 import argparse
 
 arg_lists = []
-parser = argparse.ArgumentParser(description='LANet')
+parser = argparse.ArgumentParser(description="LANet")
+
 
 def str2bool(v):
-    return v.lower() in ('true', '1')
+    return v.lower() in ("true", "1")
+
 
 def add_argument_group(name):
     arg = parser.add_argument_group(name)
     arg_lists.append(arg)
     return arg
 
+
 # train data params
-traindata_arg = add_argument_group('Traindata Params')
-traindata_arg.add_argument('--train_txt', type=str, default='',
-                            help='Train set.')
-traindata_arg.add_argument('--train_root', type=str, default='',
-                            help='Where the train images are.')
-traindata_arg.add_argument('--batch_size', type=int, default=8,
-                            help='# of images in each batch of data')
-traindata_arg.add_argument('--num_workers', type=int, default=4,
-                            help='# of subprocesses to use for data loading')
-traindata_arg.add_argument('--pin_memory', type=str2bool, default=True,
-                            help='# of subprocesses to use for data loading')
-traindata_arg.add_argument('--shuffle', type=str2bool, default=True,
-                            help='Whether to shuffle the train and valid indices')
-traindata_arg.add_argument('--image_shape', type=tuple, default=(240, 320),
-                            help='')
-traindata_arg.add_argument('--jittering', type=tuple, default=(0.5, 0.5, 0.2, 0.05),
-                            help='')
+traindata_arg = add_argument_group("Traindata Params")
+traindata_arg.add_argument("--train_txt", type=str, default="", help="Train set.")
+traindata_arg.add_argument(
+    "--train_root", type=str, default="", help="Where the train images are."
+)
+traindata_arg.add_argument(
+    "--batch_size", type=int, default=8, help="# of images in each batch of data"
+)
+traindata_arg.add_argument(
+    "--num_workers",
+    type=int,
+    default=4,
+    help="# of subprocesses to use for data loading",
+)
+traindata_arg.add_argument(
+    "--pin_memory",
+    type=str2bool,
+    default=True,
+    help="# of subprocesses to use for data loading",
+)
+traindata_arg.add_argument(
+    "--shuffle",
+    type=str2bool,
+    default=True,
+    help="Whether to shuffle the train and valid indices",
+)
+traindata_arg.add_argument("--image_shape", type=tuple, default=(240, 320), help="")
+traindata_arg.add_argument(
+    "--jittering", type=tuple, default=(0.5, 0.5, 0.2, 0.05), help=""
+)
 
 # data storage
-storage_arg = add_argument_group('Storage')
-storage_arg.add_argument('--ckpt_name', type=str, default='PointModel',
-                            help='')
+storage_arg = add_argument_group("Storage")
+storage_arg.add_argument("--ckpt_name", type=str, default="PointModel", help="")
 
 # training params
-train_arg = add_argument_group('Training Params')
-train_arg.add_argument('--start_epoch', type=int, default=0,
-                        help='')
-train_arg.add_argument('--max_epoch', type=int, default=12,
-                        help='')
-train_arg.add_argument('--init_lr', type=float, default=3e-4,
-                        help='Initial learning rate value.')
-train_arg.add_argument('--lr_factor', type=float, default=0.5,
-                        help='Reduce learning rate value.')	
-train_arg.add_argument('--momentum', type=float, default=0.9,
-                        help='Nesterov momentum value.')			   
-train_arg.add_argument('--display', type=int, default=50,
-                        help='')
+train_arg = add_argument_group("Training Params")
+train_arg.add_argument("--start_epoch", type=int, default=0, help="")
+train_arg.add_argument("--max_epoch", type=int, default=12, help="")
+train_arg.add_argument(
+    "--init_lr", type=float, default=3e-4, help="Initial learning rate value."
+)
+train_arg.add_argument(
+    "--lr_factor", type=float, default=0.5, help="Reduce learning rate value."
+)
+train_arg.add_argument(
+    "--momentum", type=float, default=0.9, help="Nesterov momentum value."
+)
+train_arg.add_argument("--display", type=int, default=50, help="")
 
 # loss function params
-loss_arg = add_argument_group('Loss function Params')
-loss_arg.add_argument('--score_weight', type=float, default=1.,
-                        help='')
-loss_arg.add_argument('--loc_weight', type=float, default=1.,
-                        help='')
-loss_arg.add_argument('--desc_weight', type=float, default=4.,
-                        help='')
-loss_arg.add_argument('--corres_weight', type=float, default=.5,
-                        help='')
-loss_arg.add_argument('--corres_threshold', type=int, default=4.,
-                        help='')
-					   
+loss_arg = add_argument_group("Loss function Params")
+loss_arg.add_argument("--score_weight", type=float, default=1.0, help="")
+loss_arg.add_argument("--loc_weight", type=float, default=1.0, help="")
+loss_arg.add_argument("--desc_weight", type=float, default=4.0, help="")
+loss_arg.add_argument("--corres_weight", type=float, default=0.5, help="")
+loss_arg.add_argument("--corres_threshold", type=int, default=4.0, help="")
+
 # other params
-misc_arg = add_argument_group('Misc.')
-misc_arg.add_argument('--use_gpu', type=str2bool, default=True,
-                        help="Whether to run on the GPU.")
-misc_arg.add_argument('--gpu', type=int, default=0,
-                        help="Which GPU to run on.")										  
-misc_arg.add_argument('--seed', type=int, default=1001,
-                        help='Seed to ensure reproducibility.')					  
-misc_arg.add_argument('--ckpt_dir', type=str, default='./checkpoints',
-                        help='Directory in which to save model checkpoints.')					  
+misc_arg = add_argument_group("Misc.")
+misc_arg.add_argument(
+    "--use_gpu", type=str2bool, default=True, help="Whether to run on the GPU."
+)
+misc_arg.add_argument("--gpu", type=int, default=0, help="Which GPU to run on.")
+misc_arg.add_argument(
+    "--seed", type=int, default=1001, help="Seed to ensure reproducibility."
+)
+misc_arg.add_argument(
+    "--ckpt_dir",
+    type=str,
+    default="./checkpoints",
+    help="Directory in which to save model checkpoints.",
+)
+
 
 def get_config():
     config, unparsed = parser.parse_known_args()
diff --git a/third_party/lanet/data_loader.py b/third_party/lanet/data_loader.py
index e694e39bb5f3e7ad6763a5cfcce3ca4804071262..d8e7bcac2274a512127920e1695a8923fd009f8a 100644
--- a/third_party/lanet/data_loader.py
+++ b/third_party/lanet/data_loader.py
@@ -4,6 +4,7 @@ from torch.utils.data import Dataset, DataLoader
 from augmentations import ha_augment_sample, resize_sample, spatial_augment_sample
 from utils import to_tensor_sample
 
+
 def image_transforms(shape, jittering):
     def train_transforms(sample):
         sample = resize_sample(sample, image_shape=shape)
@@ -12,14 +13,15 @@ def image_transforms(shape, jittering):
         sample = ha_augment_sample(sample, jitter_paramters=jittering)
         return sample
 
-    return {'train': train_transforms}
+    return {"train": train_transforms}
+
 
 class GetData(Dataset):
     def __init__(self, config, transforms=None):
         """
         Get the list containing all images and labels.
         """
-        datafile = open(config.train_txt, 'r')
+        datafile = open(config.train_txt, "r")
         lines = datafile.readlines()
 
         dataset = []
@@ -31,9 +33,9 @@ class GetData(Dataset):
         self.config = config
         self.dataset = dataset
         self.root = config.train_root
-        
+
         self.transforms = transforms
-	
+
     def __getitem__(self, index):
         """
         Return image'data and its label.
@@ -41,14 +43,14 @@ class GetData(Dataset):
         img_path = self.dataset[index]
         img_file = self.root + img_path
         img = Image.open(img_file)
-        
-        # image.mode == 'L' means the image is in gray scale 
-        if img.mode == 'L':
+
+        # image.mode == 'L' means the image is in gray scale
+        if img.mode == "L":
             img_new = Image.new("RGB", img.size)
             img_new.paste(img)
-            sample = {'image': img_new, 'idx': index}
+            sample = {"image": img_new, "idx": index}
         else:
-            sample = {'image': img, 'idx': index}
+            sample = {"image": img, "idx": index}
 
         if self.transforms:
             sample = self.transforms(sample)
@@ -61,26 +63,27 @@ class GetData(Dataset):
         """
         return len(self.dataset)
 
+
 def get_data_loader(
-                config,
-                transforms=None,
-                sampler=None,
-                drop_last=True,
-                ):
+    config,
+    transforms=None,
+    sampler=None,
+    drop_last=True,
+):
     """
     Return batch data for training.
     """
     transforms = image_transforms(shape=config.image_shape, jittering=config.jittering)
-    dataset = GetData(config, transforms=transforms['train'])
+    dataset = GetData(config, transforms=transforms["train"])
 
     train_loader = DataLoader(
-                        dataset,
-                        batch_size=config.batch_size,
-                        shuffle=config.shuffle,
-                        sampler=sampler,
-                        num_workers=config.num_workers,
-                        pin_memory=config.pin_memory,
-                        drop_last=drop_last
-                        )
+        dataset,
+        batch_size=config.batch_size,
+        shuffle=config.shuffle,
+        sampler=sampler,
+        num_workers=config.num_workers,
+        pin_memory=config.pin_memory,
+        drop_last=drop_last,
+    )
 
     return train_loader
diff --git a/third_party/lanet/datasets/hp_loader.py b/third_party/lanet/datasets/hp_loader.py
index b4c1d8f3c33fd51bfa928c529544a77c06ed73f0..f255c87dac6e06e56b67ad0f04f7da5c131f0189 100644
--- a/third_party/lanet/datasets/hp_loader.py
+++ b/third_party/lanet/datasets/hp_loader.py
@@ -30,7 +30,15 @@ class PatchesDataset(Dataset):
         v - viewpoint sequences
         all - all sequences
     """
-    def __init__(self, root_dir, use_color=True, data_transform=None, output_shape=None, type='all'):
+
+    def __init__(
+        self,
+        root_dir,
+        use_color=True,
+        data_transform=None,
+        output_shape=None,
+        type="all",
+    ):
         super().__init__()
         self.type = type
         self.root_dir = root_dir
@@ -43,33 +51,36 @@ class PatchesDataset(Dataset):
         warped_image_paths = []
         homographies = []
         for path in folder_paths:
-            if self.type == 'i' and path.stem[0] != 'i':
+            if self.type == "i" and path.stem[0] != "i":
                 continue
-            if self.type == 'v' and path.stem[0] != 'v':
+            if self.type == "v" and path.stem[0] != "v":
                 continue
             num_images = 5
-            file_ext = '.ppm'
+            file_ext = ".ppm"
             for i in range(2, 2 + num_images):
                 image_paths.append(str(Path(path, "1" + file_ext)))
                 warped_image_paths.append(str(Path(path, str(i) + file_ext)))
                 homographies.append(np.loadtxt(str(Path(path, "H_1_" + str(i)))))
-        self.files = {'image_paths': image_paths, 'warped_image_paths': warped_image_paths, 'homography': homographies}
+        self.files = {
+            "image_paths": image_paths,
+            "warped_image_paths": warped_image_paths,
+            "homography": homographies,
+        }
 
     def scale_homography(self, homography, original_scale, new_scale, pre):
         scales = np.divide(new_scale, original_scale)
         if pre:
-            s = np.diag(np.append(scales, 1.))
+            s = np.diag(np.append(scales, 1.0))
             homography = np.matmul(s, homography)
         else:
-            sinv = np.diag(np.append(1. / scales, 1.))
+            sinv = np.diag(np.append(1.0 / scales, 1.0))
             homography = np.matmul(homography, sinv)
         return homography
 
     def __len__(self):
-        return len(self.files['image_paths'])
+        return len(self.files["image_paths"])
 
     def __getitem__(self, idx):
-
         def _read_image(path):
             img = cv2.imread(path, cv2.IMREAD_COLOR)
             if self.use_color:
@@ -77,30 +88,39 @@ class PatchesDataset(Dataset):
             gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
             return gray
 
-        image = _read_image(self.files['image_paths'][idx])
+        image = _read_image(self.files["image_paths"][idx])
 
-        warped_image = _read_image(self.files['warped_image_paths'][idx])
-        homography = np.array(self.files['homography'][idx])
-        sample = {'image': image, 'warped_image': warped_image, 'homography': homography, 'index' : idx}
+        warped_image = _read_image(self.files["warped_image_paths"][idx])
+        homography = np.array(self.files["homography"][idx])
+        sample = {
+            "image": image,
+            "warped_image": warped_image,
+            "homography": homography,
+            "index": idx,
+        }
 
         # Apply transformations
         if self.output_shape is not None:
-            sample['homography'] = self.scale_homography(sample['homography'],
-                                                         sample['image'].shape[:2][::-1],
-                                                         self.output_shape,
-                                                         pre=False)
-            sample['homography'] = self.scale_homography(sample['homography'],
-                                                         sample['warped_image'].shape[:2][::-1],
-                                                         self.output_shape,
-                                                         pre=True)
+            sample["homography"] = self.scale_homography(
+                sample["homography"],
+                sample["image"].shape[:2][::-1],
+                self.output_shape,
+                pre=False,
+            )
+            sample["homography"] = self.scale_homography(
+                sample["homography"],
+                sample["warped_image"].shape[:2][::-1],
+                self.output_shape,
+                pre=True,
+            )
 
-            for key in ['image', 'warped_image']:
+            for key in ["image", "warped_image"]:
                 sample[key] = cv2.resize(sample[key], self.output_shape)
                 if self.use_color is False:
                     sample[key] = np.expand_dims(sample[key], axis=2)
 
         transform = transforms.ToTensor()
 
-        for key in ['image', 'warped_image']:
-            sample[key] = transform(sample[key]).type('torch.FloatTensor')
+        for key in ["image", "warped_image"]:
+            sample[key] = transform(sample[key]).type("torch.FloatTensor")
         return sample
diff --git a/third_party/lanet/datasets/prepare_coco.py b/third_party/lanet/datasets/prepare_coco.py
index 0468aba19c6c2c76bda1a1af2b86dc7f20176fdb..612fb400000c66476a3be796d4dcceea8bc331d4 100644
--- a/third_party/lanet/datasets/prepare_coco.py
+++ b/third_party/lanet/datasets/prepare_coco.py
@@ -1,26 +1,24 @@
 import os
 import argparse
 
+
 def prepare_coco(args):
-    train_file = open(os.path.join(args.saved_dir, args.saved_txt), 'w')
+    train_file = open(os.path.join(args.saved_dir, args.saved_txt), "w")
     dirs = os.listdir(args.raw_dir)
 
     for file in dirs:
         # Write training files
-        train_file.write('%s\n' % (file))
+        train_file.write("%s\n" % (file))
+
+    print("Data Preparation Finished.")
 
-    print('Data Preparation Finished.')
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     arg_parser = argparse.ArgumentParser(description="coco prepareing.")
-    arg_parser.add_argument('--dataset', type=str, default='coco',
-                             help='')
-    arg_parser.add_argument('--raw_dir', type=str, default='',
-                             help='')
-    arg_parser.add_argument('--saved_dir', type=str, default='',
-                             help='')
-    arg_parser.add_argument('--saved_txt', type=str, default='train2017.txt',
-                             help='')
-    args = arg_parser.parse_args() 
+    arg_parser.add_argument("--dataset", type=str, default="coco", help="")
+    arg_parser.add_argument("--raw_dir", type=str, default="", help="")
+    arg_parser.add_argument("--saved_dir", type=str, default="", help="")
+    arg_parser.add_argument("--saved_txt", type=str, default="train2017.txt", help="")
+    args = arg_parser.parse_args()
 
-    prepare_coco(args)
\ No newline at end of file
+    prepare_coco(args)
diff --git a/third_party/lanet/evaluation/descriptor_evaluation.py b/third_party/lanet/evaluation/descriptor_evaluation.py
index c0e1f84199d353ac5858641c8f68bc298f9d6413..924918a64e769e0b4e661366a0b7d59a2f819ec5 100644
--- a/third_party/lanet/evaluation/descriptor_evaluation.py
+++ b/third_party/lanet/evaluation/descriptor_evaluation.py
@@ -12,7 +12,7 @@ from utils import warp_keypoints
 
 
 def select_k_best(points, descriptors, k):
-    """ Select the k most probable points (and strip their probability).
+    """Select the k most probable points (and strip their probability).
     points has shape (num_points, 3) where the last coordinate is the probability.
 
     Parameters
@@ -25,7 +25,7 @@ def select_k_best(points, descriptors, k):
         Number of keypoints to select, based on probability.
     Returns
     -------
-    
+
     selected_points: numpy.ndarray (k,2)
         k most probable keypoints.
     selected_descriptors: numpy.ndarray (k,256)
@@ -44,7 +44,7 @@ def keep_shared_points(keypoints, descriptors, H, shape, keep_k_points=1000):
     Compute a list of keypoints from the map, filter the list of points by keeping
     only the points that once mapped by H are still inside the shape of the map
     and keep at most 'keep_k_points' keypoints in the image.
-    
+
     Parameters
     ----------
     keypoints: numpy.ndarray (N,3)
@@ -53,36 +53,44 @@ def keep_shared_points(keypoints, descriptors, H, shape, keep_k_points=1000):
         Keypoint descriptors.
     H: numpy.ndarray (3,3)
         Homography.
-    shape: tuple 
+    shape: tuple
         Image shape.
     keep_k_points: int
         Number of keypoints to select, based on probability.
 
     Returns
-    -------    
+    -------
     selected_points: numpy.ndarray (k,2)
         k most probable keypoints.
     selected_descriptors: numpy.ndarray (k,256)
         Descriptors corresponding to the k most probable keypoints.
     """
-    
+
     def keep_true_keypoints(points, descriptors, H, shape):
-        """ Keep only the points whose warped coordinates by H are still inside shape. """
+        """Keep only the points whose warped coordinates by H are still inside shape."""
         warped_points = warp_keypoints(points[:, [1, 0]], H)
         warped_points[:, [0, 1]] = warped_points[:, [1, 0]]
-        mask = (warped_points[:, 0] >= 0) & (warped_points[:, 0] < shape[0]) &\
-               (warped_points[:, 1] >= 0) & (warped_points[:, 1] < shape[1])
+        mask = (
+            (warped_points[:, 0] >= 0)
+            & (warped_points[:, 0] < shape[0])
+            & (warped_points[:, 1] >= 0)
+            & (warped_points[:, 1] < shape[1])
+        )
         return points[mask, :], descriptors[mask, :]
 
-    selected_keypoints, selected_descriptors = keep_true_keypoints(keypoints, descriptors, H, shape)
-    selected_keypoints, selected_descriptors = select_k_best(selected_keypoints, selected_descriptors, keep_k_points)
+    selected_keypoints, selected_descriptors = keep_true_keypoints(
+        keypoints, descriptors, H, shape
+    )
+    selected_keypoints, selected_descriptors = select_k_best(
+        selected_keypoints, selected_descriptors, keep_k_points
+    )
     return selected_keypoints, selected_descriptors
 
 
 def compute_matching_score(data, keep_k_points=1000):
     """
     Compute the matching score between two sets of keypoints with associated descriptors.
-    
+
     Parameters
     ----------
     data: dict
@@ -103,31 +111,35 @@ def compute_matching_score(data, keep_k_points=1000):
         Number of keypoints to select, based on probability.
 
     Returns
-    -------    
+    -------
     ms: float
         Matching score.
     """
-    shape = data['image_shape']
-    real_H = data['homography']
+    shape = data["image_shape"]
+    real_H = data["homography"]
 
     # Filter out predictions
-    keypoints = data['prob'][:, :2].T
+    keypoints = data["prob"][:, :2].T
     keypoints = keypoints[::-1]
-    prob = data['prob'][:, 2]
+    prob = data["prob"][:, 2]
     keypoints = np.stack([keypoints[0], keypoints[1], prob], axis=-1)
 
-    warped_keypoints = data['warped_prob'][:, :2].T
+    warped_keypoints = data["warped_prob"][:, :2].T
     warped_keypoints = warped_keypoints[::-1]
-    warped_prob = data['warped_prob'][:, 2]
-    warped_keypoints = np.stack([warped_keypoints[0], warped_keypoints[1], warped_prob], axis=-1)
+    warped_prob = data["warped_prob"][:, 2]
+    warped_keypoints = np.stack(
+        [warped_keypoints[0], warped_keypoints[1], warped_prob], axis=-1
+    )
+
+    desc = data["desc"]
+    warped_desc = data["warped_desc"]
 
-    desc = data['desc']
-    warped_desc = data['warped_desc']
-    
     # Keeps all points for the next frame. The matching for caculating M.Score shouldnt use only in view points.
-    keypoints,        desc        = select_k_best(keypoints,               desc, keep_k_points)
-    warped_keypoints, warped_desc = select_k_best(warped_keypoints, warped_desc, keep_k_points)
-    
+    keypoints, desc = select_k_best(keypoints, desc, keep_k_points)
+    warped_keypoints, warped_desc = select_k_best(
+        warped_keypoints, warped_desc, keep_k_points
+    )
+
     # Match the keypoints with the warped_keypoints with nearest neighbor search
     # This part needs to be done with crossCheck=False.
     # All the matched pairs need to be evaluated without any selection.
@@ -139,11 +151,16 @@ def compute_matching_score(data, keep_k_points=1000):
     matches_idx = np.array([m.trainIdx for m in matches])
     m_warped_keypoints = warped_keypoints[matches_idx, :]
 
-    true_warped_keypoints = warp_keypoints(m_warped_keypoints[:, [1, 0]], np.linalg.inv(real_H))[:,::-1]
-    vis_warped = np.all((true_warped_keypoints >= 0) & (true_warped_keypoints <= (np.array(shape)-1)), axis=-1)
+    true_warped_keypoints = warp_keypoints(
+        m_warped_keypoints[:, [1, 0]], np.linalg.inv(real_H)
+    )[:, ::-1]
+    vis_warped = np.all(
+        (true_warped_keypoints >= 0) & (true_warped_keypoints <= (np.array(shape) - 1)),
+        axis=-1,
+    )
     norm1 = np.linalg.norm(true_warped_keypoints - m_keypoints, axis=-1)
 
-    correct1 = (norm1 < 3)
+    correct1 = norm1 < 3
     count1 = np.sum(correct1 * vis_warped)
     score1 = count1 / np.maximum(np.sum(vis_warped), 1.0)
 
@@ -153,11 +170,13 @@ def compute_matching_score(data, keep_k_points=1000):
     matches_idx = np.array([m.trainIdx for m in matches])
     m_keypoints = keypoints[matches_idx, :]
 
-    true_keypoints = warp_keypoints(m_keypoints[:, [1, 0]], real_H)[:,::-1]
-    vis = np.all((true_keypoints >= 0) & (true_keypoints <= (np.array(shape)-1)), axis=-1)
+    true_keypoints = warp_keypoints(m_keypoints[:, [1, 0]], real_H)[:, ::-1]
+    vis = np.all(
+        (true_keypoints >= 0) & (true_keypoints <= (np.array(shape) - 1)), axis=-1
+    )
     norm2 = np.linalg.norm(true_keypoints - m_warped_keypoints, axis=-1)
 
-    correct2 = (norm2 < 3)
+    correct2 = norm2 < 3
     count2 = np.sum(correct2 * vis)
     score2 = count2 / np.maximum(np.sum(vis), 1.0)
 
@@ -165,9 +184,10 @@ def compute_matching_score(data, keep_k_points=1000):
 
     return ms
 
+
 def compute_homography(data, keep_k_points=1000):
     """
-    Compute the homography between 2 sets of Keypoints and descriptors inside data. 
+    Compute the homography between 2 sets of Keypoints and descriptors inside data.
     Use the homography to compute the correctness metrics (1,3,5).
 
     Parameters
@@ -190,7 +210,7 @@ def compute_homography(data, keep_k_points=1000):
         Number of keypoints to select, based on probability.
 
     Returns
-    -------    
+    -------
     correctness1: float
         correctness1 metric.
     correctness3: float
@@ -198,27 +218,30 @@ def compute_homography(data, keep_k_points=1000):
     correctness5: float
         correctness5 metric.
     """
-    shape = data['image_shape']
-    real_H = data['homography']
+    shape = data["image_shape"]
+    real_H = data["homography"]
 
     # Filter out predictions
-    keypoints = data['prob'][:, :2].T
+    keypoints = data["prob"][:, :2].T
     keypoints = keypoints[::-1]
-    prob = data['prob'][:, 2]
+    prob = data["prob"][:, 2]
     keypoints = np.stack([keypoints[0], keypoints[1], prob], axis=-1)
 
-    warped_keypoints = data['warped_prob'][:, :2].T
+    warped_keypoints = data["warped_prob"][:, :2].T
     warped_keypoints = warped_keypoints[::-1]
-    warped_prob = data['warped_prob'][:, 2]
-    warped_keypoints = np.stack([warped_keypoints[0], warped_keypoints[1], warped_prob], axis=-1)
+    warped_prob = data["warped_prob"][:, 2]
+    warped_keypoints = np.stack(
+        [warped_keypoints[0], warped_keypoints[1], warped_prob], axis=-1
+    )
+
+    desc = data["desc"]
+    warped_desc = data["warped_desc"]
 
-    desc = data['desc']
-    warped_desc = data['warped_desc']
-    
     # Keeps only the points shared between the two views
     keypoints, desc = keep_shared_points(keypoints, desc, real_H, shape, keep_k_points)
-    warped_keypoints, warped_desc = keep_shared_points(warped_keypoints, warped_desc, np.linalg.inv(real_H), shape,
-                                                       keep_k_points)
+    warped_keypoints, warped_desc = keep_shared_points(
+        warped_keypoints, warped_desc, np.linalg.inv(real_H), shape, keep_k_points
+    )
 
     bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
     matches = bf.match(desc, warped_desc)
@@ -228,8 +251,13 @@ def compute_homography(data, keep_k_points=1000):
     m_warped_keypoints = warped_keypoints[matches_idx, :]
 
     # Estimate the homography between the matches using RANSAC
-    H, _ = cv2.findHomography(m_keypoints[:, [1, 0]],
-                              m_warped_keypoints[:, [1, 0]], cv2.RANSAC, 3, maxIters=5000)
+    H, _ = cv2.findHomography(
+        m_keypoints[:, [1, 0]],
+        m_warped_keypoints[:, [1, 0]],
+        cv2.RANSAC,
+        3,
+        maxIters=5000,
+    )
 
     if H is None:
         return 0, 0, 0
@@ -237,15 +265,19 @@ def compute_homography(data, keep_k_points=1000):
     shape = shape[::-1]
 
     # Compute correctness
-    corners = np.array([[0, 0, 1],
-                        [0, shape[1] - 1, 1],
-                        [shape[0] - 1, 0, 1],
-                        [shape[0] - 1, shape[1] - 1, 1]])
+    corners = np.array(
+        [
+            [0, 0, 1],
+            [0, shape[1] - 1, 1],
+            [shape[0] - 1, 0, 1],
+            [shape[0] - 1, shape[1] - 1, 1],
+        ]
+    )
     real_warped_corners = np.dot(corners, np.transpose(real_H))
     real_warped_corners = real_warped_corners[:, :2] / real_warped_corners[:, 2:]
     warped_corners = np.dot(corners, np.transpose(H))
     warped_corners = warped_corners[:, :2] / warped_corners[:, 2:]
-    
+
     mean_dist = np.mean(np.linalg.norm(real_warped_corners - warped_corners, axis=1))
     correctness1 = float(mean_dist <= 1)
     correctness3 = float(mean_dist <= 3)
diff --git a/third_party/lanet/evaluation/detector_evaluation.py b/third_party/lanet/evaluation/detector_evaluation.py
index ccc8792d17a6fbb6b446f0f9f84a2b82e3cdb57c..7198eaec0e6042baf111208885f4040311cc605e 100644
--- a/third_party/lanet/evaluation/detector_evaluation.py
+++ b/third_party/lanet/evaluation/detector_evaluation.py
@@ -33,7 +33,7 @@ def compute_repeatability(data, keep_k_points=300, distance_thresh=3):
         Distance threshold in pixels for a corresponding keypoint to be considered a correct match.
 
     Returns
-    -------    
+    -------
     N1: int
         Number of true keypoints in the first image.
     N2: int
@@ -43,47 +43,59 @@ def compute_repeatability(data, keep_k_points=300, distance_thresh=3):
     loc_err: float
         Keypoint localization error.
     """
+
     def filter_keypoints(points, shape):
-        """ Keep only the points whose coordinates are inside the dimensions of shape. """
-        mask = (points[:, 0] >= 0) & (points[:, 0] < shape[0]) &\
-               (points[:, 1] >= 0) & (points[:, 1] < shape[1])
+        """Keep only the points whose coordinates are inside the dimensions of shape."""
+        mask = (
+            (points[:, 0] >= 0)
+            & (points[:, 0] < shape[0])
+            & (points[:, 1] >= 0)
+            & (points[:, 1] < shape[1])
+        )
         return points[mask, :]
 
     def keep_true_keypoints(points, H, shape):
-        """ Keep only the points whose warped coordinates by H are still inside shape. """
+        """Keep only the points whose warped coordinates by H are still inside shape."""
         warped_points = warp_keypoints(points[:, [1, 0]], H)
         warped_points[:, [0, 1]] = warped_points[:, [1, 0]]
-        mask = (warped_points[:, 0] >= 0) & (warped_points[:, 0] < shape[0]) &\
-               (warped_points[:, 1] >= 0) & (warped_points[:, 1] < shape[1])
+        mask = (
+            (warped_points[:, 0] >= 0)
+            & (warped_points[:, 0] < shape[0])
+            & (warped_points[:, 1] >= 0)
+            & (warped_points[:, 1] < shape[1])
+        )
         return points[mask, :]
 
-
     def select_k_best(points, k):
-        """ Select the k most probable points (and strip their probability).
-        points has shape (num_points, 3) where the last coordinate is the probability. """
+        """Select the k most probable points (and strip their probability).
+        points has shape (num_points, 3) where the last coordinate is the probability."""
         sorted_prob = points[points[:, 2].argsort(), :2]
         start = min(k, points.shape[0])
         return sorted_prob[-start:, :]
 
-    H = data['homography']
-    shape = data['image_shape']
+    H = data["homography"]
+    shape = data["image_shape"]
 
     # # Filter out predictions
-    keypoints = data['prob'][:, :2].T
+    keypoints = data["prob"][:, :2].T
     keypoints = keypoints[::-1]
-    prob = data['prob'][:, 2]
+    prob = data["prob"][:, 2]
 
-    warped_keypoints = data['warped_prob'][:, :2].T
+    warped_keypoints = data["warped_prob"][:, :2].T
     warped_keypoints = warped_keypoints[::-1]
-    warped_prob = data['warped_prob'][:, 2]
+    warped_prob = data["warped_prob"][:, 2]
 
     keypoints = np.stack([keypoints[0], keypoints[1]], axis=-1)
-    warped_keypoints = np.stack([warped_keypoints[0], warped_keypoints[1], warped_prob], axis=-1)
+    warped_keypoints = np.stack(
+        [warped_keypoints[0], warped_keypoints[1], warped_prob], axis=-1
+    )
     warped_keypoints = keep_true_keypoints(warped_keypoints, np.linalg.inv(H), shape)
 
     # Warp the original keypoints with the true homography
     true_warped_keypoints = warp_keypoints(keypoints[:, [1, 0]], H)
-    true_warped_keypoints = np.stack([true_warped_keypoints[:, 1], true_warped_keypoints[:, 0], prob], axis=-1)
+    true_warped_keypoints = np.stack(
+        [true_warped_keypoints[:, 1], true_warped_keypoints[:, 0], prob], axis=-1
+    )
     true_warped_keypoints = filter_keypoints(true_warped_keypoints, shape)
 
     # Keep only the keep_k_points best predictions
@@ -103,12 +115,12 @@ def compute_repeatability(data, keep_k_points=300, distance_thresh=3):
     le2 = 0
     if N2 != 0:
         min1 = np.min(norm, axis=1)
-        correct1 = (min1 <= distance_thresh)
+        correct1 = min1 <= distance_thresh
         count1 = np.sum(correct1)
         le1 = min1[correct1].sum()
     if N1 != 0:
         min2 = np.min(norm, axis=0)
-        correct2 = (min2 <= distance_thresh)
+        correct2 = min2 <= distance_thresh
         count2 = np.sum(correct2)
         le2 = min2[correct2].sum()
     if N1 + N2 > 0:
diff --git a/third_party/lanet/evaluation/evaluate.py b/third_party/lanet/evaluation/evaluate.py
index fa9e91ee6d9cc0142ebbe8f2a3f904f6fae8434c..06bec8e5e01b8d285622e6c1eca9000f2a0541cb 100644
--- a/third_party/lanet/evaluation/evaluate.py
+++ b/third_party/lanet/evaluation/evaluate.py
@@ -5,24 +5,25 @@ import torch
 import torchvision.transforms as transforms
 from tqdm import tqdm
 
-from evaluation.descriptor_evaluation import (compute_homography,
-                                                   compute_matching_score)
+from evaluation.descriptor_evaluation import compute_homography, compute_matching_score
 from evaluation.detector_evaluation import compute_repeatability
 
 
-def evaluate_keypoint_net(data_loader, keypoint_net, output_shape=(320, 240), top_k=300):
-    """Keypoint net evaluation script. 
+def evaluate_keypoint_net(
+    data_loader, keypoint_net, output_shape=(320, 240), top_k=300
+):
+    """Keypoint net evaluation script.
 
     Parameters
     ----------
     data_loader: torch.utils.data.DataLoader
-        Dataset loader. 
+        Dataset loader.
     keypoint_net: torch.nn.module
         Keypoint network.
     output_shape: tuple
         Original image shape.
     top_k: int
-        Number of keypoints to use to compute metrics, selected based on probability.    
+        Number of keypoints to use to compute metrics, selected based on probability.
     use_color: bool
         Use color or grayscale images.
     """
@@ -36,8 +37,8 @@ def evaluate_keypoint_net(data_loader, keypoint_net, output_shape=(320, 240), to
     with torch.no_grad():
         for i, sample in tqdm(enumerate(data_loader), desc="Evaluate point model"):
 
-            image = sample['image'].cuda()
-            warped_image = sample['warped_image'].cuda()
+            image = sample["image"].cuda()
+            warped_image = sample["warped_image"].cuda()
 
             score_1, coord_1, desc1 = keypoint_net(image)
             score_2, coord_2, desc2 = keypoint_net(warped_image)
@@ -48,7 +49,7 @@ def evaluate_keypoint_net(data_loader, keypoint_net, output_shape=(320, 240), to
             score_2 = torch.cat([coord_2, score_2], dim=1).view(3, -1).t().cpu().numpy()
             desc1 = desc1.view(256, Hc, Wc).view(256, -1).t().cpu().numpy()
             desc2 = desc2.view(256, Hc, Wc).view(256, -1).t().cpu().numpy()
-            
+
             # Filter based on confidence threshold
             desc1 = desc1[score_1[:, 2] > conf_threshold, :]
             desc2 = desc2[score_2[:, 2] > conf_threshold, :]
@@ -56,17 +57,21 @@ def evaluate_keypoint_net(data_loader, keypoint_net, output_shape=(320, 240), to
             score_2 = score_2[score_2[:, 2] > conf_threshold, :]
 
             # Prepare data for eval
-            data = {'image': sample['image'].numpy().squeeze(),
-                    'image_shape' : output_shape[::-1],
-                    'warped_image': sample['warped_image'].numpy().squeeze(),
-                    'homography': sample['homography'].squeeze().numpy(),
-                    'prob': score_1, 
-                    'warped_prob': score_2,
-                    'desc': desc1,
-                    'warped_desc': desc2}
-            
+            data = {
+                "image": sample["image"].numpy().squeeze(),
+                "image_shape": output_shape[::-1],
+                "warped_image": sample["warped_image"].numpy().squeeze(),
+                "homography": sample["homography"].squeeze().numpy(),
+                "prob": score_1,
+                "warped_prob": score_2,
+                "desc": desc1,
+                "warped_desc": desc2,
+            }
+
             # Compute repeatabilty and localization error
-            _, _, rep, loc_err = compute_repeatability(data, keep_k_points=top_k, distance_thresh=3)
+            _, _, rep, loc_err = compute_repeatability(
+                data, keep_k_points=top_k, distance_thresh=3
+            )
             repeatability.append(rep)
             localization_err.append(loc_err)
 
@@ -80,5 +85,11 @@ def evaluate_keypoint_net(data_loader, keypoint_net, output_shape=(320, 240), to
             mscore = compute_matching_score(data, keep_k_points=top_k)
             MScore.append(mscore)
 
-    return np.mean(repeatability), np.mean(localization_err), \
-           np.mean(correctness1), np.mean(correctness3), np.mean(correctness5), np.mean(MScore)
+    return (
+        np.mean(repeatability),
+        np.mean(localization_err),
+        np.mean(correctness1),
+        np.mean(correctness3),
+        np.mean(correctness5),
+        np.mean(MScore),
+    )
diff --git a/third_party/lanet/loss_function.py b/third_party/lanet/loss_function.py
index 2e74cf2b53af3c3fc26c34394df4cfe538b3b49c..b5a40c3a969f8e7725e2f30d453762a0eca6b062 100644
--- a/third_party/lanet/loss_function.py
+++ b/third_party/lanet/loss_function.py
@@ -1,6 +1,9 @@
 import torch
 
-def build_descriptor_loss(source_des, target_des, tar_points_un, top_kk=None, relax_field=4, eval_only=False):
+
+def build_descriptor_loss(
+    source_des, target_des, tar_points_un, top_kk=None, relax_field=4, eval_only=False
+):
     """
     Desc Head Loss, per-pixel level triplet loss from https://arxiv.org/pdf/1902.11046.pdf.
 
@@ -10,12 +13,12 @@ def build_descriptor_loss(source_des, target_des, tar_points_un, top_kk=None, re
         Source image descriptors.
     target_des: torch.Tensor (B,256,H/8,W/8)
         Target image descriptors.
-    source_points: torch.Tensor (B,H/8,W/8,2) 
-        Source image keypoints 
+    source_points: torch.Tensor (B,H/8,W/8,2)
+        Source image keypoints
     tar_points: torch.Tensor (B,H/8,W/8,2)
-        Target image keypoints 
+        Target image keypoints
     tar_points_un: torch.Tensor (B,2,H/8,W/8)
-        Target image keypoints unnormalized 
+        Target image keypoints unnormalized
     eval_only: bool
         Computes only recall without the loss.
     Returns
@@ -28,11 +31,11 @@ def build_descriptor_loss(source_des, target_des, tar_points_un, top_kk=None, re
     device = source_des.device
     loss = 0
     batch_size = source_des.size(0)
-    recall = 0.
+    recall = 0.0
 
     relax_field_size = [relax_field]
-    margins          = [1.0]
-    weights          = [1.0]
+    margins = [1.0]
+    weights = [1.0]
 
     isource_dense = top_kk is None
 
@@ -50,7 +53,7 @@ def build_descriptor_loss(source_des, target_des, tar_points_un, top_kk=None, re
                 continue
 
             ref_desc = source_des[b_id].squeeze()[:, top_k]
-            tar_desc = target_des[b_id].squeeze()[:, top_k]         
+            tar_desc = target_des[b_id].squeeze()[:, top_k]
             tar_points_raw = tar_points_un[b_id][:, top_k]
 
         # Compute dense descriptor distance matrix and find nearest neighbor
@@ -61,7 +64,6 @@ def build_descriptor_loss(source_des, target_des, tar_points_un, top_kk=None, re
         dmat = torch.sqrt(2 - 2 * torch.clamp(dmat, min=-1, max=1))
         _, idx = torch.sort(dmat, dim=1)
 
-
         # Compute triplet loss and recall
         for pyramid in range(len(relax_field_size)):
 
@@ -74,24 +76,41 @@ def build_descriptor_loss(source_des, target_des, tar_points_un, top_kk=None, re
             tru_y = tar_points_raw[1]
 
             if pyramid == 0:
-                correct2 = (abs(match_k_x[0]-tru_x) == 0) & (abs(match_k_y[0]-tru_y) == 0)
+                correct2 = (abs(match_k_x[0] - tru_x) == 0) & (
+                    abs(match_k_y[0] - tru_y) == 0
+                )
                 correct2_cnt = correct2.float().sum()
-                recall += float(1.0 / batch_size) * (float(correct2_cnt) / float( ref_desc.size(1)))
+                recall += float(1.0 / batch_size) * (
+                    float(correct2_cnt) / float(ref_desc.size(1))
+                )
 
             if eval_only:
                 continue
-            correct_k = (abs(match_k_x - tru_x) <= relax_field_size[pyramid]) & (abs(match_k_y - tru_y) <= relax_field_size[pyramid])
-
-            incorrect_index = torch.arange(start=correct_k.shape[0]-1, end=-1, step=-1).unsqueeze(1).repeat(1,correct_k.shape[1]).to(device)
-            incorrect_first = torch.argmax(incorrect_index * (1 - correct_k.long()), dim=0)
-
-            incorrect_first_index = candidates.gather(0, incorrect_first.unsqueeze(0)).squeeze()
+            correct_k = (abs(match_k_x - tru_x) <= relax_field_size[pyramid]) & (
+                abs(match_k_y - tru_y) <= relax_field_size[pyramid]
+            )
+
+            incorrect_index = (
+                torch.arange(start=correct_k.shape[0] - 1, end=-1, step=-1)
+                .unsqueeze(1)
+                .repeat(1, correct_k.shape[1])
+                .to(device)
+            )
+            incorrect_first = torch.argmax(
+                incorrect_index * (1 - correct_k.long()), dim=0
+            )
+
+            incorrect_first_index = candidates.gather(
+                0, incorrect_first.unsqueeze(0)
+            ).squeeze()
 
             anchor_var = ref_desc
             posource_var = tar_desc
             neg_var = tar_desc[:, incorrect_first_index]
 
-            loss += float(1.0 / batch_size) * torch.nn.functional.triplet_margin_loss(anchor_var.t(), posource_var.t(), neg_var.t(), margin=margins[pyramid]).mul(weights[pyramid])
+            loss += float(1.0 / batch_size) * torch.nn.functional.triplet_margin_loss(
+                anchor_var.t(), posource_var.t(), neg_var.t(), margin=margins[pyramid]
+            ).mul(weights[pyramid])
 
     return loss, recall
 
@@ -100,57 +119,108 @@ class KeypointLoss(object):
     """
     Loss function class encapsulating the location loss, the descriptor loss, and the score loss.
     """
+
     def __init__(self, config):
         self.score_weight = config.score_weight
         self.loc_weight = config.loc_weight
         self.desc_weight = config.desc_weight
         self.corres_weight = config.corres_weight
         self.corres_threshold = config.corres_threshold
-        
+
     def __call__(self, data):
-        B, _, hc, wc = data['source_score'].shape
-        
-        loc_mat_abs = torch.abs(data['target_coord_warped'].view(B, 2, -1).unsqueeze(3) - data['target_coord'].view(B, 2, -1).unsqueeze(2))
+        B, _, hc, wc = data["source_score"].shape
+
+        loc_mat_abs = torch.abs(
+            data["target_coord_warped"].view(B, 2, -1).unsqueeze(3)
+            - data["target_coord"].view(B, 2, -1).unsqueeze(2)
+        )
         l2_dist_loc_mat = torch.norm(loc_mat_abs, p=2, dim=1)
         l2_dist_loc_min, l2_dist_loc_min_index = l2_dist_loc_mat.min(dim=2)
 
         # construct pseudo ground truth matching matrix
-        loc_min_mat = torch.repeat_interleave(l2_dist_loc_min.unsqueeze(dim=-1), repeats=l2_dist_loc_mat.shape[-1], dim=-1)
-        pos_mask = l2_dist_loc_mat.eq(loc_min_mat) & l2_dist_loc_mat.le(1.)
-        neg_mask = l2_dist_loc_mat.ge(4.)
-
-        pos_corres = - torch.log(data['confidence_matrix'][pos_mask])
-        neg_corres = - torch.log(1.0 - data['confidence_matrix'][neg_mask])
+        loc_min_mat = torch.repeat_interleave(
+            l2_dist_loc_min.unsqueeze(dim=-1), repeats=l2_dist_loc_mat.shape[-1], dim=-1
+        )
+        pos_mask = l2_dist_loc_mat.eq(loc_min_mat) & l2_dist_loc_mat.le(1.0)
+        neg_mask = l2_dist_loc_mat.ge(4.0)
+
+        pos_corres = -torch.log(data["confidence_matrix"][pos_mask])
+        neg_corres = -torch.log(1.0 - data["confidence_matrix"][neg_mask])
         corres_loss = pos_corres.mean() + 5e5 * neg_corres.mean()
 
         # corresponding distance threshold is 4
-        dist_norm_valid_mask = l2_dist_loc_min.lt(self.corres_threshold) & data['border_mask'].view(B, hc * wc)
-        
+        dist_norm_valid_mask = l2_dist_loc_min.lt(self.corres_threshold) & data[
+            "border_mask"
+        ].view(B, hc * wc)
+
         # location loss
         loc_loss = l2_dist_loc_min[dist_norm_valid_mask].mean()
-        
+
         # desc Head Loss, per-pixel level triplet loss from https://arxiv.org/pdf/1902.11046.pdf.
-        desc_loss, _ = build_descriptor_loss(data['source_desc'], data['target_desc_warped'], data['target_coord_warped'].detach(), top_kk=data['border_mask'], relax_field=8)
-        
+        desc_loss, _ = build_descriptor_loss(
+            data["source_desc"],
+            data["target_desc_warped"],
+            data["target_coord_warped"].detach(),
+            top_kk=data["border_mask"],
+            relax_field=8,
+        )
+
         # score loss
-        target_score_associated = data['target_score'].view(B, hc * wc).gather(1, l2_dist_loc_min_index).view(B, hc, wc).unsqueeze(1)
-        dist_norm_valid_mask = dist_norm_valid_mask.view(B, hc, wc).unsqueeze(1) & data['border_mask'].unsqueeze(1) 
+        target_score_associated = (
+            data["target_score"]
+            .view(B, hc * wc)
+            .gather(1, l2_dist_loc_min_index)
+            .view(B, hc, wc)
+            .unsqueeze(1)
+        )
+        dist_norm_valid_mask = dist_norm_valid_mask.view(B, hc, wc).unsqueeze(1) & data[
+            "border_mask"
+        ].unsqueeze(1)
         l2_dist_loc_min = l2_dist_loc_min.view(B, hc, wc).unsqueeze(1)
         loc_err = l2_dist_loc_min[dist_norm_valid_mask]
-        
+
         # repeatable_constrain in score loss
-        repeatable_constrain = ((target_score_associated[dist_norm_valid_mask] + data['source_score'][dist_norm_valid_mask]) * (loc_err - loc_err.mean())).mean()
+        repeatable_constrain = (
+            (
+                target_score_associated[dist_norm_valid_mask]
+                + data["source_score"][dist_norm_valid_mask]
+            )
+            * (loc_err - loc_err.mean())
+        ).mean()
 
         # consistent_constrain in score_loss
-        consistent_constrain = torch.nn.functional.mse_loss(data['target_score_warped'][data['border_mask'].unsqueeze(1)], data['source_score'][data['border_mask'].unsqueeze(1)]).mean() * 2
-        aware_consistent_loss = torch.nn.functional.mse_loss(data['target_aware_warped'][data['border_mask'].unsqueeze(1).repeat(1, 2, 1, 1)], data['source_aware'][data['border_mask'].unsqueeze(1).repeat(1, 2, 1, 1)]).mean() * 2
-        
-        score_loss = repeatable_constrain + consistent_constrain + aware_consistent_loss
-        
-        loss = self.loc_weight * loc_loss + self.desc_weight * desc_loss + self.score_weight * score_loss + self.corres_weight * corres_loss
-        
-        return loss, self.loc_weight * loc_loss, self.desc_weight * desc_loss, self.score_weight * score_loss, self.corres_weight * corres_loss
-
-        
+        consistent_constrain = (
+            torch.nn.functional.mse_loss(
+                data["target_score_warped"][data["border_mask"].unsqueeze(1)],
+                data["source_score"][data["border_mask"].unsqueeze(1)],
+            ).mean()
+            * 2
+        )
+        aware_consistent_loss = (
+            torch.nn.functional.mse_loss(
+                data["target_aware_warped"][
+                    data["border_mask"].unsqueeze(1).repeat(1, 2, 1, 1)
+                ],
+                data["source_aware"][
+                    data["border_mask"].unsqueeze(1).repeat(1, 2, 1, 1)
+                ],
+            ).mean()
+            * 2
+        )
 
+        score_loss = repeatable_constrain + consistent_constrain + aware_consistent_loss
 
+        loss = (
+            self.loc_weight * loc_loss
+            + self.desc_weight * desc_loss
+            + self.score_weight * score_loss
+            + self.corres_weight * corres_loss
+        )
+
+        return (
+            loss,
+            self.loc_weight * loc_loss,
+            self.desc_weight * desc_loss,
+            self.score_weight * score_loss,
+            self.corres_weight * corres_loss,
+        )
diff --git a/third_party/lanet/main.py b/third_party/lanet/main.py
index 2aa81d8104c19ea1d8c4ce7d1dd547f8b35a4a72..b48dc074a2fd6d4240e126268bcd8e0d8d313d1c 100644
--- a/third_party/lanet/main.py
+++ b/third_party/lanet/main.py
@@ -5,6 +5,7 @@ from config import get_config
 from utils import prepare_dirs
 from data_loader import get_data_loader
 
+
 def main(config):
     # ensure directories are setup
     prepare_dirs(config)
@@ -20,6 +21,7 @@ def main(config):
     trainer = Trainer(config, train_loader=train_loader)
     trainer.train()
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     config, unparsed = get_config()
-    main(config)
\ No newline at end of file
+    main(config)
diff --git a/third_party/lanet/network_v0/model.py b/third_party/lanet/network_v0/model.py
index 564000330ddd5e9f18821e8606d23cd12dc847bc..6f22e015449dd7bcc8e060a2cd72a794befd2ccb 100644
--- a/third_party/lanet/network_v0/model.py
+++ b/third_party/lanet/network_v0/model.py
@@ -4,6 +4,7 @@ import torchvision.transforms as tvf
 
 from .modules import InterestPointModule, CorrespondenceModule
 
+
 def warp_homography_batch(sources, homographies):
     """
     Batch warp keypoints given homographies. From https://github.com/TRI-ML/KP2D.
@@ -24,18 +25,19 @@ def warp_homography_batch(sources, homographies):
     warped_sources = []
     for b in range(B):
         source = sources[b].clone()
-        source = source.view(-1,2)
-        '''
+        source = source.view(-1, 2)
+        """
         [X,    [M11, M12, M13    [x,    M11*x + M12*y + M13           [M11, M12      [M13,
          Y,  =  M21, M22, M23  *  y, =  M21*x + M22*y + M23 = [x, y] * M21, M22    +  M23,
          Z]     M31, M32, M33]    1]    M31*x + M32*y + M33            M31, M32].T    M33]
-        '''
-        source = torch.addmm(homographies[b,:,2], source, homographies[b,:,:2].t())
-        source.mul_(1/source[:,2].unsqueeze(1))
-        source = source[:,:2].contiguous().view(H,W,2)
+        """
+        source = torch.addmm(homographies[b, :, 2], source, homographies[b, :, :2].t())
+        source.mul_(1 / source[:, 2].unsqueeze(1))
+        source = source[:, :2].contiguous().view(H, W, 2)
         warped_sources.append(source)
     return torch.stack(warped_sources, dim=0)
- 
+
+
 class PointModel(nn.Module):
     def __init__(self, is_test=True):
         super(PointModel, self).__init__()
@@ -43,7 +45,7 @@ class PointModel(nn.Module):
         self.interestpoint_module = InterestPointModule(is_test=self.is_test)
         self.correspondence_module = CorrespondenceModule()
         self.norm_rgb = tvf.Normalize(mean=[0.5, 0.5, 0.5], std=[0.225, 0.225, 0.225])
-  
+
     def forward(self, *args):
         if self.is_test:
             img = args[0]
@@ -51,8 +53,12 @@ class PointModel(nn.Module):
             score, coord, desc = self.interestpoint_module(img)
             return score, coord, desc
         else:
-            source_score, source_coord, source_desc_block = self.interestpoint_module(args[0])
-            target_score, target_coord, target_desc_block = self.interestpoint_module(args[1])
+            source_score, source_coord, source_desc_block = self.interestpoint_module(
+                args[0]
+            )
+            target_score, target_coord, target_desc_block = self.interestpoint_module(
+                args[1]
+            )
 
             B, _, H, W = args[0].shape
             B, _, hc, wc = source_score.shape
@@ -60,21 +66,33 @@ class PointModel(nn.Module):
 
             # Normalize the coordinates from ([0, h], [0, w]) to ([0, 1], [0, 1]).
             source_coord_norm = source_coord.clone()
-            source_coord_norm[:, 0] = (source_coord_norm[:, 0] / (float(W - 1) / 2.)) - 1.
-            source_coord_norm[:, 1] = (source_coord_norm[:, 1] / (float(H - 1) / 2.)) - 1.
+            source_coord_norm[:, 0] = (
+                source_coord_norm[:, 0] / (float(W - 1) / 2.0)
+            ) - 1.0
+            source_coord_norm[:, 1] = (
+                source_coord_norm[:, 1] / (float(H - 1) / 2.0)
+            ) - 1.0
             source_coord_norm = source_coord_norm.permute(0, 2, 3, 1)
 
             target_coord_norm = target_coord.clone()
-            target_coord_norm[:, 0] = (target_coord_norm[:, 0] / (float(W - 1) / 2.)) - 1.
-            target_coord_norm[:, 1] = (target_coord_norm[:, 1] / (float(H - 1) / 2.)) - 1.
+            target_coord_norm[:, 0] = (
+                target_coord_norm[:, 0] / (float(W - 1) / 2.0)
+            ) - 1.0
+            target_coord_norm[:, 1] = (
+                target_coord_norm[:, 1] / (float(H - 1) / 2.0)
+            ) - 1.0
             target_coord_norm = target_coord_norm.permute(0, 2, 3, 1)
-            
+
             target_coord_warped_norm = warp_homography_batch(source_coord_norm, args[2])
             target_coord_warped = target_coord_warped_norm.clone()
-        
+
             # de-normlize the coordinates
-            target_coord_warped[:, :, :, 0] = (target_coord_warped[:, :, :, 0] + 1) * (float(W - 1) / 2.)
-            target_coord_warped[:, :, :, 1] = (target_coord_warped[:, :, :, 1] + 1) * (float(H - 1) / 2.)
+            target_coord_warped[:, :, :, 0] = (target_coord_warped[:, :, :, 0] + 1) * (
+                float(W - 1) / 2.0
+            )
+            target_coord_warped[:, :, :, 1] = (target_coord_warped[:, :, :, 1] + 1) * (
+                float(H - 1) / 2.0
+            )
             target_coord_warped = target_coord_warped.permute(0, 3, 1, 2)
 
             # Border mask
@@ -85,44 +103,79 @@ class PointModel(nn.Module):
             border_mask_ori[:, :, wc - 1] = 0
             border_mask_ori = border_mask_ori.gt(1e-3).to(device)
 
-            oob_mask2 = target_coord_warped_norm[:, :, :, 0].lt(1) & target_coord_warped_norm[:, :, :, 0].gt(-1) & target_coord_warped_norm[:, :, :, 1].lt(1) & target_coord_warped_norm[:, :, :, 1].gt(-1)
+            oob_mask2 = (
+                target_coord_warped_norm[:, :, :, 0].lt(1)
+                & target_coord_warped_norm[:, :, :, 0].gt(-1)
+                & target_coord_warped_norm[:, :, :, 1].lt(1)
+                & target_coord_warped_norm[:, :, :, 1].gt(-1)
+            )
             border_mask = border_mask_ori & oob_mask2
 
             # score
-            target_score_warped = torch.nn.functional.grid_sample(target_score, target_coord_warped_norm.detach(), align_corners=False)
+            target_score_warped = torch.nn.functional.grid_sample(
+                target_score, target_coord_warped_norm.detach(), align_corners=False
+            )
 
             # descriptor
-            source_desc2 = torch.nn.functional.grid_sample(source_desc_block[0], source_coord_norm.detach())
-            source_desc3 = torch.nn.functional.grid_sample(source_desc_block[1], source_coord_norm.detach())
+            source_desc2 = torch.nn.functional.grid_sample(
+                source_desc_block[0], source_coord_norm.detach()
+            )
+            source_desc3 = torch.nn.functional.grid_sample(
+                source_desc_block[1], source_coord_norm.detach()
+            )
             source_aware = source_desc_block[2]
-            source_desc = torch.mul(source_desc2, source_aware[:, 0, :, :].unsqueeze(1).contiguous()) + torch.mul(source_desc3, source_aware[:, 1, :, :].unsqueeze(1).contiguous())
+            source_desc = torch.mul(
+                source_desc2, source_aware[:, 0, :, :].unsqueeze(1).contiguous()
+            ) + torch.mul(
+                source_desc3, source_aware[:, 1, :, :].unsqueeze(1).contiguous()
+            )
 
-            target_desc2 = torch.nn.functional.grid_sample(target_desc_block[0], target_coord_norm.detach())
-            target_desc3 = torch.nn.functional.grid_sample(target_desc_block[1], target_coord_norm.detach())
+            target_desc2 = torch.nn.functional.grid_sample(
+                target_desc_block[0], target_coord_norm.detach()
+            )
+            target_desc3 = torch.nn.functional.grid_sample(
+                target_desc_block[1], target_coord_norm.detach()
+            )
             target_aware = target_desc_block[2]
-            target_desc = torch.mul(target_desc2, target_aware[:, 0, :, :].unsqueeze(1).contiguous()) + torch.mul(target_desc3, target_aware[:, 1, :, :].unsqueeze(1).contiguous())
-            
-            target_desc2_warped = torch.nn.functional.grid_sample(target_desc_block[0], target_coord_warped_norm.detach())
-            target_desc3_warped = torch.nn.functional.grid_sample(target_desc_block[1], target_coord_warped_norm.detach())
-            target_aware_warped = torch.nn.functional.grid_sample(target_desc_block[2], target_coord_warped_norm.detach())
-            target_desc_warped = torch.mul(target_desc2_warped, target_aware_warped[:, 0, :, :].unsqueeze(1).contiguous()) + torch.mul(target_desc3_warped, target_aware_warped[:, 1, :, :].unsqueeze(1).contiguous())
-            
+            target_desc = torch.mul(
+                target_desc2, target_aware[:, 0, :, :].unsqueeze(1).contiguous()
+            ) + torch.mul(
+                target_desc3, target_aware[:, 1, :, :].unsqueeze(1).contiguous()
+            )
+
+            target_desc2_warped = torch.nn.functional.grid_sample(
+                target_desc_block[0], target_coord_warped_norm.detach()
+            )
+            target_desc3_warped = torch.nn.functional.grid_sample(
+                target_desc_block[1], target_coord_warped_norm.detach()
+            )
+            target_aware_warped = torch.nn.functional.grid_sample(
+                target_desc_block[2], target_coord_warped_norm.detach()
+            )
+            target_desc_warped = torch.mul(
+                target_desc2_warped,
+                target_aware_warped[:, 0, :, :].unsqueeze(1).contiguous(),
+            ) + torch.mul(
+                target_desc3_warped,
+                target_aware_warped[:, 1, :, :].unsqueeze(1).contiguous(),
+            )
+
             confidence_matrix = self.correspondence_module(source_desc, target_desc)
             confidence_matrix = torch.clamp(confidence_matrix, 1e-12, 1 - 1e-12)
-            
+
             output = {
-                'source_score': source_score,
-                'source_coord': source_coord,
-                'source_desc': source_desc,
-                'source_aware': source_aware,
-                'target_score': target_score,
-                'target_coord': target_coord,
-                'target_score_warped': target_score_warped,
-                'target_coord_warped': target_coord_warped,
-                'target_desc_warped': target_desc_warped,
-                'target_aware_warped': target_aware_warped,
-                'border_mask': border_mask,
-                'confidence_matrix': confidence_matrix
+                "source_score": source_score,
+                "source_coord": source_coord,
+                "source_desc": source_desc,
+                "source_aware": source_aware,
+                "target_score": target_score,
+                "target_coord": target_coord,
+                "target_score_warped": target_score_warped,
+                "target_coord_warped": target_coord_warped,
+                "target_desc_warped": target_desc_warped,
+                "target_aware_warped": target_aware_warped,
+                "border_mask": border_mask,
+                "confidence_matrix": confidence_matrix,
             }
-        
+
             return output
diff --git a/third_party/lanet/network_v0/modules.py b/third_party/lanet/network_v0/modules.py
index a38c53133aff8769f267cc054174361296cb3e7d..1e5410d4340369e1d701cfc65cf6e168e776d1f9 100644
--- a/third_party/lanet/network_v0/modules.py
+++ b/third_party/lanet/network_v0/modules.py
@@ -4,30 +4,53 @@ import torch.nn.functional as F
 
 from utils import image_grid
 
+
 class ConvBlock(nn.Module):
     def __init__(self, in_channels, out_channels):
         super(ConvBlock, self).__init__()
-        
+
         self.conv = nn.Sequential(
-            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
+            nn.Conv2d(
+                in_channels,
+                out_channels,
+                kernel_size=3,
+                stride=1,
+                padding=1,
+                bias=False,
+            ),
             nn.BatchNorm2d(out_channels),
             nn.ReLU(inplace=True),
-            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
+            nn.Conv2d(
+                out_channels,
+                out_channels,
+                kernel_size=3,
+                stride=1,
+                padding=1,
+                bias=False,
+            ),
             nn.BatchNorm2d(out_channels),
-            nn.ReLU(inplace=True)
+            nn.ReLU(inplace=True),
         )
-    
+
     def forward(self, x):
         return self.conv(x)
 
-       
+
 class DilationConv3x3(nn.Module):
     def __init__(self, in_channels, out_channels):
         super(DilationConv3x3, self).__init__()
-        
-        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=2, dilation=2, bias=False)
+
+        self.conv = nn.Conv2d(
+            in_channels,
+            out_channels,
+            kernel_size=3,
+            stride=1,
+            padding=2,
+            dilation=2,
+            bias=False,
+        )
         self.bn = nn.BatchNorm2d(out_channels)
-    
+
     def forward(self, x):
         x = self.conv(x)
         x = self.bn(x)
@@ -38,22 +61,26 @@ class InterestPointModule(nn.Module):
     def __init__(self, is_test=False):
         super(InterestPointModule, self).__init__()
         self.is_test = is_test
-        
+
         self.conv1 = ConvBlock(3, 32)
         self.conv2 = ConvBlock(32, 64)
         self.conv3 = ConvBlock(64, 128)
         self.conv4 = ConvBlock(128, 256)
-        
+
         self.maxpool2x2 = nn.MaxPool2d(2, 2)
-        
+
         # score head
-        self.score_conv = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
+        self.score_conv = nn.Conv2d(
+            256, 256, kernel_size=3, stride=1, padding=1, bias=False
+        )
         self.score_norm = nn.BatchNorm2d(256)
         self.score_out = nn.Conv2d(256, 3, kernel_size=3, stride=1, padding=1)
         self.softmax = nn.Softmax(dim=1)
-        
+
         # location head
-        self.loc_conv = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
+        self.loc_conv = nn.Conv2d(
+            256, 256, kernel_size=3, stride=1, padding=1, bias=False
+        )
         self.loc_norm = nn.BatchNorm2d(256)
         self.loc_out = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
 
@@ -63,9 +90,9 @@ class InterestPointModule(nn.Module):
 
         # cross_head:
         self.shift_out = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1)
-                    
+
         self.relu = nn.ReLU(inplace=True)
-        
+
     def forward(self, x):
         B, _, H, W = x.shape
 
@@ -78,12 +105,12 @@ class InterestPointModule(nn.Module):
         x = self.conv4(x)
 
         B, _, Hc, Wc = x.shape
-        
+
         # score head
         score_x = self.score_out(self.relu(self.score_norm(self.score_conv(x))))
         aware = self.softmax(score_x[:, 0:2, :, :])
         score = score_x[:, 2, :, :].unsqueeze(1).sigmoid()
-        
+
         border_mask = torch.ones(B, Hc, Wc)
         border_mask[:, 0] = 0
         border_mask[:, Hc - 1] = 0
@@ -91,23 +118,31 @@ class InterestPointModule(nn.Module):
         border_mask[:, :, Wc - 1] = 0
         border_mask = border_mask.unsqueeze(1)
         score = score * border_mask.to(score.device)
-        
-        # location head        
+
+        # location head
         coord_x = self.relu(self.loc_norm(self.loc_conv(x)))
         coord_cell = self.loc_out(coord_x).tanh()
-        
+
         shift_ratio = self.shift_out(coord_x).sigmoid() * 2.0
 
-        step = ((H/Hc)-1) / 2.
-        center_base = image_grid(B, Hc, Wc,
-                                 dtype=coord_cell.dtype,
-                                 device=coord_cell.device,
-                                 ones=False, normalized=False).mul(H/Hc) + step
+        step = ((H / Hc) - 1) / 2.0
+        center_base = (
+            image_grid(
+                B,
+                Hc,
+                Wc,
+                dtype=coord_cell.dtype,
+                device=coord_cell.device,
+                ones=False,
+                normalized=False,
+            ).mul(H / Hc)
+            + step
+        )
 
         coord_un = center_base.add(coord_cell.mul(shift_ratio * step))
         coord = coord_un.clone()
-        coord[:, 0] = torch.clamp(coord_un[:, 0], min=0, max=W-1)
-        coord[:, 1] = torch.clamp(coord_un[:, 1], min=0, max=H-1)
+        coord[:, 0] = torch.clamp(coord_un[:, 0], min=0, max=W - 1)
+        coord[:, 1] = torch.clamp(coord_un[:, 1], min=0, max=H - 1)
 
         # descriptor block
         desc_block = []
@@ -117,16 +152,20 @@ class InterestPointModule(nn.Module):
 
         if self.is_test:
             coord_norm = coord[:, :2].clone()
-            coord_norm[:, 0] = (coord_norm[:, 0] / (float(W-1)/2.)) - 1.
-            coord_norm[:, 1] = (coord_norm[:, 1] / (float(H-1)/2.)) - 1.
+            coord_norm[:, 0] = (coord_norm[:, 0] / (float(W - 1) / 2.0)) - 1.0
+            coord_norm[:, 1] = (coord_norm[:, 1] / (float(H - 1) / 2.0)) - 1.0
             coord_norm = coord_norm.permute(0, 2, 3, 1)
 
-            desc2 = torch.nn.functional.grid_sample(desc_block[0], coord_norm)         
+            desc2 = torch.nn.functional.grid_sample(desc_block[0], coord_norm)
             desc3 = torch.nn.functional.grid_sample(desc_block[1], coord_norm)
             aware = desc_block[2]
-            
-            desc = torch.mul(desc2, aware[:, 0, :, :]) + torch.mul(desc3, aware[:, 1, :, :])         
-            desc = desc.div(torch.unsqueeze(torch.norm(desc, p=2, dim=1), 1))  # Divide by norm to normalize.
+
+            desc = torch.mul(desc2, aware[:, 0, :, :]) + torch.mul(
+                desc3, aware[:, 1, :, :]
+            )
+            desc = desc.div(
+                torch.unsqueeze(torch.norm(desc, p=2, dim=1), 1)
+            )  # Divide by norm to normalize.
 
             return score, coord, desc
 
@@ -134,25 +173,32 @@ class InterestPointModule(nn.Module):
 
 
 class CorrespondenceModule(nn.Module):
-    def __init__(self, match_type='dual_softmax'):
+    def __init__(self, match_type="dual_softmax"):
         super(CorrespondenceModule, self).__init__()
         self.match_type = match_type
 
-        if self.match_type == 'dual_softmax':
+        if self.match_type == "dual_softmax":
             self.temperature = 0.1
         else:
             raise NotImplementedError()
- 
-    def forward(self, source_desc, target_desc):
-        b, c, h, w = source_desc.size()       
-     
-        source_desc = source_desc.div(torch.unsqueeze(torch.norm(source_desc, p=2, dim=1), 1)).view(b, -1, h*w)
-        target_desc = target_desc.div(torch.unsqueeze(torch.norm(target_desc, p=2, dim=1), 1)).view(b, -1, h*w)
 
-        if self.match_type == 'dual_softmax':
-            sim_mat = torch.einsum("bcm, bcn -> bmn", source_desc, target_desc) / self.temperature
+    def forward(self, source_desc, target_desc):
+        b, c, h, w = source_desc.size()
+
+        source_desc = source_desc.div(
+            torch.unsqueeze(torch.norm(source_desc, p=2, dim=1), 1)
+        ).view(b, -1, h * w)
+        target_desc = target_desc.div(
+            torch.unsqueeze(torch.norm(target_desc, p=2, dim=1), 1)
+        ).view(b, -1, h * w)
+
+        if self.match_type == "dual_softmax":
+            sim_mat = (
+                torch.einsum("bcm, bcn -> bmn", source_desc, target_desc)
+                / self.temperature
+            )
             confidence_matrix = F.softmax(sim_mat, 1) * F.softmax(sim_mat, 2)
         else:
             raise NotImplementedError()
-        
-        return confidence_matrix
\ No newline at end of file
+
+        return confidence_matrix
diff --git a/third_party/lanet/network_v1/model.py b/third_party/lanet/network_v1/model.py
index baeb37c563852340fe9278ed5c2dccea4b3b693a..51ca366db1d8afd76722f5c51ccfbf8b081c61e2 100644
--- a/third_party/lanet/network_v1/model.py
+++ b/third_party/lanet/network_v1/model.py
@@ -4,6 +4,7 @@ import torchvision.transforms as tvf
 
 from .modules import InterestPointModule, CorrespondenceModule
 
+
 def warp_homography_batch(sources, homographies):
     """
     Batch warp keypoints given homographies. From https://github.com/TRI-ML/KP2D.
@@ -24,27 +25,29 @@ def warp_homography_batch(sources, homographies):
     warped_sources = []
     for b in range(B):
         source = sources[b].clone()
-        source = source.view(-1,2)
-        '''
+        source = source.view(-1, 2)
+        """
         [X,    [M11, M12, M13    [x,    M11*x + M12*y + M13           [M11, M12      [M13,
          Y,  =  M21, M22, M23  *  y, =  M21*x + M22*y + M23 = [x, y] * M21, M22    +  M23,
          Z]     M31, M32, M33]    1]    M31*x + M32*y + M33            M31, M32].T    M33]
-        '''
-        source = torch.addmm(homographies[b,:,2], source, homographies[b,:,:2].t())
-        source.mul_(1/source[:,2].unsqueeze(1))
-        source = source[:,:2].contiguous().view(H,W,2)
+        """
+        source = torch.addmm(homographies[b, :, 2], source, homographies[b, :, :2].t())
+        source.mul_(1 / source[:, 2].unsqueeze(1))
+        source = source[:, :2].contiguous().view(H, W, 2)
         warped_sources.append(source)
     return torch.stack(warped_sources, dim=0)
 
- 
+
 class PointModel(nn.Module):
     def __init__(self, is_test=False):
         super(PointModel, self).__init__()
         self.is_test = is_test
         self.interestpoint_module = InterestPointModule(is_test=self.is_test)
         self.correspondence_module = CorrespondenceModule()
-        self.norm_rgb = tvf.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
-  
+        self.norm_rgb = tvf.Normalize(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+        )
+
     def forward(self, *args):
         img = args[0]
         img = self.norm_rgb(img)
diff --git a/third_party/lanet/network_v1/modules.py b/third_party/lanet/network_v1/modules.py
index 4daed5f12c40e40f6fc8347f701235e141839ada..583076eba72ea6f79f4ca55ffcef82ebbdecd91c 100644
--- a/third_party/lanet/network_v1/modules.py
+++ b/third_party/lanet/network_v1/modules.py
@@ -6,29 +6,53 @@ import torch.nn.functional as F
 from torchvision import models
 from utils import image_grid
 
+
 class ConvBlock(nn.Module):
     def __init__(self, in_channels, out_channels):
         super(ConvBlock, self).__init__()
-        
+
         self.conv = nn.Sequential(
-            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
+            nn.Conv2d(
+                in_channels,
+                out_channels,
+                kernel_size=3,
+                stride=1,
+                padding=1,
+                bias=False,
+            ),
             nn.BatchNorm2d(out_channels),
             nn.ReLU(inplace=True),
-            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
+            nn.Conv2d(
+                out_channels,
+                out_channels,
+                kernel_size=3,
+                stride=1,
+                padding=1,
+                bias=False,
+            ),
             nn.BatchNorm2d(out_channels),
-            nn.ReLU(inplace=True)
+            nn.ReLU(inplace=True),
         )
-    
+
     def forward(self, x):
         return self.conv(x)
 
+
 class DilationConv3x3(nn.Module):
     def __init__(self, in_channels, out_channels):
         super(DilationConv3x3, self).__init__()
-        
-        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=2, dilation=2, bias=False)
+
+        self.conv = nn.Conv2d(
+            in_channels,
+            out_channels,
+            kernel_size=3,
+            stride=1,
+            padding=2,
+            dilation=2,
+            bias=False,
+        )
         self.bn = nn.BatchNorm2d(out_channels)
-    
+
     def forward(self, x):
         x = self.conv(x)
         x = self.bn(x)
@@ -43,19 +67,17 @@ class InterestPointModule(nn.Module):
         model = models.vgg16_bn(pretrained=True)
 
         # use the first 23 layers as encoder
-        self.encoder = nn.Sequential(
-            *list(model.features.children())[: 33]
-        )
-        
+        self.encoder = nn.Sequential(*list(model.features.children())[:33])
+
         # score head
         self.score_head = nn.Sequential(
             nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1, bias=False),
             nn.BatchNorm2d(256),
             nn.ReLU(inplace=True),
-            nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
+            nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1),
         )
         self.softmax = nn.Softmax(dim=1)
-        
+
         # location head
         self.loc_head = nn.Sequential(
             nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1, bias=False),
@@ -65,18 +87,18 @@ class InterestPointModule(nn.Module):
         # location out
         self.loc_out = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
         self.shift_out = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1)
-        
+
         # descriptor out
         self.des_out2 = DilationConv3x3(128, 256)
         self.des_out3 = DilationConv3x3(256, 256)
         self.des_out4 = DilationConv3x3(512, 256)
-        
+
     def forward(self, x):
         B, _, H, W = x.shape
 
         x = self.encoder[2](self.encoder[1](self.encoder[0](x)))
         x = self.encoder[5](self.encoder[4](self.encoder[3](x)))
-        
+
         x = self.encoder[6](x)
         x = self.encoder[9](self.encoder[8](self.encoder[7](x)))
         x2 = self.encoder[12](self.encoder[11](self.encoder[10](x)))
@@ -85,20 +107,19 @@ class InterestPointModule(nn.Module):
         x = self.encoder[16](self.encoder[15](self.encoder[14](x)))
         x = self.encoder[19](self.encoder[18](self.encoder[17](x)))
         x3 = self.encoder[22](self.encoder[21](self.encoder[20](x)))
-        
+
         x = self.encoder[23](x3)
         x = self.encoder[26](self.encoder[25](self.encoder[24](x)))
         x = self.encoder[29](self.encoder[28](self.encoder[27](x)))
         x = self.encoder[32](self.encoder[31](self.encoder[30](x)))
-        
 
         B, _, Hc, Wc = x.shape
-        
+
         # score head
         score_x = self.score_head(x)
         aware = self.softmax(score_x[:, 0:3, :, :])
         score = score_x[:, 3, :, :].unsqueeze(1).sigmoid()
-        
+
         border_mask = torch.ones(B, Hc, Wc)
         border_mask[:, 0] = 0
         border_mask[:, Hc - 1] = 0
@@ -106,23 +127,31 @@ class InterestPointModule(nn.Module):
         border_mask[:, :, Wc - 1] = 0
         border_mask = border_mask.unsqueeze(1)
         score = score * border_mask.to(score.device)
-        
+
         # location head
-        coord_x = self.loc_head(x)        
+        coord_x = self.loc_head(x)
         coord_cell = self.loc_out(coord_x).tanh()
-        
+
         shift_ratio = self.shift_out(coord_x).sigmoid() * 2.0
 
-        step = ((H/Hc)-1) / 2.
-        center_base = image_grid(B, Hc, Wc,
-                                 dtype=coord_cell.dtype,
-                                 device=coord_cell.device,
-                                 ones=False, normalized=False).mul(H/Hc) + step
+        step = ((H / Hc) - 1) / 2.0
+        center_base = (
+            image_grid(
+                B,
+                Hc,
+                Wc,
+                dtype=coord_cell.dtype,
+                device=coord_cell.device,
+                ones=False,
+                normalized=False,
+            ).mul(H / Hc)
+            + step
+        )
 
         coord_un = center_base.add(coord_cell.mul(shift_ratio * step))
         coord = coord_un.clone()
-        coord[:, 0] = torch.clamp(coord_un[:, 0], min=0, max=W-1)
-        coord[:, 1] = torch.clamp(coord_un[:, 1], min=0, max=H-1)
+        coord[:, 0] = torch.clamp(coord_un[:, 0], min=0, max=W - 1)
+        coord[:, 1] = torch.clamp(coord_un[:, 1], min=0, max=H - 1)
 
         # descriptor block
         desc_block = []
@@ -133,42 +162,56 @@ class InterestPointModule(nn.Module):
 
         if self.is_test:
             coord_norm = coord[:, :2].clone()
-            coord_norm[:, 0] = (coord_norm[:, 0] / (float(W-1)/2.)) - 1.
-            coord_norm[:, 1] = (coord_norm[:, 1] / (float(H-1)/2.)) - 1.
+            coord_norm[:, 0] = (coord_norm[:, 0] / (float(W - 1) / 2.0)) - 1.0
+            coord_norm[:, 1] = (coord_norm[:, 1] / (float(H - 1) / 2.0)) - 1.0
             coord_norm = coord_norm.permute(0, 2, 3, 1)
 
-            desc2 = torch.nn.functional.grid_sample(desc_block[0], coord_norm)         
+            desc2 = torch.nn.functional.grid_sample(desc_block[0], coord_norm)
             desc3 = torch.nn.functional.grid_sample(desc_block[1], coord_norm)
             desc4 = torch.nn.functional.grid_sample(desc_block[2], coord_norm)
             aware = desc_block[3]
-            
-            desc = torch.mul(desc2, aware[:, 0, :, :]) + torch.mul(desc3, aware[:, 1, :, :]) + torch.mul(desc4, aware[:, 2, :, :])         
-            desc = desc.div(torch.unsqueeze(torch.norm(desc, p=2, dim=1), 1))  # Divide by norm to normalize.
+
+            desc = (
+                torch.mul(desc2, aware[:, 0, :, :])
+                + torch.mul(desc3, aware[:, 1, :, :])
+                + torch.mul(desc4, aware[:, 2, :, :])
+            )
+            desc = desc.div(
+                torch.unsqueeze(torch.norm(desc, p=2, dim=1), 1)
+            )  # Divide by norm to normalize.
 
             return score, coord, desc
 
         return score, coord, desc_block
 
+
 class CorrespondenceModule(nn.Module):
-    def __init__(self, match_type='dual_softmax'):
+    def __init__(self, match_type="dual_softmax"):
         super(CorrespondenceModule, self).__init__()
         self.match_type = match_type
 
-        if self.match_type == 'dual_softmax':
+        if self.match_type == "dual_softmax":
             self.temperature = 0.1
         else:
             raise NotImplementedError()
- 
-    def forward(self, source_desc, target_desc):
-        b, c, h, w = source_desc.size()       
-     
-        source_desc = source_desc.div(torch.unsqueeze(torch.norm(source_desc, p=2, dim=1), 1)).view(b, -1, h*w)
-        target_desc = target_desc.div(torch.unsqueeze(torch.norm(target_desc, p=2, dim=1), 1)).view(b, -1, h*w)
 
-        if self.match_type == 'dual_softmax':
-            sim_mat = torch.einsum("bcm, bcn -> bmn", source_desc, target_desc) / self.temperature
+    def forward(self, source_desc, target_desc):
+        b, c, h, w = source_desc.size()
+
+        source_desc = source_desc.div(
+            torch.unsqueeze(torch.norm(source_desc, p=2, dim=1), 1)
+        ).view(b, -1, h * w)
+        target_desc = target_desc.div(
+            torch.unsqueeze(torch.norm(target_desc, p=2, dim=1), 1)
+        ).view(b, -1, h * w)
+
+        if self.match_type == "dual_softmax":
+            sim_mat = (
+                torch.einsum("bcm, bcn -> bmn", source_desc, target_desc)
+                / self.temperature
+            )
             confidence_matrix = F.softmax(sim_mat, 1) * F.softmax(sim_mat, 2)
         else:
             raise NotImplementedError()
-        
+
         return confidence_matrix
diff --git a/third_party/lanet/test.py b/third_party/lanet/test.py
index cc9365f5c92cbd69c3ee9250ff66b07bd1eed1c6..d54b60f6669ac02ca16aacd94bb9145050a99a05 100644
--- a/third_party/lanet/test.py
+++ b/third_party/lanet/test.py
@@ -14,9 +14,9 @@ from evaluation.evaluate import evaluate_keypoint_net
 
 
 def main():
-    parser = argparse.ArgumentParser(description='Testing')
-    parser.add_argument('--device', default=0, type=int, help='which gpu to run on.')
-    parser.add_argument('--test_dir', required=True, type=str, help='Test data path.')
+    parser = argparse.ArgumentParser(description="Testing")
+    parser.add_argument("--device", default=0, type=int, help="which gpu to run on.")
+    parser.add_argument("--test_dir", required=True, type=str, help="Test data path.")
     opt = parser.parse_args()
 
     torch.manual_seed(0)
@@ -25,63 +25,67 @@ def main():
         torch.cuda.set_device(opt.device)
 
     # Load data in 320x240
-    hp_dataset_320x240 = PatchesDataset(root_dir=opt.test_dir, use_color=True, output_shape=(320, 240), type='all')
-    data_loader_320x240 = DataLoader(hp_dataset_320x240,
-                             batch_size=1,
-                             pin_memory=False,
-                             shuffle=False,
-                             num_workers=4,
-                             worker_init_fn=None,
-                             sampler=None)
+    hp_dataset_320x240 = PatchesDataset(
+        root_dir=opt.test_dir, use_color=True, output_shape=(320, 240), type="all"
+    )
+    data_loader_320x240 = DataLoader(
+        hp_dataset_320x240,
+        batch_size=1,
+        pin_memory=False,
+        shuffle=False,
+        num_workers=4,
+        worker_init_fn=None,
+        sampler=None,
+    )
 
     # Load data in 640x480
-    hp_dataset_640x480 = PatchesDataset(root_dir=opt.test_dir, use_color=True, output_shape=(640, 480), type='all')
-    data_loader_640x480 = DataLoader(hp_dataset_640x480,
-                             batch_size=1,
-                             pin_memory=False,
-                             shuffle=False,
-                             num_workers=4,
-                             worker_init_fn=None,
-                             sampler=None)
+    hp_dataset_640x480 = PatchesDataset(
+        root_dir=opt.test_dir, use_color=True, output_shape=(640, 480), type="all"
+    )
+    data_loader_640x480 = DataLoader(
+        hp_dataset_640x480,
+        batch_size=1,
+        pin_memory=False,
+        shuffle=False,
+        num_workers=4,
+        worker_init_fn=None,
+        sampler=None,
+    )
 
     # Load model
     model = PointModel(is_test=True)
-    ckpt = torch.load('./checkpoints/PointModel_v0.pth')
-    model.load_state_dict(ckpt['model_state'])
+    ckpt = torch.load("./checkpoints/PointModel_v0.pth")
+    model.load_state_dict(ckpt["model_state"])
     model = model.eval()
     if use_gpu:
         model = model.cuda()
 
-
-    print('Evaluating in 320x240, 300 points')
+    print("Evaluating in 320x240, 300 points")
     rep, loc, c1, c3, c5, mscore = evaluate_keypoint_net(
-        data_loader_320x240,
-        model,
-        output_shape=(320, 240),
-        top_k=300)
+        data_loader_320x240, model, output_shape=(320, 240), top_k=300
+    )
 
-    print('Repeatability: {0:.3f}'.format(rep))
-    print('Localization Error: {0:.3f}'.format(loc))
-    print('H-1 Accuracy: {:.3f}'.format(c1))
-    print('H-3 Accuracy: {:.3f}'.format(c3))
-    print('H-5 Accuracy: {:.3f}'.format(c5))
-    print('Matching Score: {:.3f}'.format(mscore))
-    print('\n')
+    print("Repeatability: {0:.3f}".format(rep))
+    print("Localization Error: {0:.3f}".format(loc))
+    print("H-1 Accuracy: {:.3f}".format(c1))
+    print("H-3 Accuracy: {:.3f}".format(c3))
+    print("H-5 Accuracy: {:.3f}".format(c5))
+    print("Matching Score: {:.3f}".format(mscore))
+    print("\n")
 
-    print('Evaluating in 640x480, 1000 points')
+    print("Evaluating in 640x480, 1000 points")
     rep, loc, c1, c3, c5, mscore = evaluate_keypoint_net(
-        data_loader_640x480,
-        model,
-        output_shape=(640, 480),
-        top_k=1000)
+        data_loader_640x480, model, output_shape=(640, 480), top_k=1000
+    )
+
+    print("Repeatability: {0:.3f}".format(rep))
+    print("Localization Error: {0:.3f}".format(loc))
+    print("H-1 Accuracy: {:.3f}".format(c1))
+    print("H-3 Accuracy: {:.3f}".format(c3))
+    print("H-5 Accuracy: {:.3f}".format(c5))
+    print("Matching Score: {:.3f}".format(mscore))
+    print("\n")
 
-    print('Repeatability: {0:.3f}'.format(rep))
-    print('Localization Error: {0:.3f}'.format(loc))
-    print('H-1 Accuracy: {:.3f}'.format(c1))
-    print('H-3 Accuracy: {:.3f}'.format(c3))
-    print('H-5 Accuracy: {:.3f}'.format(c5))
-    print('Matching Score: {:.3f}'.format(mscore))
-    print('\n')
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
diff --git a/third_party/lanet/train.py b/third_party/lanet/train.py
index 3076a0fdb78a59bfd64367399c0f2b0de1297653..e82900a3b27f8954c65f7bf4127f38a65ac76fff 100644
--- a/third_party/lanet/train.py
+++ b/third_party/lanet/train.py
@@ -8,6 +8,7 @@ from torch.autograd import Variable
 from network_v0.model import PointModel
 from loss_function import KeypointLoss
 
+
 class Trainer(object):
     def __init__(self, config, train_loader=None):
         self.config = config
@@ -28,56 +29,76 @@ class Trainer(object):
         self.random_seed = config.seed
         self.gpu = config.gpu
         self.ckpt_dir = config.ckpt_dir
-        self.ckpt_name = '{}-{}'.format(config.ckpt_name, config.seed)
-		
+        self.ckpt_name = "{}-{}".format(config.ckpt_name, config.seed)
+
         # build model
         self.model = PointModel(is_test=False)
-        
+
         # training on GPU
         if self.use_gpu:
             torch.cuda.set_device(self.gpu)
             self.model.cuda()
 
-        print('Number of model parameters: {:,}'.format(sum([p.data.nelement() for p in self.model.parameters()])))	
-        
+        print(
+            "Number of model parameters: {:,}".format(
+                sum([p.data.nelement() for p in self.model.parameters()])
+            )
+        )
+
         # build loss functional
         self.loss_func = KeypointLoss(config)
-        
+
         # build optimizer and scheduler
         self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
-        self.lr_scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[4, 8], gamma=self.lr_factor)
+        self.lr_scheduler = optim.lr_scheduler.MultiStepLR(
+            self.optimizer, milestones=[4, 8], gamma=self.lr_factor
+        )
 
         # resume
         if int(self.config.start_epoch) > 0:
-            self.config.start_epoch, self.model, self.optimizer, self.lr_scheduler = self.load_checkpoint(int(self.config.start_epoch), self.model, self.optimizer, self.lr_scheduler)    
-    
+            (
+                self.config.start_epoch,
+                self.model,
+                self.optimizer,
+                self.lr_scheduler,
+            ) = self.load_checkpoint(
+                int(self.config.start_epoch),
+                self.model,
+                self.optimizer,
+                self.lr_scheduler,
+            )
+
     def train(self):
         print("\nTrain on {} samples".format(self.num_train))
         self.save_checkpoint(0, self.model, self.optimizer, self.lr_scheduler)
         for epoch in range(self.start_epoch, self.max_epoch):
-            print("\nEpoch: {}/{} --lr: {:.6f}".format(epoch+1, self.max_epoch, self.lr))
+            print(
+                "\nEpoch: {}/{} --lr: {:.6f}".format(epoch + 1, self.max_epoch, self.lr)
+            )
             # train for one epoch
             self.train_one_epoch(epoch)
             if self.lr_scheduler:
                 self.lr_scheduler.step()
-            self.save_checkpoint(epoch+1, self.model, self.optimizer, self.lr_scheduler)
-            
+            self.save_checkpoint(
+                epoch + 1, self.model, self.optimizer, self.lr_scheduler
+            )
+
     def train_one_epoch(self, epoch):
         self.model.train()
         for (i, data) in enumerate(tqdm(self.train_loader)):
 
             if self.use_gpu:
-                source_img = data['image_aug'].cuda()
-                target_img = data['image'].cuda()
-                homography = data['homography'].cuda()
-            
+                source_img = data["image_aug"].cuda()
+                target_img = data["image"].cuda()
+                homography = data["homography"].cuda()
+
             source_img = Variable(source_img)
             target_img = Variable(target_img)
             homography = Variable(homography)
-            
+
             # forward propogation
             output = self.model(source_img, target_img, homography)
-            
+
             # compute loss
             loss, loc_loss, desc_loss, score_loss, corres_loss = self.loss_func(output)
 
@@ -87,43 +108,45 @@ class Trainer(object):
             self.optimizer.step()
 
             # print training info
-            msg_batch = "Epoch:{} Iter:{} lr:{:.4f} "\
-                        "loc_loss={:.4f} desc_loss={:.4f} score_loss={:.4f} corres_loss={:.4f} "\
-                        "loss={:.4f} "\
-                        .format((epoch + 1), i, self.lr, loc_loss.data, desc_loss.data, score_loss.data, corres_loss.data, loss.data)
+            msg_batch = (
+                "Epoch:{} Iter:{} lr:{:.4f} "
+                "loc_loss={:.4f} desc_loss={:.4f} score_loss={:.4f} corres_loss={:.4f} "
+                "loss={:.4f} ".format(
+                    (epoch + 1),
+                    i,
+                    self.lr,
+                    loc_loss.data,
+                    desc_loss.data,
+                    score_loss.data,
+                    corres_loss.data,
+                    loss.data,
+                )
+            )
 
-            if((i % self.display) == 0):
+            if (i % self.display) == 0:
                 print(msg_batch)
         return
 
     def save_checkpoint(self, epoch, model, optimizer, lr_scheduler):
-        filename = self.ckpt_name + '_' + str(epoch) + '.pth'
+        filename = self.ckpt_name + "_" + str(epoch) + ".pth"
         torch.save(
-            {'epoch': epoch,
-            'model_state': model.state_dict(),
-            'optimizer_state': optimizer.state_dict(),
-            'lr_scheduler': lr_scheduler.state_dict()},
-            os.path.join(self.ckpt_dir, filename))
+            {
+                "epoch": epoch,
+                "model_state": model.state_dict(),
+                "optimizer_state": optimizer.state_dict(),
+                "lr_scheduler": lr_scheduler.state_dict(),
+            },
+            os.path.join(self.ckpt_dir, filename),
+        )
 
     def load_checkpoint(self, epoch, model, optimizer, lr_scheduler):
-        filename = self.ckpt_name + '_' + str(epoch) + '.pth'
+        filename = self.ckpt_name + "_" + str(epoch) + ".pth"
         ckpt = torch.load(os.path.join(self.ckpt_dir, filename))
-        epoch = ckpt['epoch']
-        model.load_state_dict(ckpt['model_state'])
-        optimizer.load_state_dict(ckpt['optimizer_state'])
-        lr_scheduler.load_state_dict(ckpt['lr_scheduler'])
-
-        print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt['epoch']))
-
-        return epoch, model, optimizer, lr_scheduler				        
-        
-        
-        
-        
-        
-        
-        
-        
-        
-        
-        
\ No newline at end of file
+        epoch = ckpt["epoch"]
+        model.load_state_dict(ckpt["model_state"])
+        optimizer.load_state_dict(ckpt["optimizer_state"])
+        lr_scheduler.load_state_dict(ckpt["lr_scheduler"])
+
+        print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt["epoch"]))
+
+        return epoch, model, optimizer, lr_scheduler
diff --git a/third_party/lanet/utils.py b/third_party/lanet/utils.py
index d5422ebcfc2847be047391791d891a09388ca7d1..6f1ead467c166a95e6782a8112bafe363f948f9b 100644
--- a/third_party/lanet/utils.py
+++ b/third_party/lanet/utils.py
@@ -4,6 +4,7 @@ import torch
 import torchvision.transforms as transforms
 from functools import lru_cache
 
+
 @lru_cache(maxsize=None)
 def meshgrid(B, H, W, dtype, device, normalized=False):
     """
@@ -35,8 +36,8 @@ def meshgrid(B, H, W, dtype, device, normalized=False):
         xs = torch.linspace(-1, 1, W, device=device, dtype=dtype)
         ys = torch.linspace(-1, 1, H, device=device, dtype=dtype)
     else:
-        xs = torch.linspace(0, W-1, W, device=device, dtype=dtype)
-        ys = torch.linspace(0, H-1, H, device=device, dtype=dtype)
+        xs = torch.linspace(0, W - 1, W, device=device, dtype=dtype)
+        ys = torch.linspace(0, H - 1, H, device=device, dtype=dtype)
     ys, xs = torch.meshgrid([ys, xs])
     return xs.repeat([B, 1, 1]), ys.repeat([B, 1, 1])
 
@@ -75,7 +76,8 @@ def image_grid(B, H, W, dtype, device, ones=True, normalized=False):
     grid = torch.stack(coords, dim=1)  # B3HW
     return grid
 
-def to_tensor_sample(sample, tensor_type='torch.FloatTensor'):
+
+def to_tensor_sample(sample, tensor_type="torch.FloatTensor"):
     """
     Casts the keys of sample to tensors. From https://github.com/TRI-ML/KP2D.
 
@@ -92,11 +94,11 @@ def to_tensor_sample(sample, tensor_type='torch.FloatTensor'):
         Sample with keys cast as tensors
     """
     transform = transforms.ToTensor()
-    sample['image'] = transform(sample['image']).type(tensor_type)
+    sample["image"] = transform(sample["image"]).type(tensor_type)
     return sample
 
+
 def prepare_dirs(config):
     for path in [config.ckpt_dir]:
         if not os.path.exists(path):
             os.makedirs(path)
-
diff --git a/third_party/r2d2/datasets/__init__.py b/third_party/r2d2/datasets/__init__.py
index 8f11df21be72856ea365f6efd7a389aba267562b..f538fb5372197bcdba9db28c861af39c541539ee 100644
--- a/third_party/r2d2/datasets/__init__.py
+++ b/third_party/r2d2/datasets/__init__.py
@@ -10,6 +10,7 @@ from .aachen import *
 
 # try to instanciate datasets
 import sys
+
 try:
     web_images = RandomWebImages(0, 52)
 except AssertionError as e:
@@ -23,11 +24,12 @@ except AssertionError as e:
 try:
     aachen_style_transfer_pairs = AachenPairs_StyleTransferDayNight()
 except AssertionError as e:
-    print(f"Dataset aachen_style_transfer_pairs not available, reason: {e}", file=sys.stderr)
+    print(
+        f"Dataset aachen_style_transfer_pairs not available, reason: {e}",
+        file=sys.stderr,
+    )
 
 try:
     aachen_flow_pairs = AachenPairs_OpticalFlow()
 except AssertionError as e:
     print(f"Dataset aachen_flow_pairs not available, reason: {e}", file=sys.stderr)
-
-
diff --git a/third_party/r2d2/datasets/aachen.py b/third_party/r2d2/datasets/aachen.py
index 4ddb324cea01da2430ee89b32c7627b34c01a41f..fbe2364a51c648ee48989f1725cf0033cd0c0547 100644
--- a/third_party/r2d2/datasets/aachen.py
+++ b/third_party/r2d2/datasets/aachen.py
@@ -10,61 +10,61 @@ from .dataset import Dataset
 from .pair_dataset import PairDataset, StillPairDataset
 
 
-class AachenImages (Dataset):
-    """ Loads all images from the Aachen Day-Night dataset 
-    """
-    def __init__(self, select='db day night', root='data/aachen'):
+class AachenImages(Dataset):
+    """Loads all images from the Aachen Day-Night dataset"""
+
+    def __init__(self, select="db day night", root="data/aachen"):
         Dataset.__init__(self)
         self.root = root
-        self.img_dir = 'images_upright'
+        self.img_dir = "images_upright"
         self.select = set(select.split())
-        assert self.select, 'Nothing was selected'
-        
+        assert self.select, "Nothing was selected"
+
         self.imgs = []
         root = os.path.join(root, self.img_dir)
         for dirpath, _, filenames in os.walk(root):
-            r = dirpath[len(root)+1:]
-            if not(self.select & set(r.split('/'))): continue
-            self.imgs += [os.path.join(r,f) for f in filenames if f.endswith('.jpg')]
-        
+            r = dirpath[len(root) + 1 :]
+            if not (self.select & set(r.split("/"))):
+                continue
+            self.imgs += [os.path.join(r, f) for f in filenames if f.endswith(".jpg")]
+
         self.nimg = len(self.imgs)
-        assert self.nimg, 'Empty Aachen dataset'
+        assert self.nimg, "Empty Aachen dataset"
 
     def get_key(self, idx):
         return self.imgs[idx]
 
 
+class AachenImages_DB(AachenImages):
+    """Only database (db) images."""
 
-class AachenImages_DB (AachenImages):
-    """ Only database (db) images.
-    """
     def __init__(self, **kw):
-        AachenImages.__init__(self, select='db', **kw)
-        self.db_image_idxs = {self.get_tag(i) : i for i,f in enumerate(self.imgs)}
-    
-    def get_tag(self, idx): 
-        # returns image tag == img number (name)
-        return os.path.split( self.imgs[idx][:-4] )[1]
+        AachenImages.__init__(self, select="db", **kw)
+        self.db_image_idxs = {self.get_tag(i): i for i, f in enumerate(self.imgs)}
 
+    def get_tag(self, idx):
+        # returns image tag == img number (name)
+        return os.path.split(self.imgs[idx][:-4])[1]
 
 
-class AachenPairs_StyleTransferDayNight (AachenImages_DB, StillPairDataset):
-    """ synthetic day-night pairs of images 
-        (night images obtained using autoamtic style transfer from web night images)
+class AachenPairs_StyleTransferDayNight(AachenImages_DB, StillPairDataset):
+    """synthetic day-night pairs of images
+    (night images obtained using autoamtic style transfer from web night images)
     """
-    def __init__(self, root='data/aachen/style_transfer', **kw):
+
+    def __init__(self, root="data/aachen/style_transfer", **kw):
         StillPairDataset.__init__(self)
         AachenImages_DB.__init__(self, **kw)
         old_root = os.path.join(self.root, self.img_dir)
         self.root = os.path.commonprefix((old_root, root))
-        self.img_dir = ''
+        self.img_dir = ""
 
-        newpath = lambda folder, f: os.path.join(folder, f)[len(self.root):]
+        newpath = lambda folder, f: os.path.join(folder, f)[len(self.root) :]
         self.imgs = [newpath(old_root, f) for f in self.imgs]
 
         self.image_pairs = []
         for fname in os.listdir(root):
-            tag = fname.split('.jpg.st_')[0]
+            tag = fname.split(".jpg.st_")[0]
             self.image_pairs.append((self.db_image_idxs[tag], len(self.imgs)))
             self.imgs.append(newpath(root, fname))
 
@@ -73,42 +73,45 @@ class AachenPairs_StyleTransferDayNight (AachenImages_DB, StillPairDataset):
         assert self.nimg and self.npairs
 
 
+class AachenPairs_OpticalFlow(AachenImages_DB, PairDataset):
+    """Image pairs from Aachen db with optical flow."""
 
-class AachenPairs_OpticalFlow (AachenImages_DB, PairDataset):
-    """ Image pairs from Aachen db with optical flow.
-    """
-    def __init__(self, root='data/aachen/optical_flow', **kw):
+    def __init__(self, root="data/aachen/optical_flow", **kw):
         PairDataset.__init__(self)
         AachenImages_DB.__init__(self, **kw)
         self.root_flow = root
 
         # find out the subsest of valid pairs from the list of flow files
-        flows = {f for f in os.listdir(os.path.join(root, 'flow')) if f.endswith('.png')}
-        masks = {f for f in os.listdir(os.path.join(root, 'mask')) if f.endswith('.png')}
-        assert flows == masks, 'Missing flow or mask pairs'
-        
-        make_pair = lambda f: tuple(self.db_image_idxs[v] for v in f[:-4].split('_'))
+        flows = {
+            f for f in os.listdir(os.path.join(root, "flow")) if f.endswith(".png")
+        }
+        masks = {
+            f for f in os.listdir(os.path.join(root, "mask")) if f.endswith(".png")
+        }
+        assert flows == masks, "Missing flow or mask pairs"
+
+        make_pair = lambda f: tuple(self.db_image_idxs[v] for v in f[:-4].split("_"))
         self.image_pairs = [make_pair(f) for f in flows]
         self.npairs = len(self.image_pairs)
         assert self.nimg and self.npairs
 
     def get_mask_filename(self, pair_idx):
         tag_a, tag_b = map(self.get_tag, self.image_pairs[pair_idx])
-        return os.path.join(self.root_flow, 'mask', f'{tag_a}_{tag_b}.png')
+        return os.path.join(self.root_flow, "mask", f"{tag_a}_{tag_b}.png")
 
     def get_mask(self, pair_idx):
         return np.asarray(Image.open(self.get_mask_filename(pair_idx)))
 
     def get_flow_filename(self, pair_idx):
         tag_a, tag_b = map(self.get_tag, self.image_pairs[pair_idx])
-        return os.path.join(self.root_flow, 'flow', f'{tag_a}_{tag_b}.png')
+        return os.path.join(self.root_flow, "flow", f"{tag_a}_{tag_b}.png")
 
     def get_flow(self, pair_idx):
         fname = self.get_flow_filename(pair_idx)
         try:
             return self._png2flow(fname)
         except IOError:
-            flow = open(fname[:-4], 'rb')
+            flow = open(fname[:-4], "rb")
             help = np.fromfile(flow, np.float32, 1)
             assert help == 202021.25
             W, H = np.fromfile(flow, np.int32, 2)
@@ -116,30 +119,28 @@ class AachenPairs_OpticalFlow (AachenImages_DB, PairDataset):
             return self._flow2png(flow, fname)
 
     def get_pair(self, idx, output=()):
-        if isinstance(output, str): 
+        if isinstance(output, str):
             output = output.split()
 
         img1, img2 = map(self.get_image, self.image_pairs[idx])
         meta = {}
-        
-        if 'flow' in output or 'aflow' in output:
+
+        if "flow" in output or "aflow" in output:
             flow = self.get_flow(idx)
             assert flow.shape[:2] == img1.size[::-1]
-            meta['flow'] = flow
+            meta["flow"] = flow
             H, W = flow.shape[:2]
-            meta['aflow'] = flow + np.mgrid[:H,:W][::-1].transpose(1,2,0)
-        
-        if 'mask' in output:
+            meta["aflow"] = flow + np.mgrid[:H, :W][::-1].transpose(1, 2, 0)
+
+        if "mask" in output:
             mask = self.get_mask(idx)
             assert mask.shape[:2] == img1.size[::-1]
-            meta['mask'] = mask
-        
-        return img1, img2, meta
-
+            meta["mask"] = mask
 
+        return img1, img2, meta
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     print(aachen_db_images)
     print(aachen_style_transfer_pairs)
     print(aachen_flow_pairs)
diff --git a/third_party/r2d2/datasets/dataset.py b/third_party/r2d2/datasets/dataset.py
index 80d893b8ea4ead7845f35c4fe82c9f5a9b849de3..5f4474e7dc8b81f091cac1e13f431c5c9f1840f3 100644
--- a/third_party/r2d2/datasets/dataset.py
+++ b/third_party/r2d2/datasets/dataset.py
@@ -9,10 +9,10 @@ import numpy as np
 
 
 class Dataset(object):
-    ''' Base class for a dataset. To be overloaded.
-    '''
-    root = ''
-    img_dir = ''
+    """Base class for a dataset. To be overloaded."""
+
+    root = ""
+    img_dir = ""
     nimg = 0
 
     def __len__(self):
@@ -26,23 +26,23 @@ class Dataset(object):
 
     def get_image(self, img_idx):
         from PIL import Image
+
         fname = self.get_filename(img_idx)
         try:
-            return Image.open(fname).convert('RGB')
+            return Image.open(fname).convert("RGB")
         except Exception as e:
             raise IOError("Could not load image %s (reason: %s)" % (fname, str(e)))
 
     def __repr__(self):
-        res =  'Dataset: %s\n' % self.__class__.__name__
-        res += '  %d images' % self.nimg
-        res += '\n  root: %s...\n' % self.root
+        res = "Dataset: %s\n" % self.__class__.__name__
+        res += "  %d images" % self.nimg
+        res += "\n  root: %s...\n" % self.root
         return res
 
 
+class CatDataset(Dataset):
+    """Concatenation of several datasets."""
 
-class CatDataset (Dataset):
-    ''' Concatenation of several datasets.
-    '''
     def __init__(self, *datasets):
         assert len(datasets) >= 1
         self.datasets = datasets
@@ -54,8 +54,8 @@ class CatDataset (Dataset):
         self.root = None
 
     def which(self, i):
-        pos = np.searchsorted(self.offsets, i, side='right')-1
-        assert pos < self.nimg, 'Bad image index %d >= %d' % (i, self.nimg)
+        pos = np.searchsorted(self.offsets, i, side="right") - 1
+        assert pos < self.nimg, "Bad image index %d >= %d" % (i, self.nimg)
         return pos, i - self.offsets[pos]
 
     def get_key(self, i):
@@ -69,9 +69,5 @@ class CatDataset (Dataset):
     def __repr__(self):
         fmt_str = "CatDataset("
         for db in self.datasets:
-            fmt_str += str(db).replace("\n"," ") + ', '
-        return fmt_str[:-2] + ')'
-
-
-
-
+            fmt_str += str(db).replace("\n", " ") + ", "
+        return fmt_str[:-2] + ")"
diff --git a/third_party/r2d2/datasets/imgfolder.py b/third_party/r2d2/datasets/imgfolder.py
index 45f7bc9ee4c3ba5f04380dbc02ad17b6463cf32f..40168f00e8ad177f3d94f75578dba2e640944c4c 100644
--- a/third_party/r2d2/datasets/imgfolder.py
+++ b/third_party/r2d2/datasets/imgfolder.py
@@ -8,10 +8,10 @@ from .dataset import Dataset
 from .pair_dataset import SyntheticPairDataset
 
 
-class ImgFolder (Dataset):
-    """ load all images in a folder (no recursion).
-    """
-    def __init__(self, root, imgs=None, exts=('.jpg','.png','.ppm')):
+class ImgFolder(Dataset):
+    """load all images in a folder (no recursion)."""
+
+    def __init__(self, root, imgs=None, exts=(".jpg", ".png", ".ppm")):
         Dataset.__init__(self)
         self.root = root
         self.imgs = imgs or [f for f in os.listdir(root) if f.endswith(exts)]
@@ -19,5 +19,3 @@ class ImgFolder (Dataset):
 
     def get_key(self, idx):
         return self.imgs[idx]
-
-
diff --git a/third_party/r2d2/datasets/pair_dataset.py b/third_party/r2d2/datasets/pair_dataset.py
index aeed98b6700e0ba108bb44abccc20351d16f3295..ba178c18a0a6fbb1decfe4a797dbcab0636dbeaf 100644
--- a/third_party/r2d2/datasets/pair_dataset.py
+++ b/third_party/r2d2/datasets/pair_dataset.py
@@ -11,20 +11,24 @@ from tools.transforms import instanciate_transformation
 from tools.transforms_tools import persp_apply
 
 
-class PairDataset (Dataset):
-    """ A dataset that serves image pairs with ground-truth pixel correspondences.
-    """
+class PairDataset(Dataset):
+    """A dataset that serves image pairs with ground-truth pixel correspondences."""
+
     def __init__(self):
         Dataset.__init__(self)
         self.npairs = 0
 
     def get_filename(self, img_idx, root=None):
-        if is_pair(img_idx): # if img_idx is a pair of indices, we return a pair of filenames
+        if is_pair(
+            img_idx
+        ):  # if img_idx is a pair of indices, we return a pair of filenames
             return tuple(Dataset.get_filename(self, i, root) for i in img_idx)
         return Dataset.get_filename(self, img_idx, root)
 
     def get_image(self, img_idx):
-        if is_pair(img_idx): # if img_idx is a pair of indices, we return a pair of images
+        if is_pair(
+            img_idx
+        ):  # if img_idx is a pair of indices, we return a pair of images
             return tuple(Dataset.get_image(self, i) for i in img_idx)
         return Dataset.get_image(self, img_idx)
 
@@ -41,8 +45,8 @@ class PairDataset (Dataset):
         raise NotImplementedError()
 
     def get_pair(self, idx, output=()):
-        """ returns (img1, img2, `metadata`)
-        
+        """returns (img1, img2, `metadata`)
+
         `metadata` is a dict() that can contain:
             flow: optical flow
             aflow: absolute flow
@@ -55,24 +59,24 @@ class PairDataset (Dataset):
     def get_paired_images(self):
         fns = set()
         for i in range(self.npairs):
-            a,b = self.image_pairs[i]
+            a, b = self.image_pairs[i]
             fns.add(self.get_filename(a))
             fns.add(self.get_filename(b))
         return fns
 
     def __len__(self):
-        return self.npairs # size should correspond to the number of pairs, not images
-    
+        return self.npairs  # size should correspond to the number of pairs, not images
+
     def __repr__(self):
-        res =  'Dataset: %s\n' % self.__class__.__name__
-        res += '  %d images,' % self.nimg
-        res += ' %d image pairs' % self.npairs
-        res += '\n  root: %s...\n' % self.root
+        res = "Dataset: %s\n" % self.__class__.__name__
+        res += "  %d images," % self.nimg
+        res += " %d image pairs" % self.npairs
+        res += "\n  root: %s...\n" % self.root
         return res
 
     @staticmethod
     def _flow2png(flow, path):
-        flow = np.clip(np.around(16*flow), -2**15, 2**15-1)
+        flow = np.clip(np.around(16 * flow), -(2**15), 2**15 - 1)
         bytes = np.int16(flow).view(np.uint8)
         Image.fromarray(bytes).save(path)
         return flow / 16
@@ -86,41 +90,42 @@ class PairDataset (Dataset):
             raise IOError("Error loading flow for %s" % path)
 
 
-
-class StillPairDataset (PairDataset):
-    """ A dataset of 'still' image pairs.
-        By overloading a normal image dataset, it appends the get_pair(i) function
-        that serves trivial image pairs (img1, img2) where img1 == img2 == get_image(i).
+class StillPairDataset(PairDataset):
+    """A dataset of 'still' image pairs.
+    By overloading a normal image dataset, it appends the get_pair(i) function
+    that serves trivial image pairs (img1, img2) where img1 == img2 == get_image(i).
     """
+
     def get_pair(self, pair_idx, output=()):
-        if isinstance(output, str): output = output.split()
+        if isinstance(output, str):
+            output = output.split()
         img1, img2 = map(self.get_image, self.image_pairs[pair_idx])
 
-        W,H = img1.size
+        W, H = img1.size
         sx = img2.size[0] / float(W)
         sy = img2.size[1] / float(H)
 
         meta = {}
-        if 'aflow' in output or 'flow' in output:
-            mgrid = np.mgrid[0:H, 0:W][::-1].transpose(1,2,0).astype(np.float32)
-            meta['aflow'] = mgrid * (sx,sy)
-            meta['flow'] = meta['aflow'] - mgrid
+        if "aflow" in output or "flow" in output:
+            mgrid = np.mgrid[0:H, 0:W][::-1].transpose(1, 2, 0).astype(np.float32)
+            meta["aflow"] = mgrid * (sx, sy)
+            meta["flow"] = meta["aflow"] - mgrid
 
-        if 'mask' in output:
-            meta['mask'] = np.ones((H,W), np.uint8)
+        if "mask" in output:
+            meta["mask"] = np.ones((H, W), np.uint8)
 
-        if 'homography' in output:
-            meta['homography'] = np.diag(np.float32([sx, sy, 1]))
+        if "homography" in output:
+            meta["homography"] = np.diag(np.float32([sx, sy, 1]))
 
         return img1, img2, meta
 
 
-
-class SyntheticPairDataset (PairDataset):
-    """ A synthetic generator of image pairs.
-        Given a normal image dataset, it constructs pairs using random homographies & noise.
+class SyntheticPairDataset(PairDataset):
+    """A synthetic generator of image pairs.
+    Given a normal image dataset, it constructs pairs using random homographies & noise.
     """
-    def __init__(self, dataset, scale='', distort=''):
+
+    def __init__(self, dataset, scale="", distort=""):
         self.attach_dataset(dataset)
         self.distort = instanciate_transformation(distort)
         self.scale = instanciate_transformation(scale)
@@ -133,56 +138,57 @@ class SyntheticPairDataset (PairDataset):
         self.get_key = dataset.get_key
         self.get_filename = dataset.get_filename
         self.root = None
-        
+
     def make_pair(self, img):
         return img, img
 
-    def get_pair(self, i, output=('aflow')):
-        """ Procedure:
-        This function applies a series of random transformations to one original image 
+    def get_pair(self, i, output=("aflow")):
+        """Procedure:
+        This function applies a series of random transformations to one original image
         to form a synthetic image pairs with perfect ground-truth.
         """
-        if isinstance(output, str): 
+        if isinstance(output, str):
             output = output.split()
-            
+
         original_img = self.dataset.get_image(i)
-        
+
         scaled_image = self.scale(original_img)
         scaled_image, scaled_image2 = self.make_pair(scaled_image)
         scaled_and_distorted_image = self.distort(
-            dict(img=scaled_image2, persp=(1,0,0,0,1,0,0,0)))
+            dict(img=scaled_image2, persp=(1, 0, 0, 0, 1, 0, 0, 0))
+        )
         W, H = scaled_image.size
-        trf = scaled_and_distorted_image['persp']
+        trf = scaled_and_distorted_image["persp"]
 
         meta = dict()
-        if 'aflow' in output or 'flow' in output:
+        if "aflow" in output or "flow" in output:
             # compute optical flow
-            xy = np.mgrid[0:H,0:W][::-1].reshape(2,H*W).T
-            aflow = np.float32(persp_apply(trf, xy).reshape(H,W,2))
-            meta['flow'] = aflow - xy.reshape(H,W,2)
-            meta['aflow'] = aflow
-        
-        if 'homography' in output:
-            meta['homography'] = np.float32(trf+(1,)).reshape(3,3)
-
-        return scaled_image, scaled_and_distorted_image['img'], meta
-    
-    def __repr__(self):
-        res =  'Dataset: %s\n' % self.__class__.__name__
-        res += '  %d images and pairs' % self.npairs
-        res += '\n  root: %s...' % self.dataset.root
-        res += '\n  Scale: %s' % (repr(self.scale).replace('\n',''))
-        res += '\n  Distort: %s' % (repr(self.distort).replace('\n',''))
-        return res + '\n'
+            xy = np.mgrid[0:H, 0:W][::-1].reshape(2, H * W).T
+            aflow = np.float32(persp_apply(trf, xy).reshape(H, W, 2))
+            meta["flow"] = aflow - xy.reshape(H, W, 2)
+            meta["aflow"] = aflow
 
+        if "homography" in output:
+            meta["homography"] = np.float32(trf + (1,)).reshape(3, 3)
 
+        return scaled_image, scaled_and_distorted_image["img"], meta
 
-class TransformedPairs (PairDataset):
-    """ Automatic data augmentation for pre-existing image pairs.
-        Given an image pair dataset, it generates synthetically jittered pairs
-        using random transformations (e.g. homographies & noise).
+    def __repr__(self):
+        res = "Dataset: %s\n" % self.__class__.__name__
+        res += "  %d images and pairs" % self.npairs
+        res += "\n  root: %s..." % self.dataset.root
+        res += "\n  Scale: %s" % (repr(self.scale).replace("\n", ""))
+        res += "\n  Distort: %s" % (repr(self.distort).replace("\n", ""))
+        return res + "\n"
+
+
+class TransformedPairs(PairDataset):
+    """Automatic data augmentation for pre-existing image pairs.
+    Given an image pair dataset, it generates synthetically jittered pairs
+    using random transformations (e.g. homographies & noise).
     """
-    def __init__(self, dataset, trf=''):
+
+    def __init__(self, dataset, trf=""):
         self.attach_dataset(dataset)
         self.trf = instanciate_transformation(trf)
 
@@ -195,48 +201,47 @@ class TransformedPairs (PairDataset):
         self.get_key = dataset.get_key
         self.get_filename = dataset.get_filename
         self.root = None
-        
-    def get_pair(self, i, output=''):
-        """ Procedure:
-        This function applies a series of random transformations to one original image 
+
+    def get_pair(self, i, output=""):
+        """Procedure:
+        This function applies a series of random transformations to one original image
         to form a synthetic image pairs with perfect ground-truth.
         """
         img_a, img_b_, metadata = self.dataset.get_pair(i, output)
 
-        img_b = self.trf({'img': img_b_, 'persp':(1,0,0,0,1,0,0,0)})
-        trf = img_b['persp']
+        img_b = self.trf({"img": img_b_, "persp": (1, 0, 0, 0, 1, 0, 0, 0)})
+        trf = img_b["persp"]
 
-        if 'aflow' in metadata or 'flow' in metadata:
-            aflow = metadata['aflow']
-            aflow[:] = persp_apply(trf, aflow.reshape(-1,2)).reshape(aflow.shape)
+        if "aflow" in metadata or "flow" in metadata:
+            aflow = metadata["aflow"]
+            aflow[:] = persp_apply(trf, aflow.reshape(-1, 2)).reshape(aflow.shape)
             W, H = img_a.size
-            flow = metadata['flow']
-            mgrid = np.mgrid[0:H, 0:W][::-1].transpose(1,2,0).astype(np.float32)
+            flow = metadata["flow"]
+            mgrid = np.mgrid[0:H, 0:W][::-1].transpose(1, 2, 0).astype(np.float32)
             flow[:] = aflow - mgrid
 
-        if 'corres' in metadata:
-            corres = metadata['corres']
-            corres[:,1] = persp_apply(trf, corres[:,1])
-        
-        if 'homography' in metadata:
+        if "corres" in metadata:
+            corres = metadata["corres"]
+            corres[:, 1] = persp_apply(trf, corres[:, 1])
+
+        if "homography" in metadata:
             # p_b = homography * p_a
-            trf_ = np.float32(trf+(1,)).reshape(3,3)
-            metadata['homography'] = np.float32(trf_ @ metadata['homography'])
+            trf_ = np.float32(trf + (1,)).reshape(3, 3)
+            metadata["homography"] = np.float32(trf_ @ metadata["homography"])
 
-        return img_a, img_b['img'], metadata
+        return img_a, img_b["img"], metadata
 
     def __repr__(self):
-        res =  'Transformed Pairs from %s\n' % type(self.dataset).__name__
-        res += '  %d images and pairs' % self.npairs
-        res += '\n  root: %s...' % self.dataset.root
-        res += '\n  transform: %s' % (repr(self.trf).replace('\n',''))
-        return res + '\n'
+        res = "Transformed Pairs from %s\n" % type(self.dataset).__name__
+        res += "  %d images and pairs" % self.npairs
+        res += "\n  root: %s..." % self.dataset.root
+        res += "\n  transform: %s" % (repr(self.trf).replace("\n", ""))
+        return res + "\n"
 
 
+class CatPairDataset(CatDataset):
+    """Concatenation of several pair datasets."""
 
-class CatPairDataset (CatDataset):
-    ''' Concatenation of several pair datasets.
-    '''
     def __init__(self, *datasets):
         CatDataset.__init__(self, *datasets)
         pair_offsets = [0]
@@ -251,12 +256,12 @@ class CatPairDataset (CatDataset):
     def __repr__(self):
         fmt_str = "CatPairDataset("
         for db in self.datasets:
-            fmt_str += str(db).replace("\n"," ") + ', '
-        return fmt_str[:-2] + ')'
+            fmt_str += str(db).replace("\n", " ") + ", "
+        return fmt_str[:-2] + ")"
 
     def pair_which(self, i):
-        pos = np.searchsorted(self.pair_offsets, i, side='right')-1
-        assert pos < self.npairs, 'Bad pair index %d >= %d' % (i, self.npairs)
+        pos = np.searchsorted(self.pair_offsets, i, side="right") - 1
+        assert pos < self.npairs, "Bad pair index %d >= %d" % (i, self.npairs)
         return pos, i - self.pair_offsets[pos]
 
     def pair_call(self, func, i, *args, **kwargs):
@@ -268,20 +273,18 @@ class CatPairDataset (CatDataset):
         return self.datasets[b].get_pair(i, output)
 
     def get_flow_filename(self, pair_idx, *args, **kwargs):
-        return self.pair_call('get_flow_filename', pair_idx, *args, **kwargs)
+        return self.pair_call("get_flow_filename", pair_idx, *args, **kwargs)
 
     def get_mask_filename(self, pair_idx, *args, **kwargs):
-        return self.pair_call('get_mask_filename', pair_idx, *args, **kwargs)
+        return self.pair_call("get_mask_filename", pair_idx, *args, **kwargs)
 
     def get_corres_filename(self, pair_idx, *args, **kwargs):
-        return self.pair_call('get_corres_filename', pair_idx, *args, **kwargs)
-
+        return self.pair_call("get_corres_filename", pair_idx, *args, **kwargs)
 
 
 def is_pair(x):
-    if isinstance(x, (tuple,list)) and len(x) == 2:
+    if isinstance(x, (tuple, list)) and len(x) == 2:
         return True
     if isinstance(x, np.ndarray) and x.ndim == 1 and x.shape[0] == 2:
         return True
     return False
-
diff --git a/third_party/r2d2/datasets/web_images.py b/third_party/r2d2/datasets/web_images.py
index 7c17fbe956f3b4db25d9a4148e8f7c615f122478..f22580f44a9b2488980ab88b656073d8531c3362 100644
--- a/third_party/r2d2/datasets/web_images.py
+++ b/third_party/r2d2/datasets/web_images.py
@@ -8,42 +8,47 @@ from tqdm import trange
 from .dataset import Dataset
 
 
-class RandomWebImages (Dataset):
-    """ 1 million distractors from Oxford and Paris Revisited
-        see http://ptak.felk.cvut.cz/revisitop/revisitop1m/
+class RandomWebImages(Dataset):
+    """1 million distractors from Oxford and Paris Revisited
+    see http://ptak.felk.cvut.cz/revisitop/revisitop1m/
     """
+
     def __init__(self, start=0, end=1024, root="data/revisitop1m"):
         Dataset.__init__(self)
         self.root = root
-        
+
         bar = None
-        self.imgs  = []
+        self.imgs = []
         for i in range(start, end):
-            try: 
+            try:
                 # read cached list
-                img_list_path = os.path.join(self.root, "image_list_%d.txt"%i) 
+                img_list_path = os.path.join(self.root, "image_list_%d.txt" % i)
                 cached_imgs = [e.strip() for e in open(img_list_path)]
                 assert cached_imgs, f"Cache '{img_list_path}' is empty!"
                 self.imgs += cached_imgs
 
             except IOError:
-                if bar is None: 
-                    bar = trange(start, 4*end, desc='Caching')
-                    bar.update(4*i)
-                
+                if bar is None:
+                    bar = trange(start, 4 * end, desc="Caching")
+                    bar.update(4 * i)
+
                 # create it
                 imgs = []
-                for d in range(i*4,(i+1)*4): # 4096 folders in total, on average 256 each
+                for d in range(
+                    i * 4, (i + 1) * 4
+                ):  # 4096 folders in total, on average 256 each
                     key = hex(d)[2:].zfill(3)
                     folder = os.path.join(self.root, key)
-                    if not os.path.isdir(folder): continue
-                    imgs += [f for f in os.listdir(folder) if verify_img(folder,f)]
+                    if not os.path.isdir(folder):
+                        continue
+                    imgs += [f for f in os.listdir(folder) if verify_img(folder, f)]
                     bar.update(1)
                 assert imgs, f"No images found in {folder}/"
-                open(img_list_path,'w').write('\n'.join(imgs))
+                open(img_list_path, "w").write("\n".join(imgs))
                 self.imgs += imgs
 
-        if bar: bar.update(bar.total - bar.n)
+        if bar:
+            bar.update(bar.total - bar.n)
         self.nimg = len(self.imgs)
 
     def get_key(self, i):
@@ -53,12 +58,12 @@ class RandomWebImages (Dataset):
 
 def verify_img(folder, f):
     path = os.path.join(folder, f)
-    if not f.endswith('.jpg'): return False
-    try: 
+    if not f.endswith(".jpg"):
+        return False
+    try:
         from PIL import Image
-        Image.open(path).convert('RGB') # try to open it
+
+        Image.open(path).convert("RGB")  # try to open it
         return True
-    except: 
+    except:
         return False
-
-
diff --git a/third_party/r2d2/extract.py b/third_party/r2d2/extract.py
index c3fea02f87c0615504e3648bfd590e413ab13898..14f6d5cf4899bb5abccbb91ca324d264d4c27d7f 100644
--- a/third_party/r2d2/extract.py
+++ b/third_party/r2d2/extract.py
@@ -13,97 +13,105 @@ from tools.dataloader import norm_RGB
 from nets.patchnet import *
 
 
-def load_network(model_fn): 
+def load_network(model_fn):
     checkpoint = torch.load(model_fn)
-    print("\n>> Creating net = " + checkpoint['net']) 
-    net = eval(checkpoint['net'])
+    print("\n>> Creating net = " + checkpoint["net"])
+    net = eval(checkpoint["net"])
     nb_of_weights = common.model_size(net)
     print(f" ( Model size: {nb_of_weights/1000:.0f}K parameters )")
 
     # initialization
-    weights = checkpoint['state_dict']
-    net.load_state_dict({k.replace('module.',''):v for k,v in weights.items()})
+    weights = checkpoint["state_dict"]
+    net.load_state_dict({k.replace("module.", ""): v for k, v in weights.items()})
     return net.eval()
 
 
-class NonMaxSuppression (torch.nn.Module):
+class NonMaxSuppression(torch.nn.Module):
     def __init__(self, rel_thr=0.7, rep_thr=0.7):
         nn.Module.__init__(self)
         self.max_filter = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
         self.rel_thr = rel_thr
         self.rep_thr = rep_thr
-    
+
     def forward(self, reliability, repeatability, **kw):
         assert len(reliability) == len(repeatability) == 1
         reliability, repeatability = reliability[0], repeatability[0]
 
         # local maxima
-        maxima = (repeatability == self.max_filter(repeatability))
+        maxima = repeatability == self.max_filter(repeatability)
 
         # remove low peaks
-        maxima *= (repeatability >= self.rep_thr)
-        maxima *= (reliability   >= self.rel_thr)
+        maxima *= repeatability >= self.rep_thr
+        maxima *= reliability >= self.rel_thr
 
         return maxima.nonzero().t()[2:4]
 
 
-def extract_multiscale( net, img, detector, scale_f=2**0.25, 
-                        min_scale=0.0, max_scale=1, 
-                        min_size=256, max_size=1024, 
-                        verbose=False):
-    old_bm = torch.backends.cudnn.benchmark 
-    torch.backends.cudnn.benchmark = False # speedup
-    
+def extract_multiscale(
+    net,
+    img,
+    detector,
+    scale_f=2**0.25,
+    min_scale=0.0,
+    max_scale=1,
+    min_size=256,
+    max_size=1024,
+    verbose=False,
+):
+    old_bm = torch.backends.cudnn.benchmark
+    torch.backends.cudnn.benchmark = False  # speedup
+
     # extract keypoints at multiple scales
     B, three, H, W = img.shape
     assert B == 1 and three == 3, "should be a batch with a single RGB image"
-    
+
     assert max_scale <= 1
-    s = 1.0 # current scale factor
-    
-    X,Y,S,C,Q,D = [],[],[],[],[],[]
-    while  s+0.001 >= max(min_scale, min_size / max(H,W)):
-        if s-0.001 <= min(max_scale, max_size / max(H,W)):
+    s = 1.0  # current scale factor
+
+    X, Y, S, C, Q, D = [], [], [], [], [], []
+    while s + 0.001 >= max(min_scale, min_size / max(H, W)):
+        if s - 0.001 <= min(max_scale, max_size / max(H, W)):
             nh, nw = img.shape[2:]
-            if verbose: print(f"extracting at scale x{s:.02f} = {nw:4d}x{nh:3d}")
+            if verbose:
+                print(f"extracting at scale x{s:.02f} = {nw:4d}x{nh:3d}")
             # extract descriptors
             with torch.no_grad():
                 res = net(imgs=[img])
-                
+
             # get output and reliability map
-            descriptors = res['descriptors'][0]
-            reliability = res['reliability'][0]
-            repeatability = res['repeatability'][0]
+            descriptors = res["descriptors"][0]
+            reliability = res["reliability"][0]
+            repeatability = res["repeatability"][0]
 
             # normalize the reliability for nms
             # extract maxima and descs
-            y,x = detector(**res) # nms
-            c = reliability[0,0,y,x]
-            q = repeatability[0,0,y,x]
-            d = descriptors[0,:,y,x].t()
+            y, x = detector(**res)  # nms
+            c = reliability[0, 0, y, x]
+            q = repeatability[0, 0, y, x]
+            d = descriptors[0, :, y, x].t()
             n = d.shape[0]
 
             # accumulate multiple scales
-            X.append(x.float() * W/nw)
-            Y.append(y.float() * H/nh)
-            S.append((32/s) * torch.ones(n, dtype=torch.float32, device=d.device))
+            X.append(x.float() * W / nw)
+            Y.append(y.float() * H / nh)
+            S.append((32 / s) * torch.ones(n, dtype=torch.float32, device=d.device))
             C.append(c)
             Q.append(q)
             D.append(d)
         s /= scale_f
 
         # down-scale the image for next iteration
-        nh, nw = round(H*s), round(W*s)
-        img = F.interpolate(img, (nh,nw), mode='bilinear', align_corners=False)
+        nh, nw = round(H * s), round(W * s)
+        img = F.interpolate(img, (nh, nw), mode="bilinear", align_corners=False)
 
     # restore value
     torch.backends.cudnn.benchmark = old_bm
 
     Y = torch.cat(Y)
     X = torch.cat(X)
-    S = torch.cat(S) # scale
-    scores = torch.cat(C) * torch.cat(Q) # scores = reliability * repeatability
-    XYS = torch.stack([X,Y,S], dim=-1)
+    S = torch.cat(S)  # scale
+    scores = torch.cat(C) * torch.cat(Q)  # scores = reliability * repeatability
+    XYS = torch.stack([X, Y, S], dim=-1)
     D = torch.cat(D)
     return XYS, D, scores
 
@@ -113,71 +121,82 @@ def extract_keypoints(args):
 
     # load the network...
     net = load_network(args.model)
-    if iscuda: net = net.cuda()
+    if iscuda:
+        net = net.cuda()
 
     # create the non-maxima detector
     detector = NonMaxSuppression(
-        rel_thr = args.reliability_thr, 
-        rep_thr = args.repeatability_thr)
+        rel_thr=args.reliability_thr, rep_thr=args.repeatability_thr
+    )
 
     while args.images:
         img_path = args.images.pop(0)
-        
-        if img_path.endswith('.txt'):
+
+        if img_path.endswith(".txt"):
             args.images = open(img_path).read().splitlines() + args.images
             continue
-        
+
         print(f"\nExtracting features for {img_path}")
-        img = Image.open(img_path).convert('RGB')
+        img = Image.open(img_path).convert("RGB")
         W, H = img.size
-        img = norm_RGB(img)[None] 
-        if iscuda: img = img.cuda()
-        
+        img = norm_RGB(img)[None]
+        if iscuda:
+            img = img.cuda()
+
         # extract keypoints/descriptors for a single image
-        xys, desc, scores = extract_multiscale(net, img, detector,
-            scale_f   = args.scale_f, 
-            min_scale = args.min_scale, 
-            max_scale = args.max_scale,
-            min_size  = args.min_size, 
-            max_size  = args.max_size, 
-            verbose = True)
+        xys, desc, scores = extract_multiscale(
+            net,
+            img,
+            detector,
+            scale_f=args.scale_f,
+            min_scale=args.min_scale,
+            max_scale=args.max_scale,
+            min_size=args.min_size,
+            max_size=args.max_size,
+            verbose=True,
+        )
 
         xys = xys.cpu().numpy()
         desc = desc.cpu().numpy()
         scores = scores.cpu().numpy()
-        idxs = scores.argsort()[-args.top_k or None:]
-        
-        outpath = img_path + '.' + args.tag
-        print(f"Saving {len(idxs)} keypoints to {outpath}")
-        np.savez(open(outpath,'wb'), 
-            imsize = (W,H),
-            keypoints = xys[idxs], 
-            descriptors = desc[idxs], 
-            scores = scores[idxs])
+        idxs = scores.argsort()[-args.top_k or None :]
 
+        outpath = img_path + "." + args.tag
+        print(f"Saving {len(idxs)} keypoints to {outpath}")
+        np.savez(
+            open(outpath, "wb"),
+            imsize=(W, H),
+            keypoints=xys[idxs],
+            descriptors=desc[idxs],
+            scores=scores[idxs],
+        )
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     import argparse
+
     parser = argparse.ArgumentParser("Extract keypoints for a given image")
-    parser.add_argument("--model", type=str, required=True, help='model path')
-    
-    parser.add_argument("--images", type=str, required=True, nargs='+', help='images / list')
-    parser.add_argument("--tag", type=str, default='r2d2', help='output file tag')
-    
-    parser.add_argument("--top-k", type=int, default=5000, help='number of keypoints')
+    parser.add_argument("--model", type=str, required=True, help="model path")
+
+    parser.add_argument(
+        "--images", type=str, required=True, nargs="+", help="images / list"
+    )
+    parser.add_argument("--tag", type=str, default="r2d2", help="output file tag")
+
+    parser.add_argument("--top-k", type=int, default=5000, help="number of keypoints")
 
     parser.add_argument("--scale-f", type=float, default=2**0.25)
     parser.add_argument("--min-size", type=int, default=256)
     parser.add_argument("--max-size", type=int, default=1024)
     parser.add_argument("--min-scale", type=float, default=0)
     parser.add_argument("--max-scale", type=float, default=1)
-    
+
     parser.add_argument("--reliability-thr", type=float, default=0.7)
     parser.add_argument("--repeatability-thr", type=float, default=0.7)
 
-    parser.add_argument("--gpu", type=int, nargs='+', default=[0], help='use -1 for CPU')
+    parser.add_argument(
+        "--gpu", type=int, nargs="+", default=[0], help="use -1 for CPU"
+    )
     args = parser.parse_args()
 
     extract_keypoints(args)
-
diff --git a/third_party/r2d2/extract_kapture.py b/third_party/r2d2/extract_kapture.py
index 51b2403b8a1730eaee32d099d0b6dd5d091ccdda..8e46bb5306c943ce985a13168934105b1978deb9 100644
--- a/third_party/r2d2/extract_kapture.py
+++ b/third_party/r2d2/extract_kapture.py
@@ -20,9 +20,21 @@ from extract import load_network, NonMaxSuppression, extract_multiscale
 import kapture
 from kapture.io.records import get_image_fullpath
 from kapture.io.csv import kapture_from_dir
-from kapture.io.csv import get_feature_csv_fullpath, keypoints_to_file, descriptors_to_file
-from kapture.io.features import get_keypoints_fullpath, keypoints_check_dir, image_keypoints_to_file
-from kapture.io.features import get_descriptors_fullpath, descriptors_check_dir, image_descriptors_to_file
+from kapture.io.csv import (
+    get_feature_csv_fullpath,
+    keypoints_to_file,
+    descriptors_to_file,
+)
+from kapture.io.features import (
+    get_keypoints_fullpath,
+    keypoints_check_dir,
+    image_keypoints_to_file,
+)
+from kapture.io.features import (
+    get_descriptors_fullpath,
+    descriptors_check_dir,
+    image_descriptors_to_file,
+)
 from kapture.io.csv import get_all_tar_handlers
 
 
@@ -30,41 +42,60 @@ def extract_kapture_keypoints(args):
     """
     Extract r2d2 keypoints and descritors to the kapture format directly
     """
-    print('extract_kapture_keypoints...')
-    with get_all_tar_handlers(args.kapture_root,
-                              mode={kapture.Keypoints: 'a',
-                                    kapture.Descriptors: 'a',
-                                    kapture.GlobalFeatures: 'r',
-                                    kapture.Matches: 'r'}) as tar_handlers:
-        kdata = kapture_from_dir(args.kapture_root, None,
-                                 skip_list=[kapture.GlobalFeatures,
-                                            kapture.Matches,
-                                            kapture.Points3d,
-                                            kapture.Observations],
-                                 tar_handlers=tar_handlers)
+    print("extract_kapture_keypoints...")
+    with get_all_tar_handlers(
+        args.kapture_root,
+        mode={
+            kapture.Keypoints: "a",
+            kapture.Descriptors: "a",
+            kapture.GlobalFeatures: "r",
+            kapture.Matches: "r",
+        },
+    ) as tar_handlers:
+        kdata = kapture_from_dir(
+            args.kapture_root,
+            None,
+            skip_list=[
+                kapture.GlobalFeatures,
+                kapture.Matches,
+                kapture.Points3d,
+                kapture.Observations,
+            ],
+            tar_handlers=tar_handlers,
+        )
 
         assert kdata.records_camera is not None
-        image_list = [filename for _, _, filename in kapture.flatten(kdata.records_camera)]
+        image_list = [
+            filename for _, _, filename in kapture.flatten(kdata.records_camera)
+        ]
         if args.keypoints_type is None:
             args.keypoints_type = path.splitext(path.basename(args.model))[0]
-            print(f'keypoints_type set to {args.keypoints_type}')
+            print(f"keypoints_type set to {args.keypoints_type}")
         if args.descriptors_type is None:
             args.descriptors_type = path.splitext(path.basename(args.model))[0]
-            print(f'descriptors_type set to {args.descriptors_type}')
-
-        if kdata.keypoints is not None and args.keypoints_type in kdata.keypoints \
-                and kdata.descriptors is not None and args.descriptors_type in kdata.descriptors:
-            print('detected already computed features of same keypoints_type/descriptors_type, resuming extraction...')
-            image_list = [name
-                          for name in image_list
-                          if name not in kdata.keypoints[args.keypoints_type] or
-                          name not in kdata.descriptors[args.descriptors_type]]
+            print(f"descriptors_type set to {args.descriptors_type}")
+
+        if (
+            kdata.keypoints is not None
+            and args.keypoints_type in kdata.keypoints
+            and kdata.descriptors is not None
+            and args.descriptors_type in kdata.descriptors
+        ):
+            print(
+                "detected already computed features of same keypoints_type/descriptors_type, resuming extraction..."
+            )
+            image_list = [
+                name
+                for name in image_list
+                if name not in kdata.keypoints[args.keypoints_type]
+                or name not in kdata.descriptors[args.descriptors_type]
+            ]
 
         if len(image_list) == 0:
-            print('All features were already extracted')
+            print("All features were already extracted")
             return
         else:
-            print(f'Extracting r2d2 features for {len(image_list)} images')
+            print(f"Extracting r2d2 features for {len(image_list)} images")
 
         iscuda = common.torch_set_gpu(args.gpu)
 
@@ -75,8 +106,8 @@ def extract_kapture_keypoints(args):
 
         # create the non-maxima detector
         detector = NonMaxSuppression(
-            rel_thr=args.reliability_thr,
-            rep_thr=args.repeatability_thr)
+            rel_thr=args.reliability_thr, rep_thr=args.repeatability_thr
+        )
 
         if kdata.keypoints is None:
             kdata.keypoints = {}
@@ -99,25 +130,29 @@ def extract_kapture_keypoints(args):
         for image_name in image_list:
             img_path = get_image_fullpath(args.kapture_root, image_name)
             print(f"\nExtracting features for {img_path}")
-            img = Image.open(img_path).convert('RGB')
+            img = Image.open(img_path).convert("RGB")
             W, H = img.size
             img = norm_RGB(img)[None]
             if iscuda:
                 img = img.cuda()
 
             # extract keypoints/descriptors for a single image
-            xys, desc, scores = extract_multiscale(net, img, detector,
-                                                   scale_f=args.scale_f,
-                                                   min_scale=args.min_scale,
-                                                   max_scale=args.max_scale,
-                                                   min_size=args.min_size,
-                                                   max_size=args.max_size,
-                                                   verbose=True)
+            xys, desc, scores = extract_multiscale(
+                net,
+                img,
+                detector,
+                scale_f=args.scale_f,
+                min_scale=args.min_scale,
+                max_scale=args.max_scale,
+                min_size=args.min_size,
+                max_size=args.max_size,
+                verbose=True,
+            )
 
             xys = xys.cpu().numpy()
             desc = desc.cpu().numpy()
             scores = scores.cpu().numpy()
-            idxs = scores.argsort()[-args.top_k or None:]
+            idxs = scores.argsort()[-args.top_k or None :]
 
             xys = xys[idxs]
             desc = desc[idxs]
@@ -128,56 +163,93 @@ def extract_kapture_keypoints(args):
                 keypoints_dsize = xys.shape[1]
                 descriptors_dsize = desc.shape[1]
 
-                kdata.keypoints[args.keypoints_type] = kapture.Keypoints('r2d2', keypoints_dtype, keypoints_dsize)
-                kdata.descriptors[args.descriptors_type] = kapture.Descriptors('r2d2', descriptors_dtype,
-                                                                               descriptors_dsize,
-                                                                               args.keypoints_type, 'L2')
-                keypoints_config_absolute_path = get_feature_csv_fullpath(kapture.Keypoints,
-                                                                          args.keypoints_type,
-                                                                          args.kapture_root)
-                descriptors_config_absolute_path = get_feature_csv_fullpath(kapture.Descriptors,
-                                                                            args.descriptors_type,
-                                                                            args.kapture_root)
-                keypoints_to_file(keypoints_config_absolute_path, kdata.keypoints[args.keypoints_type])
-                descriptors_to_file(descriptors_config_absolute_path, kdata.descriptors[args.descriptors_type])
+                kdata.keypoints[args.keypoints_type] = kapture.Keypoints(
+                    "r2d2", keypoints_dtype, keypoints_dsize
+                )
+                kdata.descriptors[args.descriptors_type] = kapture.Descriptors(
+                    "r2d2",
+                    descriptors_dtype,
+                    descriptors_dsize,
+                    args.keypoints_type,
+                    "L2",
+                )
+                keypoints_config_absolute_path = get_feature_csv_fullpath(
+                    kapture.Keypoints, args.keypoints_type, args.kapture_root
+                )
+                descriptors_config_absolute_path = get_feature_csv_fullpath(
+                    kapture.Descriptors, args.descriptors_type, args.kapture_root
+                )
+                keypoints_to_file(
+                    keypoints_config_absolute_path, kdata.keypoints[args.keypoints_type]
+                )
+                descriptors_to_file(
+                    descriptors_config_absolute_path,
+                    kdata.descriptors[args.descriptors_type],
+                )
             else:
                 assert kdata.keypoints[args.keypoints_type].dtype == xys.dtype
                 assert kdata.descriptors[args.descriptors_type].dtype == desc.dtype
                 assert kdata.keypoints[args.keypoints_type].dsize == xys.shape[1]
                 assert kdata.descriptors[args.descriptors_type].dsize == desc.shape[1]
-                assert kdata.descriptors[args.descriptors_type].keypoints_type == args.keypoints_type
-                assert kdata.descriptors[args.descriptors_type].metric_type == 'L2'
-
-            keypoints_fullpath = get_keypoints_fullpath(args.keypoints_type, args.kapture_root,
-                                                        image_name, tar_handlers)
+                assert (
+                    kdata.descriptors[args.descriptors_type].keypoints_type
+                    == args.keypoints_type
+                )
+                assert kdata.descriptors[args.descriptors_type].metric_type == "L2"
+
+            keypoints_fullpath = get_keypoints_fullpath(
+                args.keypoints_type, args.kapture_root, image_name, tar_handlers
+            )
             print(f"Saving {xys.shape[0]} keypoints to {keypoints_fullpath}")
             image_keypoints_to_file(keypoints_fullpath, xys)
             kdata.keypoints[args.keypoints_type].add(image_name)
 
-            descriptors_fullpath = get_descriptors_fullpath(args.descriptors_type, args.kapture_root,
-                                                            image_name, tar_handlers)
+            descriptors_fullpath = get_descriptors_fullpath(
+                args.descriptors_type, args.kapture_root, image_name, tar_handlers
+            )
             print(f"Saving {desc.shape[0]} descriptors to {descriptors_fullpath}")
             image_descriptors_to_file(descriptors_fullpath, desc)
             kdata.descriptors[args.descriptors_type].add(image_name)
 
-        if not keypoints_check_dir(kdata.keypoints[args.keypoints_type], args.keypoints_type,
-                                   args.kapture_root, tar_handlers) or \
-                not descriptors_check_dir(kdata.descriptors[args.descriptors_type], args.descriptors_type,
-                                          args.kapture_root, tar_handlers):
-            print('local feature extraction ended successfully but not all files were saved')
-
-
-if __name__ == '__main__':
+        if not keypoints_check_dir(
+            kdata.keypoints[args.keypoints_type],
+            args.keypoints_type,
+            args.kapture_root,
+            tar_handlers,
+        ) or not descriptors_check_dir(
+            kdata.descriptors[args.descriptors_type],
+            args.descriptors_type,
+            args.kapture_root,
+            tar_handlers,
+        ):
+            print(
+                "local feature extraction ended successfully but not all files were saved"
+            )
+
+
+if __name__ == "__main__":
     import argparse
-    parser = argparse.ArgumentParser(
-        "Extract r2d2 local features for all images in a dataset stored in the kapture format")
-    parser.add_argument("--model", type=str, required=True, help='model path')
-    parser.add_argument('--keypoints-type', default=None,  help='keypoint type_name, default is filename of model')
-    parser.add_argument('--descriptors-type', default=None,  help='descriptors type_name, default is filename of model')
-
-    parser.add_argument("--kapture-root", type=str, required=True, help='path to kapture root directory')
 
-    parser.add_argument("--top-k", type=int, default=5000, help='number of keypoints')
+    parser = argparse.ArgumentParser(
+        "Extract r2d2 local features for all images in a dataset stored in the kapture format"
+    )
+    parser.add_argument("--model", type=str, required=True, help="model path")
+    parser.add_argument(
+        "--keypoints-type",
+        default=None,
+        help="keypoint type_name, default is filename of model",
+    )
+    parser.add_argument(
+        "--descriptors-type",
+        default=None,
+        help="descriptors type_name, default is filename of model",
+    )
+
+    parser.add_argument(
+        "--kapture-root", type=str, required=True, help="path to kapture root directory"
+    )
+
+    parser.add_argument("--top-k", type=int, default=5000, help="number of keypoints")
 
     parser.add_argument("--scale-f", type=float, default=2**0.25)
     parser.add_argument("--min-size", type=int, default=256)
@@ -188,7 +260,9 @@ if __name__ == '__main__':
     parser.add_argument("--reliability-thr", type=float, default=0.7)
     parser.add_argument("--repeatability-thr", type=float, default=0.7)
 
-    parser.add_argument("--gpu", type=int, nargs='+', default=[0], help='use -1 for CPU')
+    parser.add_argument(
+        "--gpu", type=int, nargs="+", default=[0], help="use -1 for CPU"
+    )
     args = parser.parse_args()
 
     extract_kapture_keypoints(args)
diff --git a/third_party/r2d2/nets/ap_loss.py b/third_party/r2d2/nets/ap_loss.py
index 251815cd97009a5feb6a815c20caca0c40daaccd..deb59e4c067aa25c834caf4d0a3c06f9d470ecd4 100644
--- a/third_party/r2d2/nets/ap_loss.py
+++ b/third_party/r2d2/nets/ap_loss.py
@@ -8,15 +8,16 @@ import torch
 import torch.nn as nn
 
 
-class APLoss (nn.Module):
-    """ differentiable AP loss, through quantization.
-        
-        Input: (N, M)   values in [min, max]
-        label: (N, M)   values in {0, 1}
-        
-        Returns: list of query AP (for each n in {1..N})
-                 Note: typically, you want to minimize 1 - mean(AP)
+class APLoss(nn.Module):
+    """differentiable AP loss, through quantization.
+
+    Input: (N, M)   values in [min, max]
+    label: (N, M)   values in {0, 1}
+
+    Returns: list of query AP (for each n in {1..N})
+             Note: typically, you want to minimize 1 - mean(AP)
     """
+
     def __init__(self, nq=25, min=0, max=1, euc=False):
         nn.Module.__init__(self)
         assert isinstance(nq, int) and 2 <= nq <= 100
@@ -26,16 +27,20 @@ class APLoss (nn.Module):
         self.euc = euc
         gap = max - min
         assert gap > 0
-        
+
         # init quantizer = non-learnable (fixed) convolution
-        self.quantizer = q = nn.Conv1d(1, 2*nq, kernel_size=1, bias=True)
-        a = (nq-1) / gap
-        #1st half = lines passing to (min+x,1) and (min+x+1/a,0) with x = {nq-1..0}*gap/(nq-1)
+        self.quantizer = q = nn.Conv1d(1, 2 * nq, kernel_size=1, bias=True)
+        a = (nq - 1) / gap
+        # 1st half = lines passing to (min+x,1) and (min+x+1/a,0) with x = {nq-1..0}*gap/(nq-1)
         q.weight.data[:nq] = -a
-        q.bias.data[:nq] = torch.from_numpy(a*min + np.arange(nq, 0, -1)) # b = 1 + a*(min+x)
-        #2nd half = lines passing to (min+x,1) and (min+x-1/a,0) with x = {nq-1..0}*gap/(nq-1)
+        q.bias.data[:nq] = torch.from_numpy(
+            a * min + np.arange(nq, 0, -1)
+        )  # b = 1 + a*(min+x)
+        # 2nd half = lines passing to (min+x,1) and (min+x-1/a,0) with x = {nq-1..0}*gap/(nq-1)
         q.weight.data[nq:] = a
-        q.bias.data[nq:] = torch.from_numpy(np.arange(2-nq, 2, 1) - a*min) # b = 1 - a*(min+x)
+        q.bias.data[nq:] = torch.from_numpy(
+            np.arange(2 - nq, 2, 1) - a * min
+        )  # b = 1 - a*(min+x)
         # first and last one are special: just horizontal straight line
         q.weight.data[0] = q.weight.data[-1] = 0
         q.bias.data[0] = q.bias.data[-1] = 1
@@ -43,25 +48,22 @@ class APLoss (nn.Module):
     def compute_AP(self, x, label):
         N, M = x.shape
         if self.euc:  # euclidean distance in same range than similarities
-            x = 1 - torch.sqrt(2.001 - 2*x)
+            x = 1 - torch.sqrt(2.001 - 2 * x)
 
         # quantize all predictions
         q = self.quantizer(x.unsqueeze(1))
-        q = torch.min(q[:,:self.nq], q[:,self.nq:]).clamp(min=0) # N x Q x M
+        q = torch.min(q[:, : self.nq], q[:, self.nq :]).clamp(min=0)  # N x Q x M
 
-        nbs = q.sum(dim=-1) # number of samples  N x Q = c
-        rec = (q * label.view(N,1,M).float()).sum(dim=-1) # nb of correct samples = c+ N x Q
-        prec = rec.cumsum(dim=-1) / (1e-16 + nbs.cumsum(dim=-1)) # precision
-        rec /= rec.sum(dim=-1).unsqueeze(1) # norm in [0,1]
+        nbs = q.sum(dim=-1)  # number of samples  N x Q = c
+        rec = (q * label.view(N, 1, M).float()).sum(
+            dim=-1
+        )  # nb of correct samples = c+ N x Q
+        prec = rec.cumsum(dim=-1) / (1e-16 + nbs.cumsum(dim=-1))  # precision
+        rec /= rec.sum(dim=-1).unsqueeze(1)  # norm in [0,1]
 
-        ap = (prec * rec).sum(dim=-1) # per-image AP
+        ap = (prec * rec).sum(dim=-1)  # per-image AP
         return ap
 
     def forward(self, x, label):
-        assert x.shape == label.shape # N x M
+        assert x.shape == label.shape  # N x M
         return self.compute_AP(x, label)
-
-
-
-
-
diff --git a/third_party/r2d2/nets/losses.py b/third_party/r2d2/nets/losses.py
index f8eea8f6e82835e22d2bb445125f7dc722db85b2..973c592aab3f8f1c69b4001d1d324f1ad46ebe2d 100644
--- a/third_party/r2d2/nets/losses.py
+++ b/third_party/r2d2/nets/losses.py
@@ -13,44 +13,40 @@ from nets.repeatability_loss import *
 from nets.reliability_loss import *
 
 
-class MultiLoss (nn.Module):
-    """ Combines several loss functions for convenience.
+class MultiLoss(nn.Module):
+    """Combines several loss functions for convenience.
     *args: [loss weight (float), loss creator, ... ]
-    
+
     Example:
         loss = MultiLoss( 1, MyFirstLoss(), 0.5, MySecondLoss() )
     """
+
     def __init__(self, *args, dbg=()):
         nn.Module.__init__(self)
-        assert len(args) % 2 == 0, 'args must be a list of (float, loss)'
+        assert len(args) % 2 == 0, "args must be a list of (float, loss)"
         self.weights = []
         self.losses = nn.ModuleList()
-        for i in range(len(args)//2):
-            weight = float(args[2*i+0])
-            loss = args[2*i+1]
+        for i in range(len(args) // 2):
+            weight = float(args[2 * i + 0])
+            loss = args[2 * i + 1]
             assert isinstance(loss, nn.Module), "%s is not a loss!" % loss
             self.weights.append(weight)
             self.losses.append(loss)
 
     def forward(self, select=None, **variables):
-        assert not select or all(1<=n<=len(self.losses) for n in select)
+        assert not select or all(1 <= n <= len(self.losses) for n in select)
         d = dict()
         cum_loss = 0
-        for num, (weight, loss_func) in enumerate(zip(self.weights, self.losses),1):
-            if select is not None and num not in select: continue
-            l = loss_func(**{k:v for k,v in variables.items()})
+        for num, (weight, loss_func) in enumerate(zip(self.weights, self.losses), 1):
+            if select is not None and num not in select:
+                continue
+            l = loss_func(**{k: v for k, v in variables.items()})
             if isinstance(l, tuple):
                 assert len(l) == 2 and isinstance(l[1], dict)
             else:
-                l = l, {loss_func.name:l}
+                l = l, {loss_func.name: l}
             cum_loss = cum_loss + weight * l[0]
-            for key,val in l[1].items():
-                d['loss_'+key] = float(val)
-        d['loss'] = float(cum_loss)
+            for key, val in l[1].items():
+                d["loss_" + key] = float(val)
+        d["loss"] = float(cum_loss)
         return cum_loss, d
-
-
-
-
-
-
diff --git a/third_party/r2d2/nets/patchnet.py b/third_party/r2d2/nets/patchnet.py
index 854c61ecf9b879fa7f420255296c4fbbfd665181..8ed3fdbd55ccbbd58f0cea3dad9384a402ec5e9d 100644
--- a/third_party/r2d2/nets/patchnet.py
+++ b/third_party/r2d2/nets/patchnet.py
@@ -8,22 +8,25 @@ import torch.nn as nn
 import torch.nn.functional as F
 
 
-class BaseNet (nn.Module):
-    """ Takes a list of images as input, and returns for each image:
-        - a pixelwise descriptor
-        - a pixelwise confidence
+class BaseNet(nn.Module):
+    """Takes a list of images as input, and returns for each image:
+    - a pixelwise descriptor
+    - a pixelwise confidence
     """
+
     def softmax(self, ux):
         if ux.shape[1] == 1:
             x = F.softplus(ux)
             return x / (1 + x)  # for sure in [0,1], much less plateaus than softmax
         elif ux.shape[1] == 2:
-            return F.softmax(ux, dim=1)[:,1:2]
+            return F.softmax(ux, dim=1)[:, 1:2]
 
     def normalize(self, x, ureliability, urepeatability):
-        return dict(descriptors = F.normalize(x, p=2, dim=1),
-                    repeatability = self.softmax( urepeatability ),
-                    reliability = self.softmax( ureliability ))
+        return dict(
+            descriptors=F.normalize(x, p=2, dim=1),
+            repeatability=self.softmax(urepeatability),
+            reliability=self.softmax(ureliability),
+        )
 
     def forward_one(self, x):
         raise NotImplementedError()
@@ -31,15 +34,15 @@ class BaseNet (nn.Module):
     def forward(self, imgs, **kw):
         res = [self.forward_one(img) for img in imgs]
         # merge all dictionaries into one
-        res = {k:[r[k] for r in res if k in r] for k in {k for r in res for k in r}}
+        res = {k: [r[k] for r in res if k in r] for k in {k for r in res for k in r}}
         return dict(res, imgs=imgs, **kw)
 
 
-
-class PatchNet (BaseNet):
-    """ Helper class to construct a fully-convolutional network that
-        extract a l2-normalized patch descriptor.
+class PatchNet(BaseNet):
+    """Helper class to construct a fully-convolutional network that
+    extract a l2-normalized patch descriptor.
     """
+
     def __init__(self, inchan=3, dilated=True, dilation=1, bn=True, bn_affine=False):
         BaseNet.__init__(self)
         self.inchan = inchan
@@ -53,41 +56,54 @@ class PatchNet (BaseNet):
     def _make_bn(self, outd):
         return nn.BatchNorm2d(outd, affine=self.bn_affine)
 
-    def _add_conv(self, outd, k=3, stride=1, dilation=1, bn=True, relu=True, k_pool = 1, pool_type='max'):
+    def _add_conv(
+        self,
+        outd,
+        k=3,
+        stride=1,
+        dilation=1,
+        bn=True,
+        relu=True,
+        k_pool=1,
+        pool_type="max",
+    ):
         # as in the original implementation, dilation is applied at the end of layer, so it will have impact only from next layer
         d = self.dilation * dilation
-        if self.dilated: 
-            conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=1)
+        if self.dilated:
+            conv_params = dict(padding=((k - 1) * d) // 2, dilation=d, stride=1)
             self.dilation *= stride
         else:
-            conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=stride)
-        self.ops.append( nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params) )
-        if bn and self.bn: self.ops.append( self._make_bn(outd) )
-        if relu: self.ops.append( nn.ReLU(inplace=True) )
+            conv_params = dict(padding=((k - 1) * d) // 2, dilation=d, stride=stride)
+        self.ops.append(nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params))
+        if bn and self.bn:
+            self.ops.append(self._make_bn(outd))
+        if relu:
+            self.ops.append(nn.ReLU(inplace=True))
         self.curchan = outd
-        
+
         if k_pool > 1:
-            if pool_type == 'avg':
+            if pool_type == "avg":
                 self.ops.append(torch.nn.AvgPool2d(kernel_size=k_pool))
-            elif pool_type == 'max':
+            elif pool_type == "max":
                 self.ops.append(torch.nn.MaxPool2d(kernel_size=k_pool))
             else:
                 print(f"Error, unknown pooling type {pool_type}...")
-    
+
     def forward_one(self, x):
         assert self.ops, "You need to add convolutions first"
-        for n,op in enumerate(self.ops):
+        for n, op in enumerate(self.ops):
             x = op(x)
         return self.normalize(x)
 
 
-class L2_Net (PatchNet):
-    """ Compute a 128D descriptor for all overlapping 32x32 patches.
-        From the L2Net paper (CVPR'17).
+class L2_Net(PatchNet):
+    """Compute a 128D descriptor for all overlapping 32x32 patches.
+    From the L2Net paper (CVPR'17).
     """
-    def __init__(self, dim=128, **kw ):
+
+    def __init__(self, dim=128, **kw):
         PatchNet.__init__(self, **kw)
-        add_conv = lambda n,**kw: self._add_conv((n*dim)//128,**kw)
+        add_conv = lambda n, **kw: self._add_conv((n * dim) // 128, **kw)
         add_conv(32)
         add_conv(32)
         add_conv(64, stride=2)
@@ -98,35 +114,34 @@ class L2_Net (PatchNet):
         self.out_dim = dim
 
 
-class Quad_L2Net (PatchNet):
-    """ Same than L2_Net, but replace the final 8x8 conv by 3 successive 2x2 convs.
-    """
-    def __init__(self, dim=128, mchan=4, relu22=False, **kw ):
+class Quad_L2Net(PatchNet):
+    """Same than L2_Net, but replace the final 8x8 conv by 3 successive 2x2 convs."""
+
+    def __init__(self, dim=128, mchan=4, relu22=False, **kw):
         PatchNet.__init__(self, **kw)
-        self._add_conv(  8*mchan)
-        self._add_conv(  8*mchan)
-        self._add_conv( 16*mchan, stride=2)
-        self._add_conv( 16*mchan)
-        self._add_conv( 32*mchan, stride=2)
-        self._add_conv( 32*mchan)
+        self._add_conv(8 * mchan)
+        self._add_conv(8 * mchan)
+        self._add_conv(16 * mchan, stride=2)
+        self._add_conv(16 * mchan)
+        self._add_conv(32 * mchan, stride=2)
+        self._add_conv(32 * mchan)
         # replace last 8x8 convolution with 3 2x2 convolutions
-        self._add_conv( 32*mchan, k=2, stride=2, relu=relu22)
-        self._add_conv( 32*mchan, k=2, stride=2, relu=relu22)
+        self._add_conv(32 * mchan, k=2, stride=2, relu=relu22)
+        self._add_conv(32 * mchan, k=2, stride=2, relu=relu22)
         self._add_conv(dim, k=2, stride=2, bn=False, relu=False)
         self.out_dim = dim
 
 
+class Quad_L2Net_ConfCFS(Quad_L2Net):
+    """Same than Quad_L2Net, with 2 confidence maps for repeatability and reliability."""
 
-class Quad_L2Net_ConfCFS (Quad_L2Net):
-    """ Same than Quad_L2Net, with 2 confidence maps for repeatability and reliability.
-    """
-    def __init__(self, **kw ):
+    def __init__(self, **kw):
         Quad_L2Net.__init__(self, **kw)
         # reliability classifier
         self.clf = nn.Conv2d(self.out_dim, 2, kernel_size=1)
         # repeatability classifier: for some reasons it's a softplus, not a softmax!
         # Why? I guess it's a mistake that was left unnoticed in the code for a long time...
-        self.sal = nn.Conv2d(self.out_dim, 1, kernel_size=1) 
+        self.sal = nn.Conv2d(self.out_dim, 1, kernel_size=1)
 
     def forward_one(self, x):
         assert self.ops, "You need to add convolutions first"
@@ -138,44 +153,51 @@ class Quad_L2Net_ConfCFS (Quad_L2Net):
         return self.normalize(x, ureliability, urepeatability)
 
 
-class Fast_Quad_L2Net (PatchNet):
-    """ Faster version of Quad l2 net, replacing one dilated conv with one pooling to diminish image resolution thus increase inference time
+class Fast_Quad_L2Net(PatchNet):
+    """Faster version of Quad l2 net, replacing one dilated conv with one pooling to diminish image resolution thus increase inference time
     Dilation  factors and pooling:
         1,1,1, pool2, 1,1, 2,2, 4, 8, upsample2
     """
-    def __init__(self, dim=128, mchan=4, relu22=False, downsample_factor=2, **kw ):
+
+    def __init__(self, dim=128, mchan=4, relu22=False, downsample_factor=2, **kw):
 
         PatchNet.__init__(self, **kw)
-        self._add_conv(  8*mchan)
-        self._add_conv(  8*mchan)
-        self._add_conv( 16*mchan, k_pool = downsample_factor) # added avg pooling to decrease img resolution
-        self._add_conv( 16*mchan)
-        self._add_conv( 32*mchan, stride=2)
-        self._add_conv( 32*mchan)
-        
+        self._add_conv(8 * mchan)
+        self._add_conv(8 * mchan)
+        self._add_conv(
+            16 * mchan, k_pool=downsample_factor
+        )  # added avg pooling to decrease img resolution
+        self._add_conv(16 * mchan)
+        self._add_conv(32 * mchan, stride=2)
+        self._add_conv(32 * mchan)
+
         # replace last 8x8 convolution with 3 2x2 convolutions
-        self._add_conv( 32*mchan, k=2, stride=2, relu=relu22)
-        self._add_conv( 32*mchan, k=2, stride=2, relu=relu22)
+        self._add_conv(32 * mchan, k=2, stride=2, relu=relu22)
+        self._add_conv(32 * mchan, k=2, stride=2, relu=relu22)
         self._add_conv(dim, k=2, stride=2, bn=False, relu=False)
-        
+
         # Go back to initial image resolution with upsampling
-        self.ops.append(torch.nn.Upsample(scale_factor=downsample_factor, mode='bilinear', align_corners=False))
-        
+        self.ops.append(
+            torch.nn.Upsample(
+                scale_factor=downsample_factor, mode="bilinear", align_corners=False
+            )
+        )
+
         self.out_dim = dim
-        
-        
-class Fast_Quad_L2Net_ConfCFS (Fast_Quad_L2Net):
-    """ Fast r2d2 architecture
-    """
-    def __init__(self, **kw ):
+
+
+class Fast_Quad_L2Net_ConfCFS(Fast_Quad_L2Net):
+    """Fast r2d2 architecture"""
+
+    def __init__(self, **kw):
         Fast_Quad_L2Net.__init__(self, **kw)
         # reliability classifier
         self.clf = nn.Conv2d(self.out_dim, 2, kernel_size=1)
-        
+
         # repeatability classifier: for some reasons it's a softplus, not a softmax!
         # Why? I guess it's a mistake that was left unnoticed in the code for a long time...
-        self.sal = nn.Conv2d(self.out_dim, 1, kernel_size=1) 
-        
+        self.sal = nn.Conv2d(self.out_dim, 1, kernel_size=1)
+
     def forward_one(self, x):
         assert self.ops, "You need to add convolutions first"
         for op in self.ops:
@@ -183,4 +205,4 @@ class Fast_Quad_L2Net_ConfCFS (Fast_Quad_L2Net):
         # compute the confidence maps
         ureliability = self.clf(x**2)
         urepeatability = self.sal(x**2)
-        return self.normalize(x, ureliability, urepeatability)
\ No newline at end of file
+        return self.normalize(x, ureliability, urepeatability)
diff --git a/third_party/r2d2/nets/reliability_loss.py b/third_party/r2d2/nets/reliability_loss.py
index 52d5383b0eaa52bcf2111eabb4b45e39b63b976f..e560d1ea1b4dc27d81031c62cc4c0aed9161cc67 100644
--- a/third_party/r2d2/nets/reliability_loss.py
+++ b/third_party/r2d2/nets/reliability_loss.py
@@ -9,18 +9,19 @@ import torch.nn.functional as F
 from nets.ap_loss import APLoss
 
 
-class PixelAPLoss (nn.Module):
-    """ Computes the pixel-wise AP loss:
-        Given two images and ground-truth optical flow, computes the AP per pixel.
-        
-        feat1:  (B, C, H, W)   pixel-wise features extracted from img1
-        feat2:  (B, C, H, W)   pixel-wise features extracted from img2
-        aflow:  (B, 2, H, W)   absolute flow: aflow[...,y1,x1] = x2,y2
+class PixelAPLoss(nn.Module):
+    """Computes the pixel-wise AP loss:
+    Given two images and ground-truth optical flow, computes the AP per pixel.
+
+    feat1:  (B, C, H, W)   pixel-wise features extracted from img1
+    feat2:  (B, C, H, W)   pixel-wise features extracted from img2
+    aflow:  (B, 2, H, W)   absolute flow: aflow[...,y1,x1] = x2,y2
     """
+
     def __init__(self, sampler, nq=20):
         nn.Module.__init__(self)
         self.aploss = APLoss(nq, min=0, max=1, euc=False)
-        self.name = 'pixAP'
+        self.name = "pixAP"
         self.sampler = sampler
 
     def loss_from_ap(self, ap, rel):
@@ -28,32 +29,31 @@ class PixelAPLoss (nn.Module):
 
     def forward(self, descriptors, aflow, **kw):
         # subsample things
-        scores, gt, msk, qconf = self.sampler(descriptors, kw.get('reliability'), aflow)
-        
+        scores, gt, msk, qconf = self.sampler(descriptors, kw.get("reliability"), aflow)
+
         # compute pixel-wise AP
         n = qconf.numel()
-        if n == 0: return 0
-        scores, gt = scores.view(n,-1), gt.view(n,-1)
+        if n == 0:
+            return 0
+        scores, gt = scores.view(n, -1), gt.view(n, -1)
         ap = self.aploss(scores, gt).view(msk.shape)
 
         pixel_loss = self.loss_from_ap(ap, qconf)
-        
+
         loss = pixel_loss[msk].mean()
         return loss
 
 
-class ReliabilityLoss (PixelAPLoss):
-    """ same than PixelAPLoss, but also train a pixel-wise confidence
-        that this pixel is going to have a good AP.
+class ReliabilityLoss(PixelAPLoss):
+    """same than PixelAPLoss, but also train a pixel-wise confidence
+    that this pixel is going to have a good AP.
     """
+
     def __init__(self, sampler, base=0.5, **kw):
         PixelAPLoss.__init__(self, sampler, **kw)
         assert 0 <= base < 1
         self.base = base
-        self.name = 'reliability'
+        self.name = "reliability"
 
     def loss_from_ap(self, ap, rel):
-        return 1 - ap*rel - (1-rel)*self.base
-
-
-
+        return 1 - ap * rel - (1 - rel) * self.base
diff --git a/third_party/r2d2/nets/repeatability_loss.py b/third_party/r2d2/nets/repeatability_loss.py
index 5cda0b6d036f98af88a88780fe39da0c5c0b610e..af49e77f444c5b4b035cd43d0c065096e8dd7c1b 100644
--- a/third_party/r2d2/nets/repeatability_loss.py
+++ b/third_party/r2d2/nets/repeatability_loss.py
@@ -10,27 +10,28 @@ import torch.nn.functional as F
 
 from nets.sampler import FullSampler
 
-class CosimLoss (nn.Module):
-    """ Try to make the repeatability repeatable from one image to the other.
-    """
+
+class CosimLoss(nn.Module):
+    """Try to make the repeatability repeatable from one image to the other."""
+
     def __init__(self, N=16):
         nn.Module.__init__(self)
-        self.name = f'cosim{N}'
-        self.patches = nn.Unfold(N, padding=0, stride=N//2)
+        self.name = f"cosim{N}"
+        self.patches = nn.Unfold(N, padding=0, stride=N // 2)
 
     def extract_patches(self, sal):
-        patches = self.patches(sal).transpose(1,2) # flatten
-        patches = F.normalize(patches, p=2, dim=2) # norm
+        patches = self.patches(sal).transpose(1, 2)  # flatten
+        patches = F.normalize(patches, p=2, dim=2)  # norm
         return patches
-        
+
     def forward(self, repeatability, aflow, **kw):
-        B,two,H,W = aflow.shape
+        B, two, H, W = aflow.shape
         assert two == 2
 
         # normalize
         sali1, sali2 = repeatability
         grid = FullSampler._aflow_to_grid(aflow)
-        sali2 = F.grid_sample(sali2, grid, mode='bilinear', padding_mode='border')
+        sali2 = F.grid_sample(sali2, grid, mode="bilinear", padding_mode="border")
 
         patches1 = self.extract_patches(sali1)
         patches2 = self.extract_patches(sali2)
@@ -38,29 +39,25 @@ class CosimLoss (nn.Module):
         return 1 - cosim.mean()
 
 
-class PeakyLoss (nn.Module):
-    """ Try to make the repeatability locally peaky.
+class PeakyLoss(nn.Module):
+    """Try to make the repeatability locally peaky.
 
     Mechanism: we maximize, for each pixel, the difference between the local mean
                and the local max.
     """
+
     def __init__(self, N=16):
         nn.Module.__init__(self)
-        self.name = f'peaky{N}'
-        assert N % 2 == 0, 'N must be pair'
+        self.name = f"peaky{N}"
+        assert N % 2 == 0, "N must be pair"
         self.preproc = nn.AvgPool2d(3, stride=1, padding=1)
-        self.maxpool = nn.MaxPool2d(N+1, stride=1, padding=N//2)
-        self.avgpool = nn.AvgPool2d(N+1, stride=1, padding=N//2)
+        self.maxpool = nn.MaxPool2d(N + 1, stride=1, padding=N // 2)
+        self.avgpool = nn.AvgPool2d(N + 1, stride=1, padding=N // 2)
 
     def forward_one(self, sali):
-        sali = self.preproc(sali) # remove super high frequency
+        sali = self.preproc(sali)  # remove super high frequency
         return 1 - (self.maxpool(sali) - self.avgpool(sali)).mean()
 
     def forward(self, repeatability, **kw):
         sali1, sali2 = repeatability
-        return (self.forward_one(sali1) + self.forward_one(sali2)) /2
-
-
-
-
-
+        return (self.forward_one(sali1) + self.forward_one(sali2)) / 2
diff --git a/third_party/r2d2/nets/sampler.py b/third_party/r2d2/nets/sampler.py
index 9fede70d3a04d7f31a1d414eace0aaf3729e8235..3f2e5a276a80b997561549ed3e8466da3876e382 100644
--- a/third_party/r2d2/nets/sampler.py
+++ b/third_party/r2d2/nets/sampler.py
@@ -15,65 +15,69 @@ import torch.nn.functional as F
 
 
 class FullSampler(nn.Module):
-    """ all pixels are selected
-        - feats: keypoint descriptors
-        - confs: reliability values
+    """all pixels are selected
+    - feats: keypoint descriptors
+    - confs: reliability values
     """
+
     def __init__(self):
         nn.Module.__init__(self)
-        self.mode = 'bilinear'
-        self.padding = 'zeros'
+        self.mode = "bilinear"
+        self.padding = "zeros"
 
     @staticmethod
     def _aflow_to_grid(aflow):
         H, W = aflow.shape[2:]
-        grid = aflow.permute(0,2,3,1).clone()
-        grid[:,:,:,0] *= 2/(W-1)
-        grid[:,:,:,1] *= 2/(H-1)
+        grid = aflow.permute(0, 2, 3, 1).clone()
+        grid[:, :, :, 0] *= 2 / (W - 1)
+        grid[:, :, :, 1] *= 2 / (H - 1)
         grid -= 1
-        grid[torch.isnan(grid)] = 9e9 # invalids
+        grid[torch.isnan(grid)] = 9e9  # invalids
         return grid
-    
+
     def _warp(self, feats, confs, aflow):
-        if isinstance(aflow, tuple): return aflow # result was precomputed
+        if isinstance(aflow, tuple):
+            return aflow  # result was precomputed
         feat1, feat2 = feats
-        conf1, conf2 = confs if confs else (None,None)
-    
+        conf1, conf2 = confs if confs else (None, None)
+
         B, two, H, W = aflow.shape
         D = feat1.shape[1]
-        assert feat1.shape == feat2.shape == (B, D, H, W) # D = 128, B = batch
+        assert feat1.shape == feat2.shape == (B, D, H, W)  # D = 128, B = batch
         assert conf1.shape == conf2.shape == (B, 1, H, W) if confs else True
 
         # warp img2 to img1
         grid = self._aflow_to_grid(aflow)
-        ones2 = feat2.new_ones(feat2[:,0:1].shape)
+        ones2 = feat2.new_ones(feat2[:, 0:1].shape)
         feat2to1 = F.grid_sample(feat2, grid, mode=self.mode, padding_mode=self.padding)
-        mask2to1 = F.grid_sample(ones2, grid, mode='nearest', padding_mode='zeros')
-        conf2to1 = F.grid_sample(conf2, grid, mode=self.mode, padding_mode=self.padding) \
-                   if confs else None
+        mask2to1 = F.grid_sample(ones2, grid, mode="nearest", padding_mode="zeros")
+        conf2to1 = (
+            F.grid_sample(conf2, grid, mode=self.mode, padding_mode=self.padding)
+            if confs
+            else None
+        )
         return feat2to1, mask2to1.byte(), conf2to1
 
     def _warp_positions(self, aflow):
         B, two, H, W = aflow.shape
         assert two == 2
-        
+
         Y = torch.arange(H, device=aflow.device)
         X = torch.arange(W, device=aflow.device)
-        XY = torch.stack(torch.meshgrid(Y,X)[::-1], dim=0)
+        XY = torch.stack(torch.meshgrid(Y, X)[::-1], dim=0)
         XY = XY[None].expand(B, 2, H, W).float()
-        
+
         grid = self._aflow_to_grid(aflow)
-        XY2 = F.grid_sample(XY, grid, mode='bilinear', padding_mode='zeros')
+        XY2 = F.grid_sample(XY, grid, mode="bilinear", padding_mode="zeros")
         return XY, XY2
 
 
+class SubSampler(FullSampler):
+    """pixels are selected in an uniformly spaced grid"""
 
-class SubSampler (FullSampler):
-    """ pixels are selected in an uniformly spaced grid
-    """
     def __init__(self, border, subq, subd, perimage=False):
         FullSampler.__init__(self)
-        assert subq % subd == 0, 'subq must be multiple of subd'
+        assert subq % subd == 0, "subq must be multiple of subd"
         self.sub_q = subq
         self.sub_d = subd
         self.border = border
@@ -81,13 +85,17 @@ class SubSampler (FullSampler):
 
     def __repr__(self):
         return "SubSampler(border=%d, subq=%d, subd=%d, perimage=%d)" % (
-            self.border, self.sub_q, self.sub_d, self.perimage)
+            self.border,
+            self.sub_q,
+            self.sub_d,
+            self.perimage,
+        )
 
     def __call__(self, feats, confs, aflow):
         feat1, conf1 = feats[0], (confs[0] if confs else None)
         # warp with optical flow in img1 coords
         feat2, mask2, conf2 = self._warp(feats, confs, aflow)
-        
+
         # subsample img1
         slq = slice(self.border, -self.border or None, self.sub_q)
         feat1 = feat1[:, :, slq, slq]
@@ -97,47 +105,50 @@ class SubSampler (FullSampler):
         feat2 = feat2[:, :, sld, sld]
         mask2 = mask2[:, :, sld, sld]
         conf2 = conf2[:, :, sld, sld] if confs else None
-        
+
         B, D, Hq, Wq = feat1.shape
         B, D, Hd, Wd = feat2.shape
-        
+
         # compute gt
         if self.perimage or self.sub_q != self.sub_d:
             # compute ground-truth by comparing pixel indices
-            f = feats[0][0:1,0] if self.perimage else feats[0][:,0]
-            idxs = torch.arange(f.numel(), dtype=torch.int64, device=feat1.device).view(f.shape)
-            idxs1 = idxs[:, slq, slq].reshape(-1,Hq*Wq)
-            idxs2 = idxs[:, sld, sld].reshape(-1,Hd*Wd)
+            f = feats[0][0:1, 0] if self.perimage else feats[0][:, 0]
+            idxs = torch.arange(f.numel(), dtype=torch.int64, device=feat1.device).view(
+                f.shape
+            )
+            idxs1 = idxs[:, slq, slq].reshape(-1, Hq * Wq)
+            idxs2 = idxs[:, sld, sld].reshape(-1, Hd * Wd)
             if self.perimage:
-                gt = (idxs1[0].view(-1,1) == idxs2[0].view(1,-1))
-                gt = gt[None,:,:].expand(B, Hq*Wq, Hd*Wd)
-            else :
-                gt = (idxs1.view(-1,1) == idxs2.view(1,-1)) 
+                gt = idxs1[0].view(-1, 1) == idxs2[0].view(1, -1)
+                gt = gt[None, :, :].expand(B, Hq * Wq, Hd * Wd)
+            else:
+                gt = idxs1.view(-1, 1) == idxs2.view(1, -1)
         else:
-            gt = torch.eye(feat1[:,0].numel(), dtype=torch.uint8, device=feat1.device) # always binary for AP loss
-        
+            gt = torch.eye(
+                feat1[:, 0].numel(), dtype=torch.uint8, device=feat1.device
+            )  # always binary for AP loss
+
         # compute all images together
-        queries  =  feat1.reshape(B,D,-1) # B x D x (Hq x Wq)
-        database =  feat2.reshape(B,D,-1) # B x D x (Hd x Wd)
+        queries = feat1.reshape(B, D, -1)  # B x D x (Hq x Wq)
+        database = feat2.reshape(B, D, -1)  # B x D x (Hd x Wd)
         if self.perimage:
-            queries  =  queries.transpose(1,2) # B x (Hd x Wd) x D
-            scores = torch.bmm(queries, database) # B x (Hq x Wq) x (Hd x Wd)
+            queries = queries.transpose(1, 2)  # B x (Hd x Wd) x D
+            scores = torch.bmm(queries, database)  # B x (Hq x Wq) x (Hd x Wd)
         else:
-            queries  =  queries .transpose(1,2).reshape(-1,D) # (B x Hq x Wq) x D
-            database =  database.transpose(1,0).reshape(D,-1) # D x (B x Hd x Wd)
-            scores = torch.matmul(queries, database) # (B x Hq x Wq) x (B x Hd x Wd)
+            queries = queries.transpose(1, 2).reshape(-1, D)  # (B x Hq x Wq) x D
+            database = database.transpose(1, 0).reshape(D, -1)  # D x (B x Hd x Wd)
+            scores = torch.matmul(queries, database)  # (B x Hq x Wq) x (B x Hd x Wd)
 
         # compute reliability
-        qconf = (conf1 + conf2)/2 if confs else None
+        qconf = (conf1 + conf2) / 2 if confs else None
 
         assert gt.shape == scores.shape
         return scores, gt, mask2, qconf
 
 
+class NghSampler(FullSampler):
+    """all pixels in a small neighborhood"""
 
-class NghSampler (FullSampler):
-    """ all pixels in a small neighborhood
-    """
     def __init__(self, ngh, subq=1, subd=1, ignore=1, border=None):
         FullSampler.__init__(self)
         assert 0 <= ignore < ngh
@@ -146,86 +157,96 @@ class NghSampler (FullSampler):
         assert subd <= ngh
         self.sub_q = subq
         self.sub_d = subd
-        if border is None: border = ngh
-        assert border >= ngh, 'border has to be larger than ngh'
+        if border is None:
+            border = ngh
+        assert border >= ngh, "border has to be larger than ngh"
         self.border = border
 
     def __repr__(self):
         return "NghSampler(ngh=%d, subq=%d, subd=%d, ignore=%d, border=%d)" % (
-            self.ngh, self.sub_q, self.sub_d, self.ignore, self.border)
+            self.ngh,
+            self.sub_q,
+            self.sub_d,
+            self.ignore,
+            self.border,
+        )
 
     def trans(self, arr, i, j):
-        s = lambda i: slice(self.border+i, i-self.border or None, self.sub_q)
-        return arr[:,:,s(j),s(i)]
+        s = lambda i: slice(self.border + i, i - self.border or None, self.sub_q)
+        return arr[:, :, s(j), s(i)]
 
     def __call__(self, feats, confs, aflow):
         feat1, conf1 = feats[0], (confs[0] if confs else None)
         # warp with optical flow in img1 coords
         feat2, mask2, conf2 = self._warp(feats, confs, aflow)
-        
-        qfeat = self.trans(feat1,0,0)
-        qconf = (self.trans(conf1,0,0) + self.trans(conf2,0,0)) / 2 if confs else None
-        mask2 = self.trans(mask2,0,0)
-        scores_at = lambda i,j: (qfeat * self.trans(feat2,i,j)).sum(dim=1)
-        
+
+        qfeat = self.trans(feat1, 0, 0)
+        qconf = (
+            (self.trans(conf1, 0, 0) + self.trans(conf2, 0, 0)) / 2 if confs else None
+        )
+        mask2 = self.trans(mask2, 0, 0)
+        scores_at = lambda i, j: (qfeat * self.trans(feat2, i, j)).sum(dim=1)
+
         # compute scores for all neighbors
         B, D = feat1.shape[:2]
         min_d = self.ignore**2
         max_d = self.ngh**2
-        rad = (self.ngh//self.sub_d) * self.ngh # make an integer multiple
+        rad = (self.ngh // self.sub_d) * self.ngh  # make an integer multiple
         negs = []
         offsets = []
-        for j in range(-rad, rad+1, self.sub_d):
-          for i in range(-rad, rad+1, self.sub_d):
-            if not(min_d < i*i + j*j <= max_d): 
-                continue # out of scope
-            offsets.append((i,j)) # Note: this list is just for debug
-            negs.append( scores_at(i,j) )
-        
-        scores = torch.stack([scores_at(0,0)] + negs, dim=-1)
+        for j in range(-rad, rad + 1, self.sub_d):
+            for i in range(-rad, rad + 1, self.sub_d):
+                if not (min_d < i * i + j * j <= max_d):
+                    continue  # out of scope
+                offsets.append((i, j))  # Note: this list is just for debug
+                negs.append(scores_at(i, j))
+
+        scores = torch.stack([scores_at(0, 0)] + negs, dim=-1)
         gt = scores.new_zeros(scores.shape, dtype=torch.uint8)
-        gt[..., 0] = 1 # only the center point is positive
+        gt[..., 0] = 1  # only the center point is positive
 
         return scores, gt, mask2, qconf
 
 
+class FarNearSampler(FullSampler):
+    """Sample pixels from *both* a small neighborhood *and* far-away pixels.
 
-class FarNearSampler (FullSampler):
-    """ Sample pixels from *both* a small neighborhood *and* far-away pixels.
-        
     How it works?
         1) Queries are sampled from img1,
-            - at least `border` pixels from borders and 
+            - at least `border` pixels from borders and
             - on a grid with step = `subq`
-            
-        2) Close database pixels 
+
+        2) Close database pixels
             - from the corresponding image (img2),
-            - within a `ngh` distance radius 
+            - within a `ngh` distance radius
             - on a grid with step = `subd_ngh`
             - ignored if distance to query is >0 and <=`ignore`
-            
+
         3) Far-away database pixels from ,
             - from all batch images in `img2`
             - at least `border` pixels from borders
             - on a grid with step = `subd_far`
     """
-    def __init__(self, subq, ngh, subd_ngh, subd_far, border=None, ignore=1, 
-                       maxpool_ngh=False ):
+
+    def __init__(
+        self, subq, ngh, subd_ngh, subd_far, border=None, ignore=1, maxpool_ngh=False
+    ):
         FullSampler.__init__(self)
         border = border or ngh
-        assert ignore < ngh < subd_far, 'neighborhood needs to be smaller than far step'
-        self.close_sampler = NghSampler(ngh=ngh, subq=subq, subd=subd_ngh, 
-                ignore=not(maxpool_ngh), border=border)
+        assert ignore < ngh < subd_far, "neighborhood needs to be smaller than far step"
+        self.close_sampler = NghSampler(
+            ngh=ngh, subq=subq, subd=subd_ngh, ignore=not (maxpool_ngh), border=border
+        )
         self.faraway_sampler = SubSampler(border=border, subq=subq, subd=subd_far)
         self.maxpool_ngh = maxpool_ngh
 
     def __repr__(self):
-        c,f = self.close_sampler, self.faraway_sampler
+        c, f = self.close_sampler, self.faraway_sampler
         res = "FarNearSampler(subq=%d, ngh=%d" % (c.sub_q, c.ngh)
         res += ", subd_ngh=%d, subd_far=%d" % (c.sub_d, f.sub_d)
         res += ", border=%d, ign=%d" % (f.border, c.ignore)
         res += ", maxpool_ngh=%d" % self.maxpool_ngh
-        return res+')'
+        return res + ")"
 
     def __call__(self, feats, confs, aflow):
         # warp with optical flow in img1 coords
@@ -233,10 +254,10 @@ class FarNearSampler (FullSampler):
 
         # sample ngh pixels
         scores1, gt1, msk1, conf1 = self.close_sampler(feats, confs, aflow)
-        scores1, gt1 = scores1.view(-1,scores1.shape[-1]), gt1.view(-1,gt1.shape[-1])
+        scores1, gt1 = scores1.view(-1, scores1.shape[-1]), gt1.view(-1, gt1.shape[-1])
         if self.maxpool_ngh:
             # we consider all scores from ngh as potential positives
-            scores1, self._cached_maxpool_ngh = scores1.max(dim=1,keepdim=True)
+            scores1, self._cached_maxpool_ngh = scores1.max(dim=1, keepdim=True)
             gt1 = gt1[:, 0:1]
 
         # sample far pixels
@@ -244,22 +265,35 @@ class FarNearSampler (FullSampler):
         # assert (msk1 == msk2).all()
         # assert (conf1 == conf2).all()
 
-        return (torch.cat((scores1,scores2),dim=1), 
-                torch.cat((gt1,    gt2),    dim=1), 
-                msk1, conf1 if confs else None)
+        return (
+            torch.cat((scores1, scores2), dim=1),
+            torch.cat((gt1, gt2), dim=1),
+            msk1,
+            conf1 if confs else None,
+        )
 
 
-class NghSampler2 (nn.Module):
-    """ Similar to NghSampler, but doesnt warp the 2nd image.
+class NghSampler2(nn.Module):
+    """Similar to NghSampler, but doesnt warp the 2nd image.
     Distance to GT =>  0 ... pos_d ... neg_d ... ngh
     Pixel label    =>  + + + + + + 0 0 - - - - - - -
-    
+
     Subsample on query side: if > 0, regular grid
-                                < 0, random points 
+                                < 0, random points
     In both cases, the number of query points is = W*H/subq**2
     """
-    def __init__(self, ngh, subq=1, subd=1, pos_d=0, neg_d=2, border=None,
-                       maxpool_pos=True, subd_neg=0):
+
+    def __init__(
+        self,
+        ngh,
+        subq=1,
+        subd=1,
+        pos_d=0,
+        neg_d=2,
+        border=None,
+        maxpool_pos=True,
+        subd_neg=0,
+    ):
         nn.Module.__init__(self)
         assert 0 <= pos_d < neg_d <= (ngh if ngh else 99)
         self.ngh = ngh
@@ -270,8 +304,9 @@ class NghSampler2 (nn.Module):
         self.sub_q = subq
         self.sub_d = subd
         self.sub_d_neg = subd_neg
-        if border is None: border = ngh
-        assert border >= ngh, 'border has to be larger than ngh'
+        if border is None:
+            border = ngh
+        assert border >= ngh, "border has to be larger than ngh"
         self.border = border
         self.maxpool_pos = maxpool_pos
         self.precompute_offsets()
@@ -280,19 +315,19 @@ class NghSampler2 (nn.Module):
         pos_d2 = self.pos_d**2
         neg_d2 = self.neg_d**2
         rad2 = self.ngh**2
-        rad = (self.ngh//self.sub_d) * self.ngh # make an integer multiple
+        rad = (self.ngh // self.sub_d) * self.ngh  # make an integer multiple
         pos = []
         neg = []
-        for j in range(-rad, rad+1, self.sub_d):
-          for i in range(-rad, rad+1, self.sub_d):
-            d2 = i*i + j*j
-            if d2 <= pos_d2:
-                pos.append( (i,j) )
-            elif neg_d2 <= d2 <= rad2: 
-                neg.append( (i,j) )
+        for j in range(-rad, rad + 1, self.sub_d):
+            for i in range(-rad, rad + 1, self.sub_d):
+                d2 = i * i + j * j
+                if d2 <= pos_d2:
+                    pos.append((i, j))
+                elif neg_d2 <= d2 <= rad2:
+                    neg.append((i, j))
 
-        self.register_buffer('pos_offsets', torch.LongTensor(pos).view(-1,2).t())
-        self.register_buffer('neg_offsets', torch.LongTensor(neg).view(-1,2).t())
+        self.register_buffer("pos_offsets", torch.LongTensor(pos).view(-1, 2).t())
+        self.register_buffer("neg_offsets", torch.LongTensor(neg).view(-1, 2).t())
 
     def gen_grid(self, step, aflow):
         B, two, H, W = aflow.shape
@@ -300,21 +335,21 @@ class NghSampler2 (nn.Module):
         b1 = torch.arange(B, device=dev)
         if step > 0:
             # regular grid
-            x1 = torch.arange(self.border, W-self.border, step, device=dev)
-            y1 = torch.arange(self.border, H-self.border, step, device=dev)
+            x1 = torch.arange(self.border, W - self.border, step, device=dev)
+            y1 = torch.arange(self.border, H - self.border, step, device=dev)
             H1, W1 = len(y1), len(x1)
-            x1 = x1[None,None,:].expand(B,H1,W1).reshape(-1)
-            y1 = y1[None,:,None].expand(B,H1,W1).reshape(-1)
-            b1 = b1[:,None,None].expand(B,H1,W1).reshape(-1)
+            x1 = x1[None, None, :].expand(B, H1, W1).reshape(-1)
+            y1 = y1[None, :, None].expand(B, H1, W1).reshape(-1)
+            b1 = b1[:, None, None].expand(B, H1, W1).reshape(-1)
             shape = (B, H1, W1)
         else:
             # randomly spread
-            n = (H - 2*self.border) * (W - 2*self.border) // step**2
-            x1 = torch.randint(self.border, W-self.border, (n,), device=dev)
-            y1 = torch.randint(self.border, H-self.border, (n,), device=dev)
-            x1 = x1[None,:].expand(B,n).reshape(-1)
-            y1 = y1[None,:].expand(B,n).reshape(-1)
-            b1 = b1[:,None].expand(B,n).reshape(-1)
+            n = (H - 2 * self.border) * (W - 2 * self.border) // step**2
+            x1 = torch.randint(self.border, W - self.border, (n,), device=dev)
+            y1 = torch.randint(self.border, H - self.border, (n,), device=dev)
+            x1 = x1[None, :].expand(B, n).reshape(-1)
+            y1 = y1[None, :].expand(B, n).reshape(-1)
+            b1 = b1[:, None].expand(B, n).reshape(-1)
             shape = (B, n)
         return b1, y1, x1, shape
 
@@ -323,41 +358,41 @@ class NghSampler2 (nn.Module):
         assert two == 2
         feat1, conf1 = feats[0], (confs[0] if confs else None)
         feat2, conf2 = feats[1], (confs[1] if confs else None)
-        
+
         # positions in the first image
         b1, y1, x1, shape = self.gen_grid(self.sub_q, aflow)
 
         # sample features from first image
         feat1 = feat1[b1, :, y1, x1]
         qconf = conf1[b1, :, y1, x1].view(shape) if confs else None
-        
-        #sample GT from second image
+
+        # sample GT from second image
         b2 = b1
         xy2 = (aflow[b1, :, y1, x1] + 0.5).long().t()
         mask = (0 <= xy2[0]) * (0 <= xy2[1]) * (xy2[0] < W) * (xy2[1] < H)
         mask = mask.view(shape)
-        
+
         def clamp(xy):
-            torch.clamp(xy[0], 0, W-1, out=xy[0])
-            torch.clamp(xy[1], 0, H-1, out=xy[1])
+            torch.clamp(xy[0], 0, W - 1, out=xy[0])
+            torch.clamp(xy[1], 0, H - 1, out=xy[1])
             return xy
-        
+
         # compute positive scores
-        xy2p = clamp(xy2[:,None,:] + self.pos_offsets[:,:,None])
-        pscores = (feat1[None,:,:] * feat2[b2, :, xy2p[1], xy2p[0]]).sum(dim=-1).t()
-#        xy1p = clamp(torch.stack((x1,y1))[:,None,:] + self.pos_offsets[:,:,None])
-#        grid = FullSampler._aflow_to_grid(aflow)
-#        feat2p = F.grid_sample(feat2, grid, mode='bilinear', padding_mode='border')
-#        pscores = (feat1[None,:,:] * feat2p[b1,:,xy1p[1], xy1p[0]]).sum(dim=-1).t()
+        xy2p = clamp(xy2[:, None, :] + self.pos_offsets[:, :, None])
+        pscores = (feat1[None, :, :] * feat2[b2, :, xy2p[1], xy2p[0]]).sum(dim=-1).t()
+        #        xy1p = clamp(torch.stack((x1,y1))[:,None,:] + self.pos_offsets[:,:,None])
+        #        grid = FullSampler._aflow_to_grid(aflow)
+        #        feat2p = F.grid_sample(feat2, grid, mode='bilinear', padding_mode='border')
+        #        pscores = (feat1[None,:,:] * feat2p[b1,:,xy1p[1], xy1p[0]]).sum(dim=-1).t()
         if self.maxpool_pos:
             pscores, pos = pscores.max(dim=1, keepdim=True)
-            if confs: 
-                sel = clamp(xy2 + self.pos_offsets[:,pos.view(-1)])
-                qconf = (qconf + conf2[b2, :, sel[1], sel[0]].view(shape))/2
-        
+            if confs:
+                sel = clamp(xy2 + self.pos_offsets[:, pos.view(-1)])
+                qconf = (qconf + conf2[b2, :, sel[1], sel[0]].view(shape)) / 2
+
         # compute negative scores
-        xy2n = clamp(xy2[:,None,:] + self.neg_offsets[:,:,None])
-        nscores = (feat1[None,:,:] * feat2[b2, :, xy2n[1], xy2n[0]]).sum(dim=-1).t()
+        xy2n = clamp(xy2[:, None, :] + self.neg_offsets[:, :, None])
+        nscores = (feat1[None, :, :] * feat2[b2, :, xy2n[1], xy2n[0]]).sum(dim=-1).t()
 
         if self.sub_d_neg:
             # add distractors from a grid
@@ -365,26 +400,18 @@ class NghSampler2 (nn.Module):
             distractors = feat2[b3, :, y3, x3]
             dscores = torch.matmul(feat1, distractors.t())
             del distractors
-            
+
             # remove scores that corresponds to positives or nulls
-            dis2 = (x3 - xy2[0][:,None])**2 + (y3 - xy2[1][:,None])**2
-            dis2 += (b3 != b2[:,None]).long() * self.neg_d**2
+            dis2 = (x3 - xy2[0][:, None]) ** 2 + (y3 - xy2[1][:, None]) ** 2
+            dis2 += (b3 != b2[:, None]).long() * self.neg_d**2
             dscores[dis2 < self.neg_d**2] = 0
-            
+
             scores = torch.cat((pscores, nscores, dscores), dim=1)
         else:
             # concat everything
             scores = torch.cat((pscores, nscores), dim=1)
 
         gt = scores.new_zeros(scores.shape, dtype=torch.uint8)
-        gt[:, :pscores.shape[1]] = 1
+        gt[:, : pscores.shape[1]] = 1
 
         return scores, gt, mask, qconf
-
-
-
-
-
-
-
-
diff --git a/third_party/r2d2/tools/common.py b/third_party/r2d2/tools/common.py
index a7875ddd714b1d08efb0d1369c3a856490796288..be5137c60e3fb71cbbf180d0058de20a508ff140 100644
--- a/third_party/r2d2/tools/common.py
+++ b/third_party/r2d2/tools/common.py
@@ -2,7 +2,7 @@
 # CC BY-NC-SA 3.0
 # Available only for non-commercial use
 
-import os, pdb#, shutil
+import os, pdb  # , shutil
 import numpy as np
 import torch
 
@@ -12,8 +12,7 @@ def mkdir_for(file_path):
 
 
 def model_size(model):
-    ''' Computes the number of parameters of the model 
-    '''
+    """Computes the number of parameters of the model"""
     size = 0
     for weights in model.state_dict().values():
         size += np.prod(weights.shape)
@@ -24,18 +23,19 @@ def torch_set_gpu(gpus):
     if type(gpus) is int:
         gpus = [gpus]
 
-    cuda = all(gpu>=0 for gpu in gpus)
+    cuda = all(gpu >= 0 for gpu in gpus)
 
     if cuda:
-        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu) for gpu in gpus])
+        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(gpu) for gpu in gpus])
         assert cuda and torch.cuda.is_available(), "%s has GPUs %s unavailable" % (
-            os.environ['HOSTNAME'],os.environ['CUDA_VISIBLE_DEVICES'])
-        torch.backends.cudnn.benchmark = True # speed-up cudnn
-        torch.backends.cudnn.fastest = True # even more speed-up?
-        print( 'Launching on GPUs ' + os.environ['CUDA_VISIBLE_DEVICES'] )
+            os.environ["HOSTNAME"],
+            os.environ["CUDA_VISIBLE_DEVICES"],
+        )
+        torch.backends.cudnn.benchmark = True  # speed-up cudnn
+        torch.backends.cudnn.fastest = True  # even more speed-up?
+        print("Launching on GPUs " + os.environ["CUDA_VISIBLE_DEVICES"])
 
     else:
-        print( 'Launching on CPU' )
+        print("Launching on CPU")
 
     return cuda
-
diff --git a/third_party/r2d2/tools/dataloader.py b/third_party/r2d2/tools/dataloader.py
index f6d9fff5f8dfb8d9d3b243a57555779de33d0818..a0fc97d8085c1e7c5fc5c14cc4e0818bd343595f 100644
--- a/third_party/r2d2/tools/dataloader.py
+++ b/third_party/r2d2/tools/dataloader.py
@@ -14,99 +14,113 @@ from tools.transforms_tools import persp_apply
 
 
 RGB_mean = [0.485, 0.456, 0.406]
-RGB_std  = [0.229, 0.224, 0.225]
+RGB_std = [0.229, 0.224, 0.225]
 
 norm_RGB = tvf.Compose([tvf.ToTensor(), tvf.Normalize(mean=RGB_mean, std=RGB_std)])
 
 
 class PairLoader:
-    """ On-the-fly jittering of pairs of image with dense pixel ground-truth correspondences.
-    
+    """On-the-fly jittering of pairs of image with dense pixel ground-truth correspondences.
+
     crop:   random crop applied to both images
     scale:  random scaling applied to img2
     distort: random ditorsion applied to img2
-    
+
     self[idx] returns a dictionary with keys: img1, img2, aflow, mask
      - img1: cropped original
      - img2: distorted cropped original
      - aflow: 'absolute' optical flow = (x,y) position of each pixel from img1 in img2
      - mask: (binary image) valid pixels of img1
     """
-    def __init__(self, dataset, crop='', scale='', distort='', norm = norm_RGB, 
-                       what = 'aflow mask', idx_as_rng_seed = False):
-        assert hasattr(dataset, 'npairs')
-        assert hasattr(dataset, 'get_pair')
+
+    def __init__(
+        self,
+        dataset,
+        crop="",
+        scale="",
+        distort="",
+        norm=norm_RGB,
+        what="aflow mask",
+        idx_as_rng_seed=False,
+    ):
+        assert hasattr(dataset, "npairs")
+        assert hasattr(dataset, "get_pair")
         self.dataset = dataset
         self.distort = instanciate_transformation(distort)
         self.crop = instanciate_transformation(crop)
         self.norm = instanciate_transformation(norm)
         self.scale = instanciate_transformation(scale)
-        self.idx_as_rng_seed = idx_as_rng_seed # to remove randomness
+        self.idx_as_rng_seed = idx_as_rng_seed  # to remove randomness
         self.what = what.split() if isinstance(what, str) else what
-        self.n_samples = 5 # number of random trials per image
+        self.n_samples = 5  # number of random trials per image
 
     def __len__(self):
-        assert len(self.dataset) == self.dataset.npairs, pdb.set_trace() # and not nimg
+        assert len(self.dataset) == self.dataset.npairs, pdb.set_trace()  # and not nimg
         return len(self.dataset)
 
     def __repr__(self):
-        fmt_str = 'PairLoader\n'
+        fmt_str = "PairLoader\n"
         fmt_str += repr(self.dataset)
-        fmt_str += '  npairs: %d\n' % self.dataset.npairs
-        short_repr = lambda s: repr(s).strip().replace('\n',', ')[14:-1].replace('    ',' ')
-        fmt_str += '  Distort: %s\n' % short_repr(self.distort)
-        fmt_str += '  Crop: %s\n' % short_repr(self.crop)
-        fmt_str += '  Norm: %s\n' % short_repr(self.norm)
+        fmt_str += "  npairs: %d\n" % self.dataset.npairs
+        short_repr = (
+            lambda s: repr(s).strip().replace("\n", ", ")[14:-1].replace("    ", " ")
+        )
+        fmt_str += "  Distort: %s\n" % short_repr(self.distort)
+        fmt_str += "  Crop: %s\n" % short_repr(self.crop)
+        fmt_str += "  Norm: %s\n" % short_repr(self.norm)
         return fmt_str
 
     def __getitem__(self, i):
-        #from time import time as now; t0 = now()
+        # from time import time as now; t0 = now()
         if self.idx_as_rng_seed:
             import random
+
             random.seed(i)
             np.random.seed(i)
 
         # Retrieve an image pair and their absolute flow
         img_a, img_b, metadata = self.dataset.get_pair(i, self.what)
-        
-        # aflow contains pixel coordinates indicating where each 
+
+        # aflow contains pixel coordinates indicating where each
         # pixel from the left image ended up in the right image
         # as (x,y) pairs, but its shape is (H,W,2)
-        aflow = np.float32(metadata['aflow'])
-        mask = metadata.get('mask', np.ones(aflow.shape[:2],np.uint8))
+        aflow = np.float32(metadata["aflow"])
+        mask = metadata.get("mask", np.ones(aflow.shape[:2], np.uint8))
 
         # apply transformations to the second image
-        img_b = {'img': img_b, 'persp':(1,0,0,0,1,0,0,0)}
+        img_b = {"img": img_b, "persp": (1, 0, 0, 0, 1, 0, 0, 0)}
         if self.scale:
             img_b = self.scale(img_b)
         if self.distort:
             img_b = self.distort(img_b)
-        
+
         # apply the same transformation to the flow
-        aflow[:] = persp_apply(img_b['persp'], aflow.reshape(-1,2)).reshape(aflow.shape)
+        aflow[:] = persp_apply(img_b["persp"], aflow.reshape(-1, 2)).reshape(
+            aflow.shape
+        )
         corres = None
-        if 'corres' in metadata:
-            corres = np.float32(metadata['corres'])
-            corres[:,1] = persp_apply(img_b['persp'], corres[:,1])
-        
+        if "corres" in metadata:
+            corres = np.float32(metadata["corres"])
+            corres[:, 1] = persp_apply(img_b["persp"], corres[:, 1])
+
         # apply the same transformation to the homography
         homography = None
-        if 'homography' in metadata:
-            homography = np.float32(metadata['homography'])
+        if "homography" in metadata:
+            homography = np.float32(metadata["homography"])
             # p_b = homography * p_a
-            persp = np.float32(img_b['persp']+(1,)).reshape(3,3)
+            persp = np.float32(img_b["persp"] + (1,)).reshape(3, 3)
             homography = persp @ homography
 
         # determine crop size
-        img_b = img_b['img']
-        crop_size = self.crop({'imsize':(10000,10000)})['imsize']
+        img_b = img_b["img"]
+        crop_size = self.crop({"imsize": (10000, 10000)})["imsize"]
         output_size_a = min(img_a.size, crop_size)
         output_size_b = min(img_b.size, crop_size)
         img_a = np.array(img_a)
         img_b = np.array(img_b)
 
-        ah,aw,p1 = img_a.shape
-        bh,bw,p2 = img_b.shape
+        ah, aw, p1 = img_a.shape
+        bh, bw, p2 = img_b.shape
         assert p1 == 3
         assert p2 == 3
         assert aflow.shape == (ah, aw, 2)
@@ -114,68 +128,82 @@ class PairLoader:
 
         # Let's start by computing the scale of the
         # optical flow and applying a median filter:
-        dx = np.gradient(aflow[:,:,0])
-        dy = np.gradient(aflow[:,:,1])
-        scale = np.sqrt(np.clip(np.abs(dx[1]*dy[0] - dx[0]*dy[1]), 1e-16, 1e16))
+        dx = np.gradient(aflow[:, :, 0])
+        dy = np.gradient(aflow[:, :, 1])
+        scale = np.sqrt(np.clip(np.abs(dx[1] * dy[0] - dx[0] * dy[1]), 1e-16, 1e16))
 
-        accu2 = np.zeros((16,16), bool)
+        accu2 = np.zeros((16, 16), bool)
         Q = lambda x, w: np.int32(16 * (x - w.start) / (w.stop - w.start))
-        
+
         def window1(x, size, w):
             l = x - int(0.5 + size / 2)
             r = l + int(0.5 + size)
-            if l < 0: l,r = (0, r - l)
-            if r > w: l,r = (l + w - r, w)
-            if l < 0: l,r = 0,w # larger than width
-            return slice(l,r)
+            if l < 0:
+                l, r = (0, r - l)
+            if r > w:
+                l, r = (l + w - r, w)
+            if l < 0:
+                l, r = 0, w  # larger than width
+            return slice(l, r)
+
         def window(cx, cy, win_size, scale, img_shape):
-            return (window1(cy, win_size[1]*scale, img_shape[0]), 
-                    window1(cx, win_size[0]*scale, img_shape[1]))
+            return (
+                window1(cy, win_size[1] * scale, img_shape[0]),
+                window1(cx, win_size[0] * scale, img_shape[1]),
+            )
 
         n_valid_pixel = mask.sum()
         sample_w = mask / (1e-16 + n_valid_pixel)
+
         def sample_valid_pixel():
             n = np.random.choice(sample_w.size, p=sample_w.ravel())
             y, x = np.unravel_index(n, sample_w.shape)
             return x, y
-        
+
         # Find suitable left and right windows
-        trials = 0 # take the best out of few trials
+        trials = 0  # take the best out of few trials
         best = -np.inf, None
-        for _ in range(50*self.n_samples):
-            if trials >= self.n_samples: break # finished!
+        for _ in range(50 * self.n_samples):
+            if trials >= self.n_samples:
+                break  # finished!
 
             # pick a random valid point from the first image
-            if n_valid_pixel == 0: break
+            if n_valid_pixel == 0:
+                break
             c1x, c1y = sample_valid_pixel()
-            
+
             # Find in which position the center of the left
             # window ended up being placed in the right image
             c2x, c2y = (aflow[c1y, c1x] + 0.5).astype(np.int32)
-            if not(0 <= c2x < bw and 0 <= c2y < bh): continue
+            if not (0 <= c2x < bw and 0 <= c2y < bh):
+                continue
 
             # Get the flow scale
             sigma = scale[c1y, c1x]
 
             # Determine sampling windows
-            if 0.2 < sigma < 1: 
-                win1 = window(c1x, c1y, output_size_a, 1/sigma, img_a.shape)
+            if 0.2 < sigma < 1:
+                win1 = window(c1x, c1y, output_size_a, 1 / sigma, img_a.shape)
                 win2 = window(c2x, c2y, output_size_b, 1, img_b.shape)
             elif 1 <= sigma < 5:
                 win1 = window(c1x, c1y, output_size_a, 1, img_a.shape)
                 win2 = window(c2x, c2y, output_size_b, sigma, img_b.shape)
             else:
-                continue # bad scale
+                continue  # bad scale
 
             # compute a score based on the flow
-            x2,y2 = aflow[win1].reshape(-1, 2).T.astype(np.int32)
+            x2, y2 = aflow[win1].reshape(-1, 2).T.astype(np.int32)
             # Check the proportion of valid flow vectors
-            valid = (win2[1].start <= x2) & (x2 < win2[1].stop) \
-                  & (win2[0].start <= y2) & (y2 < win2[0].stop)
+            valid = (
+                (win2[1].start <= x2)
+                & (x2 < win2[1].stop)
+                & (win2[0].start <= y2)
+                & (y2 < win2[0].stop)
+            )
             score1 = (valid * mask[win1].ravel()).mean()
             # check the coverage of the second window
             accu2[:] = False
-            accu2[Q(y2[valid],win2[0]), Q(x2[valid],win2[1])] = True
+            accu2[Q(y2[valid], win2[0]), Q(x2[valid], win2[1])] = True
             score2 = accu2.mean()
             # Check how many hits we got
             score = min(score1, score2)
@@ -183,12 +211,12 @@ class PairLoader:
             trials += 1
             if score > best[0]:
                 best = score, win1, win2
-        
-        if None in best: # counldn't find a good window
-            img_a = np.zeros(output_size_a[::-1]+(3,), dtype=np.uint8)
-            img_b = np.zeros(output_size_b[::-1]+(3,), dtype=np.uint8)
-            aflow = np.nan * np.ones((2,)+output_size_a[::-1], dtype=np.float32)
-            homography = np.nan * np.ones((3,3), dtype=np.float32)
+
+        if None in best:  # counldn't find a good window
+            img_a = np.zeros(output_size_a[::-1] + (3,), dtype=np.uint8)
+            img_b = np.zeros(output_size_b[::-1] + (3,), dtype=np.uint8)
+            aflow = np.nan * np.ones((2,) + output_size_a[::-1], dtype=np.float32)
+            homography = np.nan * np.ones((3, 3), dtype=np.float32)
 
         else:
             win1, win2 = best[1:]
@@ -196,92 +224,103 @@ class PairLoader:
             img_b = img_b[win2]
             aflow = aflow[win1] - np.float32([[[win2[1].start, win2[0].start]]])
             mask = mask[win1]
-            aflow[~mask.view(bool)] = np.nan # mask bad pixels!
-            aflow = aflow.transpose(2,0,1) # --> (2,H,W)
-            
+            aflow[~mask.view(bool)] = np.nan  # mask bad pixels!
+            aflow = aflow.transpose(2, 0, 1)  # --> (2,H,W)
+
             if corres is not None:
-                corres[:,0] -= (win1[1].start, win1[0].start)
-                corres[:,1] -= (win2[1].start, win2[0].start)
-            
+                corres[:, 0] -= (win1[1].start, win1[0].start)
+                corres[:, 1] -= (win2[1].start, win2[0].start)
+
             if homography is not None:
                 trans1 = np.eye(3, dtype=np.float32)
-                trans1[:2,2] = (win1[1].start, win1[0].start)
+                trans1[:2, 2] = (win1[1].start, win1[0].start)
                 trans2 = np.eye(3, dtype=np.float32)
-                trans2[:2,2] = (-win2[1].start, -win2[0].start)
+                trans2[:2, 2] = (-win2[1].start, -win2[0].start)
                 homography = trans2 @ homography @ trans1
-                homography /= homography[2,2]
-            
+                homography /= homography[2, 2]
+
             # rescale if necessary
             if img_a.shape[:2][::-1] != output_size_a:
-                sx, sy = (np.float32(output_size_a)-1)/(np.float32(img_a.shape[:2][::-1])-1)
-                img_a = np.asarray(Image.fromarray(img_a).resize(output_size_a, Image.ANTIALIAS))
-                mask = np.asarray(Image.fromarray(mask).resize(output_size_a, Image.NEAREST))
+                sx, sy = (np.float32(output_size_a) - 1) / (
+                    np.float32(img_a.shape[:2][::-1]) - 1
+                )
+                img_a = np.asarray(
+                    Image.fromarray(img_a).resize(output_size_a, Image.ANTIALIAS)
+                )
+                mask = np.asarray(
+                    Image.fromarray(mask).resize(output_size_a, Image.NEAREST)
+                )
                 afx = Image.fromarray(aflow[0]).resize(output_size_a, Image.NEAREST)
                 afy = Image.fromarray(aflow[1]).resize(output_size_a, Image.NEAREST)
                 aflow = np.stack((np.float32(afx), np.float32(afy)))
-                
+
                 if corres is not None:
-                    corres[:,0] *= (sx, sy)
-                
+                    corres[:, 0] *= (sx, sy)
+
                 if homography is not None:
-                    homography = homography @ np.diag(np.float32([1/sx,1/sy,1]))
-                    homography /= homography[2,2]
+                    homography = homography @ np.diag(np.float32([1 / sx, 1 / sy, 1]))
+                    homography /= homography[2, 2]
 
             if img_b.shape[:2][::-1] != output_size_b:
-                sx, sy = (np.float32(output_size_b)-1)/(np.float32(img_b.shape[:2][::-1])-1)
-                img_b = np.asarray(Image.fromarray(img_b).resize(output_size_b, Image.ANTIALIAS))
+                sx, sy = (np.float32(output_size_b) - 1) / (
+                    np.float32(img_b.shape[:2][::-1]) - 1
+                )
+                img_b = np.asarray(
+                    Image.fromarray(img_b).resize(output_size_b, Image.ANTIALIAS)
+                )
                 aflow *= [[[sx]], [[sy]]]
-                
+
                 if corres is not None:
-                    corres[:,1] *= (sx, sy)
-                
+                    corres[:, 1] *= (sx, sy)
+
                 if homography is not None:
-                    homography = np.diag(np.float32([sx,sy,1])) @ homography
-                    homography /= homography[2,2]
-    
+                    homography = np.diag(np.float32([sx, sy, 1])) @ homography
+                    homography /= homography[2, 2]
+
         assert aflow.dtype == np.float32, pdb.set_trace()
         assert homography is None or homography.dtype == np.float32, pdb.set_trace()
-        if 'flow' in self.what:
+        if "flow" in self.what:
             H, W = img_a.shape[:2]
             mgrid = np.mgrid[0:H, 0:W][::-1].astype(np.float32)
             flow = aflow - mgrid
-        
+
         result = dict(img1=self.norm(img_a), img2=self.norm(img_b))
         for what in self.what:
-            try: result[what] = eval(what)
-            except NameError: pass
+            try:
+                result[what] = eval(what)
+            except NameError:
+                pass
         return result
 
 
+def threaded_loader(loader, iscuda, threads, batch_size=1, shuffle=True):
+    """Get a data loader, given the dataset and some parameters.
 
-def threaded_loader( loader, iscuda, threads, batch_size=1, shuffle=True):
-    """ Get a data loader, given the dataset and some parameters.
-    
     Parameters
     ----------
     loader : object[i] returns the i-th training example.
-    
+
     iscuda : bool
-        
+
     batch_size : int
-    
+
     threads : int
-    
+
     shuffle : int
-    
+
     Returns
     -------
         a multi-threaded pytorch loader.
     """
     return torch.utils.data.DataLoader(
         loader,
-        batch_size = batch_size,
-        shuffle = shuffle,
-        sampler = None,
-        num_workers = threads,
-        pin_memory = iscuda,
-        collate_fn=collate)
-
+        batch_size=batch_size,
+        shuffle=shuffle,
+        sampler=None,
+        num_workers=threads,
+        pin_memory=iscuda,
+        collate_fn=collate,
+    )
 
 
 def collate(batch, _use_shared_memory=True):
@@ -289,6 +328,7 @@ def collate(batch, _use_shared_memory=True):
     Copied from https://github.com/pytorch in torch/utils/data/_utils/collate.py
     """
     import re
+
     error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
     elem_type = type(batch[0])
     if isinstance(batch[0], torch.Tensor):
@@ -300,12 +340,15 @@ def collate(batch, _use_shared_memory=True):
             storage = batch[0].storage()._new_shared(numel)
             out = batch[0].new(storage)
         return torch.stack(batch, 0, out=out)
-    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
-            and elem_type.__name__ != 'string_':
+    elif (
+        elem_type.__module__ == "numpy"
+        and elem_type.__name__ != "str_"
+        and elem_type.__name__ != "string_"
+    ):
         elem = batch[0]
-        assert elem_type.__name__ == 'ndarray'
+        assert elem_type.__name__ == "ndarray"
         # array of string classes and object
-        if re.search('[SaUO]', elem.dtype.str) is not None:
+        if re.search("[SaUO]", elem.dtype.str) is not None:
             raise TypeError(error_msg.format(elem.dtype))
         batch = [torch.from_numpy(b) for b in batch]
         try:
@@ -322,46 +365,52 @@ def collate(batch, _use_shared_memory=True):
         return batch
     elif isinstance(batch[0], dict):
         return {key: collate([d[key] for d in batch]) for key in batch[0]}
-    elif isinstance(batch[0], (tuple,list)):
+    elif isinstance(batch[0], (tuple, list)):
         transposed = zip(*batch)
         return [collate(samples) for samples in transposed]
 
     raise TypeError((error_msg.format(type(batch[0]))))
 
 
-
 def tensor2img(tensor, model=None):
-    """ convert back a torch/numpy tensor to a PIL Image
-        by undoing the ToTensor() and Normalize() transforms.
+    """convert back a torch/numpy tensor to a PIL Image
+    by undoing the ToTensor() and Normalize() transforms.
     """
     mean = norm_RGB.transforms[1].mean
-    std =  norm_RGB.transforms[1].std
+    std = norm_RGB.transforms[1].std
     if isinstance(tensor, torch.Tensor):
         tensor = tensor.detach().cpu().numpy()
-    
-    res = np.uint8(np.clip(255*((tensor.transpose(1,2,0) * std) + mean), 0, 255))
+
+    res = np.uint8(np.clip(255 * ((tensor.transpose(1, 2, 0) * std) + mean), 0, 255))
     from PIL import Image
+
     return Image.fromarray(res)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     import argparse
+
     parser = argparse.ArgumentParser("Tool to debug/visualize the data loader")
-    parser.add_argument("dataloader", type=str, help="command to create the data loader")
+    parser.add_argument(
+        "dataloader", type=str, help="command to create the data loader"
+    )
     args = parser.parse_args()
 
     from datasets import *
-    auto_pairs = lambda db: SyntheticPairDataset(db,
-        'RandomScale(256,1024,can_upscale=True)', 
-        'RandomTilting(0.5), PixelNoise(25)')
-        
+
+    auto_pairs = lambda db: SyntheticPairDataset(
+        db,
+        "RandomScale(256,1024,can_upscale=True)",
+        "RandomTilting(0.5), PixelNoise(25)",
+    )
+
     loader = eval(args.dataloader)
     print("Data loader =", loader)
 
     from tools.viz import show_flow
+
     for data in loader:
-        aflow = data['aflow']
+        aflow = data["aflow"]
         H, W = aflow.shape[-2:]
-        flow = (aflow - np.mgrid[:H, :W][::-1]).transpose(1,2,0)
-        show_flow(tensor2img(data['img1']), tensor2img(data['img2']), flow)
-
+        flow = (aflow - np.mgrid[:H, :W][::-1]).transpose(1, 2, 0)
+        show_flow(tensor2img(data["img1"]), tensor2img(data["img2"]), flow)
diff --git a/third_party/r2d2/tools/trainer.py b/third_party/r2d2/tools/trainer.py
index 9f893395efdeb8e13cc00539325572553168c5ce..d71ef137f556b7709ebed37a6ea4c865e5ab6c37 100644
--- a/third_party/r2d2/tools/trainer.py
+++ b/third_party/r2d2/tools/trainer.py
@@ -10,15 +10,16 @@ import torch
 import torch.nn as nn
 
 
-class Trainer (nn.Module):
-    """ Helper class to train a deep network.
+class Trainer(nn.Module):
+    """Helper class to train a deep network.
         Overload this class `forward_backward` for your actual needs.
-    
-    Usage: 
+
+    Usage:
         train = Trainer(net, loader, loss, optimizer)
         for epoch in range(n_epochs):
             train()
     """
+
     def __init__(self, net, loader, loss, optimizer):
         nn.Module.__init__(self)
         self.net = net
@@ -27,50 +28,48 @@ class Trainer (nn.Module):
         self.optimizer = optimizer
 
     def iscuda(self):
-        return next(self.net.parameters()).device != torch.device('cpu')
+        return next(self.net.parameters()).device != torch.device("cpu")
 
     def todevice(self, x):
         if isinstance(x, dict):
-            return {k:self.todevice(v) for k,v in x.items()}
-        if isinstance(x, (tuple,list)):
-            return [self.todevice(v)  for v in x]
-        
-        if self.iscuda(): 
+            return {k: self.todevice(v) for k, v in x.items()}
+        if isinstance(x, (tuple, list)):
+            return [self.todevice(v) for v in x]
+
+        if self.iscuda():
             return x.contiguous().cuda(non_blocking=True)
         else:
             return x.cpu()
 
     def __call__(self):
         self.net.train()
-        
+
         stats = defaultdict(list)
-        
-        for iter,inputs in enumerate(tqdm(self.loader)):
+
+        for iter, inputs in enumerate(tqdm(self.loader)):
             inputs = self.todevice(inputs)
-            
+
             # compute gradient and do model update
             self.optimizer.zero_grad()
-            
+
             loss, details = self.forward_backward(inputs)
             if torch.isnan(loss):
-                raise RuntimeError('Loss is NaN')
-            
+                raise RuntimeError("Loss is NaN")
+
             self.optimizer.step()
-            
+
             for key, val in details.items():
-                stats[key].append( val )
-        
+                stats[key].append(val)
+
         print(" Summary of losses during this epoch:")
         mean = lambda lis: sum(lis) / len(lis)
         for loss_name, vals in stats.items():
-            N = 1 + len(vals)//10
-            print(f"  - {loss_name:20}:", end='')
-            print(f" {mean(vals[:N]):.3f} --> {mean(vals[-N:]):.3f} (avg: {mean(vals):.3f})")
-        return mean(stats['loss']) # return average loss
+            N = 1 + len(vals) // 10
+            print(f"  - {loss_name:20}:", end="")
+            print(
+                f" {mean(vals[:N]):.3f} --> {mean(vals[-N:]):.3f} (avg: {mean(vals):.3f})"
+            )
+        return mean(stats["loss"])  # return average loss
 
     def forward_backward(self, inputs):
         raise NotImplementedError()
-
-
-
-
diff --git a/third_party/r2d2/tools/transforms.py b/third_party/r2d2/tools/transforms.py
index 87275276310191a7da3fc14f606345d9616208e0..604a7c2a3ec6da955c1e85b7505103c694232458 100644
--- a/third_party/r2d2/tools/transforms.py
+++ b/third_party/r2d2/tools/transforms.py
@@ -11,23 +11,23 @@ from math import ceil
 
 from . import transforms_tools as F
 
-'''
+"""
 Example command to try out some transformation chain:
 
 python -m tools.transforms --trfs "Scale(384), ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1), RandomRotation(10), RandomTilting(0.5, 'all'), RandomScale(240,320), RandomCrop(224)"
-'''
+"""
 
 
 def instanciate_transformation(cmd_line):
-    ''' Create a sequence of transformations.
-    
+    """Create a sequence of transformations.
+
     cmd_line: (str)
         Comma-separated list of transformations.
         Ex: "Rotate(10), Scale(256)"
-    '''
+    """
     if not isinstance(cmd_line, str):
-        return cmd_line # already instanciated
-    
+        return cmd_line  # already instanciated
+
     cmd_line = "tvf.Compose([%s])" % cmd_line
     try:
         return eval(cmd_line)
@@ -35,19 +35,26 @@ def instanciate_transformation(cmd_line):
         print("Cannot interpret this transform list: %s\nReason: %s" % (cmd_line, e))
 
 
-class Scale (object):
-    """ Rescale the input PIL.Image to a given size.
+class Scale(object):
+    """Rescale the input PIL.Image to a given size.
     Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
-    
+
     The smallest dimension of the resulting image will be = size.
-    
+
     if largest == True: same behaviour for the largest dimension.
-    
+
     if not can_upscale: don't upscale
     if not can_downscale: don't downscale
     """
-    def __init__(self, size, interpolation=Image.BILINEAR, largest=False, 
-                 can_upscale=True, can_downscale=True):
+
+    def __init__(
+        self,
+        size,
+        interpolation=Image.BILINEAR,
+        largest=False,
+        can_upscale=True,
+        can_downscale=True,
+    ):
         assert isinstance(size, int) or (len(size) == 2)
         self.size = size
         self.interpolation = interpolation
@@ -57,15 +64,18 @@ class Scale (object):
 
     def __repr__(self):
         fmt_str = "RandomScale(%s" % str(self.size)
-        if self.largest: fmt_str += ', largest=True'
-        if not self.can_upscale: fmt_str += ', can_upscale=False'
-        if not self.can_downscale: fmt_str += ', can_downscale=False'
-        return fmt_str+')'
+        if self.largest:
+            fmt_str += ", largest=True"
+        if not self.can_upscale:
+            fmt_str += ", can_upscale=False"
+        if not self.can_downscale:
+            fmt_str += ", can_downscale=False"
+        return fmt_str + ")"
 
     def get_params(self, imsize):
-        w,h = imsize
+        w, h = imsize
         if isinstance(self.size, int):
-            cmp = lambda a,b: (a>=b) if self.largest else (a<=b)
+            cmp = lambda a, b: (a >= b) if self.largest else (a <= b)
             if (cmp(w, h) and w == self.size) or (cmp(h, w) and h == self.size):
                 ow, oh = w, h
             elif cmp(w, h):
@@ -81,19 +91,22 @@ class Scale (object):
     def __call__(self, inp):
         img = F.grab_img(inp)
         w, h = img.size
-        
+
         size2 = ow, oh = self.get_params(img.size)
-        
+
         if size2 != img.size:
             a1, a2 = img.size, size2
-            if (self.can_upscale and min(a1) < min(a2)) or (self.can_downscale and min(a1) > min(a2)):
+            if (self.can_upscale and min(a1) < min(a2)) or (
+                self.can_downscale and min(a1) > min(a2)
+            ):
                 img = img.resize(size2, self.interpolation)
 
-        return F.update_img_and_labels(inp, img, persp=(ow/w,0,0,0,oh/h,0,0,0))
-
+        return F.update_img_and_labels(
+            inp, img, persp=(ow / w, 0, 0, 0, oh / h, 0, 0, 0)
+        )
 
 
-class RandomScale (Scale):
+class RandomScale(Scale):
     """Rescale the input PIL.Image to a random size.
     Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
 
@@ -108,53 +121,79 @@ class RandomScale (Scale):
             ``PIL.Image.BILINEAR``
     """
 
-    def __init__(self, min_size, max_size, ar=1, 
-                 can_upscale=False, can_downscale=True, interpolation=Image.BILINEAR):
-        Scale.__init__(self, 0, can_upscale=can_upscale, can_downscale=can_downscale, interpolation=interpolation)
-        assert type(min_size) == type(max_size), 'min_size and max_size can only be 2 ints or 2 floats'
-        assert isinstance(min_size, int) and min_size >= 1 or isinstance(min_size, float) and min_size>0
-        assert isinstance(max_size, (int,float)) and min_size <= max_size
+    def __init__(
+        self,
+        min_size,
+        max_size,
+        ar=1,
+        can_upscale=False,
+        can_downscale=True,
+        interpolation=Image.BILINEAR,
+    ):
+        Scale.__init__(
+            self,
+            0,
+            can_upscale=can_upscale,
+            can_downscale=can_downscale,
+            interpolation=interpolation,
+        )
+        assert type(min_size) == type(
+            max_size
+        ), "min_size and max_size can only be 2 ints or 2 floats"
+        assert (
+            isinstance(min_size, int)
+            and min_size >= 1
+            or isinstance(min_size, float)
+            and min_size > 0
+        )
+        assert isinstance(max_size, (int, float)) and min_size <= max_size
         self.min_size = min_size
         self.max_size = max_size
-        if type(ar) in (float,int): ar = (min(1/ar,ar),max(1/ar,ar))
+        if type(ar) in (float, int):
+            ar = (min(1 / ar, ar), max(1 / ar, ar))
         assert 0.2 < ar[0] <= ar[1] < 5
         self.ar = ar
 
     def get_params(self, imsize):
-        w,h = imsize
+        w, h = imsize
         if isinstance(self.min_size, float):
-            min_size = int(self.min_size*min(w,h) + 0.5)
+            min_size = int(self.min_size * min(w, h) + 0.5)
         if isinstance(self.max_size, float):
-            max_size = int(self.max_size*min(w,h) + 0.5)
+            max_size = int(self.max_size * min(w, h) + 0.5)
         if isinstance(self.min_size, int):
             min_size = self.min_size
         if isinstance(self.max_size, int):
             max_size = self.max_size
-        
+
         if not self.can_upscale:
-            max_size = min(max_size,min(w,h))
-        
-        size = int(0.5 + F.rand_log_uniform(min_size,max_size))
-        ar = F.rand_log_uniform(*self.ar) # change of aspect ratio
+            max_size = min(max_size, min(w, h))
+
+        size = int(0.5 + F.rand_log_uniform(min_size, max_size))
+        ar = F.rand_log_uniform(*self.ar)  # change of aspect ratio
 
-        if w < h: # image is taller
+        if w < h:  # image is taller
             ow = size
             oh = int(0.5 + size * h / w / ar)
             if oh < min_size:
-                ow,oh = int(0.5 + ow*float(min_size)/oh),min_size
-        else: # image is wider
+                ow, oh = int(0.5 + ow * float(min_size) / oh), min_size
+        else:  # image is wider
             oh = size
             ow = int(0.5 + size * w / h * ar)
             if ow < min_size:
-                ow,oh = min_size,int(0.5 + oh*float(min_size)/ow)
-                
-        assert ow >= min_size, 'image too small (width=%d < min_size=%d)' % (ow, min_size)
-        assert oh >= min_size, 'image too small (height=%d < min_size=%d)' % (oh, min_size)
+                ow, oh = min_size, int(0.5 + oh * float(min_size) / ow)
+
+        assert ow >= min_size, "image too small (width=%d < min_size=%d)" % (
+            ow,
+            min_size,
+        )
+        assert oh >= min_size, "image too small (height=%d < min_size=%d)" % (
+            oh,
+            min_size,
+        )
         return ow, oh
 
 
-
-class RandomCrop (object):
+class RandomCrop(object):
     """Crop the given PIL Image at a random location.
     Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
 
@@ -182,7 +221,12 @@ class RandomCrop (object):
     def get_params(img, output_size):
         w, h = img.size
         th, tw = output_size
-        assert h >= th and w >= tw, "Image of %dx%d is too small for crop %dx%d" % (w,h,tw,th)
+        assert h >= th and w >= tw, "Image of %dx%d is too small for crop %dx%d" % (
+            w,
+            h,
+            tw,
+            th,
+        )
 
         y = np.random.randint(0, h - th) if h > th else 0
         x = np.random.randint(0, w - tw) if w > tw else 0
@@ -204,12 +248,14 @@ class RandomCrop (object):
                 padl, padt = self.padding[0:2]
 
         i, j, tw, th = self.get_params(img, self.size)
-        img = img.crop((i, j, i+tw, j+th))
-        
-        return F.update_img_and_labels(inp, img, persp=(1,0,padl-i,0,1,padt-j,0,0))
+        img = img.crop((i, j, i + tw, j + th))
 
+        return F.update_img_and_labels(
+            inp, img, persp=(1, 0, padl - i, 0, 1, padt - j, 0, 0)
+        )
 
-class CenterCrop (RandomCrop):
+
+class CenterCrop(RandomCrop):
     """Crops the given PIL Image at the center.
     Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
 
@@ -218,16 +264,16 @@ class CenterCrop (RandomCrop):
             int instead of sequence like (h, w), a square crop (size, size) is
             made.
     """
+
     @staticmethod
     def get_params(img, output_size):
         w, h = img.size
         th, tw = output_size
-        y = int(0.5 +((h - th) / 2.))
-        x = int(0.5 +((w - tw) / 2.))
+        y = int(0.5 + ((h - th) / 2.0))
+        x = int(0.5 + ((w - tw) / 2.0))
         return x, y, tw, th
 
 
-
 class RandomRotation(object):
     """Rescale the input PIL.Image to a random size.
     Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
@@ -247,19 +293,18 @@ class RandomRotation(object):
     def __call__(self, inp):
         img = F.grab_img(inp)
         w, h = img.size
-        
+
         angle = np.random.uniform(-self.degrees, self.degrees)
-        
+
         img = img.rotate(angle, resample=self.interpolation)
         w2, h2 = img.size
 
-        trf = F.translate(-w/2,-h/2)
-        trf = F.persp_mul(trf, F.rotate(-angle * np.pi/180))
-        trf = F.persp_mul(trf, F.translate(w2/2,h2/2))
+        trf = F.translate(-w / 2, -h / 2)
+        trf = F.persp_mul(trf, F.rotate(-angle * np.pi / 180))
+        trf = F.persp_mul(trf, F.translate(w2 / 2, h2 / 2))
         return F.update_img_and_labels(inp, img, persp=trf)
 
 
-
 class RandomTilting(object):
     """Apply a random tilting (left, right, up, down) to the input PIL.Image
     Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
@@ -272,34 +317,34 @@ class RandomTilting(object):
             examples: "all", "left,right", "up-down-right"
     """
 
-    def __init__(self, magnitude, directions='all'):
+    def __init__(self, magnitude, directions="all"):
         self.magnitude = magnitude
-        self.directions = directions.lower().replace(',',' ').replace('-',' ')
+        self.directions = directions.lower().replace(",", " ").replace("-", " ")
 
     def __repr__(self):
-        return "RandomTilt(%g, '%s')" % (self.magnitude,self.directions)
+        return "RandomTilt(%g, '%s')" % (self.magnitude, self.directions)
 
     def __call__(self, inp):
         img = F.grab_img(inp)
         w, h = img.size
 
-        x1,y1,x2,y2 = 0,0,h,w
+        x1, y1, x2, y2 = 0, 0, h, w
         original_plane = [(y1, x1), (y2, x1), (y2, x2), (y1, x2)]
 
         max_skew_amount = max(w, h)
         max_skew_amount = int(ceil(max_skew_amount * self.magnitude))
         skew_amount = random.randint(1, max_skew_amount)
 
-        if self.directions == 'all':
-            choices = [0,1,2,3]
+        if self.directions == "all":
+            choices = [0, 1, 2, 3]
         else:
-            dirs = ['left', 'right', 'up', 'down']
+            dirs = ["left", "right", "up", "down"]
             choices = []
             for d in self.directions.split():
                 try:
                     choices.append(dirs.index(d))
                 except:
-                    raise ValueError('Tilting direction %s not recognized' % d)
+                    raise ValueError("Tilting direction %s not recognized" % d)
 
         skew_direction = random.choice(choices)
 
@@ -307,28 +352,36 @@ class RandomTilting(object):
 
         if skew_direction == 0:
             # Left Tilt
-            new_plane = [(y1, x1 - skew_amount),  # Top Left
-                         (y2, x1),                # Top Right
-                         (y2, x2),                # Bottom Right
-                         (y1, x2 + skew_amount)]  # Bottom Left
+            new_plane = [
+                (y1, x1 - skew_amount),  # Top Left
+                (y2, x1),  # Top Right
+                (y2, x2),  # Bottom Right
+                (y1, x2 + skew_amount),
+            ]  # Bottom Left
         elif skew_direction == 1:
             # Right Tilt
-            new_plane = [(y1, x1),                # Top Left
-                         (y2, x1 - skew_amount),  # Top Right
-                         (y2, x2 + skew_amount),  # Bottom Right
-                         (y1, x2)]                # Bottom Left
+            new_plane = [
+                (y1, x1),  # Top Left
+                (y2, x1 - skew_amount),  # Top Right
+                (y2, x2 + skew_amount),  # Bottom Right
+                (y1, x2),
+            ]  # Bottom Left
         elif skew_direction == 2:
             # Forward Tilt
-            new_plane = [(y1 - skew_amount, x1),  # Top Left
-                         (y2 + skew_amount, x1),  # Top Right
-                         (y2, x2),                # Bottom Right
-                         (y1, x2)]                # Bottom Left
+            new_plane = [
+                (y1 - skew_amount, x1),  # Top Left
+                (y2 + skew_amount, x1),  # Top Right
+                (y2, x2),  # Bottom Right
+                (y1, x2),
+            ]  # Bottom Left
         elif skew_direction == 3:
             # Backward Tilt
-            new_plane = [(y1, x1),                # Top Left
-                         (y2, x1),                # Top Right
-                         (y2 + skew_amount, x2),  # Bottom Right
-                         (y1 - skew_amount, x2)]  # Bottom Left
+            new_plane = [
+                (y1, x1),  # Top Left
+                (y2, x1),  # Top Right
+                (y2 + skew_amount, x2),  # Bottom Right
+                (y1 - skew_amount, x2),
+            ]  # Bottom Left
 
         # To calculate the coefficients required by PIL for the perspective skew,
         # see the following Stack Overflow discussion: https://goo.gl/sSgJdj
@@ -343,42 +396,49 @@ class RandomTilting(object):
 
         homography = np.dot(np.linalg.pinv(A), B)
         homography = tuple(np.array(homography).reshape(8))
-        #print(homography)
+        # print(homography)
 
-        img =  img.transform(img.size, Image.PERSPECTIVE, homography, resample=Image.BICUBIC)
+        img = img.transform(
+            img.size, Image.PERSPECTIVE, homography, resample=Image.BICUBIC
+        )
 
-        homography = np.linalg.pinv(np.float32(homography+(1,)).reshape(3,3)).ravel()[:8]
+        homography = np.linalg.pinv(
+            np.float32(homography + (1,)).reshape(3, 3)
+        ).ravel()[:8]
         return F.update_img_and_labels(inp, img, persp=tuple(homography))
 
 
-RandomTilt = RandomTilting # redefinition
+RandomTilt = RandomTilting  # redefinition
 
 
 class Tilt(object):
-    """Apply a known tilting to an image
-    """
+    """Apply a known tilting to an image"""
+
     def __init__(self, *homography):
         assert len(homography) == 8
         self.homography = homography
-    
+
     def __call__(self, inp):
         img = F.grab_img(inp)
         homography = self.homography
-        #print(homography)
-        
-        img =  img.transform(img.size, Image.PERSPECTIVE, homography, resample=Image.BICUBIC)
-        
-        homography = np.linalg.pinv(np.float32(homography+(1,)).reshape(3,3)).ravel()[:8]
+        # print(homography)
+
+        img = img.transform(
+            img.size, Image.PERSPECTIVE, homography, resample=Image.BICUBIC
+        )
+
+        homography = np.linalg.pinv(
+            np.float32(homography + (1,)).reshape(3, 3)
+        ).ravel()[:8]
         return F.update_img_and_labels(inp, img, persp=tuple(homography))
 
 
+class StillTransform(object):
+    """Takes and return an image, without changing its shape or geometry."""
 
-class StillTransform (object):
-    """ Takes and return an image, without changing its shape or geometry.
-    """
     def _transform(self, img):
         raise NotImplementedError()
-        
+
     def __call__(self, inp):
         img = F.grab_img(inp)
 
@@ -388,13 +448,12 @@ class StillTransform (object):
         except TypeError:
             pass
 
-        return F.update_img_and_labels(inp, img, persp=(1,0,0,0,1,0,0,0))
+        return F.update_img_and_labels(inp, img, persp=(1, 0, 0, 0, 1, 0, 0, 0))
 
 
+class PixelNoise(StillTransform):
+    """Takes an image, and add random white noise."""
 
-class PixelNoise (StillTransform):
-    """ Takes an image, and add random white noise.
-    """
     def __init__(self, ampl=20):
         StillTransform.__init__(self)
         assert 0 <= ampl < 255
@@ -405,12 +464,13 @@ class PixelNoise (StillTransform):
 
     def _transform(self, img):
         img = np.float32(img)
-        img += np.random.uniform(0.5-self.ampl/2, 0.5+self.ampl/2, size=img.shape)
-        return Image.fromarray(np.uint8(img.clip(0,255)))
-
+        img += np.random.uniform(
+            0.5 - self.ampl / 2, 0.5 + self.ampl / 2, size=img.shape
+        )
+        return Image.fromarray(np.uint8(img.clip(0, 255)))
 
 
-class ColorJitter (StillTransform):
+class ColorJitter(StillTransform):
     """Randomly change the brightness, contrast and saturation of an image.
     Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
 
@@ -424,6 +484,7 @@ class ColorJitter (StillTransform):
     hue(float): How much to jitter hue. hue_factor is chosen uniformly from
     [-hue, hue]. Should be >=0 and <= 0.5.
     """
+
     def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
         self.brightness = brightness
         self.contrast = contrast
@@ -432,8 +493,12 @@ class ColorJitter (StillTransform):
 
     def __repr__(self):
         return "ColorJitter(%g,%g,%g,%g)" % (
-            self.brightness, self.contrast, self.saturation, self.hue)
-    
+            self.brightness,
+            self.contrast,
+            self.saturation,
+            self.hue,
+        )
+
     @staticmethod
     def get_params(brightness, contrast, saturation, hue):
         """Get a randomized transform to be applied on image.
@@ -444,16 +509,26 @@ class ColorJitter (StillTransform):
         """
         transforms = []
         if brightness > 0:
-            brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
-            transforms.append(tvf.Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
+            brightness_factor = np.random.uniform(
+                max(0, 1 - brightness), 1 + brightness
+            )
+            transforms.append(
+                tvf.Lambda(lambda img: F.adjust_brightness(img, brightness_factor))
+            )
 
         if contrast > 0:
             contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
-            transforms.append(tvf.Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
+            transforms.append(
+                tvf.Lambda(lambda img: F.adjust_contrast(img, contrast_factor))
+            )
 
         if saturation > 0:
-            saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
-            transforms.append(tvf.Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
+            saturation_factor = np.random.uniform(
+                max(0, 1 - saturation), 1 + saturation
+            )
+            transforms.append(
+                tvf.Lambda(lambda img: F.adjust_saturation(img, saturation_factor))
+            )
 
         if hue > 0:
             hue_factor = np.random.uniform(-hue, hue)
@@ -467,47 +542,52 @@ class ColorJitter (StillTransform):
         return transform
 
     def _transform(self, img):
-        transform = self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
+        transform = self.get_params(
+            self.brightness, self.contrast, self.saturation, self.hue
+        )
         return transform(img)
 
 
-
-if __name__ == '__main__':
+if __name__ == "__main__":
     import argparse
+
     parser = argparse.ArgumentParser("Script to try out and visualize transformations")
-    parser.add_argument('--img', type=str, default='imgs/test.png', help='input image')
-    parser.add_argument('--trfs', type=str, required=True, help='list of transformations')
-    parser.add_argument('--layout', type=int, nargs=2, default=(3,3), help='nb of rows,cols')
+    parser.add_argument("--img", type=str, default="imgs/test.png", help="input image")
+    parser.add_argument(
+        "--trfs", type=str, required=True, help="list of transformations"
+    )
+    parser.add_argument(
+        "--layout", type=int, nargs=2, default=(3, 3), help="nb of rows,cols"
+    )
     args = parser.parse_args()
-    
+
     import os
-    args.img = args.img.replace('$HERE',os.path.dirname(__file__))
+
+    args.img = args.img.replace("$HERE", os.path.dirname(__file__))
     img = Image.open(args.img)
     img = dict(img=img)
-    
+
     trfs = instanciate_transformation(args.trfs)
-    
+
     from matplotlib import pyplot as pl
+
     pl.ion()
-    pl.subplots_adjust(0,0,1,1)
-    
-    nr,nc = args.layout
-    
+    pl.subplots_adjust(0, 0, 1, 1)
+
+    nr, nc = args.layout
+
     while True:
         for j in range(nr):
             for i in range(nc):
-                pl.subplot(nr,nc,i+j*nc+1)
-                if i==j==0:
+                pl.subplot(nr, nc, i + j * nc + 1)
+                if i == j == 0:
                     img2 = img
                 else:
                     img2 = trfs(img.copy())
                 if isinstance(img2, dict):
-                    img2 = img2['img']
+                    img2 = img2["img"]
                 pl.imshow(img2)
                 pl.xlabel("%d x %d" % img2.size)
                 pl.xticks(())
                 pl.yticks(())
         pdb.set_trace()
-    
-
-
diff --git a/third_party/r2d2/tools/transforms_tools.py b/third_party/r2d2/tools/transforms_tools.py
index 294c22228a88f70480af52f79a77d73f9e5b3e1a..77eb1da2306116d789cdcf6b957a6c144a746a4f 100644
--- a/third_party/r2d2/tools/transforms_tools.py
+++ b/third_party/r2d2/tools/transforms_tools.py
@@ -8,31 +8,31 @@ from PIL import Image, ImageOps, ImageEnhance
 
 
 class DummyImg:
-    ''' This class is a dummy image only defined by its size.
-    '''
+    """This class is a dummy image only defined by its size."""
+
     def __init__(self, size):
         self.size = size
-        
+
     def resize(self, size, *args, **kwargs):
         return DummyImg(size)
-        
+
     def expand(self, border):
         w, h = self.size
         if isinstance(border, int):
-            size = (w+2*border, h+2*border)
+            size = (w + 2 * border, h + 2 * border)
         else:
-            l,t,r,b = border
-            size = (w+l+r, h+t+b)
+            l, t, r, b = border
+            size = (w + l + r, h + t + b)
         return DummyImg(size)
 
     def crop(self, border):
         w, h = self.size
-        l,t,r,b = border
+        l, t, r, b = border
         assert 0 <= l <= r <= w
         assert 0 <= t <= b <= h
-        size = (r-l, b-t)
+        size = (r - l, b - t)
         return DummyImg(size)
-    
+
     def rotate(self, angle):
         raise NotImplementedError
 
@@ -40,89 +40,85 @@ class DummyImg:
         return DummyImg(size)
 
 
-def grab_img( img_and_label ):
-    ''' Called to extract the image from an img_and_label input
+def grab_img(img_and_label):
+    """Called to extract the image from an img_and_label input
     (a dictionary). Also compatible with old-style PIL images.
-    '''
+    """
     if isinstance(img_and_label, dict):
         # if input is a dictionary, then
         # it must contains the img or its size.
         try:
-            return img_and_label['img']
+            return img_and_label["img"]
         except KeyError:
-            return DummyImg(img_and_label['imsize'])
-            
+            return DummyImg(img_and_label["imsize"])
+
     else:
         # or it must be the img directly
         return img_and_label
 
 
 def update_img_and_labels(img_and_label, img, persp=None):
-    ''' Called to update the img_and_label
-    '''
+    """Called to update the img_and_label"""
     if isinstance(img_and_label, dict):
-        img_and_label['img'] = img
-        img_and_label['imsize'] = img.size
+        img_and_label["img"] = img
+        img_and_label["imsize"] = img.size
 
         if persp:
-            if 'persp' not in img_and_label:
-                img_and_label['persp'] = (1,0,0,0,1,0,0,0)
-            img_and_label['persp'] = persp_mul(persp, img_and_label['persp'])
-        
+            if "persp" not in img_and_label:
+                img_and_label["persp"] = (1, 0, 0, 0, 1, 0, 0, 0)
+            img_and_label["persp"] = persp_mul(persp, img_and_label["persp"])
+
         return img_and_label
-        
+
     else:
         # or it must be the img directly
         return img
 
 
 def rand_log_uniform(a, b):
-    return np.exp(np.random.uniform(np.log(a),np.log(b)))
+    return np.exp(np.random.uniform(np.log(a), np.log(b)))
 
 
 def translate(tx, ty):
-    return (1,0,tx,
-            0,1,ty,
-            0,0)
+    return (1, 0, tx, 0, 1, ty, 0, 0)
+
 
 def rotate(angle):
-    return (np.cos(angle),-np.sin(angle), 0,
-            np.sin(angle), np.cos(angle), 0,
-            0, 0)
+    return (np.cos(angle), -np.sin(angle), 0, np.sin(angle), np.cos(angle), 0, 0, 0)
 
 
 def persp_mul(mat, mat2):
-    ''' homography (perspective) multiplication.
+    """homography (perspective) multiplication.
     mat: 8-tuple (homography transform)
     mat2: 8-tuple (homography transform) or 2-tuple (point)
-    '''
+    """
     assert isinstance(mat, tuple)
     assert isinstance(mat2, tuple)
 
-    mat = np.float32(mat+(1,)).reshape(3,3)
-    mat2 = np.array(mat2+(1,)).reshape(3,3)
+    mat = np.float32(mat + (1,)).reshape(3, 3)
+    mat2 = np.array(mat2 + (1,)).reshape(3, 3)
     res = np.dot(mat, mat2)
-    return tuple((res/res[2,2]).ravel()[:8])
+    return tuple((res / res[2, 2]).ravel()[:8])
 
 
 def persp_apply(mat, pts):
-    ''' homography (perspective) transformation.
+    """homography (perspective) transformation.
     mat: 8-tuple (homography transform)
     pts: numpy array
-    '''
+    """
     assert isinstance(mat, tuple)
     assert isinstance(pts, np.ndarray)
     assert pts.shape[-1] == 2
-    mat = np.float32(mat+(1,)).reshape(3,3)
+    mat = np.float32(mat + (1,)).reshape(3, 3)
 
     if pts.ndim == 1:
-        pt = np.dot(pts, mat[:,:2].T).ravel() + mat[:,2]
-        pt /= pt[2] # homogeneous coordinates
+        pt = np.dot(pts, mat[:, :2].T).ravel() + mat[:, 2]
+        pt /= pt[2]  # homogeneous coordinates
         return tuple(pt[:2])
     else:
-        pt = np.dot(pts, mat[:,:2].T) + mat[:,2]
-        pt[:,:2] /= pt[:,2:3] # homogeneous coordinates
-        return pt[:,:2]
+        pt = np.dot(pts, mat[:, :2].T) + mat[:, 2]
+        pt[:, :2] /= pt[:, 2:3]  # homogeneous coordinates
+        return pt[:, :2]
 
 
 def is_pil_image(img):
@@ -141,7 +137,7 @@ def adjust_brightness(img, brightness_factor):
     Copied from https://github.com/pytorch in torchvision/transforms/functional.py
     """
     if not is_pil_image(img):
-        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
+        raise TypeError("img should be PIL Image. Got {}".format(type(img)))
 
     enhancer = ImageEnhance.Brightness(img)
     img = enhancer.enhance(brightness_factor)
@@ -160,7 +156,7 @@ def adjust_contrast(img, contrast_factor):
     Copied from https://github.com/pytorch in torchvision/transforms/functional.py
     """
     if not is_pil_image(img):
-        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
+        raise TypeError("img should be PIL Image. Got {}".format(type(img)))
 
     enhancer = ImageEnhance.Contrast(img)
     img = enhancer.enhance(contrast_factor)
@@ -179,7 +175,7 @@ def adjust_saturation(img, saturation_factor):
     Copied from https://github.com/pytorch in torchvision/transforms/functional.py
     """
     if not is_pil_image(img):
-        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
+        raise TypeError("img should be PIL Image. Got {}".format(type(img)))
 
     enhancer = ImageEnhance.Color(img)
     img = enhancer.enhance(saturation_factor)
@@ -205,26 +201,23 @@ def adjust_hue(img, hue_factor):
     PIL Image: Hue adjusted image.
     Copied from https://github.com/pytorch in torchvision/transforms/functional.py
     """
-    if not(-0.5 <= hue_factor <= 0.5):
-        raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
+    if not (-0.5 <= hue_factor <= 0.5):
+        raise ValueError("hue_factor is not in [-0.5, 0.5].".format(hue_factor))
 
     if not is_pil_image(img):
-        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
+        raise TypeError("img should be PIL Image. Got {}".format(type(img)))
 
     input_mode = img.mode
-    if input_mode in {'L', '1', 'I', 'F'}:
+    if input_mode in {"L", "1", "I", "F"}:
         return img
 
-    h, s, v = img.convert('HSV').split()
+    h, s, v = img.convert("HSV").split()
 
     np_h = np.array(h, dtype=np.uint8)
     # uint8 addition take cares of rotation across boundaries
-    with np.errstate(over='ignore'):
+    with np.errstate(over="ignore"):
         np_h += np.uint8(hue_factor * 255)
-        h = Image.fromarray(np_h, 'L')
+        h = Image.fromarray(np_h, "L")
 
-    img = Image.merge('HSV', (h, s, v)).convert(input_mode)
+    img = Image.merge("HSV", (h, s, v)).convert(input_mode)
     return img
-
-
-
diff --git a/third_party/r2d2/tools/viz.py b/third_party/r2d2/tools/viz.py
index c86103f3aeb468fca8b0ac9a412f22b85239361b..4cf4b90a670ee448d9d6d1ba4137abae32def005 100644
--- a/third_party/r2d2/tools/viz.py
+++ b/third_party/r2d2/tools/viz.py
@@ -8,16 +8,16 @@ import matplotlib.pyplot as pl
 
 
 def make_colorwheel():
-    '''
+    """
     Generates a color wheel for optical flow visualization as presented in:
         Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
         URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
     According to the C++ source code of Daniel Scharstein
     According to the Matlab source code of Deqing Sun
-    
+
     Copied from https://github.com/tomrunia/OpticalFlow_Visualization/blob/master/flow_vis.py
     Copyright (c) 2018 Tom Runia
-    '''
+    """
 
     RY = 15
     YG = 6
@@ -32,32 +32,32 @@ def make_colorwheel():
 
     # RY
     colorwheel[0:RY, 0] = 255
-    colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
-    col = col+RY
+    colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
+    col = col + RY
     # YG
-    colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
-    colorwheel[col:col+YG, 1] = 255
-    col = col+YG
+    colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
+    colorwheel[col : col + YG, 1] = 255
+    col = col + YG
     # GC
-    colorwheel[col:col+GC, 1] = 255
-    colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
-    col = col+GC
+    colorwheel[col : col + GC, 1] = 255
+    colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
+    col = col + GC
     # CB
-    colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
-    colorwheel[col:col+CB, 2] = 255
-    col = col+CB
+    colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
+    colorwheel[col : col + CB, 2] = 255
+    col = col + CB
     # BM
-    colorwheel[col:col+BM, 2] = 255
-    colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
-    col = col+BM
+    colorwheel[col : col + BM, 2] = 255
+    colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
+    col = col + BM
     # MR
-    colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
-    colorwheel[col:col+MR, 0] = 255
+    colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
+    colorwheel[col : col + MR, 0] = 255
     return colorwheel
 
 
 def flow_compute_color(u, v, convert_to_bgr=False):
-    '''
+    """
     Applies the flow color wheel to (possibly clipped) flow components u and v.
     According to the C++ source code of Daniel Scharstein
     According to the Matlab source code of Deqing Sun
@@ -65,10 +65,10 @@ def flow_compute_color(u, v, convert_to_bgr=False):
     :param v: np.ndarray, input vertical flow
     :param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB
     :return:
-    
+
     Copied from https://github.com/tomrunia/OpticalFlow_Visualization/blob/master/flow_vis.py
     Copyright (c) 2018 Tom Runia
-    '''
+    """
 
     flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
 
@@ -76,9 +76,9 @@ def flow_compute_color(u, v, convert_to_bgr=False):
     ncols = colorwheel.shape[0]
 
     rad = np.sqrt(np.square(u) + np.square(v))
-    a = np.arctan2(-v, -u)/np.pi
+    a = np.arctan2(-v, -u) / np.pi
 
-    fk = (a+1) / 2*(ncols-1)
+    fk = (a + 1) / 2 * (ncols - 1)
     k0 = np.floor(fk).astype(np.int32)
     k1 = k0 + 1
     k1[k1 == ncols] = 0
@@ -86,43 +86,43 @@ def flow_compute_color(u, v, convert_to_bgr=False):
 
     for i in range(colorwheel.shape[1]):
 
-        tmp = colorwheel[:,i]
+        tmp = colorwheel[:, i]
         col0 = tmp[k0] / 255.0
         col1 = tmp[k1] / 255.0
-        col = (1-f)*col0 + f*col1
+        col = (1 - f) * col0 + f * col1
 
-        idx = (rad <= 1)
-        col[idx]  = 1 - rad[idx] * (1-col[idx])
-        col[~idx] = col[~idx] * 0.75   # out of range?
+        idx = rad <= 1
+        col[idx] = 1 - rad[idx] * (1 - col[idx])
+        col[~idx] = col[~idx] * 0.75  # out of range?
 
         # Note the 2-i => BGR instead of RGB
-        ch_idx = 2-i if convert_to_bgr else i
-        flow_image[:,:,ch_idx] = np.floor(255 * col)
+        ch_idx = 2 - i if convert_to_bgr else i
+        flow_image[:, :, ch_idx] = np.floor(255 * col)
 
     return flow_image
 
 
 def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False):
-    '''
+    """
     Expects a two dimensional flow image of shape [H,W,2]
     According to the C++ source code of Daniel Scharstein
     According to the Matlab source code of Deqing Sun
     :param flow_uv: np.ndarray of shape [H,W,2]
     :param clip_flow: float, maximum clipping value for flow
     :return:
-    
+
     Copied from https://github.com/tomrunia/OpticalFlow_Visualization/blob/master/flow_vis.py
     Copyright (c) 2018 Tom Runia
-    '''
+    """
 
-    assert flow_uv.ndim == 3, 'input flow must have three dimensions'
-    assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
+    assert flow_uv.ndim == 3, "input flow must have three dimensions"
+    assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]"
 
     if clip_flow is not None:
         flow_uv = np.clip(flow_uv, 0, clip_flow)
 
-    u = flow_uv[:,:,0]
-    v = flow_uv[:,:,1]
+    u = flow_uv[:, :, 0]
+    v = flow_uv[:, :, 1]
 
     rad = np.sqrt(np.square(u) + np.square(v))
     rad_max = np.max(rad)
@@ -134,58 +134,59 @@ def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False):
     return flow_compute_color(u, v, convert_to_bgr)
 
 
-
-def show_flow( img0, img1, flow, mask=None ):
+def show_flow(img0, img1, flow, mask=None):
     img0 = np.asarray(img0)
     img1 = np.asarray(img1)
-    if mask is None: mask = 1
+    if mask is None:
+        mask = 1
     mask = np.asarray(mask)
-    if mask.ndim == 2: mask = mask[:,:,None]
+    if mask.ndim == 2:
+        mask = mask[:, :, None]
     assert flow.ndim == 3
     assert flow.shape[:2] == img0.shape[:2] and flow.shape[2] == 2
-    
+
     def noticks():
-      pl.xticks([])
-      pl.yticks([])
+        pl.xticks([])
+        pl.yticks([])
+
     fig = pl.figure("showing correspondences")
     ax1 = pl.subplot(221)
     ax1.numaxis = 0
-    pl.imshow(img0*mask)
+    pl.imshow(img0 * mask)
     noticks()
     ax2 = pl.subplot(222)
     ax2.numaxis = 1
     pl.imshow(img1)
     noticks()
-    
+
     ax = pl.subplot(212)
     ax.numaxis = 0
     flow_img = flow_to_color(np.where(np.isnan(flow), 0, flow))
     pl.imshow(flow_img * mask)
     noticks()
-    
+
     pl.subplots_adjust(0.01, 0.01, 0.99, 0.99, wspace=0.02, hspace=0.02)
-    
+
     def motion_notify_callback(event):
-      if event.inaxes is None: return
-      x,y = event.xdata, event.ydata
-      ax1.lines = []
-      ax2.lines = []
-      try:
-        x,y = int(x+0.5), int(y+0.5)
-        ax1.plot(x,y,'+',ms=10,mew=2,color='blue',scalex=False,scaley=False)
-        x,y = flow[y,x] + (x,y)
-        ax2.plot(x,y,'+',ms=10,mew=2,color='red',scalex=False,scaley=False)
-        # we redraw only the concerned axes
-        renderer = fig.canvas.get_renderer()
-        ax1.draw(renderer)
-        ax2.draw(renderer)
-        fig.canvas.blit(ax1.bbox)
-        fig.canvas.blit(ax2.bbox)
-      except IndexError:
-        return
-  
-    cid_move = fig.canvas.mpl_connect('motion_notify_event',motion_notify_callback)
+        if event.inaxes is None:
+            return
+        x, y = event.xdata, event.ydata
+        ax1.lines = []
+        ax2.lines = []
+        try:
+            x, y = int(x + 0.5), int(y + 0.5)
+            ax1.plot(x, y, "+", ms=10, mew=2, color="blue", scalex=False, scaley=False)
+            x, y = flow[y, x] + (x, y)
+            ax2.plot(x, y, "+", ms=10, mew=2, color="red", scalex=False, scaley=False)
+            # we redraw only the concerned axes
+            renderer = fig.canvas.get_renderer()
+            ax1.draw(renderer)
+            ax2.draw(renderer)
+            fig.canvas.blit(ax1.bbox)
+            fig.canvas.blit(ax2.bbox)
+        except IndexError:
+            return
+
+    cid_move = fig.canvas.mpl_connect("motion_notify_event", motion_notify_callback)
     print("Move your mouse over the images to show matches (ctrl-C to quit)")
     pl.show()
-
-
diff --git a/third_party/r2d2/train.py b/third_party/r2d2/train.py
index 10d23d9e40ebe8cb10c4d548b7fcb5c1c0fd7739..232d61d0eb830454b4f785cfb82536b6cfba7071 100644
--- a/third_party/r2d2/train.py
+++ b/third_party/r2d2/train.py
@@ -35,12 +35,12 @@ db_aachen_style_transfer = """TransformedPairs(
 db_aachen_flow = "aachen_flow_pairs"
 
 data_sources = dict(
-    D = toy_db_debug,
-    W = db_web_images,
-    A = db_aachen_images,
-    F = db_aachen_flow,
-    S = db_aachen_style_transfer,
-    )
+    D=toy_db_debug,
+    W=db_web_images,
+    A=db_aachen_images,
+    F=db_aachen_flow,
+    S=db_aachen_style_transfer,
+)
 
 default_dataloader = """PairLoader(CatPairDataset(`data`),
     scale   = 'RandomScale(256,1024,can_upscale=True)',
@@ -57,75 +57,101 @@ default_loss = """MultiLoss(
 
 
 class MyTrainer(trainer.Trainer):
-    """ This class implements the network training.
-        Below is the function I need to overload to explain how to do the backprop.
+    """This class implements the network training.
+    Below is the function I need to overload to explain how to do the backprop.
     """
+
     def forward_backward(self, inputs):
-        output = self.net(imgs=[inputs.pop('img1'),inputs.pop('img2')])
+        output = self.net(imgs=[inputs.pop("img1"), inputs.pop("img2")])
         allvars = dict(inputs, **output)
         loss, details = self.loss_func(**allvars)
-        if torch.is_grad_enabled(): loss.backward()
+        if torch.is_grad_enabled():
+            loss.backward()
         return loss, details
 
 
-
-if __name__ == '__main__':
+if __name__ == "__main__":
     import argparse
+
     parser = argparse.ArgumentParser("Train R2D2")
 
     parser.add_argument("--data-loader", type=str, default=default_dataloader)
-    parser.add_argument("--train-data", type=str, default=list('WASF'), nargs='+', 
-        choices = set(data_sources.keys()))
-    parser.add_argument("--net", type=str, default=default_net, help='network architecture')
+    parser.add_argument(
+        "--train-data",
+        type=str,
+        default=list("WASF"),
+        nargs="+",
+        choices=set(data_sources.keys()),
+    )
+    parser.add_argument(
+        "--net", type=str, default=default_net, help="network architecture"
+    )
+
+    parser.add_argument(
+        "--pretrained", type=str, default="", help="pretrained model path"
+    )
+    parser.add_argument(
+        "--save-path", type=str, required=True, help="model save_path path"
+    )
 
-    parser.add_argument("--pretrained", type=str, default="", help='pretrained model path')
-    parser.add_argument("--save-path", type=str, required=True, help='model save_path path')
-    
     parser.add_argument("--loss", type=str, default=default_loss, help="loss function")
-    parser.add_argument("--sampler", type=str, default=default_sampler, help="AP sampler")
-    parser.add_argument("--N", type=int, default=16, help="patch size for repeatability")
+    parser.add_argument(
+        "--sampler", type=str, default=default_sampler, help="AP sampler"
+    )
+    parser.add_argument(
+        "--N", type=int, default=16, help="patch size for repeatability"
+    )
 
-    parser.add_argument("--epochs", type=int, default=25, help='number of training epochs')
+    parser.add_argument(
+        "--epochs", type=int, default=25, help="number of training epochs"
+    )
     parser.add_argument("--batch-size", "--bs", type=int, default=8, help="batch size")
     parser.add_argument("--learning-rate", "--lr", type=str, default=1e-4)
     parser.add_argument("--weight-decay", "--wd", type=float, default=5e-4)
-    
-    parser.add_argument("--threads", type=int, default=8, help='number of worker threads')
-    parser.add_argument("--gpu", type=int, nargs='+', default=[0], help='-1 for CPU')
-    
+
+    parser.add_argument(
+        "--threads", type=int, default=8, help="number of worker threads"
+    )
+    parser.add_argument("--gpu", type=int, nargs="+", default=[0], help="-1 for CPU")
+
     args = parser.parse_args()
-    
+
     iscuda = common.torch_set_gpu(args.gpu)
     common.mkdir_for(args.save_path)
 
     # Create data loader
     from datasets import *
+
     db = [data_sources[key] for key in args.train_data]
-    db = eval(args.data_loader.replace('`data`',','.join(db)).replace('\n',''))
+    db = eval(args.data_loader.replace("`data`", ",".join(db)).replace("\n", ""))
     print("Training image database =", db)
     loader = threaded_loader(db, iscuda, args.threads, args.batch_size, shuffle=True)
 
     # create network
-    print("\n>> Creating net = " + args.net) 
+    print("\n>> Creating net = " + args.net)
     net = eval(args.net)
     print(f" ( Model size: {common.model_size(net)/1000:.0f}K parameters )")
 
     # initialization
     if args.pretrained:
-        checkpoint = torch.load(args.pretrained, lambda a,b:a)
-        net.load_pretrained(checkpoint['state_dict'])
-        
+        checkpoint = torch.load(args.pretrained, lambda a, b: a)
+        net.load_pretrained(checkpoint["state_dict"])
+
     # create losses
-    loss = args.loss.replace('`sampler`',args.sampler).replace('`N`',str(args.N))
+    loss = args.loss.replace("`sampler`", args.sampler).replace("`N`", str(args.N))
     print("\n>> Creating loss = " + loss)
-    loss = eval(loss.replace('\n',''))
-    
+    loss = eval(loss.replace("\n", ""))
+
     # create optimizer
-    optimizer = optim.Adam( [p for p in net.parameters() if p.requires_grad], 
-                            lr=args.learning_rate, weight_decay=args.weight_decay)
+    optimizer = optim.Adam(
+        [p for p in net.parameters() if p.requires_grad],
+        lr=args.learning_rate,
+        weight_decay=args.weight_decay,
+    )
 
     train = MyTrainer(net, loader, loss, optimizer)
-    if iscuda: train = train.cuda()
+    if iscuda:
+        train = train.cuda()
 
     # Training loop #
     for epoch in range(args.epochs):
@@ -133,6 +159,4 @@ if __name__ == '__main__':
         train()
 
     print(f"\n>> Saving model to {args.save_path}")
-    torch.save({'net': args.net, 'state_dict': net.state_dict()}, args.save_path)
-
-
+    torch.save({"net": args.net, "state_dict": net.state_dict()}, args.save_path)
diff --git a/third_party/r2d2/viz_heatmaps.py b/third_party/r2d2/viz_heatmaps.py
index 42705e70ecea82696a0d784b274f7f387fdf6595..e5cb8b3bb502ce4d9e5169c55be3f479f8f8fce4 100644
--- a/third_party/r2d2/viz_heatmaps.py
+++ b/third_party/r2d2/viz_heatmaps.py
@@ -7,116 +7,134 @@ import numpy as np
 import torch
 
 from PIL import Image
-from matplotlib import pyplot as pl; pl.ion()
+from matplotlib import pyplot as pl
+
+pl.ion()
 from scipy.ndimage import uniform_filter
+
 smooth = lambda arr: uniform_filter(arr, 3)
 
+
 def transparent(img, alpha, cmap, **kw):
     from matplotlib.colors import Normalize
-    colored_img = cmap(Normalize(clip=True,**kw)(img))
-    colored_img[:,:,-1] = alpha
+
+    colored_img = cmap(Normalize(clip=True, **kw)(img))
+    colored_img[:, :, -1] = alpha
     return colored_img
 
+
 from tools import common
 from tools.dataloader import norm_RGB
 from nets.patchnet import *
 from extract import NonMaxSuppression
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     import argparse
+
     parser = argparse.ArgumentParser("Visualize the patch detector and descriptor")
-    
+
     parser.add_argument("--img", type=str, default="imgs/brooklyn.png")
     parser.add_argument("--resize", type=int, default=512)
     parser.add_argument("--out", type=str, default="viz.png")
 
-    parser.add_argument("--checkpoint", type=str, required=True, help='network path')
-    parser.add_argument("--net", type=str, default="", help='network command')
+    parser.add_argument("--checkpoint", type=str, required=True, help="network path")
+    parser.add_argument("--net", type=str, default="", help="network command")
 
     parser.add_argument("--max-kpts", type=int, default=200)
     parser.add_argument("--reliability-thr", type=float, default=0.8)
     parser.add_argument("--repeatability-thr", type=float, default=0.7)
-    parser.add_argument("--border", type=int, default=20,help='rm keypoints close to border')
+    parser.add_argument(
+        "--border", type=int, default=20, help="rm keypoints close to border"
+    )
+
+    parser.add_argument("--gpu", type=int, nargs="+", required=True, help="-1 for CPU")
+    parser.add_argument("--dbg", type=str, nargs="+", default=(), help="debug options")
 
-    parser.add_argument("--gpu", type=int, nargs='+', required=True, help='-1 for CPU')
-    parser.add_argument("--dbg", type=str, nargs='+', default=(), help='debug options')
-    
     args = parser.parse_args()
     args.dbg = set(args.dbg)
-    
+
     iscuda = common.torch_set_gpu(args.gpu)
-    device = torch.device('cuda' if iscuda else 'cpu')
+    device = torch.device("cuda" if iscuda else "cpu")
 
     # create network
-    checkpoint = torch.load(args.checkpoint, lambda a,b:a)
-    args.net = args.net or checkpoint['net']
-    print("\n>> Creating net = " + args.net) 
+    checkpoint = torch.load(args.checkpoint, lambda a, b: a)
+    args.net = args.net or checkpoint["net"]
+    print("\n>> Creating net = " + args.net)
     net = eval(args.net)
-    net.load_state_dict({k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()})
-    if iscuda: net = net.cuda()
+    net.load_state_dict(
+        {k.replace("module.", ""): v for k, v in checkpoint["state_dict"].items()}
+    )
+    if iscuda:
+        net = net.cuda()
     print(f" ( Model size: {common.model_size(net)/1000:.0f}K parameters )")
 
-    img = Image.open(args.img).convert('RGB')
-    if args.resize: img.thumbnail((args.resize,args.resize))
+    img = Image.open(args.img).convert("RGB")
+    if args.resize:
+        img.thumbnail((args.resize, args.resize))
     img = np.asarray(img)
-        
+
     detector = NonMaxSuppression(
-        rel_thr = args.reliability_thr, 
-        rep_thr = args.repeatability_thr)
+        rel_thr=args.reliability_thr, rep_thr=args.repeatability_thr
+    )
 
     with torch.no_grad():
         print(">> computing features...")
         res = net(imgs=[norm_RGB(img).unsqueeze(0).to(device)])
-        rela = res.get('reliability')
-        repe = res.get('repeatability')
-        kpts = detector(**res).T[:,[1,0]]
-        kpts = kpts[repe[0][0,0][kpts[:,1],kpts[:,0]].argsort()[-args.max_kpts:]]
+        rela = res.get("reliability")
+        repe = res.get("repeatability")
+        kpts = detector(**res).T[:, [1, 0]]
+        kpts = kpts[repe[0][0, 0][kpts[:, 1], kpts[:, 0]].argsort()[-args.max_kpts :]]
 
     fig = pl.figure("viz")
     kw = dict(cmap=pl.cm.RdYlGn, vmax=1)
-    crop = (slice(args.border,-args.border or 1),)*2
-    
-    if 'reliability' in args.dbg:
-    
+    crop = (slice(args.border, -args.border or 1),) * 2
+
+    if "reliability" in args.dbg:
+
         ax1 = pl.subplot(131)
         pl.imshow(img[crop], cmap=pl.cm.gray)
-        pl.xticks(()); pl.yticks(())
+        pl.xticks(())
+        pl.yticks(())
 
         pl.subplot(132)
         pl.imshow(img[crop], cmap=pl.cm.gray, alpha=0)
-        pl.xticks(()); pl.yticks(())
+        pl.xticks(())
+        pl.yticks(())
 
-        x,y = kpts[:,0:2].cpu().numpy().T - args.border
-        pl.plot(x,y,'+',c=(0,1,0),ms=10, scalex=0, scaley=0)
+        x, y = kpts[:, 0:2].cpu().numpy().T - args.border
+        pl.plot(x, y, "+", c=(0, 1, 0), ms=10, scalex=0, scaley=0)
 
         ax1 = pl.subplot(133)
-        rela = rela[0][0,0].cpu().numpy()
+        rela = rela[0][0, 0].cpu().numpy()
         pl.imshow(rela[crop], cmap=pl.cm.RdYlGn, vmax=1, vmin=0.9)
-        pl.xticks(()); pl.yticks(())
+        pl.xticks(())
+        pl.yticks(())
 
     else:
         ax1 = pl.subplot(131)
         pl.imshow(img[crop], cmap=pl.cm.gray)
-        pl.xticks(()); pl.yticks(())
+        pl.xticks(())
+        pl.yticks(())
 
-        x,y = kpts[:,0:2].cpu().numpy().T - args.border
-        pl.plot(x,y,'+',c=(0,1,0),ms=10, scalex=0, scaley=0)
+        x, y = kpts[:, 0:2].cpu().numpy().T - args.border
+        pl.plot(x, y, "+", c=(0, 1, 0), ms=10, scalex=0, scaley=0)
 
         pl.subplot(132)
         pl.imshow(img[crop], cmap=pl.cm.gray)
-        pl.xticks(()); pl.yticks(())
-        c = repe[0][0,0].cpu().numpy()
+        pl.xticks(())
+        pl.yticks(())
+        c = repe[0][0, 0].cpu().numpy()
         pl.imshow(transparent(smooth(c)[crop], 0.5, vmin=0, **kw))
 
         ax1 = pl.subplot(133)
         pl.imshow(img[crop], cmap=pl.cm.gray)
-        pl.xticks(()); pl.yticks(())
-        rela = rela[0][0,0].cpu().numpy()
+        pl.xticks(())
+        pl.yticks(())
+        rela = rela[0][0, 0].cpu().numpy()
         pl.imshow(transparent(rela[crop], 0.5, vmin=0.9, **kw))
 
     pl.gcf().set_size_inches(9, 2.73)
-    pl.subplots_adjust(0.01,0.01,0.99,0.99,hspace=0.1)
+    pl.subplots_adjust(0.01, 0.01, 0.99, 0.99, hspace=0.1)
     pl.savefig(args.out)
     pdb.set_trace()
-