Spaces:
Runtime error
Runtime error
| # -*- coding: UTF-8 -*- | |
| '''================================================= | |
| @Project -> File pram -> extract_features.py | |
| @IDE PyCharm | |
| @Author [email protected] | |
| @Date 07/02/2024 14:49 | |
| ==================================================''' | |
| import os | |
| import os.path as osp | |
| import h5py | |
| import numpy as np | |
| import progressbar | |
| import yaml | |
| import torch | |
| import cv2 | |
| import torch.utils.data as Data | |
| from tqdm import tqdm | |
| from types import SimpleNamespace | |
| import logging | |
| import pprint | |
| from pathlib import Path | |
| import argparse | |
| from nets.sfd2 import ResNet4x, extract_sfd2_return | |
| from nets.superpoint import SuperPoint, extract_sp_return | |
| confs = { | |
| 'superpoint-n4096': { | |
| 'output': 'feats-superpoint-n4096', | |
| 'model': { | |
| 'name': 'superpoint', | |
| 'outdim': 256, | |
| 'use_stability': False, | |
| 'nms_radius': 3, | |
| 'max_keypoints': 4096, | |
| 'conf_th': 0.005, | |
| 'multiscale': False, | |
| 'scales': [1.0], | |
| 'model_fn': osp.join(os.getcwd(), | |
| "weights/superpoint_v1.pth"), | |
| }, | |
| 'preprocessing': { | |
| 'grayscale': True, | |
| 'resize_max': False, | |
| }, | |
| }, | |
| 'resnet4x-20230511-210205-pho-0005': { | |
| 'output': 'feats-resnet4x-20230511-210205-pho-0005', | |
| 'model': { | |
| 'outdim': 128, | |
| 'name': 'resnet4x', | |
| 'use_stability': False, | |
| 'max_keypoints': 4096, | |
| 'conf_th': 0.005, | |
| 'multiscale': False, | |
| 'scales': [1.0], | |
| 'model_fn': osp.join(os.getcwd(), | |
| "weights/sfd2_20230511_210205_resnet4x.79.pth"), | |
| }, | |
| 'preprocessing': { | |
| 'grayscale': False, | |
| 'resize_max': False, | |
| }, | |
| 'mask': False, | |
| }, | |
| 'sfd2': { | |
| 'output': 'feats-sfd2', | |
| 'model': { | |
| 'outdim': 128, | |
| 'name': 'resnet4x', | |
| 'use_stability': False, | |
| 'max_keypoints': 4096, | |
| 'conf_th': 0.005, | |
| 'multiscale': False, | |
| 'scales': [1.0], | |
| 'model_fn': osp.join(os.getcwd(), | |
| "weights/sfd2_20230511_210205_resnet4x.79.pth"), | |
| }, | |
| 'preprocessing': { | |
| 'grayscale': False, | |
| 'resize_max': False, | |
| }, | |
| 'mask': False, | |
| }, | |
| } | |
| class ImageDataset(Data.Dataset): | |
| default_conf = { | |
| 'globs': ['*.jpg', '*.png', '*.jpeg', '*.JPG', '*.PNG'], | |
| 'grayscale': False, | |
| 'resize_max': None, | |
| 'resize_force': False, | |
| } | |
| def __init__(self, root, conf, image_list=None, | |
| mask_root=None): | |
| self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf}) | |
| self.root = root | |
| self.paths = [] | |
| if image_list is None: | |
| for g in conf.globs: | |
| self.paths += list(Path(root).glob('**/' + g)) | |
| if len(self.paths) == 0: | |
| raise ValueError(f'Could not find any image in root: {root}.') | |
| self.paths = [i.relative_to(root) for i in self.paths] | |
| else: | |
| with open(image_list, "r") as f: | |
| lines = f.readlines() | |
| for l in lines: | |
| l = l.strip() | |
| self.paths.append(Path(l)) | |
| logging.info(f'Found {len(self.paths)} images in root {root}.') | |
| if mask_root is not None: | |
| self.mask_root = mask_root | |
| else: | |
| self.mask_root = None | |
| def __getitem__(self, idx): | |
| path = self.paths[idx] | |
| if self.conf.grayscale: | |
| mode = cv2.IMREAD_GRAYSCALE | |
| else: | |
| mode = cv2.IMREAD_COLOR | |
| image = cv2.imread(str(self.root / path), mode) | |
| if not self.conf.grayscale: | |
| image = image[:, :, ::-1] # BGR to RGB | |
| if image is None: | |
| raise ValueError(f'Cannot read image {str(path)}.') | |
| image = image.astype(np.float32) | |
| size = image.shape[:2][::-1] | |
| w, h = size | |
| if self.conf.resize_max and (self.conf.resize_force | |
| or max(w, h) > self.conf.resize_max): | |
| scale = self.conf.resize_max / max(h, w) | |
| h_new, w_new = int(round(h * scale)), int(round(w * scale)) | |
| image = cv2.resize( | |
| image, (w_new, h_new), interpolation=cv2.INTER_CUBIC) | |
| if self.conf.grayscale: | |
| image = image[None] | |
| else: | |
| image = image.transpose((2, 0, 1)) # HxWxC to CxHxW | |
| image = image / 255. | |
| data = { | |
| 'name': str(path), | |
| 'image': image, | |
| 'original_size': np.array(size), | |
| } | |
| if self.mask_root is not None: | |
| mask_path = Path(str(path).replace("jpg", "png")) | |
| if osp.exists(mask_path): | |
| mask = cv2.imread(str(self.mask_root / mask_path)) | |
| mask = cv2.resize(mask, dsize=(image.shape[2], image.shape[1]), interpolation=cv2.INTER_NEAREST) | |
| else: | |
| mask = np.zeros(shape=(image.shape[1], image.shape[2], 3), dtype=np.uint8) | |
| data['mask'] = mask | |
| return data | |
| def __len__(self): | |
| return len(self.paths) | |
| def get_model(model_name, weight_path, outdim=128, **kwargs): | |
| if model_name == 'superpoint': | |
| model = SuperPoint(config={ | |
| 'descriptor_dim': 256, | |
| 'nms_radius': 4, | |
| 'keypoint_threshold': 0.005, | |
| 'max_keypoints': -1, | |
| 'remove_borders': 4, | |
| 'weight_path': weight_path, | |
| }).eval() | |
| extractor = extract_sp_return | |
| if model_name == 'resnet4x': | |
| model = ResNet4x(outdim=outdim).eval() | |
| model.load_state_dict(torch.load(weight_path)['state_dict'], strict=True) | |
| extractor = extract_sfd2_return | |
| return model, extractor | |
| def main(conf, image_dir, export_dir): | |
| logging.info('Extracting local features with configuration:' | |
| f'\n{pprint.pformat(conf)}') | |
| model, extractor = get_model(model_name=conf['model']['name'], weight_path=conf["model"]["model_fn"], | |
| use_stability=conf['model']['use_stability'], outdim=conf['model']['outdim']) | |
| model = model.cuda() | |
| loader = ImageDataset(image_dir, | |
| conf['preprocessing'], | |
| image_list=args.image_list, | |
| mask_root=None) | |
| loader = torch.utils.data.DataLoader(loader, num_workers=4) | |
| os.makedirs(export_dir, exist_ok=True) | |
| feature_path = Path(export_dir, conf['output'] + '.h5') | |
| feature_path.parent.mkdir(exist_ok=True, parents=True) | |
| feature_file = h5py.File(str(feature_path), 'a') | |
| with tqdm(total=len(loader)) as t: | |
| for idx, data in enumerate(loader): | |
| t.update() | |
| pred = extractor(model, img=data["image"], | |
| topK=conf["model"]["max_keypoints"], | |
| mask=None, | |
| conf_th=conf["model"]["conf_th"], | |
| scales=conf["model"]["scales"], | |
| ) | |
| # pred = {k: v[0].cpu().numpy() for k, v in pred.items()} | |
| pred['descriptors'] = pred['descriptors'].transpose() | |
| t.set_postfix(npoints=pred['keypoints'].shape[0]) | |
| # print(pred['keypoints'].shape) | |
| pred['image_size'] = original_size = data['original_size'][0].numpy() | |
| # pred['descriptors'] = pred['descriptors'].T | |
| if 'keypoints' in pred.keys(): | |
| size = np.array(data['image'].shape[-2:][::-1]) | |
| scales = (original_size / size).astype(np.float32) | |
| pred['keypoints'] = (pred['keypoints'] + .5) * scales[None] - .5 | |
| grp = feature_file.create_group(data['name'][0]) | |
| for k, v in pred.items(): | |
| # print(k, v.shape) | |
| grp.create_dataset(k, data=v) | |
| del pred | |
| feature_file.close() | |
| logging.info('Finished exporting features.') | |
| return feature_path | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--image_dir', type=Path, required=True) | |
| parser.add_argument('--image_list', type=str, default=None) | |
| parser.add_argument('--mask_dir', type=Path, default=None) | |
| parser.add_argument('--export_dir', type=Path, required=True) | |
| parser.add_argument('--conf', type=str, required=True, choices=list(confs.keys())) | |
| args = parser.parse_args() | |
| main(confs[args.conf], args.image_dir, args.export_dir) | |