import argparse import os import random import torch from PIL import ImageDraw from torchvision.transforms import transforms from dataset.base import Base as DatasetBase from backbone.base import Base as BackboneBase from bbox import BBox from model import Model from roi.pooler import Pooler from config.eval_config import EvalConfig as Config def _infer(path_to_input_image: str, path_to_output_image: str, path_to_checkpoint: str, dataset_name: str, backbone_name: str, prob_thresh: float): dataset_class = DatasetBase.from_name(dataset_name) backbone = BackboneBase.from_name(backbone_name)(pretrained=False) model = Model(backbone, dataset_class.num_classes(), pooler_mode=Config.POOLER_MODE, anchor_ratios=Config.ANCHOR_RATIOS, anchor_sizes=Config.ANCHOR_SIZES, rpn_pre_nms_top_n=Config.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=Config.RPN_POST_NMS_TOP_N).cuda() model.load(path_to_checkpoint) with torch.no_grad(): image = transforms.Image.open(path_to_input_image) image_tensor, scale = dataset_class.preprocess(image, Config.IMAGE_MIN_SIDE, Config.IMAGE_MAX_SIDE) detection_bboxes, detection_classes, detection_probs, _ = \ model.eval().forward(image_tensor.unsqueeze(dim=0).cuda()) detection_bboxes /= scale kept_indices = detection_probs > prob_thresh detection_bboxes = detection_bboxes[kept_indices] detection_classes = detection_classes[kept_indices] detection_probs = detection_probs[kept_indices] draw = ImageDraw.Draw(image) for bbox, cls, prob in zip(detection_bboxes.tolist(), detection_classes.tolist(), detection_probs.tolist()): color = random.choice(['red', 'green', 'blue', 'yellow', 'purple', 'white']) bbox = BBox(left=bbox[0], top=bbox[1], right=bbox[2], bottom=bbox[3]) category = dataset_class.LABEL_TO_CATEGORY_DICT[cls] draw.rectangle(((bbox.left, bbox.top), (bbox.right, bbox.bottom)), outline=color) draw.text((bbox.left, bbox.top), text=f'{category:s} {prob:.3f}', fill=color) image.save(path_to_output_image) print(f'Output image is saved to {path_to_output_image}') if __name__ == '__main__': def main(): parser = argparse.ArgumentParser() parser.add_argument('-s', '--dataset', type=str, choices=DatasetBase.OPTIONS, required=True, help='name of dataset') parser.add_argument('-b', '--backbone', type=str, choices=BackboneBase.OPTIONS, required=True, help='name of backbone model') parser.add_argument('-c', '--checkpoint', type=str, required=True, help='path to checkpoint') parser.add_argument('-p', '--probability_threshold', type=float, default=0.6, help='threshold of detection probability') parser.add_argument('--image_min_side', type=float, help='default: {:g}'.format(Config.IMAGE_MIN_SIDE)) parser.add_argument('--image_max_side', type=float, help='default: {:g}'.format(Config.IMAGE_MAX_SIDE)) parser.add_argument('--anchor_ratios', type=str, help='default: "{!s}"'.format(Config.ANCHOR_RATIOS)) parser.add_argument('--anchor_sizes', type=str, help='default: "{!s}"'.format(Config.ANCHOR_SIZES)) parser.add_argument('--pooler_mode', type=str, choices=Pooler.OPTIONS, help='default: {.value:s}'.format(Config.POOLER_MODE)) parser.add_argument('--rpn_pre_nms_top_n', type=int, help='default: {:d}'.format(Config.RPN_PRE_NMS_TOP_N)) parser.add_argument('--rpn_post_nms_top_n', type=int, help='default: {:d}'.format(Config.RPN_POST_NMS_TOP_N)) parser.add_argument('input', type=str, help='path to input image') parser.add_argument('output', type=str, help='path to output result image') args = parser.parse_args() path_to_input_image = args.input path_to_output_image = args.output dataset_name = args.dataset backbone_name = args.backbone path_to_checkpoint = args.checkpoint prob_thresh = args.probability_threshold os.makedirs(os.path.join(os.path.curdir, os.path.dirname(path_to_output_image)), exist_ok=True) Config.setup(image_min_side=args.image_min_side, image_max_side=args.image_max_side, anchor_ratios=args.anchor_ratios, anchor_sizes=args.anchor_sizes, pooler_mode=args.pooler_mode, rpn_pre_nms_top_n=args.rpn_pre_nms_top_n, rpn_post_nms_top_n=args.rpn_post_nms_top_n) print('Arguments:') for k, v in vars(args).items(): print(f'\t{k} = {v}') print(Config.describe()) _infer(path_to_input_image, path_to_output_image, path_to_checkpoint, dataset_name, backbone_name, prob_thresh) main()