init
Browse files- .gitignore +60 -0
- README.md +5 -8
- app.py +52 -0
- configs/Base-RCNN-FPN-OPENDET.yaml +25 -0
- configs/Base-RCNN-FPN.yaml +44 -0
- configs/Base-RetinaNet.yaml +26 -0
- configs/faster_rcnn_R_50_FPN_3x_baseline.yaml +16 -0
- configs/faster_rcnn_R_50_FPN_3x_ds.yaml +18 -0
- configs/faster_rcnn_R_50_FPN_3x_opendet.yaml +16 -0
- configs/faster_rcnn_R_50_FPN_3x_proser.yaml +16 -0
- configs/faster_rcnn_Swin_T_FPN_3x_opendet.yaml +25 -0
- configs/retinanet_R_50_FPN_3x_baseline.yaml +17 -0
- configs/retinanet_R_50_FPN_3x_opendet.yaml +25 -0
- datasets/README.md +51 -0
- demo/demo.py +194 -0
- demo/predictor.py +224 -0
- opendet2/__init__.py +7 -0
- opendet2/config/__init__.py +3 -0
- opendet2/config/defaults.py +50 -0
- opendet2/data/__init__.py +4 -0
- opendet2/data/build.py +299 -0
- opendet2/data/builtin.py +31 -0
- opendet2/data/voc_coco.py +35 -0
- opendet2/engine/__init__.py +3 -0
- opendet2/engine/defaults.py +441 -0
- opendet2/evaluation/__init__.py +3 -0
- opendet2/evaluation/pascal_voc_evaluation.py +377 -0
- opendet2/modeling/__init__.py +5 -0
- opendet2/modeling/backbone/__init__.py +3 -0
- opendet2/modeling/backbone/swin_transformer.py +726 -0
- opendet2/modeling/layers/__init__.py +3 -0
- opendet2/modeling/layers/mlp.py +46 -0
- opendet2/modeling/losses/__init__.py +4 -0
- opendet2/modeling/losses/instance_contrastive_loss.py +40 -0
- opendet2/modeling/losses/unknown_probability_loss.py +93 -0
- opendet2/modeling/meta_arch/__init__.py +3 -0
- opendet2/modeling/meta_arch/retinanet.py +483 -0
- opendet2/modeling/roi_heads/__init__.py +4 -0
- opendet2/modeling/roi_heads/box_head.py +163 -0
- opendet2/modeling/roi_heads/fast_rcnn.py +645 -0
- opendet2/modeling/roi_heads/roi_heads.py +150 -0
- opendet2/solver/__init__.py +3 -0
- opendet2/solver/build.py +57 -0
- setup.py +14 -0
- tools/convert_swin_to_d2.py +36 -0
- tools/train_net.py +79 -0
.gitignore
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# output dir
|
2 |
+
output*
|
3 |
+
instant_test_output
|
4 |
+
inference_test_output
|
5 |
+
|
6 |
+
|
7 |
+
*.png
|
8 |
+
*.json
|
9 |
+
*.diff
|
10 |
+
*.jpg
|
11 |
+
!/projects/DensePose/doc/images/*.jpg
|
12 |
+
|
13 |
+
# compilation and distribution
|
14 |
+
__pycache__
|
15 |
+
_ext
|
16 |
+
*.pyc
|
17 |
+
*.pyd
|
18 |
+
*.so
|
19 |
+
*.dll
|
20 |
+
*.egg-info/
|
21 |
+
build/
|
22 |
+
dist/
|
23 |
+
wheels/
|
24 |
+
|
25 |
+
# pytorch/python/numpy formats
|
26 |
+
*.pth
|
27 |
+
*.pkl
|
28 |
+
*.npy
|
29 |
+
*.ts
|
30 |
+
model_ts*.txt
|
31 |
+
|
32 |
+
# ipython/jupyter notebooks
|
33 |
+
*.ipynb
|
34 |
+
**/.ipynb_checkpoints/
|
35 |
+
|
36 |
+
# Editor temporaries
|
37 |
+
*.swn
|
38 |
+
*.swo
|
39 |
+
*.swp
|
40 |
+
*~
|
41 |
+
|
42 |
+
# editor settings
|
43 |
+
.idea
|
44 |
+
.vscode
|
45 |
+
_darcs
|
46 |
+
|
47 |
+
# project dirs
|
48 |
+
/detectron2/model_zoo/configs
|
49 |
+
/datasets/*
|
50 |
+
!/datasets/*.*
|
51 |
+
/projects/*/datasets
|
52 |
+
/models
|
53 |
+
/snippet
|
54 |
+
res*
|
55 |
+
/checkpoints
|
56 |
+
detectron2
|
57 |
+
results
|
58 |
+
checkpoints
|
59 |
+
demo_images
|
60 |
+
exps
|
README.md
CHANGED
@@ -1,12 +1,9 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 2.8.12
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
|
|
|
1 |
---
|
2 |
+
title: OpenDet2
|
3 |
+
emoji:🦓
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: pink
|
6 |
sdk: gradio
|
|
|
7 |
app_file: app.py
|
8 |
pinned: false
|
9 |
+
---
|
|
|
|
app.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.system('pip install torch==1.9 torchvision')
|
3 |
+
os.system('pip install detectron2==0.5 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.9/index.html')
|
4 |
+
os.system('pip install timm opencv-python-headless')
|
5 |
+
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
from demo.predictor import VisualizationDemo
|
10 |
+
from detectron2.config import get_cfg
|
11 |
+
from opendet2 import add_opendet_config
|
12 |
+
|
13 |
+
|
14 |
+
model_cfgs = {
|
15 |
+
"FR-CNN": ["configs/faster_rcnn_R_50_FPN_3x_baseline.yaml", "frcnn_r50.pth"],
|
16 |
+
"OpenDet-R50": ["configs/faster_rcnn_R_50_FPN_3x_opendet.yaml", "opendet2_r50.pth"],
|
17 |
+
"OpenDet-SwinT": ["configs/faster_rcnn_Swin_T_FPN_18e_opendet_voc.yaml", "opendet2_swint.pth"],
|
18 |
+
}
|
19 |
+
|
20 |
+
|
21 |
+
def setup_cfg(model):
|
22 |
+
cfg = get_cfg()
|
23 |
+
add_opendet_config(cfg)
|
24 |
+
model_cfg = model_cfgs[model]
|
25 |
+
cfg.merge_from_file(model_cfg[0])
|
26 |
+
cfg.MODEL.WEIGHTS = model_cfg[1]
|
27 |
+
cfg.MODEL.DEVICE = "cpu"
|
28 |
+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
|
29 |
+
cfg.MODEL.ROI_HEADS.VIS_IOU_THRESH = 0.8
|
30 |
+
cfg.freeze()
|
31 |
+
return cfg
|
32 |
+
|
33 |
+
|
34 |
+
def inference(input, model):
|
35 |
+
cfg = setup_cfg(model)
|
36 |
+
demo = VisualizationDemo(cfg)
|
37 |
+
# use PIL, to be consistent with evaluation
|
38 |
+
predictions, visualized_output = demo.run_on_image(input)
|
39 |
+
output = visualized_output.get_image()[:, :, ::-1]
|
40 |
+
return output
|
41 |
+
|
42 |
+
|
43 |
+
iface = gr.Interface(
|
44 |
+
inference,
|
45 |
+
[
|
46 |
+
"image",
|
47 |
+
gr.inputs.Radio(
|
48 |
+
["FR-CNN", "OpenDet-R50", "OpenDet-SwinT"], default='OpenDet-R50'),
|
49 |
+
],
|
50 |
+
"image")
|
51 |
+
|
52 |
+
iface.launch()
|
configs/Base-RCNN-FPN-OPENDET.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "./Base-RCNN-FPN.yaml"
|
2 |
+
MODEL:
|
3 |
+
MASK_ON: False
|
4 |
+
ROI_HEADS:
|
5 |
+
NAME: "OpenSetStandardROIHeads"
|
6 |
+
NUM_CLASSES: 81
|
7 |
+
NUM_KNOWN_CLASSES: 20
|
8 |
+
ROI_BOX_HEAD:
|
9 |
+
NAME: "FastRCNNSeparateConvFCHead"
|
10 |
+
OUTPUT_LAYERS: "OpenDetFastRCNNOutputLayers"
|
11 |
+
CLS_AGNOSTIC_BBOX_REG: True
|
12 |
+
UPLOSS:
|
13 |
+
START_ITER: 100
|
14 |
+
SAMPLING_METRIC: "min_score"
|
15 |
+
TOPK: 3
|
16 |
+
ALPHA: 1.0
|
17 |
+
WEIGHT: 1.0
|
18 |
+
ICLOSS:
|
19 |
+
OUT_DIM: 128
|
20 |
+
QUEUE_SIZE: 256
|
21 |
+
IN_QUEUE_SIZE: 16
|
22 |
+
BATCH_IOU_THRESH: 0.5
|
23 |
+
QUEUE_IOU_THRESH: 0.7
|
24 |
+
TEMPERATURE: 0.1
|
25 |
+
WEIGHT: 0.1
|
configs/Base-RCNN-FPN.yaml
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The same as detectron2/configs/Base-RCNN-FPN.yaml
|
2 |
+
MODEL:
|
3 |
+
META_ARCHITECTURE: "GeneralizedRCNN"
|
4 |
+
BACKBONE:
|
5 |
+
NAME: "build_resnet_fpn_backbone"
|
6 |
+
RESNETS:
|
7 |
+
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
|
8 |
+
FPN:
|
9 |
+
IN_FEATURES: ["res2", "res3", "res4", "res5"]
|
10 |
+
ANCHOR_GENERATOR:
|
11 |
+
SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map
|
12 |
+
ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps)
|
13 |
+
RPN:
|
14 |
+
IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
|
15 |
+
PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level
|
16 |
+
PRE_NMS_TOPK_TEST: 1000 # Per FPN level
|
17 |
+
# Detectron1 uses 2000 proposals per-batch,
|
18 |
+
# (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue)
|
19 |
+
# which is approximately 1000 proposals per-image since the default batch size for FPN is 2.
|
20 |
+
POST_NMS_TOPK_TRAIN: 1000
|
21 |
+
POST_NMS_TOPK_TEST: 1000
|
22 |
+
ROI_HEADS:
|
23 |
+
NAME: "StandardROIHeads"
|
24 |
+
IN_FEATURES: ["p2", "p3", "p4", "p5"]
|
25 |
+
ROI_BOX_HEAD:
|
26 |
+
NAME: "FastRCNNConvFCHead"
|
27 |
+
NUM_FC: 2
|
28 |
+
POOLER_RESOLUTION: 7
|
29 |
+
ROI_MASK_HEAD:
|
30 |
+
NAME: "MaskRCNNConvUpsampleHead"
|
31 |
+
NUM_CONV: 4
|
32 |
+
POOLER_RESOLUTION: 14
|
33 |
+
DATASETS:
|
34 |
+
TRAIN: ("coco_2017_train",)
|
35 |
+
TEST: ("coco_2017_val",)
|
36 |
+
SOLVER:
|
37 |
+
IMS_PER_BATCH: 16
|
38 |
+
BASE_LR: 0.02
|
39 |
+
STEPS: (60000, 80000)
|
40 |
+
MAX_ITER: 90000
|
41 |
+
INPUT:
|
42 |
+
MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
|
43 |
+
MIN_SIZE_TEST: 800
|
44 |
+
VERSION: 2
|
configs/Base-RetinaNet.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The same as detectron2/configs/Base-RetinaNet.yaml
|
2 |
+
MODEL:
|
3 |
+
META_ARCHITECTURE: "RetinaNet"
|
4 |
+
BACKBONE:
|
5 |
+
NAME: "build_retinanet_resnet_fpn_backbone"
|
6 |
+
RESNETS:
|
7 |
+
OUT_FEATURES: ["res3", "res4", "res5"]
|
8 |
+
ANCHOR_GENERATOR:
|
9 |
+
SIZES: !!python/object/apply:eval ["[[x, x * 2**(1.0/3), x * 2**(2.0/3) ] for x in [32, 64, 128, 256, 512 ]]"]
|
10 |
+
FPN:
|
11 |
+
IN_FEATURES: ["res3", "res4", "res5"]
|
12 |
+
RETINANET:
|
13 |
+
IOU_THRESHOLDS: [0.4, 0.5]
|
14 |
+
IOU_LABELS: [0, -1, 1]
|
15 |
+
SMOOTH_L1_LOSS_BETA: 0.0
|
16 |
+
DATASETS:
|
17 |
+
TRAIN: ("coco_2017_train",)
|
18 |
+
TEST: ("coco_2017_val",)
|
19 |
+
SOLVER:
|
20 |
+
IMS_PER_BATCH: 16
|
21 |
+
BASE_LR: 0.01 # Note that RetinaNet uses a different default learning rate
|
22 |
+
STEPS: (60000, 80000)
|
23 |
+
MAX_ITER: 90000
|
24 |
+
INPUT:
|
25 |
+
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
|
26 |
+
VERSION: 2
|
configs/faster_rcnn_R_50_FPN_3x_baseline.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "./Base-RCNN-FPN-OPENDET.yaml"
|
2 |
+
MODEL:
|
3 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
4 |
+
RESNETS:
|
5 |
+
DEPTH: 50
|
6 |
+
ROI_BOX_HEAD:
|
7 |
+
OUTPUT_LAYERS: "CosineFastRCNNOutputLayers" # baseline use a simple cosine FRCNN
|
8 |
+
DATASETS:
|
9 |
+
TRAIN: ('voc_2007_train', 'voc_2012_trainval')
|
10 |
+
TEST: ('voc_2007_test', 'voc_coco_20_40_test', 'voc_coco_20_60_test', 'voc_coco_20_80_test', 'voc_coco_2500_test', 'voc_coco_5000_test', 'voc_coco_10000_test', 'voc_coco_20000_test')
|
11 |
+
SOLVER:
|
12 |
+
STEPS: (21000, 29000)
|
13 |
+
MAX_ITER: 32000
|
14 |
+
WARMUP_ITERS: 100
|
15 |
+
AMP:
|
16 |
+
ENABLED: True
|
configs/faster_rcnn_R_50_FPN_3x_ds.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "./Base-RCNN-FPN-OPENDET.yaml"
|
2 |
+
MODEL:
|
3 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
4 |
+
RESNETS:
|
5 |
+
DEPTH: 50
|
6 |
+
ROI_HEADS:
|
7 |
+
NAME: "DropoutStandardROIHeads"
|
8 |
+
ROI_BOX_HEAD:
|
9 |
+
OUTPUT_LAYERS: "DropoutFastRCNNOutputLayers"
|
10 |
+
DATASETS:
|
11 |
+
TRAIN: ('voc_2007_train', 'voc_2012_trainval')
|
12 |
+
TEST: ('voc_2007_test', 'voc_coco_20_40_test', 'voc_coco_20_60_test', 'voc_coco_20_80_test', 'voc_coco_2500_test', 'voc_coco_5000_test', 'voc_coco_10000_test', 'voc_coco_20000_test')
|
13 |
+
SOLVER:
|
14 |
+
STEPS: (21000, 29000)
|
15 |
+
MAX_ITER: 32000
|
16 |
+
WARMUP_ITERS: 100
|
17 |
+
AMP:
|
18 |
+
ENABLED: True
|
configs/faster_rcnn_R_50_FPN_3x_opendet.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "./Base-RCNN-FPN-OPENDET.yaml"
|
2 |
+
MODEL:
|
3 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
4 |
+
RESNETS:
|
5 |
+
DEPTH: 50
|
6 |
+
DATASETS:
|
7 |
+
TRAIN: ('voc_2007_train', 'voc_2012_trainval')
|
8 |
+
TEST: ('voc_2007_test', 'voc_coco_20_40_test', 'voc_coco_20_60_test', 'voc_coco_20_80_test', 'voc_coco_2500_test', 'voc_coco_5000_test', 'voc_coco_10000_test', 'voc_coco_20000_test')
|
9 |
+
SOLVER:
|
10 |
+
STEPS: (21000, 29000)
|
11 |
+
MAX_ITER: 32000
|
12 |
+
WARMUP_ITERS: 100
|
13 |
+
AMP:
|
14 |
+
ENABLED: True
|
15 |
+
|
16 |
+
# UPLOSS.WEIGHT: former two are 0.5, the last is 1.0
|
configs/faster_rcnn_R_50_FPN_3x_proser.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "./Base-RCNN-FPN-OPENDET.yaml"
|
2 |
+
MODEL:
|
3 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
4 |
+
RESNETS:
|
5 |
+
DEPTH: 50
|
6 |
+
ROI_BOX_HEAD:
|
7 |
+
OUTPUT_LAYERS: "PROSERFastRCNNOutputLayers"
|
8 |
+
DATASETS:
|
9 |
+
TRAIN: ('voc_2007_train', 'voc_2012_trainval')
|
10 |
+
TEST: ('voc_2007_test', 'voc_coco_20_40_test', 'voc_coco_20_60_test', 'voc_coco_20_80_test', 'voc_coco_2500_test', 'voc_coco_5000_test', 'voc_coco_10000_test', 'voc_coco_20000_test')
|
11 |
+
SOLVER:
|
12 |
+
STEPS: (21000, 29000)
|
13 |
+
MAX_ITER: 32000
|
14 |
+
WARMUP_ITERS: 100
|
15 |
+
AMP:
|
16 |
+
ENABLED: True
|
configs/faster_rcnn_Swin_T_FPN_3x_opendet.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "./Base-RCNN-FPN-OPENDET.yaml"
|
2 |
+
MODEL:
|
3 |
+
WEIGHTS: "checkpoints/swin_tiny_patch4_window7_224_d2.pth"
|
4 |
+
PIXEL_MEAN: [123.675, 116.28, 103.53]
|
5 |
+
PIXEL_STD: [58.395, 57.12, 57.375]
|
6 |
+
RESNETS:
|
7 |
+
DEPTH: 50
|
8 |
+
BACKBONE:
|
9 |
+
NAME: "build_swint_fpn_backbone"
|
10 |
+
SWINT:
|
11 |
+
OUT_FEATURES: ["stage2", "stage3", "stage4", "stage5"]
|
12 |
+
FPN:
|
13 |
+
IN_FEATURES: ["stage2", "stage3", "stage4", "stage5"]
|
14 |
+
DATASETS:
|
15 |
+
TRAIN: ('voc_2007_train', 'voc_2012_trainval')
|
16 |
+
TEST: ('voc_2007_test', 'voc_coco_20_40_test', 'voc_coco_20_60_test', 'voc_coco_20_80_test', 'voc_coco_2500_test', 'voc_coco_5000_test', 'voc_coco_10000_test', 'voc_coco_20000_test')
|
17 |
+
SOLVER:
|
18 |
+
STEPS: (21000, 29000)
|
19 |
+
MAX_ITER: 32000
|
20 |
+
WARMUP_ITERS: 100
|
21 |
+
WEIGHT_DECAY: 0.05
|
22 |
+
BASE_LR: 0.0001
|
23 |
+
OPTIMIZER: "ADAMW"
|
24 |
+
AMP:
|
25 |
+
ENABLED: True
|
configs/retinanet_R_50_FPN_3x_baseline.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "./Base-RetinaNet.yaml"
|
2 |
+
MODEL:
|
3 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
4 |
+
RESNETS:
|
5 |
+
DEPTH: 50
|
6 |
+
RETINANET:
|
7 |
+
NUM_CLASSES: 81
|
8 |
+
NUM_KNOWN_CLASSES: 20
|
9 |
+
DATASETS:
|
10 |
+
TRAIN: ('voc_2007_train', 'voc_2012_trainval')
|
11 |
+
TEST: ('voc_2007_test', 'voc_coco_20_40_test', 'voc_coco_20_60_test', 'voc_coco_20_80_test', 'voc_coco_2500_test', 'voc_coco_5000_test', 'voc_coco_10000_test', 'voc_coco_20000_test')
|
12 |
+
SOLVER:
|
13 |
+
STEPS: (21000, 29000)
|
14 |
+
MAX_ITER: 32000
|
15 |
+
WARMUP_ITERS: 1000
|
16 |
+
AMP:
|
17 |
+
ENABLED: True
|
configs/retinanet_R_50_FPN_3x_opendet.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "./Base-RetinaNet.yaml"
|
2 |
+
MODEL:
|
3 |
+
META_ARCHITECTURE: "OpenSetRetinaNet"
|
4 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
5 |
+
RESNETS:
|
6 |
+
DEPTH: 50
|
7 |
+
RETINANET:
|
8 |
+
NUM_CLASSES: 81
|
9 |
+
NUM_KNOWN_CLASSES: 20
|
10 |
+
DATASETS:
|
11 |
+
TRAIN: ('voc_2007_train', 'voc_2012_trainval')
|
12 |
+
TEST: ('voc_2007_test', 'voc_coco_20_40_test', 'voc_coco_20_60_test', 'voc_coco_20_80_test', 'voc_coco_2500_test', 'voc_coco_5000_test', 'voc_coco_10000_test', 'voc_coco_20000_test')
|
13 |
+
SOLVER:
|
14 |
+
STEPS: (21000, 29000)
|
15 |
+
MAX_ITER: 32000
|
16 |
+
WARMUP_ITERS: 1000
|
17 |
+
AMP:
|
18 |
+
ENABLED: True
|
19 |
+
UPLOSS:
|
20 |
+
TOPK: 10
|
21 |
+
WEIGHT: 0.2
|
22 |
+
ICLOSS:
|
23 |
+
QUEUE_SIZE: 1024
|
24 |
+
IN_QUEUE_SIZE: 64
|
25 |
+
WEIGHT: 0.2
|
datasets/README.md
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use Builtin Datasets
|
2 |
+
|
3 |
+
A dataset can be used by accessing [DatasetCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.DatasetCatalog)
|
4 |
+
for its data, or [MetadataCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.MetadataCatalog) for its metadata (class names, etc).
|
5 |
+
This document explains how to setup the builtin datasets so they can be used by the above APIs.
|
6 |
+
[Use Custom Datasets](https://detectron2.readthedocs.io/tutorials/datasets.html) gives a deeper dive on how to use `DatasetCatalog` and `MetadataCatalog`,
|
7 |
+
and how to add new datasets to them.
|
8 |
+
|
9 |
+
Detectron2 has builtin support for a few datasets.
|
10 |
+
The datasets are assumed to exist in a directory specified by the environment variable
|
11 |
+
`DETECTRON2_DATASETS`.
|
12 |
+
Under this directory, detectron2 will look for datasets in the structure described below, if needed.
|
13 |
+
```
|
14 |
+
$DETECTRON2_DATASETS/
|
15 |
+
coco/
|
16 |
+
VOC20{07,12}/
|
17 |
+
```
|
18 |
+
|
19 |
+
You can set the location for builtin datasets by `export DETECTRON2_DATASETS=/path/to/datasets`.
|
20 |
+
If left unset, the default is `./datasets` relative to your current working directory.
|
21 |
+
|
22 |
+
The [model zoo](https://github.com/facebookresearch/detectron2/blob/master/MODEL_ZOO.md)
|
23 |
+
contains configs and models that use these builtin datasets.
|
24 |
+
|
25 |
+
## Expected dataset structure for [COCO instance/keypoint detection](https://cocodataset.org/#download):
|
26 |
+
|
27 |
+
```
|
28 |
+
coco/
|
29 |
+
annotations/
|
30 |
+
instances_{train,val}2017.json
|
31 |
+
person_keypoints_{train,val}2017.json
|
32 |
+
{train,val}2017/
|
33 |
+
# image files that are mentioned in the corresponding json
|
34 |
+
```
|
35 |
+
|
36 |
+
You can use the 2014 version of the dataset as well.
|
37 |
+
|
38 |
+
Some of the builtin tests (`dev/run_*_tests.sh`) uses a tiny version of the COCO dataset,
|
39 |
+
which you can download with `./datasets/prepare_for_tests.sh`.
|
40 |
+
|
41 |
+
## Expected dataset structure for [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/index.html):
|
42 |
+
```
|
43 |
+
VOC20{07,12}/
|
44 |
+
Annotations/
|
45 |
+
ImageSets/
|
46 |
+
Main/
|
47 |
+
trainval.txt
|
48 |
+
test.txt
|
49 |
+
# train.txt or val.txt, if you use these splits
|
50 |
+
JPEGImages/
|
51 |
+
```
|
demo/demo.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import argparse
|
3 |
+
import glob
|
4 |
+
import multiprocessing as mp
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
import tempfile
|
8 |
+
import time
|
9 |
+
import warnings
|
10 |
+
import cv2
|
11 |
+
import tqdm
|
12 |
+
|
13 |
+
from detectron2.config import get_cfg
|
14 |
+
from detectron2.data.detection_utils import read_image
|
15 |
+
from detectron2.utils.logger import setup_logger
|
16 |
+
|
17 |
+
from predictor import VisualizationDemo
|
18 |
+
|
19 |
+
import sys
|
20 |
+
sys.path.insert(-1, "../")
|
21 |
+
from opendet2 import add_opendet_config, builtin, OpenDetTrainer
|
22 |
+
|
23 |
+
# constants
|
24 |
+
WINDOW_NAME = "COCO detections"
|
25 |
+
|
26 |
+
|
27 |
+
def setup_cfg(args):
|
28 |
+
# load config from file and command-line arguments
|
29 |
+
cfg = get_cfg()
|
30 |
+
add_opendet_config(cfg)
|
31 |
+
# To use demo for Panoptic-DeepLab, please uncomment the following two lines.
|
32 |
+
# from detectron2.projects.panoptic_deeplab import add_panoptic_deeplab_config # noqa
|
33 |
+
# add_panoptic_deeplab_config(cfg)
|
34 |
+
cfg.merge_from_file(args.config_file)
|
35 |
+
cfg.merge_from_list(args.opts)
|
36 |
+
# Set score_threshold for builtin models
|
37 |
+
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold
|
38 |
+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
|
39 |
+
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold
|
40 |
+
cfg.MODEL.ROI_HEADS.VIS_IOU_THRESH = 0.8
|
41 |
+
cfg.freeze()
|
42 |
+
return cfg
|
43 |
+
|
44 |
+
|
45 |
+
def get_parser():
|
46 |
+
parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
|
47 |
+
parser.add_argument(
|
48 |
+
"--config-file",
|
49 |
+
default="configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml",
|
50 |
+
metavar="FILE",
|
51 |
+
help="path to config file",
|
52 |
+
)
|
53 |
+
parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.")
|
54 |
+
parser.add_argument("--video-input", help="Path to video file.")
|
55 |
+
parser.add_argument(
|
56 |
+
"--input",
|
57 |
+
nargs="+",
|
58 |
+
help="A list of space separated input images; "
|
59 |
+
"or a single glob pattern such as 'directory/*.jpg'",
|
60 |
+
)
|
61 |
+
parser.add_argument(
|
62 |
+
"--output",
|
63 |
+
help="A file or directory to save output visualizations. "
|
64 |
+
"If not given, will show output in an OpenCV window.",
|
65 |
+
)
|
66 |
+
|
67 |
+
parser.add_argument(
|
68 |
+
"--confidence-threshold",
|
69 |
+
type=float,
|
70 |
+
default=0.5,
|
71 |
+
help="Minimum score for instance predictions to be shown",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--opts",
|
75 |
+
help="Modify config options using the command-line 'KEY VALUE' pairs",
|
76 |
+
default=[],
|
77 |
+
nargs=argparse.REMAINDER,
|
78 |
+
)
|
79 |
+
return parser
|
80 |
+
|
81 |
+
|
82 |
+
def test_opencv_video_format(codec, file_ext):
|
83 |
+
with tempfile.TemporaryDirectory(prefix="video_format_test") as dir:
|
84 |
+
filename = os.path.join(dir, "test_file" + file_ext)
|
85 |
+
writer = cv2.VideoWriter(
|
86 |
+
filename=filename,
|
87 |
+
fourcc=cv2.VideoWriter_fourcc(*codec),
|
88 |
+
fps=float(30),
|
89 |
+
frameSize=(10, 10),
|
90 |
+
isColor=True,
|
91 |
+
)
|
92 |
+
[writer.write(np.zeros((10, 10, 3), np.uint8)) for _ in range(30)]
|
93 |
+
writer.release()
|
94 |
+
if os.path.isfile(filename):
|
95 |
+
return True
|
96 |
+
return False
|
97 |
+
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
mp.set_start_method("spawn", force=True)
|
101 |
+
args = get_parser().parse_args()
|
102 |
+
setup_logger(name="fvcore")
|
103 |
+
logger = setup_logger()
|
104 |
+
logger.info("Arguments: " + str(args))
|
105 |
+
|
106 |
+
cfg = setup_cfg(args)
|
107 |
+
|
108 |
+
demo = VisualizationDemo(cfg)
|
109 |
+
|
110 |
+
if args.input:
|
111 |
+
if len(args.input) == 1:
|
112 |
+
args.input = glob.glob(os.path.expanduser(args.input[0]))
|
113 |
+
assert args.input, "The input path(s) was not found"
|
114 |
+
for path in tqdm.tqdm(args.input, disable=not args.output):
|
115 |
+
# use PIL, to be consistent with evaluation
|
116 |
+
img = read_image(path, format="BGR")
|
117 |
+
start_time = time.time()
|
118 |
+
predictions, visualized_output = demo.run_on_image(img)
|
119 |
+
logger.info(
|
120 |
+
"{}: {} in {:.2f}s".format(
|
121 |
+
path,
|
122 |
+
"detected {} instances".format(len(predictions["instances"]))
|
123 |
+
if "instances" in predictions
|
124 |
+
else "finished",
|
125 |
+
time.time() - start_time,
|
126 |
+
)
|
127 |
+
)
|
128 |
+
|
129 |
+
if args.output:
|
130 |
+
if os.path.isdir(args.output):
|
131 |
+
assert os.path.isdir(args.output), args.output
|
132 |
+
out_filename = os.path.join(args.output, os.path.basename(path))
|
133 |
+
else:
|
134 |
+
assert len(args.input) == 1, "Please specify a directory with args.output"
|
135 |
+
out_filename = args.output
|
136 |
+
visualized_output.save(out_filename)
|
137 |
+
else:
|
138 |
+
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
|
139 |
+
cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1])
|
140 |
+
if cv2.waitKey(0) == 27:
|
141 |
+
break # esc to quit
|
142 |
+
elif args.webcam:
|
143 |
+
assert args.input is None, "Cannot have both --input and --webcam!"
|
144 |
+
assert args.output is None, "output not yet supported with --webcam!"
|
145 |
+
cam = cv2.VideoCapture(0)
|
146 |
+
for vis in tqdm.tqdm(demo.run_on_video(cam)):
|
147 |
+
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
|
148 |
+
cv2.imshow(WINDOW_NAME, vis)
|
149 |
+
if cv2.waitKey(1) == 27:
|
150 |
+
break # esc to quit
|
151 |
+
cam.release()
|
152 |
+
cv2.destroyAllWindows()
|
153 |
+
elif args.video_input:
|
154 |
+
video = cv2.VideoCapture(args.video_input)
|
155 |
+
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
|
156 |
+
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
157 |
+
frames_per_second = video.get(cv2.CAP_PROP_FPS)
|
158 |
+
num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
159 |
+
basename = os.path.basename(args.video_input)
|
160 |
+
codec, file_ext = (
|
161 |
+
("x264", ".mkv") if test_opencv_video_format("x264", ".mkv") else ("mp4v", ".mp4")
|
162 |
+
)
|
163 |
+
if codec == ".mp4v":
|
164 |
+
warnings.warn("x264 codec not available, switching to mp4v")
|
165 |
+
if args.output:
|
166 |
+
if os.path.isdir(args.output):
|
167 |
+
output_fname = os.path.join(args.output, basename)
|
168 |
+
output_fname = os.path.splitext(output_fname)[0] + file_ext
|
169 |
+
else:
|
170 |
+
output_fname = args.output
|
171 |
+
assert not os.path.isfile(output_fname), output_fname
|
172 |
+
output_file = cv2.VideoWriter(
|
173 |
+
filename=output_fname,
|
174 |
+
# some installation of opencv may not support x264 (due to its license),
|
175 |
+
# you can try other format (e.g. MPEG)
|
176 |
+
fourcc=cv2.VideoWriter_fourcc(*codec),
|
177 |
+
fps=float(frames_per_second),
|
178 |
+
frameSize=(width, height),
|
179 |
+
isColor=True,
|
180 |
+
)
|
181 |
+
assert os.path.isfile(args.video_input)
|
182 |
+
for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames):
|
183 |
+
if args.output:
|
184 |
+
output_file.write(vis_frame)
|
185 |
+
else:
|
186 |
+
cv2.namedWindow(basename, cv2.WINDOW_NORMAL)
|
187 |
+
cv2.imshow(basename, vis_frame)
|
188 |
+
if cv2.waitKey(1) == 27:
|
189 |
+
break # esc to quit
|
190 |
+
video.release()
|
191 |
+
if args.output:
|
192 |
+
output_file.release()
|
193 |
+
else:
|
194 |
+
cv2.destroyAllWindows()
|
demo/predictor.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import atexit
|
3 |
+
import bisect
|
4 |
+
import multiprocessing as mp
|
5 |
+
from collections import deque
|
6 |
+
import cv2
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from detectron2.data import MetadataCatalog
|
10 |
+
from detectron2.engine.defaults import DefaultPredictor
|
11 |
+
from detectron2.utils.video_visualizer import VideoVisualizer
|
12 |
+
from detectron2.utils.visualizer import ColorMode, Visualizer
|
13 |
+
from detectron2.data.datasets.builtin_meta import _get_coco_instances_meta
|
14 |
+
|
15 |
+
|
16 |
+
class VisualizationDemo(object):
|
17 |
+
def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
|
18 |
+
"""
|
19 |
+
Args:
|
20 |
+
cfg (CfgNode):
|
21 |
+
instance_mode (ColorMode):
|
22 |
+
parallel (bool): whether to run the model in different processes from visualization.
|
23 |
+
Useful since the visualization logic can be slow.
|
24 |
+
"""
|
25 |
+
self.metadata = MetadataCatalog.get(
|
26 |
+
cfg.DATASETS.TEST[-1] if len(cfg.DATASETS.TEST) else "__unused"
|
27 |
+
)
|
28 |
+
thing_colors = _get_coco_instances_meta()["thing_colors"]
|
29 |
+
thing_colors.append((0,0,0))
|
30 |
+
self.metadata.set(thing_colors=thing_colors)
|
31 |
+
self.cpu_device = torch.device("cpu")
|
32 |
+
self.instance_mode = instance_mode
|
33 |
+
|
34 |
+
self.parallel = parallel
|
35 |
+
if parallel:
|
36 |
+
num_gpu = torch.cuda.device_count()
|
37 |
+
self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu)
|
38 |
+
else:
|
39 |
+
self.predictor = DefaultPredictor(cfg)
|
40 |
+
|
41 |
+
def run_on_image(self, image):
|
42 |
+
"""
|
43 |
+
Args:
|
44 |
+
image (np.ndarray): an image of shape (H, W, C) (in BGR order).
|
45 |
+
This is the format used by OpenCV.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
predictions (dict): the output of the model.
|
49 |
+
vis_output (VisImage): the visualized image output.
|
50 |
+
"""
|
51 |
+
vis_output = None
|
52 |
+
predictions = self.predictor(image)
|
53 |
+
# Convert image from OpenCV BGR format to Matplotlib RGB format.
|
54 |
+
image = image[:, :, ::-1]
|
55 |
+
visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode)
|
56 |
+
if "panoptic_seg" in predictions:
|
57 |
+
panoptic_seg, segments_info = predictions["panoptic_seg"]
|
58 |
+
vis_output = visualizer.draw_panoptic_seg_predictions(
|
59 |
+
panoptic_seg.to(self.cpu_device), segments_info
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
if "sem_seg" in predictions:
|
63 |
+
vis_output = visualizer.draw_sem_seg(
|
64 |
+
predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
|
65 |
+
)
|
66 |
+
if "instances" in predictions:
|
67 |
+
instances = predictions["instances"].to(self.cpu_device)
|
68 |
+
vis_output = visualizer.draw_instance_predictions(predictions=instances)
|
69 |
+
|
70 |
+
return predictions, vis_output
|
71 |
+
|
72 |
+
def _frame_from_video(self, video):
|
73 |
+
while video.isOpened():
|
74 |
+
success, frame = video.read()
|
75 |
+
if success:
|
76 |
+
yield frame
|
77 |
+
else:
|
78 |
+
break
|
79 |
+
|
80 |
+
def run_on_video(self, video):
|
81 |
+
"""
|
82 |
+
Visualizes predictions on frames of the input video.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be
|
86 |
+
either a webcam or a video file.
|
87 |
+
|
88 |
+
Yields:
|
89 |
+
ndarray: BGR visualizations of each video frame.
|
90 |
+
"""
|
91 |
+
video_visualizer = VideoVisualizer(self.metadata, self.instance_mode)
|
92 |
+
|
93 |
+
def process_predictions(frame, predictions):
|
94 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
95 |
+
if "panoptic_seg" in predictions:
|
96 |
+
panoptic_seg, segments_info = predictions["panoptic_seg"]
|
97 |
+
vis_frame = video_visualizer.draw_panoptic_seg_predictions(
|
98 |
+
frame, panoptic_seg.to(self.cpu_device), segments_info
|
99 |
+
)
|
100 |
+
elif "instances" in predictions:
|
101 |
+
predictions = predictions["instances"].to(self.cpu_device)
|
102 |
+
vis_frame = video_visualizer.draw_instance_predictions(frame, predictions)
|
103 |
+
elif "sem_seg" in predictions:
|
104 |
+
vis_frame = video_visualizer.draw_sem_seg(
|
105 |
+
frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
|
106 |
+
)
|
107 |
+
|
108 |
+
# Converts Matplotlib RGB format to OpenCV BGR format
|
109 |
+
vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR)
|
110 |
+
return vis_frame
|
111 |
+
|
112 |
+
frame_gen = self._frame_from_video(video)
|
113 |
+
if self.parallel:
|
114 |
+
buffer_size = self.predictor.default_buffer_size
|
115 |
+
|
116 |
+
frame_data = deque()
|
117 |
+
|
118 |
+
for cnt, frame in enumerate(frame_gen):
|
119 |
+
frame_data.append(frame)
|
120 |
+
self.predictor.put(frame)
|
121 |
+
|
122 |
+
if cnt >= buffer_size:
|
123 |
+
frame = frame_data.popleft()
|
124 |
+
predictions = self.predictor.get()
|
125 |
+
yield process_predictions(frame, predictions)
|
126 |
+
|
127 |
+
while len(frame_data):
|
128 |
+
frame = frame_data.popleft()
|
129 |
+
predictions = self.predictor.get()
|
130 |
+
yield process_predictions(frame, predictions)
|
131 |
+
else:
|
132 |
+
for frame in frame_gen:
|
133 |
+
yield process_predictions(frame, self.predictor(frame))
|
134 |
+
|
135 |
+
|
136 |
+
class AsyncPredictor:
|
137 |
+
"""
|
138 |
+
A predictor that runs the model asynchronously, possibly on >1 GPUs.
|
139 |
+
Because rendering the visualization takes considerably amount of time,
|
140 |
+
this helps improve throughput a little bit when rendering videos.
|
141 |
+
"""
|
142 |
+
|
143 |
+
class _StopToken:
|
144 |
+
pass
|
145 |
+
|
146 |
+
class _PredictWorker(mp.Process):
|
147 |
+
def __init__(self, cfg, task_queue, result_queue):
|
148 |
+
self.cfg = cfg
|
149 |
+
self.task_queue = task_queue
|
150 |
+
self.result_queue = result_queue
|
151 |
+
super().__init__()
|
152 |
+
|
153 |
+
def run(self):
|
154 |
+
predictor = DefaultPredictor(self.cfg)
|
155 |
+
|
156 |
+
while True:
|
157 |
+
task = self.task_queue.get()
|
158 |
+
if isinstance(task, AsyncPredictor._StopToken):
|
159 |
+
break
|
160 |
+
idx, data = task
|
161 |
+
result = predictor(data)
|
162 |
+
self.result_queue.put((idx, result))
|
163 |
+
|
164 |
+
def __init__(self, cfg, num_gpus: int = 1):
|
165 |
+
"""
|
166 |
+
Args:
|
167 |
+
cfg (CfgNode):
|
168 |
+
num_gpus (int): if 0, will run on CPU
|
169 |
+
"""
|
170 |
+
num_workers = max(num_gpus, 1)
|
171 |
+
self.task_queue = mp.Queue(maxsize=num_workers * 3)
|
172 |
+
self.result_queue = mp.Queue(maxsize=num_workers * 3)
|
173 |
+
self.procs = []
|
174 |
+
for gpuid in range(max(num_gpus, 1)):
|
175 |
+
cfg = cfg.clone()
|
176 |
+
cfg.defrost()
|
177 |
+
cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
|
178 |
+
self.procs.append(
|
179 |
+
AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
|
180 |
+
)
|
181 |
+
|
182 |
+
self.put_idx = 0
|
183 |
+
self.get_idx = 0
|
184 |
+
self.result_rank = []
|
185 |
+
self.result_data = []
|
186 |
+
|
187 |
+
for p in self.procs:
|
188 |
+
p.start()
|
189 |
+
atexit.register(self.shutdown)
|
190 |
+
|
191 |
+
def put(self, image):
|
192 |
+
self.put_idx += 1
|
193 |
+
self.task_queue.put((self.put_idx, image))
|
194 |
+
|
195 |
+
def get(self):
|
196 |
+
self.get_idx += 1 # the index needed for this request
|
197 |
+
if len(self.result_rank) and self.result_rank[0] == self.get_idx:
|
198 |
+
res = self.result_data[0]
|
199 |
+
del self.result_data[0], self.result_rank[0]
|
200 |
+
return res
|
201 |
+
|
202 |
+
while True:
|
203 |
+
# make sure the results are returned in the correct order
|
204 |
+
idx, res = self.result_queue.get()
|
205 |
+
if idx == self.get_idx:
|
206 |
+
return res
|
207 |
+
insert = bisect.bisect(self.result_rank, idx)
|
208 |
+
self.result_rank.insert(insert, idx)
|
209 |
+
self.result_data.insert(insert, res)
|
210 |
+
|
211 |
+
def __len__(self):
|
212 |
+
return self.put_idx - self.get_idx
|
213 |
+
|
214 |
+
def __call__(self, image):
|
215 |
+
self.put(image)
|
216 |
+
return self.get()
|
217 |
+
|
218 |
+
def shutdown(self):
|
219 |
+
for _ in self.procs:
|
220 |
+
self.task_queue.put(AsyncPredictor._StopToken())
|
221 |
+
|
222 |
+
@property
|
223 |
+
def default_buffer_size(self):
|
224 |
+
return len(self.procs) * 5
|
opendet2/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .config import *
|
2 |
+
from .data import *
|
3 |
+
from .engine import *
|
4 |
+
from .evaluation import *
|
5 |
+
from .modeling import *
|
6 |
+
|
7 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
opendet2/config/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .defaults import add_opendet_config
|
2 |
+
|
3 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
opendet2/config/defaults.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from detectron2.config import CfgNode as CN
|
2 |
+
|
3 |
+
|
4 |
+
def add_opendet_config(cfg):
|
5 |
+
_C = cfg
|
6 |
+
|
7 |
+
# unknown probability loss
|
8 |
+
_C.UPLOSS = CN()
|
9 |
+
_C.UPLOSS.START_ITER = 100 # usually the same as warmup iter
|
10 |
+
_C.UPLOSS.SAMPLING_METRIC = "min_score"
|
11 |
+
_C.UPLOSS.TOPK = 3
|
12 |
+
_C.UPLOSS.ALPHA = 1.0
|
13 |
+
_C.UPLOSS.WEIGHT = 0.5
|
14 |
+
|
15 |
+
# instance contrastive loss
|
16 |
+
_C.ICLOSS = CN()
|
17 |
+
_C.ICLOSS.OUT_DIM = 128
|
18 |
+
_C.ICLOSS.QUEUE_SIZE = 256
|
19 |
+
_C.ICLOSS.IN_QUEUE_SIZE = 16
|
20 |
+
_C.ICLOSS.BATCH_IOU_THRESH = 0.5
|
21 |
+
_C.ICLOSS.QUEUE_IOU_THRESH = 0.7
|
22 |
+
_C.ICLOSS.TEMPERATURE = 0.1
|
23 |
+
_C.ICLOSS.WEIGHT = 0.1
|
24 |
+
|
25 |
+
# register RoI output layer
|
26 |
+
_C.MODEL.ROI_BOX_HEAD.OUTPUT_LAYERS = "FastRCNNOutputLayers"
|
27 |
+
# known classes
|
28 |
+
_C.MODEL.ROI_HEADS.NUM_KNOWN_CLASSES = 20
|
29 |
+
_C.MODEL.RETINANET.NUM_KNOWN_CLASSES = 20
|
30 |
+
# thresh for visualization results.
|
31 |
+
_C.MODEL.ROI_HEADS.VIS_IOU_THRESH = 1.0
|
32 |
+
# scale for cosine classifier
|
33 |
+
_C.MODEL.ROI_HEADS.COSINE_SCALE = 20
|
34 |
+
|
35 |
+
# swin transformer
|
36 |
+
_C.MODEL.SWINT = CN()
|
37 |
+
_C.MODEL.SWINT.EMBED_DIM = 96
|
38 |
+
_C.MODEL.SWINT.OUT_FEATURES = ["stage2", "stage3", "stage4", "stage5"]
|
39 |
+
_C.MODEL.SWINT.DEPTHS = [2, 2, 6, 2]
|
40 |
+
_C.MODEL.SWINT.NUM_HEADS = [3, 6, 12, 24]
|
41 |
+
_C.MODEL.SWINT.WINDOW_SIZE = 7
|
42 |
+
_C.MODEL.SWINT.MLP_RATIO = 4
|
43 |
+
_C.MODEL.SWINT.DROP_PATH_RATE = 0.2
|
44 |
+
_C.MODEL.SWINT.APE = False
|
45 |
+
_C.MODEL.BACKBONE.FREEZE_AT = -1
|
46 |
+
_C.MODEL.FPN.TOP_LEVELS = 2
|
47 |
+
|
48 |
+
# solver, e.g., adamw for swin
|
49 |
+
_C.SOLVER.OPTIMIZER = 'SGD'
|
50 |
+
_C.SOLVER.BETAS = (0.9, 0.999)
|
opendet2/data/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .build import *
|
2 |
+
from . import builtin
|
3 |
+
|
4 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
opendet2/data/build.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import itertools
|
3 |
+
import logging
|
4 |
+
import numpy as np
|
5 |
+
import copy
|
6 |
+
import torch.utils.data
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from detectron2.config import configurable
|
10 |
+
from detectron2.utils.logger import _log_api_usage
|
11 |
+
|
12 |
+
from detectron2.data.catalog import DatasetCatalog, MetadataCatalog
|
13 |
+
from detectron2.data.common import DatasetFromList, MapDataset
|
14 |
+
from detectron2.data.dataset_mapper import DatasetMapper
|
15 |
+
from detectron2.data.detection_utils import check_metadata_consistency
|
16 |
+
from detectron2.data.samplers import InferenceSampler, RepeatFactorTrainingSampler, TrainingSampler
|
17 |
+
|
18 |
+
from detectron2.data.build import trivial_batch_collator
|
19 |
+
from detectron2.data import (build_batch_data_loader,
|
20 |
+
print_instances_class_histogram,
|
21 |
+
load_proposals_into_dataset)
|
22 |
+
from detectron2.data.build import (filter_images_with_few_keypoints,
|
23 |
+
filter_images_with_only_crowd_annotations)
|
24 |
+
|
25 |
+
"""
|
26 |
+
This file contains the default logic to build a dataloader for training or testing.
|
27 |
+
"""
|
28 |
+
__all__ = [
|
29 |
+
"build_detection_train_loader",
|
30 |
+
"build_detection_test_loader",
|
31 |
+
"get_detection_dataset_dicts",
|
32 |
+
]
|
33 |
+
|
34 |
+
|
35 |
+
def get_detection_dataset_dicts(names, filter_empty=True, min_keypoints=0, proposal_files=None, cfg=None):
|
36 |
+
"""
|
37 |
+
Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
names (str or list[str]): a dataset name or a list of dataset names
|
41 |
+
filter_empty (bool): whether to filter out images without instance annotations
|
42 |
+
min_keypoints (int): filter out images with fewer keypoints than
|
43 |
+
`min_keypoints`. Set to 0 to do nothing.
|
44 |
+
proposal_files (list[str]): if given, a list of object proposal files
|
45 |
+
that match each dataset in `names`.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
list[dict]: a list of dicts following the standard dataset dict format.
|
49 |
+
"""
|
50 |
+
if isinstance(names, str):
|
51 |
+
names = [names]
|
52 |
+
assert len(names), names
|
53 |
+
dataset_dicts = [DatasetCatalog.get(dataset_name)
|
54 |
+
for dataset_name in names]
|
55 |
+
for dataset_name, dicts in zip(names, dataset_dicts):
|
56 |
+
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
|
57 |
+
|
58 |
+
if proposal_files is not None:
|
59 |
+
assert len(names) == len(proposal_files)
|
60 |
+
# load precomputed proposals from proposal files
|
61 |
+
dataset_dicts = [
|
62 |
+
load_proposals_into_dataset(dataset_i_dicts, proposal_file)
|
63 |
+
for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
|
64 |
+
]
|
65 |
+
|
66 |
+
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
|
67 |
+
|
68 |
+
has_instances = "annotations" in dataset_dicts[0]
|
69 |
+
if filter_empty and has_instances:
|
70 |
+
dataset_dicts = filter_images_with_only_crowd_annotations(
|
71 |
+
dataset_dicts)
|
72 |
+
if min_keypoints > 0 and has_instances:
|
73 |
+
dataset_dicts = filter_images_with_few_keypoints(
|
74 |
+
dataset_dicts, min_keypoints)
|
75 |
+
|
76 |
+
d_name = names[0]
|
77 |
+
# if 'voc_coco' in d_name:
|
78 |
+
if 'train' in d_name:
|
79 |
+
dataset_dicts = remove_unk_instances(cfg, dataset_dicts)
|
80 |
+
elif 'test' in d_name:
|
81 |
+
dataset_dicts = label_known_class_and_unknown(cfg, dataset_dicts)
|
82 |
+
|
83 |
+
if has_instances:
|
84 |
+
try:
|
85 |
+
class_names = MetadataCatalog.get(names[0]).thing_classes
|
86 |
+
check_metadata_consistency("thing_classes", names)
|
87 |
+
print_instances_class_histogram(dataset_dicts, class_names)
|
88 |
+
except AttributeError: # class names are not available for this dataset
|
89 |
+
pass
|
90 |
+
|
91 |
+
assert len(dataset_dicts), "No valid data found in {}.".format(
|
92 |
+
",".join(names))
|
93 |
+
return dataset_dicts
|
94 |
+
|
95 |
+
|
96 |
+
def remove_unk_instances(cfg, dataset_dicts):
|
97 |
+
num_known_classes = cfg.MODEL.ROI_HEADS.NUM_KNOWN_CLASSES
|
98 |
+
valid_classes = range(0, num_known_classes)
|
99 |
+
|
100 |
+
logger = logging.getLogger(__name__)
|
101 |
+
logger.info("Valid classes: " + str(valid_classes))
|
102 |
+
logger.info("Removing unknown objects...")
|
103 |
+
|
104 |
+
for entry in copy.copy(dataset_dicts):
|
105 |
+
annos = entry["annotations"]
|
106 |
+
for annotation in copy.copy(annos):
|
107 |
+
if annotation["category_id"] not in valid_classes:
|
108 |
+
annos.remove(annotation)
|
109 |
+
if len(annos) == 0:
|
110 |
+
dataset_dicts.remove(entry)
|
111 |
+
|
112 |
+
return dataset_dicts
|
113 |
+
|
114 |
+
|
115 |
+
def label_known_class_and_unknown(cfg, dataset_dicts):
|
116 |
+
num_known_classes = cfg.MODEL.ROI_HEADS.NUM_KNOWN_CLASSES
|
117 |
+
total_num_class = cfg.MODEL.ROI_HEADS.NUM_CLASSES
|
118 |
+
|
119 |
+
known_classes = range(0, num_known_classes)
|
120 |
+
|
121 |
+
logger = logging.getLogger(__name__)
|
122 |
+
logger.info("Known classes: " + str(known_classes))
|
123 |
+
logger.info(
|
124 |
+
"Labelling known instances the corresponding label, and unknown instances as unknown...")
|
125 |
+
|
126 |
+
for entry in dataset_dicts:
|
127 |
+
annos = entry["annotations"]
|
128 |
+
for annotation in annos:
|
129 |
+
if annotation["category_id"] not in known_classes:
|
130 |
+
annotation["category_id"] = total_num_class - 1
|
131 |
+
|
132 |
+
return dataset_dicts
|
133 |
+
|
134 |
+
|
135 |
+
def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
|
136 |
+
if dataset is None:
|
137 |
+
dataset = get_detection_dataset_dicts(
|
138 |
+
cfg.DATASETS.TRAIN,
|
139 |
+
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
|
140 |
+
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
|
141 |
+
if cfg.MODEL.KEYPOINT_ON
|
142 |
+
else 0,
|
143 |
+
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
|
144 |
+
cfg=cfg
|
145 |
+
)
|
146 |
+
_log_api_usage("dataset." + cfg.DATASETS.TRAIN[0])
|
147 |
+
|
148 |
+
if mapper is None:
|
149 |
+
mapper = DatasetMapper(cfg, True)
|
150 |
+
else:
|
151 |
+
mapper = mapper(cfg, True)
|
152 |
+
|
153 |
+
if sampler is None:
|
154 |
+
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
|
155 |
+
logger = logging.getLogger(__name__)
|
156 |
+
logger.info("Using training sampler {}".format(sampler_name))
|
157 |
+
if sampler_name == "TrainingSampler":
|
158 |
+
sampler = TrainingSampler(len(dataset))
|
159 |
+
elif sampler_name == "RepeatFactorTrainingSampler":
|
160 |
+
repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
|
161 |
+
dataset, cfg.DATALOADER.REPEAT_THRESHOLD
|
162 |
+
)
|
163 |
+
sampler = RepeatFactorTrainingSampler(repeat_factors)
|
164 |
+
else:
|
165 |
+
raise ValueError(
|
166 |
+
"Unknown training sampler: {}".format(sampler_name))
|
167 |
+
|
168 |
+
return {
|
169 |
+
"dataset": dataset,
|
170 |
+
"sampler": sampler,
|
171 |
+
"mapper": mapper,
|
172 |
+
"total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
|
173 |
+
"aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
|
174 |
+
"num_workers": cfg.DATALOADER.NUM_WORKERS,
|
175 |
+
}
|
176 |
+
|
177 |
+
|
178 |
+
# TODO can allow dataset as an iterable or IterableDataset to make this function more general
|
179 |
+
@configurable(from_config=_train_loader_from_config)
|
180 |
+
def build_detection_train_loader(
|
181 |
+
dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0
|
182 |
+
):
|
183 |
+
"""
|
184 |
+
Build a dataloader for object detection with some default features.
|
185 |
+
This interface is experimental.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
|
189 |
+
or a map-style pytorch dataset. They can be obtained by using
|
190 |
+
:func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
|
191 |
+
mapper (callable): a callable which takes a sample (dict) from dataset and
|
192 |
+
returns the format to be consumed by the model.
|
193 |
+
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
|
194 |
+
sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
|
195 |
+
indices to be applied on ``dataset``. Default to :class:`TrainingSampler`,
|
196 |
+
which coordinates an infinite random shuffle sequence across all workers.
|
197 |
+
total_batch_size (int): total batch size across all workers. Batching
|
198 |
+
simply puts data into a list.
|
199 |
+
aspect_ratio_grouping (bool): whether to group images with similar
|
200 |
+
aspect ratio for efficiency. When enabled, it requires each
|
201 |
+
element in dataset be a dict with keys "width" and "height".
|
202 |
+
num_workers (int): number of parallel data loading workers
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
torch.utils.data.DataLoader:
|
206 |
+
a dataloader. Each output from it is a ``list[mapped_element]`` of length
|
207 |
+
``total_batch_size / num_workers``, where ``mapped_element`` is produced
|
208 |
+
by the ``mapper``.
|
209 |
+
"""
|
210 |
+
if isinstance(dataset, list):
|
211 |
+
dataset = DatasetFromList(dataset, copy=False)
|
212 |
+
if mapper is not None:
|
213 |
+
dataset = MapDataset(dataset, mapper)
|
214 |
+
if sampler is None:
|
215 |
+
sampler = TrainingSampler(len(dataset))
|
216 |
+
assert isinstance(sampler, torch.utils.data.sampler.Sampler)
|
217 |
+
return build_batch_data_loader(
|
218 |
+
dataset,
|
219 |
+
sampler,
|
220 |
+
total_batch_size,
|
221 |
+
aspect_ratio_grouping=aspect_ratio_grouping,
|
222 |
+
num_workers=num_workers,
|
223 |
+
)
|
224 |
+
|
225 |
+
|
226 |
+
def _test_loader_from_config(cfg, dataset_name, mapper=None):
|
227 |
+
"""
|
228 |
+
Uses the given `dataset_name` argument (instead of the names in cfg), because the
|
229 |
+
standard practice is to evaluate each test set individually (not combining them).
|
230 |
+
"""
|
231 |
+
if isinstance(dataset_name, str):
|
232 |
+
dataset_name = [dataset_name]
|
233 |
+
|
234 |
+
dataset = get_detection_dataset_dicts(
|
235 |
+
dataset_name,
|
236 |
+
filter_empty=False,
|
237 |
+
proposal_files=[
|
238 |
+
cfg.DATASETS.PROPOSAL_FILES_TEST[list(
|
239 |
+
cfg.DATASETS.TEST).index(dataset_name)]
|
240 |
+
]
|
241 |
+
if cfg.MODEL.LOAD_PROPOSALS
|
242 |
+
else None,
|
243 |
+
cfg=cfg
|
244 |
+
)
|
245 |
+
if mapper is None:
|
246 |
+
mapper = DatasetMapper(cfg, False)
|
247 |
+
return {"dataset": dataset, "mapper": mapper, "num_workers": cfg.DATALOADER.NUM_WORKERS}
|
248 |
+
|
249 |
+
|
250 |
+
@configurable(from_config=_test_loader_from_config)
|
251 |
+
def build_detection_test_loader(dataset, *, mapper, sampler=None, num_workers=0):
|
252 |
+
"""
|
253 |
+
Similar to `build_detection_train_loader`, but uses a batch size of 1,
|
254 |
+
and :class:`InferenceSampler`. This sampler coordinates all workers to
|
255 |
+
produce the exact set of all samples.
|
256 |
+
This interface is experimental.
|
257 |
+
|
258 |
+
Args:
|
259 |
+
dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
|
260 |
+
or a map-style pytorch dataset. They can be obtained by using
|
261 |
+
:func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
|
262 |
+
mapper (callable): a callable which takes a sample (dict) from dataset
|
263 |
+
and returns the format to be consumed by the model.
|
264 |
+
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
|
265 |
+
sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
|
266 |
+
indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
|
267 |
+
which splits the dataset across all workers.
|
268 |
+
num_workers (int): number of parallel data loading workers
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
DataLoader: a torch DataLoader, that loads the given detection
|
272 |
+
dataset, with test-time transformation and batching.
|
273 |
+
|
274 |
+
Examples:
|
275 |
+
::
|
276 |
+
data_loader = build_detection_test_loader(
|
277 |
+
DatasetRegistry.get("my_test"),
|
278 |
+
mapper=DatasetMapper(...))
|
279 |
+
|
280 |
+
# or, instantiate with a CfgNode:
|
281 |
+
data_loader = build_detection_test_loader(cfg, "my_test")
|
282 |
+
"""
|
283 |
+
if isinstance(dataset, list):
|
284 |
+
dataset = DatasetFromList(dataset, copy=False)
|
285 |
+
if mapper is not None:
|
286 |
+
dataset = MapDataset(dataset, mapper)
|
287 |
+
if sampler is None:
|
288 |
+
sampler = InferenceSampler(len(dataset))
|
289 |
+
# Always use 1 image per worker during inference since this is the
|
290 |
+
# standard when reporting inference time in papers.
|
291 |
+
batch_sampler = torch.utils.data.sampler.BatchSampler(
|
292 |
+
sampler, 1, drop_last=False)
|
293 |
+
data_loader = torch.utils.data.DataLoader(
|
294 |
+
dataset,
|
295 |
+
num_workers=num_workers,
|
296 |
+
batch_sampler=batch_sampler,
|
297 |
+
collate_fn=trivial_batch_collator,
|
298 |
+
)
|
299 |
+
return data_loader
|
opendet2/data/builtin.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from .voc_coco import register_voc_coco
|
4 |
+
from detectron2.data import MetadataCatalog
|
5 |
+
|
6 |
+
|
7 |
+
def register_all_voc_coco(root):
|
8 |
+
SPLITS = [
|
9 |
+
# VOC_COCO_openset
|
10 |
+
("voc_coco_20_40_test", "voc_coco", "voc_coco_20_40_test"),
|
11 |
+
("voc_coco_20_60_test", "voc_coco", "voc_coco_20_60_test"),
|
12 |
+
("voc_coco_20_80_test", "voc_coco", "voc_coco_20_80_test"),
|
13 |
+
|
14 |
+
("voc_coco_2500_test", "voc_coco", "voc_coco_2500_test"),
|
15 |
+
("voc_coco_5000_test", "voc_coco", "voc_coco_5000_test"),
|
16 |
+
("voc_coco_10000_test", "voc_coco", "voc_coco_10000_test"),
|
17 |
+
("voc_coco_20000_test", "voc_coco", "voc_coco_20000_test"),
|
18 |
+
|
19 |
+
("voc_coco_val", "voc_coco", "voc_coco_val"),
|
20 |
+
|
21 |
+
]
|
22 |
+
for name, dirname, split in SPLITS:
|
23 |
+
year = 2007 if "2007" in name else 2012
|
24 |
+
register_voc_coco(name, os.path.join(root, dirname), split, year)
|
25 |
+
MetadataCatalog.get(name).evaluator_type = "pascal_voc"
|
26 |
+
|
27 |
+
|
28 |
+
if __name__.endswith(".builtin"):
|
29 |
+
# Register them all under "./datasets"
|
30 |
+
_root = os.getenv("DETECTRON2_DATASETS", "datasets")
|
31 |
+
register_all_voc_coco(_root)
|
opendet2/data/voc_coco.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
2 |
+
from detectron2.data.datasets import load_voc_instances
|
3 |
+
|
4 |
+
VOC_COCO_CATEGORIES = [
|
5 |
+
# VOC
|
6 |
+
"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
|
7 |
+
"chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
|
8 |
+
"pottedplant", "sheep", "sofa", "train", "tvmonitor",
|
9 |
+
# COCO-20-40
|
10 |
+
"truck", "traffic light", "fire hydrant", "stop sign", "parking meter",
|
11 |
+
"bench", "elephant", "bear", "zebra", "giraffe",
|
12 |
+
"backpack", "umbrella", "handbag", "tie", "suitcase",
|
13 |
+
"microwave", "oven", "toaster", "sink", "refrigerator",
|
14 |
+
# COCO-40-60
|
15 |
+
"frisbee", "skis", "snowboard", "sports ball", "kite",
|
16 |
+
"baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket",
|
17 |
+
"banana", "apple", "sandwich", "orange", "broccoli",
|
18 |
+
"carrot", "hot dog", "pizza", "donut", "cake",
|
19 |
+
# COCO-60-80
|
20 |
+
"bed", "toilet", "laptop", "mouse",
|
21 |
+
"remote", "keyboard", "cell phone", "book", "clock",
|
22 |
+
"vase", "scissors", "teddy bear", "hair drier", "toothbrush",
|
23 |
+
"wine glass", "cup", "fork", "knife", "spoon", "bowl",
|
24 |
+
# Unknown
|
25 |
+
"unknown",
|
26 |
+
]
|
27 |
+
|
28 |
+
|
29 |
+
def register_voc_coco(name, dirname, split, year):
|
30 |
+
class_names = VOC_COCO_CATEGORIES
|
31 |
+
DatasetCatalog.register(
|
32 |
+
name, lambda: load_voc_instances(dirname, split, class_names))
|
33 |
+
MetadataCatalog.get(name).set(
|
34 |
+
thing_classes=list(class_names), dirname=dirname, year=year, split=split
|
35 |
+
)
|
opendet2/engine/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .defaults import OpenDetTrainer
|
2 |
+
|
3 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
opendet2/engine/defaults.py
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import weakref
|
4 |
+
from collections import OrderedDict
|
5 |
+
from typing import Dict
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from detectron2.checkpoint import DetectionCheckpointer
|
9 |
+
from detectron2.config import CfgNode
|
10 |
+
from detectron2.data import MetadataCatalog
|
11 |
+
from detectron2.engine import (AMPTrainer, SimpleTrainer,
|
12 |
+
TrainerBase, create_ddp_model, hooks, create_ddp_model, default_writers)
|
13 |
+
from detectron2.evaluation import (DatasetEvaluator, DatasetEvaluators,
|
14 |
+
inference_on_dataset, print_csv_format,
|
15 |
+
verify_results)
|
16 |
+
from detectron2.modeling import GeneralizedRCNNWithTTA, build_model
|
17 |
+
from detectron2.solver import build_lr_scheduler
|
18 |
+
from detectron2.utils import comm
|
19 |
+
from detectron2.utils.logger import setup_logger
|
20 |
+
from fvcore.nn.precise_bn import get_bn_modules
|
21 |
+
|
22 |
+
from ..data import build_detection_test_loader, build_detection_train_loader
|
23 |
+
from ..evaluation import PascalVOCDetectionEvaluator
|
24 |
+
from ..solver import build_optimizer
|
25 |
+
|
26 |
+
|
27 |
+
class OpenDetTrainer(TrainerBase):
|
28 |
+
"""
|
29 |
+
A trainer with default training logic. It does the following:
|
30 |
+
|
31 |
+
1. Create a :class:`SimpleTrainer` using model, optimizer, dataloader
|
32 |
+
defined by the given config. Create a LR scheduler defined by the config.
|
33 |
+
2. Load the last checkpoint or `cfg.MODEL.WEIGHTS`, if exists, when
|
34 |
+
`resume_or_load` is called.
|
35 |
+
3. Register a few common hooks defined by the config.
|
36 |
+
|
37 |
+
It is created to simplify the **standard model training workflow** and reduce code boilerplate
|
38 |
+
for users who only need the standard training workflow, with standard features.
|
39 |
+
It means this class makes *many assumptions* about your training logic that
|
40 |
+
may easily become invalid in a new research. In fact, any assumptions beyond those made in the
|
41 |
+
:class:`SimpleTrainer` are too much for research.
|
42 |
+
|
43 |
+
The code of this class has been annotated about restrictive assumptions it makes.
|
44 |
+
When they do not work for you, you're encouraged to:
|
45 |
+
|
46 |
+
1. Overwrite methods of this class, OR:
|
47 |
+
2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
|
48 |
+
nothing else. You can then add your own hooks if needed. OR:
|
49 |
+
3. Write your own training loop similar to `tools/plain_train_net.py`.
|
50 |
+
|
51 |
+
See the :doc:`/tutorials/training` tutorials for more details.
|
52 |
+
|
53 |
+
Note that the behavior of this class, like other functions/classes in
|
54 |
+
this file, is not stable, since it is meant to represent the "common default behavior".
|
55 |
+
It is only guaranteed to work well with the standard models and training workflow in detectron2.
|
56 |
+
To obtain more stable behavior, write your own training logic with other public APIs.
|
57 |
+
|
58 |
+
Examples:
|
59 |
+
::
|
60 |
+
trainer = DefaultTrainer(cfg)
|
61 |
+
trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS
|
62 |
+
trainer.train()
|
63 |
+
|
64 |
+
Attributes:
|
65 |
+
scheduler:
|
66 |
+
checkpointer (DetectionCheckpointer):
|
67 |
+
cfg (CfgNode):
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, cfg):
|
71 |
+
"""
|
72 |
+
Args:
|
73 |
+
cfg (CfgNode):
|
74 |
+
"""
|
75 |
+
super().__init__()
|
76 |
+
logger = logging.getLogger("detectron2")
|
77 |
+
# setup_logger is not called for d2
|
78 |
+
if not logger.isEnabledFor(logging.INFO):
|
79 |
+
setup_logger()
|
80 |
+
cfg = OpenDetTrainer.auto_scale_workers(cfg, comm.get_world_size())
|
81 |
+
|
82 |
+
# Assume these objects must be constructed in this order.
|
83 |
+
model = self.build_model(cfg)
|
84 |
+
optimizer = self.build_optimizer(cfg, model)
|
85 |
+
data_loader = self.build_train_loader(cfg)
|
86 |
+
|
87 |
+
model = create_ddp_model(
|
88 |
+
model, broadcast_buffers=False, find_unused_parameters=True)
|
89 |
+
# model = create_ddp_model(model, broadcast_buffers=False)
|
90 |
+
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
|
91 |
+
model, data_loader, optimizer
|
92 |
+
)
|
93 |
+
|
94 |
+
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
|
95 |
+
self.checkpointer = DetectionCheckpointer(
|
96 |
+
# Assume you want to save checkpoints together with logs/statistics
|
97 |
+
model,
|
98 |
+
cfg.OUTPUT_DIR,
|
99 |
+
trainer=weakref.proxy(self),
|
100 |
+
)
|
101 |
+
self.start_iter = 0
|
102 |
+
self.max_iter = cfg.SOLVER.MAX_ITER
|
103 |
+
self.cfg = cfg
|
104 |
+
|
105 |
+
self.register_hooks(self.build_hooks())
|
106 |
+
|
107 |
+
def resume_or_load(self, resume=True):
|
108 |
+
"""
|
109 |
+
If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
|
110 |
+
a `last_checkpoint` file), resume from the file. Resuming means loading all
|
111 |
+
available states (eg. optimizer and scheduler) and update iteration counter
|
112 |
+
from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.
|
113 |
+
|
114 |
+
Otherwise, this is considered as an independent training. The method will load model
|
115 |
+
weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
|
116 |
+
from iteration 0.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
resume (bool): whether to do resume or not
|
120 |
+
"""
|
121 |
+
self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
|
122 |
+
if resume and self.checkpointer.has_checkpoint():
|
123 |
+
# The checkpoint stores the training iteration that just finished, thus we start
|
124 |
+
# at the next iteration
|
125 |
+
self.start_iter = self.iter + 1
|
126 |
+
|
127 |
+
def build_hooks(self):
|
128 |
+
"""
|
129 |
+
Build a list of default hooks, including timing, evaluation,
|
130 |
+
checkpointing, lr scheduling, precise BN, writing events.
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
list[HookBase]:
|
134 |
+
"""
|
135 |
+
cfg = self.cfg.clone()
|
136 |
+
cfg.defrost()
|
137 |
+
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
|
138 |
+
|
139 |
+
ret = [
|
140 |
+
hooks.IterationTimer(),
|
141 |
+
hooks.LRScheduler(),
|
142 |
+
hooks.PreciseBN(
|
143 |
+
# Run at the same freq as (but before) evaluation.
|
144 |
+
cfg.TEST.EVAL_PERIOD,
|
145 |
+
self.model,
|
146 |
+
# Build a new data loader to not affect training
|
147 |
+
self.build_train_loader(cfg),
|
148 |
+
cfg.TEST.PRECISE_BN.NUM_ITER,
|
149 |
+
)
|
150 |
+
if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
|
151 |
+
else None,
|
152 |
+
]
|
153 |
+
|
154 |
+
# Do PreciseBN before checkpointer, because it updates the model and need to
|
155 |
+
# be saved by checkpointer.
|
156 |
+
# This is not always the best: if checkpointing has a different frequency,
|
157 |
+
# some checkpoints may have more precise statistics than others.
|
158 |
+
if comm.is_main_process():
|
159 |
+
ret.append(hooks.PeriodicCheckpointer(
|
160 |
+
self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
|
161 |
+
|
162 |
+
def test_and_save_results():
|
163 |
+
self._last_eval_results = self.test(self.cfg, self.model)
|
164 |
+
return self._last_eval_results
|
165 |
+
|
166 |
+
# Do evaluation after checkpointer, because then if it fails,
|
167 |
+
# we can use the saved checkpoint to debug.
|
168 |
+
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
|
169 |
+
|
170 |
+
if comm.is_main_process():
|
171 |
+
# Here the default print/log frequency of each writer is used.
|
172 |
+
# run writers in the end, so that evaluation metrics are written
|
173 |
+
ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
|
174 |
+
return ret
|
175 |
+
|
176 |
+
def build_writers(self):
|
177 |
+
"""
|
178 |
+
Build a list of writers to be used using :func:`default_writers()`.
|
179 |
+
If you'd like a different list of writers, you can overwrite it in
|
180 |
+
your trainer.
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
list[EventWriter]: a list of :class:`EventWriter` objects.
|
184 |
+
"""
|
185 |
+
return default_writers(self.cfg.OUTPUT_DIR, self.max_iter)
|
186 |
+
|
187 |
+
def train(self):
|
188 |
+
"""
|
189 |
+
Run training.
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
OrderedDict of results, if evaluation is enabled. Otherwise None.
|
193 |
+
"""
|
194 |
+
super().train(self.start_iter, self.max_iter)
|
195 |
+
if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process():
|
196 |
+
assert hasattr(
|
197 |
+
self, "_last_eval_results"
|
198 |
+
), "No evaluation results obtained during training!"
|
199 |
+
verify_results(self.cfg, self._last_eval_results)
|
200 |
+
return self._last_eval_results
|
201 |
+
|
202 |
+
def run_step(self):
|
203 |
+
self._trainer.iter = self.iter
|
204 |
+
self._trainer.run_step()
|
205 |
+
|
206 |
+
@classmethod
|
207 |
+
def build_model(cls, cfg):
|
208 |
+
"""
|
209 |
+
Returns:
|
210 |
+
torch.nn.Module:
|
211 |
+
|
212 |
+
It now calls :func:`detectron2.modeling.build_model`.
|
213 |
+
Overwrite it if you'd like a different model.
|
214 |
+
"""
|
215 |
+
model = build_model(cfg)
|
216 |
+
logger = logging.getLogger(__name__)
|
217 |
+
logger.info("Model:\n{}".format(model))
|
218 |
+
return model
|
219 |
+
|
220 |
+
@classmethod
|
221 |
+
def build_optimizer(cls, cfg, model):
|
222 |
+
"""
|
223 |
+
Returns:
|
224 |
+
torch.optim.Optimizer:
|
225 |
+
|
226 |
+
It now calls :func:`detectron2.solver.build_optimizer`.
|
227 |
+
Overwrite it if you'd like a different optimizer.
|
228 |
+
"""
|
229 |
+
return build_optimizer(cfg, model)
|
230 |
+
|
231 |
+
@classmethod
|
232 |
+
def build_lr_scheduler(cls, cfg, optimizer):
|
233 |
+
"""
|
234 |
+
It now calls :func:`detectron2.solver.build_lr_scheduler`.
|
235 |
+
Overwrite it if you'd like a different scheduler.
|
236 |
+
"""
|
237 |
+
return build_lr_scheduler(cfg, optimizer)
|
238 |
+
|
239 |
+
@classmethod
|
240 |
+
def build_train_loader(cls, cfg):
|
241 |
+
"""
|
242 |
+
Returns:
|
243 |
+
iterable
|
244 |
+
|
245 |
+
It now calls :func:`detectron2.data.build_detection_train_loader`.
|
246 |
+
Overwrite it if you'd like a different data loader.
|
247 |
+
"""
|
248 |
+
return build_detection_train_loader(cfg)
|
249 |
+
|
250 |
+
@classmethod
|
251 |
+
def build_test_loader(cls, cfg, dataset_name):
|
252 |
+
"""
|
253 |
+
Returns:
|
254 |
+
iterable
|
255 |
+
|
256 |
+
It now calls :func:`detectron2.data.build_detection_test_loader`.
|
257 |
+
Overwrite it if you'd like a different data loader.
|
258 |
+
"""
|
259 |
+
return build_detection_test_loader(cfg, dataset_name)
|
260 |
+
|
261 |
+
@classmethod
|
262 |
+
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
|
263 |
+
if output_folder is None:
|
264 |
+
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
|
265 |
+
evaluator_list = []
|
266 |
+
evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
|
267 |
+
|
268 |
+
if evaluator_type == "pascal_voc":
|
269 |
+
return PascalVOCDetectionEvaluator(dataset_name, cfg)
|
270 |
+
|
271 |
+
if len(evaluator_list) == 0:
|
272 |
+
raise NotImplementedError(
|
273 |
+
"no Evaluator for the dataset {} with the type {}".format(
|
274 |
+
dataset_name, evaluator_type
|
275 |
+
)
|
276 |
+
)
|
277 |
+
elif len(evaluator_list) == 1:
|
278 |
+
return evaluator_list[0]
|
279 |
+
return DatasetEvaluators(evaluator_list)
|
280 |
+
|
281 |
+
@classmethod
|
282 |
+
def test_with_TTA(cls, cfg, model):
|
283 |
+
logger = logging.getLogger("detectron2.trainer")
|
284 |
+
# In the end of training, run an evaluation with TTA
|
285 |
+
# Only support some R-CNN models.
|
286 |
+
logger.info("Running inference with test-time augmentation ...")
|
287 |
+
model = GeneralizedRCNNWithTTA(cfg, model)
|
288 |
+
evaluators = [
|
289 |
+
cls.build_evaluator(
|
290 |
+
cfg, name, output_folder=os.path.join(
|
291 |
+
cfg.OUTPUT_DIR, "inference_TTA")
|
292 |
+
)
|
293 |
+
for name in cfg.DATASETS.TEST
|
294 |
+
]
|
295 |
+
res = cls.test(cfg, model, evaluators)
|
296 |
+
res = OrderedDict({k + "_TTA": v for k, v in res.items()})
|
297 |
+
return res
|
298 |
+
|
299 |
+
@classmethod
|
300 |
+
def test(cls, cfg, model, evaluators=None):
|
301 |
+
"""
|
302 |
+
Args:
|
303 |
+
cfg (CfgNode):
|
304 |
+
model (nn.Module):
|
305 |
+
evaluators (list[DatasetEvaluator] or None): if None, will call
|
306 |
+
:meth:`build_evaluator`. Otherwise, must have the same length as
|
307 |
+
``cfg.DATASETS.TEST``.
|
308 |
+
|
309 |
+
Returns:
|
310 |
+
dict: a dict of result metrics
|
311 |
+
"""
|
312 |
+
logger = logging.getLogger(__name__)
|
313 |
+
if isinstance(evaluators, DatasetEvaluator):
|
314 |
+
evaluators = [evaluators]
|
315 |
+
if evaluators is not None:
|
316 |
+
assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
|
317 |
+
len(cfg.DATASETS.TEST), len(evaluators)
|
318 |
+
)
|
319 |
+
|
320 |
+
results = OrderedDict()
|
321 |
+
for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
|
322 |
+
data_loader = cls.build_test_loader(cfg, dataset_name)
|
323 |
+
# When evaluators are passed in as arguments,
|
324 |
+
# implicitly assume that evaluators can be created before data_loader.
|
325 |
+
if evaluators is not None:
|
326 |
+
evaluator = evaluators[idx]
|
327 |
+
else:
|
328 |
+
try:
|
329 |
+
evaluator = cls.build_evaluator(cfg, dataset_name)
|
330 |
+
except NotImplementedError:
|
331 |
+
logger.warn(
|
332 |
+
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
|
333 |
+
"or implement its `build_evaluator` method."
|
334 |
+
)
|
335 |
+
results[dataset_name] = {}
|
336 |
+
continue
|
337 |
+
results_i = inference_on_dataset(model, data_loader, evaluator)
|
338 |
+
results[dataset_name] = results_i
|
339 |
+
if comm.is_main_process():
|
340 |
+
assert isinstance(
|
341 |
+
results_i, dict
|
342 |
+
), "Evaluator must return a dict on the main process. Got {} instead.".format(
|
343 |
+
results_i
|
344 |
+
)
|
345 |
+
logger.info(
|
346 |
+
"Evaluation results for {} in csv format:".format(dataset_name))
|
347 |
+
print_csv_format(results_i)
|
348 |
+
|
349 |
+
if len(results) == 1:
|
350 |
+
results = list(results.values())[0]
|
351 |
+
return results
|
352 |
+
|
353 |
+
@staticmethod
|
354 |
+
def auto_scale_workers(cfg, num_workers: int):
|
355 |
+
"""
|
356 |
+
When the config is defined for certain number of workers (according to
|
357 |
+
``cfg.SOLVER.REFERENCE_WORLD_SIZE``) that's different from the number of
|
358 |
+
workers currently in use, returns a new cfg where the total batch size
|
359 |
+
is scaled so that the per-GPU batch size stays the same as the
|
360 |
+
original ``IMS_PER_BATCH // REFERENCE_WORLD_SIZE``.
|
361 |
+
|
362 |
+
Other config options are also scaled accordingly:
|
363 |
+
* training steps and warmup steps are scaled inverse proportionally.
|
364 |
+
* learning rate are scaled proportionally, following :paper:`ImageNet in 1h`.
|
365 |
+
|
366 |
+
For example, with the original config like the following:
|
367 |
+
|
368 |
+
.. code-block:: yaml
|
369 |
+
|
370 |
+
IMS_PER_BATCH: 16
|
371 |
+
BASE_LR: 0.1
|
372 |
+
REFERENCE_WORLD_SIZE: 8
|
373 |
+
MAX_ITER: 5000
|
374 |
+
STEPS: (4000,)
|
375 |
+
CHECKPOINT_PERIOD: 1000
|
376 |
+
|
377 |
+
When this config is used on 16 GPUs instead of the reference number 8,
|
378 |
+
calling this method will return a new config with:
|
379 |
+
|
380 |
+
.. code-block:: yaml
|
381 |
+
|
382 |
+
IMS_PER_BATCH: 32
|
383 |
+
BASE_LR: 0.2
|
384 |
+
REFERENCE_WORLD_SIZE: 16
|
385 |
+
MAX_ITER: 2500
|
386 |
+
STEPS: (2000,)
|
387 |
+
CHECKPOINT_PERIOD: 500
|
388 |
+
|
389 |
+
Note that both the original config and this new config can be trained on 16 GPUs.
|
390 |
+
It's up to user whether to enable this feature (by setting ``REFERENCE_WORLD_SIZE``).
|
391 |
+
|
392 |
+
Returns:
|
393 |
+
CfgNode: a new config. Same as original if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``.
|
394 |
+
"""
|
395 |
+
old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZE
|
396 |
+
if old_world_size == 0 or old_world_size == num_workers:
|
397 |
+
return cfg
|
398 |
+
cfg = cfg.clone()
|
399 |
+
frozen = cfg.is_frozen()
|
400 |
+
cfg.defrost()
|
401 |
+
|
402 |
+
assert (
|
403 |
+
cfg.SOLVER.IMS_PER_BATCH % old_world_size == 0
|
404 |
+
), "Invalid REFERENCE_WORLD_SIZE in config!"
|
405 |
+
scale = num_workers / old_world_size
|
406 |
+
bs = cfg.SOLVER.IMS_PER_BATCH = int(
|
407 |
+
round(cfg.SOLVER.IMS_PER_BATCH * scale))
|
408 |
+
lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale
|
409 |
+
max_iter = cfg.SOLVER.MAX_ITER = int(
|
410 |
+
round(cfg.SOLVER.MAX_ITER / scale))
|
411 |
+
warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(
|
412 |
+
round(cfg.SOLVER.WARMUP_ITERS / scale))
|
413 |
+
cfg.SOLVER.STEPS = tuple(int(round(s / scale))
|
414 |
+
for s in cfg.SOLVER.STEPS)
|
415 |
+
cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale))
|
416 |
+
cfg.SOLVER.CHECKPOINT_PERIOD = int(
|
417 |
+
round(cfg.SOLVER.CHECKPOINT_PERIOD / scale))
|
418 |
+
cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers # maintain invariant
|
419 |
+
logger = logging.getLogger(__name__)
|
420 |
+
logger.info(
|
421 |
+
f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, "
|
422 |
+
f"max_iter={max_iter}, warmup={warmup_iter}."
|
423 |
+
)
|
424 |
+
|
425 |
+
if frozen:
|
426 |
+
cfg.freeze()
|
427 |
+
return cfg
|
428 |
+
|
429 |
+
|
430 |
+
# Access basic attributes from the underlying trainer
|
431 |
+
for _attr in ["model", "data_loader", "optimizer"]:
|
432 |
+
setattr(
|
433 |
+
OpenDetTrainer,
|
434 |
+
_attr,
|
435 |
+
property(
|
436 |
+
# getter
|
437 |
+
lambda self, x=_attr: getattr(self._trainer, x),
|
438 |
+
# setter
|
439 |
+
lambda self, value, x=_attr: setattr(self._trainer, x, value),
|
440 |
+
),
|
441 |
+
)
|
opendet2/evaluation/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .pascal_voc_evaluation import *
|
2 |
+
|
3 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
opendet2/evaluation/pascal_voc_evaluation.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Code is modified from https://github.com/JosephKJ/OWOD
|
3 |
+
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import tempfile
|
7 |
+
import xml.etree.ElementTree as ET
|
8 |
+
from collections import OrderedDict, defaultdict
|
9 |
+
from functools import lru_cache
|
10 |
+
from tabulate import tabulate
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from detectron2.data import MetadataCatalog
|
15 |
+
from detectron2.evaluation import DatasetEvaluator
|
16 |
+
from detectron2.evaluation.pascal_voc_evaluation import voc_ap
|
17 |
+
from detectron2.utils import comm
|
18 |
+
from detectron2.utils.file_io import PathManager
|
19 |
+
|
20 |
+
|
21 |
+
class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
22 |
+
def __init__(self, dataset_name, cfg=None):
|
23 |
+
"""
|
24 |
+
Args:
|
25 |
+
dataset_name (str): name of the dataset, e.g., "voc_2007_test"
|
26 |
+
"""
|
27 |
+
self._dataset_name = dataset_name
|
28 |
+
meta = MetadataCatalog.get(dataset_name)
|
29 |
+
|
30 |
+
# Too many tiny files, download all to local for speed.
|
31 |
+
annotation_dir_local = PathManager.get_local_path(
|
32 |
+
os.path.join(meta.dirname, "Annotations/")
|
33 |
+
)
|
34 |
+
self._anno_file_template = os.path.join(annotation_dir_local, "{}.xml")
|
35 |
+
self._image_set_path = os.path.join(
|
36 |
+
meta.dirname, "ImageSets", "Main", meta.split + ".txt")
|
37 |
+
self._class_names = meta.thing_classes
|
38 |
+
assert meta.year in [2007, 2012], meta.year
|
39 |
+
self.logger = logging.getLogger(__name__)
|
40 |
+
self._is_2007 = meta.year == 2007
|
41 |
+
self._cpu_device = torch.device("cpu")
|
42 |
+
if cfg is not None:
|
43 |
+
self.output_dir = cfg.OUTPUT_DIR
|
44 |
+
self.total_num_class = cfg.MODEL.ROI_HEADS.NUM_CLASSES
|
45 |
+
self.unknown_class_index = self.total_num_class - 1
|
46 |
+
self.num_known_classes = cfg.MODEL.ROI_HEADS.NUM_KNOWN_CLASSES
|
47 |
+
self.known_classes = self._class_names[:self.num_known_classes]
|
48 |
+
|
49 |
+
def reset(self):
|
50 |
+
# class name -> list of prediction strings
|
51 |
+
self._predictions = defaultdict(list)
|
52 |
+
|
53 |
+
def process(self, inputs, outputs):
|
54 |
+
for input, output in zip(inputs, outputs):
|
55 |
+
image_id = input["image_id"]
|
56 |
+
instances = output["instances"].to(self._cpu_device)
|
57 |
+
boxes = instances.pred_boxes.tensor.numpy()
|
58 |
+
scores = instances.scores.tolist()
|
59 |
+
classes = instances.pred_classes.tolist()
|
60 |
+
|
61 |
+
for box, score, cls in zip(boxes, scores, classes):
|
62 |
+
xmin, ymin, xmax, ymax = box
|
63 |
+
# The inverse of data loading logic in `datasets/pascal_voc.py`
|
64 |
+
xmin += 1
|
65 |
+
ymin += 1
|
66 |
+
self._predictions[cls].append(
|
67 |
+
f"{image_id} {score:.3f} {xmin:.1f} {ymin:.1f} {xmax:.1f} {ymax:.1f}"
|
68 |
+
)
|
69 |
+
|
70 |
+
def compute_WI_at_many_recall_level(self, recalls, tp_plus_fp_cs, fp_os):
|
71 |
+
wi_at_recall = {}
|
72 |
+
# for r in range(1, 10):
|
73 |
+
for r in [8]:
|
74 |
+
r = r/10
|
75 |
+
wi = self.compute_WI_at_a_recall_level(
|
76 |
+
recalls, tp_plus_fp_cs, fp_os, recall_level=r)
|
77 |
+
wi_at_recall[r] = wi
|
78 |
+
return wi_at_recall
|
79 |
+
|
80 |
+
def compute_WI_at_a_recall_level(self, recalls, tp_plus_fp_cs, fp_os, recall_level=0.5):
|
81 |
+
wi_at_iou = {}
|
82 |
+
for iou, recall in recalls.items():
|
83 |
+
tp_plus_fps = []
|
84 |
+
fps = []
|
85 |
+
for cls_id, rec in enumerate(recall):
|
86 |
+
if cls_id in range(self.num_known_classes) and len(rec) > 0:
|
87 |
+
index = min(range(len(rec)), key=lambda i: abs(
|
88 |
+
rec[i] - recall_level))
|
89 |
+
tp_plus_fp = tp_plus_fp_cs[iou][cls_id][index]
|
90 |
+
tp_plus_fps.append(tp_plus_fp)
|
91 |
+
fp = fp_os[iou][cls_id][index]
|
92 |
+
fps.append(fp)
|
93 |
+
if len(tp_plus_fps) > 0:
|
94 |
+
wi_at_iou[iou] = np.mean(fps) / np.mean(tp_plus_fps)
|
95 |
+
else:
|
96 |
+
wi_at_iou[iou] = 0
|
97 |
+
return wi_at_iou
|
98 |
+
|
99 |
+
def evaluate(self):
|
100 |
+
"""
|
101 |
+
Returns:
|
102 |
+
dict: has a key "segm", whose value is a dict of "AP", "AP50", and "AP75".
|
103 |
+
"""
|
104 |
+
all_predictions = comm.gather(self._predictions, dst=0)
|
105 |
+
if not comm.is_main_process():
|
106 |
+
return
|
107 |
+
predictions = defaultdict(list)
|
108 |
+
for predictions_per_rank in all_predictions:
|
109 |
+
for clsid, lines in predictions_per_rank.items():
|
110 |
+
predictions[clsid].extend(lines)
|
111 |
+
del all_predictions
|
112 |
+
|
113 |
+
self.logger.info(
|
114 |
+
"Evaluating {} using {} metric. "
|
115 |
+
"Note that results do not use the official Matlab API.".format(
|
116 |
+
self._dataset_name, 2007 if self._is_2007 else 2012
|
117 |
+
)
|
118 |
+
)
|
119 |
+
|
120 |
+
dirname = os.path.join(self.output_dir, 'pascal_voc_eval')
|
121 |
+
if not os.path.exists(dirname):
|
122 |
+
os.mkdir(dirname)
|
123 |
+
# with tempfile.TemporaryDirectory(prefix="pascal_voc_eval_") as dirname:
|
124 |
+
res_file_template = os.path.join(dirname, "{}.txt")
|
125 |
+
|
126 |
+
aps = defaultdict(list) # iou -> ap per class
|
127 |
+
recs = defaultdict(list)
|
128 |
+
precs = defaultdict(list)
|
129 |
+
all_recs = defaultdict(list)
|
130 |
+
all_precs = defaultdict(list)
|
131 |
+
unk_det_as_knowns = defaultdict(list)
|
132 |
+
num_unks = defaultdict(list)
|
133 |
+
tp_plus_fp_cs = defaultdict(list)
|
134 |
+
fp_os = defaultdict(list)
|
135 |
+
|
136 |
+
for cls_id, cls_name in enumerate(self._class_names):
|
137 |
+
lines = predictions.get(cls_id, [""])
|
138 |
+
|
139 |
+
with open(res_file_template.format(cls_name), "w") as f:
|
140 |
+
f.write("\n".join(lines))
|
141 |
+
|
142 |
+
for thresh in [50, ]:
|
143 |
+
# for thresh in range(50, 100, 5):
|
144 |
+
(rec, prec, ap, unk_det_as_known, num_unk,
|
145 |
+
tp_plus_fp_closed_set, fp_open_set) = voc_eval(
|
146 |
+
res_file_template,
|
147 |
+
self._anno_file_template,
|
148 |
+
self._image_set_path,
|
149 |
+
cls_name,
|
150 |
+
ovthresh=thresh / 100.0,
|
151 |
+
use_07_metric=self._is_2007,
|
152 |
+
known_classes=self.known_classes
|
153 |
+
)
|
154 |
+
aps[thresh].append(ap * 100)
|
155 |
+
unk_det_as_knowns[thresh].append(unk_det_as_known)
|
156 |
+
num_unks[thresh].append(num_unk)
|
157 |
+
all_precs[thresh].append(prec)
|
158 |
+
all_recs[thresh].append(rec)
|
159 |
+
tp_plus_fp_cs[thresh].append(tp_plus_fp_closed_set)
|
160 |
+
fp_os[thresh].append(fp_open_set)
|
161 |
+
try:
|
162 |
+
recs[thresh].append(rec[-1] * 100)
|
163 |
+
precs[thresh].append(prec[-1] * 100)
|
164 |
+
except:
|
165 |
+
recs[thresh].append(0)
|
166 |
+
precs[thresh].append(0)
|
167 |
+
|
168 |
+
results_2d = {}
|
169 |
+
mAP = {iou: np.mean(x) for iou, x in aps.items()}
|
170 |
+
results_2d['mAP'] = mAP[50]
|
171 |
+
|
172 |
+
wi = self.compute_WI_at_many_recall_level(
|
173 |
+
all_recs, tp_plus_fp_cs, fp_os)
|
174 |
+
results_2d['WI'] = wi[0.8][50] * 100
|
175 |
+
|
176 |
+
total_num_unk_det_as_known = {iou: np.sum(
|
177 |
+
x) for iou, x in unk_det_as_knowns.items()}
|
178 |
+
# total_num_unk = num_unks[50][0]
|
179 |
+
# self.logger.info('num_unk ' + str(total_num_unk))
|
180 |
+
results_2d['AOSE'] = total_num_unk_det_as_known[50]
|
181 |
+
|
182 |
+
# class-wise P-R
|
183 |
+
# self.logger.info(self._class_names)
|
184 |
+
# self.logger.info("AP50: " + str(['%.1f' % x for x in aps[50]]))
|
185 |
+
# self.logger.info("P50: " + str(['%.1f' % x for x in precs[50]]))
|
186 |
+
# self.logger.info("R50: " + str(['%.1f' % x for x in recs[50]]))
|
187 |
+
|
188 |
+
# Known
|
189 |
+
results_2d.update({
|
190 |
+
"AP@K": np.mean(aps[50][:self.num_known_classes]),
|
191 |
+
"P@K": np.mean(precs[50][:self.num_known_classes]),
|
192 |
+
"R@K": np.mean(recs[50][:self.num_known_classes]),
|
193 |
+
})
|
194 |
+
|
195 |
+
# Unknown
|
196 |
+
results_2d.update({
|
197 |
+
"AP@U": np.mean(aps[50][-1]),
|
198 |
+
"P@U": np.mean(precs[50][-1]),
|
199 |
+
"R@U": np.mean(recs[50][-1]),
|
200 |
+
})
|
201 |
+
results_head = list(results_2d.keys())
|
202 |
+
results_data = [[float(results_2d[k]) for k in results_2d]]
|
203 |
+
table = tabulate(
|
204 |
+
results_data,
|
205 |
+
tablefmt="pipe",
|
206 |
+
floatfmt=".2f",
|
207 |
+
headers=results_head,
|
208 |
+
numalign="left",
|
209 |
+
)
|
210 |
+
self.logger.info("\n" + table)
|
211 |
+
|
212 |
+
return {",".join(results_head): ",".join([str(round(x,2)) for x in results_data[0]])}
|
213 |
+
|
214 |
+
|
215 |
+
@lru_cache(maxsize=None)
|
216 |
+
def parse_rec(filename, known_classes):
|
217 |
+
"""Parse a PASCAL VOC xml file."""
|
218 |
+
with PathManager.open(filename) as f:
|
219 |
+
tree = ET.parse(f)
|
220 |
+
objects = []
|
221 |
+
for obj in tree.findall("object"):
|
222 |
+
obj_struct = {}
|
223 |
+
cls_name = obj.find("name").text
|
224 |
+
# translate unseen classes to unknown
|
225 |
+
if cls_name not in known_classes:
|
226 |
+
cls_name = 'unknown'
|
227 |
+
|
228 |
+
obj_struct["name"] = cls_name
|
229 |
+
# obj_struct["pose"] = obj.find("pose").text
|
230 |
+
# obj_struct["truncated"] = int(obj.find("truncated").text)
|
231 |
+
obj_struct["difficult"] = int(obj.find("difficult").text)
|
232 |
+
bbox = obj.find("bndbox")
|
233 |
+
obj_struct["bbox"] = [
|
234 |
+
int(bbox.find("xmin").text),
|
235 |
+
int(bbox.find("ymin").text),
|
236 |
+
int(bbox.find("xmax").text),
|
237 |
+
int(bbox.find("ymax").text),
|
238 |
+
]
|
239 |
+
objects.append(obj_struct)
|
240 |
+
|
241 |
+
return objects
|
242 |
+
|
243 |
+
|
244 |
+
def compute_overlaps(BBGT, bb):
|
245 |
+
# compute overlaps
|
246 |
+
# intersection
|
247 |
+
ixmin = np.maximum(BBGT[:, 0], bb[0])
|
248 |
+
iymin = np.maximum(BBGT[:, 1], bb[1])
|
249 |
+
ixmax = np.minimum(BBGT[:, 2], bb[2])
|
250 |
+
iymax = np.minimum(BBGT[:, 3], bb[3])
|
251 |
+
iw = np.maximum(ixmax - ixmin + 1.0, 0.0)
|
252 |
+
ih = np.maximum(iymax - iymin + 1.0, 0.0)
|
253 |
+
inters = iw * ih
|
254 |
+
|
255 |
+
# union
|
256 |
+
uni = (
|
257 |
+
(bb[2] - bb[0] + 1.0) * (bb[3] - bb[1] + 1.0)
|
258 |
+
+ (BBGT[:, 2] - BBGT[:, 0] + 1.0) * (BBGT[:, 3] - BBGT[:, 1] + 1.0)
|
259 |
+
- inters
|
260 |
+
)
|
261 |
+
|
262 |
+
return inters / uni
|
263 |
+
|
264 |
+
|
265 |
+
def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_metric=False, known_classes=None):
|
266 |
+
# first load gt
|
267 |
+
# read list of images
|
268 |
+
with PathManager.open(imagesetfile, "r") as f:
|
269 |
+
lines = f.readlines()
|
270 |
+
imagenames = [x.strip() for x in lines]
|
271 |
+
|
272 |
+
# load annots
|
273 |
+
recs = {}
|
274 |
+
for imagename in imagenames:
|
275 |
+
recs[imagename] = parse_rec(
|
276 |
+
annopath.format(imagename), tuple(known_classes))
|
277 |
+
|
278 |
+
# extract gt objects for this class
|
279 |
+
class_recs = {}
|
280 |
+
npos = 0
|
281 |
+
for imagename in imagenames:
|
282 |
+
R = [obj for obj in recs[imagename] if obj["name"] == classname]
|
283 |
+
bbox = np.array([x["bbox"] for x in R])
|
284 |
+
difficult = np.array([x["difficult"] for x in R]).astype(np.bool)
|
285 |
+
# difficult = np.array([False for x in R]).astype(np.bool) # treat all "difficult" as GT
|
286 |
+
det = [False] * len(R)
|
287 |
+
npos = npos + sum(~difficult)
|
288 |
+
class_recs[imagename] = {"bbox": bbox,
|
289 |
+
"difficult": difficult, "det": det}
|
290 |
+
|
291 |
+
# read dets
|
292 |
+
detfile = detpath.format(classname)
|
293 |
+
with open(detfile, "r") as f:
|
294 |
+
lines = f.readlines()
|
295 |
+
|
296 |
+
splitlines = [x.strip().split(" ") for x in lines]
|
297 |
+
image_ids = [x[0] for x in splitlines]
|
298 |
+
confidence = np.array([float(x[1]) for x in splitlines])
|
299 |
+
BB = np.array([[float(z) for z in x[2:]]
|
300 |
+
for x in splitlines]).reshape(-1, 4)
|
301 |
+
|
302 |
+
# sort by confidence
|
303 |
+
sorted_ind = np.argsort(-confidence)
|
304 |
+
BB = BB[sorted_ind, :]
|
305 |
+
image_ids = [image_ids[x] for x in sorted_ind]
|
306 |
+
|
307 |
+
# go down dets and mark TPs and FPs
|
308 |
+
nd = len(image_ids)
|
309 |
+
tp = np.zeros(nd)
|
310 |
+
fp = np.zeros(nd)
|
311 |
+
for d in range(nd):
|
312 |
+
R = class_recs[image_ids[d]]
|
313 |
+
bb = BB[d, :].astype(float)
|
314 |
+
ovmax = -np.inf
|
315 |
+
BBGT = R["bbox"].astype(float)
|
316 |
+
|
317 |
+
if BBGT.size > 0:
|
318 |
+
overlaps = compute_overlaps(BBGT, bb)
|
319 |
+
ovmax = np.max(overlaps)
|
320 |
+
jmax = np.argmax(overlaps)
|
321 |
+
|
322 |
+
if ovmax > ovthresh:
|
323 |
+
if not R["difficult"][jmax]:
|
324 |
+
if not R["det"][jmax]:
|
325 |
+
tp[d] = 1.0
|
326 |
+
R["det"][jmax] = 1
|
327 |
+
else:
|
328 |
+
fp[d] = 1.0
|
329 |
+
else:
|
330 |
+
fp[d] = 1.0
|
331 |
+
|
332 |
+
# compute precision recall
|
333 |
+
fp = np.cumsum(fp)
|
334 |
+
tp = np.cumsum(tp)
|
335 |
+
rec = tp / float(npos)
|
336 |
+
# avoid divide by zero in case the first detection matches a difficult
|
337 |
+
# ground truth
|
338 |
+
prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
|
339 |
+
ap = voc_ap(rec, prec, use_07_metric)
|
340 |
+
|
341 |
+
# compute unknown det as known
|
342 |
+
unknown_class_recs = {}
|
343 |
+
n_unk = 0
|
344 |
+
for imagename in imagenames:
|
345 |
+
R = [obj for obj in recs[imagename] if obj["name"] == 'unknown']
|
346 |
+
bbox = np.array([x["bbox"] for x in R])
|
347 |
+
difficult = np.array([x["difficult"] for x in R]).astype(np.bool)
|
348 |
+
det = [False] * len(R)
|
349 |
+
n_unk = n_unk + sum(~difficult)
|
350 |
+
unknown_class_recs[imagename] = {
|
351 |
+
"bbox": bbox, "difficult": difficult, "det": det}
|
352 |
+
|
353 |
+
if classname == 'unknown':
|
354 |
+
return rec, prec, ap, 0, n_unk, None, None
|
355 |
+
|
356 |
+
# Go down each detection and see if it has an overlap with an unknown object.
|
357 |
+
# If so, it is an unknown object that was classified as known.
|
358 |
+
is_unk = np.zeros(nd)
|
359 |
+
for d in range(nd):
|
360 |
+
R = unknown_class_recs[image_ids[d]]
|
361 |
+
bb = BB[d, :].astype(float)
|
362 |
+
ovmax = -np.inf
|
363 |
+
BBGT = R["bbox"].astype(float)
|
364 |
+
|
365 |
+
if BBGT.size > 0:
|
366 |
+
overlaps = compute_overlaps(BBGT, bb)
|
367 |
+
ovmax = np.max(overlaps)
|
368 |
+
jmax = np.argmax(overlaps)
|
369 |
+
|
370 |
+
if ovmax > ovthresh:
|
371 |
+
is_unk[d] = 1.0
|
372 |
+
|
373 |
+
is_unk_sum = np.sum(is_unk)
|
374 |
+
tp_plus_fp_closed_set = tp+fp
|
375 |
+
fp_open_set = np.cumsum(is_unk)
|
376 |
+
|
377 |
+
return rec, prec, ap, is_unk_sum, n_unk, tp_plus_fp_closed_set, fp_open_set
|
opendet2/modeling/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .meta_arch import OpenSetRetinaNet
|
2 |
+
from .backbone import *
|
3 |
+
from .roi_heads import *
|
4 |
+
|
5 |
+
__all__ = list(globals().keys())
|
opendet2/modeling/backbone/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .swin_transformer import SwinTransformer
|
2 |
+
|
3 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
opendet2/modeling/backbone/swin_transformer.py
ADDED
@@ -0,0 +1,726 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Swin Transformer
|
3 |
+
# modified from https://github.com/xiaohu2015/SwinT_detectron2/blob/main/swint/swin_transformer.py
|
4 |
+
# --------------------------------------------------------
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch.utils.checkpoint as checkpoint
|
10 |
+
import numpy as np
|
11 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
12 |
+
|
13 |
+
from detectron2.modeling.backbone import Backbone
|
14 |
+
from detectron2.modeling.backbone.build import BACKBONE_REGISTRY
|
15 |
+
from detectron2.modeling.backbone.fpn import FPN, LastLevelMaxPool, LastLevelP6P7
|
16 |
+
from detectron2.layers import ShapeSpec
|
17 |
+
|
18 |
+
|
19 |
+
class Mlp(nn.Module):
|
20 |
+
""" Multilayer perceptron."""
|
21 |
+
|
22 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
23 |
+
super().__init__()
|
24 |
+
out_features = out_features or in_features
|
25 |
+
hidden_features = hidden_features or in_features
|
26 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
27 |
+
self.act = act_layer()
|
28 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
29 |
+
self.drop = nn.Dropout(drop)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
x = self.fc1(x)
|
33 |
+
x = self.act(x)
|
34 |
+
x = self.drop(x)
|
35 |
+
x = self.fc2(x)
|
36 |
+
x = self.drop(x)
|
37 |
+
return x
|
38 |
+
|
39 |
+
|
40 |
+
def window_partition(x, window_size):
|
41 |
+
"""
|
42 |
+
Args:
|
43 |
+
x: (B, H, W, C)
|
44 |
+
window_size (int): window size
|
45 |
+
Returns:
|
46 |
+
windows: (num_windows*B, window_size, window_size, C)
|
47 |
+
"""
|
48 |
+
B, H, W, C = x.shape
|
49 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
50 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
51 |
+
return windows
|
52 |
+
|
53 |
+
|
54 |
+
def window_reverse(windows, window_size, H, W):
|
55 |
+
"""
|
56 |
+
Args:
|
57 |
+
windows: (num_windows*B, window_size, window_size, C)
|
58 |
+
window_size (int): Window size
|
59 |
+
H (int): Height of image
|
60 |
+
W (int): Width of image
|
61 |
+
Returns:
|
62 |
+
x: (B, H, W, C)
|
63 |
+
"""
|
64 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
65 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
66 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
67 |
+
return x
|
68 |
+
|
69 |
+
|
70 |
+
class WindowAttention(nn.Module):
|
71 |
+
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
72 |
+
It supports both of shifted and non-shifted window.
|
73 |
+
Args:
|
74 |
+
dim (int): Number of input channels.
|
75 |
+
window_size (tuple[int]): The height and width of the window.
|
76 |
+
num_heads (int): Number of attention heads.
|
77 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
78 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
79 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
80 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
84 |
+
|
85 |
+
super().__init__()
|
86 |
+
self.dim = dim
|
87 |
+
self.window_size = window_size # Wh, Ww
|
88 |
+
self.num_heads = num_heads
|
89 |
+
head_dim = dim // num_heads
|
90 |
+
self.scale = qk_scale or head_dim ** -0.5
|
91 |
+
|
92 |
+
# define a parameter table of relative position bias
|
93 |
+
self.relative_position_bias_table = nn.Parameter(
|
94 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
95 |
+
|
96 |
+
# get pair-wise relative position index for each token inside the window
|
97 |
+
coords_h = torch.arange(self.window_size[0])
|
98 |
+
coords_w = torch.arange(self.window_size[1])
|
99 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
100 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
101 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
102 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
103 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
104 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
105 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
106 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
107 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
108 |
+
|
109 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
110 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
111 |
+
self.proj = nn.Linear(dim, dim)
|
112 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
113 |
+
|
114 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
115 |
+
self.softmax = nn.Softmax(dim=-1)
|
116 |
+
|
117 |
+
def forward(self, x, mask=None):
|
118 |
+
""" Forward function.
|
119 |
+
Args:
|
120 |
+
x: input features with shape of (num_windows*B, N, C)
|
121 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
122 |
+
"""
|
123 |
+
B_, N, C = x.shape
|
124 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
125 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
126 |
+
|
127 |
+
q = q * self.scale
|
128 |
+
attn = (q @ k.transpose(-2, -1))
|
129 |
+
|
130 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
131 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
132 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
133 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
134 |
+
|
135 |
+
if mask is not None:
|
136 |
+
nW = mask.shape[0]
|
137 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
138 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
139 |
+
attn = self.softmax(attn)
|
140 |
+
else:
|
141 |
+
attn = self.softmax(attn)
|
142 |
+
|
143 |
+
attn = self.attn_drop(attn)
|
144 |
+
|
145 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
146 |
+
x = self.proj(x)
|
147 |
+
x = self.proj_drop(x)
|
148 |
+
return x
|
149 |
+
|
150 |
+
|
151 |
+
class SwinTransformerBlock(nn.Module):
|
152 |
+
""" Swin Transformer Block.
|
153 |
+
Args:
|
154 |
+
dim (int): Number of input channels.
|
155 |
+
num_heads (int): Number of attention heads.
|
156 |
+
window_size (int): Window size.
|
157 |
+
shift_size (int): Shift size for SW-MSA.
|
158 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
159 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
160 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
161 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
162 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
163 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
164 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
165 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
166 |
+
"""
|
167 |
+
|
168 |
+
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
|
169 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
170 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
171 |
+
super().__init__()
|
172 |
+
self.dim = dim
|
173 |
+
self.num_heads = num_heads
|
174 |
+
self.window_size = window_size
|
175 |
+
self.shift_size = shift_size
|
176 |
+
self.mlp_ratio = mlp_ratio
|
177 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
178 |
+
|
179 |
+
self.norm1 = norm_layer(dim)
|
180 |
+
self.attn = WindowAttention(
|
181 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
182 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
183 |
+
|
184 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
185 |
+
self.norm2 = norm_layer(dim)
|
186 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
187 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
188 |
+
|
189 |
+
self.H = None
|
190 |
+
self.W = None
|
191 |
+
|
192 |
+
def forward(self, x, mask_matrix):
|
193 |
+
""" Forward function.
|
194 |
+
Args:
|
195 |
+
x: Input feature, tensor size (B, H*W, C).
|
196 |
+
H, W: Spatial resolution of the input feature.
|
197 |
+
mask_matrix: Attention mask for cyclic shift.
|
198 |
+
"""
|
199 |
+
B, L, C = x.shape
|
200 |
+
H, W = self.H, self.W
|
201 |
+
assert L == H * W, "input feature has wrong size"
|
202 |
+
|
203 |
+
shortcut = x
|
204 |
+
x = self.norm1(x)
|
205 |
+
x = x.view(B, H, W, C)
|
206 |
+
|
207 |
+
# pad feature maps to multiples of window size
|
208 |
+
pad_l = pad_t = 0
|
209 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
210 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
211 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
212 |
+
_, Hp, Wp, _ = x.shape
|
213 |
+
|
214 |
+
# cyclic shift
|
215 |
+
if self.shift_size > 0:
|
216 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
217 |
+
attn_mask = mask_matrix
|
218 |
+
else:
|
219 |
+
shifted_x = x
|
220 |
+
attn_mask = None
|
221 |
+
|
222 |
+
# partition windows
|
223 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
224 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
225 |
+
|
226 |
+
# W-MSA/SW-MSA
|
227 |
+
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
228 |
+
|
229 |
+
# merge windows
|
230 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
231 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
232 |
+
|
233 |
+
# reverse cyclic shift
|
234 |
+
if self.shift_size > 0:
|
235 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
236 |
+
else:
|
237 |
+
x = shifted_x
|
238 |
+
|
239 |
+
if pad_r > 0 or pad_b > 0:
|
240 |
+
x = x[:, :H, :W, :].contiguous()
|
241 |
+
|
242 |
+
x = x.view(B, H * W, C)
|
243 |
+
|
244 |
+
# FFN
|
245 |
+
x = shortcut + self.drop_path(x)
|
246 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
247 |
+
|
248 |
+
return x
|
249 |
+
|
250 |
+
|
251 |
+
class PatchMerging(nn.Module):
|
252 |
+
""" Patch Merging Layer
|
253 |
+
Args:
|
254 |
+
dim (int): Number of input channels.
|
255 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
256 |
+
"""
|
257 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
258 |
+
super().__init__()
|
259 |
+
self.dim = dim
|
260 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
261 |
+
self.norm = norm_layer(4 * dim)
|
262 |
+
|
263 |
+
def forward(self, x, H, W):
|
264 |
+
""" Forward function.
|
265 |
+
Args:
|
266 |
+
x: Input feature, tensor size (B, H*W, C).
|
267 |
+
H, W: Spatial resolution of the input feature.
|
268 |
+
"""
|
269 |
+
B, L, C = x.shape
|
270 |
+
assert L == H * W, "input feature has wrong size"
|
271 |
+
|
272 |
+
x = x.view(B, H, W, C)
|
273 |
+
|
274 |
+
# padding
|
275 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
276 |
+
if pad_input:
|
277 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
278 |
+
|
279 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
280 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
281 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
282 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
283 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
284 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
285 |
+
|
286 |
+
x = self.norm(x)
|
287 |
+
x = self.reduction(x)
|
288 |
+
|
289 |
+
return x
|
290 |
+
|
291 |
+
|
292 |
+
class BasicLayer(nn.Module):
|
293 |
+
""" A basic Swin Transformer layer for one stage.
|
294 |
+
Args:
|
295 |
+
dim (int): Number of feature channels
|
296 |
+
depth (int): Depths of this stage.
|
297 |
+
num_heads (int): Number of attention head.
|
298 |
+
window_size (int): Local window size. Default: 7.
|
299 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
300 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
301 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
302 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
303 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
304 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
305 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
306 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
307 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
308 |
+
"""
|
309 |
+
|
310 |
+
def __init__(self,
|
311 |
+
dim,
|
312 |
+
depth,
|
313 |
+
num_heads,
|
314 |
+
window_size=7,
|
315 |
+
mlp_ratio=4.,
|
316 |
+
qkv_bias=True,
|
317 |
+
qk_scale=None,
|
318 |
+
drop=0.,
|
319 |
+
attn_drop=0.,
|
320 |
+
drop_path=0.,
|
321 |
+
norm_layer=nn.LayerNorm,
|
322 |
+
downsample=None,
|
323 |
+
use_checkpoint=False):
|
324 |
+
super().__init__()
|
325 |
+
self.window_size = window_size
|
326 |
+
self.shift_size = window_size // 2
|
327 |
+
self.depth = depth
|
328 |
+
self.use_checkpoint = use_checkpoint
|
329 |
+
|
330 |
+
# build blocks
|
331 |
+
self.blocks = nn.ModuleList([
|
332 |
+
SwinTransformerBlock(
|
333 |
+
dim=dim,
|
334 |
+
num_heads=num_heads,
|
335 |
+
window_size=window_size,
|
336 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
337 |
+
mlp_ratio=mlp_ratio,
|
338 |
+
qkv_bias=qkv_bias,
|
339 |
+
qk_scale=qk_scale,
|
340 |
+
drop=drop,
|
341 |
+
attn_drop=attn_drop,
|
342 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
343 |
+
norm_layer=norm_layer)
|
344 |
+
for i in range(depth)])
|
345 |
+
|
346 |
+
# patch merging layer
|
347 |
+
if downsample is not None:
|
348 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
349 |
+
else:
|
350 |
+
self.downsample = None
|
351 |
+
|
352 |
+
def forward(self, x, H, W):
|
353 |
+
""" Forward function.
|
354 |
+
Args:
|
355 |
+
x: Input feature, tensor size (B, H*W, C).
|
356 |
+
H, W: Spatial resolution of the input feature.
|
357 |
+
"""
|
358 |
+
|
359 |
+
# calculate attention mask for SW-MSA
|
360 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
361 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
362 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
363 |
+
h_slices = (slice(0, -self.window_size),
|
364 |
+
slice(-self.window_size, -self.shift_size),
|
365 |
+
slice(-self.shift_size, None))
|
366 |
+
w_slices = (slice(0, -self.window_size),
|
367 |
+
slice(-self.window_size, -self.shift_size),
|
368 |
+
slice(-self.shift_size, None))
|
369 |
+
cnt = 0
|
370 |
+
for h in h_slices:
|
371 |
+
for w in w_slices:
|
372 |
+
img_mask[:, h, w, :] = cnt
|
373 |
+
cnt += 1
|
374 |
+
|
375 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
376 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
377 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
378 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
379 |
+
|
380 |
+
for blk in self.blocks:
|
381 |
+
blk.H, blk.W = H, W
|
382 |
+
if self.use_checkpoint:
|
383 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
384 |
+
else:
|
385 |
+
x = blk(x, attn_mask)
|
386 |
+
if self.downsample is not None:
|
387 |
+
x_down = self.downsample(x, H, W)
|
388 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
389 |
+
return x, H, W, x_down, Wh, Ww
|
390 |
+
else:
|
391 |
+
return x, H, W, x, H, W
|
392 |
+
|
393 |
+
|
394 |
+
class PatchEmbed(nn.Module):
|
395 |
+
""" Image to Patch Embedding
|
396 |
+
Args:
|
397 |
+
patch_size (int): Patch token size. Default: 4.
|
398 |
+
in_chans (int): Number of input image channels. Default: 3.
|
399 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
400 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
401 |
+
"""
|
402 |
+
|
403 |
+
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
404 |
+
super().__init__()
|
405 |
+
patch_size = to_2tuple(patch_size)
|
406 |
+
self.patch_size = patch_size
|
407 |
+
|
408 |
+
self.in_chans = in_chans
|
409 |
+
self.embed_dim = embed_dim
|
410 |
+
|
411 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
412 |
+
if norm_layer is not None:
|
413 |
+
self.norm = norm_layer(embed_dim)
|
414 |
+
else:
|
415 |
+
self.norm = None
|
416 |
+
|
417 |
+
def forward(self, x):
|
418 |
+
"""Forward function."""
|
419 |
+
# padding
|
420 |
+
_, _, H, W = x.size()
|
421 |
+
if W % self.patch_size[1] != 0:
|
422 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
423 |
+
if H % self.patch_size[0] != 0:
|
424 |
+
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
425 |
+
|
426 |
+
x = self.proj(x) # B C Wh Ww
|
427 |
+
if self.norm is not None:
|
428 |
+
Wh, Ww = x.size(2), x.size(3)
|
429 |
+
x = x.flatten(2).transpose(1, 2)
|
430 |
+
x = self.norm(x)
|
431 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
432 |
+
|
433 |
+
return x
|
434 |
+
|
435 |
+
|
436 |
+
class SwinTransformer(Backbone):
|
437 |
+
""" Swin Transformer backbone.
|
438 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
439 |
+
https://arxiv.org/pdf/2103.14030
|
440 |
+
Args:
|
441 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
442 |
+
used in absolute postion embedding. Default 224.
|
443 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
444 |
+
in_chans (int): Number of input image channels. Default: 3.
|
445 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
446 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
447 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
448 |
+
window_size (int): Window size. Default: 7.
|
449 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
450 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
451 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
452 |
+
drop_rate (float): Dropout rate.
|
453 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
454 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
455 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
456 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
457 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
458 |
+
out_indices (Sequence[int]): Output from which stages.
|
459 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
460 |
+
-1 means not freezing any parameters.
|
461 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
462 |
+
"""
|
463 |
+
|
464 |
+
def __init__(self,
|
465 |
+
pretrain_img_size=224,
|
466 |
+
patch_size=4,
|
467 |
+
in_chans=3,
|
468 |
+
embed_dim=96,
|
469 |
+
depths=[2, 2, 6, 2],
|
470 |
+
num_heads=[3, 6, 12, 24],
|
471 |
+
window_size=7,
|
472 |
+
mlp_ratio=4.,
|
473 |
+
qkv_bias=True,
|
474 |
+
qk_scale=None,
|
475 |
+
drop_rate=0.,
|
476 |
+
attn_drop_rate=0.,
|
477 |
+
drop_path_rate=0.2,
|
478 |
+
norm_layer=nn.LayerNorm,
|
479 |
+
ape=False,
|
480 |
+
patch_norm=True,
|
481 |
+
frozen_stages=-1,
|
482 |
+
use_checkpoint=False,
|
483 |
+
out_features=None):
|
484 |
+
super(SwinTransformer, self).__init__()
|
485 |
+
|
486 |
+
self.pretrain_img_size = pretrain_img_size
|
487 |
+
self.num_layers = len(depths)
|
488 |
+
self.embed_dim = embed_dim
|
489 |
+
self.ape = ape
|
490 |
+
self.patch_norm = patch_norm
|
491 |
+
self.frozen_stages = frozen_stages
|
492 |
+
|
493 |
+
self.out_features = out_features
|
494 |
+
|
495 |
+
# split image into non-overlapping patches
|
496 |
+
self.patch_embed = PatchEmbed(
|
497 |
+
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
498 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
499 |
+
|
500 |
+
# absolute position embedding
|
501 |
+
if self.ape:
|
502 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
503 |
+
patch_size = to_2tuple(patch_size)
|
504 |
+
patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
|
505 |
+
|
506 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
|
507 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
508 |
+
|
509 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
510 |
+
|
511 |
+
# stochastic depth
|
512 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
513 |
+
|
514 |
+
self._out_feature_strides = {}
|
515 |
+
self._out_feature_channels = {}
|
516 |
+
|
517 |
+
# build layers
|
518 |
+
self.layers = nn.ModuleList()
|
519 |
+
for i_layer in range(self.num_layers):
|
520 |
+
layer = BasicLayer(
|
521 |
+
dim=int(embed_dim * 2 ** i_layer),
|
522 |
+
depth=depths[i_layer],
|
523 |
+
num_heads=num_heads[i_layer],
|
524 |
+
window_size=window_size,
|
525 |
+
mlp_ratio=mlp_ratio,
|
526 |
+
qkv_bias=qkv_bias,
|
527 |
+
qk_scale=qk_scale,
|
528 |
+
drop=drop_rate,
|
529 |
+
attn_drop=attn_drop_rate,
|
530 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
531 |
+
norm_layer=norm_layer,
|
532 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
533 |
+
use_checkpoint=use_checkpoint)
|
534 |
+
self.layers.append(layer)
|
535 |
+
|
536 |
+
stage = f'stage{i_layer+2}'
|
537 |
+
if stage in self.out_features:
|
538 |
+
self._out_feature_channels[stage] = embed_dim * 2 ** i_layer
|
539 |
+
self._out_feature_strides[stage] = 4 * 2 ** i_layer
|
540 |
+
|
541 |
+
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
542 |
+
self.num_features = num_features
|
543 |
+
|
544 |
+
# add a norm layer for each output
|
545 |
+
for i_layer in range(self.num_layers):
|
546 |
+
stage = f'stage{i_layer+2}'
|
547 |
+
if stage in self.out_features:
|
548 |
+
layer = norm_layer(num_features[i_layer])
|
549 |
+
layer_name = f'norm{i_layer}'
|
550 |
+
self.add_module(layer_name, layer)
|
551 |
+
|
552 |
+
self._freeze_stages()
|
553 |
+
|
554 |
+
def _freeze_stages(self):
|
555 |
+
if self.frozen_stages >= 0:
|
556 |
+
self.patch_embed.eval()
|
557 |
+
for param in self.patch_embed.parameters():
|
558 |
+
param.requires_grad = False
|
559 |
+
|
560 |
+
if self.frozen_stages >= 1 and self.ape:
|
561 |
+
self.absolute_pos_embed.requires_grad = False
|
562 |
+
|
563 |
+
if self.frozen_stages >= 2:
|
564 |
+
self.pos_drop.eval()
|
565 |
+
for i in range(0, self.frozen_stages - 1):
|
566 |
+
m = self.layers[i]
|
567 |
+
m.eval()
|
568 |
+
for param in m.parameters():
|
569 |
+
param.requires_grad = False
|
570 |
+
|
571 |
+
def init_weights(self, pretrained=None):
|
572 |
+
"""Initialize the weights in backbone.
|
573 |
+
Args:
|
574 |
+
pretrained (str, optional): Path to pre-trained weights.
|
575 |
+
Defaults to None.
|
576 |
+
"""
|
577 |
+
|
578 |
+
def _init_weights(m):
|
579 |
+
if isinstance(m, nn.Linear):
|
580 |
+
trunc_normal_(m.weight, std=.02)
|
581 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
582 |
+
nn.init.constant_(m.bias, 0)
|
583 |
+
elif isinstance(m, nn.LayerNorm):
|
584 |
+
nn.init.constant_(m.bias, 0)
|
585 |
+
nn.init.constant_(m.weight, 1.0)
|
586 |
+
|
587 |
+
self.apply(_init_weights)
|
588 |
+
|
589 |
+
def forward(self, x):
|
590 |
+
"""Forward function."""
|
591 |
+
x = self.patch_embed(x)
|
592 |
+
|
593 |
+
Wh, Ww = x.size(2), x.size(3)
|
594 |
+
if self.ape:
|
595 |
+
# interpolate the position embedding to the corresponding size
|
596 |
+
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
597 |
+
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
598 |
+
else:
|
599 |
+
x = x.flatten(2).transpose(1, 2)
|
600 |
+
x = self.pos_drop(x)
|
601 |
+
|
602 |
+
outs = {}
|
603 |
+
for i in range(self.num_layers):
|
604 |
+
layer = self.layers[i]
|
605 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
606 |
+
name = f'stage{i+2}'
|
607 |
+
if name in self.out_features:
|
608 |
+
norm_layer = getattr(self, f'norm{i}')
|
609 |
+
x_out = norm_layer(x_out)
|
610 |
+
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
611 |
+
outs[name] = out
|
612 |
+
|
613 |
+
return outs #{"stage%d" % (i+2,): out for i, out in enumerate(outs)} #tuple(outs)
|
614 |
+
|
615 |
+
def train(self, mode=True):
|
616 |
+
"""Convert the model into training mode while keep layers freezed."""
|
617 |
+
super(SwinTransformer, self).train(mode)
|
618 |
+
self._freeze_stages()
|
619 |
+
|
620 |
+
def output_shape(self):
|
621 |
+
return {
|
622 |
+
name: ShapeSpec(
|
623 |
+
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
|
624 |
+
)
|
625 |
+
for name in self.out_features
|
626 |
+
}
|
627 |
+
|
628 |
+
@BACKBONE_REGISTRY.register()
|
629 |
+
def build_swint_backbone(cfg, input_shape):
|
630 |
+
"""
|
631 |
+
Create a SwinT instance from config.
|
632 |
+
|
633 |
+
Returns:
|
634 |
+
VoVNet: a :class:`VoVNet` instance.
|
635 |
+
"""
|
636 |
+
out_features = cfg.MODEL.SWINT.OUT_FEATURES
|
637 |
+
|
638 |
+
return SwinTransformer(
|
639 |
+
patch_size=4,
|
640 |
+
in_chans=input_shape.channels,
|
641 |
+
embed_dim=cfg.MODEL.SWINT.EMBED_DIM,
|
642 |
+
depths=cfg.MODEL.SWINT.DEPTHS,
|
643 |
+
num_heads=cfg.MODEL.SWINT.NUM_HEADS,
|
644 |
+
window_size=cfg.MODEL.SWINT.WINDOW_SIZE,
|
645 |
+
mlp_ratio=cfg.MODEL.SWINT.MLP_RATIO,
|
646 |
+
qkv_bias=True,
|
647 |
+
qk_scale=None,
|
648 |
+
drop_rate=0.,
|
649 |
+
attn_drop_rate=0.,
|
650 |
+
drop_path_rate=cfg.MODEL.SWINT.DROP_PATH_RATE,
|
651 |
+
norm_layer=nn.LayerNorm,
|
652 |
+
ape=cfg.MODEL.SWINT.APE,
|
653 |
+
patch_norm=True,
|
654 |
+
frozen_stages=cfg.MODEL.BACKBONE.FREEZE_AT,
|
655 |
+
out_features=out_features
|
656 |
+
)
|
657 |
+
|
658 |
+
|
659 |
+
@BACKBONE_REGISTRY.register()
|
660 |
+
def build_swint_fpn_backbone(cfg, input_shape: ShapeSpec):
|
661 |
+
"""
|
662 |
+
Args:
|
663 |
+
cfg: a detectron2 CfgNode
|
664 |
+
|
665 |
+
Returns:
|
666 |
+
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
|
667 |
+
"""
|
668 |
+
bottom_up = build_swint_backbone(cfg, input_shape)
|
669 |
+
in_features = cfg.MODEL.FPN.IN_FEATURES
|
670 |
+
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
|
671 |
+
backbone = FPN(
|
672 |
+
bottom_up=bottom_up,
|
673 |
+
in_features=in_features,
|
674 |
+
out_channels=out_channels,
|
675 |
+
norm=cfg.MODEL.FPN.NORM,
|
676 |
+
top_block=LastLevelMaxPool(),
|
677 |
+
fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
|
678 |
+
)
|
679 |
+
return backbone
|
680 |
+
|
681 |
+
class LastLevelP6(nn.Module):
|
682 |
+
"""
|
683 |
+
This module is used in FCOS to generate extra layers
|
684 |
+
"""
|
685 |
+
|
686 |
+
def __init__(self, in_channels, out_channels, in_features="res5"):
|
687 |
+
super().__init__()
|
688 |
+
self.num_levels = 1
|
689 |
+
self.in_feature = in_features
|
690 |
+
self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
|
691 |
+
for module in [self.p6]:
|
692 |
+
weight_init.c2_xavier_fill(module)
|
693 |
+
|
694 |
+
def forward(self, x):
|
695 |
+
p6 = self.p6(x)
|
696 |
+
return [p6]
|
697 |
+
|
698 |
+
@BACKBONE_REGISTRY.register()
|
699 |
+
def build_retinanet_swint_fpn_backbone(cfg, input_shape: ShapeSpec):
|
700 |
+
"""
|
701 |
+
Args:
|
702 |
+
cfg: a detectron2 CfgNode
|
703 |
+
|
704 |
+
Returns:
|
705 |
+
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
|
706 |
+
"""
|
707 |
+
bottom_up = build_swint_backbone(cfg, input_shape)
|
708 |
+
in_features = cfg.MODEL.FPN.IN_FEATURES
|
709 |
+
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
|
710 |
+
top_levels = cfg.MODEL.FPN.TOP_LEVELS
|
711 |
+
in_channels_top = out_channels
|
712 |
+
if top_levels == 2:
|
713 |
+
top_block = LastLevelP6P7(in_channels_top, out_channels, "p5")
|
714 |
+
if top_levels == 1:
|
715 |
+
top_block = LastLevelP6(in_channels_top, out_channels, "p5")
|
716 |
+
elif top_levels == 0:
|
717 |
+
top_block = None
|
718 |
+
backbone = FPN(
|
719 |
+
bottom_up=bottom_up,
|
720 |
+
in_features=in_features,
|
721 |
+
out_channels=out_channels,
|
722 |
+
norm=cfg.MODEL.FPN.NORM,
|
723 |
+
top_block=top_block,
|
724 |
+
fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
|
725 |
+
)
|
726 |
+
return backbone
|
opendet2/modeling/layers/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .mlp import *
|
2 |
+
|
3 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
opendet2/modeling/layers/mlp.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import fvcore.nn.weight_init as weight_init
|
5 |
+
|
6 |
+
|
7 |
+
class MLP(nn.Module):
|
8 |
+
def __init__(self, in_dim, out_dim, hidden_dim=None):
|
9 |
+
super().__init__()
|
10 |
+
if not hidden_dim:
|
11 |
+
hidden_dim = in_dim
|
12 |
+
self.head = nn.Sequential(
|
13 |
+
nn.Linear(in_dim, hidden_dim),
|
14 |
+
nn.ReLU(inplace=True),
|
15 |
+
nn.Linear(hidden_dim, out_dim),
|
16 |
+
)
|
17 |
+
for layer in self.head:
|
18 |
+
if isinstance(layer, nn.Linear):
|
19 |
+
weight_init.c2_xavier_fill(layer)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
feat = self.head(x)
|
23 |
+
feat_norm = F.normalize(feat, dim=1)
|
24 |
+
return feat_norm
|
25 |
+
|
26 |
+
|
27 |
+
class ConvMLP(nn.Module):
|
28 |
+
def __init__(self, in_dim, out_dim, hidden_dim=None):
|
29 |
+
super().__init__()
|
30 |
+
if not hidden_dim:
|
31 |
+
hidden_dim = in_dim
|
32 |
+
self.head = nn.Sequential(
|
33 |
+
nn.Conv2d(in_dim, hidden_dim, kernel_size=3, stride=1, padding=1),
|
34 |
+
nn.ReLU(inplace=True),
|
35 |
+
nn.Conv2d(hidden_dim, out_dim, kernel_size=3, stride=1, padding=1),
|
36 |
+
)
|
37 |
+
# Initialization
|
38 |
+
for layer in self.head:
|
39 |
+
if isinstance(layer, nn.Conv2d):
|
40 |
+
torch.nn.init.normal_(layer.weight, mean=0, std=0.01)
|
41 |
+
torch.nn.init.constant_(layer.bias, 0)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
feat = self.head(x)
|
45 |
+
feat_norm = F.normalize(feat, dim=1)
|
46 |
+
return feat_norm
|
opendet2/modeling/losses/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .unknown_probability_loss import UPLoss
|
2 |
+
from .instance_contrastive_loss import ICLoss
|
3 |
+
|
4 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
opendet2/modeling/losses/instance_contrastive_loss.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class ICLoss(nn.Module):
|
7 |
+
""" Instance Contrastive Loss
|
8 |
+
"""
|
9 |
+
def __init__(self, tau=0.1):
|
10 |
+
super().__init__()
|
11 |
+
self.tau = tau
|
12 |
+
|
13 |
+
def forward(self, features, labels, queue_features, queue_labels):
|
14 |
+
device = features.device
|
15 |
+
mask = torch.eq(labels[:, None], queue_labels[:, None].T).float().to(device)
|
16 |
+
|
17 |
+
# compute logits
|
18 |
+
anchor_dot_contrast = torch.div(
|
19 |
+
torch.matmul(features, queue_features.T), self.tau)
|
20 |
+
|
21 |
+
# for numerical stability
|
22 |
+
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
|
23 |
+
logits = anchor_dot_contrast - logits_max.detach()
|
24 |
+
|
25 |
+
logits_mask = torch.ones_like(logits)
|
26 |
+
# mask itself
|
27 |
+
logits_mask[logits == 0] = 0
|
28 |
+
|
29 |
+
mask = mask * logits_mask
|
30 |
+
|
31 |
+
# compute log_prob
|
32 |
+
exp_logits = torch.exp(logits) * logits_mask
|
33 |
+
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
|
34 |
+
|
35 |
+
# compute mean of log-likelihood over positive
|
36 |
+
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
|
37 |
+
# loss
|
38 |
+
loss = - mean_log_prob_pos.mean()
|
39 |
+
# trick: avoid loss nan
|
40 |
+
return loss if not torch.isnan(loss) else features.new_tensor(0.0)
|
opendet2/modeling/losses/unknown_probability_loss.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.distributions as dists
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
|
9 |
+
class UPLoss(nn.Module):
|
10 |
+
"""Unknown Probability Loss
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self,
|
14 |
+
num_classes: int,
|
15 |
+
sampling_metric: str = "min_score",
|
16 |
+
topk: int = 3,
|
17 |
+
alpha: float = 1.0):
|
18 |
+
super().__init__()
|
19 |
+
self.num_classes = num_classes
|
20 |
+
assert sampling_metric in ["min_score", "max_entropy", "random"]
|
21 |
+
self.sampling_metric = sampling_metric
|
22 |
+
# if topk==-1, sample len(fg)*2 examples
|
23 |
+
self.topk = topk
|
24 |
+
self.alpha = alpha
|
25 |
+
|
26 |
+
def _soft_cross_entropy(self, input: Tensor, target: Tensor):
|
27 |
+
logprobs = F.log_softmax(input, dim=1)
|
28 |
+
return -(target * logprobs).sum() / input.shape[0]
|
29 |
+
|
30 |
+
def _sampling(self, scores: Tensor, labels: Tensor):
|
31 |
+
fg_inds = labels != self.num_classes
|
32 |
+
fg_scores, fg_labels = scores[fg_inds], labels[fg_inds]
|
33 |
+
bg_scores, bg_labels = scores[~fg_inds], labels[~fg_inds]
|
34 |
+
|
35 |
+
# remove unknown classes
|
36 |
+
_fg_scores = torch.cat(
|
37 |
+
[fg_scores[:, :self.num_classes-1], fg_scores[:, -1:]], dim=1)
|
38 |
+
_bg_scores = torch.cat(
|
39 |
+
[bg_scores[:, :self.num_classes-1], bg_scores[:, -1:]], dim=1)
|
40 |
+
|
41 |
+
num_fg = fg_scores.size(0)
|
42 |
+
topk = num_fg if (self.topk == -1) or (num_fg <
|
43 |
+
self.topk) else self.topk
|
44 |
+
# use maximum entropy as a metric for uncertainty
|
45 |
+
# we select topk proposals with maximum entropy
|
46 |
+
if self.sampling_metric == "max_entropy":
|
47 |
+
pos_metric = dists.Categorical(
|
48 |
+
_fg_scores.softmax(dim=1)).entropy()
|
49 |
+
neg_metric = dists.Categorical(
|
50 |
+
_bg_scores.softmax(dim=1)).entropy()
|
51 |
+
# use minimum score as a metric for uncertainty
|
52 |
+
# we select topk proposals with minimum max-score
|
53 |
+
elif self.sampling_metric == "min_score":
|
54 |
+
pos_metric = -_fg_scores.max(dim=1)[0]
|
55 |
+
neg_metric = -_bg_scores.max(dim=1)[0]
|
56 |
+
# we randomly select topk proposals
|
57 |
+
elif self.sampling_metric == "random":
|
58 |
+
pos_metric = torch.rand(_fg_scores.size(0),).to(scores.device)
|
59 |
+
neg_metric = torch.rand(_bg_scores.size(0),).to(scores.device)
|
60 |
+
|
61 |
+
_, pos_inds = pos_metric.topk(topk)
|
62 |
+
_, neg_inds = neg_metric.topk(topk)
|
63 |
+
fg_scores, fg_labels = fg_scores[pos_inds], fg_labels[pos_inds]
|
64 |
+
bg_scores, bg_labels = bg_scores[neg_inds], bg_labels[neg_inds]
|
65 |
+
|
66 |
+
return fg_scores, bg_scores, fg_labels, bg_labels
|
67 |
+
|
68 |
+
def forward(self, scores: Tensor, labels: Tensor):
|
69 |
+
fg_scores, bg_scores, fg_labels, bg_labels = self._sampling(
|
70 |
+
scores, labels)
|
71 |
+
# sample both fg and bg
|
72 |
+
scores = torch.cat([fg_scores, bg_scores])
|
73 |
+
labels = torch.cat([fg_labels, bg_labels])
|
74 |
+
|
75 |
+
num_sample, num_classes = scores.shape
|
76 |
+
mask = torch.arange(num_classes).repeat(
|
77 |
+
num_sample, 1).to(scores.device)
|
78 |
+
inds = mask != labels[:, None].repeat(1, num_classes)
|
79 |
+
mask = mask[inds].reshape(num_sample, num_classes-1)
|
80 |
+
|
81 |
+
gt_scores = torch.gather(
|
82 |
+
F.softmax(scores, dim=1), 1, labels[:, None]).squeeze(1)
|
83 |
+
mask_scores = torch.gather(scores, 1, mask)
|
84 |
+
|
85 |
+
gt_scores[gt_scores < 0] = 0.0
|
86 |
+
targets = torch.zeros_like(mask_scores)
|
87 |
+
num_fg = fg_scores.size(0)
|
88 |
+
targets[:num_fg, self.num_classes-2] = gt_scores[:num_fg] * \
|
89 |
+
(1-gt_scores[:num_fg]).pow(self.alpha)
|
90 |
+
targets[num_fg:, self.num_classes-1] = gt_scores[num_fg:] * \
|
91 |
+
(1-gt_scores[num_fg:]).pow(self.alpha)
|
92 |
+
|
93 |
+
return self._soft_cross_entropy(mask_scores, targets.detach())
|
opendet2/modeling/meta_arch/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .retinanet import OpenSetRetinaNet
|
2 |
+
|
3 |
+
__all__ = list(globals().keys())
|
opendet2/modeling/meta_arch/retinanet.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import logging
|
3 |
+
from typing import Dict, List, Tuple
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.distributions as dists
|
8 |
+
from detectron2.config import configurable
|
9 |
+
from detectron2.layers import ShapeSpec, cat, cross_entropy
|
10 |
+
from detectron2.modeling import META_ARCH_REGISTRY
|
11 |
+
from detectron2.modeling.box_regression import _dense_box_regression_loss
|
12 |
+
from detectron2.modeling.meta_arch.retinanet import RetinaNet, RetinaNetHead
|
13 |
+
from detectron2.modeling.postprocessing import detector_postprocess
|
14 |
+
from detectron2.structures import Boxes, Instances, pairwise_iou
|
15 |
+
from detectron2.utils.events import get_event_storage
|
16 |
+
from fvcore.nn import sigmoid_focal_loss_jit
|
17 |
+
from torch import Tensor, nn
|
18 |
+
from torch.nn import functional as F
|
19 |
+
|
20 |
+
from ..layers import ConvMLP
|
21 |
+
from ..losses import ICLoss
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
def permute_to_N_HWA_K(tensor, K: int):
|
27 |
+
"""
|
28 |
+
Transpose/reshape a tensor from (N, (Ai x K), H, W) to (N, (HxWxAi), K)
|
29 |
+
"""
|
30 |
+
assert tensor.dim() == 4, tensor.shape
|
31 |
+
N, _, H, W = tensor.shape
|
32 |
+
tensor = tensor.view(N, -1, K, H, W)
|
33 |
+
tensor = tensor.permute(0, 3, 4, 1, 2)
|
34 |
+
tensor = tensor.reshape(N, -1, K) # Size=(N,HWA,K)
|
35 |
+
return tensor
|
36 |
+
|
37 |
+
|
38 |
+
class UPLoss(nn.Module):
|
39 |
+
"""Unknown Probability Loss for RetinaNet
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self,
|
43 |
+
num_classes: int,
|
44 |
+
sampling_metric: str = "min_score",
|
45 |
+
topk: int = 3,
|
46 |
+
alpha: float = 1.0):
|
47 |
+
super().__init__()
|
48 |
+
self.num_classes = num_classes
|
49 |
+
assert sampling_metric in ["min_score", "max_entropy", "random"]
|
50 |
+
self.sampling_metric = sampling_metric
|
51 |
+
# if topk==-1, sample len(fg)*2 examples
|
52 |
+
self.topk = topk
|
53 |
+
self.alpha = alpha
|
54 |
+
|
55 |
+
def _soft_cross_entropy(self, input: Tensor, target: Tensor):
|
56 |
+
logprobs = F.log_softmax(input, dim=1)
|
57 |
+
return -(target * logprobs).sum() / input.shape[0]
|
58 |
+
|
59 |
+
def _sampling(self, scores: Tensor, labels: Tensor):
|
60 |
+
fg_inds = labels != self.num_classes
|
61 |
+
fg_scores, fg_labels = scores[fg_inds], labels[fg_inds]
|
62 |
+
|
63 |
+
# remove unknown classes
|
64 |
+
_fg_scores = torch.cat(
|
65 |
+
[fg_scores[:, :self.num_classes-1], fg_scores[:, -1:]], dim=1)
|
66 |
+
|
67 |
+
num_fg = fg_scores.size(0)
|
68 |
+
topk = num_fg if (self.topk == -1) or (num_fg <
|
69 |
+
self.topk) else self.topk
|
70 |
+
# use maximum entropy as a metric for uncertainty
|
71 |
+
# we select topk proposals with maximum entropy
|
72 |
+
if self.sampling_metric == "max_entropy":
|
73 |
+
pos_metric = dists.Categorical(
|
74 |
+
_fg_scores.softmax(dim=1)).entropy()
|
75 |
+
# use minimum score as a metric for uncertainty
|
76 |
+
# we select topk proposals with minimum max-score
|
77 |
+
elif self.sampling_metric == "min_score":
|
78 |
+
pos_metric = -_fg_scores.max(dim=1)[0]
|
79 |
+
# we randomly select topk proposals
|
80 |
+
elif self.sampling_metric == "random":
|
81 |
+
pos_metric = torch.rand(_fg_scores.size(0),).to(scores.device)
|
82 |
+
|
83 |
+
_, pos_inds = pos_metric.topk(topk)
|
84 |
+
fg_scores, fg_labels = fg_scores[pos_inds], fg_labels[pos_inds]
|
85 |
+
|
86 |
+
return fg_scores, fg_labels
|
87 |
+
|
88 |
+
def forward(self, scores: Tensor, labels: Tensor):
|
89 |
+
scores, labels = self._sampling(scores, labels)
|
90 |
+
|
91 |
+
num_sample, num_classes = scores.shape
|
92 |
+
mask = torch.arange(num_classes).repeat(
|
93 |
+
num_sample, 1).to(scores.device)
|
94 |
+
inds = mask != labels[:, None].repeat(1, num_classes)
|
95 |
+
mask = mask[inds].reshape(num_sample, num_classes-1)
|
96 |
+
|
97 |
+
gt_scores = torch.gather(
|
98 |
+
F.softmax(scores, dim=1), 1, labels[:, None]).squeeze(1)
|
99 |
+
mask_scores = torch.gather(scores, 1, mask)
|
100 |
+
|
101 |
+
gt_scores[gt_scores < 0] = 0.0
|
102 |
+
targets = torch.zeros_like(mask_scores)
|
103 |
+
targets[:, self.num_classes-2] = gt_scores * \
|
104 |
+
(1-gt_scores).pow(self.alpha)
|
105 |
+
|
106 |
+
return self._soft_cross_entropy(mask_scores, targets.detach())
|
107 |
+
|
108 |
+
|
109 |
+
@META_ARCH_REGISTRY.register()
|
110 |
+
class OpenSetRetinaNet(RetinaNet):
|
111 |
+
"""
|
112 |
+
Implement RetinaNet in :paper:`RetinaNet`.
|
113 |
+
"""
|
114 |
+
|
115 |
+
@configurable
|
116 |
+
def __init__(
|
117 |
+
self,
|
118 |
+
num_known_classes,
|
119 |
+
max_iters,
|
120 |
+
up_loss_start_iter,
|
121 |
+
up_loss_sampling_metric,
|
122 |
+
up_loss_topk,
|
123 |
+
up_loss_alpha,
|
124 |
+
up_loss_weight,
|
125 |
+
ins_con_out_dim,
|
126 |
+
ins_con_queue_size,
|
127 |
+
ins_con_in_queue_size,
|
128 |
+
ins_con_batch_iou_thr,
|
129 |
+
ins_con_queue_iou_thr,
|
130 |
+
ins_con_queue_tau,
|
131 |
+
ins_con_loss_weight,
|
132 |
+
*args,
|
133 |
+
**kargs,
|
134 |
+
):
|
135 |
+
super().__init__(*args, **kargs)
|
136 |
+
self.num_known_classes = num_known_classes
|
137 |
+
self.max_iters = max_iters
|
138 |
+
|
139 |
+
self.up_loss = UPLoss(
|
140 |
+
self.num_classes,
|
141 |
+
sampling_metric=up_loss_sampling_metric,
|
142 |
+
topk=up_loss_topk,
|
143 |
+
alpha=up_loss_alpha
|
144 |
+
)
|
145 |
+
self.up_loss_start_iter = up_loss_start_iter
|
146 |
+
self.up_loss_weight = up_loss_weight
|
147 |
+
|
148 |
+
self.ins_con_loss = ICLoss(tau=ins_con_queue_tau)
|
149 |
+
self.ins_con_out_dim = ins_con_out_dim
|
150 |
+
self.ins_con_queue_size = ins_con_queue_size
|
151 |
+
self.ins_con_in_queue_size = ins_con_in_queue_size
|
152 |
+
self.ins_con_batch_iou_thr = ins_con_batch_iou_thr
|
153 |
+
self.ins_con_queue_iou_thr = ins_con_queue_iou_thr
|
154 |
+
self.ins_con_loss_weight = ins_con_loss_weight
|
155 |
+
|
156 |
+
self.register_buffer('queue', torch.zeros(
|
157 |
+
self.num_known_classes, ins_con_queue_size, ins_con_out_dim))
|
158 |
+
self.register_buffer('queue_label', torch.empty(
|
159 |
+
self.num_known_classes, ins_con_queue_size).fill_(-1).long())
|
160 |
+
self.register_buffer('queue_ptr', torch.zeros(
|
161 |
+
self.num_known_classes, dtype=torch.long))
|
162 |
+
|
163 |
+
@classmethod
|
164 |
+
def from_config(cls, cfg):
|
165 |
+
ret = super().from_config(cfg)
|
166 |
+
backbone_shape = ret["backbone"].output_shape()
|
167 |
+
feature_shapes = [backbone_shape[f] for f in cfg.MODEL.RETINANET.IN_FEATURES]
|
168 |
+
head = OpenSetRetinaNetHead(cfg, feature_shapes)
|
169 |
+
ret.update({
|
170 |
+
"head": head,
|
171 |
+
"num_known_classes": cfg.MODEL.ROI_HEADS.NUM_KNOWN_CLASSES,
|
172 |
+
"max_iters": cfg.SOLVER.MAX_ITER,
|
173 |
+
|
174 |
+
"up_loss_start_iter": cfg.UPLOSS.START_ITER,
|
175 |
+
"up_loss_sampling_metric": cfg.UPLOSS.SAMPLING_METRIC,
|
176 |
+
"up_loss_topk": cfg.UPLOSS.TOPK,
|
177 |
+
"up_loss_alpha": cfg.UPLOSS.ALPHA,
|
178 |
+
"up_loss_weight": cfg.UPLOSS.WEIGHT,
|
179 |
+
|
180 |
+
"ins_con_out_dim": cfg.ICLOSS.OUT_DIM,
|
181 |
+
"ins_con_queue_size": cfg.ICLOSS.QUEUE_SIZE,
|
182 |
+
"ins_con_in_queue_size": cfg.ICLOSS.IN_QUEUE_SIZE,
|
183 |
+
"ins_con_batch_iou_thr": cfg.ICLOSS.BATCH_IOU_THRESH,
|
184 |
+
"ins_con_queue_iou_thr": cfg.ICLOSS.QUEUE_IOU_THRESH,
|
185 |
+
"ins_con_queue_tau": cfg.ICLOSS.TEMPERATURE,
|
186 |
+
"ins_con_loss_weight": cfg.ICLOSS.WEIGHT,
|
187 |
+
})
|
188 |
+
return ret
|
189 |
+
|
190 |
+
def get_up_loss(self, scores, gt_classes):
|
191 |
+
# start up loss after warmup iters
|
192 |
+
storage = get_event_storage()
|
193 |
+
if storage.iter > self.up_loss_start_iter:
|
194 |
+
loss_cls_up = self.up_loss(scores, gt_classes)
|
195 |
+
else:
|
196 |
+
loss_cls_up = scores.new_tensor(0.0)
|
197 |
+
|
198 |
+
return self.up_loss_weight * loss_cls_up
|
199 |
+
|
200 |
+
def get_ins_con_loss(self, feat, gt_classes, ious):
|
201 |
+
# select foreground and iou > thr instance in a mini-batch
|
202 |
+
pos_inds = (ious > self.ins_con_batch_iou_thr) & (
|
203 |
+
gt_classes != self.num_classes)
|
204 |
+
|
205 |
+
if not pos_inds.sum():
|
206 |
+
return feat.new_tensor(0.0)
|
207 |
+
|
208 |
+
feat, gt_classes = feat[pos_inds], gt_classes[pos_inds]
|
209 |
+
|
210 |
+
queue = self.queue.reshape(-1, self.ins_con_out_dim)
|
211 |
+
queue_label = self.queue_label.reshape(-1)
|
212 |
+
queue_inds = queue_label != -1 # filter empty queue
|
213 |
+
queue, queue_label = queue[queue_inds], queue_label[queue_inds]
|
214 |
+
|
215 |
+
loss_ins_con = self.ins_con_loss(feat, gt_classes, queue, queue_label)
|
216 |
+
# loss decay
|
217 |
+
storage = get_event_storage()
|
218 |
+
decay_weight = 1.0 - storage.iter / self.max_iters
|
219 |
+
return self.ins_con_loss_weight * decay_weight * loss_ins_con
|
220 |
+
|
221 |
+
@ torch.no_grad()
|
222 |
+
def _dequeue_and_enqueue(self, feat, gt_classes, ious, iou_thr=0.7):
|
223 |
+
# 1. gather variable
|
224 |
+
# feat = self.concat_all_gather(feat)
|
225 |
+
# gt_classes = self.concat_all_gather(gt_classes)
|
226 |
+
# ious = self.concat_all_gather(ious)
|
227 |
+
# 2. filter by iou and obj, remove bg
|
228 |
+
keep = (ious > iou_thr) & (gt_classes != self.num_classes)
|
229 |
+
feat, gt_classes = feat[keep], gt_classes[keep]
|
230 |
+
|
231 |
+
for i in range(self.num_known_classes):
|
232 |
+
ptr = int(self.queue_ptr[i])
|
233 |
+
cls_ind = gt_classes == i
|
234 |
+
cls_feat, cls_gt_classes = feat[cls_ind], gt_classes[cls_ind]
|
235 |
+
# 3. sort by similarity, low sim ranks first
|
236 |
+
cls_queue = self.queue[i, self.queue_label[i] != -1]
|
237 |
+
_, sim_inds = F.cosine_similarity(
|
238 |
+
cls_feat[:, None], cls_queue[None, :], dim=-1).mean(dim=1).sort()
|
239 |
+
top_sim_inds = sim_inds[:self.ins_con_in_queue_size]
|
240 |
+
cls_feat, cls_gt_classes = cls_feat[top_sim_inds], cls_gt_classes[top_sim_inds]
|
241 |
+
# 4. in queue
|
242 |
+
batch_size = cls_feat.size(
|
243 |
+
0) if ptr + cls_feat.size(0) <= self.ins_con_queue_size else self.ins_con_queue_size - ptr
|
244 |
+
self.queue[i, ptr:ptr+batch_size] = cls_feat[:batch_size]
|
245 |
+
self.queue_label[i, ptr:ptr +
|
246 |
+
batch_size] = cls_gt_classes[:batch_size]
|
247 |
+
|
248 |
+
ptr = ptr + batch_size if ptr + batch_size < self.ins_con_queue_size else 0
|
249 |
+
self.queue_ptr[i] = ptr
|
250 |
+
|
251 |
+
@ torch.no_grad()
|
252 |
+
def concat_all_gather(self, tensor):
|
253 |
+
tensors_gather = [torch.ones_like(tensor) for _ in range(
|
254 |
+
torch.distributed.get_world_size())]
|
255 |
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
256 |
+
output = torch.cat(tensors_gather, dim=0)
|
257 |
+
return output
|
258 |
+
|
259 |
+
def forward(self, batched_inputs: List[Dict[str, Tensor]]):
|
260 |
+
"""
|
261 |
+
Args:
|
262 |
+
batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
|
263 |
+
Each item in the list contains the inputs for one image.
|
264 |
+
For now, each item in the list is a dict that contains:
|
265 |
+
|
266 |
+
* image: Tensor, image in (C, H, W) format.
|
267 |
+
* instances: Instances
|
268 |
+
|
269 |
+
Other information that's included in the original dicts, such as:
|
270 |
+
|
271 |
+
* "height", "width" (int): the output resolution of the model, used in inference.
|
272 |
+
See :meth:`postprocess` for details.
|
273 |
+
Returns:
|
274 |
+
In training, dict[str, Tensor]: mapping from a named loss to a tensor storing the
|
275 |
+
loss. Used during training only. In inference, the standard output format, described
|
276 |
+
in :doc:`/tutorials/models`.
|
277 |
+
"""
|
278 |
+
images = self.preprocess_image(batched_inputs)
|
279 |
+
features = self.backbone(images.tensor)
|
280 |
+
features = [features[f] for f in self.head_in_features]
|
281 |
+
|
282 |
+
anchors = self.anchor_generator(features)
|
283 |
+
pred_logits, pred_anchor_deltas, pred_mlp_feats = self.head(features)
|
284 |
+
# Transpose the Hi*Wi*A dimension to the middle:
|
285 |
+
pred_logits = [permute_to_N_HWA_K(
|
286 |
+
x, self.num_classes) for x in pred_logits]
|
287 |
+
pred_anchor_deltas = [permute_to_N_HWA_K(
|
288 |
+
x, 4) for x in pred_anchor_deltas]
|
289 |
+
pred_mlp_feats = [permute_to_N_HWA_K(
|
290 |
+
x, self.ins_con_out_dim) for x in pred_mlp_feats]
|
291 |
+
|
292 |
+
if self.training:
|
293 |
+
assert not torch.jit.is_scripting(), "Not supported"
|
294 |
+
assert "instances" in batched_inputs[0], "Instance annotations are missing in training!"
|
295 |
+
gt_instances = [x["instances"].to(
|
296 |
+
self.device) for x in batched_inputs]
|
297 |
+
|
298 |
+
gt_labels, gt_boxes, gt_ious = self.label_anchors(
|
299 |
+
anchors, gt_instances)
|
300 |
+
losses = self.losses(anchors, pred_logits, pred_mlp_feats,
|
301 |
+
gt_labels, pred_anchor_deltas, gt_boxes, gt_ious)
|
302 |
+
|
303 |
+
if self.vis_period > 0:
|
304 |
+
storage = get_event_storage()
|
305 |
+
if storage.iter % self.vis_period == 0:
|
306 |
+
results = self.inference(
|
307 |
+
anchors, pred_logits, pred_anchor_deltas, images.image_sizes
|
308 |
+
)
|
309 |
+
self.visualize_training(batched_inputs, results)
|
310 |
+
|
311 |
+
return losses
|
312 |
+
else:
|
313 |
+
results = self.inference(
|
314 |
+
anchors, pred_logits, pred_anchor_deltas, images.image_sizes)
|
315 |
+
if torch.jit.is_scripting():
|
316 |
+
return results
|
317 |
+
processed_results = []
|
318 |
+
for results_per_image, input_per_image, image_size in zip(
|
319 |
+
results, batched_inputs, images.image_sizes
|
320 |
+
):
|
321 |
+
height = input_per_image.get("height", image_size[0])
|
322 |
+
width = input_per_image.get("width", image_size[1])
|
323 |
+
r = detector_postprocess(results_per_image, height, width)
|
324 |
+
processed_results.append({"instances": r})
|
325 |
+
return processed_results
|
326 |
+
|
327 |
+
def losses(self, anchors, pred_logits, pred_mlp_feats, gt_labels, pred_anchor_deltas, gt_boxes, gt_ious):
|
328 |
+
"""
|
329 |
+
Args:
|
330 |
+
anchors (list[Boxes]): a list of #feature level Boxes
|
331 |
+
gt_labels, gt_boxes: see output of :meth:`RetinaNet.label_anchors`.
|
332 |
+
Their shapes are (N, R) and (N, R, 4), respectively, where R is
|
333 |
+
the total number of anchors across levels, i.e. sum(Hi x Wi x Ai)
|
334 |
+
pred_logits, pred_anchor_deltas: both are list[Tensor]. Each element in the
|
335 |
+
list corresponds to one level and has shape (N, Hi * Wi * Ai, K or 4).
|
336 |
+
Where K is the number of classes used in `pred_logits`.
|
337 |
+
|
338 |
+
Returns:
|
339 |
+
dict[str, Tensor]:
|
340 |
+
mapping from a named loss to a scalar tensor
|
341 |
+
storing the loss. Used during training only. The dict keys are:
|
342 |
+
"loss_cls" and "loss_box_reg"
|
343 |
+
"""
|
344 |
+
num_images = len(gt_labels)
|
345 |
+
gt_labels = torch.stack(gt_labels) # (N, R)
|
346 |
+
|
347 |
+
valid_mask = gt_labels >= 0
|
348 |
+
pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes)
|
349 |
+
num_pos_anchors = pos_mask.sum().item()
|
350 |
+
get_event_storage().put_scalar("num_pos_anchors", num_pos_anchors / num_images)
|
351 |
+
self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + (
|
352 |
+
1 - self.loss_normalizer_momentum
|
353 |
+
) * max(num_pos_anchors, 1)
|
354 |
+
|
355 |
+
# classification and regression loss
|
356 |
+
gt_labels_target = F.one_hot(gt_labels[valid_mask], num_classes=self.num_classes + 1)[
|
357 |
+
:, :-1
|
358 |
+
] # no loss for the last (background) class
|
359 |
+
|
360 |
+
loss_cls_ce = sigmoid_focal_loss_jit(
|
361 |
+
cat(pred_logits, dim=1)[valid_mask],
|
362 |
+
gt_labels_target.to(pred_logits[0].dtype),
|
363 |
+
alpha=self.focal_loss_alpha,
|
364 |
+
gamma=self.focal_loss_gamma,
|
365 |
+
reduction="sum",
|
366 |
+
)
|
367 |
+
|
368 |
+
loss_cls_up = self.get_up_loss(cat(pred_logits, dim=1)[
|
369 |
+
valid_mask], gt_labels[valid_mask])
|
370 |
+
|
371 |
+
gt_ious = torch.stack(gt_ious)
|
372 |
+
# we first store feats in the queue, then cmopute the loss
|
373 |
+
pred_mlp_feats = cat(pred_mlp_feats, dim=1)[valid_mask] # [N, *, 128]
|
374 |
+
# [N*, 128]
|
375 |
+
pred_mlp_feats = pred_mlp_feats.reshape(-1, pred_mlp_feats.shape[-1])
|
376 |
+
self._dequeue_and_enqueue(
|
377 |
+
pred_mlp_feats, gt_labels[valid_mask], gt_ious[valid_mask], iou_thr=self.ins_con_queue_iou_thr)
|
378 |
+
loss_ins_con = self.get_ins_con_loss(
|
379 |
+
pred_mlp_feats, gt_labels[valid_mask], gt_ious[valid_mask])
|
380 |
+
|
381 |
+
loss_box_reg = _dense_box_regression_loss(
|
382 |
+
anchors,
|
383 |
+
self.box2box_transform,
|
384 |
+
pred_anchor_deltas,
|
385 |
+
gt_boxes,
|
386 |
+
pos_mask,
|
387 |
+
box_reg_loss_type=self.box_reg_loss_type,
|
388 |
+
smooth_l1_beta=self.smooth_l1_beta,
|
389 |
+
)
|
390 |
+
|
391 |
+
return {
|
392 |
+
"loss_cls_ce": loss_cls_ce / self.loss_normalizer,
|
393 |
+
"loss_box_reg": loss_box_reg / self.loss_normalizer,
|
394 |
+
"loss_ins_con": loss_ins_con,
|
395 |
+
"loss_cls_up": loss_cls_up,
|
396 |
+
}
|
397 |
+
|
398 |
+
@torch.no_grad()
|
399 |
+
def label_anchors(self, anchors, gt_instances):
|
400 |
+
|
401 |
+
anchors = Boxes.cat(anchors) # Rx4
|
402 |
+
|
403 |
+
gt_labels = []
|
404 |
+
matched_gt_boxes = []
|
405 |
+
matched_gt_ious = []
|
406 |
+
for gt_per_image in gt_instances:
|
407 |
+
match_quality_matrix = pairwise_iou(gt_per_image.gt_boxes, anchors)
|
408 |
+
matched_idxs, anchor_labels = self.anchor_matcher(
|
409 |
+
match_quality_matrix)
|
410 |
+
# del match_quality_matrix
|
411 |
+
|
412 |
+
if len(gt_per_image) > 0:
|
413 |
+
matched_gt_boxes_i = gt_per_image.gt_boxes.tensor[matched_idxs]
|
414 |
+
matched_gt_ious_i = match_quality_matrix.max(dim=1)[
|
415 |
+
0][matched_idxs]
|
416 |
+
|
417 |
+
gt_labels_i = gt_per_image.gt_classes[matched_idxs]
|
418 |
+
# Anchors with label 0 are treated as background.
|
419 |
+
gt_labels_i[anchor_labels == 0] = self.num_classes
|
420 |
+
# Anchors with label -1 are ignored.
|
421 |
+
gt_labels_i[anchor_labels == -1] = -1
|
422 |
+
else:
|
423 |
+
matched_gt_boxes_i = torch.zeros_like(anchors.tensor)
|
424 |
+
matched_gt_ious_i = torch.zeros_like(matched_idxs)
|
425 |
+
gt_labels_i = torch.zeros_like(matched_idxs) + self.num_classes
|
426 |
+
|
427 |
+
gt_labels.append(gt_labels_i)
|
428 |
+
matched_gt_boxes.append(matched_gt_boxes_i)
|
429 |
+
matched_gt_ious.append(matched_gt_ious_i)
|
430 |
+
|
431 |
+
del match_quality_matrix
|
432 |
+
|
433 |
+
return gt_labels, matched_gt_boxes, matched_gt_ious
|
434 |
+
|
435 |
+
|
436 |
+
class OpenSetRetinaNetHead(RetinaNetHead):
|
437 |
+
"""
|
438 |
+
The head used in RetinaNet for object classification and box regression.
|
439 |
+
It has two subnets for the two tasks, with a common structure but separate parameters.
|
440 |
+
"""
|
441 |
+
|
442 |
+
@configurable
|
443 |
+
def __init__(
|
444 |
+
self,
|
445 |
+
*args,
|
446 |
+
ins_con_out_dim,
|
447 |
+
**kargs
|
448 |
+
):
|
449 |
+
super().__init__(*args, **kargs)
|
450 |
+
self.mlp = ConvMLP(kargs["conv_dims"][-1], ins_con_out_dim * kargs["num_anchors"])
|
451 |
+
|
452 |
+
@classmethod
|
453 |
+
def from_config(cls, cfg, input_shape: List[ShapeSpec]):
|
454 |
+
ret = super().from_config(cfg, input_shape)
|
455 |
+
ret["ins_con_out_dim"] = cfg.ICLOSS.OUT_DIM
|
456 |
+
return ret
|
457 |
+
|
458 |
+
def forward(self, features: List[Tensor]):
|
459 |
+
"""
|
460 |
+
Arguments:
|
461 |
+
features (list[Tensor]): FPN feature map tensors in high to low resolution.
|
462 |
+
Each tensor in the list correspond to different feature levels.
|
463 |
+
|
464 |
+
Returns:
|
465 |
+
logits (list[Tensor]): #lvl tensors, each has shape (N, AxK, Hi, Wi).
|
466 |
+
The tensor predicts the classification probability
|
467 |
+
at each spatial position for each of the A anchors and K object
|
468 |
+
classes.
|
469 |
+
bbox_reg (list[Tensor]): #lvl tensors, each has shape (N, Ax4, Hi, Wi).
|
470 |
+
The tensor predicts 4-vector (dx,dy,dw,dh) box
|
471 |
+
regression values for every anchor. These values are the
|
472 |
+
relative offset between the anchor and the ground truth box.
|
473 |
+
"""
|
474 |
+
logits = []
|
475 |
+
mlp_feats = []
|
476 |
+
bbox_reg = []
|
477 |
+
for feature in features:
|
478 |
+
cls_feat = self.cls_subnet(feature)
|
479 |
+
mlp_feats.append(self.mlp(cls_feat))
|
480 |
+
logits.append(self.cls_score(cls_feat))
|
481 |
+
|
482 |
+
bbox_reg.append(self.bbox_pred(self.bbox_subnet(feature)))
|
483 |
+
return logits, bbox_reg, mlp_feats
|
opendet2/modeling/roi_heads/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .roi_heads import OpenSetStandardROIHeads
|
2 |
+
from .box_head import FastRCNNSeparateConvFCHead, FastRCNNSeparateDropoutConvFCHead
|
3 |
+
|
4 |
+
__all__ = list(globals().keys())
|
opendet2/modeling/roi_heads/box_head.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import fvcore.nn.weight_init as weight_init
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from detectron2.config import configurable
|
8 |
+
from detectron2.layers import Conv2d, ShapeSpec, get_norm
|
9 |
+
from detectron2.modeling.roi_heads import ROI_BOX_HEAD_REGISTRY
|
10 |
+
from detectron2.utils.registry import Registry
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
@ROI_BOX_HEAD_REGISTRY.register()
|
15 |
+
class FastRCNNSeparateConvFCHead(nn.Module):
|
16 |
+
"""
|
17 |
+
FastRCNN with separate ConvFC layers
|
18 |
+
"""
|
19 |
+
|
20 |
+
@configurable
|
21 |
+
def __init__(
|
22 |
+
self, input_shape: ShapeSpec, *, conv_dims: List[int], fc_dims: List[int], conv_norm=""
|
23 |
+
):
|
24 |
+
"""
|
25 |
+
NOTE: this interface is experimental.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
input_shape (ShapeSpec): shape of the input feature.
|
29 |
+
conv_dims (list[int]): the output dimensions of the conv layers
|
30 |
+
fc_dims (list[int]): the output dimensions of the fc layers
|
31 |
+
conv_norm (str or callable): normalization for the conv layers.
|
32 |
+
See :func:`detectron2.layers.get_norm` for supported types.
|
33 |
+
"""
|
34 |
+
super().__init__()
|
35 |
+
assert len(conv_dims) + len(fc_dims) > 0
|
36 |
+
self.conv_dims = conv_dims
|
37 |
+
self.fc_dims = fc_dims
|
38 |
+
|
39 |
+
self._output_size = (input_shape.channels,
|
40 |
+
input_shape.height, input_shape.width)
|
41 |
+
|
42 |
+
self.reg_conv_norm_relus = self._add_conv_norm_relus(
|
43 |
+
self._output_size[0], conv_dims, conv_norm)
|
44 |
+
self.cls_conv_norm_relus = self._add_conv_norm_relus(
|
45 |
+
self._output_size[0], conv_dims, conv_norm)
|
46 |
+
conv_dim = self._output_size[0] if len(conv_dims) == 0 else conv_dims[-1]
|
47 |
+
self._output_size = (
|
48 |
+
conv_dim, self._output_size[1], self._output_size[2])
|
49 |
+
|
50 |
+
self.reg_fcs = self._add_fcs(np.prod(self._output_size), fc_dims)
|
51 |
+
self.cls_fcs = self._add_fcs(np.prod(self._output_size), fc_dims)
|
52 |
+
self._output_size = self._output_size if len(fc_dims)==0 else fc_dims[-1]
|
53 |
+
|
54 |
+
for layer in self.reg_conv_norm_relus:
|
55 |
+
weight_init.c2_msra_fill(layer)
|
56 |
+
for layer in self.cls_conv_norm_relus:
|
57 |
+
weight_init.c2_msra_fill(layer)
|
58 |
+
for layer in self.cls_fcs:
|
59 |
+
if isinstance(layer, nn.Linear):
|
60 |
+
weight_init.c2_xavier_fill(layer)
|
61 |
+
for layer in self.reg_fcs:
|
62 |
+
if isinstance(layer, nn.Linear):
|
63 |
+
weight_init.c2_xavier_fill(layer)
|
64 |
+
|
65 |
+
@classmethod
|
66 |
+
def from_config(cls, cfg, input_shape):
|
67 |
+
num_conv = cfg.MODEL.ROI_BOX_HEAD.NUM_CONV
|
68 |
+
conv_dim = cfg.MODEL.ROI_BOX_HEAD.CONV_DIM
|
69 |
+
num_fc = cfg.MODEL.ROI_BOX_HEAD.NUM_FC
|
70 |
+
fc_dim = cfg.MODEL.ROI_BOX_HEAD.FC_DIM
|
71 |
+
return {
|
72 |
+
"input_shape": input_shape,
|
73 |
+
"conv_dims": [conv_dim] * num_conv,
|
74 |
+
"fc_dims": [fc_dim] * num_fc,
|
75 |
+
"conv_norm": cfg.MODEL.ROI_BOX_HEAD.NORM,
|
76 |
+
}
|
77 |
+
|
78 |
+
def _add_conv_norm_relus(self, input_dim, conv_dims, conv_norm):
|
79 |
+
conv_norm_relus = []
|
80 |
+
for k, conv_dim in enumerate(conv_dims):
|
81 |
+
conv = Conv2d(
|
82 |
+
input_dim,
|
83 |
+
conv_dim,
|
84 |
+
kernel_size=3,
|
85 |
+
padding=1,
|
86 |
+
bias=not conv_norm,
|
87 |
+
norm=get_norm(conv_norm, conv_dim),
|
88 |
+
activation=nn.ReLU(),
|
89 |
+
)
|
90 |
+
input_dim = conv_dim
|
91 |
+
conv_norm_relus.append(conv)
|
92 |
+
|
93 |
+
return nn.Sequential(*conv_norm_relus)
|
94 |
+
|
95 |
+
def _add_fcs(self, input_dim, fc_dims):
|
96 |
+
fcs = []
|
97 |
+
for k, fc_dim in enumerate(fc_dims):
|
98 |
+
if k == 0:
|
99 |
+
fcs.append(nn.Flatten())
|
100 |
+
fc = nn.Linear(int(input_dim), fc_dim)
|
101 |
+
fcs.append(fc)
|
102 |
+
fcs.append(nn.ReLU())
|
103 |
+
input_dim = fc_dim
|
104 |
+
return nn.Sequential(*fcs)
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
reg_feat = x
|
108 |
+
cls_feat = x
|
109 |
+
if len(self.conv_dims) > 0:
|
110 |
+
reg_feat = self.reg_conv_norm_relus(x)
|
111 |
+
cls_feat = self.cls_conv_norm_relus(x)
|
112 |
+
if len(self.fc_dims) > 0:
|
113 |
+
reg_feat = self.reg_fcs(reg_feat)
|
114 |
+
cls_feat = self.cls_fcs(cls_feat)
|
115 |
+
return reg_feat, cls_feat
|
116 |
+
|
117 |
+
@property
|
118 |
+
@torch.jit.unused
|
119 |
+
def output_shape(self):
|
120 |
+
"""
|
121 |
+
Returns:
|
122 |
+
ShapeSpec: the output feature shape
|
123 |
+
"""
|
124 |
+
o = self._output_size
|
125 |
+
if isinstance(o, int):
|
126 |
+
return ShapeSpec(channels=o)
|
127 |
+
else:
|
128 |
+
return ShapeSpec(channels=o[0], height=o[1], width=o[2])
|
129 |
+
|
130 |
+
|
131 |
+
@ROI_BOX_HEAD_REGISTRY.register()
|
132 |
+
class FastRCNNSeparateDropoutConvFCHead(nn.Module):
|
133 |
+
"""Add dropout before each conv/fc layer
|
134 |
+
"""
|
135 |
+
def _add_conv_norm_relus(self, input_dim, conv_dims, conv_norm):
|
136 |
+
conv_norm_relus = []
|
137 |
+
for k, conv_dim in enumerate(conv_dims):
|
138 |
+
conv = Conv2d(
|
139 |
+
input_dim,
|
140 |
+
conv_dim,
|
141 |
+
kernel_size=3,
|
142 |
+
padding=1,
|
143 |
+
bias=not conv_norm,
|
144 |
+
norm=get_norm(conv_norm, conv_dim),
|
145 |
+
activation=nn.ReLU(),
|
146 |
+
)
|
147 |
+
input_dim = conv_dim
|
148 |
+
conv_norm_relus.append(nn.Dropout2d(p=0.5))
|
149 |
+
conv_norm_relus.append(conv)
|
150 |
+
|
151 |
+
return nn.Sequential(*conv_norm_relus)
|
152 |
+
|
153 |
+
def _add_fcs(self, input_dim, fc_dims):
|
154 |
+
fcs = []
|
155 |
+
for k, fc_dim in enumerate(fc_dims):
|
156 |
+
if k == 0:
|
157 |
+
fcs.append(nn.Flatten())
|
158 |
+
fc = nn.Linear(int(input_dim), fc_dim)
|
159 |
+
fcs.append(nn.Dropout2d(p=0.5))
|
160 |
+
fcs.append(fc)
|
161 |
+
fcs.append(nn.ReLU())
|
162 |
+
input_dim = fc_dim
|
163 |
+
return nn.Sequential(*fcs)
|
opendet2/modeling/roi_heads/fast_rcnn.py
ADDED
@@ -0,0 +1,645 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import itertools
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
from typing import Dict, List, Tuple, Union
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.distributions as dists
|
12 |
+
from detectron2.config import configurable
|
13 |
+
from detectron2.layers import (ShapeSpec, batched_nms, cat, cross_entropy,
|
14 |
+
nonzero_tuple)
|
15 |
+
from detectron2.modeling.box_regression import Box2BoxTransform
|
16 |
+
from detectron2.modeling.roi_heads.fast_rcnn import (FastRCNNOutputLayers,
|
17 |
+
_log_classification_stats)
|
18 |
+
from detectron2.structures import Boxes, Instances, pairwise_iou
|
19 |
+
from detectron2.structures.boxes import matched_boxlist_iou
|
20 |
+
# fast_rcnn_inference)
|
21 |
+
from detectron2.utils import comm
|
22 |
+
from detectron2.utils.events import get_event_storage
|
23 |
+
from detectron2.utils.registry import Registry
|
24 |
+
from fvcore.nn import giou_loss, smooth_l1_loss
|
25 |
+
from torch import nn
|
26 |
+
from torch.nn import functional as F
|
27 |
+
|
28 |
+
from ..layers import MLP
|
29 |
+
from ..losses import ICLoss, UPLoss
|
30 |
+
|
31 |
+
ROI_BOX_OUTPUT_LAYERS_REGISTRY = Registry("ROI_BOX_OUTPUT_LAYERS")
|
32 |
+
ROI_BOX_OUTPUT_LAYERS_REGISTRY.__doc__ = """
|
33 |
+
ROI_BOX_OUTPUT_LAYERS
|
34 |
+
"""
|
35 |
+
|
36 |
+
|
37 |
+
def fast_rcnn_inference(
|
38 |
+
boxes: List[torch.Tensor],
|
39 |
+
scores: List[torch.Tensor],
|
40 |
+
image_shapes: List[Tuple[int, int]],
|
41 |
+
score_thresh: float,
|
42 |
+
nms_thresh: float,
|
43 |
+
topk_per_image: int,
|
44 |
+
vis_iou_thr: float = 1.0,
|
45 |
+
):
|
46 |
+
result_per_image = [
|
47 |
+
fast_rcnn_inference_single_image(
|
48 |
+
boxes_per_image, scores_per_image, image_shape, score_thresh, nms_thresh, topk_per_image, vis_iou_thr
|
49 |
+
)
|
50 |
+
for scores_per_image, boxes_per_image, image_shape in zip(scores, boxes, image_shapes)
|
51 |
+
]
|
52 |
+
return [x[0] for x in result_per_image], [x[1] for x in result_per_image]
|
53 |
+
|
54 |
+
|
55 |
+
def fast_rcnn_inference_single_image(
|
56 |
+
boxes,
|
57 |
+
scores,
|
58 |
+
image_shape: Tuple[int, int],
|
59 |
+
score_thresh: float,
|
60 |
+
nms_thresh: float,
|
61 |
+
topk_per_image: int,
|
62 |
+
vis_iou_thr: float,
|
63 |
+
):
|
64 |
+
valid_mask = torch.isfinite(boxes).all(
|
65 |
+
dim=1) & torch.isfinite(scores).all(dim=1)
|
66 |
+
if not valid_mask.all():
|
67 |
+
boxes = boxes[valid_mask]
|
68 |
+
scores = scores[valid_mask]
|
69 |
+
|
70 |
+
scores = scores[:, :-1]
|
71 |
+
num_bbox_reg_classes = boxes.shape[1] // 4
|
72 |
+
# Convert to Boxes to use the `clip` function ...
|
73 |
+
boxes = Boxes(boxes.reshape(-1, 4))
|
74 |
+
boxes.clip(image_shape)
|
75 |
+
boxes = boxes.tensor.view(-1, num_bbox_reg_classes, 4) # R x C x 4
|
76 |
+
|
77 |
+
# 1. Filter results based on detection scores. It can make NMS more efficient
|
78 |
+
# by filtering out low-confidence detections.
|
79 |
+
filter_mask = scores > score_thresh # R x K
|
80 |
+
# R' x 2. First column contains indices of the R predictions;
|
81 |
+
# Second column contains indices of classes.
|
82 |
+
filter_inds = filter_mask.nonzero()
|
83 |
+
if num_bbox_reg_classes == 1:
|
84 |
+
boxes = boxes[filter_inds[:, 0], 0]
|
85 |
+
else:
|
86 |
+
boxes = boxes[filter_mask]
|
87 |
+
scores = scores[filter_mask]
|
88 |
+
|
89 |
+
# 2. Apply NMS for each class independently.
|
90 |
+
keep = batched_nms(boxes, scores, filter_inds[:, 1], nms_thresh)
|
91 |
+
if topk_per_image >= 0:
|
92 |
+
keep = keep[:topk_per_image]
|
93 |
+
boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[keep]
|
94 |
+
|
95 |
+
# apply nms between known classes and unknown class for visualization.
|
96 |
+
if vis_iou_thr < 1.0:
|
97 |
+
boxes, scores, filter_inds = unknown_aware_nms(
|
98 |
+
boxes, scores, filter_inds, iou_thr=vis_iou_thr)
|
99 |
+
|
100 |
+
result = Instances(image_shape)
|
101 |
+
result.pred_boxes = Boxes(boxes)
|
102 |
+
result.scores = scores
|
103 |
+
result.pred_classes = filter_inds[:, 1]
|
104 |
+
return result, filter_inds[:, 0]
|
105 |
+
|
106 |
+
|
107 |
+
def unknown_aware_nms(boxes, scores, labels, ukn_class_id=80, iou_thr=0.9):
|
108 |
+
u_inds = labels[:, 1] == ukn_class_id
|
109 |
+
k_inds = ~u_inds
|
110 |
+
if k_inds.sum() == 0 or u_inds.sum() == 0:
|
111 |
+
return boxes, scores, labels
|
112 |
+
|
113 |
+
k_boxes, k_scores, k_labels = boxes[k_inds], scores[k_inds], labels[k_inds]
|
114 |
+
u_boxes, u_scores, u_labels = boxes[u_inds], scores[u_inds], labels[u_inds]
|
115 |
+
|
116 |
+
ious = pairwise_iou(Boxes(k_boxes), Boxes(u_boxes))
|
117 |
+
mask = torch.ones((ious.size(0), ious.size(1), 2), device=ious.device)
|
118 |
+
inds = (ious > iou_thr).nonzero()
|
119 |
+
if not inds.numel():
|
120 |
+
return boxes, scores, labels
|
121 |
+
|
122 |
+
for [ind_x, ind_y] in inds:
|
123 |
+
if k_scores[ind_x] >= u_scores[ind_y]:
|
124 |
+
mask[ind_x, ind_y, 1] = 0
|
125 |
+
else:
|
126 |
+
mask[ind_x, ind_y, 0] = 0
|
127 |
+
|
128 |
+
k_inds = mask[..., 0].mean(dim=1) == 1
|
129 |
+
u_inds = mask[..., 1].mean(dim=0) == 1
|
130 |
+
|
131 |
+
k_boxes, k_scores, k_labels = k_boxes[k_inds], k_scores[k_inds], k_labels[k_inds]
|
132 |
+
u_boxes, u_scores, u_labels = u_boxes[u_inds], u_scores[u_inds], u_labels[u_inds]
|
133 |
+
|
134 |
+
boxes = torch.cat([k_boxes, u_boxes])
|
135 |
+
scores = torch.cat([k_scores, u_scores])
|
136 |
+
labels = torch.cat([k_labels, u_labels])
|
137 |
+
|
138 |
+
return boxes, scores, labels
|
139 |
+
|
140 |
+
|
141 |
+
logger = logging.getLogger(__name__)
|
142 |
+
|
143 |
+
|
144 |
+
def build_roi_box_output_layers(cfg, input_shape):
|
145 |
+
"""
|
146 |
+
Build ROIHeads defined by `cfg.MODEL.ROI_HEADS.NAME`.
|
147 |
+
"""
|
148 |
+
name = cfg.MODEL.ROI_BOX_HEAD.OUTPUT_LAYERS
|
149 |
+
return ROI_BOX_OUTPUT_LAYERS_REGISTRY.get(name)(cfg, input_shape)
|
150 |
+
|
151 |
+
|
152 |
+
@ROI_BOX_OUTPUT_LAYERS_REGISTRY.register()
|
153 |
+
class CosineFastRCNNOutputLayers(FastRCNNOutputLayers):
|
154 |
+
|
155 |
+
@configurable
|
156 |
+
def __init__(
|
157 |
+
self,
|
158 |
+
*args,
|
159 |
+
scale: int = 20,
|
160 |
+
vis_iou_thr: float = 1.0,
|
161 |
+
**kargs,
|
162 |
+
|
163 |
+
):
|
164 |
+
super().__init__(*args, **kargs)
|
165 |
+
# prediction layer for num_classes foreground classes and one background class (hence + 1)
|
166 |
+
self.cls_score = nn.Linear(
|
167 |
+
self.cls_score.in_features, self.num_classes + 1, bias=False)
|
168 |
+
nn.init.normal_(self.cls_score.weight, std=0.01)
|
169 |
+
# scaling factor
|
170 |
+
self.scale = scale
|
171 |
+
self.vis_iou_thr = vis_iou_thr
|
172 |
+
|
173 |
+
@classmethod
|
174 |
+
def from_config(cls, cfg, input_shape):
|
175 |
+
ret = super().from_config(cfg, input_shape)
|
176 |
+
ret['scale'] = cfg.MODEL.ROI_HEADS.COSINE_SCALE
|
177 |
+
ret['vis_iou_thr'] = cfg.MODEL.ROI_HEADS.VIS_IOU_THRESH
|
178 |
+
return ret
|
179 |
+
|
180 |
+
def forward(self, feats):
|
181 |
+
|
182 |
+
# support shared & sepearte head
|
183 |
+
if isinstance(feats, tuple):
|
184 |
+
reg_x, cls_x = feats
|
185 |
+
else:
|
186 |
+
reg_x = cls_x = feats
|
187 |
+
|
188 |
+
if reg_x.dim() > 2:
|
189 |
+
reg_x = torch.flatten(reg_x, start_dim=1)
|
190 |
+
cls_x = torch.flatten(cls_x, start_dim=1)
|
191 |
+
|
192 |
+
x_norm = torch.norm(cls_x, p=2, dim=1).unsqueeze(1).expand_as(cls_x)
|
193 |
+
x_normalized = cls_x.div(x_norm + 1e-5)
|
194 |
+
|
195 |
+
# normalize weight
|
196 |
+
temp_norm = (
|
197 |
+
torch.norm(self.cls_score.weight.data, p=2, dim=1)
|
198 |
+
.unsqueeze(1)
|
199 |
+
.expand_as(self.cls_score.weight.data)
|
200 |
+
)
|
201 |
+
self.cls_score.weight.data = self.cls_score.weight.data.div(
|
202 |
+
temp_norm + 1e-5
|
203 |
+
)
|
204 |
+
cos_dist = self.cls_score(x_normalized)
|
205 |
+
scores = self.scale * cos_dist
|
206 |
+
proposal_deltas = self.bbox_pred(reg_x)
|
207 |
+
|
208 |
+
return scores, proposal_deltas
|
209 |
+
|
210 |
+
def inference(self, predictions: Tuple[torch.Tensor, torch.Tensor], proposals: List[Instances]):
|
211 |
+
|
212 |
+
boxes = self.predict_boxes(predictions, proposals)
|
213 |
+
scores = self.predict_probs(predictions, proposals)
|
214 |
+
image_shapes = [x.image_size for x in proposals]
|
215 |
+
return fast_rcnn_inference(
|
216 |
+
boxes,
|
217 |
+
scores,
|
218 |
+
image_shapes,
|
219 |
+
self.test_score_thresh,
|
220 |
+
self.test_nms_thresh,
|
221 |
+
self.test_topk_per_image,
|
222 |
+
self.vis_iou_thr,
|
223 |
+
)
|
224 |
+
|
225 |
+
def predict_boxes(
|
226 |
+
self, predictions: Tuple[torch.Tensor, torch.Tensor], proposals: List[Instances]
|
227 |
+
):
|
228 |
+
if not len(proposals):
|
229 |
+
return []
|
230 |
+
proposal_deltas = predictions[1]
|
231 |
+
num_prop_per_image = [len(p) for p in proposals]
|
232 |
+
proposal_boxes = cat(
|
233 |
+
[p.proposal_boxes.tensor for p in proposals], dim=0)
|
234 |
+
predict_boxes = self.box2box_transform.apply_deltas(
|
235 |
+
proposal_deltas,
|
236 |
+
proposal_boxes,
|
237 |
+
) # Nx(KxB)
|
238 |
+
return predict_boxes.split(num_prop_per_image)
|
239 |
+
|
240 |
+
def predict_probs(
|
241 |
+
self, predictions: Tuple[torch.Tensor, torch.Tensor], proposals: List[Instances]
|
242 |
+
):
|
243 |
+
scores = predictions[0]
|
244 |
+
num_inst_per_image = [len(p) for p in proposals]
|
245 |
+
probs = F.softmax(scores, dim=-1)
|
246 |
+
return probs.split(num_inst_per_image, dim=0)
|
247 |
+
|
248 |
+
|
249 |
+
@ROI_BOX_OUTPUT_LAYERS_REGISTRY.register()
|
250 |
+
class OpenDetFastRCNNOutputLayers(CosineFastRCNNOutputLayers):
|
251 |
+
@configurable
|
252 |
+
def __init__(
|
253 |
+
self,
|
254 |
+
*args,
|
255 |
+
num_known_classes,
|
256 |
+
max_iters,
|
257 |
+
up_loss_start_iter,
|
258 |
+
up_loss_sampling_metric,
|
259 |
+
up_loss_topk,
|
260 |
+
up_loss_alpha,
|
261 |
+
up_loss_weight,
|
262 |
+
ic_loss_out_dim,
|
263 |
+
ic_loss_queue_size,
|
264 |
+
ic_loss_in_queue_size,
|
265 |
+
ic_loss_batch_iou_thr,
|
266 |
+
ic_loss_queue_iou_thr,
|
267 |
+
ic_loss_queue_tau,
|
268 |
+
ic_loss_weight,
|
269 |
+
**kargs
|
270 |
+
):
|
271 |
+
super().__init__(*args, **kargs)
|
272 |
+
self.num_known_classes = num_known_classes
|
273 |
+
self.max_iters = max_iters
|
274 |
+
|
275 |
+
self.up_loss = UPLoss(
|
276 |
+
self.num_classes,
|
277 |
+
sampling_metric=up_loss_sampling_metric,
|
278 |
+
topk=up_loss_topk,
|
279 |
+
alpha=up_loss_alpha
|
280 |
+
)
|
281 |
+
self.up_loss_start_iter = up_loss_start_iter
|
282 |
+
self.up_loss_weight = up_loss_weight
|
283 |
+
|
284 |
+
self.encoder = MLP(self.cls_score.in_features, ic_loss_out_dim)
|
285 |
+
self.ic_loss_loss = ICLoss(tau=ic_loss_queue_tau)
|
286 |
+
self.ic_loss_out_dim = ic_loss_out_dim
|
287 |
+
self.ic_loss_queue_size = ic_loss_queue_size
|
288 |
+
self.ic_loss_in_queue_size = ic_loss_in_queue_size
|
289 |
+
self.ic_loss_batch_iou_thr = ic_loss_batch_iou_thr
|
290 |
+
self.ic_loss_queue_iou_thr = ic_loss_queue_iou_thr
|
291 |
+
self.ic_loss_weight = ic_loss_weight
|
292 |
+
|
293 |
+
self.register_buffer('queue', torch.zeros(
|
294 |
+
self.num_known_classes, ic_loss_queue_size, ic_loss_out_dim))
|
295 |
+
self.register_buffer('queue_label', torch.empty(
|
296 |
+
self.num_known_classes, ic_loss_queue_size).fill_(-1).long())
|
297 |
+
self.register_buffer('queue_ptr', torch.zeros(
|
298 |
+
self.num_known_classes, dtype=torch.long))
|
299 |
+
|
300 |
+
@classmethod
|
301 |
+
def from_config(cls, cfg, input_shape):
|
302 |
+
ret = super().from_config(cfg, input_shape)
|
303 |
+
ret.update({
|
304 |
+
'num_known_classes': cfg.MODEL.ROI_HEADS.NUM_KNOWN_CLASSES,
|
305 |
+
"max_iters": cfg.SOLVER.MAX_ITER,
|
306 |
+
|
307 |
+
"up_loss_start_iter": cfg.UPLOSS.START_ITER,
|
308 |
+
"up_loss_sampling_metric": cfg.UPLOSS.SAMPLING_METRIC,
|
309 |
+
"up_loss_topk": cfg.UPLOSS.TOPK,
|
310 |
+
"up_loss_alpha": cfg.UPLOSS.ALPHA,
|
311 |
+
"up_loss_weight": cfg.UPLOSS.WEIGHT,
|
312 |
+
|
313 |
+
"ic_loss_out_dim": cfg.ICLOSS.OUT_DIM,
|
314 |
+
"ic_loss_queue_size": cfg.ICLOSS.QUEUE_SIZE,
|
315 |
+
"ic_loss_in_queue_size": cfg.ICLOSS.IN_QUEUE_SIZE,
|
316 |
+
"ic_loss_batch_iou_thr": cfg.ICLOSS.BATCH_IOU_THRESH,
|
317 |
+
"ic_loss_queue_iou_thr": cfg.ICLOSS.QUEUE_IOU_THRESH,
|
318 |
+
"ic_loss_queue_tau": cfg.ICLOSS.TEMPERATURE,
|
319 |
+
"ic_loss_weight": cfg.ICLOSS.WEIGHT,
|
320 |
+
|
321 |
+
})
|
322 |
+
return ret
|
323 |
+
|
324 |
+
def forward(self, feats):
|
325 |
+
# support shared & sepearte head
|
326 |
+
if isinstance(feats, tuple):
|
327 |
+
reg_x, cls_x = feats
|
328 |
+
else:
|
329 |
+
reg_x = cls_x = feats
|
330 |
+
|
331 |
+
if reg_x.dim() > 2:
|
332 |
+
reg_x = torch.flatten(reg_x, start_dim=1)
|
333 |
+
cls_x = torch.flatten(cls_x, start_dim=1)
|
334 |
+
|
335 |
+
x_norm = torch.norm(cls_x, p=2, dim=1).unsqueeze(1).expand_as(cls_x)
|
336 |
+
x_normalized = cls_x.div(x_norm + 1e-5)
|
337 |
+
|
338 |
+
# normalize weight
|
339 |
+
temp_norm = (
|
340 |
+
torch.norm(self.cls_score.weight.data, p=2, dim=1)
|
341 |
+
.unsqueeze(1)
|
342 |
+
.expand_as(self.cls_score.weight.data)
|
343 |
+
)
|
344 |
+
self.cls_score.weight.data = self.cls_score.weight.data.div(
|
345 |
+
temp_norm + 1e-5
|
346 |
+
)
|
347 |
+
cos_dist = self.cls_score(x_normalized)
|
348 |
+
scores = self.scale * cos_dist
|
349 |
+
proposal_deltas = self.bbox_pred(reg_x)
|
350 |
+
|
351 |
+
# encode feature with MLP
|
352 |
+
mlp_feat = self.encoder(cls_x)
|
353 |
+
|
354 |
+
return scores, proposal_deltas, mlp_feat
|
355 |
+
|
356 |
+
def get_up_loss(self, scores, gt_classes):
|
357 |
+
# start up loss after several warmup iters
|
358 |
+
storage = get_event_storage()
|
359 |
+
if storage.iter > self.up_loss_start_iter:
|
360 |
+
loss_cls_up = self.up_loss(scores, gt_classes)
|
361 |
+
else:
|
362 |
+
loss_cls_up = scores.new_tensor(0.0)
|
363 |
+
|
364 |
+
return {"loss_cls_up": self.up_loss_weight * loss_cls_up}
|
365 |
+
|
366 |
+
def get_ic_loss(self, feat, gt_classes, ious):
|
367 |
+
# select foreground and iou > thr instance in a mini-batch
|
368 |
+
pos_inds = (ious > self.ic_loss_batch_iou_thr) & (
|
369 |
+
gt_classes != self.num_classes)
|
370 |
+
feat, gt_classes = feat[pos_inds], gt_classes[pos_inds]
|
371 |
+
|
372 |
+
queue = self.queue.reshape(-1, self.ic_loss_out_dim)
|
373 |
+
queue_label = self.queue_label.reshape(-1)
|
374 |
+
queue_inds = queue_label != -1 # filter empty queue
|
375 |
+
queue, queue_label = queue[queue_inds], queue_label[queue_inds]
|
376 |
+
|
377 |
+
loss_ic_loss = self.ic_loss_loss(feat, gt_classes, queue, queue_label)
|
378 |
+
# loss decay
|
379 |
+
storage = get_event_storage()
|
380 |
+
decay_weight = 1.0 - storage.iter / self.max_iters
|
381 |
+
return {"loss_cls_ic": self.ic_loss_weight * decay_weight * loss_ic_loss}
|
382 |
+
|
383 |
+
@torch.no_grad()
|
384 |
+
def _dequeue_and_enqueue(self, feat, gt_classes, ious, iou_thr=0.7):
|
385 |
+
# 1. gather variable
|
386 |
+
feat = self.concat_all_gather(feat)
|
387 |
+
gt_classes = self.concat_all_gather(gt_classes)
|
388 |
+
ious = self.concat_all_gather(ious)
|
389 |
+
# 2. filter by iou and obj, remove bg
|
390 |
+
keep = (ious > iou_thr) & (gt_classes != self.num_classes)
|
391 |
+
feat, gt_classes = feat[keep], gt_classes[keep]
|
392 |
+
|
393 |
+
for i in range(self.num_known_classes):
|
394 |
+
ptr = int(self.queue_ptr[i])
|
395 |
+
cls_ind = gt_classes == i
|
396 |
+
cls_feat, cls_gt_classes = feat[cls_ind], gt_classes[cls_ind]
|
397 |
+
# 3. sort by similarity, low sim ranks first
|
398 |
+
cls_queue = self.queue[i, self.queue_label[i] != -1]
|
399 |
+
_, sim_inds = F.cosine_similarity(
|
400 |
+
cls_feat[:, None], cls_queue[None, :], dim=-1).mean(dim=1).sort()
|
401 |
+
top_sim_inds = sim_inds[:self.ic_loss_in_queue_size]
|
402 |
+
cls_feat, cls_gt_classes = cls_feat[top_sim_inds], cls_gt_classes[top_sim_inds]
|
403 |
+
# 4. in queue
|
404 |
+
batch_size = cls_feat.size(
|
405 |
+
0) if ptr + cls_feat.size(0) <= self.ic_loss_queue_size else self.ic_loss_queue_size - ptr
|
406 |
+
self.queue[i, ptr:ptr+batch_size] = cls_feat[:batch_size]
|
407 |
+
self.queue_label[i, ptr:ptr + batch_size] = cls_gt_classes[:batch_size]
|
408 |
+
|
409 |
+
ptr = ptr + batch_size if ptr + batch_size < self.ic_loss_queue_size else 0
|
410 |
+
self.queue_ptr[i] = ptr
|
411 |
+
|
412 |
+
@torch.no_grad()
|
413 |
+
def concat_all_gather(self, tensor):
|
414 |
+
world_size = comm.get_world_size()
|
415 |
+
# single GPU, directly return the tensor
|
416 |
+
if world_size == 1:
|
417 |
+
return tensor
|
418 |
+
# multiple GPUs, gather tensors
|
419 |
+
tensors_gather = [torch.ones_like(tensor) for _ in range(world_size)]
|
420 |
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
421 |
+
output = torch.cat(tensors_gather, dim=0)
|
422 |
+
return output
|
423 |
+
|
424 |
+
def losses(self, predictions, proposals, input_features=None):
|
425 |
+
"""
|
426 |
+
Args:
|
427 |
+
predictions: return values of :meth:`forward()`.
|
428 |
+
proposals (list[Instances]): proposals that match the features that were used
|
429 |
+
to compute predictions. The fields ``proposal_boxes``, ``gt_boxes``,
|
430 |
+
``gt_classes`` are expected.
|
431 |
+
|
432 |
+
Returns:
|
433 |
+
Dict[str, Tensor]: dict of losses
|
434 |
+
"""
|
435 |
+
scores, proposal_deltas, mlp_feat = predictions
|
436 |
+
# parse classification outputs
|
437 |
+
gt_classes = (
|
438 |
+
cat([p.gt_classes for p in proposals], dim=0) if len(
|
439 |
+
proposals) else torch.empty(0)
|
440 |
+
)
|
441 |
+
_log_classification_stats(scores, gt_classes)
|
442 |
+
|
443 |
+
# parse box regression outputs
|
444 |
+
if len(proposals):
|
445 |
+
proposal_boxes = cat(
|
446 |
+
[p.proposal_boxes.tensor for p in proposals], dim=0) # Nx4
|
447 |
+
assert not proposal_boxes.requires_grad, "Proposals should not require gradients!"
|
448 |
+
# If "gt_boxes" does not exist, the proposals must be all negative and
|
449 |
+
# should not be included in regression loss computation.
|
450 |
+
# Here we just use proposal_boxes as an arbitrary placeholder because its
|
451 |
+
# value won't be used in self.box_reg_loss().
|
452 |
+
gt_boxes = cat(
|
453 |
+
[(p.gt_boxes if p.has("gt_boxes")
|
454 |
+
else p.proposal_boxes).tensor for p in proposals],
|
455 |
+
dim=0,
|
456 |
+
)
|
457 |
+
else:
|
458 |
+
proposal_boxes = gt_boxes = torch.empty(
|
459 |
+
(0, 4), device=proposal_deltas.device)
|
460 |
+
|
461 |
+
losses = {
|
462 |
+
"loss_cls_ce": cross_entropy(scores, gt_classes, reduction="mean"),
|
463 |
+
"loss_box_reg": self.box_reg_loss(
|
464 |
+
proposal_boxes, gt_boxes, proposal_deltas, gt_classes
|
465 |
+
),
|
466 |
+
}
|
467 |
+
|
468 |
+
# up loss
|
469 |
+
losses.update(self.get_up_loss(scores, gt_classes))
|
470 |
+
|
471 |
+
ious = cat([p.iou for p in proposals], dim=0)
|
472 |
+
# we first store feats in the queue, then cmopute loss
|
473 |
+
self._dequeue_and_enqueue(
|
474 |
+
mlp_feat, gt_classes, ious, iou_thr=self.ic_loss_queue_iou_thr)
|
475 |
+
losses.update(self.get_ic_loss(mlp_feat, gt_classes, ious))
|
476 |
+
|
477 |
+
return {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()}
|
478 |
+
|
479 |
+
|
480 |
+
@ROI_BOX_OUTPUT_LAYERS_REGISTRY.register()
|
481 |
+
class PROSERFastRCNNOutputLayers(CosineFastRCNNOutputLayers):
|
482 |
+
"""PROSER
|
483 |
+
"""
|
484 |
+
@configurable
|
485 |
+
def __init__(self, *args, **kargs):
|
486 |
+
super().__init__(*args, **kargs)
|
487 |
+
self.proser_weight = 0.1
|
488 |
+
|
489 |
+
def get_proser_loss(self, scores, gt_classes):
|
490 |
+
num_sample, num_classes = scores.shape
|
491 |
+
mask = torch.arange(num_classes).repeat(
|
492 |
+
num_sample, 1).to(scores.device)
|
493 |
+
inds = mask != gt_classes[:, None].repeat(1, num_classes)
|
494 |
+
mask = mask[inds].reshape(num_sample, num_classes-1)
|
495 |
+
mask_scores = torch.gather(scores, 1, mask)
|
496 |
+
|
497 |
+
targets = torch.zeros_like(gt_classes)
|
498 |
+
fg_inds = gt_classes != self.num_classes
|
499 |
+
targets[fg_inds] = self.num_classes-2
|
500 |
+
targets[~fg_inds] = self.num_classes-1
|
501 |
+
|
502 |
+
loss_cls_proser = cross_entropy(mask_scores, targets)
|
503 |
+
return {"loss_cls_proser": self.proser_weight * loss_cls_proser}
|
504 |
+
|
505 |
+
def losses(self, predictions, proposals, input_features=None):
|
506 |
+
"""
|
507 |
+
Args:
|
508 |
+
predictions: return values of :meth:`forward()`.
|
509 |
+
proposals (list[Instances]): proposals that match the features that were used
|
510 |
+
to compute predictions. The fields ``proposal_boxes``, ``gt_boxes``,
|
511 |
+
``gt_classes`` are expected.
|
512 |
+
|
513 |
+
Returns:
|
514 |
+
Dict[str, Tensor]: dict of losses
|
515 |
+
"""
|
516 |
+
scores, proposal_deltas = predictions
|
517 |
+
# parse classification outputs
|
518 |
+
gt_classes = (
|
519 |
+
cat([p.gt_classes for p in proposals], dim=0) if len(
|
520 |
+
proposals) else torch.empty(0)
|
521 |
+
)
|
522 |
+
_log_classification_stats(scores, gt_classes)
|
523 |
+
|
524 |
+
# parse box regression outputs
|
525 |
+
if len(proposals):
|
526 |
+
proposal_boxes = cat(
|
527 |
+
[p.proposal_boxes.tensor for p in proposals], dim=0) # Nx4
|
528 |
+
assert not proposal_boxes.requires_grad, "Proposals should not require gradients!"
|
529 |
+
# If "gt_boxes" does not exist, the proposals must be all negative and
|
530 |
+
# should not be included in regression loss computation.
|
531 |
+
# Here we just use proposal_boxes as an arbitrary placeholder because its
|
532 |
+
# value won't be used in self.box_reg_loss().
|
533 |
+
gt_boxes = cat(
|
534 |
+
[(p.gt_boxes if p.has("gt_boxes")
|
535 |
+
else p.proposal_boxes).tensor for p in proposals],
|
536 |
+
dim=0,
|
537 |
+
)
|
538 |
+
else:
|
539 |
+
proposal_boxes = gt_boxes = torch.empty(
|
540 |
+
(0, 4), device=proposal_deltas.device)
|
541 |
+
|
542 |
+
losses = {
|
543 |
+
"loss_cls_ce": cross_entropy(scores, gt_classes, reduction="mean"),
|
544 |
+
"loss_box_reg": self.box_reg_loss(
|
545 |
+
proposal_boxes, gt_boxes, proposal_deltas, gt_classes
|
546 |
+
),
|
547 |
+
}
|
548 |
+
losses.update(self.get_proser_loss(scores, gt_classes))
|
549 |
+
|
550 |
+
return {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()}
|
551 |
+
|
552 |
+
|
553 |
+
@ROI_BOX_OUTPUT_LAYERS_REGISTRY.register()
|
554 |
+
class DropoutFastRCNNOutputLayers(CosineFastRCNNOutputLayers):
|
555 |
+
|
556 |
+
@configurable
|
557 |
+
def __init__(self, *args, **kargs):
|
558 |
+
super().__init__(*args, **kargs)
|
559 |
+
self.dropout = nn.Dropout(p=0.5)
|
560 |
+
self.entropy_thr = 0.25
|
561 |
+
|
562 |
+
def forward(self, feats, testing=False):
|
563 |
+
# support shared & sepearte head
|
564 |
+
if isinstance(feats, tuple):
|
565 |
+
reg_x, cls_x = feats
|
566 |
+
else:
|
567 |
+
reg_x = cls_x = feats
|
568 |
+
|
569 |
+
if reg_x.dim() > 2:
|
570 |
+
reg_x = torch.flatten(reg_x, start_dim=1)
|
571 |
+
cls_x = torch.flatten(cls_x, start_dim=1)
|
572 |
+
|
573 |
+
x_norm = torch.norm(cls_x, p=2, dim=1).unsqueeze(1).expand_as(cls_x)
|
574 |
+
x_normalized = cls_x.div(x_norm + 1e-5)
|
575 |
+
|
576 |
+
# normalize weight
|
577 |
+
temp_norm = (
|
578 |
+
torch.norm(self.cls_score.weight.data, p=2, dim=1)
|
579 |
+
.unsqueeze(1)
|
580 |
+
.expand_as(self.cls_score.weight.data)
|
581 |
+
)
|
582 |
+
self.cls_score.weight.data = self.cls_score.weight.data.div(
|
583 |
+
temp_norm + 1e-5
|
584 |
+
)
|
585 |
+
if testing:
|
586 |
+
self.dropout.train()
|
587 |
+
x_normalized = self.dropout(x_normalized)
|
588 |
+
cos_dist = self.cls_score(x_normalized)
|
589 |
+
scores = self.scale * cos_dist
|
590 |
+
proposal_deltas = self.bbox_pred(reg_x)
|
591 |
+
|
592 |
+
return scores, proposal_deltas
|
593 |
+
|
594 |
+
def inference(self, predictions: List[Tuple[torch.Tensor, torch.Tensor]], proposals: List[Instances]):
|
595 |
+
"""
|
596 |
+
Args:
|
597 |
+
predictions: return values of :meth:`forward()`.
|
598 |
+
proposals (list[Instances]): proposals that match the features that were
|
599 |
+
used to compute predictions. The ``proposal_boxes`` field is expected.
|
600 |
+
|
601 |
+
Returns:
|
602 |
+
list[Instances]: same as `fast_rcnn_inference`.
|
603 |
+
list[Tensor]: same as `fast_rcnn_inference`.
|
604 |
+
"""
|
605 |
+
boxes = self.predict_boxes(predictions[0], proposals)
|
606 |
+
scores = self.predict_probs(predictions, proposals)
|
607 |
+
image_shapes = [x.image_size for x in proposals]
|
608 |
+
return fast_rcnn_inference(
|
609 |
+
boxes,
|
610 |
+
scores,
|
611 |
+
image_shapes,
|
612 |
+
self.test_score_thresh,
|
613 |
+
self.test_nms_thresh,
|
614 |
+
self.test_topk_per_image,
|
615 |
+
)
|
616 |
+
|
617 |
+
def predict_probs(
|
618 |
+
self, predictions: List[Tuple[torch.Tensor, torch.Tensor]], proposals: List[Instances]
|
619 |
+
):
|
620 |
+
"""
|
621 |
+
Args:
|
622 |
+
predictions: return values of :meth:`forward()`.
|
623 |
+
proposals (list[Instances]): proposals that match the features that were
|
624 |
+
used to compute predictions.
|
625 |
+
|
626 |
+
Returns:
|
627 |
+
list[Tensor]:
|
628 |
+
A list of Tensors of predicted class probabilities for each image.
|
629 |
+
Element i has shape (Ri, K + 1), where Ri is the number of proposals for image i.
|
630 |
+
"""
|
631 |
+
# mean of multiple observations
|
632 |
+
scores = torch.stack([pred[0] for pred in predictions], dim=-1)
|
633 |
+
scores = scores.mean(dim=-1)
|
634 |
+
# threshlod by entropy
|
635 |
+
norm_entropy = dists.Categorical(scores.softmax(
|
636 |
+
dim=1)).entropy() / np.log(self.num_classes)
|
637 |
+
inds = norm_entropy > self.entropy_thr
|
638 |
+
max_scores = scores.max(dim=1)[0]
|
639 |
+
# set those with high entropy unknown objects
|
640 |
+
scores[inds, :] = 0.0
|
641 |
+
scores[inds, self.num_classes-1] = max_scores[inds]
|
642 |
+
|
643 |
+
num_inst_per_image = [len(p) for p in proposals]
|
644 |
+
probs = F.softmax(scores, dim=-1)
|
645 |
+
return probs.split(num_inst_per_image, dim=0)
|
opendet2/modeling/roi_heads/roi_heads.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import logging
|
3 |
+
from typing import Dict, List
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from detectron2.config import configurable
|
9 |
+
from detectron2.layers import ShapeSpec
|
10 |
+
from detectron2.modeling.poolers import ROIPooler
|
11 |
+
from detectron2.modeling.roi_heads.box_head import build_box_head
|
12 |
+
from detectron2.modeling.roi_heads.roi_heads import (
|
13 |
+
ROI_HEADS_REGISTRY, StandardROIHeads, add_ground_truth_to_proposals)
|
14 |
+
from detectron2.structures import Boxes, Instances, pairwise_iou
|
15 |
+
from detectron2.utils.events import get_event_storage
|
16 |
+
from detectron2.utils.registry import Registry
|
17 |
+
from torch import nn
|
18 |
+
|
19 |
+
from .fast_rcnn import build_roi_box_output_layers
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
@ROI_HEADS_REGISTRY.register()
|
25 |
+
class OpenSetStandardROIHeads(StandardROIHeads):
|
26 |
+
|
27 |
+
@torch.no_grad()
|
28 |
+
def label_and_sample_proposals(self, proposals: List[Instances], targets: List[Instances]) -> List[Instances]:
|
29 |
+
if self.proposal_append_gt:
|
30 |
+
proposals = add_ground_truth_to_proposals(targets, proposals)
|
31 |
+
|
32 |
+
proposals_with_gt = []
|
33 |
+
|
34 |
+
num_fg_samples = []
|
35 |
+
num_bg_samples = []
|
36 |
+
for proposals_per_image, targets_per_image in zip(proposals, targets):
|
37 |
+
has_gt = len(targets_per_image) > 0
|
38 |
+
match_quality_matrix = pairwise_iou(
|
39 |
+
targets_per_image.gt_boxes, proposals_per_image.proposal_boxes
|
40 |
+
)
|
41 |
+
matched_idxs, matched_labels = self.proposal_matcher(
|
42 |
+
match_quality_matrix)
|
43 |
+
sampled_idxs, gt_classes = self._sample_proposals(
|
44 |
+
matched_idxs, matched_labels, targets_per_image.gt_classes
|
45 |
+
)
|
46 |
+
|
47 |
+
# Set target attributes of the sampled proposals:
|
48 |
+
proposals_per_image = proposals_per_image[sampled_idxs]
|
49 |
+
proposals_per_image.gt_classes = gt_classes
|
50 |
+
# NOTE: add iou of each proposal
|
51 |
+
ious, _ = match_quality_matrix.max(dim=0)
|
52 |
+
proposals_per_image.iou = ious[sampled_idxs]
|
53 |
+
|
54 |
+
if has_gt:
|
55 |
+
sampled_targets = matched_idxs[sampled_idxs]
|
56 |
+
for (trg_name, trg_value) in targets_per_image.get_fields().items():
|
57 |
+
if trg_name.startswith("gt_") and not proposals_per_image.has(trg_name):
|
58 |
+
proposals_per_image.set(
|
59 |
+
trg_name, trg_value[sampled_targets])
|
60 |
+
|
61 |
+
num_bg_samples.append(
|
62 |
+
(gt_classes == self.num_classes).sum().item())
|
63 |
+
num_fg_samples.append(gt_classes.numel() - num_bg_samples[-1])
|
64 |
+
proposals_with_gt.append(proposals_per_image)
|
65 |
+
|
66 |
+
# Log the number of fg/bg samples that are selected for training ROI heads
|
67 |
+
storage = get_event_storage()
|
68 |
+
storage.put_scalar("roi_head/num_fg_samples", np.mean(num_fg_samples))
|
69 |
+
storage.put_scalar("roi_head/num_bg_samples", np.mean(num_bg_samples))
|
70 |
+
|
71 |
+
return proposals_with_gt
|
72 |
+
|
73 |
+
@classmethod
|
74 |
+
def _init_box_head(cls, cfg, input_shape):
|
75 |
+
# fmt: off
|
76 |
+
in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES
|
77 |
+
pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
|
78 |
+
pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features)
|
79 |
+
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
|
80 |
+
pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE
|
81 |
+
# fmt: on
|
82 |
+
|
83 |
+
# If StandardROIHeads is applied on multiple feature maps (as in FPN),
|
84 |
+
# then we share the same predictors and therefore the channel counts must be the same
|
85 |
+
in_channels = [input_shape[f].channels for f in in_features]
|
86 |
+
# Check all channel counts are equal
|
87 |
+
assert len(set(in_channels)) == 1, in_channels
|
88 |
+
in_channels = in_channels[0]
|
89 |
+
|
90 |
+
box_pooler = ROIPooler(
|
91 |
+
output_size=pooler_resolution,
|
92 |
+
scales=pooler_scales,
|
93 |
+
sampling_ratio=sampling_ratio,
|
94 |
+
pooler_type=pooler_type,
|
95 |
+
)
|
96 |
+
|
97 |
+
box_head = build_box_head(
|
98 |
+
cfg, ShapeSpec(channels=in_channels,
|
99 |
+
height=pooler_resolution, width=pooler_resolution)
|
100 |
+
)
|
101 |
+
# register output layers
|
102 |
+
box_predictor = build_roi_box_output_layers(cfg, box_head.output_shape)
|
103 |
+
return {
|
104 |
+
"box_in_features": in_features,
|
105 |
+
"box_pooler": box_pooler,
|
106 |
+
"box_head": box_head,
|
107 |
+
"box_predictor": box_predictor,
|
108 |
+
}
|
109 |
+
|
110 |
+
|
111 |
+
@ROI_HEADS_REGISTRY.register()
|
112 |
+
class DropoutStandardROIHeads(OpenSetStandardROIHeads):
|
113 |
+
@configurable
|
114 |
+
def __init__(self, *args, **kwargs,):
|
115 |
+
super().__init__(*args, **kwargs)
|
116 |
+
# num of sampling
|
117 |
+
self.num_sample = 30
|
118 |
+
|
119 |
+
def _forward_box(self, features: Dict[str, torch.Tensor], proposals: List[Instances], targets=None):
|
120 |
+
|
121 |
+
features = [features[f] for f in self.box_in_features]
|
122 |
+
box_features = self.box_pooler(
|
123 |
+
features, [x.proposal_boxes for x in proposals])
|
124 |
+
box_features = self.box_head(box_features)
|
125 |
+
|
126 |
+
# if testing, we run multiple inference for dropout sampling
|
127 |
+
if self.training:
|
128 |
+
predictions = self.box_predictor(box_features)
|
129 |
+
else:
|
130 |
+
predictions = [self.box_predictor(
|
131 |
+
box_features, testing=True) for _ in range(self.num_sample)]
|
132 |
+
|
133 |
+
del box_features
|
134 |
+
|
135 |
+
if self.training:
|
136 |
+
losses = self.box_predictor.losses(predictions, proposals)
|
137 |
+
# proposals is modified in-place below, so losses must be computed first.
|
138 |
+
if self.train_on_pred_boxes:
|
139 |
+
with torch.no_grad():
|
140 |
+
pred_boxes = self.box_predictor.predict_boxes_for_gt_classes(
|
141 |
+
predictions, proposals
|
142 |
+
)
|
143 |
+
for proposals_per_image, pred_boxes_per_image in zip(proposals, pred_boxes):
|
144 |
+
proposals_per_image.proposal_boxes = Boxes(
|
145 |
+
pred_boxes_per_image)
|
146 |
+
return losses
|
147 |
+
else:
|
148 |
+
pred_instances, _ = self.box_predictor.inference(
|
149 |
+
predictions, proposals)
|
150 |
+
return pred_instances
|
opendet2/solver/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .build import *
|
2 |
+
|
3 |
+
__all__ = list(globals().keys())
|
opendet2/solver/build.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Set
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from detectron2.config import CfgNode
|
5 |
+
from detectron2.solver.build import maybe_add_gradient_clipping
|
6 |
+
|
7 |
+
|
8 |
+
def build_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer:
|
9 |
+
"""
|
10 |
+
Build an optimizer from config.
|
11 |
+
"""
|
12 |
+
norm_module_types = (
|
13 |
+
torch.nn.BatchNorm1d,
|
14 |
+
torch.nn.BatchNorm2d,
|
15 |
+
torch.nn.BatchNorm3d,
|
16 |
+
torch.nn.SyncBatchNorm,
|
17 |
+
# NaiveSyncBatchNorm inherits from BatchNorm2d
|
18 |
+
torch.nn.GroupNorm,
|
19 |
+
torch.nn.InstanceNorm1d,
|
20 |
+
torch.nn.InstanceNorm2d,
|
21 |
+
torch.nn.InstanceNorm3d,
|
22 |
+
torch.nn.LayerNorm,
|
23 |
+
torch.nn.LocalResponseNorm,
|
24 |
+
)
|
25 |
+
params: List[Dict[str, Any]] = []
|
26 |
+
memo: Set[torch.nn.parameter.Parameter] = set()
|
27 |
+
for module in model.modules():
|
28 |
+
for key, value in module.named_parameters(recurse=False):
|
29 |
+
if not value.requires_grad:
|
30 |
+
continue
|
31 |
+
# Avoid duplicating parameters
|
32 |
+
if value in memo:
|
33 |
+
continue
|
34 |
+
memo.add(value)
|
35 |
+
lr = cfg.SOLVER.BASE_LR
|
36 |
+
weight_decay = cfg.SOLVER.WEIGHT_DECAY
|
37 |
+
if isinstance(module, norm_module_types):
|
38 |
+
weight_decay = cfg.SOLVER.WEIGHT_DECAY_NORM
|
39 |
+
elif key == "bias":
|
40 |
+
# NOTE: unlike Detectron v1, we now default BIAS_LR_FACTOR to 1.0
|
41 |
+
# and WEIGHT_DECAY_BIAS to WEIGHT_DECAY so that bias optimizer
|
42 |
+
# hyperparameters are by default exactly the same as for regular
|
43 |
+
# weights.
|
44 |
+
lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
|
45 |
+
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
|
46 |
+
params += [{"params": [value], "lr": lr,
|
47 |
+
"weight_decay": weight_decay}]
|
48 |
+
|
49 |
+
# To support AdamW for swin_transformer
|
50 |
+
if cfg.SOLVER.OPTIMIZER == "ADAMW":
|
51 |
+
optimizer = torch.optim.AdamW(
|
52 |
+
params, lr=cfg.SOLVER.BASE_LR, betas=cfg.SOLVER.BETAS, weight_decay=cfg.SOLVER.WEIGHT_DECAY)
|
53 |
+
else:
|
54 |
+
optimizer = torch.optim.SGD(
|
55 |
+
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM)
|
56 |
+
optimizer = maybe_add_gradient_clipping(cfg, optimizer)
|
57 |
+
return optimizer
|
setup.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
from setuptools import setup
|
3 |
+
|
4 |
+
setup(
|
5 |
+
name="opendet2",
|
6 |
+
version=0.1,
|
7 |
+
author="csuhan",
|
8 |
+
url="https://github.com/csuhan/opendet2",
|
9 |
+
description="Codebase for open set object detection",
|
10 |
+
python_requires=">=3.6",
|
11 |
+
install_requires=[
|
12 |
+
'timm', 'opencv-python'
|
13 |
+
],
|
14 |
+
)
|
tools/convert_swin_to_d2.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def parse_args():
|
8 |
+
parser = argparse.ArgumentParser("Convert Swin Transformer to Detectron2")
|
9 |
+
|
10 |
+
parser.add_argument("source_model", default="", type=str,
|
11 |
+
help="Source model")
|
12 |
+
parser.add_argument("output_model", default="", type=str,
|
13 |
+
help="Output model")
|
14 |
+
return parser.parse_args()
|
15 |
+
|
16 |
+
|
17 |
+
def main():
|
18 |
+
args = parse_args()
|
19 |
+
|
20 |
+
if os.path.splitext(args.source_model)[-1] != ".pth":
|
21 |
+
raise ValueError("You should save weights as pth file")
|
22 |
+
|
23 |
+
source_weights = torch.load(
|
24 |
+
args.source_model, map_location=torch.device('cpu'))["model"]
|
25 |
+
converted_weights = {}
|
26 |
+
keys = list(source_weights.keys())
|
27 |
+
|
28 |
+
prefix = 'backbone.bottom_up.'
|
29 |
+
for key in keys:
|
30 |
+
converted_weights[prefix+key] = source_weights[key]
|
31 |
+
|
32 |
+
torch.save(converted_weights, args.output_model)
|
33 |
+
|
34 |
+
|
35 |
+
if __name__ == "__main__":
|
36 |
+
main()
|
tools/train_net.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
import os
|
4 |
+
|
5 |
+
import detectron2.utils.comm as comm
|
6 |
+
from detectron2.checkpoint import DetectionCheckpointer
|
7 |
+
from detectron2.config import get_cfg
|
8 |
+
from detectron2.engine import (default_argument_parser, default_setup, hooks,
|
9 |
+
launch)
|
10 |
+
from detectron2.evaluation import verify_results
|
11 |
+
from detectron2.utils.logger import setup_logger
|
12 |
+
from opendet2 import OpenDetTrainer, add_opendet_config, builtin
|
13 |
+
|
14 |
+
|
15 |
+
def setup(args):
|
16 |
+
"""
|
17 |
+
Create configs and perform basic setups.
|
18 |
+
"""
|
19 |
+
cfg = get_cfg()
|
20 |
+
# add opendet config
|
21 |
+
add_opendet_config(cfg)
|
22 |
+
cfg.merge_from_file(args.config_file)
|
23 |
+
cfg.merge_from_list(args.opts)
|
24 |
+
# Note: we use the key ROI_HEAD.NUM_KNOWN_CLASSES
|
25 |
+
# for open-set data processing and evaluation.
|
26 |
+
if 'RetinaNet' in cfg.MODEL.META_ARCHITECTURE:
|
27 |
+
cfg.MODEL.ROI_HEADS.NUM_KNOWN_CLASSES = cfg.MODEL.RETINANET.NUM_KNOWN_CLASSES
|
28 |
+
# add output dir if not exist
|
29 |
+
if cfg.OUTPUT_DIR == "./output":
|
30 |
+
config_name = os.path.basename(args.config_file).split(".yaml")[0]
|
31 |
+
cfg.OUTPUT_DIR = os.path.join(cfg.OUTPUT_DIR, config_name)
|
32 |
+
cfg.freeze()
|
33 |
+
default_setup(cfg, args)
|
34 |
+
setup_logger(output=cfg.OUTPUT_DIR,
|
35 |
+
distributed_rank=comm.get_rank(), name="opendet2")
|
36 |
+
return cfg
|
37 |
+
|
38 |
+
|
39 |
+
def main(args):
|
40 |
+
cfg = setup(args)
|
41 |
+
|
42 |
+
if args.eval_only:
|
43 |
+
model = OpenDetTrainer.build_model(cfg)
|
44 |
+
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
|
45 |
+
cfg.MODEL.WEIGHTS, resume=args.resume
|
46 |
+
)
|
47 |
+
res = OpenDetTrainer.test(cfg, model)
|
48 |
+
if cfg.TEST.AUG.ENABLED:
|
49 |
+
res.update(OpenDetTrainer.test_with_TTA(cfg, model))
|
50 |
+
if comm.is_main_process():
|
51 |
+
verify_results(cfg, res)
|
52 |
+
return res
|
53 |
+
|
54 |
+
"""
|
55 |
+
If you'd like to do anything fancier than the standard training logic,
|
56 |
+
consider writing your own training loop (see plain_train_net.py) or
|
57 |
+
subclassing the trainer.
|
58 |
+
"""
|
59 |
+
trainer = OpenDetTrainer(cfg)
|
60 |
+
trainer.resume_or_load(resume=args.resume)
|
61 |
+
if cfg.TEST.AUG.ENABLED:
|
62 |
+
trainer.register_hooks(
|
63 |
+
[hooks.EvalHook(
|
64 |
+
0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
|
65 |
+
)
|
66 |
+
return trainer.train()
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == "__main__":
|
70 |
+
args = default_argument_parser().parse_args()
|
71 |
+
print("Command Line Args:", args)
|
72 |
+
launch(
|
73 |
+
main,
|
74 |
+
args.num_gpus,
|
75 |
+
num_machines=args.num_machines,
|
76 |
+
machine_rank=args.machine_rank,
|
77 |
+
dist_url=args.dist_url,
|
78 |
+
args=(args,),
|
79 |
+
)
|