Upload 117 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- LICENSE +21 -0
- README.md +141 -3
- assets/teaser_figure.png +3 -0
- avism/__init__.py +12 -0
- avism/avism_model.py +460 -0
- avism/avism_model_coco.py +460 -0
- avism/config.py +56 -0
- avism/data/__init__.py +4 -0
- avism/data/augmentation.py +623 -0
- avism/data/avis_eval.py +203 -0
- avism/data/aviseval/__init__.py +5 -0
- avism/data/aviseval/_timing.py +65 -0
- avism/data/aviseval/datasets/__init__.py +1 -0
- avism/data/aviseval/datasets/_base_dataset.py +326 -0
- avism/data/aviseval/datasets/avis.py +367 -0
- avism/data/aviseval/eval.py +209 -0
- avism/data/aviseval/metrics/__init__.py +11 -0
- avism/data/aviseval/metrics/_base_metric.py +132 -0
- avism/data/aviseval/metrics/av_loc.py +191 -0
- avism/data/aviseval/metrics/avisa.py +190 -0
- avism/data/aviseval/metrics/clear.py +186 -0
- avism/data/aviseval/metrics/count.py +44 -0
- avism/data/aviseval/metrics/hota.py +202 -0
- avism/data/aviseval/metrics/identity.py +135 -0
- avism/data/aviseval/metrics/ideucl.py +135 -0
- avism/data/aviseval/metrics/j_and_f.py +310 -0
- avism/data/aviseval/metrics/track_map.py +462 -0
- avism/data/aviseval/metrics/vace.py +131 -0
- avism/data/aviseval/plotting.py +230 -0
- avism/data/aviseval/utils.py +146 -0
- avism/data/build.py +247 -0
- avism/data/dataset_mapper.py +272 -0
- avism/data/datasets/__init__.py +3 -0
- avism/data/datasets/avis.py +209 -0
- avism/data/datasets/avis_api/__init__.py +1 -0
- avism/data/datasets/avis_api/avos.py +277 -0
- avism/data/datasets/avis_api/avoseval.py +559 -0
- avism/data/datasets/builtin.py +29 -0
- avism/data/datasets/extract_audio_feat/audio_feature_extractor.py +77 -0
- avism/data/datasets/extract_audio_feat/mel_features.py +233 -0
- avism/data/datasets/extract_audio_feat/vggish_input.py +103 -0
- avism/data/datasets/extract_audio_feat/vggish_params.py +53 -0
- avism/data/datasets/extract_audio_feat/vggish_slim.py +134 -0
- avism/modeling/__init__.py +0 -0
- avism/modeling/avism_criterion.py +335 -0
- avism/modeling/avism_matcher.py +194 -0
- avism/modeling/transformer_decoder/__init__.py +1 -0
- avism/modeling/transformer_decoder/avism.py +675 -0
- avism/modeling/transformer_decoder/avism_coco.py +675 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/teaser_figure.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 GeWu-Lab
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,3 +1,141 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Audio-Visual Instance Segmentation
|
2 |
+
|
3 |
+
|
4 |
+
[](https://arxiv.org/abs/2412.03069)
|
5 |
+
[](https://ruohaoguo.github.io/avis/)
|
6 |
+
[](https://1drv.ms/u/c/3c9af704fb61931d/EVOs609SGMxLsbvVzVJHAa4Bmnu4GVZGjqYHQxDz0NKTew?e=WQU2Uf)
|
7 |
+
|
8 |
+
Ruohao Guo, Xianghua Ying*, Yaru Chen, Dantong Niu, Guangyao Li, Liao Qu, Yanyu Qi, Jinxing Zhou, Bowei Xing, Wenzhen Yue, Ji Shi, Qixun Wang, Peiliang Zhang, Buwen Liang
|
9 |
+
|
10 |
+
## 📰 News
|
11 |
+
|
12 |
+
🔥**2025.03.01**: Codes and checkpoints are released!
|
13 |
+
|
14 |
+
🔥**2025.02.27**: AVIS got accepted to **CVPR 2025**! 🎉🎉🎉
|
15 |
+
|
16 |
+
🔥**2024.11.12**: Our [project page](https://ruohaoguo.github.io/avis/) is now available!
|
17 |
+
|
18 |
+
🔥**2024.11.11**: The AVISeg dataset has been uploaded to [OneDrive](https://1drv.ms/u/c/3c9af704fb61931d/EVOs609SGMxLsbvVzVJHAa4Bmnu4GVZGjqYHQxDz0NKTew?e=WQU2Uf), welcome to download and use!
|
19 |
+
|
20 |
+
|
21 |
+
## 🌿 Introduction
|
22 |
+
|
23 |
+
In this paper, we propose a new multi-modal task, termed audio-visual instance segmentation (AVIS), which aims to simultaneously identify, segment and track individual sounding object instances in audible videos. To facilitate this research, we introduce a high-quality benchmark named AVISeg, containing over 90K instance masks from 26 semantic categories in 926 long videos. Additionally, we propose a strong baseline model for this task. Our model first localizes sound source within each frame, and condenses object-specific contexts into concise tokens. Then it builds long-range audio-visual dependencies between these tokens using window-based attention, and tracks sounding objects among the entire video sequences.
|
24 |
+
|
25 |
+
<div align='center'>
|
26 |
+
<img src="./assets/teaser_figure.png" class="interpolation-image" alt="radar." height="50%" width="100%" />
|
27 |
+
</div>
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
## ⚙️ Installation
|
32 |
+
|
33 |
+
```bash
|
34 |
+
conda create --name avism python=3.8 -y
|
35 |
+
conda activate avism
|
36 |
+
|
37 |
+
conda install pytorch==1.9.0 torchvision==0.10.0 cudatoolkit=11.1 -c pytorch -c nvidia
|
38 |
+
pip install -U opencv-python
|
39 |
+
|
40 |
+
cd ./AVISM
|
41 |
+
git clone https://github.com/facebookresearch/detectron2
|
42 |
+
cd detectron2
|
43 |
+
pip install -e .
|
44 |
+
|
45 |
+
cd ../
|
46 |
+
pip install -r requirements.txt
|
47 |
+
cd mask2former/modeling/pixel_decoder/ops
|
48 |
+
sh make.sh
|
49 |
+
```
|
50 |
+
|
51 |
+
## 🤗 Setup
|
52 |
+
|
53 |
+
### Datasets
|
54 |
+
|
55 |
+
Download and unzip datasets [OneDrive](https://1drv.ms/u/c/3c9af704fb61931d/EVOs609SGMxLsbvVzVJHAa4Bmnu4GVZGjqYHQxDz0NKTew?e=WQU2Uf) and put them in ```./datasets```.
|
56 |
+
|
57 |
+
### Pretrained Backbones
|
58 |
+
Download and unzip pre-trained backbones [OneDrive](https://1drv.ms/u/c/3c9af704fb61931d/ETDDliQ8zZFGmYxlLVPyi3sBis_fdjX0w8mJhyQnYVSdXA?e=Wt7pUb) and put them in ```./pre_models```.
|
59 |
+
|
60 |
+
### Checkpoints
|
61 |
+
|
62 |
+
Download the following checkpoints and put them in ```./checkpoints```.
|
63 |
+
|
64 |
+
<table>
|
65 |
+
<tr>
|
66 |
+
<th style="width: 150px;">Backbone</th>
|
67 |
+
<th>Pre-trained Datasets</th>
|
68 |
+
<th>FSLA</th>
|
69 |
+
<th>HOTA</th>
|
70 |
+
<th>mAP</th>
|
71 |
+
<th>Model Weight</th>
|
72 |
+
</tr>
|
73 |
+
<tr>
|
74 |
+
<td align="center">ResNet-50</td>
|
75 |
+
<td align="center">ImageNet</td>
|
76 |
+
<td align="center">42.78</td>
|
77 |
+
<td align="center">61.73</td>
|
78 |
+
<td align="center">40.57</td>
|
79 |
+
<td align="center"><a href="https://1drv.ms/u/c/3c9af704fb61931d/EYyAuCNpRjxDqEohJfoDLO0BYgw0lbwKqQ1lwVXe_kIPVQ?e=PeRlyx">AVISM_R50_IN.pth</a></td>
|
80 |
+
</tr>
|
81 |
+
<tr>
|
82 |
+
<td align="center">ResNet-50</td>
|
83 |
+
<td align="center">ImageNet & COCO</td>
|
84 |
+
<td align="center">44.42</td>
|
85 |
+
<td align="center">64.52</td>
|
86 |
+
<td align="center">45.04</td>
|
87 |
+
<td align="center"><a href="https://1drv.ms/u/c/3c9af704fb61931d/EX0snZsxQwdBswQFdG4sc9kBd-Bd7lw5zaTGR6FvrSxinQ?e=bdZF5G">AVISM_R50_COCO.pth</a></td>
|
88 |
+
</tr>
|
89 |
+
<tr>
|
90 |
+
<td align="center">Swin-L</td>
|
91 |
+
<td align="center">ImageNet</td>
|
92 |
+
<td align="center">49.15</td>
|
93 |
+
<td align="center">68.81</td>
|
94 |
+
<td align="center">49.06</td>
|
95 |
+
<td align="center"><a href="https://1drv.ms/u/c/3c9af704fb61931d/EV4V5Bh5AqVBhLVMM1ucdN0BuOZgHu17W3JDGjKDMLZ1bg?e=hF8umh">AVISM_SwinL_IN.pth</a></td>
|
96 |
+
</tr>
|
97 |
+
<tr>
|
98 |
+
<td align="center">Swin-L</td>
|
99 |
+
<td align="center">ImageNet & COCO</td>
|
100 |
+
<td align="center">52.49</td>
|
101 |
+
<td align="center">71.13</td>
|
102 |
+
<td align="center">53.46</td>
|
103 |
+
<td align="center"><a href="https://1drv.ms/u/c/3c9af704fb61931d/EXuM4cUxPTpEk1M7FoPqtNEBi47L7uR-ZlnqDCJscmNsiA?e=7prFiN">AVISM_SwinL_COCO.pth</a></td>
|
104 |
+
</tr>
|
105 |
+
</table>
|
106 |
+
|
107 |
+
|
108 |
+
## 📌 Getting Started
|
109 |
+
|
110 |
+
### Training
|
111 |
+
```
|
112 |
+
python train_net.py --num-gpus 2 --config-file configs/avism/R50/avism_R50_IN.yaml
|
113 |
+
```
|
114 |
+
|
115 |
+
### Evaluation
|
116 |
+
```
|
117 |
+
python train_net.py --config-file configs/avism/R50/avism_R50_IN.yaml --eval-only MODEL.WEIGHTS checkpoints/AVISM_R50_IN.pth
|
118 |
+
```
|
119 |
+
|
120 |
+
### Demo
|
121 |
+
```
|
122 |
+
python demo_video/demo.py --config-file configs/avism/R50/avism_R50_IN.yaml --opts MODEL.WEIGHTS checkpoints/AVISM_R50_IN.pth
|
123 |
+
```
|
124 |
+
|
125 |
+
## Acknowledgement
|
126 |
+
|
127 |
+
We thank the great work from [Detectron2](https://github.com/facebookresearch/detectron2), [Mask2Former](https://github.com/facebookresearch/MaskFormer) and [VITA](https://github.com/sukjunhwang/VITA).
|
128 |
+
|
129 |
+
|
130 |
+
## 📄 Citation
|
131 |
+
|
132 |
+
If our work assists your research, feel free to give us a star ⭐ or cite us using
|
133 |
+
|
134 |
+
```
|
135 |
+
@article{guo2023audio,
|
136 |
+
title={Audio-Visual Instance Segmentation},
|
137 |
+
author={Guo, Ruohao and Ying, Xianghua and Chen, Yaru and Niu, Dantong and Li, Guangyao and Qu, Liao and Qi, Yanyu and Zhou, Jinxing and Xing, Bowei and Yue, Wenzhen and Shi, Ji and Wang, Qixun and Zhang, Peiliang and Liang, Buwen},
|
138 |
+
journal={arXiv preprint arXiv:2310.18709},
|
139 |
+
year={2023}
|
140 |
+
}
|
141 |
+
```
|
assets/teaser_figure.png
ADDED
![]() |
Git LFS Details
|
avism/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# model code
|
2 |
+
from . import modeling
|
3 |
+
|
4 |
+
# config
|
5 |
+
from .config import add_avism_config
|
6 |
+
|
7 |
+
# models
|
8 |
+
from .avism_model import AVISM
|
9 |
+
from .avism_model_coco import AVISM_COCO
|
10 |
+
|
11 |
+
# video
|
12 |
+
from .data import *
|
avism/avism_model.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from detectron2.config import configurable
|
9 |
+
from detectron2.data import MetadataCatalog
|
10 |
+
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
|
11 |
+
from detectron2.modeling.backbone import Backbone
|
12 |
+
from detectron2.structures import Boxes, ImageList, Instances, BitMasks
|
13 |
+
from detectron2.utils.memory import retry_if_cuda_oom
|
14 |
+
|
15 |
+
from mask2former.modeling.criterion import SetCriterion
|
16 |
+
from mask2former.modeling.matcher import HungarianMatcher
|
17 |
+
from .modeling.avism_criterion import AvismSetCriterion
|
18 |
+
from .modeling.avism_matcher import AvismHungarianMatcher
|
19 |
+
from .modeling.transformer_decoder.avism import Avism
|
20 |
+
|
21 |
+
|
22 |
+
@META_ARCH_REGISTRY.register()
|
23 |
+
class AVISM(nn.Module):
|
24 |
+
"""
|
25 |
+
Main class for mask classification semantic segmentation architectures.
|
26 |
+
"""
|
27 |
+
|
28 |
+
@configurable
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
*,
|
32 |
+
backbone: Backbone,
|
33 |
+
sem_seg_head: nn.Module,
|
34 |
+
criterion: nn.Module,
|
35 |
+
num_queries: int,
|
36 |
+
object_mask_threshold: float,
|
37 |
+
overlap_threshold: float,
|
38 |
+
metadata,
|
39 |
+
size_divisibility: int,
|
40 |
+
pixel_mean: Tuple[float],
|
41 |
+
pixel_std: Tuple[float],
|
42 |
+
# inference
|
43 |
+
test_topk_per_image: int,
|
44 |
+
# avism
|
45 |
+
avism_module: nn.Module,
|
46 |
+
avism_criterion: nn.Module,
|
47 |
+
num_frames: int,
|
48 |
+
num_classes: int,
|
49 |
+
is_multi_cls: bool,
|
50 |
+
apply_cls_thres: float,
|
51 |
+
freeze_detector: bool,
|
52 |
+
test_run_chunk_size: int,
|
53 |
+
test_interpolate_chunk_size: int,
|
54 |
+
is_coco: bool,
|
55 |
+
):
|
56 |
+
"""
|
57 |
+
Args:
|
58 |
+
backbone: a backbone module, must follow detectron2's backbone interface
|
59 |
+
sem_seg_head: a module that predicts semantic segmentation from backbone features
|
60 |
+
criterion: a module that defines the loss
|
61 |
+
num_queries: int, number of queries
|
62 |
+
object_mask_threshold: float, threshold to filter query based on classification score
|
63 |
+
for panoptic segmentation inference
|
64 |
+
overlap_threshold: overlap threshold used in general inference for panoptic segmentation
|
65 |
+
metadata: dataset meta, get `thing` and `stuff` category names for panoptic
|
66 |
+
segmentation inference
|
67 |
+
size_divisibility: Some backbones require the input height and width to be divisible by a
|
68 |
+
specific integer. We can use this to override such requirement.
|
69 |
+
pixel_mean, pixel_std: list or tuple with #channels element, representing
|
70 |
+
the per-channel mean and std to be used to normalize the input image
|
71 |
+
test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
|
72 |
+
"""
|
73 |
+
super().__init__()
|
74 |
+
self.backbone = backbone
|
75 |
+
self.sem_seg_head = sem_seg_head
|
76 |
+
self.criterion = criterion
|
77 |
+
self.num_queries = num_queries
|
78 |
+
self.overlap_threshold = overlap_threshold
|
79 |
+
self.object_mask_threshold = object_mask_threshold
|
80 |
+
self.metadata = metadata
|
81 |
+
if size_divisibility < 0:
|
82 |
+
# use backbone size_divisibility if not set
|
83 |
+
size_divisibility = self.backbone.size_divisibility
|
84 |
+
self.size_divisibility = size_divisibility
|
85 |
+
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
86 |
+
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
87 |
+
|
88 |
+
# additional args
|
89 |
+
self.test_topk_per_image = test_topk_per_image
|
90 |
+
|
91 |
+
# avism hyper-parameters
|
92 |
+
self.num_frames = num_frames
|
93 |
+
self.num_classes = num_classes
|
94 |
+
self.avism_module = avism_module
|
95 |
+
self.avism_criterion = avism_criterion
|
96 |
+
self.is_multi_cls = is_multi_cls
|
97 |
+
self.apply_cls_thres = apply_cls_thres
|
98 |
+
|
99 |
+
if freeze_detector:
|
100 |
+
for name, p in self.named_parameters():
|
101 |
+
if not "avism_module" in name:
|
102 |
+
p.requires_grad_(False)
|
103 |
+
self.test_run_chunk_size = test_run_chunk_size
|
104 |
+
self.test_interpolate_chunk_size = test_interpolate_chunk_size
|
105 |
+
|
106 |
+
self.is_coco = is_coco
|
107 |
+
|
108 |
+
@classmethod
|
109 |
+
def from_config(cls, cfg):
|
110 |
+
backbone = build_backbone(cfg)
|
111 |
+
sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())
|
112 |
+
|
113 |
+
# Loss parameters:
|
114 |
+
deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
|
115 |
+
no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT
|
116 |
+
avism_deep_supervision = cfg.MODEL.AVISM.DEEP_SUPERVISION
|
117 |
+
|
118 |
+
# loss weights
|
119 |
+
class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT
|
120 |
+
dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
|
121 |
+
mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
|
122 |
+
sim_weight = cfg.MODEL.AVISM.SIM_WEIGHT
|
123 |
+
|
124 |
+
# building criterion
|
125 |
+
matcher = HungarianMatcher(
|
126 |
+
cost_class=class_weight,
|
127 |
+
cost_mask=mask_weight,
|
128 |
+
cost_dice=dice_weight,
|
129 |
+
num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
|
130 |
+
)
|
131 |
+
|
132 |
+
weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight}
|
133 |
+
|
134 |
+
if deep_supervision:
|
135 |
+
dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
|
136 |
+
aux_weight_dict = {}
|
137 |
+
for i in range(dec_layers - 1):
|
138 |
+
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
139 |
+
weight_dict.update(aux_weight_dict)
|
140 |
+
|
141 |
+
losses = ["labels", "masks"]
|
142 |
+
|
143 |
+
criterion = SetCriterion(
|
144 |
+
sem_seg_head.num_classes,
|
145 |
+
matcher=matcher,
|
146 |
+
weight_dict=weight_dict,
|
147 |
+
eos_coef=no_object_weight,
|
148 |
+
losses=losses,
|
149 |
+
num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
|
150 |
+
oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
|
151 |
+
importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
|
152 |
+
avism_last_layer_num=cfg.MODEL.AVISM.LAST_LAYER_NUM,
|
153 |
+
)
|
154 |
+
|
155 |
+
# Avism
|
156 |
+
num_classes = sem_seg_head.num_classes
|
157 |
+
hidden_dim = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
|
158 |
+
avism_module = Avism(cfg=cfg, in_channels=hidden_dim, aux_loss=avism_deep_supervision)
|
159 |
+
|
160 |
+
# building criterion for avism inference
|
161 |
+
avism_matcher = AvismHungarianMatcher(
|
162 |
+
cost_class=class_weight,
|
163 |
+
cost_mask=mask_weight,
|
164 |
+
cost_dice=dice_weight,
|
165 |
+
num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
|
166 |
+
)
|
167 |
+
avism_weight_dict = {
|
168 |
+
"loss_avism_ce": class_weight, "loss_avism_mask": mask_weight, "loss_avism_dice": dice_weight
|
169 |
+
}
|
170 |
+
if sim_weight > 0.0:
|
171 |
+
avism_weight_dict["loss_avism_sim"] = sim_weight
|
172 |
+
|
173 |
+
if avism_deep_supervision:
|
174 |
+
avism_dec_layers = cfg.MODEL.AVISM.DEC_LAYERS
|
175 |
+
aux_weight_dict = {}
|
176 |
+
for i in range(avism_dec_layers - 1):
|
177 |
+
aux_weight_dict.update({k + f"_{i}": v for k, v in avism_weight_dict.items()})
|
178 |
+
avism_weight_dict.update(aux_weight_dict)
|
179 |
+
avism_losses = ["avism_labels", "avism_masks"]
|
180 |
+
if sim_weight > 0.0:
|
181 |
+
avism_losses.append("fg_sim")
|
182 |
+
|
183 |
+
avism_criterion = AvismSetCriterion(
|
184 |
+
num_classes,
|
185 |
+
matcher=avism_matcher,
|
186 |
+
weight_dict=avism_weight_dict,
|
187 |
+
eos_coef=cfg.MODEL.AVISM.NO_OBJECT_WEIGHT,
|
188 |
+
losses=avism_losses,
|
189 |
+
num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
|
190 |
+
oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
|
191 |
+
importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
|
192 |
+
sim_use_clip=cfg.MODEL.AVISM.SIM_USE_CLIP,
|
193 |
+
)
|
194 |
+
|
195 |
+
return {
|
196 |
+
"backbone": backbone,
|
197 |
+
"sem_seg_head": sem_seg_head,
|
198 |
+
"criterion": criterion,
|
199 |
+
"num_queries": cfg.MODEL.AVISM.NUM_OBJECT_QUERIES,
|
200 |
+
"object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD,
|
201 |
+
"overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD,
|
202 |
+
"metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
|
203 |
+
"size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
|
204 |
+
"pixel_mean": cfg.MODEL.PIXEL_MEAN,
|
205 |
+
"pixel_std": cfg.MODEL.PIXEL_STD,
|
206 |
+
# inference
|
207 |
+
"test_topk_per_image": cfg.TEST.DETECTIONS_PER_IMAGE,
|
208 |
+
# avism
|
209 |
+
"avism_module": avism_module,
|
210 |
+
"avism_criterion": avism_criterion,
|
211 |
+
"num_frames": cfg.INPUT.SAMPLING_FRAME_NUM,
|
212 |
+
"num_classes": num_classes,
|
213 |
+
"is_multi_cls": cfg.MODEL.AVISM.MULTI_CLS_ON,
|
214 |
+
"apply_cls_thres": cfg.MODEL.AVISM.APPLY_CLS_THRES,
|
215 |
+
"freeze_detector": cfg.MODEL.AVISM.FREEZE_DETECTOR,
|
216 |
+
"test_run_chunk_size": cfg.MODEL.AVISM.TEST_RUN_CHUNK_SIZE,
|
217 |
+
"test_interpolate_chunk_size": cfg.MODEL.AVISM.TEST_INTERPOLATE_CHUNK_SIZE,
|
218 |
+
"is_coco": cfg.DATASETS.TEST[0].startswith("coco"),
|
219 |
+
}
|
220 |
+
|
221 |
+
@property
|
222 |
+
def device(self):
|
223 |
+
return self.pixel_mean.device
|
224 |
+
|
225 |
+
def forward(self, batched_inputs):
|
226 |
+
"""
|
227 |
+
Args:
|
228 |
+
batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
|
229 |
+
Each item in the list contains the inputs for one image.
|
230 |
+
For now, each item in the list is a dict that contains:
|
231 |
+
* "image": Tensor, image in (C, H, W) format.
|
232 |
+
* "instances": per-region ground truth
|
233 |
+
* Other information that's included in the original dicts, such as:
|
234 |
+
"height", "width" (int): the output resolution of the model (may be different
|
235 |
+
from input resolution), used in inference.
|
236 |
+
Returns:
|
237 |
+
list[dict]:
|
238 |
+
each dict has the results for one image. The dict contains the following keys:
|
239 |
+
|
240 |
+
* "sem_seg":
|
241 |
+
A Tensor that represents the
|
242 |
+
per-pixel segmentation prediced by the head.
|
243 |
+
The prediction has shape KxHxW that represents the logits of
|
244 |
+
each class for each pixel.
|
245 |
+
* "panoptic_seg":
|
246 |
+
A tuple that represent panoptic output
|
247 |
+
panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
|
248 |
+
segments_info (list[dict]): Describe each segment in `panoptic_seg`.
|
249 |
+
Each dict contains keys "id", "category_id", "isthing".
|
250 |
+
"""
|
251 |
+
if self.training:
|
252 |
+
return self.train_model(batched_inputs)
|
253 |
+
else:
|
254 |
+
# NOTE consider only B=1 case.
|
255 |
+
return self.inference(batched_inputs[0])
|
256 |
+
|
257 |
+
def train_model(self, batched_inputs):
|
258 |
+
images = []
|
259 |
+
audio_features = []
|
260 |
+
for video in batched_inputs:
|
261 |
+
for frame in video["image"]:
|
262 |
+
images.append(frame.to(self.device))
|
263 |
+
for audio_feat in video["audio"]:
|
264 |
+
audio_features.append(torch.tensor(audio_feat).to(self.device))
|
265 |
+
|
266 |
+
audio_features = torch.stack(audio_features)
|
267 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
268 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
269 |
+
image_features = self.backbone(images.tensor)
|
270 |
+
|
271 |
+
BT = len(images)
|
272 |
+
T = self.num_frames if self.training else BT
|
273 |
+
B = BT // T
|
274 |
+
|
275 |
+
outputs, frame_queries, mask_features = self.sem_seg_head(image_features, audio_features)
|
276 |
+
|
277 |
+
mask_features = self.avism_module.avism_mask_features(mask_features)
|
278 |
+
mask_features = mask_features.view(B, self.num_frames, *mask_features.shape[-3:])
|
279 |
+
|
280 |
+
# mask classification target
|
281 |
+
frame_targets, clip_targets = self.prepare_targets(batched_inputs, images)
|
282 |
+
|
283 |
+
# bipartite matching-based loss
|
284 |
+
losses, fg_indices = self.criterion(outputs, frame_targets)
|
285 |
+
|
286 |
+
avism_outputs = self.avism_module(frame_queries, audio_features)
|
287 |
+
avism_outputs["pred_masks"] = torch.einsum("lbqc,btchw->lbqthw", avism_outputs["pred_mask_embed"], mask_features)
|
288 |
+
for out in avism_outputs["aux_outputs"]:
|
289 |
+
out["pred_masks"] = torch.einsum("lbqc,btchw->lbqthw", out["pred_mask_embed"], mask_features)
|
290 |
+
|
291 |
+
for k in list(losses.keys()):
|
292 |
+
if k in self.criterion.weight_dict:
|
293 |
+
losses[k] *= self.criterion.weight_dict[k]
|
294 |
+
else:
|
295 |
+
# remove this loss if not specified in `weight_dict`
|
296 |
+
losses.pop(k)
|
297 |
+
avism_loss_dict = self.avism_criterion(avism_outputs, clip_targets, frame_targets, fg_indices)
|
298 |
+
avism_weight_dict = self.avism_criterion.weight_dict
|
299 |
+
|
300 |
+
for k in avism_loss_dict.keys():
|
301 |
+
if k in avism_weight_dict:
|
302 |
+
avism_loss_dict[k] *= avism_weight_dict[k]
|
303 |
+
losses.update(avism_loss_dict)
|
304 |
+
return losses
|
305 |
+
|
306 |
+
def prepare_targets(self, targets, images):
|
307 |
+
h_pad, w_pad = images.tensor.shape[-2:]
|
308 |
+
frame_gt_instances = []
|
309 |
+
clip_gt_instances = []
|
310 |
+
for targets_per_video in targets:
|
311 |
+
_num_instance = len(targets_per_video["instances"][0])
|
312 |
+
mask_shape = [_num_instance, self.num_frames, h_pad, w_pad]
|
313 |
+
gt_masks_per_video = torch.zeros(mask_shape, dtype=torch.bool, device=self.device)
|
314 |
+
|
315 |
+
gt_classes_per_video = targets_per_video["instances"][0].gt_classes.to(self.device)
|
316 |
+
gt_ids_per_video = []
|
317 |
+
for f_i, targets_per_frame in enumerate(targets_per_video["instances"]):
|
318 |
+
targets_per_frame = targets_per_frame.to(self.device)
|
319 |
+
h, w = targets_per_frame.image_size
|
320 |
+
|
321 |
+
_update_cls = gt_classes_per_video == -1
|
322 |
+
gt_classes_per_video[_update_cls] = targets_per_frame.gt_classes[_update_cls]
|
323 |
+
gt_ids_per_video.append(targets_per_frame.gt_ids)
|
324 |
+
if isinstance(targets_per_frame.gt_masks, BitMasks):
|
325 |
+
gt_masks_per_video[:, f_i, :h, :w] = targets_per_frame.gt_masks.tensor
|
326 |
+
else: #polygon
|
327 |
+
gt_masks_per_video[:, f_i, :h, :w] = targets_per_frame.gt_masks
|
328 |
+
|
329 |
+
gt_ids_per_video = torch.stack(gt_ids_per_video, dim=1)
|
330 |
+
gt_ids_per_video[gt_masks_per_video.sum(dim=(2,3)) == 0] = -1
|
331 |
+
valid_bool_frame = (gt_ids_per_video != -1)
|
332 |
+
valid_bool_clip = valid_bool_frame.any(dim=-1)
|
333 |
+
|
334 |
+
gt_classes_per_video = gt_classes_per_video[valid_bool_clip].long() # N,
|
335 |
+
gt_ids_per_video = gt_ids_per_video[valid_bool_clip].long() # N, num_frames
|
336 |
+
gt_masks_per_video = gt_masks_per_video[valid_bool_clip].float() # N, num_frames, H, W
|
337 |
+
valid_bool_frame = valid_bool_frame[valid_bool_clip]
|
338 |
+
|
339 |
+
if len(gt_ids_per_video) > 0:
|
340 |
+
min_id = max(gt_ids_per_video[valid_bool_frame].min(), 0)
|
341 |
+
gt_ids_per_video[valid_bool_frame] -= min_id
|
342 |
+
|
343 |
+
clip_gt_instances.append(
|
344 |
+
{
|
345 |
+
"labels": gt_classes_per_video, "ids": gt_ids_per_video, "masks": gt_masks_per_video,
|
346 |
+
"video_len": targets_per_video["length"], "frame_idx": targets_per_video["frame_idx"],
|
347 |
+
}
|
348 |
+
)
|
349 |
+
|
350 |
+
for f_i in range(self.num_frames):
|
351 |
+
_cls = gt_classes_per_video.clone()
|
352 |
+
_ids = gt_ids_per_video[:, f_i].clone()
|
353 |
+
_mask = gt_masks_per_video[:, f_i].clone()
|
354 |
+
|
355 |
+
valid = _ids != -1
|
356 |
+
frame_gt_instances.append({
|
357 |
+
"labels": _cls[valid],
|
358 |
+
"ids": _ids[valid],
|
359 |
+
"masks": _mask[valid],
|
360 |
+
})
|
361 |
+
|
362 |
+
return frame_gt_instances, clip_gt_instances
|
363 |
+
|
364 |
+
def inference(self, batched_inputs):
|
365 |
+
frame_queries, mask_features = [], []
|
366 |
+
num_frames = len(batched_inputs["image"])
|
367 |
+
to_store = self.device if num_frames <= 36 else "cpu"
|
368 |
+
|
369 |
+
audio_features = torch.tensor(batched_inputs["audio"]).to(self.device)
|
370 |
+
|
371 |
+
with torch.no_grad():
|
372 |
+
for i in range(math.ceil(num_frames / self.test_run_chunk_size)):
|
373 |
+
images = batched_inputs["image"][i*self.test_run_chunk_size : (i+1)*self.test_run_chunk_size]
|
374 |
+
images = [(x.to(self.device) - self.pixel_mean) / self.pixel_std for x in images]
|
375 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
376 |
+
|
377 |
+
audio_features_chunk = audio_features[i*self.test_run_chunk_size : (i+1)*self.test_run_chunk_size]
|
378 |
+
|
379 |
+
features = self.backbone(images.tensor)
|
380 |
+
outputs, _frame_queries, _mask_features = self.sem_seg_head(features, audio_features_chunk)
|
381 |
+
|
382 |
+
_mask_features = self.avism_module.avism_mask_features(_mask_features)
|
383 |
+
|
384 |
+
# BT is 1 as runs per frame
|
385 |
+
frame_queries.append(_frame_queries[-1]) # T', fQ, C
|
386 |
+
mask_features.append(_mask_features.to(to_store)) # T', C, H, W
|
387 |
+
|
388 |
+
interim_size = images.tensor.shape[-2:]
|
389 |
+
image_size = images.image_sizes[0] # image size without padding after data augmentation
|
390 |
+
|
391 |
+
out_height = batched_inputs.get("height", image_size[0]) # raw image size before data augmentation
|
392 |
+
out_width = batched_inputs.get("width", image_size[1])
|
393 |
+
|
394 |
+
del outputs, images, batched_inputs
|
395 |
+
|
396 |
+
frame_queries = torch.cat(frame_queries)[None] # 1, T, fQ, C
|
397 |
+
mask_features = torch.cat(mask_features) # T, C, H, W
|
398 |
+
|
399 |
+
avism_outputs = self.avism_module(frame_queries, audio_features)
|
400 |
+
|
401 |
+
mask_cls = avism_outputs["pred_logits"][-1, 0] # cQ, K+1
|
402 |
+
mask_embed = avism_outputs["pred_mask_embed"][-1, 0] # cQ, C
|
403 |
+
|
404 |
+
del avism_outputs
|
405 |
+
|
406 |
+
scores = F.softmax(mask_cls, dim=-1)[:, :-1]
|
407 |
+
labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
|
408 |
+
|
409 |
+
num_topk = self.test_topk_per_image
|
410 |
+
scores_per_video, topk_indices = scores.flatten(0, 1).topk(num_topk, sorted=False)
|
411 |
+
labels_per_video = labels[topk_indices]
|
412 |
+
|
413 |
+
topk_indices = torch.div(topk_indices, self.sem_seg_head.num_classes, rounding_mode='floor')
|
414 |
+
mask_embed = mask_embed[topk_indices]
|
415 |
+
|
416 |
+
masks_per_video = []
|
417 |
+
numerator = torch.zeros(len(mask_embed), dtype=torch.float, device=self.device)
|
418 |
+
denominator = torch.zeros(len(mask_embed), dtype=torch.float, device=self.device)
|
419 |
+
for i in range(math.ceil(len(mask_features) / self.test_interpolate_chunk_size)):
|
420 |
+
m_f = mask_features[i*self.test_interpolate_chunk_size : (i+1)*self.test_interpolate_chunk_size].to(self.device)
|
421 |
+
|
422 |
+
mask_pred = torch.einsum("qc,tchw->qthw", mask_embed, m_f)
|
423 |
+
|
424 |
+
# upsample masks
|
425 |
+
mask_pred = retry_if_cuda_oom(F.interpolate)(
|
426 |
+
mask_pred,
|
427 |
+
size=interim_size,
|
428 |
+
mode="bilinear",
|
429 |
+
align_corners=False,
|
430 |
+
) # cQ, T, H, W
|
431 |
+
|
432 |
+
mask_pred = mask_pred[:, :, : image_size[0], : image_size[1]]
|
433 |
+
|
434 |
+
interim_mask_soft = mask_pred.sigmoid()
|
435 |
+
interim_mask_hard = interim_mask_soft > 0.5
|
436 |
+
|
437 |
+
numerator += (interim_mask_soft.flatten(1) * interim_mask_hard.flatten(1)).sum(1)
|
438 |
+
denominator += interim_mask_hard.flatten(1).sum(1)
|
439 |
+
|
440 |
+
mask_pred = F.interpolate(
|
441 |
+
mask_pred, size=(out_height, out_width), mode="bilinear", align_corners=False
|
442 |
+
) > 0.
|
443 |
+
masks_per_video.append(mask_pred.to(to_store))
|
444 |
+
masks_per_video = torch.cat(masks_per_video, dim=1)
|
445 |
+
scores_per_video *= (numerator / (denominator + 1e-6))
|
446 |
+
|
447 |
+
confidence = 0.3
|
448 |
+
indices = torch.nonzero(scores_per_video > confidence).squeeze(-1)
|
449 |
+
scores_per_video = scores_per_video[indices]
|
450 |
+
labels_per_video = labels_per_video[indices]
|
451 |
+
masks_per_video = masks_per_video[indices]
|
452 |
+
|
453 |
+
processed_results = {
|
454 |
+
"image_size": (out_height, out_width),
|
455 |
+
"pred_scores": scores_per_video.tolist(),
|
456 |
+
"pred_labels": labels_per_video.tolist(),
|
457 |
+
"pred_masks": masks_per_video.cpu(),
|
458 |
+
}
|
459 |
+
|
460 |
+
return processed_results
|
avism/avism_model_coco.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from detectron2.config import configurable
|
9 |
+
from detectron2.data import MetadataCatalog
|
10 |
+
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
|
11 |
+
from detectron2.modeling.backbone import Backbone
|
12 |
+
from detectron2.structures import Boxes, ImageList, Instances, BitMasks
|
13 |
+
from detectron2.utils.memory import retry_if_cuda_oom
|
14 |
+
|
15 |
+
from mask2former.modeling.criterion import SetCriterion
|
16 |
+
from mask2former.modeling.matcher import HungarianMatcher
|
17 |
+
from .modeling.avism_criterion import AvismSetCriterion
|
18 |
+
from .modeling.avism_matcher import AvismHungarianMatcher
|
19 |
+
from .modeling.transformer_decoder.avism_coco import Avism_COCO
|
20 |
+
|
21 |
+
|
22 |
+
@META_ARCH_REGISTRY.register()
|
23 |
+
class AVISM_COCO(nn.Module):
|
24 |
+
"""
|
25 |
+
Main class for mask classification semantic segmentation architectures.
|
26 |
+
"""
|
27 |
+
|
28 |
+
@configurable
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
*,
|
32 |
+
backbone: Backbone,
|
33 |
+
sem_seg_head: nn.Module,
|
34 |
+
criterion: nn.Module,
|
35 |
+
num_queries: int,
|
36 |
+
object_mask_threshold: float,
|
37 |
+
overlap_threshold: float,
|
38 |
+
metadata,
|
39 |
+
size_divisibility: int,
|
40 |
+
pixel_mean: Tuple[float],
|
41 |
+
pixel_std: Tuple[float],
|
42 |
+
# inference
|
43 |
+
test_topk_per_image: int,
|
44 |
+
# avism
|
45 |
+
avism_module: nn.Module,
|
46 |
+
avism_criterion: nn.Module,
|
47 |
+
num_frames: int,
|
48 |
+
num_classes: int,
|
49 |
+
is_multi_cls: bool,
|
50 |
+
apply_cls_thres: float,
|
51 |
+
freeze_detector: bool,
|
52 |
+
test_run_chunk_size: int,
|
53 |
+
test_interpolate_chunk_size: int,
|
54 |
+
is_coco: bool,
|
55 |
+
):
|
56 |
+
"""
|
57 |
+
Args:
|
58 |
+
backbone: a backbone module, must follow detectron2's backbone interface
|
59 |
+
sem_seg_head: a module that predicts semantic segmentation from backbone features
|
60 |
+
criterion: a module that defines the loss
|
61 |
+
num_queries: int, number of queries
|
62 |
+
object_mask_threshold: float, threshold to filter query based on classification score
|
63 |
+
for panoptic segmentation inference
|
64 |
+
overlap_threshold: overlap threshold used in general inference for panoptic segmentation
|
65 |
+
metadata: dataset meta, get `thing` and `stuff` category names for panoptic
|
66 |
+
segmentation inference
|
67 |
+
size_divisibility: Some backbones require the input height and width to be divisible by a
|
68 |
+
specific integer. We can use this to override such requirement.
|
69 |
+
pixel_mean, pixel_std: list or tuple with #channels element, representing
|
70 |
+
the per-channel mean and std to be used to normalize the input image
|
71 |
+
test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
|
72 |
+
"""
|
73 |
+
super().__init__()
|
74 |
+
self.backbone = backbone
|
75 |
+
self.sem_seg_head = sem_seg_head
|
76 |
+
self.criterion = criterion
|
77 |
+
self.num_queries = num_queries
|
78 |
+
self.overlap_threshold = overlap_threshold
|
79 |
+
self.object_mask_threshold = object_mask_threshold
|
80 |
+
self.metadata = metadata
|
81 |
+
if size_divisibility < 0:
|
82 |
+
# use backbone size_divisibility if not set
|
83 |
+
size_divisibility = self.backbone.size_divisibility
|
84 |
+
self.size_divisibility = size_divisibility
|
85 |
+
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
86 |
+
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
87 |
+
|
88 |
+
# additional args
|
89 |
+
self.test_topk_per_image = test_topk_per_image
|
90 |
+
|
91 |
+
# avism hyper-parameters
|
92 |
+
self.num_frames = num_frames
|
93 |
+
self.num_classes = num_classes
|
94 |
+
self.vita_module = avism_module
|
95 |
+
self.vita_criterion = avism_criterion
|
96 |
+
self.is_multi_cls = is_multi_cls
|
97 |
+
self.apply_cls_thres = apply_cls_thres
|
98 |
+
|
99 |
+
if freeze_detector:
|
100 |
+
for name, p in self.named_parameters():
|
101 |
+
if not "vita_module" in name:
|
102 |
+
p.requires_grad_(False)
|
103 |
+
self.test_run_chunk_size = test_run_chunk_size
|
104 |
+
self.test_interpolate_chunk_size = test_interpolate_chunk_size
|
105 |
+
|
106 |
+
self.is_coco = is_coco
|
107 |
+
|
108 |
+
@classmethod
|
109 |
+
def from_config(cls, cfg):
|
110 |
+
backbone = build_backbone(cfg)
|
111 |
+
sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())
|
112 |
+
|
113 |
+
# Loss parameters:
|
114 |
+
deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
|
115 |
+
no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT
|
116 |
+
avism_deep_supervision = cfg.MODEL.AVISM.DEEP_SUPERVISION
|
117 |
+
|
118 |
+
# loss weights
|
119 |
+
class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT
|
120 |
+
dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
|
121 |
+
mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
|
122 |
+
sim_weight = cfg.MODEL.AVISM.SIM_WEIGHT
|
123 |
+
|
124 |
+
# building criterion
|
125 |
+
matcher = HungarianMatcher(
|
126 |
+
cost_class=class_weight,
|
127 |
+
cost_mask=mask_weight,
|
128 |
+
cost_dice=dice_weight,
|
129 |
+
num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
|
130 |
+
)
|
131 |
+
|
132 |
+
weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight}
|
133 |
+
|
134 |
+
if deep_supervision:
|
135 |
+
dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
|
136 |
+
aux_weight_dict = {}
|
137 |
+
for i in range(dec_layers - 1):
|
138 |
+
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
139 |
+
weight_dict.update(aux_weight_dict)
|
140 |
+
|
141 |
+
losses = ["labels", "masks"]
|
142 |
+
|
143 |
+
criterion = SetCriterion(
|
144 |
+
sem_seg_head.num_classes,
|
145 |
+
matcher=matcher,
|
146 |
+
weight_dict=weight_dict,
|
147 |
+
eos_coef=no_object_weight,
|
148 |
+
losses=losses,
|
149 |
+
num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
|
150 |
+
oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
|
151 |
+
importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
|
152 |
+
avism_last_layer_num=cfg.MODEL.AVISM.LAST_LAYER_NUM,
|
153 |
+
)
|
154 |
+
|
155 |
+
# Avism
|
156 |
+
num_classes = sem_seg_head.num_classes
|
157 |
+
hidden_dim = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
|
158 |
+
avism_module = Avism_COCO(cfg=cfg, in_channels=hidden_dim, aux_loss=avism_deep_supervision)
|
159 |
+
|
160 |
+
# building criterion for avism inference
|
161 |
+
avism_matcher = AvismHungarianMatcher(
|
162 |
+
cost_class=class_weight,
|
163 |
+
cost_mask=mask_weight,
|
164 |
+
cost_dice=dice_weight,
|
165 |
+
num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
|
166 |
+
)
|
167 |
+
avism_weight_dict = {
|
168 |
+
"loss_avism_ce": class_weight, "loss_avism_mask": mask_weight, "loss_avism_dice": dice_weight
|
169 |
+
}
|
170 |
+
if sim_weight > 0.0:
|
171 |
+
avism_weight_dict["loss_avism_sim"] = sim_weight
|
172 |
+
|
173 |
+
if avism_deep_supervision:
|
174 |
+
avism_dec_layers = cfg.MODEL.AVISM.DEC_LAYERS
|
175 |
+
aux_weight_dict = {}
|
176 |
+
for i in range(avism_dec_layers - 1):
|
177 |
+
aux_weight_dict.update({k + f"_{i}": v for k, v in avism_weight_dict.items()})
|
178 |
+
avism_weight_dict.update(aux_weight_dict)
|
179 |
+
avism_losses = ["avism_labels", "avism_masks"]
|
180 |
+
if sim_weight > 0.0:
|
181 |
+
avism_losses.append("fg_sim")
|
182 |
+
|
183 |
+
avism_criterion = AvismSetCriterion(
|
184 |
+
num_classes,
|
185 |
+
matcher=avism_matcher,
|
186 |
+
weight_dict=avism_weight_dict,
|
187 |
+
eos_coef=cfg.MODEL.AVISM.NO_OBJECT_WEIGHT,
|
188 |
+
losses=avism_losses,
|
189 |
+
num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
|
190 |
+
oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
|
191 |
+
importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
|
192 |
+
sim_use_clip=cfg.MODEL.AVISM.SIM_USE_CLIP,
|
193 |
+
)
|
194 |
+
|
195 |
+
return {
|
196 |
+
"backbone": backbone,
|
197 |
+
"sem_seg_head": sem_seg_head,
|
198 |
+
"criterion": criterion,
|
199 |
+
"num_queries": cfg.MODEL.AVISM.NUM_OBJECT_QUERIES,
|
200 |
+
"object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD,
|
201 |
+
"overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD,
|
202 |
+
"metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
|
203 |
+
"size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
|
204 |
+
"pixel_mean": cfg.MODEL.PIXEL_MEAN,
|
205 |
+
"pixel_std": cfg.MODEL.PIXEL_STD,
|
206 |
+
# inference
|
207 |
+
"test_topk_per_image": cfg.TEST.DETECTIONS_PER_IMAGE,
|
208 |
+
# avism
|
209 |
+
"avism_module": avism_module,
|
210 |
+
"avism_criterion": avism_criterion,
|
211 |
+
"num_frames": cfg.INPUT.SAMPLING_FRAME_NUM,
|
212 |
+
"num_classes": num_classes,
|
213 |
+
"is_multi_cls": cfg.MODEL.AVISM.MULTI_CLS_ON,
|
214 |
+
"apply_cls_thres": cfg.MODEL.AVISM.APPLY_CLS_THRES,
|
215 |
+
"freeze_detector": cfg.MODEL.AVISM.FREEZE_DETECTOR,
|
216 |
+
"test_run_chunk_size": cfg.MODEL.AVISM.TEST_RUN_CHUNK_SIZE,
|
217 |
+
"test_interpolate_chunk_size": cfg.MODEL.AVISM.TEST_INTERPOLATE_CHUNK_SIZE,
|
218 |
+
"is_coco": cfg.DATASETS.TEST[0].startswith("coco"),
|
219 |
+
}
|
220 |
+
|
221 |
+
@property
|
222 |
+
def device(self):
|
223 |
+
return self.pixel_mean.device
|
224 |
+
|
225 |
+
def forward(self, batched_inputs):
|
226 |
+
"""
|
227 |
+
Args:
|
228 |
+
batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
|
229 |
+
Each item in the list contains the inputs for one image.
|
230 |
+
For now, each item in the list is a dict that contains:
|
231 |
+
* "image": Tensor, image in (C, H, W) format.
|
232 |
+
* "instances": per-region ground truth
|
233 |
+
* Other information that's included in the original dicts, such as:
|
234 |
+
"height", "width" (int): the output resolution of the model (may be different
|
235 |
+
from input resolution), used in inference.
|
236 |
+
Returns:
|
237 |
+
list[dict]:
|
238 |
+
each dict has the results for one image. The dict contains the following keys:
|
239 |
+
|
240 |
+
* "sem_seg":
|
241 |
+
A Tensor that represents the
|
242 |
+
per-pixel segmentation prediced by the head.
|
243 |
+
The prediction has shape KxHxW that represents the logits of
|
244 |
+
each class for each pixel.
|
245 |
+
* "panoptic_seg":
|
246 |
+
A tuple that represent panoptic output
|
247 |
+
panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
|
248 |
+
segments_info (list[dict]): Describe each segment in `panoptic_seg`.
|
249 |
+
Each dict contains keys "id", "category_id", "isthing".
|
250 |
+
"""
|
251 |
+
if self.training:
|
252 |
+
return self.train_model(batched_inputs)
|
253 |
+
else:
|
254 |
+
# NOTE consider only B=1 case.
|
255 |
+
return self.inference(batched_inputs[0])
|
256 |
+
|
257 |
+
def train_model(self, batched_inputs):
|
258 |
+
images = []
|
259 |
+
audio_features = []
|
260 |
+
for video in batched_inputs:
|
261 |
+
for frame in video["image"]:
|
262 |
+
images.append(frame.to(self.device))
|
263 |
+
for audio_feat in video["audio"]:
|
264 |
+
audio_features.append(torch.tensor(audio_feat).to(self.device))
|
265 |
+
|
266 |
+
audio_features = torch.stack(audio_features)
|
267 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
268 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
269 |
+
image_features = self.backbone(images.tensor)
|
270 |
+
|
271 |
+
BT = len(images)
|
272 |
+
T = self.num_frames if self.training else BT
|
273 |
+
B = BT // T
|
274 |
+
|
275 |
+
outputs, frame_queries, mask_features = self.sem_seg_head(image_features, audio_features)
|
276 |
+
|
277 |
+
mask_features = self.vita_module.vita_mask_features(mask_features)
|
278 |
+
mask_features = mask_features.view(B, self.num_frames, *mask_features.shape[-3:])
|
279 |
+
|
280 |
+
# mask classification target
|
281 |
+
frame_targets, clip_targets = self.prepare_targets(batched_inputs, images)
|
282 |
+
|
283 |
+
# bipartite matching-based loss
|
284 |
+
losses, fg_indices = self.criterion(outputs, frame_targets)
|
285 |
+
|
286 |
+
avism_outputs = self.vita_module(frame_queries, audio_features)
|
287 |
+
avism_outputs["pred_masks"] = torch.einsum("lbqc,btchw->lbqthw", avism_outputs["pred_mask_embed"], mask_features)
|
288 |
+
for out in avism_outputs["aux_outputs"]:
|
289 |
+
out["pred_masks"] = torch.einsum("lbqc,btchw->lbqthw", out["pred_mask_embed"], mask_features)
|
290 |
+
|
291 |
+
for k in list(losses.keys()):
|
292 |
+
if k in self.criterion.weight_dict:
|
293 |
+
losses[k] *= self.criterion.weight_dict[k]
|
294 |
+
else:
|
295 |
+
# remove this loss if not specified in `weight_dict`
|
296 |
+
losses.pop(k)
|
297 |
+
avism_loss_dict = self.vita_criterion(avism_outputs, clip_targets, frame_targets, fg_indices)
|
298 |
+
avism_weight_dict = self.vita_criterion.weight_dict
|
299 |
+
|
300 |
+
for k in avism_loss_dict.keys():
|
301 |
+
if k in avism_weight_dict:
|
302 |
+
avism_loss_dict[k] *= avism_weight_dict[k]
|
303 |
+
losses.update(avism_loss_dict)
|
304 |
+
return losses
|
305 |
+
|
306 |
+
def prepare_targets(self, targets, images):
|
307 |
+
h_pad, w_pad = images.tensor.shape[-2:]
|
308 |
+
frame_gt_instances = []
|
309 |
+
clip_gt_instances = []
|
310 |
+
for targets_per_video in targets:
|
311 |
+
_num_instance = len(targets_per_video["instances"][0])
|
312 |
+
mask_shape = [_num_instance, self.num_frames, h_pad, w_pad]
|
313 |
+
gt_masks_per_video = torch.zeros(mask_shape, dtype=torch.bool, device=self.device)
|
314 |
+
|
315 |
+
gt_classes_per_video = targets_per_video["instances"][0].gt_classes.to(self.device)
|
316 |
+
gt_ids_per_video = []
|
317 |
+
for f_i, targets_per_frame in enumerate(targets_per_video["instances"]):
|
318 |
+
targets_per_frame = targets_per_frame.to(self.device)
|
319 |
+
h, w = targets_per_frame.image_size
|
320 |
+
|
321 |
+
_update_cls = gt_classes_per_video == -1
|
322 |
+
gt_classes_per_video[_update_cls] = targets_per_frame.gt_classes[_update_cls]
|
323 |
+
gt_ids_per_video.append(targets_per_frame.gt_ids)
|
324 |
+
if isinstance(targets_per_frame.gt_masks, BitMasks):
|
325 |
+
gt_masks_per_video[:, f_i, :h, :w] = targets_per_frame.gt_masks.tensor
|
326 |
+
else: #polygon
|
327 |
+
gt_masks_per_video[:, f_i, :h, :w] = targets_per_frame.gt_masks
|
328 |
+
|
329 |
+
gt_ids_per_video = torch.stack(gt_ids_per_video, dim=1)
|
330 |
+
gt_ids_per_video[gt_masks_per_video.sum(dim=(2,3)) == 0] = -1
|
331 |
+
valid_bool_frame = (gt_ids_per_video != -1)
|
332 |
+
valid_bool_clip = valid_bool_frame.any(dim=-1)
|
333 |
+
|
334 |
+
gt_classes_per_video = gt_classes_per_video[valid_bool_clip].long() # N,
|
335 |
+
gt_ids_per_video = gt_ids_per_video[valid_bool_clip].long() # N, num_frames
|
336 |
+
gt_masks_per_video = gt_masks_per_video[valid_bool_clip].float() # N, num_frames, H, W
|
337 |
+
valid_bool_frame = valid_bool_frame[valid_bool_clip]
|
338 |
+
|
339 |
+
if len(gt_ids_per_video) > 0:
|
340 |
+
min_id = max(gt_ids_per_video[valid_bool_frame].min(), 0)
|
341 |
+
gt_ids_per_video[valid_bool_frame] -= min_id
|
342 |
+
|
343 |
+
clip_gt_instances.append(
|
344 |
+
{
|
345 |
+
"labels": gt_classes_per_video, "ids": gt_ids_per_video, "masks": gt_masks_per_video,
|
346 |
+
"video_len": targets_per_video["length"], "frame_idx": targets_per_video["frame_idx"],
|
347 |
+
}
|
348 |
+
)
|
349 |
+
|
350 |
+
for f_i in range(self.num_frames):
|
351 |
+
_cls = gt_classes_per_video.clone()
|
352 |
+
_ids = gt_ids_per_video[:, f_i].clone()
|
353 |
+
_mask = gt_masks_per_video[:, f_i].clone()
|
354 |
+
|
355 |
+
valid = _ids != -1
|
356 |
+
frame_gt_instances.append({
|
357 |
+
"labels": _cls[valid],
|
358 |
+
"ids": _ids[valid],
|
359 |
+
"masks": _mask[valid],
|
360 |
+
})
|
361 |
+
|
362 |
+
return frame_gt_instances, clip_gt_instances
|
363 |
+
|
364 |
+
def inference(self, batched_inputs):
|
365 |
+
frame_queries, mask_features = [], []
|
366 |
+
num_frames = len(batched_inputs["image"])
|
367 |
+
to_store = self.device if num_frames <= 36 else "cpu"
|
368 |
+
|
369 |
+
audio_features = torch.tensor(batched_inputs["audio"]).to(self.device)
|
370 |
+
|
371 |
+
with torch.no_grad():
|
372 |
+
for i in range(math.ceil(num_frames / self.test_run_chunk_size)):
|
373 |
+
images = batched_inputs["image"][i*self.test_run_chunk_size : (i+1)*self.test_run_chunk_size]
|
374 |
+
images = [(x.to(self.device) - self.pixel_mean) / self.pixel_std for x in images]
|
375 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
376 |
+
|
377 |
+
audio_features_chunk = audio_features[i*self.test_run_chunk_size : (i+1)*self.test_run_chunk_size]
|
378 |
+
|
379 |
+
features = self.backbone(images.tensor)
|
380 |
+
outputs, _frame_queries, _mask_features = self.sem_seg_head(features, audio_features_chunk)
|
381 |
+
|
382 |
+
_mask_features = self.vita_module.vita_mask_features(_mask_features)
|
383 |
+
|
384 |
+
# BT is 1 as runs per frame
|
385 |
+
frame_queries.append(_frame_queries[-1]) # T', fQ, C
|
386 |
+
mask_features.append(_mask_features.to(to_store)) # T', C, H, W
|
387 |
+
|
388 |
+
interim_size = images.tensor.shape[-2:]
|
389 |
+
image_size = images.image_sizes[0] # image size without padding after data augmentation
|
390 |
+
|
391 |
+
out_height = batched_inputs.get("height", image_size[0]) # raw image size before data augmentation
|
392 |
+
out_width = batched_inputs.get("width", image_size[1])
|
393 |
+
|
394 |
+
del outputs, images, batched_inputs
|
395 |
+
|
396 |
+
frame_queries = torch.cat(frame_queries)[None] # 1, T, fQ, C
|
397 |
+
mask_features = torch.cat(mask_features) # T, C, H, W
|
398 |
+
|
399 |
+
avism_outputs = self.vita_module(frame_queries, audio_features)
|
400 |
+
|
401 |
+
mask_cls = avism_outputs["pred_logits"][-1, 0] # cQ, K+1
|
402 |
+
mask_embed = avism_outputs["pred_mask_embed"][-1, 0] # cQ, C
|
403 |
+
|
404 |
+
del avism_outputs
|
405 |
+
|
406 |
+
scores = F.softmax(mask_cls, dim=-1)[:, :-1]
|
407 |
+
labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
|
408 |
+
|
409 |
+
num_topk = self.test_topk_per_image
|
410 |
+
scores_per_video, topk_indices = scores.flatten(0, 1).topk(num_topk, sorted=False)
|
411 |
+
labels_per_video = labels[topk_indices]
|
412 |
+
|
413 |
+
topk_indices = torch.div(topk_indices, self.sem_seg_head.num_classes, rounding_mode='floor')
|
414 |
+
mask_embed = mask_embed[topk_indices]
|
415 |
+
|
416 |
+
masks_per_video = []
|
417 |
+
numerator = torch.zeros(len(mask_embed), dtype=torch.float, device=self.device)
|
418 |
+
denominator = torch.zeros(len(mask_embed), dtype=torch.float, device=self.device)
|
419 |
+
for i in range(math.ceil(len(mask_features) / self.test_interpolate_chunk_size)):
|
420 |
+
m_f = mask_features[i*self.test_interpolate_chunk_size : (i+1)*self.test_interpolate_chunk_size].to(self.device)
|
421 |
+
|
422 |
+
mask_pred = torch.einsum("qc,tchw->qthw", mask_embed, m_f)
|
423 |
+
|
424 |
+
# upsample masks
|
425 |
+
mask_pred = retry_if_cuda_oom(F.interpolate)(
|
426 |
+
mask_pred,
|
427 |
+
size=interim_size,
|
428 |
+
mode="bilinear",
|
429 |
+
align_corners=False,
|
430 |
+
) # cQ, T, H, W
|
431 |
+
|
432 |
+
mask_pred = mask_pred[:, :, : image_size[0], : image_size[1]]
|
433 |
+
|
434 |
+
interim_mask_soft = mask_pred.sigmoid()
|
435 |
+
interim_mask_hard = interim_mask_soft > 0.5
|
436 |
+
|
437 |
+
numerator += (interim_mask_soft.flatten(1) * interim_mask_hard.flatten(1)).sum(1)
|
438 |
+
denominator += interim_mask_hard.flatten(1).sum(1)
|
439 |
+
|
440 |
+
mask_pred = F.interpolate(
|
441 |
+
mask_pred, size=(out_height, out_width), mode="bilinear", align_corners=False
|
442 |
+
) > 0.
|
443 |
+
masks_per_video.append(mask_pred.to(to_store))
|
444 |
+
masks_per_video = torch.cat(masks_per_video, dim=1)
|
445 |
+
scores_per_video *= (numerator / (denominator + 1e-6))
|
446 |
+
|
447 |
+
confidence = 0.3
|
448 |
+
indices = torch.nonzero(scores_per_video > confidence).squeeze(-1)
|
449 |
+
scores_per_video = scores_per_video[indices]
|
450 |
+
labels_per_video = labels_per_video[indices]
|
451 |
+
masks_per_video = masks_per_video[indices]
|
452 |
+
|
453 |
+
processed_results = {
|
454 |
+
"image_size": (out_height, out_width),
|
455 |
+
"pred_scores": scores_per_video.tolist(),
|
456 |
+
"pred_labels": labels_per_video.tolist(),
|
457 |
+
"pred_masks": masks_per_video.cpu(),
|
458 |
+
}
|
459 |
+
|
460 |
+
return processed_results
|
avism/config.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
from detectron2.config import CfgNode as CN
|
3 |
+
|
4 |
+
|
5 |
+
def add_avism_config(cfg):
|
6 |
+
cfg.DATASETS.DATASET_RATIO = []
|
7 |
+
|
8 |
+
# DataLoader
|
9 |
+
cfg.INPUT.SAMPLING_FRAME_NUM = 2
|
10 |
+
cfg.INPUT.SAMPLING_FRAME_RANGE = 20
|
11 |
+
cfg.INPUT.SAMPLING_FRAME_SHUFFLE = False
|
12 |
+
cfg.INPUT.AUGMENTATIONS = [] # "brightness", "contrast", "saturation", "rotation"
|
13 |
+
|
14 |
+
# Pseudo Data Use
|
15 |
+
cfg.INPUT.PSEUDO = CN()
|
16 |
+
cfg.INPUT.PSEUDO.AUGMENTATIONS = ['rotation']
|
17 |
+
cfg.INPUT.PSEUDO.MIN_SIZE_TRAIN = (480, 512, 544, 576, 608, 640, 672, 704, 736, 768)
|
18 |
+
cfg.INPUT.PSEUDO.MAX_SIZE_TRAIN = 768
|
19 |
+
cfg.INPUT.PSEUDO.MIN_SIZE_TRAIN_SAMPLING = "choice_by_clip"
|
20 |
+
cfg.INPUT.PSEUDO.CROP = CN()
|
21 |
+
cfg.INPUT.PSEUDO.CROP.ENABLED = False
|
22 |
+
cfg.INPUT.PSEUDO.CROP.TYPE = "absolute_range"
|
23 |
+
cfg.INPUT.PSEUDO.CROP.SIZE = (384, 600)
|
24 |
+
|
25 |
+
# LSJ
|
26 |
+
cfg.INPUT.LSJ_AUG = CN()
|
27 |
+
cfg.INPUT.LSJ_AUG.ENABLED = False
|
28 |
+
cfg.INPUT.LSJ_AUG.IMAGE_SIZE = 1024
|
29 |
+
cfg.INPUT.LSJ_AUG.MIN_SCALE = 0.1
|
30 |
+
cfg.INPUT.LSJ_AUG.MAX_SCALE = 2.0
|
31 |
+
|
32 |
+
# AVISM
|
33 |
+
cfg.MODEL.AVISM = CN()
|
34 |
+
cfg.MODEL.AVISM.NHEADS = 8
|
35 |
+
cfg.MODEL.AVISM.DROPOUT = 0.0
|
36 |
+
cfg.MODEL.AVISM.DIM_FEEDFORWARD = 2048
|
37 |
+
cfg.MODEL.AVISM.ENC_LAYERS = 6
|
38 |
+
cfg.MODEL.AVISM.DEC_LAYERS = 3
|
39 |
+
cfg.MODEL.AVISM.ENC_WINDOW_SIZE = 0
|
40 |
+
cfg.MODEL.AVISM.PRE_NORM = False
|
41 |
+
cfg.MODEL.AVISM.HIDDEN_DIM = 256
|
42 |
+
cfg.MODEL.AVISM.NUM_OBJECT_QUERIES = 100
|
43 |
+
cfg.MODEL.AVISM.ENFORCE_INPUT_PROJ = True
|
44 |
+
|
45 |
+
cfg.MODEL.AVISM.NO_OBJECT_WEIGHT = 0.1
|
46 |
+
cfg.MODEL.AVISM.DEEP_SUPERVISION = True
|
47 |
+
cfg.MODEL.AVISM.LAST_LAYER_NUM = 3
|
48 |
+
cfg.MODEL.AVISM.MULTI_CLS_ON = True
|
49 |
+
cfg.MODEL.AVISM.APPLY_CLS_THRES = 0.01
|
50 |
+
|
51 |
+
cfg.MODEL.AVISM.SIM_USE_CLIP = True
|
52 |
+
cfg.MODEL.AVISM.SIM_WEIGHT = 0.5
|
53 |
+
|
54 |
+
cfg.MODEL.AVISM.FREEZE_DETECTOR = False
|
55 |
+
cfg.MODEL.AVISM.TEST_RUN_CHUNK_SIZE = 18
|
56 |
+
cfg.MODEL.AVISM.TEST_INTERPOLATE_CHUNK_SIZE = 5
|
avism/data/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .datasets import *
|
2 |
+
from .dataset_mapper import AVISDatasetMapper
|
3 |
+
from .build import *
|
4 |
+
from .avis_eval import AVISEvaluator
|
avism/data/augmentation.py
ADDED
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import numpy as np
|
3 |
+
import logging
|
4 |
+
import sys
|
5 |
+
from fvcore.transforms.transform import (
|
6 |
+
HFlipTransform,
|
7 |
+
NoOpTransform,
|
8 |
+
VFlipTransform,
|
9 |
+
)
|
10 |
+
from PIL import Image
|
11 |
+
from typing import Tuple
|
12 |
+
from fvcore.transforms.transform import (
|
13 |
+
BlendTransform,
|
14 |
+
CropTransform,
|
15 |
+
HFlipTransform,
|
16 |
+
NoOpTransform,
|
17 |
+
PadTransform,
|
18 |
+
Transform,
|
19 |
+
TransformList,
|
20 |
+
VFlipTransform,
|
21 |
+
)
|
22 |
+
|
23 |
+
from detectron2.data import transforms as T
|
24 |
+
|
25 |
+
|
26 |
+
class RandomApplyClip(T.Augmentation):
|
27 |
+
"""
|
28 |
+
Randomly apply an augmentation with a given probability.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, tfm_or_aug, prob=0.5, clip_frame_cnt=1):
|
32 |
+
"""
|
33 |
+
Args:
|
34 |
+
tfm_or_aug (Transform, Augmentation): the transform or augmentation
|
35 |
+
to be applied. It can either be a `Transform` or `Augmentation`
|
36 |
+
instance.
|
37 |
+
prob (float): probability between 0.0 and 1.0 that
|
38 |
+
the wrapper transformation is applied
|
39 |
+
"""
|
40 |
+
super().__init__()
|
41 |
+
self.aug = T.augmentation._transform_to_aug(tfm_or_aug)
|
42 |
+
assert 0.0 <= prob <= 1.0, f"Probablity must be between 0.0 and 1.0 (given: {prob})"
|
43 |
+
self.prob = prob
|
44 |
+
self._cnt = 0
|
45 |
+
self.clip_frame_cnt = clip_frame_cnt
|
46 |
+
|
47 |
+
def get_transform(self, *args):
|
48 |
+
if self._cnt % self.clip_frame_cnt == 0:
|
49 |
+
self.do = self._rand_range() < self.prob
|
50 |
+
self._cnt = 0 # avoiding overflow
|
51 |
+
self._cnt += 1
|
52 |
+
|
53 |
+
if self.do:
|
54 |
+
return self.aug.get_transform(*args)
|
55 |
+
else:
|
56 |
+
return NoOpTransform()
|
57 |
+
|
58 |
+
def __call__(self, aug_input):
|
59 |
+
if self._cnt % self.clip_frame_cnt == 0:
|
60 |
+
self.do = self._rand_range() < self.prob
|
61 |
+
self._cnt = 0 # avoiding overflow
|
62 |
+
self._cnt += 1
|
63 |
+
|
64 |
+
if self.do:
|
65 |
+
return self.aug(aug_input)
|
66 |
+
else:
|
67 |
+
return NoOpTransform()
|
68 |
+
|
69 |
+
|
70 |
+
class RandomRotationClip(T.Augmentation):
|
71 |
+
"""
|
72 |
+
This method returns a copy of this image, rotated the given
|
73 |
+
number of degrees counter clockwise around the given center.
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(self, angle, prob=0.5, expand=True, center=None, interp=None, clip_frame_cnt=1):
|
77 |
+
"""
|
78 |
+
Args:
|
79 |
+
angle (list[float]): If ``sample_style=="range"``,
|
80 |
+
a [min, max] interval from which to sample the angle (in degrees).
|
81 |
+
If ``sample_style=="choice"``, a list of angles to sample from
|
82 |
+
expand (bool): choose if the image should be resized to fit the whole
|
83 |
+
rotated image (default), or simply cropped
|
84 |
+
center (list[[float, float]]): If ``sample_style=="range"``,
|
85 |
+
a [[minx, miny], [maxx, maxy]] relative interval from which to sample the center,
|
86 |
+
[0, 0] being the top left of the image and [1, 1] the bottom right.
|
87 |
+
If ``sample_style=="choice"``, a list of centers to sample from
|
88 |
+
Default: None, which means that the center of rotation is the center of the image
|
89 |
+
center has no effect if expand=True because it only affects shifting
|
90 |
+
"""
|
91 |
+
super().__init__()
|
92 |
+
if isinstance(angle, (float, int)):
|
93 |
+
angle = (angle, angle)
|
94 |
+
if center is not None and isinstance(center[0], (float, int)):
|
95 |
+
center = (center, center)
|
96 |
+
self.angle_save = None
|
97 |
+
self.center_save = None
|
98 |
+
self._cnt = 0
|
99 |
+
self._init(locals())
|
100 |
+
|
101 |
+
def get_transform(self, image):
|
102 |
+
h, w = image.shape[:2]
|
103 |
+
if self._cnt % self.clip_frame_cnt == 0:
|
104 |
+
center = None
|
105 |
+
angle = np.random.uniform(self.angle[0], self.angle[1], size=self.clip_frame_cnt)
|
106 |
+
if self.center is not None:
|
107 |
+
center = (
|
108 |
+
np.random.uniform(self.center[0][0], self.center[1][0]),
|
109 |
+
np.random.uniform(self.center[0][1], self.center[1][1]),
|
110 |
+
)
|
111 |
+
angle = np.sort(angle)
|
112 |
+
if self._rand_range() < self.prob:
|
113 |
+
angle = angle[::-1]
|
114 |
+
self.angle_save = angle
|
115 |
+
self.center_save = center
|
116 |
+
|
117 |
+
self._cnt = 0 # avoiding overflow
|
118 |
+
|
119 |
+
angle = self.angle_save[self._cnt]
|
120 |
+
center = self.center_save
|
121 |
+
|
122 |
+
self._cnt += 1
|
123 |
+
|
124 |
+
if center is not None:
|
125 |
+
center = (w * center[0], h * center[1]) # Convert to absolute coordinates
|
126 |
+
|
127 |
+
if angle % 360 == 0:
|
128 |
+
return NoOpTransform()
|
129 |
+
|
130 |
+
return T.RotationTransform(h, w, angle, expand=self.expand, center=center, interp=self.interp)
|
131 |
+
|
132 |
+
|
133 |
+
class ResizeScaleClip(T.Augmentation):
|
134 |
+
"""
|
135 |
+
Takes target size as input and randomly scales the given target size between `min_scale`
|
136 |
+
and `max_scale`. It then scales the input image such that it fits inside the scaled target
|
137 |
+
box, keeping the aspect ratio constant.
|
138 |
+
This implements the resize part of the Google's 'resize_and_crop' data augmentation:
|
139 |
+
https://github.com/tensorflow/tpu/blob/master/models/official/detection/utils/input_utils.py#L127
|
140 |
+
"""
|
141 |
+
|
142 |
+
def __init__(
|
143 |
+
self,
|
144 |
+
min_scale: float,
|
145 |
+
max_scale: float,
|
146 |
+
target_height: int,
|
147 |
+
target_width: int,
|
148 |
+
interp: int = Image.BILINEAR,
|
149 |
+
clip_frame_cnt=1,
|
150 |
+
):
|
151 |
+
"""
|
152 |
+
Args:
|
153 |
+
min_scale: minimum image scale range.
|
154 |
+
max_scale: maximum image scale range.
|
155 |
+
target_height: target image height.
|
156 |
+
target_width: target image width.
|
157 |
+
interp: image interpolation method.
|
158 |
+
"""
|
159 |
+
super().__init__()
|
160 |
+
self._init(locals())
|
161 |
+
self._cnt = 0
|
162 |
+
|
163 |
+
def _get_resize(self, image: np.ndarray, scale: float):
|
164 |
+
input_size = image.shape[:2]
|
165 |
+
|
166 |
+
# Compute new target size given a scale.
|
167 |
+
target_size = (self.target_height, self.target_width)
|
168 |
+
target_scale_size = np.multiply(target_size, scale)
|
169 |
+
|
170 |
+
# Compute actual rescaling applied to input image and output size.
|
171 |
+
output_scale = np.minimum(
|
172 |
+
target_scale_size[0] / input_size[0], target_scale_size[1] / input_size[1]
|
173 |
+
)
|
174 |
+
output_size = np.round(np.multiply(input_size, output_scale)).astype(int)
|
175 |
+
|
176 |
+
return T.ResizeTransform(
|
177 |
+
input_size[0], input_size[1], output_size[0], output_size[1], self.interp
|
178 |
+
)
|
179 |
+
|
180 |
+
def get_transform(self, image: np.ndarray):
|
181 |
+
if self._cnt % self.clip_frame_cnt == 0:
|
182 |
+
random_scale = np.random.uniform(self.min_scale, self.max_scale)
|
183 |
+
self.random_scale_save = random_scale
|
184 |
+
|
185 |
+
self._cnt = 0 # avoiding overflow
|
186 |
+
self._cnt += 1
|
187 |
+
random_scale = self.random_scale_save
|
188 |
+
|
189 |
+
return self._get_resize(image, random_scale)
|
190 |
+
|
191 |
+
|
192 |
+
class RandomCropClip(T.Augmentation):
|
193 |
+
"""
|
194 |
+
Randomly crop a rectangle region out of an image.
|
195 |
+
"""
|
196 |
+
|
197 |
+
def __init__(self, crop_type: str, crop_size, clip_frame_cnt=1):
|
198 |
+
"""
|
199 |
+
Args:
|
200 |
+
crop_type (str): one of "relative_range", "relative", "absolute", "absolute_range".
|
201 |
+
crop_size (tuple[float, float]): two floats, explained below.
|
202 |
+
- "relative": crop a (H * crop_size[0], W * crop_size[1]) region from an input image of
|
203 |
+
size (H, W). crop size should be in (0, 1]
|
204 |
+
- "relative_range": uniformly sample two values from [crop_size[0], 1]
|
205 |
+
and [crop_size[1]], 1], and use them as in "relative" crop type.
|
206 |
+
- "absolute" crop a (crop_size[0], crop_size[1]) region from input image.
|
207 |
+
crop_size must be smaller than the input image size.
|
208 |
+
- "absolute_range", for an input of size (H, W), uniformly sample H_crop in
|
209 |
+
[crop_size[0], min(H, crop_size[1])] and W_crop in [crop_size[0], min(W, crop_size[1])].
|
210 |
+
Then crop a region (H_crop, W_crop).
|
211 |
+
"""
|
212 |
+
# TODO style of relative_range and absolute_range are not consistent:
|
213 |
+
# one takes (h, w) but another takes (min, max)
|
214 |
+
super().__init__()
|
215 |
+
assert crop_type in ["relative_range", "relative", "absolute", "absolute_range"]
|
216 |
+
self._init(locals())
|
217 |
+
self._cnt = 0
|
218 |
+
|
219 |
+
def get_transform(self, image):
|
220 |
+
h, w = image.shape[:2] # 667, 500
|
221 |
+
if self._cnt % self.clip_frame_cnt == 0:
|
222 |
+
croph, cropw = self.get_crop_size((h, w))
|
223 |
+
assert h >= croph and w >= cropw, "Shape computation in {} has bugs.".format(self)
|
224 |
+
|
225 |
+
h0 = np.random.randint(h - croph + 1) # rand(124) -> 5
|
226 |
+
w0 = np.random.randint(w - cropw + 1) # rand(111) -> 634
|
227 |
+
|
228 |
+
h1 = np.random.randint(h0, h - croph + 1)
|
229 |
+
w1 = np.random.randint(w0, w - cropw + 1)
|
230 |
+
|
231 |
+
x = np.sort(np.random.rand(self.clip_frame_cnt))
|
232 |
+
|
233 |
+
h = h0 * x + h1 * (1-x)
|
234 |
+
w = w0 * x + w1 * (1-x)
|
235 |
+
h = np.round_(h).astype(int)
|
236 |
+
w = np.round_(w).astype(int)
|
237 |
+
|
238 |
+
if self._rand_range() < 0.5:
|
239 |
+
h = h[::-1]
|
240 |
+
w = w[::-1]
|
241 |
+
|
242 |
+
self.hw_save = (h, w)
|
243 |
+
self.crop_h_save, self.crop_w_save = croph, cropw
|
244 |
+
self._cnt = 0 # avoiding overflow
|
245 |
+
_h, _w = self.hw_save[0][self._cnt], self.hw_save[1][self._cnt]
|
246 |
+
self._cnt += 1
|
247 |
+
|
248 |
+
return T.CropTransform(_w, _h, self.crop_w_save, self.crop_h_save)
|
249 |
+
|
250 |
+
def get_crop_size(self, image_size):
|
251 |
+
"""
|
252 |
+
Args:
|
253 |
+
image_size (tuple): height, width
|
254 |
+
Returns:
|
255 |
+
crop_size (tuple): height, width in absolute pixels
|
256 |
+
"""
|
257 |
+
h, w = image_size
|
258 |
+
if self.crop_type == "relative":
|
259 |
+
ch, cw = self.crop_size
|
260 |
+
return int(h * ch + 0.5), int(w * cw + 0.5)
|
261 |
+
elif self.crop_type == "relative_range":
|
262 |
+
crop_size = np.asarray(self.crop_size, dtype=float)
|
263 |
+
ch, cw = crop_size + np.random.rand(2) * (1 - crop_size)
|
264 |
+
return int(h * ch + 0.5), int(w * cw + 0.5)
|
265 |
+
elif self.crop_type == "absolute":
|
266 |
+
return (min(self.crop_size[0], h), min(self.crop_size[1], w))
|
267 |
+
elif self.crop_type == "absolute_range":
|
268 |
+
assert self.crop_size[0] <= self.crop_size[1]
|
269 |
+
ch = np.random.randint(min(h, self.crop_size[0]), min(h, self.crop_size[1]) + 1)
|
270 |
+
cw = np.random.randint(min(w, self.crop_size[0]), min(w, self.crop_size[1]) + 1)
|
271 |
+
return ch, cw
|
272 |
+
else:
|
273 |
+
raise NotImplementedError("Unknown crop type {}".format(self.crop_type))
|
274 |
+
|
275 |
+
|
276 |
+
class FixedSizeCropClip(T.Augmentation):
|
277 |
+
"""
|
278 |
+
If `crop_size` is smaller than the input image size, then it uses a random crop of
|
279 |
+
the crop size. If `crop_size` is larger than the input image size, then it pads
|
280 |
+
the right and the bottom of the image to the crop size if `pad` is True, otherwise
|
281 |
+
it returns the smaller image.
|
282 |
+
"""
|
283 |
+
|
284 |
+
def __init__(self, crop_size: Tuple[int], pad: bool = True, pad_value: float = 128.0, clip_frame_cnt=1):
|
285 |
+
"""
|
286 |
+
Args:
|
287 |
+
crop_size: target image (height, width).
|
288 |
+
pad: if True, will pad images smaller than `crop_size` up to `crop_size`
|
289 |
+
pad_value: the padding value.
|
290 |
+
"""
|
291 |
+
super().__init__()
|
292 |
+
self._init(locals())
|
293 |
+
self._cnt = 0
|
294 |
+
|
295 |
+
def _get_crop(self, image: np.ndarray):
|
296 |
+
# Compute the image scale and scaled size.
|
297 |
+
input_size = image.shape[:2]
|
298 |
+
output_size = self.crop_size
|
299 |
+
|
300 |
+
# Add random crop if the image is scaled up.
|
301 |
+
max_offset = np.subtract(input_size, output_size)
|
302 |
+
max_offset = np.maximum(max_offset, 0)
|
303 |
+
|
304 |
+
if self._cnt % self.clip_frame_cnt == 0:
|
305 |
+
offset = np.multiply(max_offset, np.random.uniform(0.0, 1.0))
|
306 |
+
offset = np.round(offset).astype(int)
|
307 |
+
self.offset_save = offset
|
308 |
+
self._cnt = 0 # avoiding overflow
|
309 |
+
self._cnt += 1
|
310 |
+
offset = self.offset_save
|
311 |
+
return CropTransform(
|
312 |
+
offset[1], offset[0], output_size[1], output_size[0], input_size[1], input_size[0]
|
313 |
+
)
|
314 |
+
|
315 |
+
def _get_pad(self, image: np.ndarray):
|
316 |
+
# Compute the image scale and scaled size.
|
317 |
+
input_size = image.shape[:2]
|
318 |
+
output_size = self.crop_size
|
319 |
+
|
320 |
+
# Add padding if the image is scaled down.
|
321 |
+
pad_size = np.subtract(output_size, input_size)
|
322 |
+
pad_size = np.maximum(pad_size, 0)
|
323 |
+
original_size = np.minimum(input_size, output_size)
|
324 |
+
return PadTransform(
|
325 |
+
0, 0, pad_size[1], pad_size[0], original_size[1], original_size[0], self.pad_value
|
326 |
+
)
|
327 |
+
|
328 |
+
def get_transform(self, image: np.ndarray):
|
329 |
+
transforms = [self._get_crop(image)]
|
330 |
+
if self.pad:
|
331 |
+
transforms.append(self._get_pad(image))
|
332 |
+
return TransformList(transforms)
|
333 |
+
|
334 |
+
|
335 |
+
class ResizeShortestEdgeClip(T.Augmentation):
|
336 |
+
"""
|
337 |
+
Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge.
|
338 |
+
If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
|
339 |
+
"""
|
340 |
+
|
341 |
+
def __init__(
|
342 |
+
self, short_edge_length, max_size=sys.maxsize, sample_style="range", interp=Image.BILINEAR, clip_frame_cnt=1
|
343 |
+
):
|
344 |
+
"""
|
345 |
+
Args:
|
346 |
+
short_edge_length (list[int]): If ``sample_style=="range"``,
|
347 |
+
a [min, max] interval from which to sample the shortest edge length.
|
348 |
+
If ``sample_style=="choice"``, a list of shortest edge lengths to sample from.
|
349 |
+
max_size (int): maximum allowed longest edge length.
|
350 |
+
sample_style (str): either "range" or "choice".
|
351 |
+
"""
|
352 |
+
super().__init__()
|
353 |
+
assert sample_style in ["range", "choice", "range_by_clip", "choice_by_clip"], sample_style
|
354 |
+
|
355 |
+
self.is_range = ("range" in sample_style)
|
356 |
+
if isinstance(short_edge_length, int):
|
357 |
+
short_edge_length = (short_edge_length, short_edge_length)
|
358 |
+
if self.is_range:
|
359 |
+
assert len(short_edge_length) == 2, (
|
360 |
+
"short_edge_length must be two values using 'range' sample style."
|
361 |
+
f" Got {short_edge_length}!"
|
362 |
+
)
|
363 |
+
self._cnt = 0
|
364 |
+
self._init(locals())
|
365 |
+
|
366 |
+
def get_transform(self, image):
|
367 |
+
if self._cnt % self.clip_frame_cnt == 0:
|
368 |
+
if self.is_range:
|
369 |
+
self.size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1)
|
370 |
+
else:
|
371 |
+
self.size = np.random.choice(self.short_edge_length)
|
372 |
+
self._cnt = 0 # avoiding overflow
|
373 |
+
|
374 |
+
if self.size == 0:
|
375 |
+
return NoOpTransform()
|
376 |
+
self._cnt += 1
|
377 |
+
|
378 |
+
h, w = image.shape[:2]
|
379 |
+
|
380 |
+
scale = self.size * 1.0 / min(h, w)
|
381 |
+
if h < w:
|
382 |
+
newh, neww = self.size, scale * w
|
383 |
+
else:
|
384 |
+
newh, neww = scale * h, self.size
|
385 |
+
if max(newh, neww) > self.max_size:
|
386 |
+
scale = self.max_size * 1.0 / max(newh, neww)
|
387 |
+
newh = newh * scale
|
388 |
+
neww = neww * scale
|
389 |
+
neww = int(neww + 0.5)
|
390 |
+
newh = int(newh + 0.5)
|
391 |
+
return T.ResizeTransform(h, w, newh, neww, self.interp)
|
392 |
+
|
393 |
+
|
394 |
+
class RandomFlipClip(T.Augmentation):
|
395 |
+
"""
|
396 |
+
Flip the image horizontally or vertically with the given probability.
|
397 |
+
"""
|
398 |
+
|
399 |
+
def __init__(self, prob=0.5, *, horizontal=True, vertical=False, clip_frame_cnt=1):
|
400 |
+
"""
|
401 |
+
Args:
|
402 |
+
prob (float): probability of flip.
|
403 |
+
horizontal (boolean): whether to apply horizontal flipping
|
404 |
+
vertical (boolean): whether to apply vertical flipping
|
405 |
+
"""
|
406 |
+
super().__init__()
|
407 |
+
|
408 |
+
if horizontal and vertical:
|
409 |
+
raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.")
|
410 |
+
if not horizontal and not vertical:
|
411 |
+
raise ValueError("At least one of horiz or vert has to be True!")
|
412 |
+
self._cnt = 0
|
413 |
+
|
414 |
+
self._init(locals())
|
415 |
+
|
416 |
+
def get_transform(self, image):
|
417 |
+
if self._cnt % self.clip_frame_cnt == 0:
|
418 |
+
self.do = self._rand_range() < self.prob
|
419 |
+
self._cnt = 0 # avoiding overflow
|
420 |
+
self._cnt += 1
|
421 |
+
|
422 |
+
h, w = image.shape[:2]
|
423 |
+
|
424 |
+
if self.do:
|
425 |
+
if self.horizontal:
|
426 |
+
return HFlipTransform(w)
|
427 |
+
elif self.vertical:
|
428 |
+
return VFlipTransform(h)
|
429 |
+
else:
|
430 |
+
return NoOpTransform()
|
431 |
+
|
432 |
+
|
433 |
+
def build_augmentation(cfg, is_train):
|
434 |
+
logger = logging.getLogger(__name__)
|
435 |
+
aug_list = []
|
436 |
+
if is_train:
|
437 |
+
use_lsj = cfg.INPUT.LSJ_AUG.ENABLED
|
438 |
+
if use_lsj:
|
439 |
+
image_size = cfg.INPUT.LSJ_AUG.IMAGE_SIZE
|
440 |
+
min_scale = cfg.INPUT.LSJ_AUG.MIN_SCALE
|
441 |
+
max_scale = cfg.INPUT.LSJ_AUG.MAX_SCALE
|
442 |
+
|
443 |
+
if cfg.INPUT.RANDOM_FLIP != "none":
|
444 |
+
if cfg.INPUT.RANDOM_FLIP == "flip_by_clip":
|
445 |
+
flip_clip_frame_cnt = cfg.INPUT.SAMPLING_FRAME_NUM
|
446 |
+
else:
|
447 |
+
flip_clip_frame_cnt = 1
|
448 |
+
|
449 |
+
aug_list.append(
|
450 |
+
# NOTE using RandomFlip modified for the support of flip maintenance
|
451 |
+
RandomFlipClip(
|
452 |
+
horizontal=(cfg.INPUT.RANDOM_FLIP == "horizontal") or (cfg.INPUT.RANDOM_FLIP == "flip_by_clip"),
|
453 |
+
vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
|
454 |
+
clip_frame_cnt=flip_clip_frame_cnt,
|
455 |
+
)
|
456 |
+
)
|
457 |
+
|
458 |
+
aug_list.extend([
|
459 |
+
T.ResizeScale(
|
460 |
+
min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size
|
461 |
+
),
|
462 |
+
T.FixedSizeCrop(crop_size=(image_size, image_size)),
|
463 |
+
])
|
464 |
+
|
465 |
+
else:
|
466 |
+
min_size = cfg.INPUT.MIN_SIZE_TRAIN
|
467 |
+
max_size = cfg.INPUT.MAX_SIZE_TRAIN
|
468 |
+
sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
|
469 |
+
clip_frame_cnt = cfg.INPUT.SAMPLING_FRAME_NUM if "by_clip" in cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING else 1
|
470 |
+
|
471 |
+
# Crop
|
472 |
+
if cfg.INPUT.CROP.ENABLED:
|
473 |
+
crop_aug = RandomApplyClip(
|
474 |
+
T.AugmentationList([
|
475 |
+
ResizeShortestEdgeClip([400, 500, 600], 1333, sample_style, clip_frame_cnt=clip_frame_cnt),
|
476 |
+
RandomCropClip(cfg.INPUT.PSEUDO.CROP.TYPE, cfg.INPUT.PSEUDO.CROP.SIZE, clip_frame_cnt=clip_frame_cnt)
|
477 |
+
]),
|
478 |
+
clip_frame_cnt=clip_frame_cnt
|
479 |
+
)
|
480 |
+
aug_list.append(crop_aug)
|
481 |
+
|
482 |
+
# Resize
|
483 |
+
aug_list.append(ResizeShortestEdgeClip(min_size, max_size, sample_style, clip_frame_cnt=clip_frame_cnt))
|
484 |
+
|
485 |
+
# Flip
|
486 |
+
if cfg.INPUT.RANDOM_FLIP != "none":
|
487 |
+
if cfg.INPUT.RANDOM_FLIP == "flip_by_clip":
|
488 |
+
flip_clip_frame_cnt = cfg.INPUT.SAMPLING_FRAME_NUM
|
489 |
+
else:
|
490 |
+
flip_clip_frame_cnt = 1
|
491 |
+
|
492 |
+
aug_list.append(
|
493 |
+
# NOTE using RandomFlip modified for the support of flip maintenance
|
494 |
+
RandomFlipClip(
|
495 |
+
horizontal=(cfg.INPUT.RANDOM_FLIP == "horizontal") or (cfg.INPUT.RANDOM_FLIP == "flip_by_clip"),
|
496 |
+
vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
|
497 |
+
clip_frame_cnt=flip_clip_frame_cnt,
|
498 |
+
)
|
499 |
+
)
|
500 |
+
|
501 |
+
# Additional augmentations : brightness, contrast, saturation, rotation
|
502 |
+
augmentations = cfg.INPUT.AUGMENTATIONS
|
503 |
+
if "brightness" in augmentations:
|
504 |
+
aug_list.append(T.RandomBrightness(0.9, 1.1))
|
505 |
+
if "contrast" in augmentations:
|
506 |
+
aug_list.append(T.RandomContrast(0.9, 1.1))
|
507 |
+
if "saturation" in augmentations:
|
508 |
+
aug_list.append(T.RandomSaturation(0.9, 1.1))
|
509 |
+
if "rotation" in augmentations:
|
510 |
+
aug_list.append(
|
511 |
+
T.RandomRotation(
|
512 |
+
[-15, 15], expand=False, center=[(0.4, 0.4), (0.6, 0.6)], sample_style="range"
|
513 |
+
)
|
514 |
+
)
|
515 |
+
else:
|
516 |
+
# Resize
|
517 |
+
min_size = cfg.INPUT.MIN_SIZE_TEST
|
518 |
+
max_size = cfg.INPUT.MAX_SIZE_TEST
|
519 |
+
sample_style = "choice"
|
520 |
+
aug_list.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
|
521 |
+
|
522 |
+
return aug_list
|
523 |
+
|
524 |
+
|
525 |
+
def build_pseudo_augmentation(cfg, is_train):
|
526 |
+
logger = logging.getLogger(__name__)
|
527 |
+
aug_list = []
|
528 |
+
if is_train:
|
529 |
+
use_lsj = cfg.INPUT.LSJ_AUG.ENABLED
|
530 |
+
if use_lsj:
|
531 |
+
image_size = cfg.INPUT.LSJ_AUG.IMAGE_SIZE
|
532 |
+
min_scale = cfg.INPUT.LSJ_AUG.MIN_SCALE
|
533 |
+
max_scale = cfg.INPUT.LSJ_AUG.MAX_SCALE
|
534 |
+
|
535 |
+
if cfg.INPUT.RANDOM_FLIP != "none":
|
536 |
+
if cfg.INPUT.RANDOM_FLIP == "flip_by_clip":
|
537 |
+
clip_frame_cnt = cfg.INPUT.SAMPLING_FRAME_NUM
|
538 |
+
else:
|
539 |
+
clip_frame_cnt = 1
|
540 |
+
|
541 |
+
aug_list.append(
|
542 |
+
# NOTE using RandomFlip modified for the support of flip maintenance
|
543 |
+
RandomFlipClip(
|
544 |
+
horizontal=(cfg.INPUT.RANDOM_FLIP == "horizontal") or (cfg.INPUT.RANDOM_FLIP == "flip_by_clip"),
|
545 |
+
vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
|
546 |
+
clip_frame_cnt=clip_frame_cnt,
|
547 |
+
)
|
548 |
+
)
|
549 |
+
|
550 |
+
# Additional augmentations : brightness, contrast, saturation, rotation
|
551 |
+
augmentations = cfg.INPUT.PSEUDO.AUGMENTATIONS
|
552 |
+
if "brightness" in augmentations:
|
553 |
+
aug_list.append(T.RandomBrightness(0.9, 1.1))
|
554 |
+
if "contrast" in augmentations:
|
555 |
+
aug_list.append(T.RandomContrast(0.9, 1.1))
|
556 |
+
if "saturation" in augmentations:
|
557 |
+
aug_list.append(T.RandomSaturation(0.9, 1.1))
|
558 |
+
if "rotation" in augmentations:
|
559 |
+
aug_list.append(
|
560 |
+
RandomRotationClip(
|
561 |
+
[-15, 15], expand=False, center=[(0.4, 0.4), (0.6, 0.6)], clip_frame_cnt=clip_frame_cnt,
|
562 |
+
)
|
563 |
+
)
|
564 |
+
|
565 |
+
aug_list.extend([
|
566 |
+
ResizeScaleClip(
|
567 |
+
min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size,
|
568 |
+
clip_frame_cnt=clip_frame_cnt,
|
569 |
+
),
|
570 |
+
FixedSizeCropClip(crop_size=(image_size, image_size), clip_frame_cnt=clip_frame_cnt),
|
571 |
+
])
|
572 |
+
else:
|
573 |
+
min_size = cfg.INPUT.PSEUDO.MIN_SIZE_TRAIN
|
574 |
+
max_size = cfg.INPUT.PSEUDO.MAX_SIZE_TRAIN
|
575 |
+
sample_style = cfg.INPUT.PSEUDO.MIN_SIZE_TRAIN_SAMPLING
|
576 |
+
clip_frame_cnt = cfg.INPUT.SAMPLING_FRAME_NUM
|
577 |
+
|
578 |
+
# Crop
|
579 |
+
if cfg.INPUT.PSEUDO.CROP.ENABLED:
|
580 |
+
crop_aug = RandomApplyClip(
|
581 |
+
T.AugmentationList([
|
582 |
+
ResizeShortestEdgeClip([400, 500, 600], 1333, sample_style, clip_frame_cnt=clip_frame_cnt),
|
583 |
+
RandomCropClip(cfg.INPUT.PSEUDO.CROP.TYPE, cfg.INPUT.PSEUDO.CROP.SIZE, clip_frame_cnt=clip_frame_cnt)
|
584 |
+
]),
|
585 |
+
clip_frame_cnt=clip_frame_cnt
|
586 |
+
)
|
587 |
+
aug_list.append(crop_aug)
|
588 |
+
|
589 |
+
# Resize
|
590 |
+
aug_list.append(ResizeShortestEdgeClip(min_size, max_size, sample_style, clip_frame_cnt=clip_frame_cnt))
|
591 |
+
|
592 |
+
# Flip
|
593 |
+
aug_list.append(
|
594 |
+
# NOTE using RandomFlip modified for the support of flip maintenance
|
595 |
+
RandomFlipClip(
|
596 |
+
horizontal=(cfg.INPUT.RANDOM_FLIP == "horizontal") or (cfg.INPUT.RANDOM_FLIP == "flip_by_clip"),
|
597 |
+
vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
|
598 |
+
clip_frame_cnt=clip_frame_cnt,
|
599 |
+
)
|
600 |
+
)
|
601 |
+
|
602 |
+
# Additional augmentations : brightness, contrast, saturation, rotation
|
603 |
+
augmentations = cfg.INPUT.PSEUDO.AUGMENTATIONS
|
604 |
+
if "brightness" in augmentations:
|
605 |
+
aug_list.append(T.RandomBrightness(0.9, 1.1))
|
606 |
+
if "contrast" in augmentations:
|
607 |
+
aug_list.append(T.RandomContrast(0.9, 1.1))
|
608 |
+
if "saturation" in augmentations:
|
609 |
+
aug_list.append(T.RandomSaturation(0.9, 1.1))
|
610 |
+
if "rotation" in augmentations:
|
611 |
+
aug_list.append(
|
612 |
+
RandomRotationClip(
|
613 |
+
[-15, 15], expand=False, center=[(0.4, 0.4), (0.6, 0.6)], clip_frame_cnt=clip_frame_cnt,
|
614 |
+
)
|
615 |
+
)
|
616 |
+
else:
|
617 |
+
# Resize
|
618 |
+
min_size = cfg.INPUT.MIN_SIZE_TEST
|
619 |
+
max_size = cfg.INPUT.MAX_SIZE_TEST
|
620 |
+
sample_style = "choice"
|
621 |
+
aug_list.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
|
622 |
+
|
623 |
+
return aug_list
|
avism/data/avis_eval.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import io
|
3 |
+
import copy
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
import contextlib
|
7 |
+
from collections import OrderedDict
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
import pycocotools.mask as mask_util
|
13 |
+
from multiprocessing import freeze_support
|
14 |
+
from fvcore.common.file_io import PathManager
|
15 |
+
from detectron2.data import MetadataCatalog
|
16 |
+
from detectron2.utils.file_io import PathManager
|
17 |
+
from detectron2.evaluation import DatasetEvaluator
|
18 |
+
|
19 |
+
from .datasets.avis_api.avos import AVOS
|
20 |
+
|
21 |
+
import sys
|
22 |
+
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
|
23 |
+
import aviseval
|
24 |
+
|
25 |
+
|
26 |
+
def eval_track(out_dir, gt_file):
|
27 |
+
freeze_support()
|
28 |
+
|
29 |
+
# Command line interface:
|
30 |
+
default_eval_config = aviseval.Evaluator.get_default_eval_config()
|
31 |
+
default_dataset_config = aviseval.datasets.AVIS.get_default_dataset_config()
|
32 |
+
default_dataset_config['TRACKERS_FOLDER'] = out_dir
|
33 |
+
default_dataset_config['GT_File'] = gt_file
|
34 |
+
default_metrics_config = {'METRICS': ['TrackMAP', 'HOTA']} # 'CLEAR', 'Identity'
|
35 |
+
config = {**default_eval_config, **default_dataset_config, **default_metrics_config} # Merge default configs
|
36 |
+
eval_config = {k: v for k, v in config.items() if k in default_eval_config.keys()}
|
37 |
+
dataset_config = {k: v for k, v in config.items() if k in default_dataset_config.keys()}
|
38 |
+
metrics_config = {k: v for k, v in config.items() if k in default_metrics_config.keys()}
|
39 |
+
|
40 |
+
# Run code
|
41 |
+
evaluator = aviseval.Evaluator(eval_config)
|
42 |
+
dataset_list = [aviseval.datasets.AVIS(dataset_config)]
|
43 |
+
metrics_list = []
|
44 |
+
for metric in [aviseval.metrics.TrackMAP, aviseval.metrics.HOTA]:
|
45 |
+
if metric.get_name() in metrics_config['METRICS']:
|
46 |
+
if metric == aviseval.metrics.TrackMAP:
|
47 |
+
default_track_map_config = metric.get_default_metric_config()
|
48 |
+
default_track_map_config['USE_TIME_RANGES'] = False
|
49 |
+
default_track_map_config['AREA_RANGES'] = [[0 ** 2, 128 ** 2],
|
50 |
+
[128 ** 2, 256 ** 2],
|
51 |
+
[256 ** 2, 1e5 ** 2]]
|
52 |
+
metrics_list.append(metric(default_track_map_config))
|
53 |
+
else:
|
54 |
+
metrics_list.append(metric())
|
55 |
+
if len(metrics_list) == 0:
|
56 |
+
raise Exception('No metrics selected for evaluation')
|
57 |
+
|
58 |
+
output_res, output_msg = evaluator.evaluate(dataset_list, metrics_list)
|
59 |
+
|
60 |
+
return output_res
|
61 |
+
|
62 |
+
|
63 |
+
def instances_to_coco_json_video(inputs, outputs):
|
64 |
+
"""
|
65 |
+
Dump an "Instances" object to a COCO-format json that's used for evaluation.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
instances (Instances):
|
69 |
+
video_id (int): the image id
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
list[dict]: list of json annotations in COCO format.
|
73 |
+
"""
|
74 |
+
assert len(inputs) == 1, "More than one inputs are loaded for inference!"
|
75 |
+
|
76 |
+
video_id = inputs[0]["video_id"]
|
77 |
+
video_length = inputs[0]["length"]
|
78 |
+
|
79 |
+
scores = outputs["pred_scores"]
|
80 |
+
labels = outputs["pred_labels"]
|
81 |
+
masks = outputs["pred_masks"]
|
82 |
+
|
83 |
+
avis_results = []
|
84 |
+
for instance_id, (s, l, m) in enumerate(zip(scores, labels, masks)):
|
85 |
+
segms = [
|
86 |
+
mask_util.encode(np.array(_mask[:, :, None], order="F", dtype="uint8"))[0]
|
87 |
+
for _mask in m
|
88 |
+
]
|
89 |
+
for rle in segms:
|
90 |
+
rle["counts"] = rle["counts"].decode("utf-8")
|
91 |
+
|
92 |
+
res = {
|
93 |
+
"video_id": video_id,
|
94 |
+
"score": s,
|
95 |
+
"category_id": l,
|
96 |
+
"segmentations": segms,
|
97 |
+
}
|
98 |
+
avis_results.append(res)
|
99 |
+
|
100 |
+
return avis_results
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
class AVISEvaluator(DatasetEvaluator):
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
dataset_name,
|
108 |
+
tasks=None,
|
109 |
+
distributed=True,
|
110 |
+
output_dir=None,
|
111 |
+
*,
|
112 |
+
use_fast_impl=True,
|
113 |
+
):
|
114 |
+
self._logger = logging.getLogger(__name__)
|
115 |
+
self._distributed = distributed
|
116 |
+
self._output_dir = output_dir
|
117 |
+
self._use_fast_impl = use_fast_impl
|
118 |
+
|
119 |
+
self._cpu_device = torch.device("cpu")
|
120 |
+
|
121 |
+
self.dataset_name = dataset_name
|
122 |
+
self._metadata = MetadataCatalog.get(dataset_name)
|
123 |
+
|
124 |
+
json_file = PathManager.get_local_path(self._metadata.json_file)
|
125 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
126 |
+
self._avis_api = AVOS(json_file)
|
127 |
+
|
128 |
+
self._do_evaluation = "annotations" in self._avis_api.dataset
|
129 |
+
|
130 |
+
|
131 |
+
def reset(self):
|
132 |
+
self._predictions = []
|
133 |
+
|
134 |
+
|
135 |
+
def process(self, inputs, outputs):
|
136 |
+
"""
|
137 |
+
Args:
|
138 |
+
inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
|
139 |
+
It is a list of dict. Each dict corresponds to an image and
|
140 |
+
contains keys like "height", "width", "file_name", "image_id".
|
141 |
+
outputs: the outputs of a COCO model. It is a list of dicts with key
|
142 |
+
"instances" that contains :class:`Instances`.
|
143 |
+
"""
|
144 |
+
prediction = instances_to_coco_json_video(inputs, outputs)
|
145 |
+
self._predictions.extend(prediction)
|
146 |
+
|
147 |
+
|
148 |
+
def evaluate(self):
|
149 |
+
"""
|
150 |
+
Args:
|
151 |
+
img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset
|
152 |
+
"""
|
153 |
+
|
154 |
+
predictions = self._predictions
|
155 |
+
|
156 |
+
self._results = OrderedDict()
|
157 |
+
self._eval_predictions(predictions)
|
158 |
+
# Copy so the caller can do whatever with results
|
159 |
+
return copy.deepcopy(self._results)
|
160 |
+
|
161 |
+
|
162 |
+
def _eval_predictions(self, predictions):
|
163 |
+
"""
|
164 |
+
Evaluate predictions. Fill self._results with the metrics of the tasks.
|
165 |
+
"""
|
166 |
+
self._logger.info("Preparing results for AVIS format ...")
|
167 |
+
|
168 |
+
# unmap the category ids for COCO
|
169 |
+
if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
|
170 |
+
dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id
|
171 |
+
|
172 |
+
all_contiguous_ids = list(dataset_id_to_contiguous_id.values())
|
173 |
+
num_classes = len(all_contiguous_ids)
|
174 |
+
assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1
|
175 |
+
|
176 |
+
reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()}
|
177 |
+
for result in predictions:
|
178 |
+
category_id = result["category_id"]
|
179 |
+
assert category_id < num_classes, (
|
180 |
+
f"A prediction has class={category_id}, "
|
181 |
+
f"but the dataset only has {num_classes} classes and "
|
182 |
+
f"predicted class id should be in [0, {num_classes - 1}]."
|
183 |
+
)
|
184 |
+
result["category_id"] = reverse_id_mapping[category_id]
|
185 |
+
|
186 |
+
o_d = None
|
187 |
+
if self._output_dir:
|
188 |
+
o_d = os.path.join(self._output_dir, "results")
|
189 |
+
os.makedirs(os.path.join(o_d, "model_final"), exist_ok=True)
|
190 |
+
file_path = os.path.join(o_d, "model_final", "results.json")
|
191 |
+
|
192 |
+
self._logger.info("Saving results to {}".format(file_path))
|
193 |
+
with PathManager.open(file_path, "w") as f:
|
194 |
+
f.write(json.dumps(predictions))
|
195 |
+
f.flush()
|
196 |
+
|
197 |
+
if not self._do_evaluation:
|
198 |
+
self._logger.info("Annotations are not available for evaluation.")
|
199 |
+
return
|
200 |
+
|
201 |
+
assert o_d != None
|
202 |
+
output_res = eval_track(o_d, "test.json")
|
203 |
+
self._results["segm"] = output_res['AVIS']['model_final']
|
avism/data/aviseval/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .eval import Evaluator
|
2 |
+
from . import datasets
|
3 |
+
from . import metrics
|
4 |
+
from . import plotting
|
5 |
+
from . import utils
|
avism/data/aviseval/_timing.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import wraps
|
2 |
+
from time import perf_counter
|
3 |
+
import inspect
|
4 |
+
|
5 |
+
DO_TIMING = False
|
6 |
+
DISPLAY_LESS_PROGRESS = False
|
7 |
+
timer_dict = {}
|
8 |
+
counter = 0
|
9 |
+
|
10 |
+
|
11 |
+
def time(f):
|
12 |
+
@wraps(f)
|
13 |
+
def wrap(*args, **kw):
|
14 |
+
if DO_TIMING:
|
15 |
+
# Run function with timing
|
16 |
+
ts = perf_counter()
|
17 |
+
result = f(*args, **kw)
|
18 |
+
te = perf_counter()
|
19 |
+
tt = te-ts
|
20 |
+
|
21 |
+
# Get function name
|
22 |
+
arg_names = inspect.getfullargspec(f)[0]
|
23 |
+
if arg_names[0] == 'self' and DISPLAY_LESS_PROGRESS:
|
24 |
+
return result
|
25 |
+
elif arg_names[0] == 'self':
|
26 |
+
method_name = type(args[0]).__name__ + '.' + f.__name__
|
27 |
+
else:
|
28 |
+
method_name = f.__name__
|
29 |
+
|
30 |
+
# Record accumulative time in each function for analysis
|
31 |
+
if method_name in timer_dict.keys():
|
32 |
+
timer_dict[method_name] += tt
|
33 |
+
else:
|
34 |
+
timer_dict[method_name] = tt
|
35 |
+
|
36 |
+
# If code is finished, display timing summary
|
37 |
+
if method_name == "Evaluator.evaluate":
|
38 |
+
print("")
|
39 |
+
print("Timing analysis:")
|
40 |
+
for key, value in timer_dict.items():
|
41 |
+
print('%-70s %2.4f sec' % (key, value))
|
42 |
+
else:
|
43 |
+
# Get function argument values for printing special arguments of interest
|
44 |
+
arg_titles = ['tracker', 'seq', 'cls']
|
45 |
+
arg_vals = []
|
46 |
+
for i, a in enumerate(arg_names):
|
47 |
+
if a in arg_titles:
|
48 |
+
arg_vals.append(args[i])
|
49 |
+
arg_text = '(' + ', '.join(arg_vals) + ')'
|
50 |
+
|
51 |
+
# Display methods and functions with different indentation.
|
52 |
+
if arg_names[0] == 'self':
|
53 |
+
print('%-74s %2.4f sec' % (' '*4 + method_name + arg_text, tt))
|
54 |
+
elif arg_names[0] == 'test':
|
55 |
+
pass
|
56 |
+
else:
|
57 |
+
global counter
|
58 |
+
counter += 1
|
59 |
+
print('%i %-70s %2.4f sec' % (counter, method_name + arg_text, tt))
|
60 |
+
|
61 |
+
return result
|
62 |
+
else:
|
63 |
+
# If config["TIME_PROGRESS"] is false, or config["USE_PARALLEL"] is true, run functions normally without timing.
|
64 |
+
return f(*args, **kw)
|
65 |
+
return wrap
|
avism/data/aviseval/datasets/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .avis import AVIS
|
avism/data/aviseval/datasets/_base_dataset.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import io
|
3 |
+
import zipfile
|
4 |
+
import os
|
5 |
+
import traceback
|
6 |
+
import numpy as np
|
7 |
+
from copy import deepcopy
|
8 |
+
from abc import ABC, abstractmethod
|
9 |
+
from .. import _timing
|
10 |
+
from ..utils import TrackEvalException
|
11 |
+
|
12 |
+
|
13 |
+
class _BaseDataset(ABC):
|
14 |
+
@abstractmethod
|
15 |
+
def __init__(self):
|
16 |
+
self.tracker_list = None
|
17 |
+
self.seq_list = None
|
18 |
+
self.class_list = None
|
19 |
+
self.output_fol = None
|
20 |
+
self.output_sub_fol = None
|
21 |
+
self.should_classes_combine = True
|
22 |
+
self.use_super_categories = False
|
23 |
+
|
24 |
+
# Functions to implement:
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
@abstractmethod
|
28 |
+
def get_default_dataset_config():
|
29 |
+
...
|
30 |
+
|
31 |
+
@abstractmethod
|
32 |
+
def _load_raw_file(self, tracker, seq, is_gt):
|
33 |
+
...
|
34 |
+
|
35 |
+
@_timing.time
|
36 |
+
@abstractmethod
|
37 |
+
def get_preprocessed_seq_data(self, raw_data, cls):
|
38 |
+
...
|
39 |
+
|
40 |
+
@abstractmethod
|
41 |
+
def _calculate_similarities(self, gt_dets_t, tracker_dets_t):
|
42 |
+
...
|
43 |
+
|
44 |
+
# Helper functions for all datasets:
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def get_class_name(cls):
|
48 |
+
return cls.__name__
|
49 |
+
|
50 |
+
def get_name(self):
|
51 |
+
return self.get_class_name()
|
52 |
+
|
53 |
+
def get_output_fol(self, tracker):
|
54 |
+
return os.path.join(self.output_fol, tracker, self.output_sub_fol)
|
55 |
+
|
56 |
+
def get_display_name(self, tracker):
|
57 |
+
""" Can be overwritten if the trackers name (in files) is different to how it should be displayed.
|
58 |
+
By default this method just returns the trackers name as is.
|
59 |
+
"""
|
60 |
+
return tracker
|
61 |
+
|
62 |
+
def get_eval_info(self):
|
63 |
+
"""Return info about the dataset needed for the Evaluator"""
|
64 |
+
return self.tracker_list, self.seq_list, self.class_list
|
65 |
+
|
66 |
+
@_timing.time
|
67 |
+
def get_raw_seq_data(self, tracker, seq):
|
68 |
+
""" Loads raw data (tracker and ground-truth) for a single tracker on a single sequence.
|
69 |
+
Raw data includes all of the information needed for both preprocessing and evaluation, for all classes.
|
70 |
+
A later function (get_processed_seq_data) will perform such preprocessing and extract relevant information for
|
71 |
+
the evaluation of each class.
|
72 |
+
|
73 |
+
This returns a dict which contains the fields:
|
74 |
+
[num_timesteps]: integer
|
75 |
+
[gt_ids, tracker_ids, gt_classes, tracker_classes, tracker_confidences]:
|
76 |
+
list (for each timestep) of 1D NDArrays (for each det).
|
77 |
+
[gt_dets, tracker_dets, gt_crowd_ignore_regions]: list (for each timestep) of lists of detections.
|
78 |
+
[similarity_scores]: list (for each timestep) of 2D NDArrays.
|
79 |
+
[gt_extras]: dict (for each extra) of lists (for each timestep) of 1D NDArrays (for each det).
|
80 |
+
|
81 |
+
gt_extras contains dataset specific information used for preprocessing such as occlusion and truncation levels.
|
82 |
+
|
83 |
+
Note that similarities are extracted as part of the dataset and not the metric, because almost all metrics are
|
84 |
+
independent of the exact method of calculating the similarity. However datasets are not (e.g. segmentation
|
85 |
+
masks vs 2D boxes vs 3D boxes).
|
86 |
+
We calculate the similarity before preprocessing because often both preprocessing and evaluation require it and
|
87 |
+
we don't wish to calculate this twice.
|
88 |
+
We calculate similarity between all gt and tracker classes (not just each class individually) to allow for
|
89 |
+
calculation of metrics such as class confusion matrices. Typically the impact of this on performance is low.
|
90 |
+
"""
|
91 |
+
# Load raw data.
|
92 |
+
raw_gt_data = self._load_raw_file(tracker, seq, is_gt=True)
|
93 |
+
raw_tracker_data = self._load_raw_file(tracker, seq, is_gt=False)
|
94 |
+
raw_data = {**raw_tracker_data, **raw_gt_data} # Merges dictionaries
|
95 |
+
|
96 |
+
# Calculate similarities for each timestep.
|
97 |
+
similarity_scores = []
|
98 |
+
for t, (gt_dets_t, tracker_dets_t) in enumerate(zip(raw_data['gt_dets'], raw_data['tracker_dets'])):
|
99 |
+
ious = self._calculate_similarities(gt_dets_t, tracker_dets_t)
|
100 |
+
similarity_scores.append(ious)
|
101 |
+
raw_data['similarity_scores'] = similarity_scores
|
102 |
+
return raw_data
|
103 |
+
|
104 |
+
@staticmethod
|
105 |
+
def _load_simple_text_file(file, time_col=0, id_col=None, remove_negative_ids=False, valid_filter=None,
|
106 |
+
crowd_ignore_filter=None, convert_filter=None, is_zipped=False, zip_file=None,
|
107 |
+
force_delimiters=None):
|
108 |
+
""" Function that loads data which is in a commonly used text file format.
|
109 |
+
Assumes each det is given by one row of a text file.
|
110 |
+
There is no limit to the number or meaning of each column,
|
111 |
+
however one column needs to give the timestep of each det (time_col) which is default col 0.
|
112 |
+
|
113 |
+
The file dialect (deliminator, num cols, etc) is determined automatically.
|
114 |
+
This function automatically separates dets by timestep,
|
115 |
+
and is much faster than alternatives such as np.loadtext or pandas.
|
116 |
+
|
117 |
+
If remove_negative_ids is True and id_col is not None, dets with negative values in id_col are excluded.
|
118 |
+
These are not excluded from ignore data.
|
119 |
+
|
120 |
+
valid_filter can be used to only include certain classes.
|
121 |
+
It is a dict with ints as keys, and lists as values,
|
122 |
+
such that a row is included if "row[key].lower() is in value" for all key/value pairs in the dict.
|
123 |
+
If None, all classes are included.
|
124 |
+
|
125 |
+
crowd_ignore_filter can be used to read crowd_ignore regions separately. It has the same format as valid filter.
|
126 |
+
|
127 |
+
convert_filter can be used to convert value read to another format.
|
128 |
+
This is used most commonly to convert classes given as string to a class id.
|
129 |
+
This is a dict such that the key is the column to convert, and the value is another dict giving the mapping.
|
130 |
+
|
131 |
+
Optionally, input files could be a zip of multiple text files for storage efficiency.
|
132 |
+
|
133 |
+
Returns read_data and ignore_data.
|
134 |
+
Each is a dict (with keys as timesteps as strings) of lists (over dets) of lists (over column values).
|
135 |
+
Note that all data is returned as strings, and must be converted to float/int later if needed.
|
136 |
+
Note that timesteps will not be present in the returned dict keys if there are no dets for them
|
137 |
+
"""
|
138 |
+
|
139 |
+
if remove_negative_ids and id_col is None:
|
140 |
+
raise TrackEvalException('remove_negative_ids is True, but id_col is not given.')
|
141 |
+
if crowd_ignore_filter is None:
|
142 |
+
crowd_ignore_filter = {}
|
143 |
+
if convert_filter is None:
|
144 |
+
convert_filter = {}
|
145 |
+
try:
|
146 |
+
if is_zipped: # Either open file directly or within a zip.
|
147 |
+
if zip_file is None:
|
148 |
+
raise TrackEvalException('is_zipped set to True, but no zip_file is given.')
|
149 |
+
archive = zipfile.ZipFile(os.path.join(zip_file), 'r')
|
150 |
+
fp = io.TextIOWrapper(archive.open(file, 'r'))
|
151 |
+
else:
|
152 |
+
fp = open(file)
|
153 |
+
read_data = {}
|
154 |
+
crowd_ignore_data = {}
|
155 |
+
fp.seek(0, os.SEEK_END)
|
156 |
+
# check if file is empty
|
157 |
+
if fp.tell():
|
158 |
+
fp.seek(0)
|
159 |
+
dialect = csv.Sniffer().sniff(fp.readline(), delimiters=force_delimiters) # Auto determine structure.
|
160 |
+
dialect.skipinitialspace = True # Deal with extra spaces between columns
|
161 |
+
fp.seek(0)
|
162 |
+
reader = csv.reader(fp, dialect)
|
163 |
+
for row in reader:
|
164 |
+
try:
|
165 |
+
# Deal with extra trailing spaces at the end of rows
|
166 |
+
if row[-1] in '':
|
167 |
+
row = row[:-1]
|
168 |
+
timestep = str(int(float(row[time_col])))
|
169 |
+
# Read ignore regions separately.
|
170 |
+
is_ignored = False
|
171 |
+
for ignore_key, ignore_value in crowd_ignore_filter.items():
|
172 |
+
if row[ignore_key].lower() in ignore_value:
|
173 |
+
# Convert values in one column (e.g. string to id)
|
174 |
+
for convert_key, convert_value in convert_filter.items():
|
175 |
+
row[convert_key] = convert_value[row[convert_key].lower()]
|
176 |
+
# Save data separated by timestep.
|
177 |
+
if timestep in crowd_ignore_data.keys():
|
178 |
+
crowd_ignore_data[timestep].append(row)
|
179 |
+
else:
|
180 |
+
crowd_ignore_data[timestep] = [row]
|
181 |
+
is_ignored = True
|
182 |
+
if is_ignored: # if det is an ignore region, it cannot be a normal det.
|
183 |
+
continue
|
184 |
+
# Exclude some dets if not valid.
|
185 |
+
if valid_filter is not None:
|
186 |
+
for key, value in valid_filter.items():
|
187 |
+
if row[key].lower() not in value:
|
188 |
+
continue
|
189 |
+
if remove_negative_ids:
|
190 |
+
if int(float(row[id_col])) < 0:
|
191 |
+
continue
|
192 |
+
# Convert values in one column (e.g. string to id)
|
193 |
+
for convert_key, convert_value in convert_filter.items():
|
194 |
+
row[convert_key] = convert_value[row[convert_key].lower()]
|
195 |
+
# Save data separated by timestep.
|
196 |
+
if timestep in read_data.keys():
|
197 |
+
read_data[timestep].append(row)
|
198 |
+
else:
|
199 |
+
read_data[timestep] = [row]
|
200 |
+
except Exception:
|
201 |
+
exc_str_init = 'In file %s the following line cannot be read correctly: \n' % os.path.basename(
|
202 |
+
file)
|
203 |
+
exc_str = ' '.join([exc_str_init]+row)
|
204 |
+
raise TrackEvalException(exc_str)
|
205 |
+
fp.close()
|
206 |
+
except Exception:
|
207 |
+
print('Error loading file: %s, printing traceback.' % file)
|
208 |
+
traceback.print_exc()
|
209 |
+
raise TrackEvalException(
|
210 |
+
'File %s cannot be read because it is either not present or invalidly formatted' % os.path.basename(
|
211 |
+
file))
|
212 |
+
return read_data, crowd_ignore_data
|
213 |
+
|
214 |
+
@staticmethod
|
215 |
+
def _calculate_mask_ious(masks1, masks2, is_encoded=False, do_ioa=False):
|
216 |
+
""" Calculates the IOU (intersection over union) between two arrays of segmentation masks.
|
217 |
+
If is_encoded a run length encoding with pycocotools is assumed as input format, otherwise an input of numpy
|
218 |
+
arrays of the shape (num_masks, height, width) is assumed and the encoding is performed.
|
219 |
+
If do_ioa (intersection over area) , then calculates the intersection over the area of masks1 - this is commonly
|
220 |
+
used to determine if detections are within crowd ignore region.
|
221 |
+
:param masks1: first set of masks (numpy array of shape (num_masks, height, width) if not encoded,
|
222 |
+
else pycocotools rle encoded format)
|
223 |
+
:param masks2: second set of masks (numpy array of shape (num_masks, height, width) if not encoded,
|
224 |
+
else pycocotools rle encoded format)
|
225 |
+
:param is_encoded: whether the input is in pycocotools rle encoded format
|
226 |
+
:param do_ioa: whether to perform IoA computation
|
227 |
+
:return: the IoU/IoA scores
|
228 |
+
"""
|
229 |
+
|
230 |
+
# Only loaded when run to reduce minimum requirements
|
231 |
+
from pycocotools import mask as mask_utils
|
232 |
+
|
233 |
+
# use pycocotools for run length encoding of masks
|
234 |
+
if not is_encoded:
|
235 |
+
masks1 = mask_utils.encode(np.array(np.transpose(masks1, (1, 2, 0)), order='F'))
|
236 |
+
masks2 = mask_utils.encode(np.array(np.transpose(masks2, (1, 2, 0)), order='F'))
|
237 |
+
|
238 |
+
# use pycocotools for iou computation of rle encoded masks
|
239 |
+
ious = mask_utils.iou(masks1, masks2, [do_ioa]*len(masks2))
|
240 |
+
if len(masks1) == 0 or len(masks2) == 0:
|
241 |
+
ious = np.asarray(ious).reshape(len(masks1), len(masks2))
|
242 |
+
assert (ious >= 0 - np.finfo('float').eps).all()
|
243 |
+
assert (ious <= 1 + np.finfo('float').eps).all()
|
244 |
+
|
245 |
+
return ious
|
246 |
+
|
247 |
+
@staticmethod
|
248 |
+
def _calculate_box_ious(bboxes1, bboxes2, box_format='xywh', do_ioa=False):
|
249 |
+
""" Calculates the IOU (intersection over union) between two arrays of boxes.
|
250 |
+
Allows variable box formats ('xywh' and 'x0y0x1y1').
|
251 |
+
If do_ioa (intersection over area) , then calculates the intersection over the area of boxes1 - this is commonly
|
252 |
+
used to determine if detections are within crowd ignore region.
|
253 |
+
"""
|
254 |
+
if box_format in 'xywh':
|
255 |
+
# layout: (x0, y0, w, h)
|
256 |
+
bboxes1 = deepcopy(bboxes1)
|
257 |
+
bboxes2 = deepcopy(bboxes2)
|
258 |
+
|
259 |
+
bboxes1[:, 2] = bboxes1[:, 0] + bboxes1[:, 2]
|
260 |
+
bboxes1[:, 3] = bboxes1[:, 1] + bboxes1[:, 3]
|
261 |
+
bboxes2[:, 2] = bboxes2[:, 0] + bboxes2[:, 2]
|
262 |
+
bboxes2[:, 3] = bboxes2[:, 1] + bboxes2[:, 3]
|
263 |
+
elif box_format not in 'x0y0x1y1':
|
264 |
+
raise (TrackEvalException('box_format %s is not implemented' % box_format))
|
265 |
+
|
266 |
+
# layout: (x0, y0, x1, y1)
|
267 |
+
min_ = np.minimum(bboxes1[:, np.newaxis, :], bboxes2[np.newaxis, :, :])
|
268 |
+
max_ = np.maximum(bboxes1[:, np.newaxis, :], bboxes2[np.newaxis, :, :])
|
269 |
+
intersection = np.maximum(min_[..., 2] - max_[..., 0], 0) * np.maximum(min_[..., 3] - max_[..., 1], 0)
|
270 |
+
area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
|
271 |
+
|
272 |
+
if do_ioa:
|
273 |
+
ioas = np.zeros_like(intersection)
|
274 |
+
valid_mask = area1 > 0 + np.finfo('float').eps
|
275 |
+
ioas[valid_mask, :] = intersection[valid_mask, :] / area1[valid_mask][:, np.newaxis]
|
276 |
+
|
277 |
+
return ioas
|
278 |
+
else:
|
279 |
+
area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1])
|
280 |
+
union = area1[:, np.newaxis] + area2[np.newaxis, :] - intersection
|
281 |
+
intersection[area1 <= 0 + np.finfo('float').eps, :] = 0
|
282 |
+
intersection[:, area2 <= 0 + np.finfo('float').eps] = 0
|
283 |
+
intersection[union <= 0 + np.finfo('float').eps] = 0
|
284 |
+
union[union <= 0 + np.finfo('float').eps] = 1
|
285 |
+
ious = intersection / union
|
286 |
+
return ious
|
287 |
+
|
288 |
+
@staticmethod
|
289 |
+
def _calculate_euclidean_similarity(dets1, dets2, zero_distance=2.0):
|
290 |
+
""" Calculates the euclidean distance between two sets of detections, and then converts this into a similarity
|
291 |
+
measure with values between 0 and 1 using the following formula: sim = max(0, 1 - dist/zero_distance).
|
292 |
+
The default zero_distance of 2.0, corresponds to the default used in MOT15_3D, such that a 0.5 similarity
|
293 |
+
threshold corresponds to a 1m distance threshold for TPs.
|
294 |
+
"""
|
295 |
+
dist = np.linalg.norm(dets1[:, np.newaxis]-dets2[np.newaxis, :], axis=2)
|
296 |
+
sim = np.maximum(0, 1 - dist/zero_distance)
|
297 |
+
return sim
|
298 |
+
|
299 |
+
@staticmethod
|
300 |
+
def _check_unique_ids(data, after_preproc=False):
|
301 |
+
"""Check the requirement that the tracker_ids and gt_ids are unique per timestep"""
|
302 |
+
gt_ids = data['gt_ids']
|
303 |
+
tracker_ids = data['tracker_ids']
|
304 |
+
for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(gt_ids, tracker_ids)):
|
305 |
+
if len(tracker_ids_t) > 0:
|
306 |
+
unique_ids, counts = np.unique(tracker_ids_t, return_counts=True)
|
307 |
+
if np.max(counts) != 1:
|
308 |
+
duplicate_ids = unique_ids[counts > 1]
|
309 |
+
exc_str_init = 'Tracker predicts the same ID more than once in a single timestep ' \
|
310 |
+
'(seq: %s, frame: %i, ids:' % (data['seq'], t+1)
|
311 |
+
exc_str = ' '.join([exc_str_init] + [str(d) for d in duplicate_ids]) + ')'
|
312 |
+
if after_preproc:
|
313 |
+
exc_str_init += '\n Note that this error occurred after preprocessing (but not before), ' \
|
314 |
+
'so ids may not be as in file, and something seems wrong with preproc.'
|
315 |
+
raise TrackEvalException(exc_str)
|
316 |
+
if len(gt_ids_t) > 0:
|
317 |
+
unique_ids, counts = np.unique(gt_ids_t, return_counts=True)
|
318 |
+
if np.max(counts) != 1:
|
319 |
+
duplicate_ids = unique_ids[counts > 1]
|
320 |
+
exc_str_init = 'Ground-truth has the same ID more than once in a single timestep ' \
|
321 |
+
'(seq: %s, frame: %i, ids:' % (data['seq'], t+1)
|
322 |
+
exc_str = ' '.join([exc_str_init] + [str(d) for d in duplicate_ids]) + ')'
|
323 |
+
if after_preproc:
|
324 |
+
exc_str_init += '\n Note that this error occurred after preprocessing (but not before), ' \
|
325 |
+
'so ids may not be as in file, and something seems wrong with preproc.'
|
326 |
+
raise TrackEvalException(exc_str)
|
avism/data/aviseval/datasets/avis.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import json
|
4 |
+
from ._base_dataset import _BaseDataset
|
5 |
+
from ..utils import TrackEvalException
|
6 |
+
from .. import utils
|
7 |
+
from .. import _timing
|
8 |
+
|
9 |
+
|
10 |
+
class AVIS(_BaseDataset):
|
11 |
+
"""Dataset class for AVIS tracking"""
|
12 |
+
|
13 |
+
@staticmethod
|
14 |
+
def get_default_dataset_config():
|
15 |
+
"""Default class config values"""
|
16 |
+
default_config = {
|
17 |
+
'GT_FOLDER': "./datasets/", # Location of GT data
|
18 |
+
'TRACKERS_FOLDER': "./outputs/avism_R50_IN/inference/",
|
19 |
+
'GT_File': "test.json",
|
20 |
+
# Trackers location
|
21 |
+
'OUTPUT_FOLDER': None, # Where to save eval results (if None, same as TRACKERS_FOLDER)
|
22 |
+
'TRACKERS_TO_EVAL': None, # Filenames of trackers to eval (if None, all in folder)
|
23 |
+
'CLASSES_TO_EVAL': None, # Classes to eval (if None, all classes)
|
24 |
+
'SPLIT_TO_EVAL': None, # Valid: 'train', 'val', 'train_sub_split'
|
25 |
+
'PRINT_CONFIG': False, # Whether to print current config
|
26 |
+
'OUTPUT_SUB_FOLDER': '', # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER
|
27 |
+
'TRACKER_DISPLAY_NAMES': None, # Names of trackers to display, if None: TRACKERS_TO_EVAL
|
28 |
+
}
|
29 |
+
return default_config
|
30 |
+
|
31 |
+
def __init__(self, config=None):
|
32 |
+
"""Initialise dataset, checking that all required files are present"""
|
33 |
+
super().__init__()
|
34 |
+
# Fill non-given config values with defaults
|
35 |
+
self.config = utils.init_config(config, self.get_default_dataset_config(), self.get_name())
|
36 |
+
self.gt_fol = self.config['GT_FOLDER']
|
37 |
+
self.tracker_fol = self.config['TRACKERS_FOLDER']
|
38 |
+
self.use_super_categories = False
|
39 |
+
self.should_classes_combine = True
|
40 |
+
|
41 |
+
self.output_fol = self.config['OUTPUT_FOLDER']
|
42 |
+
if self.output_fol is None:
|
43 |
+
self.output_fol = self.tracker_fol
|
44 |
+
self.output_sub_fol = self.config['OUTPUT_SUB_FOLDER']
|
45 |
+
|
46 |
+
if not os.path.exists(self.gt_fol):
|
47 |
+
print("GT folder not found: " + self.gt_fol)
|
48 |
+
raise TrackEvalException("GT folder not found: " + os.path.basename(self.gt_fol))
|
49 |
+
gt_dir_files = [self.config['GT_File']]
|
50 |
+
if len(gt_dir_files) != 1:
|
51 |
+
raise TrackEvalException(self.gt_fol + ' does not contain exactly one json file.')
|
52 |
+
|
53 |
+
with open(os.path.join(self.gt_fol, gt_dir_files[0])) as f:
|
54 |
+
self.gt_data = json.load(f)
|
55 |
+
|
56 |
+
# Get classes to eval
|
57 |
+
self.valid_classes = [cls['name'] for cls in self.gt_data['categories']]
|
58 |
+
cls_name_to_cls_id_map = {cls['name']: cls['id'] for cls in self.gt_data['categories']}
|
59 |
+
|
60 |
+
if self.config['CLASSES_TO_EVAL']:
|
61 |
+
self.class_list = [cls.lower() if cls.lower() in self.valid_classes else None
|
62 |
+
for cls in self.config['CLASSES_TO_EVAL']]
|
63 |
+
if not all(self.class_list):
|
64 |
+
raise TrackEvalException('Attempted to evaluate an invalid class. Only classes ' +
|
65 |
+
', '.join(self.valid_classes) + ' are valid.')
|
66 |
+
else:
|
67 |
+
self.class_list = [cls['name'] for cls in self.gt_data['categories']]
|
68 |
+
self.class_name_to_class_id = {k: v for k, v in cls_name_to_cls_id_map.items() if k in self.class_list}
|
69 |
+
|
70 |
+
# Get sequences to eval and check gt files exist
|
71 |
+
self.seq_list = [vid['file_names'][0].split('/')[0] for vid in self.gt_data['videos']]
|
72 |
+
self.seq_name_to_seq_id = {vid['file_names'][0].split('/')[0]: vid['id'] for vid in self.gt_data['videos']}
|
73 |
+
self.seq_lengths = {vid['id']: len(vid['file_names']) for vid in self.gt_data['videos']}
|
74 |
+
|
75 |
+
# encode masks and compute track areas
|
76 |
+
self._prepare_gt_annotations()
|
77 |
+
|
78 |
+
# Get trackers to eval
|
79 |
+
if self.config['TRACKERS_TO_EVAL'] is None:
|
80 |
+
self.tracker_list = os.listdir(self.tracker_fol)
|
81 |
+
else:
|
82 |
+
self.tracker_list = self.config['TRACKERS_TO_EVAL']
|
83 |
+
|
84 |
+
if self.config['TRACKER_DISPLAY_NAMES'] is None:
|
85 |
+
self.tracker_to_disp = dict(zip(self.tracker_list, self.tracker_list))
|
86 |
+
elif (self.config['TRACKERS_TO_EVAL'] is not None) and (
|
87 |
+
len(self.config['TRACKER_DISPLAY_NAMES']) == len(self.tracker_list)):
|
88 |
+
self.tracker_to_disp = dict(zip(self.tracker_list, self.config['TRACKER_DISPLAY_NAMES']))
|
89 |
+
else:
|
90 |
+
raise TrackEvalException('List of tracker files and tracker display names do not match.')
|
91 |
+
|
92 |
+
# counter for globally unique track IDs
|
93 |
+
self.global_tid_counter = 0
|
94 |
+
|
95 |
+
self.tracker_data = dict()
|
96 |
+
for tracker in self.tracker_list:
|
97 |
+
tracker_dir_path = os.path.join(self.tracker_fol, tracker)
|
98 |
+
tr_dir_files = [file for file in os.listdir(tracker_dir_path) if file.endswith('.json')]
|
99 |
+
if len(tr_dir_files) != 1:
|
100 |
+
raise TrackEvalException(tracker_dir_path + ' does not contain exactly one json file.')
|
101 |
+
|
102 |
+
with open(os.path.join(tracker_dir_path, tr_dir_files[0])) as f:
|
103 |
+
curr_data = json.load(f)
|
104 |
+
|
105 |
+
self.tracker_data[tracker] = curr_data
|
106 |
+
|
107 |
+
def get_display_name(self, tracker):
|
108 |
+
return self.tracker_to_disp[tracker]
|
109 |
+
|
110 |
+
def _load_raw_file(self, tracker, seq, is_gt):
|
111 |
+
"""Load a file (gt or tracker) in the YouTubeVIS format
|
112 |
+
If is_gt, this returns a dict which contains the fields:
|
113 |
+
[gt_ids, gt_classes] : list (for each timestep) of 1D NDArrays (for each det).
|
114 |
+
[gt_dets]: list (for each timestep) of lists of detections.
|
115 |
+
[classes_to_gt_tracks]: dictionary with class values as keys and list of dictionaries (with frame indices as
|
116 |
+
keys and corresponding segmentations as values) for each track
|
117 |
+
[classes_to_gt_track_ids, classes_to_gt_track_areas, classes_to_gt_track_iscrowd]: dictionary with class values
|
118 |
+
as keys and lists (for each track) as values
|
119 |
+
|
120 |
+
if not is_gt, this returns a dict which contains the fields:
|
121 |
+
[tracker_ids, tracker_classes, tracker_confidences] : list (for each timestep) of 1D NDArrays (for each det).
|
122 |
+
[tracker_dets]: list (for each timestep) of lists of detections.
|
123 |
+
[classes_to_dt_tracks]: dictionary with class values as keys and list of dictionaries (with frame indices as
|
124 |
+
keys and corresponding segmentations as values) for each track
|
125 |
+
[classes_to_dt_track_ids, classes_to_dt_track_areas]: dictionary with class values as keys and lists as values
|
126 |
+
[classes_to_dt_track_scores]: dictionary with class values as keys and 1D numpy arrays as values
|
127 |
+
"""
|
128 |
+
# select sequence tracks
|
129 |
+
seq_id = self.seq_name_to_seq_id[seq]
|
130 |
+
if is_gt:
|
131 |
+
tracks = [ann for ann in self.gt_data['annotations'] if ann['video_id'] == seq_id]
|
132 |
+
else:
|
133 |
+
tracks = self._get_tracker_seq_tracks(tracker, seq_id)
|
134 |
+
|
135 |
+
# Convert data to required format
|
136 |
+
num_timesteps = self.seq_lengths[seq_id]
|
137 |
+
data_keys = ['ids', 'classes', 'dets']
|
138 |
+
if not is_gt:
|
139 |
+
data_keys += ['tracker_confidences']
|
140 |
+
raw_data = {key: [None] * num_timesteps for key in data_keys}
|
141 |
+
raw_data['raw_dets'] = [None] * num_timesteps
|
142 |
+
raw_data['raw_classes'] = [None] * num_timesteps
|
143 |
+
|
144 |
+
for t in range(num_timesteps):
|
145 |
+
raw_data['raw_dets'][t] = [track['segmentations'][t] for track in tracks]
|
146 |
+
raw_data['raw_classes'][t] = np.atleast_1d([track['category_id'] for track in tracks]).astype(int)
|
147 |
+
|
148 |
+
raw_data['dets'][t] = [track['segmentations'][t] for track in tracks if track['segmentations'][t]]
|
149 |
+
raw_data['ids'][t] = np.atleast_1d([track['id'] for track in tracks if track['segmentations'][t]]).astype(int)
|
150 |
+
raw_data['classes'][t] = np.atleast_1d([track['category_id'] for track in tracks if track['segmentations'][t]]).astype(int)
|
151 |
+
if not is_gt:
|
152 |
+
raw_data['tracker_confidences'][t] = np.atleast_1d([track['score'] for track in tracks if track['segmentations'][t]]).astype(float)
|
153 |
+
|
154 |
+
if is_gt:
|
155 |
+
key_map = {'ids': 'gt_ids',
|
156 |
+
'classes': 'gt_classes',
|
157 |
+
'dets': 'gt_dets'}
|
158 |
+
else:
|
159 |
+
key_map = {'ids': 'tracker_ids',
|
160 |
+
'classes': 'tracker_classes',
|
161 |
+
'dets': 'tracker_dets'}
|
162 |
+
for k, v in key_map.items():
|
163 |
+
raw_data[v] = raw_data.pop(k)
|
164 |
+
|
165 |
+
all_cls_ids = {self.class_name_to_class_id[cls] for cls in self.class_list}
|
166 |
+
classes_to_tracks = {cls: [track for track in tracks if track['category_id'] == cls] for cls in all_cls_ids}
|
167 |
+
|
168 |
+
# mapping from classes to track representations and track information
|
169 |
+
raw_data['classes_to_tracks'] = {cls: [{i: track['segmentations'][i]
|
170 |
+
for i in range(len(track['segmentations']))} for track in tracks]
|
171 |
+
for cls, tracks in classes_to_tracks.items()}
|
172 |
+
raw_data['classes_to_track_ids'] = {cls: [track['id'] for track in tracks]
|
173 |
+
for cls, tracks in classes_to_tracks.items()}
|
174 |
+
raw_data['classes_to_track_areas'] = {cls: [track['area'] for track in tracks]
|
175 |
+
for cls, tracks in classes_to_tracks.items()}
|
176 |
+
|
177 |
+
if is_gt:
|
178 |
+
raw_data['classes_to_gt_track_iscrowd'] = {cls: [track['iscrowd'] for track in tracks]
|
179 |
+
for cls, tracks in classes_to_tracks.items()}
|
180 |
+
else:
|
181 |
+
raw_data['classes_to_dt_track_scores'] = {cls: np.array([track['score'] for track in tracks])
|
182 |
+
for cls, tracks in classes_to_tracks.items()}
|
183 |
+
|
184 |
+
if is_gt:
|
185 |
+
key_map = {'classes_to_tracks': 'classes_to_gt_tracks',
|
186 |
+
'classes_to_track_ids': 'classes_to_gt_track_ids',
|
187 |
+
'classes_to_track_areas': 'classes_to_gt_track_areas'}
|
188 |
+
else:
|
189 |
+
key_map = {'classes_to_tracks': 'classes_to_dt_tracks',
|
190 |
+
'classes_to_track_ids': 'classes_to_dt_track_ids',
|
191 |
+
'classes_to_track_areas': 'classes_to_dt_track_areas'}
|
192 |
+
for k, v in key_map.items():
|
193 |
+
raw_data[v] = raw_data.pop(k)
|
194 |
+
|
195 |
+
raw_data['num_timesteps'] = num_timesteps
|
196 |
+
raw_data['seq'] = seq
|
197 |
+
return raw_data
|
198 |
+
|
199 |
+
@_timing.time
|
200 |
+
def get_preprocessed_seq_data(self, raw_data, cls):
|
201 |
+
""" Preprocess data for a single sequence for a single class ready for evaluation.
|
202 |
+
Inputs:
|
203 |
+
- raw_data is a dict containing the data for the sequence already read in by get_raw_seq_data().
|
204 |
+
- cls is the class to be evaluated.
|
205 |
+
Outputs:
|
206 |
+
- data is a dict containing all of the information that metrics need to perform evaluation.
|
207 |
+
It contains the following fields:
|
208 |
+
[num_timesteps, num_gt_ids, num_tracker_ids, num_gt_dets, num_tracker_dets] : integers.
|
209 |
+
[gt_ids, tracker_ids, tracker_confidences]: list (for each timestep) of 1D NDArrays (for each det).
|
210 |
+
[gt_dets, tracker_dets]: list (for each timestep) of lists of detections.
|
211 |
+
[similarity_scores]: list (for each timestep) of 2D NDArrays.
|
212 |
+
Notes:
|
213 |
+
General preprocessing (preproc) occurs in 4 steps. Some datasets may not use all of these steps.
|
214 |
+
1) Extract only detections relevant for the class to be evaluated (including distractor detections).
|
215 |
+
2) Match gt dets and tracker dets. Remove tracker dets that are matched to a gt det that is of a
|
216 |
+
distractor class, or otherwise marked as to be removed.
|
217 |
+
3) Remove unmatched tracker dets if they fall within a crowd ignore region or don't meet a certain
|
218 |
+
other criteria (e.g. are too small).
|
219 |
+
4) Remove gt dets that were only useful for preprocessing and not for actual evaluation.
|
220 |
+
After the above preprocessing steps, this function also calculates the number of gt and tracker detections
|
221 |
+
and unique track ids. It also relabels gt and tracker ids to be contiguous and checks that ids are
|
222 |
+
unique within each timestep.
|
223 |
+
YouTubeVIS:
|
224 |
+
In YouTubeVIS, the 4 preproc steps are as follow:
|
225 |
+
1) There are 40 classes which are evaluated separately.
|
226 |
+
2) No matched tracker dets are removed.
|
227 |
+
3) No unmatched tracker dets are removed.
|
228 |
+
4) No gt dets are removed.
|
229 |
+
Further, for TrackMAP computation track representations for the given class are accessed from a dictionary
|
230 |
+
and the tracks from the tracker data are sorted according to the tracker confidence.
|
231 |
+
"""
|
232 |
+
cls_id = self.class_name_to_class_id[cls]
|
233 |
+
|
234 |
+
data_keys = ['gt_ids', 'tracker_ids', 'gt_dets', 'tracker_dets', 'similarity_scores']
|
235 |
+
data = {key: [None] * raw_data['num_timesteps'] for key in data_keys}
|
236 |
+
unique_gt_ids = []
|
237 |
+
unique_tracker_ids = []
|
238 |
+
num_gt_dets = 0
|
239 |
+
num_tracker_dets = 0
|
240 |
+
|
241 |
+
for t in range(raw_data['num_timesteps']):
|
242 |
+
|
243 |
+
# Only extract relevant dets for this class for eval (cls)
|
244 |
+
gt_class_mask = np.atleast_1d(raw_data['gt_classes'][t] == cls_id)
|
245 |
+
gt_class_mask = gt_class_mask.astype(bool)
|
246 |
+
gt_ids = raw_data['gt_ids'][t][gt_class_mask]
|
247 |
+
gt_dets = [raw_data['gt_dets'][t][ind] for ind in range(len(gt_class_mask)) if gt_class_mask[ind]]
|
248 |
+
|
249 |
+
tracker_class_mask = np.atleast_1d(raw_data['tracker_classes'][t] == cls_id)
|
250 |
+
tracker_class_mask = tracker_class_mask.astype(bool)
|
251 |
+
tracker_ids = raw_data['tracker_ids'][t][tracker_class_mask]
|
252 |
+
tracker_dets = [raw_data['tracker_dets'][t][ind] for ind in range(len(tracker_class_mask)) if
|
253 |
+
tracker_class_mask[ind]]
|
254 |
+
similarity_scores = raw_data['similarity_scores'][t][gt_class_mask, :][:, tracker_class_mask]
|
255 |
+
|
256 |
+
data['tracker_ids'][t] = tracker_ids
|
257 |
+
data['tracker_dets'][t] = tracker_dets
|
258 |
+
data['gt_ids'][t] = gt_ids
|
259 |
+
data['gt_dets'][t] = gt_dets
|
260 |
+
data['similarity_scores'][t] = similarity_scores
|
261 |
+
|
262 |
+
unique_gt_ids += list(np.unique(data['gt_ids'][t]))
|
263 |
+
unique_tracker_ids += list(np.unique(data['tracker_ids'][t]))
|
264 |
+
num_tracker_dets += len(data['tracker_ids'][t])
|
265 |
+
num_gt_dets += len(data['gt_ids'][t])
|
266 |
+
|
267 |
+
# Re-label IDs such that there are no empty IDs
|
268 |
+
if len(unique_gt_ids) > 0:
|
269 |
+
unique_gt_ids = np.unique(unique_gt_ids)
|
270 |
+
gt_id_map = np.nan * np.ones((np.max(unique_gt_ids) + 1))
|
271 |
+
gt_id_map[unique_gt_ids] = np.arange(len(unique_gt_ids))
|
272 |
+
for t in range(raw_data['num_timesteps']):
|
273 |
+
if len(data['gt_ids'][t]) > 0:
|
274 |
+
data['gt_ids'][t] = gt_id_map[data['gt_ids'][t]].astype(int)
|
275 |
+
if len(unique_tracker_ids) > 0:
|
276 |
+
unique_tracker_ids = np.unique(unique_tracker_ids)
|
277 |
+
tracker_id_map = np.nan * np.ones((np.max(unique_tracker_ids) + 1))
|
278 |
+
tracker_id_map[unique_tracker_ids] = np.arange(len(unique_tracker_ids))
|
279 |
+
for t in range(raw_data['num_timesteps']):
|
280 |
+
if len(data['tracker_ids'][t]) > 0:
|
281 |
+
data['tracker_ids'][t] = tracker_id_map[data['tracker_ids'][t]].astype(int)
|
282 |
+
|
283 |
+
# Ensure that ids are unique per timestep.
|
284 |
+
self._check_unique_ids(data)
|
285 |
+
|
286 |
+
# Record overview statistics.
|
287 |
+
data['num_tracker_dets'] = num_tracker_dets
|
288 |
+
data['num_gt_dets'] = num_gt_dets
|
289 |
+
data['num_tracker_ids'] = len(unique_tracker_ids)
|
290 |
+
data['num_gt_ids'] = len(unique_gt_ids)
|
291 |
+
data['num_timesteps'] = raw_data['num_timesteps']
|
292 |
+
data['seq'] = raw_data['seq']
|
293 |
+
|
294 |
+
# get track representations
|
295 |
+
data['gt_tracks'] = raw_data['classes_to_gt_tracks'][cls_id]
|
296 |
+
data['gt_track_ids'] = raw_data['classes_to_gt_track_ids'][cls_id]
|
297 |
+
data['gt_track_areas'] = raw_data['classes_to_gt_track_areas'][cls_id]
|
298 |
+
data['gt_track_iscrowd'] = raw_data['classes_to_gt_track_iscrowd'][cls_id]
|
299 |
+
data['dt_tracks'] = raw_data['classes_to_dt_tracks'][cls_id]
|
300 |
+
data['dt_track_ids'] = raw_data['classes_to_dt_track_ids'][cls_id]
|
301 |
+
data['dt_track_areas'] = raw_data['classes_to_dt_track_areas'][cls_id]
|
302 |
+
data['dt_track_scores'] = raw_data['classes_to_dt_track_scores'][cls_id]
|
303 |
+
data['iou_type'] = 'mask'
|
304 |
+
|
305 |
+
# sort tracker data tracks by tracker confidence scores
|
306 |
+
if data['dt_tracks']:
|
307 |
+
idx = np.argsort([-score for score in data['dt_track_scores']], kind="mergesort")
|
308 |
+
data['dt_track_scores'] = [data['dt_track_scores'][i] for i in idx]
|
309 |
+
data['dt_tracks'] = [data['dt_tracks'][i] for i in idx]
|
310 |
+
data['dt_track_ids'] = [data['dt_track_ids'][i] for i in idx]
|
311 |
+
data['dt_track_areas'] = [data['dt_track_areas'][i] for i in idx]
|
312 |
+
|
313 |
+
return data
|
314 |
+
|
315 |
+
def _calculate_similarities(self, gt_dets_t, tracker_dets_t):
|
316 |
+
similarity_scores = self._calculate_mask_ious(gt_dets_t, tracker_dets_t, is_encoded=True, do_ioa=False)
|
317 |
+
return similarity_scores
|
318 |
+
|
319 |
+
def _prepare_gt_annotations(self):
|
320 |
+
"""
|
321 |
+
Prepares GT data by rle encoding segmentations and computing the average track area.
|
322 |
+
:return: None
|
323 |
+
"""
|
324 |
+
# only loaded when needed to reduce minimum requirements
|
325 |
+
from pycocotools import mask as mask_utils
|
326 |
+
|
327 |
+
for track in self.gt_data['annotations']:
|
328 |
+
h = track['height']
|
329 |
+
w = track['width']
|
330 |
+
for i, seg in enumerate(track['segmentations']):
|
331 |
+
if seg:
|
332 |
+
masks = mask_utils.frPyObjects(seg, h, w)
|
333 |
+
track['segmentations'][i] = mask_utils.merge(masks)
|
334 |
+
# track['segmentations'][i] = mask_utils.frPyObjects(seg, h, w)
|
335 |
+
areas = [a for a in track['areas'] if a]
|
336 |
+
if len(areas) == 0:
|
337 |
+
track['area'] = 0
|
338 |
+
else:
|
339 |
+
track['area'] = np.array(areas).mean()
|
340 |
+
|
341 |
+
def _get_tracker_seq_tracks(self, tracker, seq_id):
|
342 |
+
"""
|
343 |
+
Prepares tracker data for a given sequence. Extracts all annotations for given sequence ID, computes
|
344 |
+
average track area and assigns a track ID.
|
345 |
+
:param tracker: the given tracker
|
346 |
+
:param seq_id: the sequence ID
|
347 |
+
:return: the extracted tracks
|
348 |
+
"""
|
349 |
+
# only loaded when needed to reduce minimum requirements
|
350 |
+
from pycocotools import mask as mask_utils
|
351 |
+
|
352 |
+
tracks = [ann for ann in self.tracker_data[tracker] if ann['video_id'] == seq_id]
|
353 |
+
for track in tracks:
|
354 |
+
track['areas'] = []
|
355 |
+
for seg in track['segmentations']:
|
356 |
+
if seg:
|
357 |
+
track['areas'].append(mask_utils.area(seg))
|
358 |
+
else:
|
359 |
+
track['areas'].append(None)
|
360 |
+
areas = [a for a in track['areas'] if a]
|
361 |
+
if len(areas) == 0:
|
362 |
+
track['area'] = 0
|
363 |
+
else:
|
364 |
+
track['area'] = np.array(areas).mean()
|
365 |
+
track['id'] = self.global_tid_counter
|
366 |
+
self.global_tid_counter += 1
|
367 |
+
return tracks
|
avism/data/aviseval/eval.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tqdm
|
3 |
+
import traceback
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from . import utils
|
7 |
+
from . import _timing
|
8 |
+
from .metrics import Count
|
9 |
+
from .utils import TrackEvalException
|
10 |
+
from .metrics import compute_av_loc, combine_av_loc_sequences
|
11 |
+
|
12 |
+
|
13 |
+
class Evaluator:
|
14 |
+
"""Evaluator class for evaluating different metrics for different datasets"""
|
15 |
+
|
16 |
+
@staticmethod
|
17 |
+
def get_default_eval_config():
|
18 |
+
"""Returns the default config values for evaluation"""
|
19 |
+
code_path = utils.get_code_path()
|
20 |
+
default_config = {
|
21 |
+
'USE_PARALLEL': False,
|
22 |
+
'NUM_PARALLEL_CORES': 8,
|
23 |
+
'BREAK_ON_ERROR': True, # Raises exception and exits with error
|
24 |
+
'RETURN_ON_ERROR': False, # if not BREAK_ON_ERROR, then returns from function on error
|
25 |
+
'LOG_ON_ERROR': os.path.join(code_path, 'error_log.txt'), # if not None, save any errors into a log file.
|
26 |
+
|
27 |
+
'PRINT_RESULTS': False,
|
28 |
+
'PRINT_ONLY_COMBINED': False,
|
29 |
+
'PRINT_CONFIG': False,
|
30 |
+
'TIME_PROGRESS': False,
|
31 |
+
'DISPLAY_LESS_PROGRESS': True,
|
32 |
+
|
33 |
+
'OUTPUT_SUMMARY': False,
|
34 |
+
'OUTPUT_EMPTY_CLASSES': False,
|
35 |
+
'OUTPUT_DETAILED': False,
|
36 |
+
'PLOT_CURVES': False,
|
37 |
+
}
|
38 |
+
return default_config
|
39 |
+
|
40 |
+
def __init__(self, config=None):
|
41 |
+
"""Initialise the evaluator with a config file"""
|
42 |
+
self.config = utils.init_config(config, self.get_default_eval_config(), 'Eval')
|
43 |
+
# Only run timing analysis if not run in parallel.
|
44 |
+
if self.config['TIME_PROGRESS'] and not self.config['USE_PARALLEL']:
|
45 |
+
_timing.DO_TIMING = True
|
46 |
+
if self.config['DISPLAY_LESS_PROGRESS']:
|
47 |
+
_timing.DISPLAY_LESS_PROGRESS = True
|
48 |
+
|
49 |
+
@_timing.time
|
50 |
+
def evaluate(self, dataset_list, metrics_list):
|
51 |
+
"""Evaluate a set of metrics on a set of datasets"""
|
52 |
+
config = self.config
|
53 |
+
metrics_list = metrics_list + [Count()] # Count metrics are always run
|
54 |
+
metric_names = utils.validate_metrics_list(metrics_list)
|
55 |
+
dataset_names = [dataset.get_name() for dataset in dataset_list]
|
56 |
+
output_res = {}
|
57 |
+
output_msg = {}
|
58 |
+
|
59 |
+
for dataset, dataset_name in zip(dataset_list, dataset_names):
|
60 |
+
# Get dataset info about what to evaluate
|
61 |
+
output_res[dataset_name] = {}
|
62 |
+
output_msg[dataset_name] = {}
|
63 |
+
tracker_list, seq_list, class_list = dataset.get_eval_info()
|
64 |
+
|
65 |
+
# Evaluate each tracker
|
66 |
+
for tracker in tracker_list:
|
67 |
+
# if not config['BREAK_ON_ERROR'] then go to next tracker without breaking
|
68 |
+
try:
|
69 |
+
print('\nEvaluating model ...... \n')
|
70 |
+
res = {}
|
71 |
+
res_av_loc = {}
|
72 |
+
|
73 |
+
seq_list_sorted = sorted(seq_list)
|
74 |
+
for curr_seq in tqdm.tqdm(seq_list_sorted):
|
75 |
+
res[curr_seq] = eval_sequence(curr_seq, dataset, tracker, class_list, metrics_list, metric_names)
|
76 |
+
res_av_loc[curr_seq] = eval_av_loc_sequence(curr_seq, dataset, tracker)
|
77 |
+
|
78 |
+
# Combine results over all sequences and then over all classes
|
79 |
+
res_av_loc_all = combine_av_loc_sequences(res_av_loc)
|
80 |
+
|
81 |
+
# collecting combined cls keys (cls averaged, det averaged, super classes)
|
82 |
+
combined_cls_keys = []
|
83 |
+
res['COMBINED_SEQ'] = {}
|
84 |
+
# combine sequences for each class
|
85 |
+
for c_cls in class_list:
|
86 |
+
res['COMBINED_SEQ'][c_cls] = {}
|
87 |
+
for metric, metric_name in zip(metrics_list, metric_names):
|
88 |
+
curr_res = {seq_key: seq_value[c_cls][metric_name] for seq_key, seq_value in res.items() if
|
89 |
+
seq_key != 'COMBINED_SEQ'}
|
90 |
+
res['COMBINED_SEQ'][c_cls][metric_name] = metric.combine_sequences(curr_res)
|
91 |
+
# combine classes
|
92 |
+
if dataset.should_classes_combine:
|
93 |
+
combined_cls_keys += ['cls_comb_cls_av', 'cls_comb_det_av', 'all']
|
94 |
+
res['COMBINED_SEQ']['cls_comb_cls_av'] = {}
|
95 |
+
res['COMBINED_SEQ']['cls_comb_det_av'] = {}
|
96 |
+
for metric, metric_name in zip(metrics_list, metric_names):
|
97 |
+
cls_res = {cls_key: cls_value[metric_name] for cls_key, cls_value in
|
98 |
+
res['COMBINED_SEQ'].items() if cls_key not in combined_cls_keys}
|
99 |
+
res['COMBINED_SEQ']['cls_comb_cls_av'][metric_name] = \
|
100 |
+
metric.combine_classes_class_averaged(cls_res)
|
101 |
+
res['COMBINED_SEQ']['cls_comb_det_av'][metric_name] = \
|
102 |
+
metric.combine_classes_det_averaged(cls_res)
|
103 |
+
# combine classes to super classes
|
104 |
+
if dataset.use_super_categories:
|
105 |
+
for cat, sub_cats in dataset.super_categories.items():
|
106 |
+
combined_cls_keys.append(cat)
|
107 |
+
res['COMBINED_SEQ'][cat] = {}
|
108 |
+
for metric, metric_name in zip(metrics_list, metric_names):
|
109 |
+
cat_res = {cls_key: cls_value[metric_name] for cls_key, cls_value in
|
110 |
+
res['COMBINED_SEQ'].items() if cls_key in sub_cats}
|
111 |
+
res['COMBINED_SEQ'][cat][metric_name] = metric.combine_classes_det_averaged(cat_res)
|
112 |
+
|
113 |
+
# Print and output results in various formats
|
114 |
+
output_fol = dataset.get_output_fol(tracker)
|
115 |
+
tracker_display_name = dataset.get_display_name(tracker)
|
116 |
+
for c_cls in res['COMBINED_SEQ'].keys(): # class_list + combined classes if calculated
|
117 |
+
summaries = []
|
118 |
+
details = []
|
119 |
+
num_dets = res['COMBINED_SEQ'][c_cls]['Count']['Dets']
|
120 |
+
if config['OUTPUT_EMPTY_CLASSES'] or num_dets > 0:
|
121 |
+
for metric, metric_name in zip(metrics_list, metric_names):
|
122 |
+
# for combined classes there is no per sequence evaluation
|
123 |
+
if c_cls in combined_cls_keys:
|
124 |
+
table_res = {'COMBINED_SEQ': res['COMBINED_SEQ'][c_cls][metric_name]}
|
125 |
+
else:
|
126 |
+
table_res = {seq_key: seq_value[c_cls][metric_name] for seq_key, seq_value in res.items()}
|
127 |
+
if config['PLOT_CURVES']:
|
128 |
+
metric.plot_single_tracker_results(table_res, tracker_display_name, c_cls, output_fol)
|
129 |
+
if config['OUTPUT_SUMMARY']:
|
130 |
+
utils.write_summary_results(summaries, c_cls, output_fol)
|
131 |
+
if config['OUTPUT_DETAILED']:
|
132 |
+
utils.write_detailed_results(details, c_cls, output_fol)
|
133 |
+
|
134 |
+
# Output for returning from function
|
135 |
+
res_output = {}
|
136 |
+
|
137 |
+
res_output["AP_all"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['TrackMAP']['AP_all']), 2)
|
138 |
+
res_output["AP_s"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['TrackMAP']['AP_area_s']), 2)
|
139 |
+
res_output["AP_m"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['TrackMAP']['AP_area_m']), 2)
|
140 |
+
res_output["AP_l"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['TrackMAP']['AP_area_l']), 2)
|
141 |
+
res_output["AR_all"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['TrackMAP']['AR_all']), 2)
|
142 |
+
|
143 |
+
res_output["HOTA"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['HOTA']['HOTA']), 2)
|
144 |
+
res_output["DetA"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['HOTA']['DetA']), 2)
|
145 |
+
res_output["DetRe"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['HOTA']['DetRe']), 2)
|
146 |
+
res_output["DetPr"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['HOTA']['DetPr']), 2)
|
147 |
+
res_output["AssA"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['HOTA']['AssA']), 2)
|
148 |
+
res_output["AssRe"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['HOTA']['AssRe']), 2)
|
149 |
+
res_output["AssPr"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['HOTA']['AssPr']), 2)
|
150 |
+
res_output["LocA"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['HOTA']['LocA']), 2)
|
151 |
+
|
152 |
+
res_output["FA"] = round(100 * np.mean(res_av_loc_all['FA']), 2)
|
153 |
+
res_output["FAn"] = round(100 * np.mean(res_av_loc_all['FAn']), 2)
|
154 |
+
res_output['FAn_count'] = int(np.mean(res_av_loc_all['FAn_count']))
|
155 |
+
res_output['FAn_all'] = int(np.mean(res_av_loc_all['FAn_all']))
|
156 |
+
res_output["FAs"] = round(100 * np.mean(res_av_loc_all['FAs']), 2)
|
157 |
+
res_output['FAs_count'] = int(np.mean(res_av_loc_all['FAs_count']))
|
158 |
+
res_output['FAs_all'] = int(np.mean(res_av_loc_all['FAs_all']))
|
159 |
+
res_output["FAm"] = round(100 * np.mean(res_av_loc_all['FAm']), 2)
|
160 |
+
res_output['FAm_count'] = int(np.mean(res_av_loc_all['FAm_count']))
|
161 |
+
res_output['FAm_all'] = int(np.mean(res_av_loc_all['FAm_all']))
|
162 |
+
|
163 |
+
output_res[dataset_name][tracker] = res_output
|
164 |
+
output_msg[dataset_name][tracker] = 'Success'
|
165 |
+
|
166 |
+
except Exception as err:
|
167 |
+
output_res[dataset_name][tracker] = None
|
168 |
+
if type(err) == TrackEvalException:
|
169 |
+
output_msg[dataset_name][tracker] = str(err)
|
170 |
+
else:
|
171 |
+
output_msg[dataset_name][tracker] = 'Unknown error occurred.'
|
172 |
+
print('Tracker %s was unable to be evaluated.' % tracker)
|
173 |
+
print(err)
|
174 |
+
traceback.print_exc()
|
175 |
+
if config['LOG_ON_ERROR'] is not None:
|
176 |
+
with open(config['LOG_ON_ERROR'], 'a') as f:
|
177 |
+
print(dataset_name, file=f)
|
178 |
+
print(tracker, file=f)
|
179 |
+
print(traceback.format_exc(), file=f)
|
180 |
+
print('\n\n\n', file=f)
|
181 |
+
if config['BREAK_ON_ERROR']:
|
182 |
+
raise err
|
183 |
+
elif config['RETURN_ON_ERROR']:
|
184 |
+
return output_res, output_msg
|
185 |
+
|
186 |
+
return output_res, output_msg
|
187 |
+
|
188 |
+
|
189 |
+
@_timing.time
|
190 |
+
def eval_sequence(seq, dataset, tracker, class_list, metrics_list, metric_names):
|
191 |
+
"""Function for evaluating a single sequence"""
|
192 |
+
|
193 |
+
raw_data = dataset.get_raw_seq_data(tracker, seq)
|
194 |
+
seq_res = {}
|
195 |
+
for cls in class_list:
|
196 |
+
seq_res[cls] = {}
|
197 |
+
data = dataset.get_preprocessed_seq_data(raw_data, cls)
|
198 |
+
for metric, met_name in zip(metrics_list, metric_names):
|
199 |
+
seq_res[cls][met_name] = metric.eval_sequence(data)
|
200 |
+
return seq_res
|
201 |
+
|
202 |
+
|
203 |
+
def eval_av_loc_sequence(seq, dataset, tracker):
|
204 |
+
"""Function for evaluating a single sequence"""
|
205 |
+
|
206 |
+
raw_data = dataset.get_raw_seq_data(tracker, seq)
|
207 |
+
av_loc_res = compute_av_loc(raw_data)
|
208 |
+
|
209 |
+
return av_loc_res
|
avism/data/aviseval/metrics/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .hota import HOTA
|
2 |
+
from .clear import CLEAR
|
3 |
+
from .identity import Identity
|
4 |
+
from .count import Count
|
5 |
+
from .j_and_f import JAndF
|
6 |
+
from .track_map import TrackMAP
|
7 |
+
from .vace import VACE
|
8 |
+
from .ideucl import IDEucl
|
9 |
+
|
10 |
+
from .avisa import avisA
|
11 |
+
from .av_loc import compute_av_loc, combine_av_loc_sequences
|
avism/data/aviseval/metrics/_base_metric.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from .. import _timing
|
4 |
+
from ..utils import TrackEvalException
|
5 |
+
|
6 |
+
|
7 |
+
class _BaseMetric(ABC):
|
8 |
+
@abstractmethod
|
9 |
+
def __init__(self):
|
10 |
+
self.plottable = False
|
11 |
+
self.integer_fields = []
|
12 |
+
self.float_fields = []
|
13 |
+
self.array_labels = []
|
14 |
+
self.integer_array_fields = []
|
15 |
+
self.float_array_fields = []
|
16 |
+
self.fields = []
|
17 |
+
self.summary_fields = []
|
18 |
+
self.registered = False
|
19 |
+
|
20 |
+
#####################################################################
|
21 |
+
# Abstract functions for subclasses to implement
|
22 |
+
|
23 |
+
@_timing.time
|
24 |
+
@abstractmethod
|
25 |
+
def eval_sequence(self, data):
|
26 |
+
...
|
27 |
+
|
28 |
+
@abstractmethod
|
29 |
+
def combine_sequences(self, all_res):
|
30 |
+
...
|
31 |
+
|
32 |
+
@abstractmethod
|
33 |
+
def combine_classes_class_averaged(self, all_res, ignore_empty_classes=False):
|
34 |
+
...
|
35 |
+
|
36 |
+
@ abstractmethod
|
37 |
+
def combine_classes_det_averaged(self, all_res):
|
38 |
+
...
|
39 |
+
|
40 |
+
def plot_single_tracker_results(self, all_res, tracker, output_folder, cls):
|
41 |
+
"""Plot results of metrics, only valid for metrics with self.plottable"""
|
42 |
+
if self.plottable:
|
43 |
+
raise NotImplementedError('plot_results is not implemented for metric %s' % self.get_name())
|
44 |
+
else:
|
45 |
+
pass
|
46 |
+
|
47 |
+
#####################################################################
|
48 |
+
# Helper functions which are useful for all metrics:
|
49 |
+
|
50 |
+
@classmethod
|
51 |
+
def get_name(cls):
|
52 |
+
return cls.__name__
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def _combine_sum(all_res, field):
|
56 |
+
"""Combine sequence results via sum"""
|
57 |
+
return sum([all_res[k][field] for k in all_res.keys()])
|
58 |
+
|
59 |
+
@staticmethod
|
60 |
+
def _combine_weighted_av(all_res, field, comb_res, weight_field):
|
61 |
+
"""Combine sequence results via weighted average"""
|
62 |
+
return sum([all_res[k][field] * all_res[k][weight_field] for k in all_res.keys()]) / np.maximum(1.0, comb_res[
|
63 |
+
weight_field])
|
64 |
+
|
65 |
+
def print_table(self, table_res, tracker, cls):
|
66 |
+
"""Prints table of results for all sequences"""
|
67 |
+
print('')
|
68 |
+
metric_name = self.get_name()
|
69 |
+
self._row_print([metric_name + ': ' + tracker + '-' + cls] + self.summary_fields)
|
70 |
+
for seq, results in sorted(table_res.items()):
|
71 |
+
if seq == 'COMBINED_SEQ':
|
72 |
+
continue
|
73 |
+
summary_res = self._summary_row(results)
|
74 |
+
self._row_print([seq] + summary_res)
|
75 |
+
summary_res = self._summary_row(table_res['COMBINED_SEQ'])
|
76 |
+
self._row_print(['COMBINED'] + summary_res)
|
77 |
+
|
78 |
+
def _summary_row(self, results_):
|
79 |
+
vals = []
|
80 |
+
for h in self.summary_fields:
|
81 |
+
if h in self.float_array_fields:
|
82 |
+
vals.append("{0:1.5g}".format(100 * np.mean(results_[h])))
|
83 |
+
elif h in self.float_fields:
|
84 |
+
vals.append("{0:1.5g}".format(100 * float(results_[h])))
|
85 |
+
elif h in self.integer_fields:
|
86 |
+
vals.append("{0:d}".format(int(results_[h])))
|
87 |
+
else:
|
88 |
+
raise NotImplementedError("Summary function not implemented for this field type.")
|
89 |
+
return vals
|
90 |
+
|
91 |
+
@staticmethod
|
92 |
+
def _row_print(*argv):
|
93 |
+
"""Prints results in an evenly spaced rows, with more space in first row"""
|
94 |
+
if len(argv) == 1:
|
95 |
+
argv = argv[0]
|
96 |
+
to_print = '%-35s' % argv[0]
|
97 |
+
for v in argv[1:]:
|
98 |
+
to_print += '%-10s' % str(v)
|
99 |
+
print(to_print)
|
100 |
+
|
101 |
+
def summary_results(self, table_res):
|
102 |
+
"""Returns a simple summary of final results for a tracker"""
|
103 |
+
return dict(zip(self.summary_fields, self._summary_row(table_res['COMBINED_SEQ'])))
|
104 |
+
|
105 |
+
def detailed_results(self, table_res):
|
106 |
+
"""Returns detailed final results for a tracker"""
|
107 |
+
# Get detailed field information
|
108 |
+
detailed_fields = self.float_fields + self.integer_fields
|
109 |
+
for h in self.float_array_fields + self.integer_array_fields:
|
110 |
+
for alpha in [int(100*x) for x in self.array_labels]:
|
111 |
+
detailed_fields.append(h + '___' + str(alpha))
|
112 |
+
detailed_fields.append(h + '___AUC')
|
113 |
+
|
114 |
+
# Get detailed results
|
115 |
+
detailed_results = {}
|
116 |
+
for seq, res in table_res.items():
|
117 |
+
detailed_row = self._detailed_row(res)
|
118 |
+
if len(detailed_row) != len(detailed_fields):
|
119 |
+
raise TrackEvalException(
|
120 |
+
'Field names and data have different sizes (%i and %i)' % (len(detailed_row), len(detailed_fields)))
|
121 |
+
detailed_results[seq] = dict(zip(detailed_fields, detailed_row))
|
122 |
+
return detailed_results
|
123 |
+
|
124 |
+
def _detailed_row(self, res):
|
125 |
+
detailed_row = []
|
126 |
+
for h in self.float_fields + self.integer_fields:
|
127 |
+
detailed_row.append(res[h])
|
128 |
+
for h in self.float_array_fields + self.integer_array_fields:
|
129 |
+
for i, alpha in enumerate([int(100 * x) for x in self.array_labels]):
|
130 |
+
detailed_row.append(res[h][i])
|
131 |
+
detailed_row.append(np.mean(res[h]))
|
132 |
+
return detailed_row
|
avism/data/aviseval/metrics/av_loc.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from pycocotools import mask as mask_utils
|
3 |
+
from scipy.optimize import linear_sum_assignment
|
4 |
+
|
5 |
+
|
6 |
+
def compute_av_loc(data):
|
7 |
+
alphas = np.arange(0.05, 0.99, 0.05)
|
8 |
+
res = {}
|
9 |
+
res['FA'] = np.zeros((len(alphas)), dtype=float)
|
10 |
+
res['FAn'] = np.array([None] * len(alphas)) # frame accuracy in no sound source
|
11 |
+
res['FAs'] = np.array([None] * len(alphas)) # frame accuracy in single sound source
|
12 |
+
res['FAm'] = np.array([None] * len(alphas)) # frame accuracy in multi sound source
|
13 |
+
|
14 |
+
res['frame_num_n_all'] = np.zeros((len(alphas)), dtype=int)
|
15 |
+
res['frame_num_n_tp'] = np.zeros((len(alphas)), dtype=int)
|
16 |
+
res['frame_num_s_all'] = np.zeros((len(alphas)), dtype=int)
|
17 |
+
res['frame_num_s_tp'] = np.zeros((len(alphas)), dtype=int)
|
18 |
+
res['frame_num_m_all'] = np.zeros((len(alphas)), dtype=int)
|
19 |
+
res['frame_num_m_tp'] = np.zeros((len(alphas)), dtype=int)
|
20 |
+
|
21 |
+
frame_num_all = data['num_timesteps']
|
22 |
+
gt_classes = data['gt_classes']
|
23 |
+
gt_dets = data['gt_dets']
|
24 |
+
raw_classes = data['raw_classes']
|
25 |
+
raw_dets = data['raw_dets']
|
26 |
+
pred_classes = data['tracker_classes']
|
27 |
+
pred_dets = data['tracker_dets']
|
28 |
+
|
29 |
+
# 1. Find the best trajectory between gt and pred
|
30 |
+
unique_gt_ids = []
|
31 |
+
unique_tracker_ids = []
|
32 |
+
for t in range(data['num_timesteps']):
|
33 |
+
unique_gt_ids += list(np.unique(data['gt_ids'][t]))
|
34 |
+
unique_tracker_ids += list(np.unique(data['tracker_ids'][t]))
|
35 |
+
# Re-label IDs such that there are no empty IDs
|
36 |
+
if len(unique_gt_ids) > 0:
|
37 |
+
unique_gt_ids = np.unique(unique_gt_ids)
|
38 |
+
gt_id_map = np.nan * np.ones((np.max(unique_gt_ids) + 1))
|
39 |
+
gt_id_map[unique_gt_ids] = np.arange(len(unique_gt_ids))
|
40 |
+
for t in range(data['num_timesteps']):
|
41 |
+
if len(data['gt_ids'][t]) > 0:
|
42 |
+
data['gt_ids'][t] = gt_id_map[data['gt_ids'][t]].astype(int)
|
43 |
+
if len(unique_tracker_ids) > 0:
|
44 |
+
unique_tracker_ids = np.unique(unique_tracker_ids)
|
45 |
+
tracker_id_map = np.nan * np.ones((np.max(unique_tracker_ids) + 1))
|
46 |
+
tracker_id_map[unique_tracker_ids] = np.arange(len(unique_tracker_ids))
|
47 |
+
for t in range(data['num_timesteps']):
|
48 |
+
if len(data['tracker_ids'][t]) > 0:
|
49 |
+
data['tracker_ids'][t] = tracker_id_map[data['tracker_ids'][t]].astype(int)
|
50 |
+
data['num_tracker_ids'] = len(unique_tracker_ids)
|
51 |
+
data['num_gt_ids'] = len(unique_gt_ids)
|
52 |
+
# Variables counting global association
|
53 |
+
potential_matches_count = np.zeros((data['num_gt_ids'], data['num_tracker_ids']))
|
54 |
+
gt_id_count = np.zeros((data['num_gt_ids'], 1))
|
55 |
+
tracker_id_count = np.zeros((1, data['num_tracker_ids']))
|
56 |
+
# First loop through each timestep and accumulate global track information.
|
57 |
+
for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(data['gt_ids'], data['tracker_ids'])):
|
58 |
+
# Count the potential matches between ids in each timestep
|
59 |
+
# These are normalised, weighted by the match similarity.
|
60 |
+
similarity = data['similarity_scores'][t]
|
61 |
+
sim_iou_denom = similarity.sum(0)[np.newaxis, :] + similarity.sum(1)[:, np.newaxis] - similarity
|
62 |
+
sim_iou = np.zeros_like(similarity)
|
63 |
+
sim_iou_mask = sim_iou_denom > 0 + np.finfo('float').eps
|
64 |
+
sim_iou[sim_iou_mask] = similarity[sim_iou_mask] / sim_iou_denom[sim_iou_mask]
|
65 |
+
potential_matches_count[gt_ids_t[:, np.newaxis], tracker_ids_t[np.newaxis, :]] += sim_iou
|
66 |
+
# Calculate the total number of dets for each gt_id and tracker_id.
|
67 |
+
gt_id_count[gt_ids_t] += 1
|
68 |
+
tracker_id_count[0, tracker_ids_t] += 1
|
69 |
+
# Calculate overall jaccard alignment score (before unique matching) between IDs
|
70 |
+
global_alignment_score = potential_matches_count / (gt_id_count + tracker_id_count - potential_matches_count)
|
71 |
+
# Hungarian algorithm to find best matches
|
72 |
+
match_rows, match_cols = linear_sum_assignment(-global_alignment_score)
|
73 |
+
|
74 |
+
# 2. Compute FSLA
|
75 |
+
for a, alpha in enumerate(alphas):
|
76 |
+
frame_num_n_all = 0 # total frames in no sound source
|
77 |
+
frame_num_s_all = 0 # total frames in single sound source
|
78 |
+
frame_num_m_all = 0 # total frames in multi sound source
|
79 |
+
frame_num_n_tp = 0 # true positive frames in no sound source
|
80 |
+
frame_num_s_tp = 0 # true positive frames in single sound source
|
81 |
+
frame_num_m_tp = 0 # true positive frames in multi sound source
|
82 |
+
|
83 |
+
for frame_id in range(frame_num_all):
|
84 |
+
# classes
|
85 |
+
gt_classes_per_frame = gt_classes[frame_id]
|
86 |
+
raw_classes_per_frame = raw_classes[frame_id]
|
87 |
+
pred_classes_per_frame = pred_classes[frame_id]
|
88 |
+
# masks
|
89 |
+
gt_dets_per_frame = gt_dets[frame_id]
|
90 |
+
raw_dets_per_frame = raw_dets[frame_id]
|
91 |
+
pred_dets_per_frame = pred_dets[frame_id]
|
92 |
+
|
93 |
+
if len(pred_dets_per_frame) > 0:
|
94 |
+
pred_dets_per_frame_f = [di for di in pred_dets_per_frame if di['counts'] != 'PPTl0']
|
95 |
+
else:
|
96 |
+
pred_dets_per_frame_f = pred_dets_per_frame
|
97 |
+
|
98 |
+
# Masks must have the same class and number
|
99 |
+
if (set(gt_classes_per_frame) == set(pred_classes_per_frame)) and (len(gt_dets_per_frame) == len(pred_dets_per_frame_f)):
|
100 |
+
# 1) no sound source
|
101 |
+
if len(gt_dets_per_frame) == 0:
|
102 |
+
frame_num_n_all += 1
|
103 |
+
frame_num_n_tp += 1
|
104 |
+
# 2) single sound source
|
105 |
+
elif len(gt_dets_per_frame) == 1:
|
106 |
+
frame_num_s_all += 1
|
107 |
+
index_gt = [index for index, value in enumerate(raw_dets_per_frame) if value is not None][0]
|
108 |
+
index_pred = [index for index, element in enumerate(match_cols) if element == index_gt]
|
109 |
+
if index_pred != []:
|
110 |
+
ious = mask_utils.iou(gt_dets_per_frame, [pred_dets_per_frame[index_pred[0]]], [False])
|
111 |
+
if np.all(ious > alpha):
|
112 |
+
frame_num_s_tp += 1
|
113 |
+
# 3) multi sound source
|
114 |
+
else:
|
115 |
+
frame_num_m_all += 1
|
116 |
+
flags = [0] * len(match_rows)
|
117 |
+
for tr in range(len(match_rows)):
|
118 |
+
if (raw_classes_per_frame[match_rows[tr]] == pred_classes_per_frame[match_cols[tr]]):
|
119 |
+
if raw_dets_per_frame[match_rows[tr]] == None:
|
120 |
+
if pred_dets_per_frame[match_cols[tr]]['counts'] == 'PPTl0':
|
121 |
+
flags[tr] = 1
|
122 |
+
else:
|
123 |
+
iou = mask_utils.iou([raw_dets_per_frame[match_rows[tr]]],
|
124 |
+
[pred_dets_per_frame[match_cols[tr]]], [False])
|
125 |
+
if np.all(iou > alpha):
|
126 |
+
flags[tr] = 1
|
127 |
+
if all(ff == 1 for ff in flags):
|
128 |
+
frame_num_m_tp += 1
|
129 |
+
else:
|
130 |
+
if len(gt_dets_per_frame) == 0:
|
131 |
+
frame_num_n_all += 1
|
132 |
+
elif len(gt_dets_per_frame) == 1:
|
133 |
+
frame_num_s_all += 1
|
134 |
+
else:
|
135 |
+
frame_num_m_all += 1
|
136 |
+
|
137 |
+
assert frame_num_all == (frame_num_n_all + frame_num_s_all + frame_num_m_all)
|
138 |
+
|
139 |
+
if frame_num_n_all > 0:
|
140 |
+
res['FAn'][a] = frame_num_n_tp / frame_num_n_all
|
141 |
+
res['frame_num_n_all'][a] = frame_num_n_all
|
142 |
+
res['frame_num_n_tp'][a] = frame_num_n_tp
|
143 |
+
else:
|
144 |
+
res['FAn'][a] = None
|
145 |
+
res['frame_num_n_all'][a] = 0
|
146 |
+
res['frame_num_n_tp'][a] = 0
|
147 |
+
|
148 |
+
if frame_num_s_all > 0:
|
149 |
+
res['FAs'][a] = frame_num_s_tp / frame_num_s_all
|
150 |
+
res['frame_num_s_all'][a] = frame_num_s_all
|
151 |
+
res['frame_num_s_tp'][a] = frame_num_s_tp
|
152 |
+
else:
|
153 |
+
res['FAs'][a] = None
|
154 |
+
res['frame_num_s_all'][a] = 0
|
155 |
+
res['frame_num_s_tp'][a] = 0
|
156 |
+
|
157 |
+
if frame_num_m_all > 0:
|
158 |
+
res['FAm'][a] = frame_num_m_tp / frame_num_m_all
|
159 |
+
res['frame_num_m_all'][a] = frame_num_m_all
|
160 |
+
res['frame_num_m_tp'][a] = frame_num_m_tp
|
161 |
+
else:
|
162 |
+
res['FAm'][a] = None
|
163 |
+
res['frame_num_m_all'][a] = 0
|
164 |
+
res['frame_num_m_tp'][a] = 0
|
165 |
+
|
166 |
+
res['FA'][a] = (frame_num_n_tp + frame_num_s_tp + frame_num_m_tp) / frame_num_all
|
167 |
+
|
168 |
+
return res
|
169 |
+
|
170 |
+
|
171 |
+
def combine_av_loc_sequences(all_res):
|
172 |
+
"""Combines metrics across all sequences"""
|
173 |
+
res = {}
|
174 |
+
fields_num = ['frame_num_n_all', 'frame_num_s_all', 'frame_num_m_all', 'frame_num_n_tp', 'frame_num_s_tp', 'frame_num_m_tp']
|
175 |
+
for field in fields_num:
|
176 |
+
res[field] = sum([all_res[k][field] for k in all_res.keys()])
|
177 |
+
|
178 |
+
res_final = {}
|
179 |
+
|
180 |
+
res_final['FAn'] = res['frame_num_n_tp'] / res['frame_num_n_all']
|
181 |
+
res_final['FAn_count'] = res['frame_num_n_tp']
|
182 |
+
res_final['FAn_all'] = res['frame_num_n_all']
|
183 |
+
res_final['FAs'] = res['frame_num_s_tp'] / res['frame_num_s_all']
|
184 |
+
res_final['FAs_count'] = res['frame_num_s_tp']
|
185 |
+
res_final['FAs_all'] = res['frame_num_s_all']
|
186 |
+
res_final['FAm'] = res['frame_num_m_tp'] / res['frame_num_m_all']
|
187 |
+
res_final['FAm_count'] = res['frame_num_m_tp']
|
188 |
+
res_final['FAm_all'] = res['frame_num_m_all']
|
189 |
+
res_final['FA'] = (res['frame_num_n_tp'] + res['frame_num_s_tp'] + res['frame_num_m_tp']) / (res['frame_num_n_all'] + res['frame_num_s_all'] + res['frame_num_m_all'])
|
190 |
+
|
191 |
+
return res_final
|
avism/data/aviseval/metrics/avisa.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from scipy.optimize import linear_sum_assignment
|
4 |
+
from ._base_metric import _BaseMetric
|
5 |
+
from .. import _timing
|
6 |
+
|
7 |
+
|
8 |
+
class avisA(_BaseMetric):
|
9 |
+
def __init__(self, config=None):
|
10 |
+
super().__init__()
|
11 |
+
self.plottable = True
|
12 |
+
self.array_labels = np.arange(0.05, 0.99, 0.05)
|
13 |
+
self.integer_array_fields = ['HOTA_TP', 'HOTA_FN', 'HOTA_FP']
|
14 |
+
self.float_array_fields = ['AssA', 'AssRe', 'AssPr', 'SegA']
|
15 |
+
self.float_fields = ['SegA(0)']
|
16 |
+
self.fields = self.float_array_fields + self.integer_array_fields + self.float_fields
|
17 |
+
self.summary_fields = self.float_array_fields + self.float_fields
|
18 |
+
|
19 |
+
@_timing.time
|
20 |
+
def eval_sequence(self, data):
|
21 |
+
"""Calculates the AssA and SegA metrics for one sequence"""
|
22 |
+
|
23 |
+
# Initialise results
|
24 |
+
res = {}
|
25 |
+
for field in self.float_array_fields + self.integer_array_fields:
|
26 |
+
res[field] = np.zeros((len(self.array_labels)), dtype=float)
|
27 |
+
for field in self.float_fields:
|
28 |
+
res[field] = 0
|
29 |
+
|
30 |
+
# Return result quickly if tracker or gt sequence is empty
|
31 |
+
if data['num_tracker_dets'] == 0:
|
32 |
+
res['HOTA_FN'] = data['num_gt_dets'] * np.ones((len(self.array_labels)), dtype=float)
|
33 |
+
res['SegA'] = np.ones((len(self.array_labels)), dtype=float)
|
34 |
+
res['SegA(0)'] = 1.0
|
35 |
+
return res
|
36 |
+
if data['num_gt_dets'] == 0:
|
37 |
+
res['HOTA_FP'] = data['num_tracker_dets'] * np.ones((len(self.array_labels)), dtype=float)
|
38 |
+
res['SegA'] = np.ones((len(self.array_labels)), dtype=float)
|
39 |
+
res['SegA(0)'] = 1.0
|
40 |
+
return res
|
41 |
+
|
42 |
+
# Variables counting global association
|
43 |
+
potential_matches_count = np.zeros((data['num_gt_ids'], data['num_tracker_ids']))
|
44 |
+
gt_id_count = np.zeros((data['num_gt_ids'], 1))
|
45 |
+
tracker_id_count = np.zeros((1, data['num_tracker_ids']))
|
46 |
+
|
47 |
+
# First loop through each timestep and accumulate global track information.
|
48 |
+
for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(data['gt_ids'], data['tracker_ids'])):
|
49 |
+
# Count the potential matches between ids in each timestep
|
50 |
+
# These are normalised, weighted by the match similarity.
|
51 |
+
similarity = data['similarity_scores'][t]
|
52 |
+
sim_iou_denom = similarity.sum(0)[np.newaxis, :] + similarity.sum(1)[:, np.newaxis] - similarity
|
53 |
+
sim_iou = np.zeros_like(similarity)
|
54 |
+
sim_iou_mask = sim_iou_denom > 0 + np.finfo('float').eps
|
55 |
+
sim_iou[sim_iou_mask] = similarity[sim_iou_mask] / sim_iou_denom[sim_iou_mask]
|
56 |
+
potential_matches_count[gt_ids_t[:, np.newaxis], tracker_ids_t[np.newaxis, :]] += sim_iou
|
57 |
+
|
58 |
+
# Calculate the total number of dets for each gt_id and tracker_id.
|
59 |
+
gt_id_count[gt_ids_t] += 1
|
60 |
+
tracker_id_count[0, tracker_ids_t] += 1
|
61 |
+
|
62 |
+
# Calculate overall jaccard alignment score (before unique matching) between IDs
|
63 |
+
global_alignment_score = potential_matches_count / (gt_id_count + tracker_id_count - potential_matches_count)
|
64 |
+
matches_counts = [np.zeros_like(potential_matches_count) for _ in self.array_labels]
|
65 |
+
|
66 |
+
# Calculate scores for each timestep
|
67 |
+
for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(data['gt_ids'], data['tracker_ids'])):
|
68 |
+
# Deal with the case that there are no gt_det/tracker_det in a timestep.
|
69 |
+
if len(gt_ids_t) == 0:
|
70 |
+
for a, alpha in enumerate(self.array_labels):
|
71 |
+
res['HOTA_FP'][a] += len(tracker_ids_t)
|
72 |
+
continue
|
73 |
+
if len(tracker_ids_t) == 0:
|
74 |
+
for a, alpha in enumerate(self.array_labels):
|
75 |
+
res['HOTA_FN'][a] += len(gt_ids_t)
|
76 |
+
continue
|
77 |
+
|
78 |
+
# Get matching scores between pairs of dets for optimizing HOTA
|
79 |
+
similarity = data['similarity_scores'][t]
|
80 |
+
score_mat = global_alignment_score[gt_ids_t[:, np.newaxis], tracker_ids_t[np.newaxis, :]] * similarity
|
81 |
+
|
82 |
+
# Hungarian algorithm to find best matches: 找出最优跟踪路线
|
83 |
+
match_rows, match_cols = linear_sum_assignment(-score_mat)
|
84 |
+
|
85 |
+
# Calculate and accumulate basic statistics
|
86 |
+
for a, alpha in enumerate(self.array_labels):
|
87 |
+
actually_matched_mask = similarity[match_rows, match_cols] >= alpha - np.finfo('float').eps
|
88 |
+
alpha_match_rows = match_rows[actually_matched_mask]
|
89 |
+
alpha_match_cols = match_cols[actually_matched_mask]
|
90 |
+
num_matches = len(alpha_match_rows)
|
91 |
+
res['HOTA_TP'][a] += num_matches
|
92 |
+
res['HOTA_FN'][a] += len(gt_ids_t) - num_matches
|
93 |
+
res['HOTA_FP'][a] += len(tracker_ids_t) - num_matches
|
94 |
+
if num_matches > 0:
|
95 |
+
res['SegA'][a] += sum(similarity[alpha_match_rows, alpha_match_cols])
|
96 |
+
matches_counts[a][gt_ids_t[alpha_match_rows], tracker_ids_t[alpha_match_cols]] += 1
|
97 |
+
|
98 |
+
# Calculate association scores (AssA, AssRe, AssPr) for the alpha value.
|
99 |
+
# First calculate scores per gt_id/tracker_id combo and then average over the number of detections.
|
100 |
+
for a, alpha in enumerate(self.array_labels):
|
101 |
+
matches_count = matches_counts[a]
|
102 |
+
ass_a = matches_count / np.maximum(1, gt_id_count + tracker_id_count - matches_count)
|
103 |
+
res['AssA'][a] = np.sum(matches_count * ass_a) / np.maximum(1, res['HOTA_TP'][a])
|
104 |
+
ass_re = matches_count / np.maximum(1, gt_id_count)
|
105 |
+
res['AssRe'][a] = np.sum(matches_count * ass_re) / np.maximum(1, res['HOTA_TP'][a])
|
106 |
+
ass_pr = matches_count / np.maximum(1, tracker_id_count)
|
107 |
+
res['AssPr'][a] = np.sum(matches_count * ass_pr) / np.maximum(1, res['HOTA_TP'][a])
|
108 |
+
|
109 |
+
# Calculate final scores
|
110 |
+
res['SegA'] = np.maximum(1e-10, res['SegA']) / np.maximum(1e-10, res['HOTA_TP'])
|
111 |
+
res = self._compute_final_fields(res)
|
112 |
+
return res
|
113 |
+
|
114 |
+
def combine_sequences(self, all_res):
|
115 |
+
"""Combines metrics across all sequences"""
|
116 |
+
res = {}
|
117 |
+
for field in self.integer_array_fields:
|
118 |
+
res[field] = self._combine_sum(all_res, field)
|
119 |
+
for field in ['AssRe', 'AssPr', 'AssA']:
|
120 |
+
res[field] = self._combine_weighted_av(all_res, field, res, weight_field='HOTA_TP')
|
121 |
+
sega_weighted_sum = sum([all_res[k]['SegA'] * all_res[k]['HOTA_TP'] for k in all_res.keys()])
|
122 |
+
res['SegA'] = np.maximum(1e-10, sega_weighted_sum) / np.maximum(1e-10, res['HOTA_TP'])
|
123 |
+
res = self._compute_final_fields(res)
|
124 |
+
return res
|
125 |
+
|
126 |
+
def combine_classes_class_averaged(self, all_res, ignore_empty_classes=False):
|
127 |
+
"""Combines metrics across all classes by averaging over the class values.
|
128 |
+
If 'ignore_empty_classes' is True, then it only sums over classes with at least one gt or predicted detection.
|
129 |
+
"""
|
130 |
+
res = {}
|
131 |
+
for field in self.integer_array_fields:
|
132 |
+
if ignore_empty_classes:
|
133 |
+
res[field] = self._combine_sum(
|
134 |
+
{k: v for k, v in all_res.items()
|
135 |
+
if (v['HOTA_TP'] + v['HOTA_FN'] + v['HOTA_FP'] > 0 + np.finfo('float').eps).any()}, field)
|
136 |
+
else:
|
137 |
+
res[field] = self._combine_sum({k: v for k, v in all_res.items()}, field)
|
138 |
+
|
139 |
+
for field in self.float_fields + self.float_array_fields:
|
140 |
+
if ignore_empty_classes:
|
141 |
+
res[field] = np.mean([v[field] for v in all_res.values() if
|
142 |
+
(v['HOTA_TP'] + v['HOTA_FN'] + v['HOTA_FP'] > 0 + np.finfo('float').eps).any()],
|
143 |
+
axis=0)
|
144 |
+
else:
|
145 |
+
res[field] = np.mean([v[field] for v in all_res.values()], axis=0)
|
146 |
+
return res
|
147 |
+
|
148 |
+
def combine_classes_det_averaged(self, all_res):
|
149 |
+
"""Combines metrics across all classes by averaging over the detection values"""
|
150 |
+
res = {}
|
151 |
+
for field in self.integer_array_fields:
|
152 |
+
res[field] = self._combine_sum(all_res, field)
|
153 |
+
for field in ['AssRe', 'AssPr', 'AssA']:
|
154 |
+
res[field] = self._combine_weighted_av(all_res, field, res, weight_field='HOTA_TP')
|
155 |
+
sega_weighted_sum = sum([all_res[k]['SegA'] * all_res[k]['HOTA_TP'] for k in all_res.keys()])
|
156 |
+
res['SegA'] = np.maximum(1e-10, sega_weighted_sum) / np.maximum(1e-10, res['HOTA_TP'])
|
157 |
+
res = self._compute_final_fields(res)
|
158 |
+
return res
|
159 |
+
|
160 |
+
@staticmethod
|
161 |
+
def _compute_final_fields(res):
|
162 |
+
"""Calculate sub-metric ('field') values which only depend on other sub-metric values.
|
163 |
+
This function is used both for both per-sequence calculation, and in combining values across sequences.
|
164 |
+
"""
|
165 |
+
res['SegA(0)'] = res['SegA'][0]
|
166 |
+
return res
|
167 |
+
|
168 |
+
def plot_single_tracker_results(self, table_res, tracker, cls, output_folder):
|
169 |
+
"""Create plot of results"""
|
170 |
+
|
171 |
+
# Only loaded when run to reduce minimum requirements
|
172 |
+
from matplotlib import pyplot as plt
|
173 |
+
|
174 |
+
res = table_res['COMBINED_SEQ']
|
175 |
+
styles_to_plot = ['r', 'b', 'g', 'b--', 'b:', 'g--', 'g:', 'm']
|
176 |
+
for name, style in zip(self.float_array_fields, styles_to_plot):
|
177 |
+
plt.plot(self.array_labels, res[name], style)
|
178 |
+
plt.xlabel('alpha')
|
179 |
+
plt.ylabel('score')
|
180 |
+
plt.title(tracker + ' - ' + cls)
|
181 |
+
plt.axis([0, 1, 0, 1])
|
182 |
+
legend = []
|
183 |
+
for name in self.float_array_fields:
|
184 |
+
legend += [name + ' (' + str(np.round(np.mean(res[name]), 2)) + ')']
|
185 |
+
plt.legend(legend, loc='lower left')
|
186 |
+
out_file = os.path.join(output_folder, cls + '_plot.pdf')
|
187 |
+
os.makedirs(os.path.dirname(out_file), exist_ok=True)
|
188 |
+
plt.savefig(out_file)
|
189 |
+
plt.savefig(out_file.replace('.pdf', '.png'))
|
190 |
+
plt.clf()
|
avism/data/aviseval/metrics/clear.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
from scipy.optimize import linear_sum_assignment
|
4 |
+
from ._base_metric import _BaseMetric
|
5 |
+
from .. import _timing
|
6 |
+
from .. import utils
|
7 |
+
|
8 |
+
class CLEAR(_BaseMetric):
|
9 |
+
"""Class which implements the CLEAR metrics"""
|
10 |
+
|
11 |
+
@staticmethod
|
12 |
+
def get_default_config():
|
13 |
+
"""Default class config values"""
|
14 |
+
default_config = {
|
15 |
+
'THRESHOLD': 0.5, # Similarity score threshold required for a TP match. Default 0.5.
|
16 |
+
'PRINT_CONFIG': True, # Whether to print the config information on init. Default: False.
|
17 |
+
}
|
18 |
+
return default_config
|
19 |
+
|
20 |
+
def __init__(self, config=None):
|
21 |
+
super().__init__()
|
22 |
+
main_integer_fields = ['CLR_TP', 'CLR_FN', 'CLR_FP', 'IDSW', 'MT', 'PT', 'ML', 'Frag']
|
23 |
+
extra_integer_fields = ['CLR_Frames']
|
24 |
+
self.integer_fields = main_integer_fields + extra_integer_fields
|
25 |
+
main_float_fields = ['MOTA', 'MOTP', 'MODA', 'CLR_Re', 'CLR_Pr', 'MTR', 'PTR', 'MLR', 'sMOTA']
|
26 |
+
extra_float_fields = ['CLR_F1', 'FP_per_frame', 'MOTAL', 'MOTP_sum']
|
27 |
+
self.float_fields = main_float_fields + extra_float_fields
|
28 |
+
self.fields = self.float_fields + self.integer_fields
|
29 |
+
self.summed_fields = self.integer_fields + ['MOTP_sum']
|
30 |
+
self.summary_fields = main_float_fields + main_integer_fields
|
31 |
+
|
32 |
+
# Configuration options:
|
33 |
+
self.config = utils.init_config(config, self.get_default_config(), self.get_name())
|
34 |
+
self.threshold = float(self.config['THRESHOLD'])
|
35 |
+
|
36 |
+
|
37 |
+
@_timing.time
|
38 |
+
def eval_sequence(self, data):
|
39 |
+
"""Calculates CLEAR metrics for one sequence"""
|
40 |
+
# Initialise results
|
41 |
+
res = {}
|
42 |
+
for field in self.fields:
|
43 |
+
res[field] = 0
|
44 |
+
|
45 |
+
# Return result quickly if tracker or gt sequence is empty
|
46 |
+
if data['num_tracker_dets'] == 0:
|
47 |
+
res['CLR_FN'] = data['num_gt_dets']
|
48 |
+
res['ML'] = data['num_gt_ids']
|
49 |
+
res['MLR'] = 1.0
|
50 |
+
return res
|
51 |
+
if data['num_gt_dets'] == 0:
|
52 |
+
res['CLR_FP'] = data['num_tracker_dets']
|
53 |
+
res['MLR'] = 1.0
|
54 |
+
return res
|
55 |
+
|
56 |
+
# Variables counting global association
|
57 |
+
num_gt_ids = data['num_gt_ids']
|
58 |
+
gt_id_count = np.zeros(num_gt_ids) # For MT/ML/PT
|
59 |
+
gt_matched_count = np.zeros(num_gt_ids) # For MT/ML/PT
|
60 |
+
gt_frag_count = np.zeros(num_gt_ids) # For Frag
|
61 |
+
|
62 |
+
# Note that IDSWs are counted based on the last time each gt_id was present (any number of frames previously),
|
63 |
+
# but are only used in matching to continue current tracks based on the gt_id in the single previous timestep.
|
64 |
+
prev_tracker_id = np.nan * np.zeros(num_gt_ids) # For scoring IDSW
|
65 |
+
prev_timestep_tracker_id = np.nan * np.zeros(num_gt_ids) # For matching IDSW
|
66 |
+
|
67 |
+
# Calculate scores for each timestep
|
68 |
+
for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(data['gt_ids'], data['tracker_ids'])):
|
69 |
+
# Deal with the case that there are no gt_det/tracker_det in a timestep.
|
70 |
+
if len(gt_ids_t) == 0:
|
71 |
+
res['CLR_FP'] += len(tracker_ids_t)
|
72 |
+
continue
|
73 |
+
if len(tracker_ids_t) == 0:
|
74 |
+
res['CLR_FN'] += len(gt_ids_t)
|
75 |
+
gt_id_count[gt_ids_t] += 1
|
76 |
+
continue
|
77 |
+
|
78 |
+
# Calc score matrix to first minimise IDSWs from previous frame, and then maximise MOTP secondarily
|
79 |
+
similarity = data['similarity_scores'][t]
|
80 |
+
score_mat = (tracker_ids_t[np.newaxis, :] == prev_timestep_tracker_id[gt_ids_t[:, np.newaxis]])
|
81 |
+
score_mat = 1000 * score_mat + similarity
|
82 |
+
score_mat[similarity < self.threshold - np.finfo('float').eps] = 0
|
83 |
+
|
84 |
+
# Hungarian algorithm to find best matches
|
85 |
+
match_rows, match_cols = linear_sum_assignment(-score_mat)
|
86 |
+
actually_matched_mask = score_mat[match_rows, match_cols] > 0 + np.finfo('float').eps
|
87 |
+
match_rows = match_rows[actually_matched_mask]
|
88 |
+
match_cols = match_cols[actually_matched_mask]
|
89 |
+
|
90 |
+
matched_gt_ids = gt_ids_t[match_rows]
|
91 |
+
matched_tracker_ids = tracker_ids_t[match_cols]
|
92 |
+
|
93 |
+
# Calc IDSW for MOTA
|
94 |
+
prev_matched_tracker_ids = prev_tracker_id[matched_gt_ids]
|
95 |
+
is_idsw = (np.logical_not(np.isnan(prev_matched_tracker_ids))) & (
|
96 |
+
np.not_equal(matched_tracker_ids, prev_matched_tracker_ids))
|
97 |
+
res['IDSW'] += np.sum(is_idsw)
|
98 |
+
|
99 |
+
# Update counters for MT/ML/PT/Frag and record for IDSW/Frag for next timestep
|
100 |
+
gt_id_count[gt_ids_t] += 1
|
101 |
+
gt_matched_count[matched_gt_ids] += 1
|
102 |
+
not_previously_tracked = np.isnan(prev_timestep_tracker_id)
|
103 |
+
prev_tracker_id[matched_gt_ids] = matched_tracker_ids
|
104 |
+
prev_timestep_tracker_id[:] = np.nan
|
105 |
+
prev_timestep_tracker_id[matched_gt_ids] = matched_tracker_ids
|
106 |
+
currently_tracked = np.logical_not(np.isnan(prev_timestep_tracker_id))
|
107 |
+
gt_frag_count += np.logical_and(not_previously_tracked, currently_tracked)
|
108 |
+
|
109 |
+
# Calculate and accumulate basic statistics
|
110 |
+
num_matches = len(matched_gt_ids)
|
111 |
+
res['CLR_TP'] += num_matches
|
112 |
+
res['CLR_FN'] += len(gt_ids_t) - num_matches
|
113 |
+
res['CLR_FP'] += len(tracker_ids_t) - num_matches
|
114 |
+
if num_matches > 0:
|
115 |
+
res['MOTP_sum'] += sum(similarity[match_rows, match_cols])
|
116 |
+
|
117 |
+
# Calculate MT/ML/PT/Frag/MOTP
|
118 |
+
tracked_ratio = gt_matched_count[gt_id_count > 0] / gt_id_count[gt_id_count > 0]
|
119 |
+
res['MT'] = np.sum(np.greater(tracked_ratio, 0.8))
|
120 |
+
res['PT'] = np.sum(np.greater_equal(tracked_ratio, 0.2)) - res['MT']
|
121 |
+
res['ML'] = num_gt_ids - res['MT'] - res['PT']
|
122 |
+
res['Frag'] = np.sum(np.subtract(gt_frag_count[gt_frag_count > 0], 1))
|
123 |
+
res['MOTP'] = res['MOTP_sum'] / np.maximum(1.0, res['CLR_TP'])
|
124 |
+
|
125 |
+
res['CLR_Frames'] = data['num_timesteps']
|
126 |
+
|
127 |
+
# Calculate final CLEAR scores
|
128 |
+
res = self._compute_final_fields(res)
|
129 |
+
return res
|
130 |
+
|
131 |
+
def combine_sequences(self, all_res):
|
132 |
+
"""Combines metrics across all sequences"""
|
133 |
+
res = {}
|
134 |
+
for field in self.summed_fields:
|
135 |
+
res[field] = self._combine_sum(all_res, field)
|
136 |
+
res = self._compute_final_fields(res)
|
137 |
+
return res
|
138 |
+
|
139 |
+
def combine_classes_det_averaged(self, all_res):
|
140 |
+
"""Combines metrics across all classes by averaging over the detection values"""
|
141 |
+
res = {}
|
142 |
+
for field in self.summed_fields:
|
143 |
+
res[field] = self._combine_sum(all_res, field)
|
144 |
+
res = self._compute_final_fields(res)
|
145 |
+
return res
|
146 |
+
|
147 |
+
def combine_classes_class_averaged(self, all_res, ignore_empty_classes=False):
|
148 |
+
"""Combines metrics across all classes by averaging over the class values.
|
149 |
+
If 'ignore_empty_classes' is True, then it only sums over classes with at least one gt or predicted detection.
|
150 |
+
"""
|
151 |
+
res = {}
|
152 |
+
for field in self.integer_fields:
|
153 |
+
if ignore_empty_classes:
|
154 |
+
res[field] = self._combine_sum(
|
155 |
+
{k: v for k, v in all_res.items() if v['CLR_TP'] + v['CLR_FN'] + v['CLR_FP'] > 0}, field)
|
156 |
+
else:
|
157 |
+
res[field] = self._combine_sum({k: v for k, v in all_res.items()}, field)
|
158 |
+
for field in self.float_fields:
|
159 |
+
if ignore_empty_classes:
|
160 |
+
res[field] = np.mean(
|
161 |
+
[v[field] for v in all_res.values() if v['CLR_TP'] + v['CLR_FN'] + v['CLR_FP'] > 0], axis=0)
|
162 |
+
else:
|
163 |
+
res[field] = np.mean([v[field] for v in all_res.values()], axis=0)
|
164 |
+
return res
|
165 |
+
|
166 |
+
@staticmethod
|
167 |
+
def _compute_final_fields(res):
|
168 |
+
"""Calculate sub-metric ('field') values which only depend on other sub-metric values.
|
169 |
+
This function is used both for both per-sequence calculation, and in combining values across sequences.
|
170 |
+
"""
|
171 |
+
num_gt_ids = res['MT'] + res['ML'] + res['PT']
|
172 |
+
res['MTR'] = res['MT'] / np.maximum(1.0, num_gt_ids)
|
173 |
+
res['MLR'] = res['ML'] / np.maximum(1.0, num_gt_ids)
|
174 |
+
res['PTR'] = res['PT'] / np.maximum(1.0, num_gt_ids)
|
175 |
+
res['CLR_Re'] = res['CLR_TP'] / np.maximum(1.0, res['CLR_TP'] + res['CLR_FN'])
|
176 |
+
res['CLR_Pr'] = res['CLR_TP'] / np.maximum(1.0, res['CLR_TP'] + res['CLR_FP'])
|
177 |
+
res['MODA'] = (res['CLR_TP'] - res['CLR_FP']) / np.maximum(1.0, res['CLR_TP'] + res['CLR_FN'])
|
178 |
+
res['MOTA'] = (res['CLR_TP'] - res['CLR_FP'] - res['IDSW']) / np.maximum(1.0, res['CLR_TP'] + res['CLR_FN'])
|
179 |
+
res['MOTP'] = res['MOTP_sum'] / np.maximum(1.0, res['CLR_TP'])
|
180 |
+
res['sMOTA'] = (res['MOTP_sum'] - res['CLR_FP'] - res['IDSW']) / np.maximum(1.0, res['CLR_TP'] + res['CLR_FN'])
|
181 |
+
|
182 |
+
res['CLR_F1'] = res['CLR_TP'] / np.maximum(1.0, res['CLR_TP'] + 0.5*res['CLR_FN'] + 0.5*res['CLR_FP'])
|
183 |
+
res['FP_per_frame'] = res['CLR_FP'] / np.maximum(1.0, res['CLR_Frames'])
|
184 |
+
safe_log_idsw = np.log10(res['IDSW']) if res['IDSW'] > 0 else res['IDSW']
|
185 |
+
res['MOTAL'] = (res['CLR_TP'] - res['CLR_FP'] - safe_log_idsw) / np.maximum(1.0, res['CLR_TP'] + res['CLR_FN'])
|
186 |
+
return res
|
avism/data/aviseval/metrics/count.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from ._base_metric import _BaseMetric
|
3 |
+
from .. import _timing
|
4 |
+
|
5 |
+
|
6 |
+
class Count(_BaseMetric):
|
7 |
+
"""Class which simply counts the number of tracker and gt detections and ids."""
|
8 |
+
def __init__(self, config=None):
|
9 |
+
super().__init__()
|
10 |
+
self.integer_fields = ['Dets', 'GT_Dets', 'IDs', 'GT_IDs']
|
11 |
+
self.fields = self.integer_fields
|
12 |
+
self.summary_fields = self.fields
|
13 |
+
|
14 |
+
@_timing.time
|
15 |
+
def eval_sequence(self, data):
|
16 |
+
"""Returns counts for one sequence"""
|
17 |
+
# Get results
|
18 |
+
res = {'Dets': data['num_tracker_dets'],
|
19 |
+
'GT_Dets': data['num_gt_dets'],
|
20 |
+
'IDs': data['num_tracker_ids'],
|
21 |
+
'GT_IDs': data['num_gt_ids'],
|
22 |
+
'Frames': data['num_timesteps']}
|
23 |
+
return res
|
24 |
+
|
25 |
+
def combine_sequences(self, all_res):
|
26 |
+
"""Combines metrics across all sequences"""
|
27 |
+
res = {}
|
28 |
+
for field in self.integer_fields:
|
29 |
+
res[field] = self._combine_sum(all_res, field)
|
30 |
+
return res
|
31 |
+
|
32 |
+
def combine_classes_class_averaged(self, all_res, ignore_empty_classes=None):
|
33 |
+
"""Combines metrics across all classes by averaging over the class values"""
|
34 |
+
res = {}
|
35 |
+
for field in self.integer_fields:
|
36 |
+
res[field] = self._combine_sum(all_res, field)
|
37 |
+
return res
|
38 |
+
|
39 |
+
def combine_classes_det_averaged(self, all_res):
|
40 |
+
"""Combines metrics across all classes by averaging over the detection values"""
|
41 |
+
res = {}
|
42 |
+
for field in self.integer_fields:
|
43 |
+
res[field] = self._combine_sum(all_res, field)
|
44 |
+
return res
|
avism/data/aviseval/metrics/hota.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from scipy.optimize import linear_sum_assignment
|
4 |
+
from ._base_metric import _BaseMetric
|
5 |
+
from .. import _timing
|
6 |
+
|
7 |
+
|
8 |
+
class HOTA(_BaseMetric):
|
9 |
+
"""Class which implements the HOTA metrics.
|
10 |
+
See: https://link.springer.com/article/10.1007/s11263-020-01375-2
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, config=None):
|
14 |
+
super().__init__()
|
15 |
+
self.plottable = True
|
16 |
+
self.array_labels = np.arange(0.05, 0.99, 0.05)
|
17 |
+
self.integer_array_fields = ['HOTA_TP', 'HOTA_FN', 'HOTA_FP']
|
18 |
+
self.float_array_fields = ['HOTA', 'DetA', 'AssA', 'DetRe', 'DetPr', 'AssRe', 'AssPr', 'LocA', 'OWTA']
|
19 |
+
self.float_fields = ['HOTA(0)', 'LocA(0)', 'HOTALocA(0)']
|
20 |
+
self.fields = self.float_array_fields + self.integer_array_fields + self.float_fields
|
21 |
+
self.summary_fields = self.float_array_fields + self.float_fields
|
22 |
+
|
23 |
+
@_timing.time
|
24 |
+
def eval_sequence(self, data):
|
25 |
+
"""Calculates the HOTA metrics for one sequence"""
|
26 |
+
|
27 |
+
# Initialise results
|
28 |
+
res = {}
|
29 |
+
for field in self.float_array_fields + self.integer_array_fields:
|
30 |
+
res[field] = np.zeros((len(self.array_labels)), dtype=float)
|
31 |
+
for field in self.float_fields:
|
32 |
+
res[field] = 0
|
33 |
+
|
34 |
+
# Return result quickly if tracker or gt sequence is empty
|
35 |
+
if data['num_tracker_dets'] == 0:
|
36 |
+
res['HOTA_FN'] = data['num_gt_dets'] * np.ones((len(self.array_labels)), dtype=float)
|
37 |
+
res['LocA'] = np.ones((len(self.array_labels)), dtype=float)
|
38 |
+
res['LocA(0)'] = 1.0
|
39 |
+
return res
|
40 |
+
if data['num_gt_dets'] == 0:
|
41 |
+
res['HOTA_FP'] = data['num_tracker_dets'] * np.ones((len(self.array_labels)), dtype=float)
|
42 |
+
res['LocA'] = np.ones((len(self.array_labels)), dtype=float)
|
43 |
+
res['LocA(0)'] = 1.0
|
44 |
+
return res
|
45 |
+
|
46 |
+
# Variables counting global association
|
47 |
+
potential_matches_count = np.zeros((data['num_gt_ids'], data['num_tracker_ids']))
|
48 |
+
gt_id_count = np.zeros((data['num_gt_ids'], 1))
|
49 |
+
tracker_id_count = np.zeros((1, data['num_tracker_ids']))
|
50 |
+
|
51 |
+
# First loop through each timestep and accumulate global track information.
|
52 |
+
for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(data['gt_ids'], data['tracker_ids'])):
|
53 |
+
# Count the potential matches between ids in each timestep
|
54 |
+
# These are normalised, weighted by the match similarity.
|
55 |
+
similarity = data['similarity_scores'][t]
|
56 |
+
sim_iou_denom = similarity.sum(0)[np.newaxis, :] + similarity.sum(1)[:, np.newaxis] - similarity
|
57 |
+
sim_iou = np.zeros_like(similarity)
|
58 |
+
sim_iou_mask = sim_iou_denom > 0 + np.finfo('float').eps
|
59 |
+
sim_iou[sim_iou_mask] = similarity[sim_iou_mask] / sim_iou_denom[sim_iou_mask]
|
60 |
+
potential_matches_count[gt_ids_t[:, np.newaxis], tracker_ids_t[np.newaxis, :]] += sim_iou
|
61 |
+
|
62 |
+
# Calculate the total number of dets for each gt_id and tracker_id.
|
63 |
+
gt_id_count[gt_ids_t] += 1
|
64 |
+
tracker_id_count[0, tracker_ids_t] += 1
|
65 |
+
|
66 |
+
# Calculate overall jaccard alignment score (before unique matching) between IDs
|
67 |
+
global_alignment_score = potential_matches_count / (gt_id_count + tracker_id_count - potential_matches_count)
|
68 |
+
matches_counts = [np.zeros_like(potential_matches_count) for _ in self.array_labels]
|
69 |
+
|
70 |
+
# Calculate scores for each timestep
|
71 |
+
for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(data['gt_ids'], data['tracker_ids'])):
|
72 |
+
# Deal with the case that there are no gt_det/tracker_det in a timestep.
|
73 |
+
if len(gt_ids_t) == 0:
|
74 |
+
for a, alpha in enumerate(self.array_labels):
|
75 |
+
res['HOTA_FP'][a] += len(tracker_ids_t)
|
76 |
+
continue
|
77 |
+
if len(tracker_ids_t) == 0:
|
78 |
+
for a, alpha in enumerate(self.array_labels):
|
79 |
+
res['HOTA_FN'][a] += len(gt_ids_t)
|
80 |
+
continue
|
81 |
+
|
82 |
+
# Get matching scores between pairs of dets for optimizing HOTA
|
83 |
+
similarity = data['similarity_scores'][t]
|
84 |
+
score_mat = global_alignment_score[gt_ids_t[:, np.newaxis], tracker_ids_t[np.newaxis, :]] * similarity
|
85 |
+
|
86 |
+
# Hungarian algorithm to find best matches: 找出最优跟踪路线
|
87 |
+
match_rows, match_cols = linear_sum_assignment(-score_mat)
|
88 |
+
|
89 |
+
# Calculate and accumulate basic statistics
|
90 |
+
for a, alpha in enumerate(self.array_labels):
|
91 |
+
actually_matched_mask = similarity[match_rows, match_cols] >= alpha - np.finfo('float').eps
|
92 |
+
alpha_match_rows = match_rows[actually_matched_mask]
|
93 |
+
alpha_match_cols = match_cols[actually_matched_mask]
|
94 |
+
num_matches = len(alpha_match_rows)
|
95 |
+
res['HOTA_TP'][a] += num_matches
|
96 |
+
res['HOTA_FN'][a] += len(gt_ids_t) - num_matches
|
97 |
+
res['HOTA_FP'][a] += len(tracker_ids_t) - num_matches
|
98 |
+
if num_matches > 0:
|
99 |
+
res['LocA'][a] += sum(similarity[alpha_match_rows, alpha_match_cols])
|
100 |
+
matches_counts[a][gt_ids_t[alpha_match_rows], tracker_ids_t[alpha_match_cols]] += 1
|
101 |
+
|
102 |
+
# Calculate association scores (AssA, AssRe, AssPr) for the alpha value.
|
103 |
+
# First calculate scores per gt_id/tracker_id combo and then average over the number of detections.
|
104 |
+
for a, alpha in enumerate(self.array_labels):
|
105 |
+
matches_count = matches_counts[a]
|
106 |
+
ass_a = matches_count / np.maximum(1, gt_id_count + tracker_id_count - matches_count)
|
107 |
+
res['AssA'][a] = np.sum(matches_count * ass_a) / np.maximum(1, res['HOTA_TP'][a])
|
108 |
+
ass_re = matches_count / np.maximum(1, gt_id_count)
|
109 |
+
res['AssRe'][a] = np.sum(matches_count * ass_re) / np.maximum(1, res['HOTA_TP'][a])
|
110 |
+
ass_pr = matches_count / np.maximum(1, tracker_id_count)
|
111 |
+
res['AssPr'][a] = np.sum(matches_count * ass_pr) / np.maximum(1, res['HOTA_TP'][a])
|
112 |
+
|
113 |
+
# Calculate final scores
|
114 |
+
res['LocA'] = np.maximum(1e-10, res['LocA']) / np.maximum(1e-10, res['HOTA_TP'])
|
115 |
+
res = self._compute_final_fields(res)
|
116 |
+
return res
|
117 |
+
|
118 |
+
def combine_sequences(self, all_res):
|
119 |
+
"""Combines metrics across all sequences"""
|
120 |
+
res = {}
|
121 |
+
for field in self.integer_array_fields:
|
122 |
+
res[field] = self._combine_sum(all_res, field)
|
123 |
+
for field in ['AssRe', 'AssPr', 'AssA']:
|
124 |
+
res[field] = self._combine_weighted_av(all_res, field, res, weight_field='HOTA_TP')
|
125 |
+
loca_weighted_sum = sum([all_res[k]['LocA'] * all_res[k]['HOTA_TP'] for k in all_res.keys()])
|
126 |
+
res['LocA'] = np.maximum(1e-10, loca_weighted_sum) / np.maximum(1e-10, res['HOTA_TP'])
|
127 |
+
res = self._compute_final_fields(res)
|
128 |
+
return res
|
129 |
+
|
130 |
+
def combine_classes_class_averaged(self, all_res, ignore_empty_classes=False):
|
131 |
+
"""Combines metrics across all classes by averaging over the class values.
|
132 |
+
If 'ignore_empty_classes' is True, then it only sums over classes with at least one gt or predicted detection.
|
133 |
+
"""
|
134 |
+
res = {}
|
135 |
+
for field in self.integer_array_fields:
|
136 |
+
if ignore_empty_classes:
|
137 |
+
res[field] = self._combine_sum(
|
138 |
+
{k: v for k, v in all_res.items()
|
139 |
+
if (v['HOTA_TP'] + v['HOTA_FN'] + v['HOTA_FP'] > 0 + np.finfo('float').eps).any()}, field)
|
140 |
+
else:
|
141 |
+
res[field] = self._combine_sum({k: v for k, v in all_res.items()}, field)
|
142 |
+
|
143 |
+
for field in self.float_fields + self.float_array_fields:
|
144 |
+
if ignore_empty_classes:
|
145 |
+
res[field] = np.mean([v[field] for v in all_res.values() if
|
146 |
+
(v['HOTA_TP'] + v['HOTA_FN'] + v['HOTA_FP'] > 0 + np.finfo('float').eps).any()],
|
147 |
+
axis=0)
|
148 |
+
else:
|
149 |
+
res[field] = np.mean([v[field] for v in all_res.values()], axis=0)
|
150 |
+
return res
|
151 |
+
|
152 |
+
def combine_classes_det_averaged(self, all_res):
|
153 |
+
"""Combines metrics across all classes by averaging over the detection values"""
|
154 |
+
res = {}
|
155 |
+
for field in self.integer_array_fields:
|
156 |
+
res[field] = self._combine_sum(all_res, field)
|
157 |
+
for field in ['AssRe', 'AssPr', 'AssA']:
|
158 |
+
res[field] = self._combine_weighted_av(all_res, field, res, weight_field='HOTA_TP')
|
159 |
+
loca_weighted_sum = sum([all_res[k]['LocA'] * all_res[k]['HOTA_TP'] for k in all_res.keys()])
|
160 |
+
res['LocA'] = np.maximum(1e-10, loca_weighted_sum) / np.maximum(1e-10, res['HOTA_TP'])
|
161 |
+
res = self._compute_final_fields(res)
|
162 |
+
return res
|
163 |
+
|
164 |
+
@staticmethod
|
165 |
+
def _compute_final_fields(res):
|
166 |
+
"""Calculate sub-metric ('field') values which only depend on other sub-metric values.
|
167 |
+
This function is used both for both per-sequence calculation, and in combining values across sequences.
|
168 |
+
"""
|
169 |
+
res['DetRe'] = res['HOTA_TP'] / np.maximum(1, res['HOTA_TP'] + res['HOTA_FN'])
|
170 |
+
res['DetPr'] = res['HOTA_TP'] / np.maximum(1, res['HOTA_TP'] + res['HOTA_FP'])
|
171 |
+
res['DetA'] = res['HOTA_TP'] / np.maximum(1, res['HOTA_TP'] + res['HOTA_FN'] + res['HOTA_FP'])
|
172 |
+
res['HOTA'] = np.sqrt(res['DetA'] * res['AssA'])
|
173 |
+
res['OWTA'] = np.sqrt(res['DetRe'] * res['AssA'])
|
174 |
+
|
175 |
+
res['HOTA(0)'] = res['HOTA'][0]
|
176 |
+
res['LocA(0)'] = res['LocA'][0]
|
177 |
+
res['HOTALocA(0)'] = res['HOTA(0)']*res['LocA(0)']
|
178 |
+
return res
|
179 |
+
|
180 |
+
def plot_single_tracker_results(self, table_res, tracker, cls, output_folder):
|
181 |
+
"""Create plot of results"""
|
182 |
+
|
183 |
+
# Only loaded when run to reduce minimum requirements
|
184 |
+
from matplotlib import pyplot as plt
|
185 |
+
|
186 |
+
res = table_res['COMBINED_SEQ']
|
187 |
+
styles_to_plot = ['r', 'b', 'g', 'b--', 'b:', 'g--', 'g:', 'm']
|
188 |
+
for name, style in zip(self.float_array_fields, styles_to_plot):
|
189 |
+
plt.plot(self.array_labels, res[name], style)
|
190 |
+
plt.xlabel('alpha')
|
191 |
+
plt.ylabel('score')
|
192 |
+
plt.title(tracker + ' - ' + cls)
|
193 |
+
plt.axis([0, 1, 0, 1])
|
194 |
+
legend = []
|
195 |
+
for name in self.float_array_fields:
|
196 |
+
legend += [name + ' (' + str(np.round(np.mean(res[name]), 2)) + ')']
|
197 |
+
plt.legend(legend, loc='lower left')
|
198 |
+
out_file = os.path.join(output_folder, cls + '_plot.pdf')
|
199 |
+
os.makedirs(os.path.dirname(out_file), exist_ok=True)
|
200 |
+
plt.savefig(out_file)
|
201 |
+
plt.savefig(out_file.replace('.pdf', '.png'))
|
202 |
+
plt.clf()
|
avism/data/aviseval/metrics/identity.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy.optimize import linear_sum_assignment
|
3 |
+
from ._base_metric import _BaseMetric
|
4 |
+
from .. import _timing
|
5 |
+
from .. import utils
|
6 |
+
|
7 |
+
|
8 |
+
class Identity(_BaseMetric):
|
9 |
+
"""Class which implements the ID metrics"""
|
10 |
+
|
11 |
+
@staticmethod
|
12 |
+
def get_default_config():
|
13 |
+
"""Default class config values"""
|
14 |
+
default_config = {
|
15 |
+
'THRESHOLD': 0.5, # Similarity score threshold required for a IDTP match. Default 0.5.
|
16 |
+
'PRINT_CONFIG': True, # Whether to print the config information on init. Default: False.
|
17 |
+
}
|
18 |
+
return default_config
|
19 |
+
|
20 |
+
def __init__(self, config=None):
|
21 |
+
super().__init__()
|
22 |
+
self.integer_fields = ['IDTP', 'IDFN', 'IDFP']
|
23 |
+
self.float_fields = ['IDF1', 'IDR', 'IDP']
|
24 |
+
self.fields = self.float_fields + self.integer_fields
|
25 |
+
self.summary_fields = self.fields
|
26 |
+
|
27 |
+
# Configuration options:
|
28 |
+
self.config = utils.init_config(config, self.get_default_config(), self.get_name())
|
29 |
+
self.threshold = float(self.config['THRESHOLD'])
|
30 |
+
|
31 |
+
@_timing.time
|
32 |
+
def eval_sequence(self, data):
|
33 |
+
"""Calculates ID metrics for one sequence"""
|
34 |
+
# Initialise results
|
35 |
+
res = {}
|
36 |
+
for field in self.fields:
|
37 |
+
res[field] = 0
|
38 |
+
|
39 |
+
# Return result quickly if tracker or gt sequence is empty
|
40 |
+
if data['num_tracker_dets'] == 0:
|
41 |
+
res['IDFN'] = data['num_gt_dets']
|
42 |
+
return res
|
43 |
+
if data['num_gt_dets'] == 0:
|
44 |
+
res['IDFP'] = data['num_tracker_dets']
|
45 |
+
return res
|
46 |
+
|
47 |
+
# Variables counting global association
|
48 |
+
potential_matches_count = np.zeros((data['num_gt_ids'], data['num_tracker_ids']))
|
49 |
+
gt_id_count = np.zeros(data['num_gt_ids'])
|
50 |
+
tracker_id_count = np.zeros(data['num_tracker_ids'])
|
51 |
+
|
52 |
+
# First loop through each timestep and accumulate global track information.
|
53 |
+
for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(data['gt_ids'], data['tracker_ids'])):
|
54 |
+
# Count the potential matches between ids in each timestep
|
55 |
+
matches_mask = np.greater_equal(data['similarity_scores'][t], self.threshold)
|
56 |
+
match_idx_gt, match_idx_tracker = np.nonzero(matches_mask)
|
57 |
+
potential_matches_count[gt_ids_t[match_idx_gt], tracker_ids_t[match_idx_tracker]] += 1
|
58 |
+
|
59 |
+
# Calculate the total number of dets for each gt_id and tracker_id.
|
60 |
+
gt_id_count[gt_ids_t] += 1
|
61 |
+
tracker_id_count[tracker_ids_t] += 1
|
62 |
+
|
63 |
+
# Calculate optimal assignment cost matrix for ID metrics
|
64 |
+
num_gt_ids = data['num_gt_ids']
|
65 |
+
num_tracker_ids = data['num_tracker_ids']
|
66 |
+
fp_mat = np.zeros((num_gt_ids + num_tracker_ids, num_gt_ids + num_tracker_ids))
|
67 |
+
fn_mat = np.zeros((num_gt_ids + num_tracker_ids, num_gt_ids + num_tracker_ids))
|
68 |
+
fp_mat[num_gt_ids:, :num_tracker_ids] = 1e10
|
69 |
+
fn_mat[:num_gt_ids, num_tracker_ids:] = 1e10
|
70 |
+
for gt_id in range(num_gt_ids):
|
71 |
+
fn_mat[gt_id, :num_tracker_ids] = gt_id_count[gt_id]
|
72 |
+
fn_mat[gt_id, num_tracker_ids + gt_id] = gt_id_count[gt_id]
|
73 |
+
for tracker_id in range(num_tracker_ids):
|
74 |
+
fp_mat[:num_gt_ids, tracker_id] = tracker_id_count[tracker_id]
|
75 |
+
fp_mat[tracker_id + num_gt_ids, tracker_id] = tracker_id_count[tracker_id]
|
76 |
+
fn_mat[:num_gt_ids, :num_tracker_ids] -= potential_matches_count
|
77 |
+
fp_mat[:num_gt_ids, :num_tracker_ids] -= potential_matches_count
|
78 |
+
|
79 |
+
# Hungarian algorithm
|
80 |
+
match_rows, match_cols = linear_sum_assignment(fn_mat + fp_mat)
|
81 |
+
|
82 |
+
# Accumulate basic statistics
|
83 |
+
res['IDFN'] = fn_mat[match_rows, match_cols].sum().astype(int)
|
84 |
+
res['IDFP'] = fp_mat[match_rows, match_cols].sum().astype(int)
|
85 |
+
res['IDTP'] = (gt_id_count.sum() - res['IDFN']).astype(int)
|
86 |
+
|
87 |
+
# Calculate final ID scores
|
88 |
+
res = self._compute_final_fields(res)
|
89 |
+
return res
|
90 |
+
|
91 |
+
def combine_classes_class_averaged(self, all_res, ignore_empty_classes=False):
|
92 |
+
"""Combines metrics across all classes by averaging over the class values.
|
93 |
+
If 'ignore_empty_classes' is True, then it only sums over classes with at least one gt or predicted detection.
|
94 |
+
"""
|
95 |
+
res = {}
|
96 |
+
for field in self.integer_fields:
|
97 |
+
if ignore_empty_classes:
|
98 |
+
res[field] = self._combine_sum({k: v for k, v in all_res.items()
|
99 |
+
if v['IDTP'] + v['IDFN'] + v['IDFP'] > 0 + np.finfo('float').eps},
|
100 |
+
field)
|
101 |
+
else:
|
102 |
+
res[field] = self._combine_sum({k: v for k, v in all_res.items()}, field)
|
103 |
+
for field in self.float_fields:
|
104 |
+
if ignore_empty_classes:
|
105 |
+
res[field] = np.mean([v[field] for v in all_res.values()
|
106 |
+
if v['IDTP'] + v['IDFN'] + v['IDFP'] > 0 + np.finfo('float').eps], axis=0)
|
107 |
+
else:
|
108 |
+
res[field] = np.mean([v[field] for v in all_res.values()], axis=0)
|
109 |
+
return res
|
110 |
+
|
111 |
+
def combine_classes_det_averaged(self, all_res):
|
112 |
+
"""Combines metrics across all classes by averaging over the detection values"""
|
113 |
+
res = {}
|
114 |
+
for field in self.integer_fields:
|
115 |
+
res[field] = self._combine_sum(all_res, field)
|
116 |
+
res = self._compute_final_fields(res)
|
117 |
+
return res
|
118 |
+
|
119 |
+
def combine_sequences(self, all_res):
|
120 |
+
"""Combines metrics across all sequences"""
|
121 |
+
res = {}
|
122 |
+
for field in self.integer_fields:
|
123 |
+
res[field] = self._combine_sum(all_res, field)
|
124 |
+
res = self._compute_final_fields(res)
|
125 |
+
return res
|
126 |
+
|
127 |
+
@staticmethod
|
128 |
+
def _compute_final_fields(res):
|
129 |
+
"""Calculate sub-metric ('field') values which only depend on other sub-metric values.
|
130 |
+
This function is used both for both per-sequence calculation, and in combining values across sequences.
|
131 |
+
"""
|
132 |
+
res['IDR'] = res['IDTP'] / np.maximum(1.0, res['IDTP'] + res['IDFN'])
|
133 |
+
res['IDP'] = res['IDTP'] / np.maximum(1.0, res['IDTP'] + res['IDFP'])
|
134 |
+
res['IDF1'] = res['IDTP'] / np.maximum(1.0, res['IDTP'] + 0.5 * res['IDFP'] + 0.5 * res['IDFN'])
|
135 |
+
return res
|
avism/data/aviseval/metrics/ideucl.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy.optimize import linear_sum_assignment
|
3 |
+
from ._base_metric import _BaseMetric
|
4 |
+
from .. import _timing
|
5 |
+
from collections import defaultdict
|
6 |
+
from .. import utils
|
7 |
+
|
8 |
+
|
9 |
+
class IDEucl(_BaseMetric):
|
10 |
+
"""Class which implements the ID metrics"""
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
def get_default_config():
|
14 |
+
"""Default class config values"""
|
15 |
+
default_config = {
|
16 |
+
'THRESHOLD': 0.4, # Similarity score threshold required for a IDTP match. 0.4 for IDEucl.
|
17 |
+
'PRINT_CONFIG': True, # Whether to print the config information on init. Default: False.
|
18 |
+
}
|
19 |
+
return default_config
|
20 |
+
|
21 |
+
def __init__(self, config=None):
|
22 |
+
super().__init__()
|
23 |
+
self.fields = ['IDEucl']
|
24 |
+
self.float_fields = self.fields
|
25 |
+
self.summary_fields = self.fields
|
26 |
+
|
27 |
+
# Configuration options:
|
28 |
+
self.config = utils.init_config(config, self.get_default_config(), self.get_name())
|
29 |
+
self.threshold = float(self.config['THRESHOLD'])
|
30 |
+
|
31 |
+
|
32 |
+
@_timing.time
|
33 |
+
def eval_sequence(self, data):
|
34 |
+
"""Calculates IDEucl metrics for all frames"""
|
35 |
+
# Initialise results
|
36 |
+
res = {'IDEucl' : 0}
|
37 |
+
|
38 |
+
# Return result quickly if tracker or gt sequence is empty
|
39 |
+
if data['num_tracker_dets'] == 0 or data['num_gt_dets'] == 0.:
|
40 |
+
return res
|
41 |
+
|
42 |
+
data['centroid'] = []
|
43 |
+
for t, gt_det in enumerate(data['gt_dets']):
|
44 |
+
# import pdb;pdb.set_trace()
|
45 |
+
data['centroid'].append(self._compute_centroid(gt_det))
|
46 |
+
|
47 |
+
oid_hid_cent = defaultdict(list)
|
48 |
+
oid_cent = defaultdict(list)
|
49 |
+
for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(data['gt_ids'], data['tracker_ids'])):
|
50 |
+
matches_mask = np.greater_equal(data['similarity_scores'][t], self.threshold)
|
51 |
+
|
52 |
+
# I hope the orders of ids and boxes are maintained in `data`
|
53 |
+
for ind, gid in enumerate(gt_ids_t):
|
54 |
+
oid_cent[gid].append(data['centroid'][t][ind])
|
55 |
+
|
56 |
+
match_idx_gt, match_idx_tracker = np.nonzero(matches_mask)
|
57 |
+
for m_gid, m_tid in zip(match_idx_gt, match_idx_tracker):
|
58 |
+
oid_hid_cent[gt_ids_t[m_gid], tracker_ids_t[m_tid]].append(data['centroid'][t][m_gid])
|
59 |
+
|
60 |
+
oid_hid_dist = {k : np.sum(np.linalg.norm(np.diff(np.array(v), axis=0), axis=1)) for k, v in oid_hid_cent.items()}
|
61 |
+
oid_dist = {int(k) : np.sum(np.linalg.norm(np.diff(np.array(v), axis=0), axis=1)) for k, v in oid_cent.items()}
|
62 |
+
|
63 |
+
unique_oid = np.unique([i[0] for i in oid_hid_dist.keys()]).tolist()
|
64 |
+
unique_hid = np.unique([i[1] for i in oid_hid_dist.keys()]).tolist()
|
65 |
+
o_len = len(unique_oid)
|
66 |
+
h_len = len(unique_hid)
|
67 |
+
dist_matrix = np.zeros((o_len, h_len))
|
68 |
+
for ((oid, hid), dist) in oid_hid_dist.items():
|
69 |
+
oid_ind = unique_oid.index(oid)
|
70 |
+
hid_ind = unique_hid.index(hid)
|
71 |
+
dist_matrix[oid_ind, hid_ind] = dist
|
72 |
+
|
73 |
+
# opt_hyp_dist contains GT ID : max dist covered by track
|
74 |
+
opt_hyp_dist = dict.fromkeys(oid_dist.keys(), 0.)
|
75 |
+
cost_matrix = np.max(dist_matrix) - dist_matrix
|
76 |
+
rows, cols = linear_sum_assignment(cost_matrix)
|
77 |
+
for (row, col) in zip(rows, cols):
|
78 |
+
value = dist_matrix[row, col]
|
79 |
+
opt_hyp_dist[int(unique_oid[row])] = value
|
80 |
+
|
81 |
+
assert len(opt_hyp_dist.keys()) == len(oid_dist.keys())
|
82 |
+
hyp_length = np.sum(list(opt_hyp_dist.values()))
|
83 |
+
gt_length = np.sum(list(oid_dist.values()))
|
84 |
+
id_eucl =np.mean([np.divide(a, b, out=np.zeros_like(a), where=b!=0) for a, b in zip(opt_hyp_dist.values(), oid_dist.values())])
|
85 |
+
res['IDEucl'] = np.divide(hyp_length, gt_length, out=np.zeros_like(hyp_length), where=gt_length!=0)
|
86 |
+
return res
|
87 |
+
|
88 |
+
def combine_classes_class_averaged(self, all_res, ignore_empty_classes=False):
|
89 |
+
"""Combines metrics across all classes by averaging over the class values.
|
90 |
+
If 'ignore_empty_classes' is True, then it only sums over classes with at least one gt or predicted detection.
|
91 |
+
"""
|
92 |
+
res = {}
|
93 |
+
|
94 |
+
for field in self.float_fields:
|
95 |
+
if ignore_empty_classes:
|
96 |
+
res[field] = np.mean([v[field] for v in all_res.values()
|
97 |
+
if v['IDEucl'] > 0 + np.finfo('float').eps], axis=0)
|
98 |
+
else:
|
99 |
+
res[field] = np.mean([v[field] for v in all_res.values()], axis=0)
|
100 |
+
return res
|
101 |
+
|
102 |
+
def combine_classes_det_averaged(self, all_res):
|
103 |
+
"""Combines metrics across all classes by averaging over the detection values"""
|
104 |
+
res = {}
|
105 |
+
for field in self.float_fields:
|
106 |
+
res[field] = self._combine_sum(all_res, field)
|
107 |
+
res = self._compute_final_fields(res, len(all_res))
|
108 |
+
return res
|
109 |
+
|
110 |
+
def combine_sequences(self, all_res):
|
111 |
+
"""Combines metrics across all sequences"""
|
112 |
+
res = {}
|
113 |
+
for field in self.float_fields:
|
114 |
+
res[field] = self._combine_sum(all_res, field)
|
115 |
+
res = self._compute_final_fields(res, len(all_res))
|
116 |
+
return res
|
117 |
+
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def _compute_centroid(box):
|
121 |
+
box = np.array(box)
|
122 |
+
if len(box.shape) == 1:
|
123 |
+
centroid = (box[0:2] + box[2:4])/2
|
124 |
+
else:
|
125 |
+
centroid = (box[:, 0:2] + box[:, 2:4])/2
|
126 |
+
return np.flip(centroid, axis=1)
|
127 |
+
|
128 |
+
|
129 |
+
@staticmethod
|
130 |
+
def _compute_final_fields(res, res_len):
|
131 |
+
"""
|
132 |
+
Exists only to match signature with the original Identiy class.
|
133 |
+
|
134 |
+
"""
|
135 |
+
return {k:v/res_len for k,v in res.items()}
|
avism/data/aviseval/metrics/j_and_f.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
import math
|
4 |
+
from scipy.optimize import linear_sum_assignment
|
5 |
+
from ..utils import TrackEvalException
|
6 |
+
from ._base_metric import _BaseMetric
|
7 |
+
from .. import _timing
|
8 |
+
|
9 |
+
|
10 |
+
class JAndF(_BaseMetric):
|
11 |
+
"""Class which implements the J&F metrics"""
|
12 |
+
def __init__(self, config=None):
|
13 |
+
super().__init__()
|
14 |
+
self.integer_fields = ['num_gt_tracks']
|
15 |
+
self.float_fields = ['J-Mean', 'J-Recall', 'J-Decay', 'F-Mean', 'F-Recall', 'F-Decay', 'J&F']
|
16 |
+
self.fields = self.float_fields + self.integer_fields
|
17 |
+
self.summary_fields = self.float_fields
|
18 |
+
self.optim_type = 'J' # possible values J, J&F
|
19 |
+
|
20 |
+
@_timing.time
|
21 |
+
def eval_sequence(self, data):
|
22 |
+
"""Returns J&F metrics for one sequence"""
|
23 |
+
|
24 |
+
# Only loaded when run to reduce minimum requirements
|
25 |
+
from pycocotools import mask as mask_utils
|
26 |
+
|
27 |
+
num_timesteps = data['num_timesteps']
|
28 |
+
num_tracker_ids = data['num_tracker_ids']
|
29 |
+
num_gt_ids = data['num_gt_ids']
|
30 |
+
gt_dets = data['gt_dets']
|
31 |
+
tracker_dets = data['tracker_dets']
|
32 |
+
gt_ids = data['gt_ids']
|
33 |
+
tracker_ids = data['tracker_ids']
|
34 |
+
|
35 |
+
# get shape of frames
|
36 |
+
frame_shape = None
|
37 |
+
if num_gt_ids > 0:
|
38 |
+
for t in range(num_timesteps):
|
39 |
+
if len(gt_ids[t]) > 0:
|
40 |
+
frame_shape = gt_dets[t][0]['size']
|
41 |
+
break
|
42 |
+
elif num_tracker_ids > 0:
|
43 |
+
for t in range(num_timesteps):
|
44 |
+
if len(tracker_ids[t]) > 0:
|
45 |
+
frame_shape = tracker_dets[t][0]['size']
|
46 |
+
break
|
47 |
+
|
48 |
+
if frame_shape:
|
49 |
+
# append all zero masks for timesteps in which tracks do not have a detection
|
50 |
+
zero_padding = np.zeros((frame_shape), order= 'F').astype(np.uint8)
|
51 |
+
padding_mask = mask_utils.encode(zero_padding)
|
52 |
+
for t in range(num_timesteps):
|
53 |
+
gt_id_det_mapping = {gt_ids[t][i]: gt_dets[t][i] for i in range(len(gt_ids[t]))}
|
54 |
+
gt_dets[t] = [gt_id_det_mapping[index] if index in gt_ids[t] else padding_mask for index
|
55 |
+
in range(num_gt_ids)]
|
56 |
+
tracker_id_det_mapping = {tracker_ids[t][i]: tracker_dets[t][i] for i in range(len(tracker_ids[t]))}
|
57 |
+
tracker_dets[t] = [tracker_id_det_mapping[index] if index in tracker_ids[t] else padding_mask for index
|
58 |
+
in range(num_tracker_ids)]
|
59 |
+
# also perform zero padding if number of tracker IDs < number of ground truth IDs
|
60 |
+
if num_tracker_ids < num_gt_ids:
|
61 |
+
diff = num_gt_ids - num_tracker_ids
|
62 |
+
for t in range(num_timesteps):
|
63 |
+
tracker_dets[t] = tracker_dets[t] + [padding_mask for _ in range(diff)]
|
64 |
+
num_tracker_ids += diff
|
65 |
+
|
66 |
+
j = self._compute_j(gt_dets, tracker_dets, num_gt_ids, num_tracker_ids, num_timesteps)
|
67 |
+
|
68 |
+
# boundary threshold for F computation
|
69 |
+
bound_th = 0.008
|
70 |
+
|
71 |
+
# perform matching
|
72 |
+
if self.optim_type == 'J&F':
|
73 |
+
f = np.zeros_like(j)
|
74 |
+
for k in range(num_tracker_ids):
|
75 |
+
for i in range(num_gt_ids):
|
76 |
+
f[k, i, :] = self._compute_f(gt_dets, tracker_dets, k, i, bound_th)
|
77 |
+
optim_metrics = (np.mean(j, axis=2) + np.mean(f, axis=2)) / 2
|
78 |
+
row_ind, col_ind = linear_sum_assignment(- optim_metrics)
|
79 |
+
j_m = j[row_ind, col_ind, :]
|
80 |
+
f_m = f[row_ind, col_ind, :]
|
81 |
+
elif self.optim_type == 'J':
|
82 |
+
optim_metrics = np.mean(j, axis=2)
|
83 |
+
row_ind, col_ind = linear_sum_assignment(- optim_metrics)
|
84 |
+
j_m = j[row_ind, col_ind, :]
|
85 |
+
f_m = np.zeros_like(j_m)
|
86 |
+
for i, (tr_ind, gt_ind) in enumerate(zip(row_ind, col_ind)):
|
87 |
+
f_m[i] = self._compute_f(gt_dets, tracker_dets, tr_ind, gt_ind, bound_th)
|
88 |
+
else:
|
89 |
+
raise TrackEvalException('Unsupported optimization type %s for J&F metric.' % self.optim_type)
|
90 |
+
|
91 |
+
# append zeros for false negatives
|
92 |
+
if j_m.shape[0] < data['num_gt_ids']:
|
93 |
+
diff = data['num_gt_ids'] - j_m.shape[0]
|
94 |
+
j_m = np.concatenate((j_m, np.zeros((diff, j_m.shape[1]))), axis=0)
|
95 |
+
f_m = np.concatenate((f_m, np.zeros((diff, f_m.shape[1]))), axis=0)
|
96 |
+
|
97 |
+
# compute the metrics for each ground truth track
|
98 |
+
res = {
|
99 |
+
'J-Mean': [np.nanmean(j_m[i, :]) for i in range(j_m.shape[0])],
|
100 |
+
'J-Recall': [np.nanmean(j_m[i, :] > 0.5 + np.finfo('float').eps) for i in range(j_m.shape[0])],
|
101 |
+
'F-Mean': [np.nanmean(f_m[i, :]) for i in range(f_m.shape[0])],
|
102 |
+
'F-Recall': [np.nanmean(f_m[i, :] > 0.5 + np.finfo('float').eps) for i in range(f_m.shape[0])],
|
103 |
+
'J-Decay': [],
|
104 |
+
'F-Decay': []
|
105 |
+
}
|
106 |
+
n_bins = 4
|
107 |
+
ids = np.round(np.linspace(1, data['num_timesteps'], n_bins + 1) + 1e-10) - 1
|
108 |
+
ids = ids.astype(np.uint8)
|
109 |
+
|
110 |
+
for k in range(j_m.shape[0]):
|
111 |
+
d_bins_j = [j_m[k][ids[i]:ids[i + 1] + 1] for i in range(0, n_bins)]
|
112 |
+
res['J-Decay'].append(np.nanmean(d_bins_j[0]) - np.nanmean(d_bins_j[3]))
|
113 |
+
for k in range(f_m.shape[0]):
|
114 |
+
d_bins_f = [f_m[k][ids[i]:ids[i + 1] + 1] for i in range(0, n_bins)]
|
115 |
+
res['F-Decay'].append(np.nanmean(d_bins_f[0]) - np.nanmean(d_bins_f[3]))
|
116 |
+
|
117 |
+
# count number of tracks for weighting of the result
|
118 |
+
res['num_gt_tracks'] = len(res['J-Mean'])
|
119 |
+
for field in ['J-Mean', 'J-Recall', 'J-Decay', 'F-Mean', 'F-Recall', 'F-Decay']:
|
120 |
+
res[field] = np.mean(res[field])
|
121 |
+
res['J&F'] = (res['J-Mean'] + res['F-Mean']) / 2
|
122 |
+
return res
|
123 |
+
|
124 |
+
def combine_sequences(self, all_res):
|
125 |
+
"""Combines metrics across all sequences"""
|
126 |
+
res = {'num_gt_tracks': self._combine_sum(all_res, 'num_gt_tracks')}
|
127 |
+
for field in self.summary_fields:
|
128 |
+
res[field] = self._combine_weighted_av(all_res, field, res, weight_field='num_gt_tracks')
|
129 |
+
return res
|
130 |
+
|
131 |
+
def combine_classes_class_averaged(self, all_res, ignore_empty_classes=False):
|
132 |
+
"""Combines metrics across all classes by averaging over the class values
|
133 |
+
'ignore empty classes' is not yet implemented here.
|
134 |
+
"""
|
135 |
+
res = {'num_gt_tracks': self._combine_sum(all_res, 'num_gt_tracks')}
|
136 |
+
for field in self.float_fields:
|
137 |
+
res[field] = np.mean([v[field] for v in all_res.values()])
|
138 |
+
return res
|
139 |
+
|
140 |
+
def combine_classes_det_averaged(self, all_res):
|
141 |
+
"""Combines metrics across all classes by averaging over the detection values"""
|
142 |
+
res = {'num_gt_tracks': self._combine_sum(all_res, 'num_gt_tracks')}
|
143 |
+
for field in self.float_fields:
|
144 |
+
res[field] = np.mean([v[field] for v in all_res.values()])
|
145 |
+
return res
|
146 |
+
|
147 |
+
@staticmethod
|
148 |
+
def _seg2bmap(seg, width=None, height=None):
|
149 |
+
"""
|
150 |
+
From a segmentation, compute a binary boundary map with 1 pixel wide
|
151 |
+
boundaries. The boundary pixels are offset by 1/2 pixel towards the
|
152 |
+
origin from the actual segment boundary.
|
153 |
+
Arguments:
|
154 |
+
seg : Segments labeled from 1..k.
|
155 |
+
width : Width of desired bmap <= seg.shape[1]
|
156 |
+
height : Height of desired bmap <= seg.shape[0]
|
157 |
+
Returns:
|
158 |
+
bmap (ndarray): Binary boundary map.
|
159 |
+
David Martin <[email protected]>
|
160 |
+
January 2003
|
161 |
+
"""
|
162 |
+
|
163 |
+
seg = seg.astype(bool)
|
164 |
+
seg[seg > 0] = 1
|
165 |
+
|
166 |
+
assert np.atleast_3d(seg).shape[2] == 1
|
167 |
+
|
168 |
+
width = seg.shape[1] if width is None else width
|
169 |
+
height = seg.shape[0] if height is None else height
|
170 |
+
|
171 |
+
h, w = seg.shape[:2]
|
172 |
+
|
173 |
+
ar1 = float(width) / float(height)
|
174 |
+
ar2 = float(w) / float(h)
|
175 |
+
|
176 |
+
assert not (
|
177 |
+
width > w | height > h | abs(ar1 - ar2) > 0.01
|
178 |
+
), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height)
|
179 |
+
|
180 |
+
e = np.zeros_like(seg)
|
181 |
+
s = np.zeros_like(seg)
|
182 |
+
se = np.zeros_like(seg)
|
183 |
+
|
184 |
+
e[:, :-1] = seg[:, 1:]
|
185 |
+
s[:-1, :] = seg[1:, :]
|
186 |
+
se[:-1, :-1] = seg[1:, 1:]
|
187 |
+
|
188 |
+
b = seg ^ e | seg ^ s | seg ^ se
|
189 |
+
b[-1, :] = seg[-1, :] ^ e[-1, :]
|
190 |
+
b[:, -1] = seg[:, -1] ^ s[:, -1]
|
191 |
+
b[-1, -1] = 0
|
192 |
+
|
193 |
+
if w == width and h == height:
|
194 |
+
bmap = b
|
195 |
+
else:
|
196 |
+
bmap = np.zeros((height, width))
|
197 |
+
for x in range(w):
|
198 |
+
for y in range(h):
|
199 |
+
if b[y, x]:
|
200 |
+
j = 1 + math.floor((y - 1) + height / h)
|
201 |
+
i = 1 + math.floor((x - 1) + width / h)
|
202 |
+
bmap[j, i] = 1
|
203 |
+
|
204 |
+
return bmap
|
205 |
+
|
206 |
+
@staticmethod
|
207 |
+
def _compute_f(gt_data, tracker_data, tracker_data_id, gt_id, bound_th):
|
208 |
+
"""
|
209 |
+
Perform F computation for a given gt and a given tracker ID. Adapted from
|
210 |
+
https://github.com/davisvideochallenge/davis2017-evaluation
|
211 |
+
:param gt_data: the encoded gt masks
|
212 |
+
:param tracker_data: the encoded tracker masks
|
213 |
+
:param tracker_data_id: the tracker ID
|
214 |
+
:param gt_id: the ground truth ID
|
215 |
+
:param bound_th: boundary threshold parameter
|
216 |
+
:return: the F value for the given tracker and gt ID
|
217 |
+
"""
|
218 |
+
|
219 |
+
# Only loaded when run to reduce minimum requirements
|
220 |
+
from pycocotools import mask as mask_utils
|
221 |
+
from skimage.morphology import disk
|
222 |
+
import cv2
|
223 |
+
|
224 |
+
f = np.zeros(len(gt_data))
|
225 |
+
|
226 |
+
for t, (gt_masks, tracker_masks) in enumerate(zip(gt_data, tracker_data)):
|
227 |
+
curr_tracker_mask = mask_utils.decode(tracker_masks[tracker_data_id])
|
228 |
+
curr_gt_mask = mask_utils.decode(gt_masks[gt_id])
|
229 |
+
|
230 |
+
bound_pix = bound_th if bound_th >= 1 - np.finfo('float').eps else \
|
231 |
+
np.ceil(bound_th * np.linalg.norm(curr_tracker_mask.shape))
|
232 |
+
|
233 |
+
# Get the pixel boundaries of both masks
|
234 |
+
fg_boundary = JAndF._seg2bmap(curr_tracker_mask)
|
235 |
+
gt_boundary = JAndF._seg2bmap(curr_gt_mask)
|
236 |
+
|
237 |
+
# fg_dil = binary_dilation(fg_boundary, disk(bound_pix))
|
238 |
+
fg_dil = cv2.dilate(fg_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8))
|
239 |
+
# gt_dil = binary_dilation(gt_boundary, disk(bound_pix))
|
240 |
+
gt_dil = cv2.dilate(gt_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8))
|
241 |
+
|
242 |
+
# Get the intersection
|
243 |
+
gt_match = gt_boundary * fg_dil
|
244 |
+
fg_match = fg_boundary * gt_dil
|
245 |
+
|
246 |
+
# Area of the intersection
|
247 |
+
n_fg = np.sum(fg_boundary)
|
248 |
+
n_gt = np.sum(gt_boundary)
|
249 |
+
|
250 |
+
# % Compute precision and recall
|
251 |
+
if n_fg == 0 and n_gt > 0:
|
252 |
+
precision = 1
|
253 |
+
recall = 0
|
254 |
+
elif n_fg > 0 and n_gt == 0:
|
255 |
+
precision = 0
|
256 |
+
recall = 1
|
257 |
+
elif n_fg == 0 and n_gt == 0:
|
258 |
+
precision = 1
|
259 |
+
recall = 1
|
260 |
+
else:
|
261 |
+
precision = np.sum(fg_match) / float(n_fg)
|
262 |
+
recall = np.sum(gt_match) / float(n_gt)
|
263 |
+
|
264 |
+
# Compute F measure
|
265 |
+
if precision + recall == 0:
|
266 |
+
f_val = 0
|
267 |
+
else:
|
268 |
+
f_val = 2 * precision * recall / (precision + recall)
|
269 |
+
|
270 |
+
f[t] = f_val
|
271 |
+
|
272 |
+
return f
|
273 |
+
|
274 |
+
@staticmethod
|
275 |
+
def _compute_j(gt_data, tracker_data, num_gt_ids, num_tracker_ids, num_timesteps):
|
276 |
+
"""
|
277 |
+
Computation of J value for all ground truth IDs and all tracker IDs in the given sequence. Adapted from
|
278 |
+
https://github.com/davisvideochallenge/davis2017-evaluation
|
279 |
+
:param gt_data: the ground truth masks
|
280 |
+
:param tracker_data: the tracker masks
|
281 |
+
:param num_gt_ids: the number of ground truth IDs
|
282 |
+
:param num_tracker_ids: the number of tracker IDs
|
283 |
+
:param num_timesteps: the number of timesteps
|
284 |
+
:return: the J values
|
285 |
+
"""
|
286 |
+
|
287 |
+
# Only loaded when run to reduce minimum requirements
|
288 |
+
from pycocotools import mask as mask_utils
|
289 |
+
|
290 |
+
j = np.zeros((num_tracker_ids, num_gt_ids, num_timesteps))
|
291 |
+
|
292 |
+
for t, (time_gt, time_data) in enumerate(zip(gt_data, tracker_data)):
|
293 |
+
# run length encoded masks with pycocotools
|
294 |
+
area_gt = mask_utils.area(time_gt)
|
295 |
+
time_data = list(time_data)
|
296 |
+
area_tr = mask_utils.area(time_data)
|
297 |
+
|
298 |
+
area_tr = np.repeat(area_tr[:, np.newaxis], len(area_gt), axis=1)
|
299 |
+
area_gt = np.repeat(area_gt[np.newaxis, :], len(area_tr), axis=0)
|
300 |
+
|
301 |
+
# mask iou computation with pycocotools
|
302 |
+
ious = np.atleast_2d(mask_utils.iou(time_data, time_gt, [0]*len(time_gt)))
|
303 |
+
# set iou to 1 if both masks are close to 0 (no ground truth and no predicted mask in timestep)
|
304 |
+
ious[np.isclose(area_tr, 0) & np.isclose(area_gt, 0)] = 1
|
305 |
+
assert (ious >= 0 - np.finfo('float').eps).all()
|
306 |
+
assert (ious <= 1 + np.finfo('float').eps).all()
|
307 |
+
|
308 |
+
j[..., t] = ious
|
309 |
+
|
310 |
+
return j
|
avism/data/aviseval/metrics/track_map.py
ADDED
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from ._base_metric import _BaseMetric
|
3 |
+
from .. import _timing
|
4 |
+
from functools import partial
|
5 |
+
from .. import utils
|
6 |
+
from ..utils import TrackEvalException
|
7 |
+
|
8 |
+
|
9 |
+
class TrackMAP(_BaseMetric):
|
10 |
+
"""Class which implements the TrackMAP metrics"""
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
def get_default_metric_config():
|
14 |
+
"""Default class config values"""
|
15 |
+
default_config = {
|
16 |
+
'USE_AREA_RANGES': True, # whether to evaluate for certain area ranges
|
17 |
+
'AREA_RANGES': [[0 ** 2, 32 ** 2], # additional area range sets for which TrackMAP is evaluated
|
18 |
+
[32 ** 2, 96 ** 2], # (all area range always included), default values for TAO
|
19 |
+
[96 ** 2, 1e5 ** 2]], # evaluation
|
20 |
+
'AREA_RANGE_LABELS': ["area_s", "area_m", "area_l"], # the labels for the area ranges
|
21 |
+
'USE_TIME_RANGES': True, # whether to evaluate for certain time ranges (length of tracks)
|
22 |
+
'TIME_RANGES': [[0, 3], [3, 10], [10, 1e5]], # additional time range sets for which TrackMAP is evaluated
|
23 |
+
# (all time range always included) , default values for TAO evaluation
|
24 |
+
'TIME_RANGE_LABELS': ["time_s", "time_m", "time_l"], # the labels for the time ranges
|
25 |
+
'IOU_THRESHOLDS': np.arange(0.5, 0.96, 0.05), # the IoU thresholds
|
26 |
+
'RECALL_THRESHOLDS': np.linspace(0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01) + 1), endpoint=True),
|
27 |
+
# recall thresholds at which precision is evaluated
|
28 |
+
'MAX_DETECTIONS': 0, # limit the maximum number of considered tracks per sequence (0 for unlimited)
|
29 |
+
'PRINT_CONFIG': False
|
30 |
+
}
|
31 |
+
return default_config
|
32 |
+
|
33 |
+
def __init__(self, config=None):
|
34 |
+
super().__init__()
|
35 |
+
self.config = utils.init_config(config, self.get_default_metric_config(), self.get_name())
|
36 |
+
|
37 |
+
self.num_ig_masks = 1
|
38 |
+
self.lbls = ['all']
|
39 |
+
self.use_area_rngs = self.config['USE_AREA_RANGES']
|
40 |
+
if self.use_area_rngs:
|
41 |
+
self.area_rngs = self.config['AREA_RANGES']
|
42 |
+
self.area_rng_lbls = self.config['AREA_RANGE_LABELS']
|
43 |
+
self.num_ig_masks += len(self.area_rng_lbls)
|
44 |
+
self.lbls += self.area_rng_lbls
|
45 |
+
|
46 |
+
self.use_time_rngs = self.config['USE_TIME_RANGES']
|
47 |
+
if self.use_time_rngs:
|
48 |
+
self.time_rngs = self.config['TIME_RANGES']
|
49 |
+
self.time_rng_lbls = self.config['TIME_RANGE_LABELS']
|
50 |
+
self.num_ig_masks += len(self.time_rng_lbls)
|
51 |
+
self.lbls += self.time_rng_lbls
|
52 |
+
|
53 |
+
self.array_labels = self.config['IOU_THRESHOLDS']
|
54 |
+
self.rec_thrs = self.config['RECALL_THRESHOLDS']
|
55 |
+
|
56 |
+
self.maxDet = self.config['MAX_DETECTIONS']
|
57 |
+
self.float_array_fields = ['AP_' + lbl for lbl in self.lbls] + ['AR_' + lbl for lbl in self.lbls]
|
58 |
+
self.fields = self.float_array_fields
|
59 |
+
self.summary_fields = self.float_array_fields
|
60 |
+
|
61 |
+
@_timing.time
|
62 |
+
def eval_sequence(self, data):
|
63 |
+
"""Calculates GT and Tracker matches for one sequence for TrackMAP metrics. Adapted from
|
64 |
+
https://github.com/TAO-Dataset/"""
|
65 |
+
|
66 |
+
# Initialise results to zero for each sequence as the fields are only defined over the set of all sequences
|
67 |
+
res = {}
|
68 |
+
for field in self.fields:
|
69 |
+
res[field] = [0 for _ in self.array_labels]
|
70 |
+
|
71 |
+
gt_ids, dt_ids = data['gt_track_ids'], data['dt_track_ids']
|
72 |
+
|
73 |
+
if len(gt_ids) == 0 and len(dt_ids) == 0:
|
74 |
+
for idx in range(self.num_ig_masks):
|
75 |
+
res[idx] = None
|
76 |
+
return res
|
77 |
+
|
78 |
+
# get track data
|
79 |
+
gt_tr_areas = data.get('gt_track_areas', None) if self.use_area_rngs else None
|
80 |
+
gt_tr_lengths = data.get('gt_track_lengths', None) if self.use_time_rngs else None
|
81 |
+
gt_tr_iscrowd = data.get('gt_track_iscrowd', None)
|
82 |
+
dt_tr_areas = data.get('dt_track_areas', None) if self.use_area_rngs else None
|
83 |
+
dt_tr_lengths = data.get('dt_track_lengths', None) if self.use_time_rngs else None
|
84 |
+
is_nel = data.get('not_exhaustively_labeled', False)
|
85 |
+
|
86 |
+
# compute ignore masks for different track sets to eval
|
87 |
+
gt_ig_masks = self._compute_track_ig_masks(len(gt_ids), track_lengths=gt_tr_lengths, track_areas=gt_tr_areas,
|
88 |
+
iscrowd=gt_tr_iscrowd)
|
89 |
+
dt_ig_masks = self._compute_track_ig_masks(len(dt_ids), track_lengths=dt_tr_lengths, track_areas=dt_tr_areas,
|
90 |
+
is_not_exhaustively_labeled=is_nel, is_gt=False)
|
91 |
+
|
92 |
+
boxformat = data.get('boxformat', 'xywh')
|
93 |
+
ious = self._compute_track_ious(data['dt_tracks'], data['gt_tracks'], iou_function=data['iou_type'],
|
94 |
+
boxformat=boxformat)
|
95 |
+
|
96 |
+
for mask_idx in range(self.num_ig_masks):
|
97 |
+
gt_ig_mask = gt_ig_masks[mask_idx]
|
98 |
+
|
99 |
+
# Sort gt ignore last
|
100 |
+
gt_idx = np.argsort([g for g in gt_ig_mask], kind="mergesort")
|
101 |
+
gt_ids = [gt_ids[i] for i in gt_idx]
|
102 |
+
|
103 |
+
ious_sorted = ious[:, gt_idx] if len(ious) > 0 else ious
|
104 |
+
|
105 |
+
num_thrs = len(self.array_labels)
|
106 |
+
num_gt = len(gt_ids)
|
107 |
+
num_dt = len(dt_ids)
|
108 |
+
|
109 |
+
# Array to store the "id" of the matched dt/gt
|
110 |
+
gt_m = np.zeros((num_thrs, num_gt)) - 1
|
111 |
+
dt_m = np.zeros((num_thrs, num_dt)) - 1
|
112 |
+
|
113 |
+
gt_ig = np.array([gt_ig_mask[idx] for idx in gt_idx])
|
114 |
+
dt_ig = np.zeros((num_thrs, num_dt))
|
115 |
+
|
116 |
+
for iou_thr_idx, iou_thr in enumerate(self.array_labels):
|
117 |
+
if len(ious_sorted) == 0:
|
118 |
+
break
|
119 |
+
|
120 |
+
for dt_idx, _dt in enumerate(dt_ids):
|
121 |
+
iou = min([iou_thr, 1 - 1e-10])
|
122 |
+
# information about best match so far (m=-1 -> unmatched)
|
123 |
+
# store the gt_idx which matched for _dt
|
124 |
+
m = -1
|
125 |
+
for gt_idx, _ in enumerate(gt_ids):
|
126 |
+
# if this gt already matched continue
|
127 |
+
if gt_m[iou_thr_idx, gt_idx] > 0:
|
128 |
+
continue
|
129 |
+
# if _dt matched to reg gt, and on ignore gt, stop
|
130 |
+
if m > -1 and gt_ig[m] == 0 and gt_ig[gt_idx] == 1:
|
131 |
+
break
|
132 |
+
# continue to next gt unless better match made
|
133 |
+
if ious_sorted[dt_idx, gt_idx] < iou - np.finfo('float').eps:
|
134 |
+
continue
|
135 |
+
# if match successful and best so far, store appropriately
|
136 |
+
iou = ious_sorted[dt_idx, gt_idx]
|
137 |
+
m = gt_idx
|
138 |
+
|
139 |
+
# No match found for _dt, go to next _dt
|
140 |
+
if m == -1:
|
141 |
+
continue
|
142 |
+
|
143 |
+
# if gt to ignore for some reason update dt_ig.
|
144 |
+
# Should not be used in evaluation.
|
145 |
+
dt_ig[iou_thr_idx, dt_idx] = gt_ig[m]
|
146 |
+
# _dt match found, update gt_m, and dt_m with "id"
|
147 |
+
dt_m[iou_thr_idx, dt_idx] = gt_ids[m]
|
148 |
+
gt_m[iou_thr_idx, m] = _dt
|
149 |
+
|
150 |
+
dt_ig_mask = dt_ig_masks[mask_idx]
|
151 |
+
|
152 |
+
dt_ig_mask = np.array(dt_ig_mask).reshape((1, num_dt)) # 1 X num_dt
|
153 |
+
dt_ig_mask = np.repeat(dt_ig_mask, num_thrs, 0) # num_thrs X num_dt
|
154 |
+
|
155 |
+
# Based on dt_ig_mask ignore any unmatched detection by updating dt_ig
|
156 |
+
dt_ig = np.logical_or(dt_ig, np.logical_and(dt_m == -1, dt_ig_mask))
|
157 |
+
# store results for given video and category
|
158 |
+
res[mask_idx] = {
|
159 |
+
"dt_ids": dt_ids,
|
160 |
+
"gt_ids": gt_ids,
|
161 |
+
"dt_matches": dt_m,
|
162 |
+
"gt_matches": gt_m,
|
163 |
+
"dt_scores": data['dt_track_scores'],
|
164 |
+
"gt_ignore": gt_ig,
|
165 |
+
"dt_ignore": dt_ig,
|
166 |
+
}
|
167 |
+
|
168 |
+
return res
|
169 |
+
|
170 |
+
def combine_sequences(self, all_res):
|
171 |
+
"""Combines metrics across all sequences. Computes precision and recall values based on track matches.
|
172 |
+
Adapted from https://github.com/TAO-Dataset/
|
173 |
+
"""
|
174 |
+
num_thrs = len(self.array_labels)
|
175 |
+
num_recalls = len(self.rec_thrs)
|
176 |
+
|
177 |
+
# -1 for absent categories
|
178 |
+
precision = -np.ones(
|
179 |
+
(num_thrs, num_recalls, self.num_ig_masks)
|
180 |
+
)
|
181 |
+
recall = -np.ones((num_thrs, self.num_ig_masks))
|
182 |
+
|
183 |
+
for ig_idx in range(self.num_ig_masks):
|
184 |
+
ig_idx_results = [res[ig_idx] for res in all_res.values() if res[ig_idx] is not None]
|
185 |
+
|
186 |
+
# Remove elements which are None
|
187 |
+
if len(ig_idx_results) == 0:
|
188 |
+
continue
|
189 |
+
|
190 |
+
# Append all scores: shape (N,)
|
191 |
+
# limit considered tracks for each sequence if maxDet > 0
|
192 |
+
if self.maxDet == 0:
|
193 |
+
dt_scores = np.concatenate([res["dt_scores"] for res in ig_idx_results], axis=0)
|
194 |
+
|
195 |
+
dt_idx = np.argsort(-dt_scores, kind="mergesort")
|
196 |
+
|
197 |
+
dt_m = np.concatenate([e["dt_matches"] for e in ig_idx_results],
|
198 |
+
axis=1)[:, dt_idx]
|
199 |
+
dt_ig = np.concatenate([e["dt_ignore"] for e in ig_idx_results],
|
200 |
+
axis=1)[:, dt_idx]
|
201 |
+
elif self.maxDet > 0:
|
202 |
+
dt_scores = np.concatenate([res["dt_scores"][0:self.maxDet] for res in ig_idx_results], axis=0)
|
203 |
+
|
204 |
+
dt_idx = np.argsort(-dt_scores, kind="mergesort")
|
205 |
+
|
206 |
+
dt_m = np.concatenate([e["dt_matches"][:, 0:self.maxDet] for e in ig_idx_results],
|
207 |
+
axis=1)[:, dt_idx]
|
208 |
+
dt_ig = np.concatenate([e["dt_ignore"][:, 0:self.maxDet] for e in ig_idx_results],
|
209 |
+
axis=1)[:, dt_idx]
|
210 |
+
else:
|
211 |
+
raise Exception("Number of maximum detections must be >= 0, but is set to %i" % self.maxDet)
|
212 |
+
|
213 |
+
gt_ig = np.concatenate([res["gt_ignore"] for res in ig_idx_results])
|
214 |
+
# num gt anns to consider
|
215 |
+
num_gt = np.count_nonzero(gt_ig == 0)
|
216 |
+
|
217 |
+
if num_gt == 0:
|
218 |
+
continue
|
219 |
+
|
220 |
+
tps = np.logical_and(dt_m != -1, np.logical_not(dt_ig))
|
221 |
+
fps = np.logical_and(dt_m == -1, np.logical_not(dt_ig))
|
222 |
+
|
223 |
+
tp_sum = np.cumsum(tps, axis=1).astype(dtype=float)
|
224 |
+
fp_sum = np.cumsum(fps, axis=1).astype(dtype=float)
|
225 |
+
|
226 |
+
for iou_thr_idx, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
|
227 |
+
tp = np.array(tp)
|
228 |
+
fp = np.array(fp)
|
229 |
+
num_tp = len(tp)
|
230 |
+
rc = tp / num_gt
|
231 |
+
if num_tp:
|
232 |
+
recall[iou_thr_idx, ig_idx] = rc[-1]
|
233 |
+
else:
|
234 |
+
recall[iou_thr_idx, ig_idx] = 0
|
235 |
+
|
236 |
+
# np.spacing(1) ~= eps
|
237 |
+
pr = tp / (fp + tp + np.spacing(1))
|
238 |
+
pr = pr.tolist()
|
239 |
+
|
240 |
+
# Ensure precision values are monotonically decreasing
|
241 |
+
for i in range(num_tp - 1, 0, -1):
|
242 |
+
if pr[i] > pr[i - 1]:
|
243 |
+
pr[i - 1] = pr[i]
|
244 |
+
|
245 |
+
# find indices at the predefined recall values
|
246 |
+
rec_thrs_insert_idx = np.searchsorted(rc, self.rec_thrs, side="left")
|
247 |
+
|
248 |
+
pr_at_recall = [0.0] * num_recalls
|
249 |
+
|
250 |
+
try:
|
251 |
+
for _idx, pr_idx in enumerate(rec_thrs_insert_idx):
|
252 |
+
pr_at_recall[_idx] = pr[pr_idx]
|
253 |
+
except IndexError:
|
254 |
+
pass
|
255 |
+
|
256 |
+
precision[iou_thr_idx, :, ig_idx] = (np.array(pr_at_recall))
|
257 |
+
|
258 |
+
res = {'precision': precision, 'recall': recall}
|
259 |
+
|
260 |
+
# compute the precision and recall averages for the respective alpha thresholds and ignore masks
|
261 |
+
for lbl in self.lbls:
|
262 |
+
res['AP_' + lbl] = np.zeros((len(self.array_labels)), dtype=float)
|
263 |
+
res['AR_' + lbl] = np.zeros((len(self.array_labels)), dtype=float)
|
264 |
+
|
265 |
+
for a_id, alpha in enumerate(self.array_labels):
|
266 |
+
for lbl_idx, lbl in enumerate(self.lbls):
|
267 |
+
p = precision[a_id, :, lbl_idx]
|
268 |
+
if len(p[p > -1]) == 0:
|
269 |
+
mean_p = -1
|
270 |
+
else:
|
271 |
+
mean_p = np.mean(p[p > -1])
|
272 |
+
res['AP_' + lbl][a_id] = mean_p
|
273 |
+
res['AR_' + lbl][a_id] = recall[a_id, lbl_idx]
|
274 |
+
|
275 |
+
return res
|
276 |
+
|
277 |
+
def combine_classes_class_averaged(self, all_res, ignore_empty_classes=True):
|
278 |
+
"""Combines metrics across all classes by averaging over the class values
|
279 |
+
Note mAP is not well defined for 'empty classes' so 'ignore empty classes' is always true here.
|
280 |
+
"""
|
281 |
+
res = {}
|
282 |
+
for field in self.fields:
|
283 |
+
res[field] = np.zeros((len(self.array_labels)), dtype=float)
|
284 |
+
field_stacked = np.array([res[field] for res in all_res.values()])
|
285 |
+
|
286 |
+
for a_id, alpha in enumerate(self.array_labels):
|
287 |
+
values = field_stacked[:, a_id]
|
288 |
+
if len(values[values > -1]) == 0:
|
289 |
+
mean = -1
|
290 |
+
else:
|
291 |
+
mean = np.mean(values[values > -1])
|
292 |
+
res[field][a_id] = mean
|
293 |
+
return res
|
294 |
+
|
295 |
+
def combine_classes_det_averaged(self, all_res):
|
296 |
+
"""Combines metrics across all classes by averaging over the detection values"""
|
297 |
+
|
298 |
+
res = {}
|
299 |
+
for field in self.fields:
|
300 |
+
res[field] = np.zeros((len(self.array_labels)), dtype=float)
|
301 |
+
field_stacked = np.array([res[field] for res in all_res.values()])
|
302 |
+
|
303 |
+
for a_id, alpha in enumerate(self.array_labels):
|
304 |
+
values = field_stacked[:, a_id]
|
305 |
+
if len(values[values > -1]) == 0:
|
306 |
+
mean = -1
|
307 |
+
else:
|
308 |
+
mean = np.mean(values[values > -1])
|
309 |
+
res[field][a_id] = mean
|
310 |
+
return res
|
311 |
+
|
312 |
+
def _compute_track_ig_masks(self, num_ids, track_lengths=None, track_areas=None, iscrowd=None,
|
313 |
+
is_not_exhaustively_labeled=False, is_gt=True):
|
314 |
+
"""
|
315 |
+
Computes ignore masks for different track sets to evaluate
|
316 |
+
:param num_ids: the number of track IDs
|
317 |
+
:param track_lengths: the lengths of the tracks (number of timesteps)
|
318 |
+
:param track_areas: the average area of a track
|
319 |
+
:param iscrowd: whether a track is marked as crowd
|
320 |
+
:param is_not_exhaustively_labeled: whether the track category is not exhaustively labeled
|
321 |
+
:param is_gt: whether it is gt
|
322 |
+
:return: the track ignore masks
|
323 |
+
"""
|
324 |
+
# for TAO tracks for classes which are not exhaustively labeled are not evaluated
|
325 |
+
if not is_gt and is_not_exhaustively_labeled:
|
326 |
+
track_ig_masks = [[1 for _ in range(num_ids)] for i in range(self.num_ig_masks)]
|
327 |
+
else:
|
328 |
+
# consider all tracks
|
329 |
+
track_ig_masks = [[0 for _ in range(num_ids)]]
|
330 |
+
|
331 |
+
# consider tracks with certain area
|
332 |
+
if self.use_area_rngs:
|
333 |
+
for rng in self.area_rngs:
|
334 |
+
track_ig_masks.append([0 if rng[0] - np.finfo('float').eps <= area <= rng[1] + np.finfo('float').eps
|
335 |
+
else 1 for area in track_areas])
|
336 |
+
|
337 |
+
# consider tracks with certain duration
|
338 |
+
if self.use_time_rngs:
|
339 |
+
for rng in self.time_rngs:
|
340 |
+
track_ig_masks.append([0 if rng[0] - np.finfo('float').eps <= length
|
341 |
+
<= rng[1] + np.finfo('float').eps else 1 for length in track_lengths])
|
342 |
+
|
343 |
+
# for YouTubeVIS evaluation tracks with crowd tag are not evaluated
|
344 |
+
if is_gt and iscrowd:
|
345 |
+
track_ig_masks = [np.logical_or(mask, iscrowd) for mask in track_ig_masks]
|
346 |
+
|
347 |
+
return track_ig_masks
|
348 |
+
|
349 |
+
@staticmethod
|
350 |
+
def _compute_bb_track_iou(dt_track, gt_track, boxformat='xywh'):
|
351 |
+
"""
|
352 |
+
Calculates the track IoU for one detected track and one ground truth track for bounding boxes
|
353 |
+
:param dt_track: the detected track (format: dictionary with frame index as keys and
|
354 |
+
numpy arrays as values)
|
355 |
+
:param gt_track: the ground truth track (format: dictionary with frame index as keys and
|
356 |
+
numpy array as values)
|
357 |
+
:param boxformat: the format of the boxes
|
358 |
+
:return: the track IoU
|
359 |
+
"""
|
360 |
+
intersect = 0
|
361 |
+
union = 0
|
362 |
+
image_ids = set(gt_track.keys()) | set(dt_track.keys())
|
363 |
+
for image in image_ids:
|
364 |
+
g = gt_track.get(image, None)
|
365 |
+
d = dt_track.get(image, None)
|
366 |
+
if boxformat == 'xywh':
|
367 |
+
if d is not None and g is not None:
|
368 |
+
dx, dy, dw, dh = d
|
369 |
+
gx, gy, gw, gh = g
|
370 |
+
w = max(min(dx + dw, gx + gw) - max(dx, gx), 0)
|
371 |
+
h = max(min(dy + dh, gy + gh) - max(dy, gy), 0)
|
372 |
+
i = w * h
|
373 |
+
u = dw * dh + gw * gh - i
|
374 |
+
intersect += i
|
375 |
+
union += u
|
376 |
+
elif d is None and g is not None:
|
377 |
+
union += g[2] * g[3]
|
378 |
+
elif d is not None and g is None:
|
379 |
+
union += d[2] * d[3]
|
380 |
+
elif boxformat == 'x0y0x1y1':
|
381 |
+
if d is not None and g is not None:
|
382 |
+
dx0, dy0, dx1, dy1 = d
|
383 |
+
gx0, gy0, gx1, gy1 = g
|
384 |
+
w = max(min(dx1, gx1) - max(dx0, gx0), 0)
|
385 |
+
h = max(min(dy1, gy1) - max(dy0, gy0), 0)
|
386 |
+
i = w * h
|
387 |
+
u = (dx1 - dx0) * (dy1 - dy0) + (gx1 - gx0) * (gy1 - gy0) - i
|
388 |
+
intersect += i
|
389 |
+
union += u
|
390 |
+
elif d is None and g is not None:
|
391 |
+
union += (g[2] - g[0]) * (g[3] - g[1])
|
392 |
+
elif d is not None and g is None:
|
393 |
+
union += (d[2] - d[0]) * (d[3] - d[1])
|
394 |
+
else:
|
395 |
+
raise TrackEvalException('BoxFormat not implemented')
|
396 |
+
if intersect > union:
|
397 |
+
raise TrackEvalException("Intersection value > union value. Are the box values corrupted?")
|
398 |
+
return intersect / union if union > 0 else 0
|
399 |
+
|
400 |
+
@staticmethod
|
401 |
+
def _compute_mask_track_iou(dt_track, gt_track):
|
402 |
+
"""
|
403 |
+
Calculates the track IoU for one detected track and one ground truth track for segmentation masks
|
404 |
+
:param dt_track: the detected track (format: dictionary with frame index as keys and
|
405 |
+
pycocotools rle encoded masks as values)
|
406 |
+
:param gt_track: the ground truth track (format: dictionary with frame index as keys and
|
407 |
+
pycocotools rle encoded masks as values)
|
408 |
+
:return: the track IoU
|
409 |
+
"""
|
410 |
+
# only loaded when needed to reduce minimum requirements
|
411 |
+
from pycocotools import mask as mask_utils
|
412 |
+
|
413 |
+
intersect = .0
|
414 |
+
union = .0
|
415 |
+
image_ids = set(gt_track.keys()) | set(dt_track.keys())
|
416 |
+
for image in image_ids:
|
417 |
+
g = gt_track.get(image, None)
|
418 |
+
d = dt_track.get(image, None)
|
419 |
+
if d and g:
|
420 |
+
intersect += mask_utils.area(mask_utils.merge([d, g], True))
|
421 |
+
union += mask_utils.area(mask_utils.merge([d, g], False))
|
422 |
+
elif not d and g:
|
423 |
+
union += mask_utils.area(g)
|
424 |
+
elif d and not g:
|
425 |
+
union += mask_utils.area(d)
|
426 |
+
if union < 0.0 - np.finfo('float').eps:
|
427 |
+
raise TrackEvalException("Union value < 0. Are the segmentaions corrupted?")
|
428 |
+
if intersect > union:
|
429 |
+
raise TrackEvalException("Intersection value > union value. Are the segmentations corrupted?")
|
430 |
+
iou = intersect / union if union > 0.0 + np.finfo('float').eps else 0.0
|
431 |
+
return iou
|
432 |
+
|
433 |
+
@staticmethod
|
434 |
+
def _compute_track_ious(dt, gt, iou_function='bbox', boxformat='xywh'):
|
435 |
+
"""
|
436 |
+
Calculate track IoUs for a set of ground truth tracks and a set of detected tracks
|
437 |
+
"""
|
438 |
+
|
439 |
+
if len(gt) == 0 and len(dt) == 0:
|
440 |
+
return []
|
441 |
+
|
442 |
+
if iou_function == 'bbox':
|
443 |
+
track_iou_function = partial(TrackMAP._compute_bb_track_iou, boxformat=boxformat)
|
444 |
+
elif iou_function == 'mask':
|
445 |
+
track_iou_function = partial(TrackMAP._compute_mask_track_iou)
|
446 |
+
else:
|
447 |
+
raise Exception('IoU function not implemented')
|
448 |
+
|
449 |
+
ious = np.zeros([len(dt), len(gt)])
|
450 |
+
for i, j in np.ndindex(ious.shape):
|
451 |
+
ious[i, j] = track_iou_function(dt[i], gt[j])
|
452 |
+
return ious
|
453 |
+
|
454 |
+
@staticmethod
|
455 |
+
def _row_print(*argv):
|
456 |
+
"""Prints results in an evenly spaced rows, with more space in first row"""
|
457 |
+
if len(argv) == 1:
|
458 |
+
argv = argv[0]
|
459 |
+
to_print = '%-40s' % argv[0]
|
460 |
+
for v in argv[1:]:
|
461 |
+
to_print += '%-12s' % str(v)
|
462 |
+
print(to_print)
|
avism/data/aviseval/metrics/vace.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy.optimize import linear_sum_assignment
|
3 |
+
from ._base_metric import _BaseMetric
|
4 |
+
from .. import _timing
|
5 |
+
|
6 |
+
|
7 |
+
class VACE(_BaseMetric):
|
8 |
+
"""Class which implements the VACE metrics.
|
9 |
+
|
10 |
+
The metrics are described in:
|
11 |
+
Manohar et al. (2006) "Performance Evaluation of Object Detection and Tracking in Video"
|
12 |
+
https://link.springer.com/chapter/10.1007/11612704_16
|
13 |
+
|
14 |
+
This implementation uses the "relaxed" variant of the metrics,
|
15 |
+
where an overlap threshold is applied in each frame.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, config=None):
|
19 |
+
super().__init__()
|
20 |
+
self.integer_fields = ['VACE_IDs', 'VACE_GT_IDs', 'num_non_empty_timesteps']
|
21 |
+
self.float_fields = ['STDA', 'ATA', 'FDA', 'SFDA']
|
22 |
+
self.fields = self.integer_fields + self.float_fields
|
23 |
+
self.summary_fields = ['SFDA', 'ATA']
|
24 |
+
|
25 |
+
# Fields that are accumulated over multiple videos.
|
26 |
+
self._additive_fields = self.integer_fields + ['STDA', 'FDA']
|
27 |
+
|
28 |
+
self.threshold = 0.5
|
29 |
+
|
30 |
+
@_timing.time
|
31 |
+
def eval_sequence(self, data):
|
32 |
+
"""Calculates VACE metrics for one sequence.
|
33 |
+
|
34 |
+
Depends on the fields:
|
35 |
+
data['num_gt_ids']
|
36 |
+
data['num_tracker_ids']
|
37 |
+
data['gt_ids']
|
38 |
+
data['tracker_ids']
|
39 |
+
data['similarity_scores']
|
40 |
+
"""
|
41 |
+
res = {}
|
42 |
+
|
43 |
+
# Obtain Average Tracking Accuracy (ATA) using track correspondence.
|
44 |
+
# Obtain counts necessary to compute temporal IOU.
|
45 |
+
# Assume that integer counts can be represented exactly as floats.
|
46 |
+
potential_matches_count = np.zeros((data['num_gt_ids'], data['num_tracker_ids']))
|
47 |
+
gt_id_count = np.zeros(data['num_gt_ids'])
|
48 |
+
tracker_id_count = np.zeros(data['num_tracker_ids'])
|
49 |
+
both_present_count = np.zeros((data['num_gt_ids'], data['num_tracker_ids']))
|
50 |
+
for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(data['gt_ids'], data['tracker_ids'])):
|
51 |
+
# Count the number of frames in which two tracks satisfy the overlap criterion.
|
52 |
+
matches_mask = np.greater_equal(data['similarity_scores'][t], self.threshold)
|
53 |
+
match_idx_gt, match_idx_tracker = np.nonzero(matches_mask)
|
54 |
+
potential_matches_count[gt_ids_t[match_idx_gt], tracker_ids_t[match_idx_tracker]] += 1
|
55 |
+
# Count the number of frames in which the tracks are present.
|
56 |
+
gt_id_count[gt_ids_t] += 1
|
57 |
+
tracker_id_count[tracker_ids_t] += 1
|
58 |
+
both_present_count[gt_ids_t[:, np.newaxis], tracker_ids_t[np.newaxis, :]] += 1
|
59 |
+
# Number of frames in which either track is present (union of the two sets of frames).
|
60 |
+
union_count = (gt_id_count[:, np.newaxis]
|
61 |
+
+ tracker_id_count[np.newaxis, :]
|
62 |
+
- both_present_count)
|
63 |
+
# The denominator should always be non-zero if all tracks are non-empty.
|
64 |
+
with np.errstate(divide='raise', invalid='raise'):
|
65 |
+
temporal_iou = potential_matches_count / union_count
|
66 |
+
# Find assignment that maximizes temporal IOU.
|
67 |
+
match_rows, match_cols = linear_sum_assignment(-temporal_iou)
|
68 |
+
res['STDA'] = temporal_iou[match_rows, match_cols].sum()
|
69 |
+
res['VACE_IDs'] = data['num_tracker_ids']
|
70 |
+
res['VACE_GT_IDs'] = data['num_gt_ids']
|
71 |
+
|
72 |
+
# Obtain Frame Detection Accuracy (FDA) using per-frame correspondence.
|
73 |
+
non_empty_count = 0
|
74 |
+
fda = 0
|
75 |
+
for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(data['gt_ids'], data['tracker_ids'])):
|
76 |
+
n_g = len(gt_ids_t)
|
77 |
+
n_d = len(tracker_ids_t)
|
78 |
+
if not (n_g or n_d):
|
79 |
+
continue
|
80 |
+
# n_g > 0 or n_d > 0
|
81 |
+
non_empty_count += 1
|
82 |
+
if not (n_g and n_d):
|
83 |
+
continue
|
84 |
+
# n_g > 0 and n_d > 0
|
85 |
+
spatial_overlap = data['similarity_scores'][t]
|
86 |
+
match_rows, match_cols = linear_sum_assignment(-spatial_overlap)
|
87 |
+
overlap_ratio = spatial_overlap[match_rows, match_cols].sum()
|
88 |
+
fda += overlap_ratio / (0.5 * (n_g + n_d))
|
89 |
+
res['FDA'] = fda
|
90 |
+
res['num_non_empty_timesteps'] = non_empty_count
|
91 |
+
|
92 |
+
res.update(self._compute_final_fields(res))
|
93 |
+
return res
|
94 |
+
|
95 |
+
def combine_classes_class_averaged(self, all_res, ignore_empty_classes=True):
|
96 |
+
"""Combines metrics across all classes by averaging over the class values.
|
97 |
+
If 'ignore_empty_classes' is True, then it only sums over classes with at least one gt or predicted detection.
|
98 |
+
"""
|
99 |
+
res = {}
|
100 |
+
for field in self.fields:
|
101 |
+
if ignore_empty_classes:
|
102 |
+
res[field] = np.mean([v[field] for v in all_res.values()
|
103 |
+
if v['VACE_GT_IDs'] > 0 or v['VACE_IDs'] > 0], axis=0)
|
104 |
+
else:
|
105 |
+
res[field] = np.mean([v[field] for v in all_res.values()], axis=0)
|
106 |
+
return res
|
107 |
+
|
108 |
+
def combine_classes_det_averaged(self, all_res):
|
109 |
+
"""Combines metrics across all classes by averaging over the detection values"""
|
110 |
+
res = {}
|
111 |
+
for field in self._additive_fields:
|
112 |
+
res[field] = _BaseMetric._combine_sum(all_res, field)
|
113 |
+
res = self._compute_final_fields(res)
|
114 |
+
return res
|
115 |
+
|
116 |
+
def combine_sequences(self, all_res):
|
117 |
+
"""Combines metrics across all sequences"""
|
118 |
+
res = {}
|
119 |
+
for header in self._additive_fields:
|
120 |
+
res[header] = _BaseMetric._combine_sum(all_res, header)
|
121 |
+
res.update(self._compute_final_fields(res))
|
122 |
+
return res
|
123 |
+
|
124 |
+
@staticmethod
|
125 |
+
def _compute_final_fields(additive):
|
126 |
+
final = {}
|
127 |
+
with np.errstate(invalid='ignore'): # Permit nan results.
|
128 |
+
final['ATA'] = (additive['STDA'] /
|
129 |
+
(0.5 * (additive['VACE_IDs'] + additive['VACE_GT_IDs'])))
|
130 |
+
final['SFDA'] = additive['FDA'] / additive['num_non_empty_timesteps']
|
131 |
+
return final
|
avism/data/aviseval/plotting.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from .utils import TrackEvalException
|
5 |
+
|
6 |
+
|
7 |
+
def plot_compare_trackers(tracker_folder, tracker_list, cls, output_folder, plots_list=None):
|
8 |
+
"""Create plots which compare metrics across different trackers."""
|
9 |
+
# Define what to plot
|
10 |
+
if plots_list is None:
|
11 |
+
plots_list = get_default_plots_list()
|
12 |
+
|
13 |
+
# Load data
|
14 |
+
data = load_multiple_tracker_summaries(tracker_folder, tracker_list, cls)
|
15 |
+
out_loc = os.path.join(output_folder, cls)
|
16 |
+
|
17 |
+
# Plot
|
18 |
+
for args in plots_list:
|
19 |
+
create_comparison_plot(data, out_loc, *args)
|
20 |
+
|
21 |
+
|
22 |
+
def get_default_plots_list():
|
23 |
+
# y_label, x_label, sort_label, bg_label, bg_function
|
24 |
+
plots_list = [
|
25 |
+
['AssA', 'DetA', 'HOTA', 'HOTA', 'geometric_mean'],
|
26 |
+
['AssPr', 'AssRe', 'HOTA', 'AssA', 'jaccard'],
|
27 |
+
['DetPr', 'DetRe', 'HOTA', 'DetA', 'jaccard'],
|
28 |
+
['HOTA(0)', 'LocA(0)', 'HOTA', 'HOTALocA(0)', 'multiplication'],
|
29 |
+
['HOTA', 'LocA', 'HOTA', None, None],
|
30 |
+
|
31 |
+
['HOTA', 'MOTA', 'HOTA', None, None],
|
32 |
+
['HOTA', 'IDF1', 'HOTA', None, None],
|
33 |
+
['IDF1', 'MOTA', 'HOTA', None, None],
|
34 |
+
]
|
35 |
+
return plots_list
|
36 |
+
|
37 |
+
|
38 |
+
def load_multiple_tracker_summaries(tracker_folder, tracker_list, cls):
|
39 |
+
"""Loads summary data for multiple trackers."""
|
40 |
+
data = {}
|
41 |
+
for tracker in tracker_list:
|
42 |
+
with open(os.path.join(tracker_folder, tracker, cls + '_summary.txt')) as f:
|
43 |
+
keys = next(f).split(' ')
|
44 |
+
done = False
|
45 |
+
while not done:
|
46 |
+
values = next(f).split(' ')
|
47 |
+
if len(values) == len(keys):
|
48 |
+
done = True
|
49 |
+
data[tracker] = dict(zip(keys, map(float, values)))
|
50 |
+
return data
|
51 |
+
|
52 |
+
|
53 |
+
def create_comparison_plot(data, out_loc, y_label, x_label, sort_label, bg_label=None, bg_function=None, settings=None):
|
54 |
+
""" Creates a scatter plot comparing multiple trackers between two metric fields, with one on the x-axis and the
|
55 |
+
other on the y axis. Adds pareto optical lines and (optionally) a background contour.
|
56 |
+
|
57 |
+
Inputs:
|
58 |
+
data: dict of dicts such that data[tracker_name][metric_field_name] = float
|
59 |
+
y_label: the metric_field_name to be plotted on the y-axis
|
60 |
+
x_label: the metric_field_name to be plotted on the x-axis
|
61 |
+
sort_label: the metric_field_name by which trackers are ordered and ranked
|
62 |
+
bg_label: the metric_field_name by which (optional) background contours are plotted
|
63 |
+
bg_function: the (optional) function bg_function(x,y) which converts the x_label / y_label values into bg_label.
|
64 |
+
settings: dict of plot settings with keys:
|
65 |
+
'gap_val': gap between axis ticks and bg curves.
|
66 |
+
'num_to_plot': maximum number of trackers to plot
|
67 |
+
"""
|
68 |
+
|
69 |
+
# Only loaded when run to reduce minimum requirements
|
70 |
+
from matplotlib import pyplot as plt
|
71 |
+
|
72 |
+
# Get plot settings
|
73 |
+
if settings is None:
|
74 |
+
gap_val = 2
|
75 |
+
num_to_plot = 20
|
76 |
+
else:
|
77 |
+
gap_val = settings['gap_val']
|
78 |
+
num_to_plot = settings['num_to_plot']
|
79 |
+
|
80 |
+
if (bg_label is None) != (bg_function is None):
|
81 |
+
raise TrackEvalException('bg_function and bg_label must either be both given or neither given.')
|
82 |
+
|
83 |
+
# Extract data
|
84 |
+
tracker_names = np.array(list(data.keys()))
|
85 |
+
sort_index = np.array([data[t][sort_label] for t in tracker_names]).argsort()[::-1]
|
86 |
+
x_values = np.array([data[t][x_label] for t in tracker_names])[sort_index][:num_to_plot]
|
87 |
+
y_values = np.array([data[t][y_label] for t in tracker_names])[sort_index][:num_to_plot]
|
88 |
+
|
89 |
+
# Print info on what is being plotted
|
90 |
+
tracker_names = tracker_names[sort_index][:num_to_plot]
|
91 |
+
print('\nPlotting %s vs %s, for the following (ordered) trackers:' % (y_label, x_label))
|
92 |
+
for i, name in enumerate(tracker_names):
|
93 |
+
print('%i: %s' % (i+1, name))
|
94 |
+
|
95 |
+
# Find best fitting boundaries for data
|
96 |
+
boundaries = _get_boundaries(x_values, y_values, round_val=gap_val/2)
|
97 |
+
|
98 |
+
fig = plt.figure()
|
99 |
+
|
100 |
+
# Plot background contour
|
101 |
+
if bg_function is not None:
|
102 |
+
_plot_bg_contour(bg_function, boundaries, gap_val)
|
103 |
+
|
104 |
+
# Plot pareto optimal lines
|
105 |
+
_plot_pareto_optimal_lines(x_values, y_values)
|
106 |
+
|
107 |
+
# Plot data points with number labels
|
108 |
+
labels = np.arange(len(y_values)) + 1
|
109 |
+
plt.plot(x_values, y_values, 'b.', markersize=15)
|
110 |
+
for xx, yy, l in zip(x_values, y_values, labels):
|
111 |
+
plt.text(xx, yy, str(l), color="red", fontsize=15)
|
112 |
+
|
113 |
+
# Add extra explanatory text to plots
|
114 |
+
plt.text(0, -0.11, 'label order:\nHOTA', horizontalalignment='left', verticalalignment='center',
|
115 |
+
transform=fig.axes[0].transAxes, color="red", fontsize=12)
|
116 |
+
if bg_label is not None:
|
117 |
+
plt.text(1, -0.11, 'curve values:\n' + bg_label, horizontalalignment='right', verticalalignment='center',
|
118 |
+
transform=fig.axes[0].transAxes, color="grey", fontsize=12)
|
119 |
+
|
120 |
+
plt.xlabel(x_label, fontsize=15)
|
121 |
+
plt.ylabel(y_label, fontsize=15)
|
122 |
+
title = y_label + ' vs ' + x_label
|
123 |
+
if bg_label is not None:
|
124 |
+
title += ' (' + bg_label + ')'
|
125 |
+
plt.title(title, fontsize=17)
|
126 |
+
plt.xticks(np.arange(0, 100, gap_val))
|
127 |
+
plt.yticks(np.arange(0, 100, gap_val))
|
128 |
+
min_x, max_x, min_y, max_y = boundaries
|
129 |
+
plt.xlim(min_x, max_x)
|
130 |
+
plt.ylim(min_y, max_y)
|
131 |
+
plt.gca().set_aspect('equal', adjustable='box')
|
132 |
+
plt.tight_layout()
|
133 |
+
|
134 |
+
os.makedirs(out_loc, exist_ok=True)
|
135 |
+
filename = os.path.join(out_loc, title.replace(' ', '_'))
|
136 |
+
plt.savefig(filename + '.pdf', bbox_inches='tight', pad_inches=0.05)
|
137 |
+
plt.savefig(filename + '.png', bbox_inches='tight', pad_inches=0.05)
|
138 |
+
|
139 |
+
|
140 |
+
def _get_boundaries(x_values, y_values, round_val):
|
141 |
+
x1 = np.min(np.floor((x_values - 0.5) / round_val) * round_val)
|
142 |
+
x2 = np.max(np.ceil((x_values + 0.5) / round_val) * round_val)
|
143 |
+
y1 = np.min(np.floor((y_values - 0.5) / round_val) * round_val)
|
144 |
+
y2 = np.max(np.ceil((y_values + 0.5) / round_val) * round_val)
|
145 |
+
x_range = x2 - x1
|
146 |
+
y_range = y2 - y1
|
147 |
+
max_range = max(x_range, y_range)
|
148 |
+
x_center = (x1 + x2) / 2
|
149 |
+
y_center = (y1 + y2) / 2
|
150 |
+
min_x = max(x_center - max_range / 2, 0)
|
151 |
+
max_x = min(x_center + max_range / 2, 100)
|
152 |
+
min_y = max(y_center - max_range / 2, 0)
|
153 |
+
max_y = min(y_center + max_range / 2, 100)
|
154 |
+
return min_x, max_x, min_y, max_y
|
155 |
+
|
156 |
+
|
157 |
+
def geometric_mean(x, y):
|
158 |
+
return np.sqrt(x * y)
|
159 |
+
|
160 |
+
|
161 |
+
def jaccard(x, y):
|
162 |
+
x = x / 100
|
163 |
+
y = y / 100
|
164 |
+
return 100 * (x * y) / (x + y - x * y)
|
165 |
+
|
166 |
+
|
167 |
+
def multiplication(x, y):
|
168 |
+
return x * y / 100
|
169 |
+
|
170 |
+
|
171 |
+
bg_function_dict = {
|
172 |
+
"geometric_mean": geometric_mean,
|
173 |
+
"jaccard": jaccard,
|
174 |
+
"multiplication": multiplication,
|
175 |
+
}
|
176 |
+
|
177 |
+
|
178 |
+
def _plot_bg_contour(bg_function, plot_boundaries, gap_val):
|
179 |
+
""" Plot background contour. """
|
180 |
+
|
181 |
+
# Only loaded when run to reduce minimum requirements
|
182 |
+
from matplotlib import pyplot as plt
|
183 |
+
|
184 |
+
# Plot background contour
|
185 |
+
min_x, max_x, min_y, max_y = plot_boundaries
|
186 |
+
x = np.arange(min_x, max_x, 0.1)
|
187 |
+
y = np.arange(min_y, max_y, 0.1)
|
188 |
+
x_grid, y_grid = np.meshgrid(x, y)
|
189 |
+
if bg_function in bg_function_dict.keys():
|
190 |
+
z_grid = bg_function_dict[bg_function](x_grid, y_grid)
|
191 |
+
else:
|
192 |
+
raise TrackEvalException("background plotting function '%s' is not defined." % bg_function)
|
193 |
+
levels = np.arange(0, 100, gap_val)
|
194 |
+
con = plt.contour(x_grid, y_grid, z_grid, levels, colors='grey')
|
195 |
+
|
196 |
+
def bg_format(val):
|
197 |
+
s = '{:1f}'.format(val)
|
198 |
+
return '{:.0f}'.format(val) if s[-1] == '0' else s
|
199 |
+
|
200 |
+
con.levels = [bg_format(val) for val in con.levels]
|
201 |
+
plt.clabel(con, con.levels, inline=True, fmt='%r', fontsize=8)
|
202 |
+
|
203 |
+
|
204 |
+
def _plot_pareto_optimal_lines(x_values, y_values):
|
205 |
+
""" Plot pareto optimal lines """
|
206 |
+
|
207 |
+
# Only loaded when run to reduce minimum requirements
|
208 |
+
from matplotlib import pyplot as plt
|
209 |
+
|
210 |
+
# Plot pareto optimal lines
|
211 |
+
cxs = x_values
|
212 |
+
cys = y_values
|
213 |
+
best_y = np.argmax(cys)
|
214 |
+
x_pareto = [0, cxs[best_y]]
|
215 |
+
y_pareto = [cys[best_y], cys[best_y]]
|
216 |
+
t = 2
|
217 |
+
remaining = cxs > x_pareto[t - 1]
|
218 |
+
cys = cys[remaining]
|
219 |
+
cxs = cxs[remaining]
|
220 |
+
while len(cxs) > 0 and len(cys) > 0:
|
221 |
+
best_y = np.argmax(cys)
|
222 |
+
x_pareto += [x_pareto[t - 1], cxs[best_y]]
|
223 |
+
y_pareto += [cys[best_y], cys[best_y]]
|
224 |
+
t += 2
|
225 |
+
remaining = cxs > x_pareto[t - 1]
|
226 |
+
cys = cys[remaining]
|
227 |
+
cxs = cxs[remaining]
|
228 |
+
x_pareto.append(x_pareto[t - 1])
|
229 |
+
y_pareto.append(0)
|
230 |
+
plt.plot(np.array(x_pareto), np.array(y_pareto), '--r')
|
avism/data/aviseval/utils.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import csv
|
4 |
+
import argparse
|
5 |
+
from collections import OrderedDict
|
6 |
+
|
7 |
+
|
8 |
+
def init_config(config, default_config, name=None):
|
9 |
+
"""Initialise non-given config values with defaults"""
|
10 |
+
if config is None:
|
11 |
+
config = default_config
|
12 |
+
else:
|
13 |
+
for k in default_config.keys():
|
14 |
+
if k not in config.keys():
|
15 |
+
config[k] = default_config[k]
|
16 |
+
if name and config['PRINT_CONFIG']:
|
17 |
+
print('\n%s Config:' % name)
|
18 |
+
for c in config.keys():
|
19 |
+
print('%-20s : %-30s' % (c, config[c]))
|
20 |
+
return config
|
21 |
+
|
22 |
+
|
23 |
+
def update_config(config):
|
24 |
+
"""
|
25 |
+
Parse the arguments of a script and updates the config values for a given value if specified in the arguments.
|
26 |
+
:param config: the config to update
|
27 |
+
:return: the updated config
|
28 |
+
"""
|
29 |
+
parser = argparse.ArgumentParser()
|
30 |
+
for setting in config.keys():
|
31 |
+
if type(config[setting]) == list or type(config[setting]) == type(None):
|
32 |
+
parser.add_argument("--" + setting, nargs='+')
|
33 |
+
else:
|
34 |
+
parser.add_argument("--" + setting)
|
35 |
+
args = parser.parse_args().__dict__
|
36 |
+
for setting in args.keys():
|
37 |
+
if args[setting] is not None:
|
38 |
+
if type(config[setting]) == type(True):
|
39 |
+
if args[setting] == 'True':
|
40 |
+
x = True
|
41 |
+
elif args[setting] == 'False':
|
42 |
+
x = False
|
43 |
+
else:
|
44 |
+
raise Exception('Command line parameter ' + setting + 'must be True or False')
|
45 |
+
elif type(config[setting]) == type(1):
|
46 |
+
x = int(args[setting])
|
47 |
+
elif type(args[setting]) == type(None):
|
48 |
+
x = None
|
49 |
+
else:
|
50 |
+
x = args[setting]
|
51 |
+
config[setting] = x
|
52 |
+
return config
|
53 |
+
|
54 |
+
|
55 |
+
def get_code_path():
|
56 |
+
"""Get base path where code is"""
|
57 |
+
return os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
58 |
+
|
59 |
+
|
60 |
+
def validate_metrics_list(metrics_list):
|
61 |
+
"""Get names of metric class and ensures they are unique, further checks that the fields within each metric class
|
62 |
+
do not have overlapping names.
|
63 |
+
"""
|
64 |
+
metric_names = [metric.get_name() for metric in metrics_list]
|
65 |
+
# check metric names are unique
|
66 |
+
if len(metric_names) != len(set(metric_names)):
|
67 |
+
raise TrackEvalException('Code being run with multiple metrics of the same name')
|
68 |
+
fields = []
|
69 |
+
for m in metrics_list:
|
70 |
+
fields += m.fields
|
71 |
+
# check metric fields are unique
|
72 |
+
if len(fields) != len(set(fields)):
|
73 |
+
raise TrackEvalException('Code being run with multiple metrics with fields of the same name')
|
74 |
+
return metric_names
|
75 |
+
|
76 |
+
|
77 |
+
def write_summary_results(summaries, cls, output_folder):
|
78 |
+
"""Write summary results to file"""
|
79 |
+
|
80 |
+
fields = sum([list(s.keys()) for s in summaries], [])
|
81 |
+
values = sum([list(s.values()) for s in summaries], [])
|
82 |
+
|
83 |
+
# In order to remain consistent upon new fields being adding, for each of the following fields if they are present
|
84 |
+
# they will be output in the summary first in the order below. Any further fields will be output in the order each
|
85 |
+
# metric family is called, and within each family either in the order they were added to the dict (python >= 3.6) or
|
86 |
+
# randomly (python < 3.6).
|
87 |
+
default_order = ['HOTA', 'DetA', 'AssA', 'DetRe', 'DetPr', 'AssRe', 'AssPr', 'LocA', 'OWTA', 'HOTA(0)', 'LocA(0)',
|
88 |
+
'HOTALocA(0)', 'MOTA', 'MOTP', 'MODA', 'CLR_Re', 'CLR_Pr', 'MTR', 'PTR', 'MLR', 'CLR_TP', 'CLR_FN',
|
89 |
+
'CLR_FP', 'IDSW', 'MT', 'PT', 'ML', 'Frag', 'sMOTA', 'IDF1', 'IDR', 'IDP', 'IDTP', 'IDFN', 'IDFP',
|
90 |
+
'Dets', 'GT_Dets', 'IDs', 'GT_IDs']
|
91 |
+
default_ordered_dict = OrderedDict(zip(default_order, [None for _ in default_order]))
|
92 |
+
for f, v in zip(fields, values):
|
93 |
+
default_ordered_dict[f] = v
|
94 |
+
for df in default_order:
|
95 |
+
if default_ordered_dict[df] is None:
|
96 |
+
del default_ordered_dict[df]
|
97 |
+
fields = list(default_ordered_dict.keys())
|
98 |
+
values = list(default_ordered_dict.values())
|
99 |
+
|
100 |
+
out_file = os.path.join(output_folder, cls + '_summary.txt')
|
101 |
+
os.makedirs(os.path.dirname(out_file), exist_ok=True)
|
102 |
+
with open(out_file, 'w', newline='') as f:
|
103 |
+
writer = csv.writer(f, delimiter=' ')
|
104 |
+
writer.writerow(fields)
|
105 |
+
writer.writerow(values)
|
106 |
+
|
107 |
+
|
108 |
+
def write_detailed_results(details, cls, output_folder):
|
109 |
+
"""Write detailed results to file"""
|
110 |
+
sequences = details[0].keys()
|
111 |
+
fields = ['seq'] + sum([list(s['COMBINED_SEQ'].keys()) for s in details], [])
|
112 |
+
out_file = os.path.join(output_folder, cls + '_detailed.csv')
|
113 |
+
os.makedirs(os.path.dirname(out_file), exist_ok=True)
|
114 |
+
with open(out_file, 'w', newline='') as f:
|
115 |
+
writer = csv.writer(f)
|
116 |
+
writer.writerow(fields)
|
117 |
+
for seq in sorted(sequences):
|
118 |
+
if seq == 'COMBINED_SEQ':
|
119 |
+
continue
|
120 |
+
writer.writerow([seq] + sum([list(s[seq].values()) for s in details], []))
|
121 |
+
writer.writerow(['COMBINED'] + sum([list(s['COMBINED_SEQ'].values()) for s in details], []))
|
122 |
+
|
123 |
+
|
124 |
+
def load_detail(file):
|
125 |
+
"""Loads detailed data for a tracker."""
|
126 |
+
data = {}
|
127 |
+
with open(file) as f:
|
128 |
+
for i, row_text in enumerate(f):
|
129 |
+
row = row_text.replace('\r', '').replace('\n', '').split(',')
|
130 |
+
if i == 0:
|
131 |
+
keys = row[1:]
|
132 |
+
continue
|
133 |
+
current_values = row[1:]
|
134 |
+
seq = row[0]
|
135 |
+
if seq == 'COMBINED':
|
136 |
+
seq = 'COMBINED_SEQ'
|
137 |
+
if (len(current_values) == len(keys)) and seq != '':
|
138 |
+
data[seq] = {}
|
139 |
+
for key, value in zip(keys, current_values):
|
140 |
+
data[seq][key] = float(value)
|
141 |
+
return data
|
142 |
+
|
143 |
+
|
144 |
+
class TrackEvalException(Exception):
|
145 |
+
"""Custom exception for catching expected errors."""
|
146 |
+
...
|
avism/data/build.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import logging
|
3 |
+
import torch.utils.data
|
4 |
+
|
5 |
+
from detectron2.config import CfgNode, configurable
|
6 |
+
from detectron2.data.build import (
|
7 |
+
build_batch_data_loader,
|
8 |
+
load_proposals_into_dataset,
|
9 |
+
trivial_batch_collator,
|
10 |
+
)
|
11 |
+
from detectron2.data.catalog import DatasetCatalog
|
12 |
+
from detectron2.data.common import DatasetFromList, MapDataset
|
13 |
+
from detectron2.data.dataset_mapper import DatasetMapper
|
14 |
+
from detectron2.data.samplers import InferenceSampler, TrainingSampler
|
15 |
+
from detectron2.utils.comm import get_world_size
|
16 |
+
|
17 |
+
|
18 |
+
def _compute_num_images_per_worker(cfg: CfgNode):
|
19 |
+
num_workers = get_world_size()
|
20 |
+
images_per_batch = cfg.SOLVER.IMS_PER_BATCH
|
21 |
+
assert (
|
22 |
+
images_per_batch % num_workers == 0
|
23 |
+
), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format(
|
24 |
+
images_per_batch, num_workers
|
25 |
+
)
|
26 |
+
assert (
|
27 |
+
images_per_batch >= num_workers
|
28 |
+
), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format(
|
29 |
+
images_per_batch, num_workers
|
30 |
+
)
|
31 |
+
images_per_worker = images_per_batch // num_workers
|
32 |
+
return images_per_worker
|
33 |
+
|
34 |
+
|
35 |
+
def filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names):
|
36 |
+
"""
|
37 |
+
Filter out images with none annotations or only crowd annotations
|
38 |
+
(i.e., images without non-crowd annotations).
|
39 |
+
A common training-time preprocessing on COCO dataset.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
list[dict]: the same format, but filtered.
|
46 |
+
"""
|
47 |
+
num_before = len(dataset_dicts)
|
48 |
+
|
49 |
+
def valid(anns):
|
50 |
+
for ann in anns:
|
51 |
+
if isinstance(ann, list):
|
52 |
+
for instance in ann:
|
53 |
+
if instance.get("iscrowd", 0) == 0:
|
54 |
+
return True
|
55 |
+
else:
|
56 |
+
if ann.get("iscrowd", 0) == 0:
|
57 |
+
return True
|
58 |
+
return False
|
59 |
+
|
60 |
+
dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])]
|
61 |
+
num_after = len(dataset_dicts)
|
62 |
+
logger = logging.getLogger(__name__)
|
63 |
+
logger.info(
|
64 |
+
"Removed {} images with no usable annotations. {} images left.".format(
|
65 |
+
num_before - num_after, num_after
|
66 |
+
)
|
67 |
+
)
|
68 |
+
return dataset_dicts
|
69 |
+
|
70 |
+
|
71 |
+
def get_detection_dataset_dicts(
|
72 |
+
dataset_names, filter_empty=True, proposal_files=None
|
73 |
+
):
|
74 |
+
"""
|
75 |
+
Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
dataset_names (str or list[str]): a dataset name or a list of dataset names
|
79 |
+
filter_empty (bool): whether to filter out images without instance annotations
|
80 |
+
proposal_files (list[str]): if given, a list of object proposal files
|
81 |
+
that match each dataset in `dataset_names`.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
list[dict]: a list of dicts following the standard dataset dict format.
|
85 |
+
"""
|
86 |
+
if isinstance(dataset_names, str):
|
87 |
+
dataset_names = [dataset_names]
|
88 |
+
assert len(dataset_names)
|
89 |
+
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]
|
90 |
+
for dataset_name, dicts in zip(dataset_names, dataset_dicts):
|
91 |
+
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
|
92 |
+
|
93 |
+
if proposal_files is not None:
|
94 |
+
assert len(dataset_names) == len(proposal_files)
|
95 |
+
# load precomputed proposals from proposal files
|
96 |
+
dataset_dicts = [
|
97 |
+
load_proposals_into_dataset(dataset_i_dicts, proposal_file)
|
98 |
+
for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
|
99 |
+
]
|
100 |
+
|
101 |
+
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
|
102 |
+
|
103 |
+
has_instances = "annotations" in dataset_dicts[0]
|
104 |
+
if filter_empty and has_instances:
|
105 |
+
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names)
|
106 |
+
|
107 |
+
assert len(dataset_dicts), "No valid data found in {}.".format(",".join(dataset_names))
|
108 |
+
return dataset_dicts
|
109 |
+
|
110 |
+
|
111 |
+
def _train_loader_from_config(cfg, mapper, dataset_name=None, *, dataset=None, sampler=None):
|
112 |
+
if dataset is None:
|
113 |
+
dataset = get_detection_dataset_dicts(
|
114 |
+
dataset_name,
|
115 |
+
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
|
116 |
+
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
|
117 |
+
)
|
118 |
+
|
119 |
+
if mapper is None:
|
120 |
+
mapper = DatasetMapper(cfg, True)
|
121 |
+
|
122 |
+
if sampler is None:
|
123 |
+
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
|
124 |
+
logger = logging.getLogger(__name__)
|
125 |
+
logger.info("Using training sampler {}".format(sampler_name))
|
126 |
+
sampler = TrainingSampler(len(dataset))
|
127 |
+
|
128 |
+
return {
|
129 |
+
"dataset": dataset,
|
130 |
+
"sampler": sampler,
|
131 |
+
"mapper": mapper,
|
132 |
+
"total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
|
133 |
+
"aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
|
134 |
+
"num_workers": cfg.DATALOADER.NUM_WORKERS,
|
135 |
+
}
|
136 |
+
|
137 |
+
|
138 |
+
# TODO can allow dataset as an iterable or IterableDataset to make this function more general
|
139 |
+
@configurable(from_config=_train_loader_from_config)
|
140 |
+
def build_detection_train_loader(
|
141 |
+
dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0
|
142 |
+
):
|
143 |
+
"""
|
144 |
+
Build a dataloader for object detection with some default features.
|
145 |
+
This interface is experimental.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
|
149 |
+
or a map-style pytorch dataset. They can be obtained by using
|
150 |
+
:func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
|
151 |
+
mapper (callable): a callable which takes a sample (dict) from dataset and
|
152 |
+
returns the format to be consumed by the model.
|
153 |
+
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
|
154 |
+
sampler (torch.utils.data.sampler.Sampler or None): a sampler that
|
155 |
+
produces indices to be applied on ``dataset``.
|
156 |
+
Default to :class:`TrainingSampler`, which coordinates a random shuffle
|
157 |
+
sequence across all workers.
|
158 |
+
total_batch_size (int): total batch size across all workers. Batching
|
159 |
+
simply puts data into a list.
|
160 |
+
aspect_ratio_grouping (bool): whether to group images with similar
|
161 |
+
aspect ratio for efficiency. When enabled, it requires each
|
162 |
+
element in dataset be a dict with keys "width" and "height".
|
163 |
+
num_workers (int): number of parallel data loading workers
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
torch.utils.data.DataLoader: a dataloader. Each output from it is a
|
167 |
+
``list[mapped_element]`` of length ``total_batch_size / num_workers``,
|
168 |
+
where ``mapped_element`` is produced by the ``mapper``.
|
169 |
+
"""
|
170 |
+
if isinstance(dataset, list):
|
171 |
+
dataset = DatasetFromList(dataset, copy=False)
|
172 |
+
if mapper is not None:
|
173 |
+
dataset = MapDataset(dataset, mapper)
|
174 |
+
if sampler is None:
|
175 |
+
sampler = TrainingSampler(len(dataset))
|
176 |
+
assert isinstance(sampler, torch.utils.data.sampler.Sampler)
|
177 |
+
return build_batch_data_loader(
|
178 |
+
dataset,
|
179 |
+
sampler,
|
180 |
+
total_batch_size,
|
181 |
+
aspect_ratio_grouping=aspect_ratio_grouping,
|
182 |
+
num_workers=num_workers,
|
183 |
+
)
|
184 |
+
|
185 |
+
|
186 |
+
def _test_loader_from_config(cfg, dataset_name, mapper=None):
|
187 |
+
"""
|
188 |
+
Uses the given `dataset_name` argument (instead of the names in cfg), because the
|
189 |
+
standard practice is to evaluate each test set individually (not combining them).
|
190 |
+
"""
|
191 |
+
dataset = get_detection_dataset_dicts(
|
192 |
+
[dataset_name],
|
193 |
+
filter_empty=False,
|
194 |
+
proposal_files=[
|
195 |
+
cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)]
|
196 |
+
]
|
197 |
+
if cfg.MODEL.LOAD_PROPOSALS
|
198 |
+
else None,
|
199 |
+
)
|
200 |
+
if mapper is None:
|
201 |
+
mapper = DatasetMapper(cfg, False)
|
202 |
+
return {"dataset": dataset, "mapper": mapper, "num_workers": cfg.DATALOADER.NUM_WORKERS}
|
203 |
+
|
204 |
+
|
205 |
+
@configurable(from_config=_test_loader_from_config)
|
206 |
+
def build_detection_test_loader(dataset, *, mapper, num_workers=0):
|
207 |
+
"""
|
208 |
+
Similar to `build_detection_train_loader`, but uses a batch size of 1.
|
209 |
+
This interface is experimental.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
|
213 |
+
or a map-style pytorch dataset. They can be obtained by using
|
214 |
+
:func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
|
215 |
+
mapper (callable): a callable which takes a sample (dict) from dataset
|
216 |
+
and returns the format to be consumed by the model.
|
217 |
+
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
|
218 |
+
num_workers (int): number of parallel data loading workers
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
DataLoader: a torch DataLoader, that loads the given detection
|
222 |
+
dataset, with test-time transformation and batching.
|
223 |
+
|
224 |
+
Examples:
|
225 |
+
::
|
226 |
+
data_loader = build_detection_test_loader(
|
227 |
+
DatasetRegistry.get("my_test"),
|
228 |
+
mapper=DatasetMapper(...))
|
229 |
+
|
230 |
+
# or, instantiate with a CfgNode:
|
231 |
+
data_loader = build_detection_test_loader(cfg, "my_test")
|
232 |
+
"""
|
233 |
+
if isinstance(dataset, list):
|
234 |
+
dataset = DatasetFromList(dataset, copy=False)
|
235 |
+
if mapper is not None:
|
236 |
+
dataset = MapDataset(dataset, mapper)
|
237 |
+
sampler = InferenceSampler(len(dataset))
|
238 |
+
# Always use 1 image per worker during inference since this is the
|
239 |
+
# standard when reporting inference time in papers.
|
240 |
+
batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, 1, drop_last=False)
|
241 |
+
data_loader = torch.utils.data.DataLoader(
|
242 |
+
dataset,
|
243 |
+
num_workers=num_workers,
|
244 |
+
batch_sampler=batch_sampler,
|
245 |
+
collate_fn=trivial_batch_collator,
|
246 |
+
)
|
247 |
+
return data_loader
|
avism/data/dataset_mapper.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import logging
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
from typing import List, Union
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from detectron2.config import configurable
|
9 |
+
from detectron2.structures import (
|
10 |
+
BitMasks,
|
11 |
+
Boxes,
|
12 |
+
BoxMode,
|
13 |
+
Instances,
|
14 |
+
)
|
15 |
+
|
16 |
+
from detectron2.data import detection_utils as utils
|
17 |
+
from detectron2.data import transforms as T
|
18 |
+
|
19 |
+
from .augmentation import build_augmentation
|
20 |
+
|
21 |
+
__all__ = ["AVISDatasetMapper"]
|
22 |
+
|
23 |
+
|
24 |
+
def filter_empty_instances(instances, by_box=True, by_mask=True, box_threshold=1e-5):
|
25 |
+
"""
|
26 |
+
Filter out empty instances in an `Instances` object.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
instances (Instances):
|
30 |
+
by_box (bool): whether to filter out instances with empty boxes
|
31 |
+
by_mask (bool): whether to filter out instances with empty masks
|
32 |
+
box_threshold (float): minimum width and height to be considered non-empty
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
Instances: the filtered instances.
|
36 |
+
"""
|
37 |
+
assert by_box or by_mask
|
38 |
+
r = []
|
39 |
+
if by_box:
|
40 |
+
r.append(instances.gt_boxes.nonempty(threshold=box_threshold))
|
41 |
+
if instances.has("gt_masks") and by_mask:
|
42 |
+
r.append(instances.gt_masks.nonempty())
|
43 |
+
r.append(instances.gt_classes != -1)
|
44 |
+
|
45 |
+
if not r:
|
46 |
+
return instances
|
47 |
+
m = r[0]
|
48 |
+
for x in r[1:]:
|
49 |
+
m = m & x
|
50 |
+
|
51 |
+
instances.gt_ids[~m] = -1
|
52 |
+
return instances
|
53 |
+
|
54 |
+
|
55 |
+
def _get_dummy_anno(num_classes):
|
56 |
+
return {
|
57 |
+
"iscrowd": 0,
|
58 |
+
"category_id": num_classes,
|
59 |
+
"id": -1,
|
60 |
+
"bbox": np.array([0, 0, 0, 0]),
|
61 |
+
"bbox_mode": BoxMode.XYXY_ABS,
|
62 |
+
"segmentation": [np.array([0.0] * 6)]
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
def avis_annotations_to_instances(annos, image_size):
|
67 |
+
"""
|
68 |
+
Create an :class:`Instances` object used by the models,
|
69 |
+
from instance annotations in the dataset dict.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
annos (list[dict]): a list of instance annotations in one image, each
|
73 |
+
element for one instance.
|
74 |
+
image_size (tuple): height, width
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
Instances:
|
78 |
+
It will contain fields "gt_boxes", "gt_classes", "gt_ids",
|
79 |
+
"gt_masks", if they can be obtained from `annos`.
|
80 |
+
This is the format that builtin models expect.
|
81 |
+
"""
|
82 |
+
boxes = [BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS) for obj in annos]
|
83 |
+
target = Instances(image_size)
|
84 |
+
target.gt_boxes = Boxes(boxes)
|
85 |
+
|
86 |
+
classes = [int(obj["category_id"]) for obj in annos]
|
87 |
+
classes = torch.tensor(classes, dtype=torch.int64)
|
88 |
+
target.gt_classes = classes
|
89 |
+
|
90 |
+
ids = [int(obj["id"]) for obj in annos]
|
91 |
+
ids = torch.tensor(ids, dtype=torch.int64)
|
92 |
+
target.gt_ids = ids
|
93 |
+
|
94 |
+
if len(annos) and "segmentation" in annos[0]:
|
95 |
+
segms = [obj["segmentation"] for obj in annos]
|
96 |
+
masks = []
|
97 |
+
for segm in segms:
|
98 |
+
assert segm.ndim == 2, "Expect segmentation of 2 dimensions, got {}.".format(
|
99 |
+
segm.ndim
|
100 |
+
)
|
101 |
+
# mask array
|
102 |
+
masks.append(segm)
|
103 |
+
# torch.from_numpy does not support array with negative stride.
|
104 |
+
masks = BitMasks(
|
105 |
+
torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
|
106 |
+
)
|
107 |
+
target.gt_masks = masks
|
108 |
+
|
109 |
+
return target
|
110 |
+
|
111 |
+
|
112 |
+
class AVISDatasetMapper:
|
113 |
+
"""
|
114 |
+
A callable which takes a dataset dict in AVIS Dataset format,
|
115 |
+
and map it into a format used by the model.
|
116 |
+
"""
|
117 |
+
|
118 |
+
@configurable
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
is_train: bool,
|
122 |
+
*,
|
123 |
+
augmentations: List[Union[T.Augmentation, T.Transform]],
|
124 |
+
image_format: str,
|
125 |
+
use_instance_mask: bool = False,
|
126 |
+
sampling_frame_num: int = 2,
|
127 |
+
sampling_frame_range: int = 5,
|
128 |
+
sampling_frame_shuffle: bool = False,
|
129 |
+
num_classes: int = 26,
|
130 |
+
):
|
131 |
+
"""
|
132 |
+
NOTE: this interface is experimental.
|
133 |
+
Args:
|
134 |
+
is_train: whether it's used in training or inference
|
135 |
+
augmentations: a list of augmentations or deterministic transforms to apply
|
136 |
+
image_format: an image format supported by :func:`detection_utils.read_image`.
|
137 |
+
use_instance_mask: whether to process instance segmentation annotations, if available
|
138 |
+
"""
|
139 |
+
# fmt: off
|
140 |
+
self.is_train = is_train
|
141 |
+
self.augmentations = T.AugmentationList(augmentations)
|
142 |
+
self.image_format = image_format
|
143 |
+
self.use_instance_mask = use_instance_mask
|
144 |
+
self.sampling_frame_num = sampling_frame_num
|
145 |
+
self.sampling_frame_range = sampling_frame_range
|
146 |
+
self.sampling_frame_shuffle = sampling_frame_shuffle
|
147 |
+
self.num_classes = num_classes
|
148 |
+
# fmt: on
|
149 |
+
logger = logging.getLogger(__name__)
|
150 |
+
mode = "training" if is_train else "inference"
|
151 |
+
logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}")
|
152 |
+
|
153 |
+
@classmethod
|
154 |
+
def from_config(cls, cfg, is_train: bool = True):
|
155 |
+
augs = build_augmentation(cfg, is_train)
|
156 |
+
|
157 |
+
sampling_frame_num = cfg.INPUT.SAMPLING_FRAME_NUM
|
158 |
+
sampling_frame_range = cfg.INPUT.SAMPLING_FRAME_RANGE
|
159 |
+
sampling_frame_shuffle = cfg.INPUT.SAMPLING_FRAME_SHUFFLE
|
160 |
+
|
161 |
+
ret = {
|
162 |
+
"is_train": is_train,
|
163 |
+
"augmentations": augs,
|
164 |
+
"image_format": cfg.INPUT.FORMAT,
|
165 |
+
"use_instance_mask": cfg.MODEL.MASK_ON,
|
166 |
+
"sampling_frame_num": sampling_frame_num,
|
167 |
+
"sampling_frame_range": sampling_frame_range,
|
168 |
+
"sampling_frame_shuffle": sampling_frame_shuffle,
|
169 |
+
"num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
|
170 |
+
}
|
171 |
+
|
172 |
+
return ret
|
173 |
+
|
174 |
+
def __call__(self, dataset_dict):
|
175 |
+
"""
|
176 |
+
Args:
|
177 |
+
dataset_dict (dict): Metadata of one video, in YTVIS Dataset format.
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
dict: a format that builtin models in detectron2 accept
|
181 |
+
"""
|
182 |
+
# TODO consider examining below deepcopy as it costs huge amount of computations.
|
183 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
184 |
+
|
185 |
+
video_length = dataset_dict["length"]
|
186 |
+
if self.is_train:
|
187 |
+
ref_frame = random.randrange(video_length)
|
188 |
+
|
189 |
+
start_idx = max(0, ref_frame-self.sampling_frame_range)
|
190 |
+
end_idx = min(video_length, ref_frame+self.sampling_frame_range + 1)
|
191 |
+
|
192 |
+
selected_idx = np.random.choice(
|
193 |
+
np.array(list(range(start_idx, ref_frame)) + list(range(ref_frame+1, end_idx))),
|
194 |
+
self.sampling_frame_num - 1,
|
195 |
+
)
|
196 |
+
selected_idx = selected_idx.tolist() + [ref_frame]
|
197 |
+
selected_idx = sorted(selected_idx)
|
198 |
+
if self.sampling_frame_shuffle:
|
199 |
+
random.shuffle(selected_idx)
|
200 |
+
else:
|
201 |
+
selected_idx = range(video_length)
|
202 |
+
|
203 |
+
video_annos = dataset_dict.pop("annotations", None)
|
204 |
+
file_names = dataset_dict.pop("file_names", None)
|
205 |
+
audio_feats = dataset_dict.pop("audio", None)
|
206 |
+
|
207 |
+
if self.is_train:
|
208 |
+
_ids = set()
|
209 |
+
for frame_idx in selected_idx:
|
210 |
+
_ids.update([anno["id"] for anno in video_annos[frame_idx]])
|
211 |
+
ids = dict()
|
212 |
+
for i, _id in enumerate(_ids):
|
213 |
+
ids[_id] = i
|
214 |
+
|
215 |
+
dataset_dict["image"] = []
|
216 |
+
dataset_dict["instances"] = []
|
217 |
+
dataset_dict["file_names"] = []
|
218 |
+
dataset_dict["audio"] = []
|
219 |
+
dataset_dict["frame_idx"] = list(selected_idx)
|
220 |
+
for frame_idx in selected_idx:
|
221 |
+
dataset_dict["file_names"].append(file_names[frame_idx])
|
222 |
+
dataset_dict["audio"].append(audio_feats[frame_idx])
|
223 |
+
|
224 |
+
# Read image
|
225 |
+
image = utils.read_image(file_names[frame_idx], format=self.image_format)
|
226 |
+
utils.check_image_size(dataset_dict, image)
|
227 |
+
|
228 |
+
aug_input = T.AugInput(image)
|
229 |
+
transforms = self.augmentations(aug_input)
|
230 |
+
image = aug_input.image
|
231 |
+
|
232 |
+
image_shape = image.shape[:2] # h, w
|
233 |
+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
234 |
+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
235 |
+
# Therefore it's important to use torch.Tensor.
|
236 |
+
dataset_dict["image"].append(torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))))
|
237 |
+
|
238 |
+
if (video_annos is None) or (not self.is_train):
|
239 |
+
continue
|
240 |
+
|
241 |
+
# NOTE copy() is to prevent annotations getting changed from applying augmentations
|
242 |
+
_frame_annos = []
|
243 |
+
for anno in video_annos[frame_idx]:
|
244 |
+
_anno = {}
|
245 |
+
for k, v in anno.items():
|
246 |
+
_anno[k] = copy.deepcopy(v)
|
247 |
+
_frame_annos.append(_anno)
|
248 |
+
|
249 |
+
# USER: Implement additional transformations if you have other types of data
|
250 |
+
annos = [
|
251 |
+
utils.transform_instance_annotations(obj, transforms, image_shape)
|
252 |
+
for obj in _frame_annos
|
253 |
+
if obj.get("iscrowd", 0) == 0
|
254 |
+
]
|
255 |
+
sorted_annos = [_get_dummy_anno(self.num_classes) for _ in range(len(ids))]
|
256 |
+
|
257 |
+
for _anno in annos:
|
258 |
+
idx = ids[_anno["id"]]
|
259 |
+
sorted_annos[idx] = _anno
|
260 |
+
_gt_ids = [_anno["id"] for _anno in sorted_annos]
|
261 |
+
|
262 |
+
instances = utils.annotations_to_instances(sorted_annos, image_shape, mask_format="bitmask")
|
263 |
+
instances.gt_ids = torch.tensor(_gt_ids)
|
264 |
+
if instances.has("gt_masks"):
|
265 |
+
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
|
266 |
+
instances = filter_empty_instances(instances)
|
267 |
+
else:
|
268 |
+
instances.gt_masks = BitMasks(torch.empty((0, *image_shape)))
|
269 |
+
dataset_dict["instances"].append(instances)
|
270 |
+
|
271 |
+
return dataset_dict
|
272 |
+
|
avism/data/datasets/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from . import builtin # ensure the builtin datasets are registered
|
2 |
+
|
3 |
+
__all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
|
avism/data/datasets/avis.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import io
|
3 |
+
import logging
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import pycocotools.mask as mask_util
|
7 |
+
from fvcore.common.file_io import PathManager
|
8 |
+
from fvcore.common.timer import Timer
|
9 |
+
|
10 |
+
from detectron2.structures import Boxes, BoxMode, PolygonMasks
|
11 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
12 |
+
|
13 |
+
from .avis_api.avos import AVOS
|
14 |
+
|
15 |
+
|
16 |
+
"""
|
17 |
+
This file contains functions to parse AVIS dataset of
|
18 |
+
COCO-format annotations into dicts in "Detectron2 format".
|
19 |
+
"""
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
__all__ = ["load_avis_json", "register_avis_instances"]
|
24 |
+
|
25 |
+
|
26 |
+
AVIS_CATEGORIES = [
|
27 |
+
{"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
|
28 |
+
{"color": [0, 82, 0], "isthing": 1, "id": 2, "name": "violin"},
|
29 |
+
{"color": [119, 11, 32], "isthing": 1, "id": 3, "name": "guitar"},
|
30 |
+
{"color": [165, 42, 42], "isthing": 1, "id": 4, "name": "cello"},
|
31 |
+
{"color": [134, 134, 103], "isthing": 1, "id": 5, "name": "flute"},
|
32 |
+
{"color": [0, 0, 142], "isthing": 1, "id": 6, "name": "piano"},
|
33 |
+
{"color": [255, 109, 65], "isthing": 1, "id": 7, "name": "ukulele"},
|
34 |
+
{"color": [0, 226, 252], "isthing": 1, "id": 8, "name": "accordion"},
|
35 |
+
{"color": [5, 121, 0], "isthing": 1, "id": 9, "name": "guzheng"},
|
36 |
+
{"color": [0, 60, 100], "isthing": 1, "id": 10, "name": "clarinet"},
|
37 |
+
{"color": [250, 170, 30], "isthing": 1, "id": 11, "name": "cat"},
|
38 |
+
{"color": [100, 170, 30], "isthing": 1, "id": 12, "name": "car"},
|
39 |
+
{"color": [179, 0, 194], "isthing": 1, "id": 13, "name": "saxophone"},
|
40 |
+
{"color": [255, 77, 255], "isthing": 1, "id": 14, "name": "dog"},
|
41 |
+
{"color": [120, 166, 157], "isthing": 1, "id": 15, "name": "lawn_mover"},
|
42 |
+
{"color": [73, 77, 174], "isthing": 1, "id": 16, "name": "tuba"},
|
43 |
+
{"color": [0, 80, 100], "isthing": 1, "id": 17, "name": "banjo"},
|
44 |
+
{"color": [182, 182, 255], "isthing": 1, "id": 18, "name": "pipa"},
|
45 |
+
{"color": [0, 143, 149], "isthing": 1, "id": 19, "name": "bassoon"},
|
46 |
+
{"color": [174, 57, 255], "isthing": 1, "id": 20, "name": "airplane"},
|
47 |
+
{"color": [0, 0, 230], "isthing": 1, "id": 21, "name": "tree_harvester"},
|
48 |
+
{"color": [72, 0, 118], "isthing": 1, "id": 22, "name": "trumpet"},
|
49 |
+
{"color": [255, 179, 240], "isthing": 1, "id": 23, "name": "lion"},
|
50 |
+
{"color": [0, 125, 92], "isthing": 1, "id": 24, "name": "bass"},
|
51 |
+
{"color": [209, 0, 151], "isthing": 1, "id": 25, "name": "erhu"},
|
52 |
+
{"color": [188, 208, 182], "isthing": 1, "id": 26, "name": "horse"}]
|
53 |
+
|
54 |
+
|
55 |
+
def _get_avis_instances_meta():
|
56 |
+
thing_ids = [k["id"] for k in AVIS_CATEGORIES if k["isthing"] == 1]
|
57 |
+
thing_colors = [k["color"] for k in AVIS_CATEGORIES if k["isthing"] == 1]
|
58 |
+
assert len(thing_ids) == 26, len(thing_ids)
|
59 |
+
# Mapping from the incontiguous AVIS category id to an id in [0, 25]
|
60 |
+
thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
|
61 |
+
thing_classes = [k["name"] for k in AVIS_CATEGORIES if k["isthing"] == 1]
|
62 |
+
ret = {
|
63 |
+
"thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
|
64 |
+
"thing_classes": thing_classes,
|
65 |
+
"thing_colors": thing_colors,
|
66 |
+
}
|
67 |
+
return ret
|
68 |
+
|
69 |
+
|
70 |
+
def load_avis_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None):
|
71 |
+
|
72 |
+
timer = Timer()
|
73 |
+
json_file = PathManager.get_local_path(json_file)
|
74 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
75 |
+
avis_api = AVOS(json_file)
|
76 |
+
if timer.seconds() > 1:
|
77 |
+
logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
|
78 |
+
|
79 |
+
id_map = None
|
80 |
+
if dataset_name is not None:
|
81 |
+
meta = MetadataCatalog.get(dataset_name)
|
82 |
+
cat_ids = sorted(avis_api.getCatIds())
|
83 |
+
cats = avis_api.loadCats(cat_ids)
|
84 |
+
# The categories in a custom json file may not be sorted.
|
85 |
+
thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])]
|
86 |
+
meta.thing_classes = thing_classes
|
87 |
+
|
88 |
+
# It works by looking at the "categories" field in the json, therefore
|
89 |
+
# if users' own json also have incontiguous ids, we'll
|
90 |
+
# apply this mapping as well but print a warning.
|
91 |
+
if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)):
|
92 |
+
if "coco" not in dataset_name:
|
93 |
+
logger.warning(
|
94 |
+
"""
|
95 |
+
Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.
|
96 |
+
"""
|
97 |
+
)
|
98 |
+
id_map = {v: i for i, v in enumerate(cat_ids)}
|
99 |
+
meta.thing_dataset_id_to_contiguous_id = id_map
|
100 |
+
|
101 |
+
# sort indices for reproducible results
|
102 |
+
vid_ids = sorted(avis_api.vids.keys())
|
103 |
+
vids = avis_api.loadVids(vid_ids)
|
104 |
+
|
105 |
+
anns = [avis_api.vidToAnns[vid_id] for vid_id in vid_ids]
|
106 |
+
total_num_valid_anns = sum([len(x) for x in anns])
|
107 |
+
total_num_anns = len(avis_api.anns)
|
108 |
+
if total_num_valid_anns < total_num_anns:
|
109 |
+
logger.warning(
|
110 |
+
f"{json_file} contains {total_num_anns} annotations, but only "
|
111 |
+
f"{total_num_valid_anns} of them match to images in the file."
|
112 |
+
)
|
113 |
+
|
114 |
+
vids_anns = list(zip(vids, anns))
|
115 |
+
logger.info("Loaded {} videos in AVIS format from {}".format(len(vids_anns), json_file))
|
116 |
+
|
117 |
+
dataset_dicts = []
|
118 |
+
|
119 |
+
ann_keys = ["iscrowd", "category_id", "id"] + (extra_annotation_keys or [])
|
120 |
+
|
121 |
+
num_instances_without_valid_segmentation = 0
|
122 |
+
|
123 |
+
for (vid_dict, anno_dict_list) in vids_anns:
|
124 |
+
record = {}
|
125 |
+
record["file_names"] = [os.path.join(image_root, vid_dict["file_names"][i]) for i in range(vid_dict["length"])]
|
126 |
+
record["height"] = vid_dict["height"]
|
127 |
+
record["width"] = vid_dict["width"]
|
128 |
+
record["length"] = vid_dict["length"]
|
129 |
+
video_id = record["video_id"] = vid_dict["id"]
|
130 |
+
|
131 |
+
video_objs = []
|
132 |
+
for frame_idx in range(record["length"]):
|
133 |
+
frame_objs = []
|
134 |
+
for anno in anno_dict_list:
|
135 |
+
assert anno["video_id"] == video_id
|
136 |
+
|
137 |
+
obj = {key: anno[key] for key in ann_keys if key in anno}
|
138 |
+
|
139 |
+
_bboxes = anno.get("bboxes", None)
|
140 |
+
_segm = anno.get("segmentations", None)
|
141 |
+
|
142 |
+
if not (_bboxes and _segm and _bboxes[frame_idx] and _segm[frame_idx]):
|
143 |
+
continue
|
144 |
+
|
145 |
+
bbox = _bboxes[frame_idx]
|
146 |
+
segm = _segm[frame_idx]
|
147 |
+
|
148 |
+
obj["bbox"] = bbox
|
149 |
+
obj["bbox_mode"] = BoxMode.XYWH_ABS
|
150 |
+
|
151 |
+
if isinstance(segm, dict):
|
152 |
+
if isinstance(segm["counts"], list):
|
153 |
+
# convert to compressed RLE
|
154 |
+
segm = mask_util.frPyObjects(segm, *segm["size"])
|
155 |
+
elif segm:
|
156 |
+
# filter out invalid polygons (< 3 points)
|
157 |
+
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
|
158 |
+
if len(segm) == 0:
|
159 |
+
num_instances_without_valid_segmentation += 1
|
160 |
+
continue # ignore this instance
|
161 |
+
obj["segmentation"] = segm
|
162 |
+
|
163 |
+
if id_map:
|
164 |
+
obj["category_id"] = id_map[obj["category_id"]]
|
165 |
+
frame_objs.append(obj)
|
166 |
+
video_objs.append(frame_objs)
|
167 |
+
record["annotations"] = video_objs
|
168 |
+
|
169 |
+
# audio:
|
170 |
+
audio_feats_pth = os.path.join(image_root[:-10], "FEATAudios", vid_dict['file_names'][0].split("/")[0] + '.npy')
|
171 |
+
record["audio"] = np.load(audio_feats_pth)
|
172 |
+
|
173 |
+
dataset_dicts.append(record)
|
174 |
+
|
175 |
+
if num_instances_without_valid_segmentation > 0:
|
176 |
+
logger.warning(
|
177 |
+
"Filtered out {} instances without valid segmentation. ".format(
|
178 |
+
num_instances_without_valid_segmentation
|
179 |
+
)
|
180 |
+
+ "There might be issues in your dataset generation process. "
|
181 |
+
"A valid polygon should be a list[float] with even length >= 6."
|
182 |
+
)
|
183 |
+
return dataset_dicts
|
184 |
+
|
185 |
+
|
186 |
+
def register_avis_instances(name, metadata, json_file, image_root):
|
187 |
+
"""
|
188 |
+
Register a dataset in AVIS's json annotation format for
|
189 |
+
instance tracking.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
name (str): the name that identifies a dataset, e.g. "avis_train".
|
193 |
+
metadata (dict): extra metadata associated with this dataset. You can
|
194 |
+
leave it as an empty dict.
|
195 |
+
json_file (str): path to the json instance annotation file.
|
196 |
+
image_root (str or path-like): directory which contains all the images.
|
197 |
+
"""
|
198 |
+
assert isinstance(name, str), name
|
199 |
+
assert isinstance(json_file, (str, os.PathLike)), json_file
|
200 |
+
assert isinstance(image_root, (str, os.PathLike)), image_root
|
201 |
+
# 1. register a function which returns dicts
|
202 |
+
DatasetCatalog.register(name, lambda: load_avis_json(json_file, image_root, name))
|
203 |
+
|
204 |
+
# 2. Optionally, add metadata about this dataset,
|
205 |
+
# since they might be useful in evaluation, visualization or logging
|
206 |
+
MetadataCatalog.get(name).set(
|
207 |
+
json_file=json_file, image_root=image_root, evaluator_type="avis", **metadata
|
208 |
+
)
|
209 |
+
|
avism/data/datasets/avis_api/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
avism/data/datasets/avis_api/avos.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# The following API functions are defined:
|
4 |
+
# AVOS - AVOS api class that loads AVIS annotation file and prepare data structures.
|
5 |
+
# decodeMask - Decode binary mask M encoded via run-length encoding.
|
6 |
+
# encodeMask - Encode binary mask M using run-length encoding.
|
7 |
+
# getAnnIds - Get ann ids that satisfy given filter conditions.
|
8 |
+
# getCatIds - Get cat ids that satisfy given filter conditions.
|
9 |
+
# getImgIds - Get img ids that satisfy given filter conditions.
|
10 |
+
# loadAnns - Load anns with the specified ids.
|
11 |
+
# loadCats - Load cats with the specified ids.
|
12 |
+
# loadImgs - Load imgs with the specified ids.
|
13 |
+
# annToMask - Convert segmentation in an annotation to binary mask.
|
14 |
+
# loadRes - Load algorithm results and create API for accessing them.
|
15 |
+
|
16 |
+
import json
|
17 |
+
import time
|
18 |
+
import numpy as np
|
19 |
+
import copy
|
20 |
+
import itertools
|
21 |
+
from pycocotools import mask as maskUtils
|
22 |
+
from collections import defaultdict
|
23 |
+
import sys
|
24 |
+
PYTHON_VERSION = sys.version_info[0]
|
25 |
+
if PYTHON_VERSION == 2:
|
26 |
+
from urllib import urlretrieve
|
27 |
+
elif PYTHON_VERSION == 3:
|
28 |
+
from urllib.request import urlretrieve
|
29 |
+
|
30 |
+
|
31 |
+
def _isArrayLike(obj):
|
32 |
+
return hasattr(obj, '__iter__') and hasattr(obj, '__len__')
|
33 |
+
|
34 |
+
|
35 |
+
class AVOS:
|
36 |
+
def __init__(self, annotation_file=None):
|
37 |
+
"""
|
38 |
+
Constructor of Microsoft COCO helper class for reading and visualizing annotations.
|
39 |
+
:param annotation_file (str): location of annotation file
|
40 |
+
:param image_folder (str): location to the folder that hosts images.
|
41 |
+
:return:
|
42 |
+
"""
|
43 |
+
# load dataset
|
44 |
+
self.dataset,self.anns,self.cats,self.vids = dict(),dict(),dict(),dict()
|
45 |
+
self.vidToAnns, self.catToVids = defaultdict(list), defaultdict(list)
|
46 |
+
if not annotation_file == None:
|
47 |
+
print('loading annotations into memory...')
|
48 |
+
tic = time.time()
|
49 |
+
dataset = json.load(open(annotation_file, 'r'))
|
50 |
+
assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
|
51 |
+
print('Done (t={:0.2f}s)'.format(time.time()- tic))
|
52 |
+
self.dataset = dataset
|
53 |
+
self.createIndex()
|
54 |
+
|
55 |
+
def createIndex(self):
|
56 |
+
# create index
|
57 |
+
print('creating index...')
|
58 |
+
anns, cats, vids = {}, {}, {}
|
59 |
+
vidToAnns,catToVids = defaultdict(list),defaultdict(list)
|
60 |
+
if 'annotations' in self.dataset:
|
61 |
+
for ann in self.dataset['annotations']:
|
62 |
+
vidToAnns[ann['video_id']].append(ann)
|
63 |
+
anns[ann['id']] = ann
|
64 |
+
|
65 |
+
if 'videos' in self.dataset:
|
66 |
+
for vid in self.dataset['videos']:
|
67 |
+
vids[vid['id']] = vid
|
68 |
+
|
69 |
+
if 'categories' in self.dataset:
|
70 |
+
for cat in self.dataset['categories']:
|
71 |
+
cats[cat['id']] = cat
|
72 |
+
|
73 |
+
if 'annotations' in self.dataset and 'categories' in self.dataset:
|
74 |
+
for ann in self.dataset['annotations']:
|
75 |
+
catToVids[ann['category_id']].append(ann['video_id'])
|
76 |
+
|
77 |
+
print('index created!')
|
78 |
+
|
79 |
+
# create class members
|
80 |
+
self.anns = anns
|
81 |
+
self.vidToAnns = vidToAnns
|
82 |
+
self.catToVids = catToVids
|
83 |
+
self.vids = vids
|
84 |
+
self.cats = cats
|
85 |
+
|
86 |
+
def info(self):
|
87 |
+
"""
|
88 |
+
Print information about the annotation file.
|
89 |
+
:return:
|
90 |
+
"""
|
91 |
+
for key, value in self.dataset['info'].items():
|
92 |
+
print('{}: {}'.format(key, value))
|
93 |
+
|
94 |
+
def getAnnIds(self, vidIds=[], catIds=[], areaRng=[], iscrowd=None):
|
95 |
+
"""
|
96 |
+
Get ann ids that satisfy given filter conditions. default skips that filter
|
97 |
+
:param vidIds (int array) : get anns for given vids
|
98 |
+
catIds (int array) : get anns for given cats
|
99 |
+
areaRng (float array) : get anns for given area range (e.g. [0 inf])
|
100 |
+
iscrowd (boolean) : get anns for given crowd label (False or True)
|
101 |
+
:return: ids (int array) : integer array of ann ids
|
102 |
+
"""
|
103 |
+
vidIds = vidIds if _isArrayLike(vidIds) else [vidIds]
|
104 |
+
catIds = catIds if _isArrayLike(catIds) else [catIds]
|
105 |
+
|
106 |
+
if len(vidIds) == len(catIds) == len(areaRng) == 0:
|
107 |
+
anns = self.dataset['annotations']
|
108 |
+
else:
|
109 |
+
if not len(vidIds) == 0:
|
110 |
+
lists = [self.vidToAnns[vidId] for vidId in vidIds if vidId in self.vidToAnns]
|
111 |
+
anns = list(itertools.chain.from_iterable(lists))
|
112 |
+
else:
|
113 |
+
anns = self.dataset['annotations']
|
114 |
+
anns = anns if len(catIds) == 0 else [ann for ann in anns if ann['category_id'] in catIds]
|
115 |
+
anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['avg_area'] > areaRng[0] and ann['avg_area'] < areaRng[1]]
|
116 |
+
if not iscrowd == None:
|
117 |
+
ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]
|
118 |
+
else:
|
119 |
+
ids = [ann['id'] for ann in anns]
|
120 |
+
return ids
|
121 |
+
|
122 |
+
def getCatIds(self, catNms=[], supNms=[], catIds=[]):
|
123 |
+
"""
|
124 |
+
filtering parameters. default skips that filter.
|
125 |
+
:param catNms (str array) : get cats for given cat names
|
126 |
+
:param supNms (str array) : get cats for given supercategory names
|
127 |
+
:param catIds (int array) : get cats for given cat ids
|
128 |
+
:return: ids (int array) : integer array of cat ids
|
129 |
+
"""
|
130 |
+
catNms = catNms if _isArrayLike(catNms) else [catNms]
|
131 |
+
supNms = supNms if _isArrayLike(supNms) else [supNms]
|
132 |
+
catIds = catIds if _isArrayLike(catIds) else [catIds]
|
133 |
+
|
134 |
+
if len(catNms) == len(supNms) == len(catIds) == 0:
|
135 |
+
cats = self.dataset['categories']
|
136 |
+
else:
|
137 |
+
cats = self.dataset['categories']
|
138 |
+
cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name'] in catNms]
|
139 |
+
cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms]
|
140 |
+
cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id'] in catIds]
|
141 |
+
ids = [cat['id'] for cat in cats]
|
142 |
+
return ids
|
143 |
+
|
144 |
+
def getVidIds(self, vidIds=[], catIds=[]):
|
145 |
+
'''
|
146 |
+
Get vid ids that satisfy given filter conditions.
|
147 |
+
:param vidIds (int array) : get vids for given ids
|
148 |
+
:param catIds (int array) : get vids with all given cats
|
149 |
+
:return: ids (int array) : integer array of vid ids
|
150 |
+
'''
|
151 |
+
vidIds = vidIds if _isArrayLike(vidIds) else [vidIds]
|
152 |
+
catIds = catIds if _isArrayLike(catIds) else [catIds]
|
153 |
+
|
154 |
+
if len(vidIds) == len(catIds) == 0:
|
155 |
+
ids = self.vids.keys()
|
156 |
+
else:
|
157 |
+
ids = set(vidIds)
|
158 |
+
for i, catId in enumerate(catIds):
|
159 |
+
if i == 0 and len(ids) == 0:
|
160 |
+
ids = set(self.catToVids[catId])
|
161 |
+
else:
|
162 |
+
ids &= set(self.catToVids[catId])
|
163 |
+
return list(ids)
|
164 |
+
|
165 |
+
def loadAnns(self, ids=[]):
|
166 |
+
"""
|
167 |
+
Load anns with the specified ids.
|
168 |
+
:param ids (int array) : integer ids specifying anns
|
169 |
+
:return: anns (object array) : loaded ann objects
|
170 |
+
"""
|
171 |
+
if _isArrayLike(ids):
|
172 |
+
return [self.anns[id] for id in ids]
|
173 |
+
elif type(ids) == int:
|
174 |
+
return [self.anns[ids]]
|
175 |
+
|
176 |
+
def loadCats(self, ids=[]):
|
177 |
+
"""
|
178 |
+
Load cats with the specified ids.
|
179 |
+
:param ids (int array) : integer ids specifying cats
|
180 |
+
:return: cats (object array) : loaded cat objects
|
181 |
+
"""
|
182 |
+
if _isArrayLike(ids):
|
183 |
+
return [self.cats[id] for id in ids]
|
184 |
+
elif type(ids) == int:
|
185 |
+
return [self.cats[ids]]
|
186 |
+
|
187 |
+
def loadVids(self, ids=[]):
|
188 |
+
"""
|
189 |
+
Load anns with the specified ids.
|
190 |
+
:param ids (int array) : integer ids specifying vid
|
191 |
+
:return: vids (object array) : loaded vid objects
|
192 |
+
"""
|
193 |
+
if _isArrayLike(ids):
|
194 |
+
return [self.vids[id] for id in ids]
|
195 |
+
elif type(ids) == int:
|
196 |
+
return [self.vids[ids]]
|
197 |
+
|
198 |
+
|
199 |
+
def loadRes(self, resFile):
|
200 |
+
"""
|
201 |
+
Load result file and return a result api object.
|
202 |
+
:param resFile (str) : file name of result file
|
203 |
+
:return: res (obj) : result api object
|
204 |
+
"""
|
205 |
+
res = AVOS()
|
206 |
+
res.dataset['videos'] = [img for img in self.dataset['videos']]
|
207 |
+
|
208 |
+
print('Loading and preparing results...')
|
209 |
+
tic = time.time()
|
210 |
+
if type(resFile) == str or (PYTHON_VERSION == 2 and type(resFile) == unicode):
|
211 |
+
anns = json.load(open(resFile))
|
212 |
+
elif type(resFile) == np.ndarray:
|
213 |
+
anns = self.loadNumpyAnnotations(resFile)
|
214 |
+
else:
|
215 |
+
anns = resFile
|
216 |
+
assert type(anns) == list, 'results in not an array of objects'
|
217 |
+
annsVidIds = [ann['video_id'] for ann in anns]
|
218 |
+
assert set(annsVidIds) == (set(annsVidIds) & set(self.getVidIds())), \
|
219 |
+
'Results do not correspond to current coco set'
|
220 |
+
if 'segmentations' in anns[0]:
|
221 |
+
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
|
222 |
+
for id, ann in enumerate(anns):
|
223 |
+
ann['areas'] = []
|
224 |
+
if not 'bboxes' in ann:
|
225 |
+
ann['bboxes'] = []
|
226 |
+
for seg in ann['segmentations']:
|
227 |
+
# now only support compressed RLE format as segmentation results
|
228 |
+
if seg:
|
229 |
+
ann['areas'].append(maskUtils.area(seg))
|
230 |
+
if len(ann['bboxes']) < len(ann['areas']):
|
231 |
+
ann['bboxes'].append(maskUtils.toBbox(seg))
|
232 |
+
else:
|
233 |
+
ann['areas'].append(None)
|
234 |
+
if len(ann['bboxes']) < len(ann['areas']):
|
235 |
+
ann['bboxes'].append(None)
|
236 |
+
ann['id'] = id+1
|
237 |
+
l = [a for a in ann['areas'] if a]
|
238 |
+
if len(l)==0:
|
239 |
+
ann['avg_area'] = 0
|
240 |
+
else:
|
241 |
+
ann['avg_area'] = np.array(l).mean()
|
242 |
+
ann['iscrowd'] = 0
|
243 |
+
print('DONE (t={:0.2f}s)'.format(time.time()- tic))
|
244 |
+
|
245 |
+
res.dataset['annotations'] = anns
|
246 |
+
res.createIndex()
|
247 |
+
return res
|
248 |
+
|
249 |
+
def annToRLE(self, ann, frameId):
|
250 |
+
"""
|
251 |
+
Convert annotation which can be polygons, uncompressed RLE to RLE.
|
252 |
+
:return: binary mask (numpy 2D array)
|
253 |
+
"""
|
254 |
+
t = self.vids[ann['video_id']]
|
255 |
+
h, w = t['height'], t['width']
|
256 |
+
segm = ann['segmentations'][frameId]
|
257 |
+
if type(segm) == list:
|
258 |
+
# polygon -- a single object might consist of multiple parts
|
259 |
+
# we merge all parts into one mask rle code
|
260 |
+
rles = maskUtils.frPyObjects(segm, h, w)
|
261 |
+
rle = maskUtils.merge(rles)
|
262 |
+
elif type(segm['counts']) == list:
|
263 |
+
# uncompressed RLE
|
264 |
+
rle = maskUtils.frPyObjects(segm, h, w)
|
265 |
+
else:
|
266 |
+
# rle
|
267 |
+
rle = segm
|
268 |
+
return rle
|
269 |
+
|
270 |
+
def annToMask(self, ann, frameId):
|
271 |
+
"""
|
272 |
+
Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask.
|
273 |
+
:return: binary mask (numpy 2D array)
|
274 |
+
"""
|
275 |
+
rle = self.annToRLE(ann, frameId)
|
276 |
+
m = maskUtils.decode(rle)
|
277 |
+
return m
|
avism/data/datasets/avis_api/avoseval.py
ADDED
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import datetime
|
5 |
+
import time
|
6 |
+
from collections import defaultdict
|
7 |
+
from pycocotools import mask as maskUtils
|
8 |
+
import copy
|
9 |
+
|
10 |
+
class AVOSeval:
|
11 |
+
# Interface for evaluating video instance segmentation on the AVIS dataset.
|
12 |
+
#
|
13 |
+
# The usage for AVOSeval is as follows:
|
14 |
+
# cocoGt=..., cocoDt=... # load dataset and results
|
15 |
+
# E = AVOSeval(cocoGt,cocoDt); # initialize AVOSeval object
|
16 |
+
# E.params.recThrs = ...; # set parameters as desired
|
17 |
+
# E.evaluate(); # run per image evaluation
|
18 |
+
# E.accumulate(); # accumulate per image results
|
19 |
+
# E.summarize(); # display summary metrics of results
|
20 |
+
#
|
21 |
+
# The evaluation parameters are as follows (defaults in brackets):
|
22 |
+
# imgIds - [all] N img ids to use for evaluation
|
23 |
+
# catIds - [all] K cat ids to use for evaluation
|
24 |
+
# iouThrs - [.5:.05:.95] T=10 IoU thresholds for evaluation
|
25 |
+
# recThrs - [0:.01:1] R=101 recall thresholds for evaluation
|
26 |
+
# areaRng - [...] A=4 object area ranges for evaluation
|
27 |
+
# maxDets - [1 10 100] M=3 thresholds on max detections per image
|
28 |
+
# iouType - ['segm'] set iouType to 'segm', 'bbox' or 'keypoints'
|
29 |
+
# iouType replaced the now DEPRECATED useSegm parameter.
|
30 |
+
# useCats - [1] if true use category labels for evaluation
|
31 |
+
# Note: if useCats=0 category labels are ignored as in proposal scoring.
|
32 |
+
# Note: multiple areaRngs [Ax2] and maxDets [Mx1] can be specified.
|
33 |
+
#
|
34 |
+
# evaluate(): evaluates detections on every image and every category and
|
35 |
+
# concats the results into the "evalImgs" with fields:
|
36 |
+
# dtIds - [1xD] id for each of the D detections (dt)
|
37 |
+
# gtIds - [1xG] id for each of the G ground truths (gt)
|
38 |
+
# dtMatches - [TxD] matching gt id at each IoU or 0
|
39 |
+
# gtMatches - [TxG] matching dt id at each IoU or 0
|
40 |
+
# dtScores - [1xD] confidence of each dt
|
41 |
+
# gtIgnore - [1xG] ignore flag for each gt
|
42 |
+
# dtIgnore - [TxD] ignore flag for each dt at each IoU
|
43 |
+
#
|
44 |
+
# accumulate(): accumulates the per-image, per-category evaluation
|
45 |
+
# results in "evalImgs" into the dictionary "eval" with fields:
|
46 |
+
# params - parameters used for evaluation
|
47 |
+
# date - date evaluation was performed
|
48 |
+
# counts - [T,R,K,A,M] parameter dimensions (see above)
|
49 |
+
# precision - [TxRxKxAxM] precision for every evaluation setting
|
50 |
+
# recall - [TxKxAxM] max recall for every evaluation setting
|
51 |
+
# Note: precision and recall==-1 for settings with no gt objects.
|
52 |
+
#
|
53 |
+
# See also coco, mask, pycocoDemo, pycocoEvalDemo
|
54 |
+
def __init__(self, cocoGt=None, cocoDt=None, iouType='segm'):
|
55 |
+
'''
|
56 |
+
Initialize CocoEval using coco APIs for gt and dt
|
57 |
+
:param cocoGt: coco object with ground truth annotations
|
58 |
+
:param cocoDt: coco object with detection results
|
59 |
+
:return: None
|
60 |
+
'''
|
61 |
+
if not iouType:
|
62 |
+
print('iouType not specified. use default iouType segm')
|
63 |
+
self.cocoGt = cocoGt # ground truth COCO API
|
64 |
+
self.cocoDt = cocoDt # detections COCO API
|
65 |
+
self.params = {} # evaluation parameters
|
66 |
+
self.evalVids = defaultdict(list) # per-image per-category evaluation results [KxAxI] elements
|
67 |
+
self.eval = {} # accumulated evaluation results
|
68 |
+
self._gts = defaultdict(list) # gt for evaluation
|
69 |
+
self._dts = defaultdict(list) # dt for evaluation
|
70 |
+
self.params = Params(iouType=iouType) # parameters
|
71 |
+
self._paramsEval = {} # parameters for evaluation
|
72 |
+
self.stats = [] # result summarization
|
73 |
+
self.ious = {} # ious between all gts and dts
|
74 |
+
if not cocoGt is None:
|
75 |
+
self.params.vidIds = sorted(cocoGt.getVidIds())
|
76 |
+
self.params.catIds = sorted(cocoGt.getCatIds())
|
77 |
+
|
78 |
+
|
79 |
+
def _prepare(self):
|
80 |
+
'''
|
81 |
+
Prepare ._gts and ._dts for evaluation based on params
|
82 |
+
:return: None
|
83 |
+
'''
|
84 |
+
def _toMask(anns, coco):
|
85 |
+
# modify ann['segmentation'] by reference
|
86 |
+
for ann in anns:
|
87 |
+
for i, a in enumerate(ann['segmentations']):
|
88 |
+
if a:
|
89 |
+
rle = coco.annToRLE(ann, i)
|
90 |
+
ann['segmentations'][i] = rle
|
91 |
+
l = [a for a in ann['areas'] if a]
|
92 |
+
if len(l)==0:
|
93 |
+
ann['avg_area'] = 0
|
94 |
+
else:
|
95 |
+
ann['avg_area'] = np.array(l).mean()
|
96 |
+
p = self.params
|
97 |
+
if p.useCats:
|
98 |
+
gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(vidIds=p.vidIds, catIds=p.catIds))
|
99 |
+
dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(vidIds=p.vidIds, catIds=p.catIds))
|
100 |
+
else:
|
101 |
+
gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(vidIds=p.vidIds))
|
102 |
+
dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(vidIds=p.vidIds))
|
103 |
+
|
104 |
+
# convert ground truth to mask if iouType == 'segm'
|
105 |
+
if p.iouType == 'segm':
|
106 |
+
_toMask(gts, self.cocoGt)
|
107 |
+
_toMask(dts, self.cocoDt)
|
108 |
+
# set ignore flag
|
109 |
+
for gt in gts:
|
110 |
+
gt['ignore'] = gt['ignore'] if 'ignore' in gt else 0
|
111 |
+
gt['ignore'] = 'iscrowd' in gt and gt['iscrowd']
|
112 |
+
if p.iouType == 'keypoints':
|
113 |
+
gt['ignore'] = (gt['num_keypoints'] == 0) or gt['ignore']
|
114 |
+
|
115 |
+
self._gts = defaultdict(list) # gt for evaluation
|
116 |
+
self._dts = defaultdict(list) # dt for evaluation
|
117 |
+
for gt in gts:
|
118 |
+
self._gts[gt['video_id'], gt['category_id']].append(gt)
|
119 |
+
for dt in dts:
|
120 |
+
self._dts[dt['video_id'], dt['category_id']].append(dt)
|
121 |
+
self.evalVids = defaultdict(list) # per-image per-category evaluation results
|
122 |
+
self.eval = {} # accumulated evaluation results
|
123 |
+
|
124 |
+
def evaluate(self):
|
125 |
+
'''
|
126 |
+
Run per image evaluation on given images and store results (a list of dict) in self.evalVids
|
127 |
+
:return: None
|
128 |
+
'''
|
129 |
+
tic = time.time()
|
130 |
+
print('Running per image evaluation...')
|
131 |
+
p = self.params
|
132 |
+
# add backward compatibility if useSegm is specified in params
|
133 |
+
if not p.useSegm is None:
|
134 |
+
p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
|
135 |
+
print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
|
136 |
+
print('Evaluate annotation type *{}*'.format(p.iouType))
|
137 |
+
p.vidIds = list(np.unique(p.vidIds))
|
138 |
+
if p.useCats:
|
139 |
+
p.catIds = list(np.unique(p.catIds))
|
140 |
+
p.maxDets = sorted(p.maxDets)
|
141 |
+
self.params=p
|
142 |
+
|
143 |
+
self._prepare()
|
144 |
+
# loop through images, area range, max detection number
|
145 |
+
catIds = p.catIds if p.useCats else [-1]
|
146 |
+
|
147 |
+
if p.iouType == 'segm' or p.iouType == 'bbox':
|
148 |
+
computeIoU = self.computeIoU
|
149 |
+
elif p.iouType == 'keypoints':
|
150 |
+
computeIoU = self.computeOks
|
151 |
+
self.ious = {(vidId, catId): computeIoU(vidId, catId) \
|
152 |
+
for vidId in p.vidIds
|
153 |
+
for catId in catIds}
|
154 |
+
|
155 |
+
evaluateVid = self.evaluateVid
|
156 |
+
maxDet = p.maxDets[-1]
|
157 |
+
|
158 |
+
|
159 |
+
self.evalImgs = [evaluateVid(vidId, catId, areaRng, maxDet)
|
160 |
+
for catId in catIds
|
161 |
+
for areaRng in p.areaRng
|
162 |
+
for vidId in p.vidIds
|
163 |
+
]
|
164 |
+
self._paramsEval = copy.deepcopy(self.params)
|
165 |
+
toc = time.time()
|
166 |
+
print('DONE (t={:0.2f}s).'.format(toc-tic))
|
167 |
+
|
168 |
+
def computeIoU(self, vidId, catId):
|
169 |
+
p = self.params
|
170 |
+
if p.useCats:
|
171 |
+
gt = self._gts[vidId,catId]
|
172 |
+
dt = self._dts[vidId,catId]
|
173 |
+
else:
|
174 |
+
gt = [_ for cId in p.catIds for _ in self._gts[vidId,cId]]
|
175 |
+
dt = [_ for cId in p.catIds for _ in self._dts[vidId,cId]]
|
176 |
+
if len(gt) == 0 and len(dt) ==0:
|
177 |
+
return []
|
178 |
+
inds = np.argsort([-d['score'] for d in dt], kind='mergesort')
|
179 |
+
dt = [dt[i] for i in inds]
|
180 |
+
if len(dt) > p.maxDets[-1]:
|
181 |
+
dt=dt[0:p.maxDets[-1]]
|
182 |
+
|
183 |
+
if p.iouType == 'segm':
|
184 |
+
g = [g['segmentations'] for g in gt]
|
185 |
+
d = [d['segmentations'] for d in dt]
|
186 |
+
elif p.iouType == 'bbox':
|
187 |
+
g = [g['bboxes'] for g in gt]
|
188 |
+
d = [d['bboxes'] for d in dt]
|
189 |
+
else:
|
190 |
+
raise Exception('unknown iouType for iou computation')
|
191 |
+
|
192 |
+
# compute iou between each dt and gt region
|
193 |
+
iscrowd = [int(o['iscrowd']) for o in gt]
|
194 |
+
#ious = maskUtils.iou(d,g,iscrowd)
|
195 |
+
def iou_seq(d_seq, g_seq):
|
196 |
+
i = .0
|
197 |
+
u = .0
|
198 |
+
for d, g in zip(d_seq, g_seq):
|
199 |
+
if d and g:
|
200 |
+
i += maskUtils.area(maskUtils.merge([d, g], True))
|
201 |
+
u += maskUtils.area(maskUtils.merge([d, g], False))
|
202 |
+
elif not d and g:
|
203 |
+
u += maskUtils.area(g)
|
204 |
+
elif d and not g:
|
205 |
+
u += maskUtils.area(d)
|
206 |
+
if not u > .0:
|
207 |
+
print("Mask sizes in video {} and category {} may not match!".format(vidId, catId))
|
208 |
+
iou = i / u if u > .0 else .0
|
209 |
+
return iou
|
210 |
+
ious = np.zeros([len(d), len(g)])
|
211 |
+
for i, j in np.ndindex(ious.shape):
|
212 |
+
ious[i, j] = iou_seq(d[i], g[j])
|
213 |
+
#print(vidId, catId, ious.shape, ious)
|
214 |
+
return ious
|
215 |
+
|
216 |
+
def computeOks(self, imgId, catId):
|
217 |
+
p = self.params
|
218 |
+
# dimention here should be Nxm
|
219 |
+
gts = self._gts[imgId, catId]
|
220 |
+
dts = self._dts[imgId, catId]
|
221 |
+
inds = np.argsort([-d['score'] for d in dts], kind='mergesort')
|
222 |
+
dts = [dts[i] for i in inds]
|
223 |
+
if len(dts) > p.maxDets[-1]:
|
224 |
+
dts = dts[0:p.maxDets[-1]]
|
225 |
+
# if len(gts) == 0 and len(dts) == 0:
|
226 |
+
if len(gts) == 0 or len(dts) == 0:
|
227 |
+
return []
|
228 |
+
ious = np.zeros((len(dts), len(gts)))
|
229 |
+
sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62,.62, 1.07, 1.07, .87, .87, .89, .89])/10.0
|
230 |
+
vars = (sigmas * 2)**2
|
231 |
+
k = len(sigmas)
|
232 |
+
# compute oks between each detection and ground truth object
|
233 |
+
for j, gt in enumerate(gts):
|
234 |
+
# create bounds for ignore regions(double the gt bbox)
|
235 |
+
g = np.array(gt['keypoints'])
|
236 |
+
xg = g[0::3]; yg = g[1::3]; vg = g[2::3]
|
237 |
+
k1 = np.count_nonzero(vg > 0)
|
238 |
+
bb = gt['bbox']
|
239 |
+
x0 = bb[0] - bb[2]; x1 = bb[0] + bb[2] * 2
|
240 |
+
y0 = bb[1] - bb[3]; y1 = bb[1] + bb[3] * 2
|
241 |
+
for i, dt in enumerate(dts):
|
242 |
+
d = np.array(dt['keypoints'])
|
243 |
+
xd = d[0::3]; yd = d[1::3]
|
244 |
+
if k1>0:
|
245 |
+
# measure the per-keypoint distance if keypoints visible
|
246 |
+
dx = xd - xg
|
247 |
+
dy = yd - yg
|
248 |
+
else:
|
249 |
+
# measure minimum distance to keypoints in (x0,y0) & (x1,y1)
|
250 |
+
z = np.zeros((k))
|
251 |
+
dx = np.max((z, x0-xd),axis=0)+np.max((z, xd-x1),axis=0)
|
252 |
+
dy = np.max((z, y0-yd),axis=0)+np.max((z, yd-y1),axis=0)
|
253 |
+
e = (dx**2 + dy**2) / vars / (gt['avg_area']+np.spacing(1)) / 2
|
254 |
+
if k1 > 0:
|
255 |
+
e=e[vg > 0]
|
256 |
+
ious[i, j] = np.sum(np.exp(-e)) / e.shape[0]
|
257 |
+
return ious
|
258 |
+
|
259 |
+
def evaluateVid(self, vidId, catId, aRng, maxDet):
|
260 |
+
'''
|
261 |
+
perform evaluation for single category and image
|
262 |
+
:return: dict (single image results)
|
263 |
+
'''
|
264 |
+
p = self.params
|
265 |
+
if p.useCats:
|
266 |
+
gt = self._gts[vidId,catId]
|
267 |
+
dt = self._dts[vidId,catId]
|
268 |
+
else:
|
269 |
+
gt = [_ for cId in p.catIds for _ in self._gts[vidId,cId]]
|
270 |
+
dt = [_ for cId in p.catIds for _ in self._dts[vidId,cId]]
|
271 |
+
if len(gt) == 0 and len(dt) ==0:
|
272 |
+
return None
|
273 |
+
|
274 |
+
for g in gt:
|
275 |
+
if g['ignore'] or (g['avg_area']<aRng[0] or g['avg_area']>aRng[1]):
|
276 |
+
g['_ignore'] = 1
|
277 |
+
else:
|
278 |
+
g['_ignore'] = 0
|
279 |
+
|
280 |
+
# sort dt highest score first, sort gt ignore last
|
281 |
+
gtind = np.argsort([g['_ignore'] for g in gt], kind='mergesort')
|
282 |
+
gt = [gt[i] for i in gtind]
|
283 |
+
dtind = np.argsort([-d['score'] for d in dt], kind='mergesort')
|
284 |
+
dt = [dt[i] for i in dtind[0:maxDet]]
|
285 |
+
iscrowd = [int(o['iscrowd']) for o in gt]
|
286 |
+
# load computed ious
|
287 |
+
ious = self.ious[vidId, catId][:, gtind] if len(self.ious[vidId, catId]) > 0 else self.ious[vidId, catId]
|
288 |
+
|
289 |
+
T = len(p.iouThrs)
|
290 |
+
G = len(gt)
|
291 |
+
D = len(dt)
|
292 |
+
gtm = np.zeros((T,G))
|
293 |
+
dtm = np.zeros((T,D))
|
294 |
+
gtIg = np.array([g['_ignore'] for g in gt])
|
295 |
+
dtIg = np.zeros((T,D))
|
296 |
+
if not len(ious)==0:
|
297 |
+
for tind, t in enumerate(p.iouThrs):
|
298 |
+
for dind, d in enumerate(dt):
|
299 |
+
# information about best match so far (m=-1 -> unmatched)
|
300 |
+
iou = min([t,1-1e-10])
|
301 |
+
m = -1
|
302 |
+
for gind, g in enumerate(gt):
|
303 |
+
# if this gt already matched, and not a crowd, continue
|
304 |
+
if gtm[tind,gind]>0 and not iscrowd[gind]:
|
305 |
+
continue
|
306 |
+
# if dt matched to reg gt, and on ignore gt, stop
|
307 |
+
if m>-1 and gtIg[m]==0 and gtIg[gind]==1:
|
308 |
+
break
|
309 |
+
# continue to next gt unless better match made
|
310 |
+
if ious[dind,gind] < iou:
|
311 |
+
continue
|
312 |
+
# if match successful and best so far, store appropriately
|
313 |
+
iou=ious[dind,gind]
|
314 |
+
m=gind
|
315 |
+
# if match made store id of match for both dt and gt
|
316 |
+
if m ==-1:
|
317 |
+
continue
|
318 |
+
dtIg[tind,dind] = gtIg[m]
|
319 |
+
dtm[tind,dind] = gt[m]['id']
|
320 |
+
gtm[tind,m] = d['id']
|
321 |
+
# set unmatched detections outside of area range to ignore
|
322 |
+
a = np.array([d['avg_area']<aRng[0] or d['avg_area']>aRng[1] for d in dt]).reshape((1, len(dt)))
|
323 |
+
dtIg = np.logical_or(dtIg, np.logical_and(dtm==0, np.repeat(a,T,0)))
|
324 |
+
# store results for given image and category
|
325 |
+
return {
|
326 |
+
'video_id': vidId,
|
327 |
+
'category_id': catId,
|
328 |
+
'aRng': aRng,
|
329 |
+
'maxDet': maxDet,
|
330 |
+
'dtIds': [d['id'] for d in dt],
|
331 |
+
'gtIds': [g['id'] for g in gt],
|
332 |
+
'dtMatches': dtm,
|
333 |
+
'gtMatches': gtm,
|
334 |
+
'dtScores': [d['score'] for d in dt],
|
335 |
+
'gtIgnore': gtIg,
|
336 |
+
'dtIgnore': dtIg,
|
337 |
+
}
|
338 |
+
|
339 |
+
def accumulate(self, p = None):
|
340 |
+
'''
|
341 |
+
Accumulate per image evaluation results and store the result in self.eval
|
342 |
+
:param p: input params for evaluation
|
343 |
+
:return: None
|
344 |
+
'''
|
345 |
+
print('Accumulating evaluation results...')
|
346 |
+
tic = time.time()
|
347 |
+
if not self.evalImgs:
|
348 |
+
print('Please run evaluate() first')
|
349 |
+
# allows input customized parameters
|
350 |
+
if p is None:
|
351 |
+
p = self.params
|
352 |
+
p.catIds = p.catIds if p.useCats == 1 else [-1]
|
353 |
+
T = len(p.iouThrs)
|
354 |
+
R = len(p.recThrs)
|
355 |
+
K = len(p.catIds) if p.useCats else 1
|
356 |
+
A = len(p.areaRng)
|
357 |
+
M = len(p.maxDets)
|
358 |
+
precision = -np.ones((T,R,K,A,M)) # -1 for the precision of absent categories
|
359 |
+
recall = -np.ones((T,K,A,M))
|
360 |
+
scores = -np.ones((T,R,K,A,M))
|
361 |
+
|
362 |
+
# create dictionary for future indexing
|
363 |
+
_pe = self._paramsEval
|
364 |
+
catIds = _pe.catIds if _pe.useCats else [-1]
|
365 |
+
setK = set(catIds)
|
366 |
+
setA = set(map(tuple, _pe.areaRng))
|
367 |
+
setM = set(_pe.maxDets)
|
368 |
+
setI = set(_pe.vidIds)
|
369 |
+
# get inds to evaluate
|
370 |
+
k_list = [n for n, k in enumerate(p.catIds) if k in setK]
|
371 |
+
m_list = [m for n, m in enumerate(p.maxDets) if m in setM]
|
372 |
+
a_list = [n for n, a in enumerate(map(lambda x: tuple(x), p.areaRng)) if a in setA]
|
373 |
+
i_list = [n for n, i in enumerate(p.vidIds) if i in setI]
|
374 |
+
I0 = len(_pe.vidIds)
|
375 |
+
A0 = len(_pe.areaRng)
|
376 |
+
# retrieve E at each category, area range, and max number of detections
|
377 |
+
for k, k0 in enumerate(k_list):
|
378 |
+
Nk = k0*A0*I0
|
379 |
+
for a, a0 in enumerate(a_list):
|
380 |
+
Na = a0*I0
|
381 |
+
for m, maxDet in enumerate(m_list):
|
382 |
+
E = [self.evalImgs[Nk + Na + i] for i in i_list]
|
383 |
+
E = [e for e in E if not e is None]
|
384 |
+
if len(E) == 0:
|
385 |
+
continue
|
386 |
+
dtScores = np.concatenate([e['dtScores'][0:maxDet] for e in E])
|
387 |
+
|
388 |
+
# different sorting method generates slightly different results.
|
389 |
+
# mergesort is used to be consistent as Matlab implementation.
|
390 |
+
inds = np.argsort(-dtScores, kind='mergesort')
|
391 |
+
dtScoresSorted = dtScores[inds]
|
392 |
+
|
393 |
+
dtm = np.concatenate([e['dtMatches'][:,0:maxDet] for e in E], axis=1)[:,inds]
|
394 |
+
dtIg = np.concatenate([e['dtIgnore'][:,0:maxDet] for e in E], axis=1)[:,inds]
|
395 |
+
gtIg = np.concatenate([e['gtIgnore'] for e in E])
|
396 |
+
npig = np.count_nonzero(gtIg==0 )
|
397 |
+
if npig == 0:
|
398 |
+
continue
|
399 |
+
tps = np.logical_and( dtm, np.logical_not(dtIg) )
|
400 |
+
fps = np.logical_and(np.logical_not(dtm), np.logical_not(dtIg) )
|
401 |
+
|
402 |
+
tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float64)
|
403 |
+
fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float64)
|
404 |
+
for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
|
405 |
+
tp = np.array(tp)
|
406 |
+
fp = np.array(fp)
|
407 |
+
nd = len(tp)
|
408 |
+
rc = tp / npig
|
409 |
+
pr = tp / (fp+tp+np.spacing(1))
|
410 |
+
q = np.zeros((R,))
|
411 |
+
ss = np.zeros((R,))
|
412 |
+
|
413 |
+
if nd:
|
414 |
+
recall[t,k,a,m] = rc[-1]
|
415 |
+
else:
|
416 |
+
recall[t,k,a,m] = 0
|
417 |
+
|
418 |
+
# numpy is slow without cython optimization for accessing elements
|
419 |
+
# use python array gets significant speed improvement
|
420 |
+
pr = pr.tolist(); q = q.tolist()
|
421 |
+
|
422 |
+
for i in range(nd-1, 0, -1):
|
423 |
+
if pr[i] > pr[i-1]:
|
424 |
+
pr[i-1] = pr[i]
|
425 |
+
|
426 |
+
inds = np.searchsorted(rc, p.recThrs, side='left')
|
427 |
+
try:
|
428 |
+
for ri, pi in enumerate(inds):
|
429 |
+
q[ri] = pr[pi]
|
430 |
+
ss[ri] = dtScoresSorted[pi]
|
431 |
+
except:
|
432 |
+
pass
|
433 |
+
precision[t,:,k,a,m] = np.array(q)
|
434 |
+
scores[t,:,k,a,m] = np.array(ss)
|
435 |
+
self.eval = {
|
436 |
+
'params': p,
|
437 |
+
'counts': [T, R, K, A, M],
|
438 |
+
'date': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
439 |
+
'precision': precision,
|
440 |
+
'recall': recall,
|
441 |
+
'scores': scores,
|
442 |
+
}
|
443 |
+
toc = time.time()
|
444 |
+
print('DONE (t={:0.2f}s).'.format( toc-tic))
|
445 |
+
|
446 |
+
def summarize(self):
|
447 |
+
'''
|
448 |
+
Compute and display summary metrics for evaluation results.
|
449 |
+
Note this functin can *only* be applied on the default parameter setting
|
450 |
+
'''
|
451 |
+
def _summarize( ap=1, iouThr=None, areaRng='all', maxDets=100 ):
|
452 |
+
p = self.params
|
453 |
+
iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}'
|
454 |
+
titleStr = 'Average Precision' if ap == 1 else 'Average Recall'
|
455 |
+
typeStr = '(AP)' if ap==1 else '(AR)'
|
456 |
+
iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
|
457 |
+
if iouThr is None else '{:0.2f}'.format(iouThr)
|
458 |
+
|
459 |
+
aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
|
460 |
+
mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
|
461 |
+
if ap == 1:
|
462 |
+
# dimension of precision: [TxRxKxAxM]
|
463 |
+
s = self.eval['precision']
|
464 |
+
# IoU
|
465 |
+
if iouThr is not None:
|
466 |
+
t = np.where(iouThr == p.iouThrs)[0]
|
467 |
+
s = s[t]
|
468 |
+
s = s[:,:,:,aind,mind]
|
469 |
+
else:
|
470 |
+
# dimension of recall: [TxKxAxM]
|
471 |
+
s = self.eval['recall']
|
472 |
+
if iouThr is not None:
|
473 |
+
t = np.where(iouThr == p.iouThrs)[0]
|
474 |
+
s = s[t]
|
475 |
+
s = s[:,:,aind,mind]
|
476 |
+
if len(s[s>-1])==0:
|
477 |
+
mean_s = -1
|
478 |
+
else:
|
479 |
+
mean_s = np.mean(s[s>-1])
|
480 |
+
print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s))
|
481 |
+
return mean_s
|
482 |
+
def _summarizeDets():
|
483 |
+
stats = np.zeros((12,))
|
484 |
+
stats[0] = _summarize(1)
|
485 |
+
stats[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2])
|
486 |
+
stats[2] = _summarize(1, iouThr=.75, maxDets=self.params.maxDets[2])
|
487 |
+
stats[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2])
|
488 |
+
stats[4] = _summarize(1, areaRng='medium', maxDets=self.params.maxDets[2])
|
489 |
+
stats[5] = _summarize(1, areaRng='large', maxDets=self.params.maxDets[2])
|
490 |
+
stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
|
491 |
+
stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
|
492 |
+
stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
|
493 |
+
stats[9] = _summarize(0, areaRng='small', maxDets=self.params.maxDets[2])
|
494 |
+
stats[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2])
|
495 |
+
stats[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2])
|
496 |
+
return stats
|
497 |
+
def _summarizeKps():
|
498 |
+
stats = np.zeros((10,))
|
499 |
+
stats[0] = _summarize(1, maxDets=20)
|
500 |
+
stats[1] = _summarize(1, maxDets=20, iouThr=.5)
|
501 |
+
stats[2] = _summarize(1, maxDets=20, iouThr=.75)
|
502 |
+
stats[3] = _summarize(1, maxDets=20, areaRng='medium')
|
503 |
+
stats[4] = _summarize(1, maxDets=20, areaRng='large')
|
504 |
+
stats[5] = _summarize(0, maxDets=20)
|
505 |
+
stats[6] = _summarize(0, maxDets=20, iouThr=.5)
|
506 |
+
stats[7] = _summarize(0, maxDets=20, iouThr=.75)
|
507 |
+
stats[8] = _summarize(0, maxDets=20, areaRng='medium')
|
508 |
+
stats[9] = _summarize(0, maxDets=20, areaRng='large')
|
509 |
+
return stats
|
510 |
+
if not self.eval:
|
511 |
+
raise Exception('Please run accumulate() first')
|
512 |
+
iouType = self.params.iouType
|
513 |
+
if iouType == 'segm' or iouType == 'bbox':
|
514 |
+
summarize = _summarizeDets
|
515 |
+
elif iouType == 'keypoints':
|
516 |
+
summarize = _summarizeKps
|
517 |
+
self.stats = summarize()
|
518 |
+
|
519 |
+
def __str__(self):
|
520 |
+
self.summarize()
|
521 |
+
|
522 |
+
class Params:
|
523 |
+
'''
|
524 |
+
Params for coco evaluation api
|
525 |
+
'''
|
526 |
+
def setDetParams(self):
|
527 |
+
self.vidIds = []
|
528 |
+
self.catIds = []
|
529 |
+
# np.arange causes trouble. the data point on arange is slightly larger than the true value
|
530 |
+
#self.iouThrs = np.linspace(.5, 0.95, np.round((0.95 - .5) / .05) + 1, endpoint=True)
|
531 |
+
#self.recThrs = np.linspace(.0, 1.00, np.round((1.00 - .0) / .01) + 1, endpoint=True)
|
532 |
+
self.iouThrs = np.linspace(.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
|
533 |
+
self.recThrs = np.linspace(.0, 1.00, int(np.round((1.00 - .0) / .01)) + 1, endpoint=True)
|
534 |
+
self.maxDets = [1, 10, 100]
|
535 |
+
self.areaRng = [[0 ** 2, 1e5 ** 2], [0 ** 2, 128 ** 2], [ 128 ** 2, 256 ** 2], [256 ** 2, 1e5 ** 2]]
|
536 |
+
self.areaRngLbl = ['all', 'small', 'medium', 'large']
|
537 |
+
self.useCats = 1
|
538 |
+
|
539 |
+
def setKpParams(self):
|
540 |
+
self.vidIds = []
|
541 |
+
self.catIds = []
|
542 |
+
# np.arange causes trouble. the data point on arange is slightly larger than the true value
|
543 |
+
self.iouThrs = np.linspace(.5, 0.95, np.round((0.95 - .5) / .05) + 1, endpoint=True)
|
544 |
+
self.recThrs = np.linspace(.0, 1.00, np.round((1.00 - .0) / .01) + 1, endpoint=True)
|
545 |
+
self.maxDets = [20]
|
546 |
+
self.areaRng = [[0 ** 2, 1e5 ** 2], [32 ** 2, 96 ** 2], [96 ** 2, 1e5 ** 2]]
|
547 |
+
self.areaRngLbl = ['all', 'medium', 'large']
|
548 |
+
self.useCats = 1
|
549 |
+
|
550 |
+
def __init__(self, iouType='segm'):
|
551 |
+
if iouType == 'segm' or iouType == 'bbox':
|
552 |
+
self.setDetParams()
|
553 |
+
elif iouType == 'keypoints':
|
554 |
+
self.setKpParams()
|
555 |
+
else:
|
556 |
+
raise Exception('iouType not supported')
|
557 |
+
self.iouType = iouType
|
558 |
+
# useSegm is deprecated
|
559 |
+
self.useSegm = None
|
avism/data/datasets/builtin.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from .avis import (
|
4 |
+
register_avis_instances,
|
5 |
+
_get_avis_instances_meta,
|
6 |
+
)
|
7 |
+
|
8 |
+
|
9 |
+
# ==== Predefined splits for AVIS ===========
|
10 |
+
_PREDEFINED_SPLITS_AVIS = {
|
11 |
+
"avis_train": ("train/JPEGImages", "train.json"),
|
12 |
+
"avis_val": ("val/JPEGImages", "val.json"),
|
13 |
+
"avis_test": ("test/JPEGImages", "test.json"),
|
14 |
+
}
|
15 |
+
|
16 |
+
def register_all_avis(root):
|
17 |
+
for key, (image_root, json_file) in _PREDEFINED_SPLITS_AVIS.items():
|
18 |
+
# Assume pre-defined datasets live in `./datasets`.
|
19 |
+
register_avis_instances(
|
20 |
+
key,
|
21 |
+
_get_avis_instances_meta(),
|
22 |
+
os.path.join(root, json_file) if "://" not in json_file else json_file,
|
23 |
+
os.path.join(root, image_root),
|
24 |
+
)
|
25 |
+
|
26 |
+
if __name__.endswith(".builtin"):
|
27 |
+
# Assume pre-defined datasets live in `./datasets`.
|
28 |
+
_root = os.getenv("DETECTRON2_DATASETS", "datasets")
|
29 |
+
register_all_avis(_root)
|
avism/data/datasets/extract_audio_feat/audio_feature_extractor.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
3 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "1" # set gpu number
|
4 |
+
import numpy as np
|
5 |
+
import tensorflow as tf
|
6 |
+
|
7 |
+
import vggish_input
|
8 |
+
import vggish_params
|
9 |
+
import vggish_slim
|
10 |
+
import contextlib
|
11 |
+
import wave
|
12 |
+
|
13 |
+
|
14 |
+
# get audio length
|
15 |
+
def get_audio_len(audio_file):
|
16 |
+
with contextlib.closing(wave.open(audio_file, 'r')) as f:
|
17 |
+
frames = f.getnframes()
|
18 |
+
rate = f.getframerate()
|
19 |
+
wav_length = int(frames / float(rate))
|
20 |
+
return wav_length
|
21 |
+
|
22 |
+
# Paths to downloaded VGGish files.
|
23 |
+
checkpoint_path = './vggish_model.ckpt'
|
24 |
+
pca_params_path = './vggish_pca_params.npz'
|
25 |
+
freq = 1000
|
26 |
+
sr = 44100
|
27 |
+
|
28 |
+
|
29 |
+
audio_root = "./datasets/"
|
30 |
+
for subset in ["train", "val", "test"]:
|
31 |
+
print("{} ----------> ".format(subset))
|
32 |
+
|
33 |
+
audio_dir = os.path.join(audio_root, subset, "WAVAudios")
|
34 |
+
save_dir = os.path.join(audio_root, subset, "FEATAudios")
|
35 |
+
if not os.path.exists(save_dir):
|
36 |
+
os.makedirs(save_dir)
|
37 |
+
|
38 |
+
lis = sorted(os.listdir(audio_dir))
|
39 |
+
len_data = len(lis)
|
40 |
+
print(len_data)
|
41 |
+
|
42 |
+
i = 0
|
43 |
+
for n in range(len_data):
|
44 |
+
i += 1
|
45 |
+
# save file
|
46 |
+
outfile = os.path.join(save_dir, lis[n][:-4] + '.npy')
|
47 |
+
if os.path.exists(outfile):
|
48 |
+
print("\nProcessing: ", i, " / ", len_data, " ----> ", lis[n][:-4] + '.npy', " is already exist! ")
|
49 |
+
continue
|
50 |
+
|
51 |
+
'''feature learning by VGG-net trained by audioset'''
|
52 |
+
audio_index = os.path.join(audio_dir, lis[n]) # path of your audio files
|
53 |
+
num_secs = len(os.listdir(os.path.join(audio_root, subset, "JPEGImages", lis[n][:-4])))
|
54 |
+
# num_secs_real = get_audio_len(audio_index)
|
55 |
+
# print("\nProcessing: ", i, " / ", len_data, " --------> video: ", lis[n], " ---> sec: ", num_secs_real)
|
56 |
+
|
57 |
+
input_batch = vggish_input.wavfile_to_examples(audio_index, num_secs)
|
58 |
+
np.testing.assert_equal(
|
59 |
+
input_batch.shape,
|
60 |
+
[num_secs, vggish_params.NUM_FRAMES, vggish_params.NUM_BANDS])
|
61 |
+
|
62 |
+
# Define VGGish, load the checkpoint, and run the batch through the model to produce embeddings.
|
63 |
+
with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
|
64 |
+
vggish_slim.define_vggish_slim()
|
65 |
+
vggish_slim.load_vggish_slim_checkpoint(sess, checkpoint_path)
|
66 |
+
|
67 |
+
features_tensor = sess.graph.get_tensor_by_name(vggish_params.INPUT_TENSOR_NAME)
|
68 |
+
embedding_tensor = sess.graph.get_tensor_by_name(vggish_params.OUTPUT_TENSOR_NAME)
|
69 |
+
[embedding_batch] = sess.run([embedding_tensor], feed_dict={features_tensor: input_batch})
|
70 |
+
np.save(outfile, embedding_batch)
|
71 |
+
print(" save info: ", lis[n][:-4] + '.npy', " ---> ", embedding_batch.shape)
|
72 |
+
|
73 |
+
i += 1
|
74 |
+
|
75 |
+
print("\n---------------------------------- end ----------------------------------\n")
|
76 |
+
|
77 |
+
|
avism/data/datasets/extract_audio_feat/mel_features.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Defines routines to compute mel spectrogram features from audio waveform."""
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
|
20 |
+
|
21 |
+
def frame(data, window_length, hop_length):
|
22 |
+
"""Convert array into a sequence of successive possibly overlapping frames.
|
23 |
+
|
24 |
+
An n-dimensional array of shape (num_samples, ...) is converted into an
|
25 |
+
(n+1)-D array of shape (num_frames, window_length, ...), where each frame
|
26 |
+
starts hop_length points after the preceding one.
|
27 |
+
|
28 |
+
This is accomplished using stride_tricks, so the original data is not
|
29 |
+
copied. However, there is no zero-padding, so any incomplete frames at the
|
30 |
+
end are not included.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
data: np.array of dimension N >= 1.
|
34 |
+
window_length: Number of samples in each frame.
|
35 |
+
hop_length: Advance (in samples) between each window.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
(N+1)-D np.array with as many rows as there are complete frames that can be
|
39 |
+
extracted.
|
40 |
+
"""
|
41 |
+
|
42 |
+
# print("data: ", data.shape)
|
43 |
+
num_samples = data.shape[0]
|
44 |
+
# print("num_samples: ", num_samples)
|
45 |
+
num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length))
|
46 |
+
# print("num_frames: ", num_frames)
|
47 |
+
shape = (num_frames, window_length) + data.shape[1:]
|
48 |
+
# print("shape: ", shape)
|
49 |
+
strides = (data.strides[0] * hop_length,) + data.strides
|
50 |
+
# print("strides: ", strides)
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
# 按shape进行分块
|
55 |
+
return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
|
56 |
+
|
57 |
+
|
58 |
+
def periodic_hann(window_length):
|
59 |
+
"""Calculate a "periodic" Hann window.
|
60 |
+
|
61 |
+
The classic Hann window is defined as a raised cosine that starts and
|
62 |
+
ends on zero, and where every value appears twice, except the middle
|
63 |
+
point for an odd-length window. Matlab calls this a "symmetric" window
|
64 |
+
and np.hanning() returns it. However, for Fourier analysis, this
|
65 |
+
actually represents just over one cycle of a period N-1 cosine, and
|
66 |
+
thus is not compactly expressed on a length-N Fourier basis. Instead,
|
67 |
+
it's better to use a raised cosine that ends just before the final
|
68 |
+
zero value - i.e. a complete cycle of a period-N cosine. Matlab
|
69 |
+
calls this a "periodic" window. This routine calculates it.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
window_length: The number of points in the returned window.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
A 1D np.array containing the periodic hann window.
|
76 |
+
"""
|
77 |
+
return 0.5 - (0.5 * np.cos(2 * np.pi / window_length *
|
78 |
+
np.arange(window_length)))
|
79 |
+
|
80 |
+
|
81 |
+
def stft_magnitude(signal, fft_length,
|
82 |
+
hop_length=None,
|
83 |
+
window_length=None):
|
84 |
+
"""Calculate the short-time Fourier transform magnitude.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
signal: 1D np.array of the input time-domain signal.
|
88 |
+
fft_length: Size of the FFT to apply.
|
89 |
+
hop_length: Advance (in samples) between each frame passed to FFT.
|
90 |
+
window_length: Length of each block of samples to pass to FFT.
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
2D np.array where each row contains the magnitudes of the fft_length/2+1
|
94 |
+
unique values of the FFT for the corresponding frame of input samples.
|
95 |
+
"""
|
96 |
+
frames = frame(signal, window_length, hop_length)
|
97 |
+
# Apply frame window to each frame. We use a periodic Hann (cosine of period
|
98 |
+
# window_length) instead of the symmetric Hann of np.hanning (period
|
99 |
+
# window_length-1).
|
100 |
+
window = periodic_hann(window_length)
|
101 |
+
windowed_frames = frames * window
|
102 |
+
return np.abs(np.fft.rfft(windowed_frames, int(fft_length)))
|
103 |
+
|
104 |
+
|
105 |
+
# Mel spectrum constants and functions.
|
106 |
+
_MEL_BREAK_FREQUENCY_HERTZ = 700.0
|
107 |
+
_MEL_HIGH_FREQUENCY_Q = 1127.0
|
108 |
+
|
109 |
+
|
110 |
+
def hertz_to_mel(frequencies_hertz):
|
111 |
+
"""Convert frequencies to mel scale using HTK formula.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
frequencies_hertz: Scalar or np.array of frequencies in hertz.
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
Object of same size as frequencies_hertz containing corresponding values
|
118 |
+
on the mel scale.
|
119 |
+
"""
|
120 |
+
return _MEL_HIGH_FREQUENCY_Q * np.log(
|
121 |
+
1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
|
122 |
+
|
123 |
+
|
124 |
+
def spectrogram_to_mel_matrix(num_mel_bins=20,
|
125 |
+
num_spectrogram_bins=129,
|
126 |
+
audio_sample_rate=8000,
|
127 |
+
lower_edge_hertz=125.0,
|
128 |
+
upper_edge_hertz=3800.0):
|
129 |
+
"""Return a matrix that can post-multiply spectrogram rows to make mel.
|
130 |
+
|
131 |
+
Returns a np.array matrix A that can be used to post-multiply a matrix S of
|
132 |
+
spectrogram values (STFT magnitudes) arranged as frames x bins to generate a
|
133 |
+
"mel spectrogram" M of frames x num_mel_bins. M = S A.
|
134 |
+
|
135 |
+
The classic HTK algorithm exploits the complementarity of adjacent mel bands
|
136 |
+
to multiply each FFT bin by only one mel weight, then add it, with positive
|
137 |
+
and negative signs, to the two adjacent mel bands to which that bin
|
138 |
+
contributes. Here, by expressing this operation as a matrix multiply, we go
|
139 |
+
from num_fft multiplies per frame (plus around 2*num_fft adds) to around
|
140 |
+
num_fft^2 multiplies and adds. However, because these are all presumably
|
141 |
+
accomplished in a single call to np.dot(), it's not clear which approach is
|
142 |
+
faster in Python. The matrix multiplication has the attraction of being more
|
143 |
+
general and flexible, and much easier to read.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
num_mel_bins: How many bands in the resulting mel spectrum. This is
|
147 |
+
the number of columns in the output matrix.
|
148 |
+
num_spectrogram_bins: How many bins there are in the source spectrogram
|
149 |
+
data, which is understood to be fft_size/2 + 1, i.e. the spectrogram
|
150 |
+
only contains the nonredundant FFT bins.
|
151 |
+
audio_sample_rate: Samples per second of the audio at the input to the
|
152 |
+
spectrogram. We need this to figure out the actual frequencies for
|
153 |
+
each spectrogram bin, which dictates how they are mapped into mel.
|
154 |
+
lower_edge_hertz: Lower bound on the frequencies to be included in the mel
|
155 |
+
spectrum. This corresponds to the lower edge of the lowest triangular
|
156 |
+
band.
|
157 |
+
upper_edge_hertz: The desired top edge of the highest frequency band.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
An np.array with shape (num_spectrogram_bins, num_mel_bins).
|
161 |
+
|
162 |
+
Raises:
|
163 |
+
ValueError: if frequency edges are incorrectly ordered or out of range.
|
164 |
+
"""
|
165 |
+
nyquist_hertz = audio_sample_rate / 2.
|
166 |
+
if lower_edge_hertz < 0.0:
|
167 |
+
raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz)
|
168 |
+
if lower_edge_hertz >= upper_edge_hertz:
|
169 |
+
raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
|
170 |
+
(lower_edge_hertz, upper_edge_hertz))
|
171 |
+
if upper_edge_hertz > nyquist_hertz:
|
172 |
+
raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" %
|
173 |
+
(upper_edge_hertz, nyquist_hertz))
|
174 |
+
spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
|
175 |
+
spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz)
|
176 |
+
# The i'th mel band (starting from i=1) has center frequency
|
177 |
+
# band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
|
178 |
+
# band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in
|
179 |
+
# the band_edges_mel arrays.
|
180 |
+
band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz),
|
181 |
+
hertz_to_mel(upper_edge_hertz), num_mel_bins + 2)
|
182 |
+
# Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
|
183 |
+
# of spectrogram values.
|
184 |
+
mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins))
|
185 |
+
for i in range(num_mel_bins):
|
186 |
+
lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3]
|
187 |
+
# Calculate lower and upper slopes for every spectrogram bin.
|
188 |
+
# Line segments are linear in the *mel* domain, not hertz.
|
189 |
+
lower_slope = ((spectrogram_bins_mel - lower_edge_mel) /
|
190 |
+
(center_mel - lower_edge_mel))
|
191 |
+
upper_slope = ((upper_edge_mel - spectrogram_bins_mel) /
|
192 |
+
(upper_edge_mel - center_mel))
|
193 |
+
# .. then intersect them with each other and zero.
|
194 |
+
mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope,
|
195 |
+
upper_slope))
|
196 |
+
# HTK excludes the spectrogram DC bin; make sure it always gets a zero
|
197 |
+
# coefficient.
|
198 |
+
mel_weights_matrix[0, :] = 0.0
|
199 |
+
return mel_weights_matrix
|
200 |
+
|
201 |
+
|
202 |
+
def log_mel_spectrogram(data,
|
203 |
+
audio_sample_rate=8000,
|
204 |
+
log_offset=0.0,
|
205 |
+
window_length_secs=0.025,
|
206 |
+
hop_length_secs=0.010,
|
207 |
+
**kwargs):
|
208 |
+
"""Convert waveform to a log magnitude mel-frequency spectrogram.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
data: 1D np.array of waveform data.
|
212 |
+
audio_sample_rate: The sampling rate of data.
|
213 |
+
log_offset: Add this to values when taking log to avoid -Infs.
|
214 |
+
window_length_secs: Duration of each window to analyze.
|
215 |
+
hop_length_secs: Advance between successive analysis windows.
|
216 |
+
**kwargs: Additional arguments to pass to spectrogram_to_mel_matrix.
|
217 |
+
|
218 |
+
Returns:
|
219 |
+
2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank
|
220 |
+
magnitudes for successive frames.
|
221 |
+
"""
|
222 |
+
window_length_samples = int(round(audio_sample_rate * window_length_secs))
|
223 |
+
hop_length_samples = int(round(audio_sample_rate * hop_length_secs))
|
224 |
+
fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
|
225 |
+
spectrogram = stft_magnitude(
|
226 |
+
data,
|
227 |
+
fft_length=fft_length,
|
228 |
+
hop_length=hop_length_samples,
|
229 |
+
window_length=window_length_samples)
|
230 |
+
mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix(
|
231 |
+
num_spectrogram_bins=spectrogram.shape[1],
|
232 |
+
audio_sample_rate=audio_sample_rate, **kwargs))
|
233 |
+
return np.log(mel_spectrogram + log_offset)
|
avism/data/datasets/extract_audio_feat/vggish_input.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Compute input examples for VGGish from audio waveform."""
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import resampy
|
20 |
+
from scipy.io import wavfile
|
21 |
+
|
22 |
+
import mel_features
|
23 |
+
import vggish_params
|
24 |
+
|
25 |
+
|
26 |
+
def waveform_to_examples(data, sample_rate):
|
27 |
+
"""Converts audio waveform into an array of examples for VGGish.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
data: np.array of either one dimension (mono) or two dimensions
|
31 |
+
(multi-channel, with the outer dimension representing channels).
|
32 |
+
Each sample is generally expected to lie in the range [-1.0, +1.0],
|
33 |
+
although this is not required.
|
34 |
+
sample_rate: Sample rate of data.
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
3-D np.array of shape [num_examples, num_frames, num_bands] which represents
|
38 |
+
a sequence of examples, each of which contains a patch of log mel
|
39 |
+
spectrogram, covering num_frames frames of audio and num_bands mel frequency
|
40 |
+
bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS.
|
41 |
+
"""
|
42 |
+
# Convert to mono.
|
43 |
+
if len(data.shape) > 1:
|
44 |
+
data = np.mean(data, axis=1)
|
45 |
+
# Resample to the rate assumed by VGGish.
|
46 |
+
if sample_rate != vggish_params.SAMPLE_RATE:
|
47 |
+
data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)
|
48 |
+
|
49 |
+
# Compute log mel spectrogram features.
|
50 |
+
log_mel = mel_features.log_mel_spectrogram(
|
51 |
+
data,
|
52 |
+
audio_sample_rate=vggish_params.SAMPLE_RATE,
|
53 |
+
log_offset=vggish_params.LOG_OFFSET,
|
54 |
+
window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS,
|
55 |
+
hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS,
|
56 |
+
num_mel_bins=vggish_params.NUM_MEL_BINS,
|
57 |
+
lower_edge_hertz=vggish_params.MEL_MIN_HZ,
|
58 |
+
upper_edge_hertz=vggish_params.MEL_MAX_HZ)
|
59 |
+
|
60 |
+
# Frame features into examples.
|
61 |
+
features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS
|
62 |
+
example_window_length = int(round(
|
63 |
+
vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate))
|
64 |
+
example_hop_length = int(round(
|
65 |
+
vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate))
|
66 |
+
log_mel_examples = mel_features.frame(
|
67 |
+
log_mel,
|
68 |
+
window_length=example_window_length,
|
69 |
+
hop_length=example_hop_length)
|
70 |
+
return log_mel_examples
|
71 |
+
|
72 |
+
|
73 |
+
def wavfile_to_examples(wav_file, num_secs):
|
74 |
+
"""Convenience wrapper around waveform_to_examples() for a common WAV format.
|
75 |
+
Args:
|
76 |
+
wav_file: String path to a file, or a file-like object. The file
|
77 |
+
is assumed to contain WAV audio data with signed 16-bit PCM samples.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
See waveform_to_examples.
|
81 |
+
"""
|
82 |
+
sr, snd = wavfile.read(wav_file)
|
83 |
+
L = sr * num_secs
|
84 |
+
wav_data = snd[:L, :]
|
85 |
+
wav_data = wav_data / 32768.0 # Convert to [-1.0, +1.0]
|
86 |
+
T = num_secs
|
87 |
+
log_mel = np.zeros([T, 96, 64])
|
88 |
+
|
89 |
+
for i in range(T):
|
90 |
+
s = i * sr
|
91 |
+
e = (i + 1) * sr
|
92 |
+
if len(wav_data.shape) > 1:
|
93 |
+
data = wav_data[s:e, :]
|
94 |
+
else:
|
95 |
+
data = wav_data[s:e]
|
96 |
+
|
97 |
+
wave_data_array = waveform_to_examples(data, sr)
|
98 |
+
if len(wave_data_array) != 0:
|
99 |
+
log_mel[i, :, :] = wave_data_array
|
100 |
+
else:
|
101 |
+
log_mel[i, :, :] = np.zeros((1, 96, 64), dtype=float)
|
102 |
+
|
103 |
+
return log_mel
|
avism/data/datasets/extract_audio_feat/vggish_params.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Global parameters for the VGGish model.
|
17 |
+
|
18 |
+
See vggish_slim.py for more information.
|
19 |
+
"""
|
20 |
+
|
21 |
+
# Architectural constants.
|
22 |
+
NUM_FRAMES = 96 # Frames in input mel-spectrogram patch.
|
23 |
+
NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch.
|
24 |
+
EMBEDDING_SIZE = 128 # Size of embedding layer.
|
25 |
+
|
26 |
+
# Hyperparameters used in feature and example generation.
|
27 |
+
SAMPLE_RATE = 16000
|
28 |
+
STFT_WINDOW_LENGTH_SECONDS = 0.025
|
29 |
+
STFT_HOP_LENGTH_SECONDS = 0.010
|
30 |
+
NUM_MEL_BINS = NUM_BANDS
|
31 |
+
MEL_MIN_HZ = 125
|
32 |
+
MEL_MAX_HZ = 7500
|
33 |
+
LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram.
|
34 |
+
EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
|
35 |
+
EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
|
36 |
+
|
37 |
+
# Parameters used for embedding postprocessing.
|
38 |
+
PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors'
|
39 |
+
PCA_MEANS_NAME = 'pca_means'
|
40 |
+
QUANTIZE_MIN_VAL = -2.0
|
41 |
+
QUANTIZE_MAX_VAL = +2.0
|
42 |
+
|
43 |
+
# Hyperparameters used in training.
|
44 |
+
INIT_STDDEV = 0.01 # Standard deviation used to initialize weights.
|
45 |
+
LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer.
|
46 |
+
ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer.
|
47 |
+
|
48 |
+
# Names of ops, tensors, and features.
|
49 |
+
INPUT_OP_NAME = 'vggish/input_features'
|
50 |
+
INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0'
|
51 |
+
OUTPUT_OP_NAME = 'vggish/embedding'
|
52 |
+
OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0'
|
53 |
+
AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding'
|
avism/data/datasets/extract_audio_feat/vggish_slim.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Defines the 'VGGish' model used to generate AudioSet embedding features.
|
17 |
+
|
18 |
+
The public AudioSet release (https://research.google.com/audioset/download.html)
|
19 |
+
includes 128-D features extracted from the embedding layer of a VGG-like model
|
20 |
+
that was trained on a large Google-internal YouTube dataset. Here we provide
|
21 |
+
a TF-Slim definition of the same model, without any dependences on libraries
|
22 |
+
internal to Google. We call it 'VGGish'.
|
23 |
+
|
24 |
+
Note that we only define the model up to the embedding layer, which is the
|
25 |
+
penultimate layer before the final classifier layer. We also provide various
|
26 |
+
hyperparameter values (in vggish_params.py) that were used to train this model
|
27 |
+
internally.
|
28 |
+
|
29 |
+
For comparison, here is TF-Slim's VGG definition:
|
30 |
+
https://github.com/tensorflow/models/blob/master/research/slim/nets/vgg.py
|
31 |
+
"""
|
32 |
+
|
33 |
+
import tensorflow._api.v2.compat.v1 as tf
|
34 |
+
tf.disable_v2_behavior()
|
35 |
+
import tf_slim as slim
|
36 |
+
|
37 |
+
import vggish_params as params
|
38 |
+
|
39 |
+
|
40 |
+
def define_vggish_slim(training=False):
|
41 |
+
"""Defines the VGGish TensorFlow model.
|
42 |
+
|
43 |
+
All ops are created in the current default graph, under the scope 'vggish/'.
|
44 |
+
|
45 |
+
The input is a placeholder named 'vggish/input_features' of type float32 and
|
46 |
+
shape [batch_size, num_frames, num_bands] where batch_size is variable and
|
47 |
+
num_frames and num_bands are constants, and [num_frames, num_bands] represents
|
48 |
+
a log-mel-scale spectrogram patch covering num_bands frequency bands and
|
49 |
+
num_frames time frames (where each frame step is usually 10ms). This is
|
50 |
+
produced by computing the stabilized log(mel-spectrogram + params.LOG_OFFSET).
|
51 |
+
The output is an op named 'vggish/embedding' which produces the activations of
|
52 |
+
a 128-D embedding layer, which is usually the penultimate layer when used as
|
53 |
+
part of a full model with a final classifier layer.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
training: If true, all parameters are marked trainable.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
The op 'vggish/embeddings'.
|
60 |
+
"""
|
61 |
+
# Defaults:
|
62 |
+
# - All weights are initialized to N(0, INIT_STDDEV).
|
63 |
+
# - All biases are initialized to 0.
|
64 |
+
# - All activations are ReLU.
|
65 |
+
# - All convolutions are 3x3 with stride 1 and SAME padding.
|
66 |
+
# - All max-pools are 2x2 with stride 2 and SAME padding.
|
67 |
+
with slim.arg_scope([slim.conv2d, slim.fully_connected],
|
68 |
+
weights_initializer=tf.truncated_normal_initializer(
|
69 |
+
stddev=params.INIT_STDDEV),
|
70 |
+
biases_initializer=tf.zeros_initializer(),
|
71 |
+
activation_fn=tf.nn.relu,
|
72 |
+
trainable=training), \
|
73 |
+
slim.arg_scope([slim.conv2d],
|
74 |
+
kernel_size=[3, 3], stride=1, padding='SAME'), \
|
75 |
+
slim.arg_scope([slim.max_pool2d],
|
76 |
+
kernel_size=[2, 2], stride=2, padding='SAME'), \
|
77 |
+
tf.compat.v1.variable_scope('vggish'):
|
78 |
+
# tf.variable_scope('vggish'):
|
79 |
+
# Input: a batch of 2-D log-mel-spectrogram patches.
|
80 |
+
# features = tf.placeholder(
|
81 |
+
features = tf.compat.v1.placeholder(
|
82 |
+
tf.float32, shape=(None, params.NUM_FRAMES, params.NUM_BANDS),
|
83 |
+
name='input_features')
|
84 |
+
# Reshape to 4-D so that we can convolve a batch with conv2d().
|
85 |
+
net = tf.reshape(features, [-1, params.NUM_FRAMES, params.NUM_BANDS, 1])
|
86 |
+
|
87 |
+
# The VGG stack of alternating convolutions and max-pools.
|
88 |
+
net = slim.conv2d(net, 64, scope='conv1')
|
89 |
+
net = slim.max_pool2d(net, scope='pool1')
|
90 |
+
net = slim.conv2d(net, 128, scope='conv2')
|
91 |
+
net = slim.max_pool2d(net, scope='pool2')
|
92 |
+
net = slim.repeat(net, 2, slim.conv2d, 256, scope='conv3')
|
93 |
+
net = slim.max_pool2d(net, scope='pool3')
|
94 |
+
net = slim.repeat(net, 2, slim.conv2d, 512, scope='conv4')
|
95 |
+
net = slim.max_pool2d(net, scope='pool4')
|
96 |
+
|
97 |
+
# Flatten before entering fully-connected layers
|
98 |
+
net = slim.flatten(net)
|
99 |
+
net = slim.repeat(net, 2, slim.fully_connected, 4096, scope='fc1')
|
100 |
+
# The embedding layer.
|
101 |
+
net = slim.fully_connected(net, params.EMBEDDING_SIZE, scope='fc2')
|
102 |
+
return tf.identity(net, name='embedding')
|
103 |
+
|
104 |
+
|
105 |
+
def load_vggish_slim_checkpoint(session, checkpoint_path):
|
106 |
+
"""Loads a pre-trained VGGish-compatible checkpoint.
|
107 |
+
|
108 |
+
This function can be used as an initialization function (referred to as
|
109 |
+
init_fn in TensorFlow documentation) which is called in a Session after
|
110 |
+
initializating all variables. When used as an init_fn, this will load
|
111 |
+
a pre-trained checkpoint that is compatible with the VGGish model
|
112 |
+
definition. Only variables defined by VGGish will be loaded.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
session: an active TensorFlow session.
|
116 |
+
checkpoint_path: path to a file containing a checkpoint that is
|
117 |
+
compatible with the VGGish model definition.
|
118 |
+
"""
|
119 |
+
# Get the list of names of all VGGish variables that exist in
|
120 |
+
# the checkpoint (i.e., all inference-mode VGGish variables).
|
121 |
+
with tf.Graph().as_default():
|
122 |
+
define_vggish_slim(training=False)
|
123 |
+
# vggish_var_names = [v.name for v in tf.global_variables()]
|
124 |
+
vggish_var_names = [v.name for v in tf.compat.v1.global_variables()]
|
125 |
+
|
126 |
+
# Get the list of all currently existing variables that match
|
127 |
+
# the list of variable names we just computed.
|
128 |
+
# vggish_vars = [v for v in tf.global_variables() if v.name in vggish_var_names]
|
129 |
+
vggish_vars = [v for v in tf.compat.v1.global_variables() if v.name in vggish_var_names]
|
130 |
+
|
131 |
+
# Use a Saver to restore just the variables selected above.
|
132 |
+
# saver = tf.train.Saver(vggish_vars, name='vggish_load_pretrained')
|
133 |
+
saver = tf.compat.v1.train.Saver(vggish_vars, name='vggish_load_pretrained')
|
134 |
+
saver.restore(session, checkpoint_path)
|
avism/modeling/__init__.py
ADDED
File without changes
|
avism/modeling/avism_criterion.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
from detectron2.utils.comm import get_world_size
|
6 |
+
from detectron2.projects.point_rend.point_features import (
|
7 |
+
get_uncertain_point_coords_with_randomness,
|
8 |
+
point_sample,
|
9 |
+
)
|
10 |
+
|
11 |
+
from ..utils.misc import is_dist_avail_and_initialized
|
12 |
+
|
13 |
+
|
14 |
+
def dice_loss(
|
15 |
+
inputs: torch.Tensor,
|
16 |
+
targets: torch.Tensor,
|
17 |
+
num_masks: float,
|
18 |
+
):
|
19 |
+
"""
|
20 |
+
Compute the DICE loss, similar to generalized IOU for masks
|
21 |
+
Args:
|
22 |
+
inputs: A float tensor of arbitrary shape.
|
23 |
+
The predictions for each example.
|
24 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
25 |
+
classification label for each element in inputs
|
26 |
+
(0 for the negative class and 1 for the positive class).
|
27 |
+
"""
|
28 |
+
inputs = inputs.sigmoid()
|
29 |
+
inputs = inputs.flatten(1)
|
30 |
+
numerator = 2 * (inputs * targets).sum(-1)
|
31 |
+
denominator = inputs.sum(-1) + targets.sum(-1)
|
32 |
+
loss = 1 - (numerator + 1) / (denominator + 1)
|
33 |
+
return loss.sum() / num_masks
|
34 |
+
|
35 |
+
|
36 |
+
dice_loss_jit = torch.jit.script(
|
37 |
+
dice_loss
|
38 |
+
) # type: torch.jit.ScriptModule
|
39 |
+
|
40 |
+
|
41 |
+
def sigmoid_ce_loss(
|
42 |
+
inputs: torch.Tensor,
|
43 |
+
targets: torch.Tensor,
|
44 |
+
num_masks: float,
|
45 |
+
):
|
46 |
+
"""
|
47 |
+
Args:
|
48 |
+
inputs: A float tensor of arbitrary shape.
|
49 |
+
The predictions for each example.
|
50 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
51 |
+
classification label for each element in inputs
|
52 |
+
(0 for the negative class and 1 for the positive class).
|
53 |
+
Returns:
|
54 |
+
Loss tensor
|
55 |
+
"""
|
56 |
+
loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
57 |
+
|
58 |
+
return loss.mean(1).sum() / num_masks
|
59 |
+
|
60 |
+
|
61 |
+
sigmoid_ce_loss_jit = torch.jit.script(
|
62 |
+
sigmoid_ce_loss
|
63 |
+
) # type: torch.jit.ScriptModule
|
64 |
+
|
65 |
+
|
66 |
+
def calculate_uncertainty(logits):
|
67 |
+
"""
|
68 |
+
We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
|
69 |
+
foreground class in `classes`.
|
70 |
+
Args:
|
71 |
+
logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or
|
72 |
+
class-agnostic, where R is the total number of predicted masks in all images and C is
|
73 |
+
the number of foreground classes. The values are logits.
|
74 |
+
Returns:
|
75 |
+
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
|
76 |
+
the most uncertain locations having the highest uncertainty score.
|
77 |
+
"""
|
78 |
+
assert logits.shape[1] == 1
|
79 |
+
gt_class_logits = logits.clone()
|
80 |
+
return -(torch.abs(gt_class_logits))
|
81 |
+
|
82 |
+
|
83 |
+
class AvismSetCriterion(nn.Module):
|
84 |
+
"""This class computes the loss for DETR.
|
85 |
+
The process happens in two steps:
|
86 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
87 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses,
|
91 |
+
num_points, oversample_ratio, importance_sample_ratio, sim_use_clip):
|
92 |
+
"""Create the criterion.
|
93 |
+
Parameters:
|
94 |
+
num_classes: number of object categories, omitting the special no-object category
|
95 |
+
matcher: module able to compute a matching between targets and proposals
|
96 |
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
97 |
+
eos_coef: relative classification weight applied to the no-object category
|
98 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
99 |
+
"""
|
100 |
+
super().__init__()
|
101 |
+
self.num_classes = num_classes
|
102 |
+
self.matcher = matcher
|
103 |
+
self.weight_dict = weight_dict
|
104 |
+
self.eos_coef = eos_coef
|
105 |
+
self.losses = losses
|
106 |
+
empty_weight = torch.ones(self.num_classes + 1)
|
107 |
+
empty_weight[-1] = self.eos_coef
|
108 |
+
self.register_buffer("empty_weight", empty_weight)
|
109 |
+
|
110 |
+
# pointwise mask loss parameters
|
111 |
+
self.num_points = num_points
|
112 |
+
self.oversample_ratio = oversample_ratio
|
113 |
+
self.importance_sample_ratio = importance_sample_ratio
|
114 |
+
self.sim_use_clip = sim_use_clip
|
115 |
+
|
116 |
+
def loss_labels(self, outputs, targets, indices, num_masks):
|
117 |
+
"""Classification loss (NLL)
|
118 |
+
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
119 |
+
"""
|
120 |
+
assert "pred_logits" in outputs
|
121 |
+
src_logits = outputs['pred_logits']
|
122 |
+
L, B, cQ, _ = src_logits.shape
|
123 |
+
src_logits = src_logits.reshape(L*B, cQ, self.num_classes+1)
|
124 |
+
|
125 |
+
idx = self._get_src_permutation_idx(indices)
|
126 |
+
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets * L, indices)])
|
127 |
+
target_classes = torch.full(
|
128 |
+
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
|
129 |
+
)
|
130 |
+
target_classes[idx] = target_classes_o
|
131 |
+
|
132 |
+
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
|
133 |
+
losses = {'loss_avism_ce': loss_ce}
|
134 |
+
|
135 |
+
return losses
|
136 |
+
|
137 |
+
def loss_masks(self, outputs, targets, indices, num_masks):
|
138 |
+
"""Compute the losses related to the masks: the focal loss and the dice loss.
|
139 |
+
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
|
140 |
+
"""
|
141 |
+
assert "pred_masks" in outputs
|
142 |
+
|
143 |
+
idx = self._get_src_permutation_idx(indices)
|
144 |
+
src_masks = outputs["pred_masks"]
|
145 |
+
L, B, cQ, T, H, W = src_masks.shape
|
146 |
+
src_masks = src_masks.reshape(L*B, cQ, T, H, W)
|
147 |
+
|
148 |
+
src_masks = src_masks[idx] # Nt x T x Hp x Wp
|
149 |
+
target_masks = torch.cat([t['masks'][i] for t, (_, i) in zip(targets * L, indices)]).to(src_masks)
|
150 |
+
# Nt x T x Ht x Wt
|
151 |
+
src_masks = src_masks.flatten(0, 1)[:, None]
|
152 |
+
target_masks = target_masks.flatten(0, 1)[:, None]
|
153 |
+
|
154 |
+
with torch.no_grad():
|
155 |
+
# sample point_coords
|
156 |
+
point_coords = get_uncertain_point_coords_with_randomness(
|
157 |
+
src_masks,
|
158 |
+
lambda logits: calculate_uncertainty(logits),
|
159 |
+
self.num_points,
|
160 |
+
self.oversample_ratio,
|
161 |
+
self.importance_sample_ratio,
|
162 |
+
)
|
163 |
+
# get gt labels
|
164 |
+
point_labels = point_sample(
|
165 |
+
target_masks,
|
166 |
+
point_coords,
|
167 |
+
align_corners=False,
|
168 |
+
).squeeze(1)
|
169 |
+
|
170 |
+
point_logits = point_sample(
|
171 |
+
src_masks,
|
172 |
+
point_coords,
|
173 |
+
align_corners=False,
|
174 |
+
).squeeze(1)
|
175 |
+
|
176 |
+
# Nt*T, randN -> Nt, T*randN
|
177 |
+
point_logits = point_logits.view(len(idx[0]), T * self.num_points)
|
178 |
+
point_labels = point_labels.view(len(idx[0]), T * self.num_points)
|
179 |
+
|
180 |
+
losses = {
|
181 |
+
"loss_avism_mask": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks),
|
182 |
+
"loss_avism_dice": dice_loss_jit(point_logits, point_labels, num_masks),
|
183 |
+
}
|
184 |
+
|
185 |
+
del src_masks
|
186 |
+
del target_masks
|
187 |
+
return losses
|
188 |
+
|
189 |
+
def loss_fg_sim(
|
190 |
+
self, outputs, clip_targets, frame_targets,
|
191 |
+
clip_indices, frame_indices, num_masks, MULTIPLIER=1000
|
192 |
+
):
|
193 |
+
total_src_q, total_tgt_ids, total_batch_idx = [], [], []
|
194 |
+
|
195 |
+
# Frame
|
196 |
+
src_fq = outputs["pred_fq_embed"] # L, B, T, fQ, C
|
197 |
+
# L = number of frame_decoder layers
|
198 |
+
L, B, T, fQ, C = src_fq.shape
|
199 |
+
src_fq = src_fq.flatten(0, 2) # LBT, fQ, C
|
200 |
+
|
201 |
+
frame_indices = sum(frame_indices, [])
|
202 |
+
frame_src_idx = self._get_src_permutation_idx(frame_indices) # len = LBT
|
203 |
+
src_fq = src_fq[frame_src_idx] # Nf, C
|
204 |
+
target_frame_ids = torch.cat(
|
205 |
+
[t["ids"][J] for t, (_, J) in zip(frame_targets * L, frame_indices)]
|
206 |
+
)
|
207 |
+
frame_batch_idx = torch.div(frame_src_idx[0].to(device=src_fq.device), T, rounding_mode="floor")
|
208 |
+
is_frame_valid = target_frame_ids != -1
|
209 |
+
target_frame_ids += frame_batch_idx * MULTIPLIER
|
210 |
+
|
211 |
+
total_src_q.append(src_fq[is_frame_valid])
|
212 |
+
total_tgt_ids.append(target_frame_ids[is_frame_valid])
|
213 |
+
total_batch_idx.append(frame_batch_idx[is_frame_valid])
|
214 |
+
|
215 |
+
# Clip
|
216 |
+
if self.sim_use_clip:
|
217 |
+
src_cq = outputs["pred_cq_embed"] # L, B, cQ, C
|
218 |
+
src_cq = src_cq.flatten(0, 1) # LB , cQ, C
|
219 |
+
|
220 |
+
clip_src_idx = self._get_src_permutation_idx(clip_indices) # len = LB
|
221 |
+
src_cq = src_cq[clip_src_idx] # Nc, C
|
222 |
+
target_clip_ids = torch.cat( # clip_ids' shape = (N, num_frames) -> (N,)
|
223 |
+
[t["ids"][J] for t, (_, J) in zip(clip_targets * L, clip_indices)]
|
224 |
+
).amax(dim=1)
|
225 |
+
clip_batch_idx = clip_src_idx[0].to(device=src_fq.device)
|
226 |
+
is_clip_valid = target_clip_ids != -1
|
227 |
+
target_clip_ids += clip_batch_idx * MULTIPLIER
|
228 |
+
|
229 |
+
total_src_q.append(src_cq[is_clip_valid])
|
230 |
+
total_tgt_ids.append(target_clip_ids[is_clip_valid])
|
231 |
+
total_batch_idx.append(clip_batch_idx[is_clip_valid])
|
232 |
+
|
233 |
+
# Clip + Frame
|
234 |
+
total_src_q = torch.cat(total_src_q) # Nc+Nf, C
|
235 |
+
total_tgt_ids = torch.cat(total_tgt_ids) # Nc+Nf
|
236 |
+
total_batch_idx = torch.cat(total_batch_idx) # Nc+Nf
|
237 |
+
|
238 |
+
sim_pred_logits = torch.matmul(total_src_q, total_src_q.T) # Nc+Nf, Nc+Nf
|
239 |
+
sim_tgt = (total_tgt_ids[:, None] == total_tgt_ids[None]).float() # Nc+Nf, Nc+Nf
|
240 |
+
|
241 |
+
same_clip = (total_batch_idx[:, None] == total_batch_idx[None]).float()
|
242 |
+
loss = F.binary_cross_entropy_with_logits(sim_pred_logits, sim_tgt, reduction='none')
|
243 |
+
|
244 |
+
loss = loss * same_clip
|
245 |
+
loss_clip_sim = loss.sum() / (same_clip.sum() + 1e-6)
|
246 |
+
|
247 |
+
return {"loss_clip_sim": loss_clip_sim}
|
248 |
+
|
249 |
+
def _get_src_permutation_idx(self, indices):
|
250 |
+
# permute predictions following indices
|
251 |
+
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
252 |
+
src_idx = torch.cat([src for (src, _) in indices])
|
253 |
+
return batch_idx, src_idx
|
254 |
+
|
255 |
+
def _get_tgt_permutation_idx(self, indices):
|
256 |
+
# permute targets following indices
|
257 |
+
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
258 |
+
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
259 |
+
return batch_idx, tgt_idx
|
260 |
+
|
261 |
+
def get_loss(
|
262 |
+
self, loss, outputs, clip_targets, frame_targets, clip_indices, frame_indices, num_masks
|
263 |
+
):
|
264 |
+
loss_map = {
|
265 |
+
'avism_labels': self.loss_labels,
|
266 |
+
'avism_masks': self.loss_masks,
|
267 |
+
'fg_sim': self.loss_fg_sim,
|
268 |
+
}
|
269 |
+
assert loss in loss_map, f"do you really want to compute {loss} loss?"
|
270 |
+
if loss == 'fg_sim':
|
271 |
+
return loss_map[loss](
|
272 |
+
outputs, clip_targets, frame_targets, clip_indices, frame_indices, num_masks
|
273 |
+
)
|
274 |
+
return loss_map[loss](outputs, clip_targets, clip_indices, num_masks)
|
275 |
+
|
276 |
+
def forward(self, outputs, clip_targets, frame_targets, frame_indices=None):
|
277 |
+
"""This performs the loss computation.
|
278 |
+
Parameters:
|
279 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
280 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
281 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
282 |
+
"""
|
283 |
+
outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
|
284 |
+
|
285 |
+
# Retrieve the matching between the outputs of the last layer and the targets
|
286 |
+
clip_indices = self.matcher(outputs_without_aux, clip_targets)
|
287 |
+
|
288 |
+
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
289 |
+
num_masks = sum(len(t["labels"]) for t in clip_targets) * len(outputs_without_aux["pred_masks"])
|
290 |
+
num_masks = torch.as_tensor(
|
291 |
+
[num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
|
292 |
+
)
|
293 |
+
if is_dist_avail_and_initialized():
|
294 |
+
torch.distributed.all_reduce(num_masks)
|
295 |
+
num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
|
296 |
+
|
297 |
+
# Compute all the requested losses
|
298 |
+
losses = {}
|
299 |
+
for loss in self.losses:
|
300 |
+
losses.update(
|
301 |
+
self.get_loss(
|
302 |
+
loss, outputs, clip_targets, frame_targets, clip_indices, frame_indices, num_masks
|
303 |
+
)
|
304 |
+
)
|
305 |
+
|
306 |
+
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
307 |
+
if "aux_outputs" in outputs:
|
308 |
+
for i, aux_outputs in enumerate(outputs["aux_outputs"]):
|
309 |
+
clip_indices = self.matcher(aux_outputs, clip_targets)
|
310 |
+
for loss in self.losses:
|
311 |
+
if loss == "fg_sim":
|
312 |
+
continue
|
313 |
+
l_dict = self.get_loss(
|
314 |
+
loss, aux_outputs, clip_targets, frame_targets, clip_indices, frame_indices, num_masks
|
315 |
+
)
|
316 |
+
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
|
317 |
+
losses.update(l_dict)
|
318 |
+
|
319 |
+
return losses
|
320 |
+
|
321 |
+
def __repr__(self):
|
322 |
+
head = "Criterion " + self.__class__.__name__
|
323 |
+
body = [
|
324 |
+
"matcher: {}".format(self.matcher.__repr__(_repr_indent=8)),
|
325 |
+
"losses: {}".format(self.losses),
|
326 |
+
"weight_dict: {}".format(self.weight_dict),
|
327 |
+
"num_classes: {}".format(self.num_classes),
|
328 |
+
"eos_coef: {}".format(self.eos_coef),
|
329 |
+
"num_points: {}".format(self.num_points),
|
330 |
+
"oversample_ratio: {}".format(self.oversample_ratio),
|
331 |
+
"importance_sample_ratio: {}".format(self.importance_sample_ratio),
|
332 |
+
]
|
333 |
+
_repr_indent = 4
|
334 |
+
lines = [head] + [" " * _repr_indent + line for line in body]
|
335 |
+
return "\n".join(lines)
|
avism/modeling/avism_matcher.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modules to compute the matching cost and solve the corresponding LSAP.
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from scipy.optimize import linear_sum_assignment
|
7 |
+
from torch import nn
|
8 |
+
from torch.cuda.amp import autocast
|
9 |
+
|
10 |
+
from detectron2.projects.point_rend.point_features import point_sample
|
11 |
+
|
12 |
+
|
13 |
+
def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor):
|
14 |
+
"""
|
15 |
+
Compute the DICE loss, similar to generalized IOU for masks
|
16 |
+
Args:
|
17 |
+
inputs: A float tensor of arbitrary shape.
|
18 |
+
The predictions for each example.
|
19 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
20 |
+
classification label for each element in inputs
|
21 |
+
(0 for the negative class and 1 for the positive class).
|
22 |
+
"""
|
23 |
+
inputs = inputs.sigmoid()
|
24 |
+
inputs = inputs.flatten(1)
|
25 |
+
numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets)
|
26 |
+
denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]
|
27 |
+
loss = 1 - (numerator + 1) / (denominator + 1)
|
28 |
+
return loss
|
29 |
+
|
30 |
+
|
31 |
+
batch_dice_loss_jit = torch.jit.script(
|
32 |
+
batch_dice_loss
|
33 |
+
) # type: torch.jit.ScriptModule
|
34 |
+
|
35 |
+
|
36 |
+
def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor):
|
37 |
+
"""
|
38 |
+
Args:
|
39 |
+
inputs: A float tensor of arbitrary shape.
|
40 |
+
The predictions for each example.
|
41 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
42 |
+
classification label for each element in inputs
|
43 |
+
(0 for the negative class and 1 for the positive class).
|
44 |
+
Returns:
|
45 |
+
Loss tensor
|
46 |
+
"""
|
47 |
+
hw = inputs.shape[1]
|
48 |
+
|
49 |
+
pos = F.binary_cross_entropy_with_logits(
|
50 |
+
inputs, torch.ones_like(inputs), reduction="none"
|
51 |
+
)
|
52 |
+
neg = F.binary_cross_entropy_with_logits(
|
53 |
+
inputs, torch.zeros_like(inputs), reduction="none"
|
54 |
+
)
|
55 |
+
|
56 |
+
loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum(
|
57 |
+
"nc,mc->nm", neg, (1 - targets)
|
58 |
+
)
|
59 |
+
|
60 |
+
return loss / hw
|
61 |
+
|
62 |
+
|
63 |
+
batch_sigmoid_ce_loss_jit = torch.jit.script(
|
64 |
+
batch_sigmoid_ce_loss
|
65 |
+
) # type: torch.jit.ScriptModule
|
66 |
+
|
67 |
+
|
68 |
+
class AvismHungarianMatcher(nn.Module):
|
69 |
+
"""This class computes an assignment between the targets and the predictions of the network
|
70 |
+
|
71 |
+
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
|
72 |
+
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
|
73 |
+
while the others are un-matched (and thus treated as non-objects).
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, num_points: int = 0):
|
77 |
+
"""Creates the matcher
|
78 |
+
|
79 |
+
Params:
|
80 |
+
cost_class: This is the relative weight of the classification error in the matching cost
|
81 |
+
cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost
|
82 |
+
cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost
|
83 |
+
"""
|
84 |
+
super().__init__()
|
85 |
+
self.cost_class = cost_class
|
86 |
+
self.cost_mask = cost_mask
|
87 |
+
self.cost_dice = cost_dice
|
88 |
+
|
89 |
+
assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0"
|
90 |
+
|
91 |
+
self.num_points = num_points
|
92 |
+
|
93 |
+
@torch.no_grad()
|
94 |
+
def memory_efficient_forward(self, outputs, targets):
|
95 |
+
# We flatten to compute the cost matrices in a batch
|
96 |
+
|
97 |
+
# Here, "L" is the number of frame-level decoder layers.
|
98 |
+
out_prob = outputs["pred_logits"].softmax(-1) # L, B, cQ, K+1
|
99 |
+
out_mask = outputs["pred_masks"] # L, B, cQ, T, H, W
|
100 |
+
|
101 |
+
L, B, cQ, T, s_h, s_w = out_mask.shape
|
102 |
+
|
103 |
+
out_prob = out_prob.reshape(L*B, cQ, -1)
|
104 |
+
out_mask = out_mask.reshape(L*B, cQ, T, s_h, s_w)
|
105 |
+
|
106 |
+
# If target is [vid1, vid2, vid3],
|
107 |
+
# it now becomes [vid1, vid2, vid3, vid1, vid2, vid3, ...].
|
108 |
+
targets = targets * L
|
109 |
+
|
110 |
+
indices = []
|
111 |
+
for b in range(L*B):
|
112 |
+
b_out_prob = out_prob[b]
|
113 |
+
tgt_ids = targets[b]["labels"]
|
114 |
+
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
115 |
+
# but approximate it in 1 - proba[target class].
|
116 |
+
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
117 |
+
cost_class = -b_out_prob[:, tgt_ids]
|
118 |
+
|
119 |
+
b_out_mask = out_mask[b] # cQ x T x H_pred x W_pred
|
120 |
+
# gt masks are already padded when preparing target
|
121 |
+
tgt_mask = targets[b]["masks"].to(b_out_mask) # Nins x T x H_tgt x W_tgt
|
122 |
+
|
123 |
+
# out_mask = out_mask[:, None]
|
124 |
+
# tgt_mask = tgt_mask[:, None]
|
125 |
+
# all masks share the same set of points for efficient matching!
|
126 |
+
point_coords = torch.rand(1, self.num_points, 2, device=b_out_mask.device)
|
127 |
+
# get gt labels
|
128 |
+
tgt_mask = point_sample(
|
129 |
+
tgt_mask,
|
130 |
+
point_coords.repeat(tgt_mask.shape[0], 1, 1),
|
131 |
+
align_corners=False,
|
132 |
+
).flatten(1)
|
133 |
+
|
134 |
+
b_out_mask = point_sample(
|
135 |
+
b_out_mask,
|
136 |
+
point_coords.repeat(b_out_mask.shape[0], 1, 1),
|
137 |
+
align_corners=False,
|
138 |
+
).flatten(1)
|
139 |
+
|
140 |
+
with autocast(enabled=False):
|
141 |
+
b_out_mask = b_out_mask.float()
|
142 |
+
tgt_mask = tgt_mask.float()
|
143 |
+
# Compute the focal loss between masks
|
144 |
+
cost_mask = batch_sigmoid_ce_loss_jit(b_out_mask, tgt_mask)
|
145 |
+
# Compute the dice loss betwen masks
|
146 |
+
cost_dice = batch_dice_loss(b_out_mask, tgt_mask)
|
147 |
+
|
148 |
+
# Final cost matrix
|
149 |
+
C = (
|
150 |
+
self.cost_mask * cost_mask
|
151 |
+
+ self.cost_class * cost_class
|
152 |
+
+ self.cost_dice * cost_dice
|
153 |
+
)
|
154 |
+
C = C.reshape(cQ, -1).cpu()
|
155 |
+
|
156 |
+
indices.append(linear_sum_assignment(C))
|
157 |
+
|
158 |
+
return [
|
159 |
+
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
|
160 |
+
for i, j in indices
|
161 |
+
]
|
162 |
+
|
163 |
+
@torch.no_grad()
|
164 |
+
def forward(self, outputs, targets):
|
165 |
+
"""Performs the matching
|
166 |
+
|
167 |
+
Params:
|
168 |
+
outputs: This is a dict that contains at least these entries:
|
169 |
+
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
170 |
+
"pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks
|
171 |
+
|
172 |
+
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
173 |
+
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
|
174 |
+
objects in the target) containing the class labels
|
175 |
+
"masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
179 |
+
- index_i is the indices of the selected predictions (in order)
|
180 |
+
- index_j is the indices of the corresponding selected targets (in order)
|
181 |
+
For each batch element, it holds:
|
182 |
+
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
183 |
+
"""
|
184 |
+
return self.memory_efficient_forward(outputs, targets)
|
185 |
+
|
186 |
+
def __repr__(self, _repr_indent=4):
|
187 |
+
head = "Matcher " + self.__class__.__name__
|
188 |
+
body = [
|
189 |
+
"cost_class: {}".format(self.cost_class),
|
190 |
+
"cost_mask: {}".format(self.cost_mask),
|
191 |
+
"cost_dice: {}".format(self.cost_dice),
|
192 |
+
]
|
193 |
+
lines = [head] + [" " * _repr_indent + line for line in body]
|
194 |
+
return "\n".join(lines)
|
avism/modeling/transformer_decoder/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .avism_transformer_decoder import AVISMMultiScaleMaskedTransformerDecoder
|
avism/modeling/transformer_decoder/avism.py
ADDED
@@ -0,0 +1,675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import ceil
|
2 |
+
import fvcore.nn.weight_init as weight_init
|
3 |
+
from typing import Optional
|
4 |
+
import torch
|
5 |
+
from torch import nn, Tensor
|
6 |
+
from torch.nn import functional as F
|
7 |
+
import copy
|
8 |
+
|
9 |
+
from detectron2.config import configurable
|
10 |
+
from detectron2.layers import Conv2d
|
11 |
+
|
12 |
+
|
13 |
+
class SelfAttentionLayer(nn.Module):
|
14 |
+
|
15 |
+
def __init__(self, d_model, nhead, dropout=0.0,
|
16 |
+
activation="relu", normalize_before=False):
|
17 |
+
super().__init__()
|
18 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
19 |
+
|
20 |
+
self.norm = nn.LayerNorm(d_model)
|
21 |
+
self.dropout = nn.Dropout(dropout)
|
22 |
+
|
23 |
+
self.activation = _get_activation_fn(activation)
|
24 |
+
self.normalize_before = normalize_before
|
25 |
+
|
26 |
+
self._reset_parameters()
|
27 |
+
|
28 |
+
def _reset_parameters(self):
|
29 |
+
for p in self.parameters():
|
30 |
+
if p.dim() > 1:
|
31 |
+
nn.init.xavier_uniform_(p)
|
32 |
+
|
33 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
34 |
+
return tensor if pos is None else tensor + pos
|
35 |
+
|
36 |
+
def forward_post(self, tgt,
|
37 |
+
tgt_mask: Optional[Tensor] = None,
|
38 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
39 |
+
query_pos: Optional[Tensor] = None):
|
40 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
41 |
+
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
42 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
43 |
+
tgt = tgt + self.dropout(tgt2)
|
44 |
+
tgt = self.norm(tgt)
|
45 |
+
|
46 |
+
return tgt
|
47 |
+
|
48 |
+
def forward_pre(self, tgt,
|
49 |
+
tgt_mask: Optional[Tensor] = None,
|
50 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
51 |
+
query_pos: Optional[Tensor] = None):
|
52 |
+
tgt2 = self.norm(tgt)
|
53 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
54 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
55 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
56 |
+
tgt = tgt + self.dropout(tgt2)
|
57 |
+
|
58 |
+
return tgt
|
59 |
+
|
60 |
+
def forward(self, tgt,
|
61 |
+
tgt_mask: Optional[Tensor] = None,
|
62 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
63 |
+
query_pos: Optional[Tensor] = None):
|
64 |
+
if self.normalize_before:
|
65 |
+
return self.forward_pre(tgt, tgt_mask,
|
66 |
+
tgt_key_padding_mask, query_pos)
|
67 |
+
return self.forward_post(tgt, tgt_mask,
|
68 |
+
tgt_key_padding_mask, query_pos)
|
69 |
+
|
70 |
+
|
71 |
+
class CrossAttentionLayer(nn.Module):
|
72 |
+
|
73 |
+
def __init__(self, d_model, nhead, dropout=0.0,
|
74 |
+
activation="relu", normalize_before=False):
|
75 |
+
super().__init__()
|
76 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
77 |
+
|
78 |
+
self.norm = nn.LayerNorm(d_model)
|
79 |
+
self.dropout = nn.Dropout(dropout)
|
80 |
+
|
81 |
+
self.activation = _get_activation_fn(activation)
|
82 |
+
self.normalize_before = normalize_before
|
83 |
+
|
84 |
+
self._reset_parameters()
|
85 |
+
|
86 |
+
def _reset_parameters(self):
|
87 |
+
for p in self.parameters():
|
88 |
+
if p.dim() > 1:
|
89 |
+
nn.init.xavier_uniform_(p)
|
90 |
+
|
91 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
92 |
+
return tensor if pos is None else tensor + pos
|
93 |
+
|
94 |
+
def forward_post(self, tgt, memory,
|
95 |
+
memory_mask: Optional[Tensor] = None,
|
96 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
97 |
+
pos: Optional[Tensor] = None,
|
98 |
+
query_pos: Optional[Tensor] = None):
|
99 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
100 |
+
key=self.with_pos_embed(memory, pos),
|
101 |
+
value=memory, attn_mask=memory_mask,
|
102 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
103 |
+
tgt = tgt + self.dropout(tgt2)
|
104 |
+
tgt = self.norm(tgt)
|
105 |
+
|
106 |
+
return tgt
|
107 |
+
|
108 |
+
def forward_pre(self, tgt, memory,
|
109 |
+
memory_mask: Optional[Tensor] = None,
|
110 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
111 |
+
pos: Optional[Tensor] = None,
|
112 |
+
query_pos: Optional[Tensor] = None):
|
113 |
+
tgt2 = self.norm(tgt)
|
114 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
115 |
+
key=self.with_pos_embed(memory, pos),
|
116 |
+
value=memory, attn_mask=memory_mask,
|
117 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
118 |
+
tgt = tgt + self.dropout(tgt2)
|
119 |
+
|
120 |
+
return tgt
|
121 |
+
|
122 |
+
def forward(self, tgt, memory,
|
123 |
+
memory_mask: Optional[Tensor] = None,
|
124 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
125 |
+
pos: Optional[Tensor] = None,
|
126 |
+
query_pos: Optional[Tensor] = None):
|
127 |
+
if self.normalize_before:
|
128 |
+
return self.forward_pre(tgt, memory, memory_mask,
|
129 |
+
memory_key_padding_mask, pos, query_pos)
|
130 |
+
return self.forward_post(tgt, memory, memory_mask,
|
131 |
+
memory_key_padding_mask, pos, query_pos)
|
132 |
+
|
133 |
+
|
134 |
+
class FFNLayer(nn.Module):
|
135 |
+
|
136 |
+
def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
|
137 |
+
activation="relu", normalize_before=False):
|
138 |
+
super().__init__()
|
139 |
+
# Implementation of Feedforward model
|
140 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
141 |
+
self.dropout = nn.Dropout(dropout)
|
142 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
143 |
+
|
144 |
+
self.norm = nn.LayerNorm(d_model)
|
145 |
+
|
146 |
+
self.activation = _get_activation_fn(activation)
|
147 |
+
self.normalize_before = normalize_before
|
148 |
+
|
149 |
+
self._reset_parameters()
|
150 |
+
|
151 |
+
def _reset_parameters(self):
|
152 |
+
for p in self.parameters():
|
153 |
+
if p.dim() > 1:
|
154 |
+
nn.init.xavier_uniform_(p)
|
155 |
+
|
156 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
157 |
+
return tensor if pos is None else tensor + pos
|
158 |
+
|
159 |
+
def forward_post(self, tgt):
|
160 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
161 |
+
tgt = tgt + self.dropout(tgt2)
|
162 |
+
tgt = self.norm(tgt)
|
163 |
+
return tgt
|
164 |
+
|
165 |
+
def forward_pre(self, tgt):
|
166 |
+
tgt2 = self.norm(tgt)
|
167 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
168 |
+
tgt = tgt + self.dropout(tgt2)
|
169 |
+
return tgt
|
170 |
+
|
171 |
+
def forward(self, tgt):
|
172 |
+
if self.normalize_before:
|
173 |
+
return self.forward_pre(tgt)
|
174 |
+
return self.forward_post(tgt)
|
175 |
+
|
176 |
+
|
177 |
+
def _get_activation_fn(activation):
|
178 |
+
"""Return an activation function given a string"""
|
179 |
+
if activation == "relu":
|
180 |
+
return F.relu
|
181 |
+
if activation == "gelu":
|
182 |
+
return F.gelu
|
183 |
+
if activation == "glu":
|
184 |
+
return F.glu
|
185 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
186 |
+
|
187 |
+
|
188 |
+
class MLP(nn.Module):
|
189 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
190 |
+
|
191 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
192 |
+
super().__init__()
|
193 |
+
self.num_layers = num_layers
|
194 |
+
h = [hidden_dim] * (num_layers - 1)
|
195 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
196 |
+
|
197 |
+
def forward(self, x):
|
198 |
+
for i, layer in enumerate(self.layers):
|
199 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
200 |
+
return x
|
201 |
+
|
202 |
+
|
203 |
+
def _get_clones(module, N):
|
204 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
205 |
+
|
206 |
+
|
207 |
+
class Avism(nn.Module):
|
208 |
+
|
209 |
+
@configurable
|
210 |
+
def __init__(
|
211 |
+
self,
|
212 |
+
in_channels,
|
213 |
+
aux_loss,
|
214 |
+
*,
|
215 |
+
hidden_dim: int,
|
216 |
+
num_frame_queries: int,
|
217 |
+
num_queries: int,
|
218 |
+
nheads: int,
|
219 |
+
dim_feedforward: int,
|
220 |
+
enc_layers: int,
|
221 |
+
dec_layers: int,
|
222 |
+
enc_window_size: int,
|
223 |
+
pre_norm: bool,
|
224 |
+
enforce_input_project: bool,
|
225 |
+
num_frames: int,
|
226 |
+
num_classes: int,
|
227 |
+
clip_last_layer_num: bool,
|
228 |
+
conv_dim: int,
|
229 |
+
mask_dim: int,
|
230 |
+
sim_use_clip: list,
|
231 |
+
use_sim: bool,
|
232 |
+
):
|
233 |
+
"""
|
234 |
+
NOTE: this interface is experimental.
|
235 |
+
Args:
|
236 |
+
in_channels: channels of the input features
|
237 |
+
hidden_dim: Transformer feature dimension
|
238 |
+
num_queries: number of queries
|
239 |
+
nheads: number of heads
|
240 |
+
dim_feedforward: feature dimension in feedforward network
|
241 |
+
enc_layers: number of Transformer encoder layers
|
242 |
+
dec_layers: number of Transformer decoder layers
|
243 |
+
pre_norm: whether to use pre-LayerNorm or not
|
244 |
+
enforce_input_project: add input project 1x1 conv even if input
|
245 |
+
channels and hidden dim is identical
|
246 |
+
"""
|
247 |
+
super().__init__()
|
248 |
+
|
249 |
+
# define Transformer decoder here
|
250 |
+
self.num_heads = nheads
|
251 |
+
self.num_layers = dec_layers
|
252 |
+
self.transformer_self_attention_layers = nn.ModuleList()
|
253 |
+
self.transformer_cross_attention_layers = nn.ModuleList()
|
254 |
+
self.transformer_ffn_layers = nn.ModuleList()
|
255 |
+
self.num_frames = num_frames
|
256 |
+
self.num_classes = num_classes
|
257 |
+
self.clip_last_layer_num = clip_last_layer_num
|
258 |
+
|
259 |
+
self.enc_layers = enc_layers
|
260 |
+
self.window_size = enc_window_size
|
261 |
+
self.sim_use_clip = sim_use_clip
|
262 |
+
self.use_sim = use_sim
|
263 |
+
self.aux_loss = aux_loss
|
264 |
+
|
265 |
+
self.av_proj = nn.Linear(128, hidden_dim)
|
266 |
+
|
267 |
+
self.enc_layers = enc_layers
|
268 |
+
if enc_layers > 0:
|
269 |
+
self.enc_self_attn = nn.ModuleList()
|
270 |
+
self.enc_ffn = nn.ModuleList()
|
271 |
+
for _ in range(self.enc_layers):
|
272 |
+
self.enc_self_attn.append(
|
273 |
+
SelfAttentionLayer(
|
274 |
+
d_model=hidden_dim,
|
275 |
+
nhead=nheads,
|
276 |
+
dropout=0.0,
|
277 |
+
normalize_before=pre_norm,
|
278 |
+
),
|
279 |
+
)
|
280 |
+
self.enc_ffn.append(
|
281 |
+
FFNLayer(
|
282 |
+
d_model=hidden_dim,
|
283 |
+
dim_feedforward=dim_feedforward,
|
284 |
+
dropout=0.0,
|
285 |
+
normalize_before=pre_norm,
|
286 |
+
)
|
287 |
+
)
|
288 |
+
|
289 |
+
if enc_layers > 0:
|
290 |
+
self.enc_av_cross_attn = nn.ModuleList()
|
291 |
+
self.enc_av_ffn = nn.ModuleList()
|
292 |
+
for _ in range(self.enc_layers):
|
293 |
+
self.enc_av_cross_attn.append(
|
294 |
+
CrossAttentionLayer(
|
295 |
+
d_model=hidden_dim,
|
296 |
+
nhead=nheads,
|
297 |
+
dropout=0.0,
|
298 |
+
normalize_before=pre_norm,
|
299 |
+
),
|
300 |
+
)
|
301 |
+
self.enc_av_ffn.append(
|
302 |
+
FFNLayer(
|
303 |
+
d_model=hidden_dim,
|
304 |
+
dim_feedforward=dim_feedforward,
|
305 |
+
dropout=0.0,
|
306 |
+
normalize_before=pre_norm,
|
307 |
+
)
|
308 |
+
)
|
309 |
+
|
310 |
+
for _ in range(self.num_layers):
|
311 |
+
self.transformer_self_attention_layers.append(
|
312 |
+
SelfAttentionLayer(
|
313 |
+
d_model=hidden_dim,
|
314 |
+
nhead=nheads,
|
315 |
+
dropout=0.0,
|
316 |
+
normalize_before=pre_norm,
|
317 |
+
)
|
318 |
+
)
|
319 |
+
|
320 |
+
self.transformer_cross_attention_layers.append(
|
321 |
+
CrossAttentionLayer(
|
322 |
+
d_model=hidden_dim,
|
323 |
+
nhead=nheads,
|
324 |
+
dropout=0.0,
|
325 |
+
normalize_before=pre_norm,
|
326 |
+
)
|
327 |
+
)
|
328 |
+
|
329 |
+
self.transformer_ffn_layers.append(
|
330 |
+
FFNLayer(
|
331 |
+
d_model=hidden_dim,
|
332 |
+
dim_feedforward=dim_feedforward,
|
333 |
+
dropout=0.0,
|
334 |
+
normalize_before=pre_norm,
|
335 |
+
)
|
336 |
+
)
|
337 |
+
|
338 |
+
self.avism_mask_features = Conv2d(
|
339 |
+
conv_dim,
|
340 |
+
mask_dim,
|
341 |
+
kernel_size=1,
|
342 |
+
stride=1,
|
343 |
+
padding=0,
|
344 |
+
)
|
345 |
+
weight_init.c2_xavier_fill(self.avism_mask_features)
|
346 |
+
|
347 |
+
self.decoder_norm = nn.LayerNorm(hidden_dim)
|
348 |
+
|
349 |
+
self.num_queries = num_queries
|
350 |
+
# learnable query features
|
351 |
+
self.query_feat = nn.Embedding(num_queries, hidden_dim)
|
352 |
+
# learnable query p.e.
|
353 |
+
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
354 |
+
|
355 |
+
self.fq_pos = nn.Embedding(num_frame_queries, hidden_dim)
|
356 |
+
|
357 |
+
if in_channels != hidden_dim or enforce_input_project:
|
358 |
+
self.input_proj_dec = nn.Linear(hidden_dim, hidden_dim)
|
359 |
+
else:
|
360 |
+
self.input_proj_dec = nn.Sequential()
|
361 |
+
self.src_embed = nn.Identity()
|
362 |
+
|
363 |
+
self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
|
364 |
+
self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
|
365 |
+
if self.use_sim:
|
366 |
+
self.sim_embed_frame = nn.Linear(hidden_dim, hidden_dim)
|
367 |
+
if self.sim_use_clip:
|
368 |
+
self.sim_embed_clip = nn.Linear(hidden_dim, hidden_dim)
|
369 |
+
|
370 |
+
@classmethod
|
371 |
+
def from_config(cls, cfg, in_channels):
|
372 |
+
ret = {}
|
373 |
+
ret["in_channels"] = in_channels
|
374 |
+
|
375 |
+
ret["hidden_dim"] = cfg.MODEL.AVISM.HIDDEN_DIM
|
376 |
+
ret["num_frame_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
|
377 |
+
ret["num_queries"] = cfg.MODEL.AVISM.NUM_OBJECT_QUERIES
|
378 |
+
# Transformer parameters:
|
379 |
+
ret["nheads"] = cfg.MODEL.AVISM.NHEADS
|
380 |
+
ret["dim_feedforward"] = cfg.MODEL.AVISM.DIM_FEEDFORWARD
|
381 |
+
|
382 |
+
assert cfg.MODEL.AVISM.DEC_LAYERS >= 1
|
383 |
+
ret["enc_layers"] = cfg.MODEL.AVISM.ENC_LAYERS
|
384 |
+
ret["dec_layers"] = cfg.MODEL.AVISM.DEC_LAYERS
|
385 |
+
ret["enc_window_size"] = cfg.MODEL.AVISM.ENC_WINDOW_SIZE
|
386 |
+
ret["pre_norm"] = cfg.MODEL.AVISM.PRE_NORM
|
387 |
+
ret["enforce_input_project"] = cfg.MODEL.AVISM.ENFORCE_INPUT_PROJ
|
388 |
+
|
389 |
+
ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
|
390 |
+
ret["num_frames"] = cfg.INPUT.SAMPLING_FRAME_NUM
|
391 |
+
ret["clip_last_layer_num"] = cfg.MODEL.AVISM.LAST_LAYER_NUM
|
392 |
+
|
393 |
+
ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
|
394 |
+
ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
|
395 |
+
ret["sim_use_clip"] = cfg.MODEL.AVISM.SIM_USE_CLIP
|
396 |
+
ret["use_sim"] = cfg.MODEL.AVISM.SIM_WEIGHT > 0.0
|
397 |
+
|
398 |
+
return ret
|
399 |
+
|
400 |
+
def forward(self, frame_query, audio_features):
|
401 |
+
"""
|
402 |
+
L: Number of Layers.
|
403 |
+
B: Batch size.
|
404 |
+
T: Temporal window size. Number of frames per video.
|
405 |
+
C: Channel size.
|
406 |
+
fQ: Number of frame-wise queries from IFC.
|
407 |
+
cQ: Number of clip-wise queries to decode Q.
|
408 |
+
"""
|
409 |
+
if not self.training:
|
410 |
+
frame_query = frame_query[[-1]]
|
411 |
+
|
412 |
+
L, BT, fQ, C = frame_query.shape
|
413 |
+
B = BT // self.num_frames if self.training else 1
|
414 |
+
T = self.num_frames if self.training else BT // B
|
415 |
+
|
416 |
+
frame_query = frame_query.reshape(L * B, T, fQ, C)
|
417 |
+
frame_query = frame_query.permute(1, 2, 0, 3).contiguous()
|
418 |
+
frame_query = self.input_proj_dec(frame_query) # T, fQ, LB, C
|
419 |
+
|
420 |
+
audio_feat = self.av_proj(audio_features) # T, C
|
421 |
+
audio_feat = audio_feat[:, None, None, :].repeat(1, fQ, L * B, 1)
|
422 |
+
|
423 |
+
if self.window_size > 0:
|
424 |
+
pad = int(ceil(T / self.window_size)) * self.window_size - T
|
425 |
+
_T = pad + T
|
426 |
+
frame_query = F.pad(frame_query, (0, 0, 0, 0, 0, 0, 0, pad)) # _T, fQ, LB, C
|
427 |
+
audio_feat = F.pad(audio_feat, (0, 0, 0, 0, 0, 0, 0, pad))
|
428 |
+
enc_mask = frame_query.new_ones(L * B, _T).bool() # LB, _T
|
429 |
+
enc_mask[:, :T] = False
|
430 |
+
else:
|
431 |
+
enc_mask = None
|
432 |
+
|
433 |
+
frame_query = self.encode_frame_query(frame_query, enc_mask)
|
434 |
+
|
435 |
+
# audio
|
436 |
+
av_feat = self.encode_av_fusion(frame_query, enc_mask, audio_feat)
|
437 |
+
|
438 |
+
frame_query = frame_query[:T].flatten(0, 1) # TfQ, LB, C
|
439 |
+
av_feat = av_feat[:T].flatten(0, 1)
|
440 |
+
frame_query = frame_query + av_feat
|
441 |
+
|
442 |
+
if self.use_sim:
|
443 |
+
pred_fq_embed = self.sim_embed_frame(frame_query) # TfQ, LB, C
|
444 |
+
pred_fq_embed = pred_fq_embed.transpose(0, 1).reshape(L, B, T, fQ, C)
|
445 |
+
else:
|
446 |
+
pred_fq_embed = None
|
447 |
+
|
448 |
+
src = self.src_embed(frame_query) # TfQ, LB, C
|
449 |
+
dec_pos = self.fq_pos.weight[None, :, None, :].repeat(T, 1, L * B, 1).flatten(0, 1) # TfQ, LB, C
|
450 |
+
|
451 |
+
# QxNxC
|
452 |
+
query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, L * B, 1) # cQ, LB, C
|
453 |
+
output = self.query_feat.weight.unsqueeze(1).repeat(1, L * B, 1) # cQ, LB, C
|
454 |
+
|
455 |
+
decoder_outputs = []
|
456 |
+
for i in range(self.num_layers):
|
457 |
+
# attention: cross-attention first
|
458 |
+
output = self.transformer_cross_attention_layers[i](
|
459 |
+
output, src,
|
460 |
+
memory_mask=None,
|
461 |
+
memory_key_padding_mask=None,
|
462 |
+
pos=dec_pos, query_pos=query_embed
|
463 |
+
)
|
464 |
+
|
465 |
+
output = self.transformer_self_attention_layers[i](
|
466 |
+
output, tgt_mask=None,
|
467 |
+
tgt_key_padding_mask=None,
|
468 |
+
query_pos=query_embed
|
469 |
+
)
|
470 |
+
|
471 |
+
# FFN
|
472 |
+
output = self.transformer_ffn_layers[i](
|
473 |
+
output
|
474 |
+
)
|
475 |
+
|
476 |
+
if (self.training and self.aux_loss) or (i == self.num_layers - 1):
|
477 |
+
dec_out = self.decoder_norm(output) # cQ, LB, C
|
478 |
+
dec_out = dec_out.transpose(0, 1) # LB, cQ, C
|
479 |
+
decoder_outputs.append(dec_out.view(L, B, self.num_queries, C))
|
480 |
+
|
481 |
+
decoder_outputs = torch.stack(decoder_outputs, dim=0) # D, L, B, cQ, C
|
482 |
+
|
483 |
+
pred_cls = self.class_embed(decoder_outputs)
|
484 |
+
pred_mask_embed = self.mask_embed(decoder_outputs)
|
485 |
+
if self.use_sim and self.sim_use_clip:
|
486 |
+
pred_cq_embed = self.sim_embed_clip(decoder_outputs)
|
487 |
+
else:
|
488 |
+
pred_cq_embed = [None] * self.num_layers
|
489 |
+
|
490 |
+
out = {
|
491 |
+
'pred_logits': pred_cls[-1],
|
492 |
+
'pred_mask_embed': pred_mask_embed[-1],
|
493 |
+
'pred_fq_embed': pred_fq_embed,
|
494 |
+
'pred_cq_embed': pred_cq_embed[-1],
|
495 |
+
'aux_outputs': self._set_aux_loss(
|
496 |
+
pred_cls, pred_mask_embed, pred_cq_embed, pred_fq_embed
|
497 |
+
)
|
498 |
+
}
|
499 |
+
return out
|
500 |
+
|
501 |
+
@torch.jit.unused
|
502 |
+
def _set_aux_loss(
|
503 |
+
self, outputs_cls, outputs_mask_embed, outputs_cq_embed, outputs_fq_embed
|
504 |
+
):
|
505 |
+
return [{"pred_logits": a, "pred_mask_embed": b, "pred_cq_embed": c, "pred_fq_embed": outputs_fq_embed}
|
506 |
+
for a, b, c in zip(outputs_cls[:-1], outputs_mask_embed[:-1], outputs_cq_embed[:-1])]
|
507 |
+
|
508 |
+
def encode_frame_query(self, frame_query, attn_mask):
|
509 |
+
"""
|
510 |
+
input shape (frame_query) : T, fQ, LB, C
|
511 |
+
output shape (frame_query) : T, fQ, LB, C
|
512 |
+
"""
|
513 |
+
|
514 |
+
# Not using window-based attention if self.window_size == 0.
|
515 |
+
if self.window_size == 0:
|
516 |
+
return_shape = frame_query.shape # T, fQ, LB, C
|
517 |
+
frame_query = frame_query.flatten(0, 1) # TfQ, LB, C
|
518 |
+
|
519 |
+
for i in range(self.enc_layers):
|
520 |
+
frame_query = self.enc_self_attn[i](frame_query)
|
521 |
+
frame_query = self.enc_ffn[i](frame_query)
|
522 |
+
|
523 |
+
frame_query = frame_query.view(return_shape)
|
524 |
+
return frame_query
|
525 |
+
# Using window-based attention if self.window_size > 0.
|
526 |
+
else:
|
527 |
+
T, fQ, LB, C = frame_query.shape
|
528 |
+
W = self.window_size
|
529 |
+
Nw = T // W
|
530 |
+
half_W = int(ceil(W / 2))
|
531 |
+
|
532 |
+
window_mask = attn_mask.view(LB * Nw, W)[..., None].repeat(1, 1, fQ).flatten(1)
|
533 |
+
|
534 |
+
_attn_mask = torch.roll(attn_mask, half_W, 1)
|
535 |
+
_attn_mask = _attn_mask.view(LB, Nw, W)[..., None].repeat(1, 1, 1, W) # LB, Nw, W, W
|
536 |
+
_attn_mask[:, 0] = _attn_mask[:, 0] | _attn_mask[:, 0].transpose(-2, -1)
|
537 |
+
_attn_mask[:, -1] = _attn_mask[:, -1] | _attn_mask[:, -1].transpose(-2, -1)
|
538 |
+
_attn_mask[:, 0, :half_W, half_W:] = True
|
539 |
+
_attn_mask[:, 0, half_W:, :half_W] = True
|
540 |
+
_attn_mask = _attn_mask.view(LB * Nw, 1, W, 1, W, 1).repeat(1, self.num_heads, 1, fQ, 1, fQ).view(
|
541 |
+
LB * Nw * self.num_heads, W * fQ, W * fQ)
|
542 |
+
shift_window_mask = _attn_mask.float() * -1000
|
543 |
+
|
544 |
+
for layer_idx in range(self.enc_layers):
|
545 |
+
if self.training or layer_idx % 2 == 0:
|
546 |
+
frame_query = self._window_attn(frame_query, window_mask, layer_idx)
|
547 |
+
else:
|
548 |
+
frame_query = self._shift_window_attn(frame_query, shift_window_mask, layer_idx)
|
549 |
+
return frame_query
|
550 |
+
|
551 |
+
def _window_attn(self, frame_query, attn_mask, layer_idx):
|
552 |
+
T, fQ, LB, C = frame_query.shape
|
553 |
+
# LBN, WTfQ = attn_mask.shape
|
554 |
+
|
555 |
+
W = self.window_size
|
556 |
+
Nw = T // W
|
557 |
+
|
558 |
+
frame_query = frame_query.view(Nw, W, fQ, LB, C)
|
559 |
+
frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
|
560 |
+
|
561 |
+
frame_query = self.enc_self_attn[layer_idx](frame_query, tgt_key_padding_mask=attn_mask)
|
562 |
+
frame_query = self.enc_ffn[layer_idx](frame_query)
|
563 |
+
frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
|
564 |
+
|
565 |
+
return frame_query
|
566 |
+
|
567 |
+
def _shift_window_attn(self, frame_query, attn_mask, layer_idx):
|
568 |
+
T, fQ, LB, C = frame_query.shape
|
569 |
+
# LBNH, WfQ, WfQ = attn_mask.shape
|
570 |
+
|
571 |
+
W = self.window_size
|
572 |
+
Nw = T // W
|
573 |
+
half_W = int(ceil(W / 2))
|
574 |
+
|
575 |
+
frame_query = torch.roll(frame_query, half_W, 0)
|
576 |
+
frame_query = frame_query.view(Nw, W, fQ, LB, C)
|
577 |
+
frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
|
578 |
+
|
579 |
+
frame_query = self.enc_self_attn[layer_idx](frame_query, tgt_mask=attn_mask)
|
580 |
+
frame_query = self.enc_ffn[layer_idx](frame_query)
|
581 |
+
frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
|
582 |
+
|
583 |
+
frame_query = torch.roll(frame_query, -half_W, 0)
|
584 |
+
|
585 |
+
return frame_query
|
586 |
+
|
587 |
+
def encode_av_fusion(self, frame_query, attn_mask, audio_feats):
|
588 |
+
"""
|
589 |
+
input shape (frame_query) : T, fQ, LB, C
|
590 |
+
output shape (frame_query) : T, fQ, LB, C
|
591 |
+
"""
|
592 |
+
|
593 |
+
# Not using window-based attention if self.window_size == 0.
|
594 |
+
if self.window_size == 0:
|
595 |
+
return_shape = frame_query.shape # T, fQ, LB, C
|
596 |
+
frame_query = frame_query.flatten(0, 1) # TfQ, LB, C
|
597 |
+
audio_feats = audio_feats.flatten(0, 1)
|
598 |
+
|
599 |
+
for i in range(self.enc_layers):
|
600 |
+
audio_feats = self.enc_av_cross_attn[i](audio_feats, frame_query)
|
601 |
+
audio_feats = self.enc_av_ffn[i](audio_feats)
|
602 |
+
|
603 |
+
audio_feats = audio_feats.view(return_shape)
|
604 |
+
return audio_feats
|
605 |
+
# Using window-based attention if self.window_size > 0.
|
606 |
+
else:
|
607 |
+
T, fQ, LB, C = frame_query.shape
|
608 |
+
W = self.window_size
|
609 |
+
Nw = T // W
|
610 |
+
half_W = int(ceil(W / 2))
|
611 |
+
|
612 |
+
window_mask = attn_mask.view(LB * Nw, W)[..., None].repeat(1, 1, fQ).flatten(1)
|
613 |
+
|
614 |
+
_attn_mask = torch.roll(attn_mask, half_W, 1)
|
615 |
+
_attn_mask = _attn_mask.view(LB, Nw, W)[..., None].repeat(1, 1, 1, W) # LB, Nw, W, W
|
616 |
+
_attn_mask[:, 0] = _attn_mask[:, 0] | _attn_mask[:, 0].transpose(-2, -1)
|
617 |
+
_attn_mask[:, -1] = _attn_mask[:, -1] | _attn_mask[:, -1].transpose(-2, -1)
|
618 |
+
_attn_mask[:, 0, :half_W, half_W:] = True
|
619 |
+
_attn_mask[:, 0, half_W:, :half_W] = True
|
620 |
+
_attn_mask = _attn_mask.view(LB * Nw, 1, W, 1, W, 1).repeat(1, self.num_heads, 1, fQ, 1, fQ).view(
|
621 |
+
LB * Nw * self.num_heads, W * fQ, W * fQ)
|
622 |
+
shift_window_mask = _attn_mask.float() * -1000
|
623 |
+
|
624 |
+
for layer_idx in range(self.enc_layers):
|
625 |
+
if layer_idx % 2 == 0:
|
626 |
+
frame_query, audio_feats = self._window_av_attn(frame_query, window_mask, layer_idx, audio_feats)
|
627 |
+
else:
|
628 |
+
frame_query, audio_feats = self._shift_window_av_attn(frame_query, shift_window_mask, layer_idx, audio_feats)
|
629 |
+
return audio_feats
|
630 |
+
|
631 |
+
def _window_av_attn(self, frame_query, attn_mask, layer_idx, audio_feats):
|
632 |
+
T, fQ, LB, C = frame_query.shape
|
633 |
+
|
634 |
+
W = self.window_size
|
635 |
+
Nw = T // W
|
636 |
+
|
637 |
+
frame_query = frame_query.view(Nw, W, fQ, LB, C)
|
638 |
+
frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
|
639 |
+
|
640 |
+
audio_feats = audio_feats.view(Nw, W, fQ, LB, C)
|
641 |
+
audio_feats = audio_feats.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
|
642 |
+
|
643 |
+
audio_feats = self.enc_av_cross_attn[layer_idx](audio_feats, frame_query, memory_key_padding_mask=attn_mask)
|
644 |
+
audio_feats = self.enc_av_ffn[layer_idx](audio_feats)
|
645 |
+
|
646 |
+
frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
|
647 |
+
audio_feats = audio_feats.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
|
648 |
+
|
649 |
+
return frame_query, audio_feats
|
650 |
+
|
651 |
+
def _shift_window_av_attn(self, frame_query, attn_mask, layer_idx, audio_feats):
|
652 |
+
T, fQ, LB, C = frame_query.shape
|
653 |
+
|
654 |
+
W = self.window_size
|
655 |
+
Nw = T // W
|
656 |
+
half_W = int(ceil(W / 2))
|
657 |
+
|
658 |
+
frame_query = torch.roll(frame_query, half_W, 0)
|
659 |
+
frame_query = frame_query.view(Nw, W, fQ, LB, C)
|
660 |
+
frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
|
661 |
+
|
662 |
+
audio_feats = torch.roll(audio_feats, half_W, 0)
|
663 |
+
audio_feats = audio_feats.view(Nw, W, fQ, LB, C)
|
664 |
+
audio_feats = audio_feats.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
|
665 |
+
|
666 |
+
audio_feats = self.enc_av_cross_attn[layer_idx](audio_feats, frame_query, memory_mask=attn_mask)
|
667 |
+
audio_feats = self.enc_av_ffn[layer_idx](audio_feats)
|
668 |
+
|
669 |
+
frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
|
670 |
+
frame_query = torch.roll(frame_query, -half_W, 0)
|
671 |
+
|
672 |
+
audio_feats = audio_feats.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
|
673 |
+
audio_feats = torch.roll(audio_feats, -half_W, 0)
|
674 |
+
|
675 |
+
return frame_query, audio_feats
|
avism/modeling/transformer_decoder/avism_coco.py
ADDED
@@ -0,0 +1,675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import ceil
|
2 |
+
import fvcore.nn.weight_init as weight_init
|
3 |
+
from typing import Optional
|
4 |
+
import torch
|
5 |
+
from torch import nn, Tensor
|
6 |
+
from torch.nn import functional as F
|
7 |
+
import copy
|
8 |
+
|
9 |
+
from detectron2.config import configurable
|
10 |
+
from detectron2.layers import Conv2d
|
11 |
+
|
12 |
+
|
13 |
+
class SelfAttentionLayer(nn.Module):
|
14 |
+
|
15 |
+
def __init__(self, d_model, nhead, dropout=0.0,
|
16 |
+
activation="relu", normalize_before=False):
|
17 |
+
super().__init__()
|
18 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
19 |
+
|
20 |
+
self.norm = nn.LayerNorm(d_model)
|
21 |
+
self.dropout = nn.Dropout(dropout)
|
22 |
+
|
23 |
+
self.activation = _get_activation_fn(activation)
|
24 |
+
self.normalize_before = normalize_before
|
25 |
+
|
26 |
+
self._reset_parameters()
|
27 |
+
|
28 |
+
def _reset_parameters(self):
|
29 |
+
for p in self.parameters():
|
30 |
+
if p.dim() > 1:
|
31 |
+
nn.init.xavier_uniform_(p)
|
32 |
+
|
33 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
34 |
+
return tensor if pos is None else tensor + pos
|
35 |
+
|
36 |
+
def forward_post(self, tgt,
|
37 |
+
tgt_mask: Optional[Tensor] = None,
|
38 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
39 |
+
query_pos: Optional[Tensor] = None):
|
40 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
41 |
+
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
42 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
43 |
+
tgt = tgt + self.dropout(tgt2)
|
44 |
+
tgt = self.norm(tgt)
|
45 |
+
|
46 |
+
return tgt
|
47 |
+
|
48 |
+
def forward_pre(self, tgt,
|
49 |
+
tgt_mask: Optional[Tensor] = None,
|
50 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
51 |
+
query_pos: Optional[Tensor] = None):
|
52 |
+
tgt2 = self.norm(tgt)
|
53 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
54 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
55 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
56 |
+
tgt = tgt + self.dropout(tgt2)
|
57 |
+
|
58 |
+
return tgt
|
59 |
+
|
60 |
+
def forward(self, tgt,
|
61 |
+
tgt_mask: Optional[Tensor] = None,
|
62 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
63 |
+
query_pos: Optional[Tensor] = None):
|
64 |
+
if self.normalize_before:
|
65 |
+
return self.forward_pre(tgt, tgt_mask,
|
66 |
+
tgt_key_padding_mask, query_pos)
|
67 |
+
return self.forward_post(tgt, tgt_mask,
|
68 |
+
tgt_key_padding_mask, query_pos)
|
69 |
+
|
70 |
+
|
71 |
+
class CrossAttentionLayer(nn.Module):
|
72 |
+
|
73 |
+
def __init__(self, d_model, nhead, dropout=0.0,
|
74 |
+
activation="relu", normalize_before=False):
|
75 |
+
super().__init__()
|
76 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
77 |
+
|
78 |
+
self.norm = nn.LayerNorm(d_model)
|
79 |
+
self.dropout = nn.Dropout(dropout)
|
80 |
+
|
81 |
+
self.activation = _get_activation_fn(activation)
|
82 |
+
self.normalize_before = normalize_before
|
83 |
+
|
84 |
+
self._reset_parameters()
|
85 |
+
|
86 |
+
def _reset_parameters(self):
|
87 |
+
for p in self.parameters():
|
88 |
+
if p.dim() > 1:
|
89 |
+
nn.init.xavier_uniform_(p)
|
90 |
+
|
91 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
92 |
+
return tensor if pos is None else tensor + pos
|
93 |
+
|
94 |
+
def forward_post(self, tgt, memory,
|
95 |
+
memory_mask: Optional[Tensor] = None,
|
96 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
97 |
+
pos: Optional[Tensor] = None,
|
98 |
+
query_pos: Optional[Tensor] = None):
|
99 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
100 |
+
key=self.with_pos_embed(memory, pos),
|
101 |
+
value=memory, attn_mask=memory_mask,
|
102 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
103 |
+
tgt = tgt + self.dropout(tgt2)
|
104 |
+
tgt = self.norm(tgt)
|
105 |
+
|
106 |
+
return tgt
|
107 |
+
|
108 |
+
def forward_pre(self, tgt, memory,
|
109 |
+
memory_mask: Optional[Tensor] = None,
|
110 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
111 |
+
pos: Optional[Tensor] = None,
|
112 |
+
query_pos: Optional[Tensor] = None):
|
113 |
+
tgt2 = self.norm(tgt)
|
114 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
115 |
+
key=self.with_pos_embed(memory, pos),
|
116 |
+
value=memory, attn_mask=memory_mask,
|
117 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
118 |
+
tgt = tgt + self.dropout(tgt2)
|
119 |
+
|
120 |
+
return tgt
|
121 |
+
|
122 |
+
def forward(self, tgt, memory,
|
123 |
+
memory_mask: Optional[Tensor] = None,
|
124 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
125 |
+
pos: Optional[Tensor] = None,
|
126 |
+
query_pos: Optional[Tensor] = None):
|
127 |
+
if self.normalize_before:
|
128 |
+
return self.forward_pre(tgt, memory, memory_mask,
|
129 |
+
memory_key_padding_mask, pos, query_pos)
|
130 |
+
return self.forward_post(tgt, memory, memory_mask,
|
131 |
+
memory_key_padding_mask, pos, query_pos)
|
132 |
+
|
133 |
+
|
134 |
+
class FFNLayer(nn.Module):
|
135 |
+
|
136 |
+
def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
|
137 |
+
activation="relu", normalize_before=False):
|
138 |
+
super().__init__()
|
139 |
+
# Implementation of Feedforward model
|
140 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
141 |
+
self.dropout = nn.Dropout(dropout)
|
142 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
143 |
+
|
144 |
+
self.norm = nn.LayerNorm(d_model)
|
145 |
+
|
146 |
+
self.activation = _get_activation_fn(activation)
|
147 |
+
self.normalize_before = normalize_before
|
148 |
+
|
149 |
+
self._reset_parameters()
|
150 |
+
|
151 |
+
def _reset_parameters(self):
|
152 |
+
for p in self.parameters():
|
153 |
+
if p.dim() > 1:
|
154 |
+
nn.init.xavier_uniform_(p)
|
155 |
+
|
156 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
157 |
+
return tensor if pos is None else tensor + pos
|
158 |
+
|
159 |
+
def forward_post(self, tgt):
|
160 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
161 |
+
tgt = tgt + self.dropout(tgt2)
|
162 |
+
tgt = self.norm(tgt)
|
163 |
+
return tgt
|
164 |
+
|
165 |
+
def forward_pre(self, tgt):
|
166 |
+
tgt2 = self.norm(tgt)
|
167 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
168 |
+
tgt = tgt + self.dropout(tgt2)
|
169 |
+
return tgt
|
170 |
+
|
171 |
+
def forward(self, tgt):
|
172 |
+
if self.normalize_before:
|
173 |
+
return self.forward_pre(tgt)
|
174 |
+
return self.forward_post(tgt)
|
175 |
+
|
176 |
+
|
177 |
+
def _get_activation_fn(activation):
|
178 |
+
"""Return an activation function given a string"""
|
179 |
+
if activation == "relu":
|
180 |
+
return F.relu
|
181 |
+
if activation == "gelu":
|
182 |
+
return F.gelu
|
183 |
+
if activation == "glu":
|
184 |
+
return F.glu
|
185 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
186 |
+
|
187 |
+
|
188 |
+
class MLP(nn.Module):
|
189 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
190 |
+
|
191 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
192 |
+
super().__init__()
|
193 |
+
self.num_layers = num_layers
|
194 |
+
h = [hidden_dim] * (num_layers - 1)
|
195 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
196 |
+
|
197 |
+
def forward(self, x):
|
198 |
+
for i, layer in enumerate(self.layers):
|
199 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
200 |
+
return x
|
201 |
+
|
202 |
+
|
203 |
+
def _get_clones(module, N):
|
204 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
205 |
+
|
206 |
+
|
207 |
+
class Avism_COCO(nn.Module):
|
208 |
+
|
209 |
+
@configurable
|
210 |
+
def __init__(
|
211 |
+
self,
|
212 |
+
in_channels,
|
213 |
+
aux_loss,
|
214 |
+
*,
|
215 |
+
hidden_dim: int,
|
216 |
+
num_frame_queries: int,
|
217 |
+
num_queries: int,
|
218 |
+
nheads: int,
|
219 |
+
dim_feedforward: int,
|
220 |
+
enc_layers: int,
|
221 |
+
dec_layers: int,
|
222 |
+
enc_window_size: int,
|
223 |
+
pre_norm: bool,
|
224 |
+
enforce_input_project: bool,
|
225 |
+
num_frames: int,
|
226 |
+
num_classes: int,
|
227 |
+
clip_last_layer_num: bool,
|
228 |
+
conv_dim: int,
|
229 |
+
mask_dim: int,
|
230 |
+
sim_use_clip: list,
|
231 |
+
use_sim: bool,
|
232 |
+
):
|
233 |
+
"""
|
234 |
+
NOTE: this interface is experimental.
|
235 |
+
Args:
|
236 |
+
in_channels: channels of the input features
|
237 |
+
hidden_dim: Transformer feature dimension
|
238 |
+
num_queries: number of queries
|
239 |
+
nheads: number of heads
|
240 |
+
dim_feedforward: feature dimension in feedforward network
|
241 |
+
enc_layers: number of Transformer encoder layers
|
242 |
+
dec_layers: number of Transformer decoder layers
|
243 |
+
pre_norm: whether to use pre-LayerNorm or not
|
244 |
+
enforce_input_project: add input project 1x1 conv even if input
|
245 |
+
channels and hidden dim is identical
|
246 |
+
"""
|
247 |
+
super().__init__()
|
248 |
+
|
249 |
+
# define Transformer decoder here
|
250 |
+
self.num_heads = nheads
|
251 |
+
self.num_layers = dec_layers
|
252 |
+
self.transformer_self_attention_layers = nn.ModuleList()
|
253 |
+
self.transformer_cross_attention_layers = nn.ModuleList()
|
254 |
+
self.transformer_ffn_layers = nn.ModuleList()
|
255 |
+
self.num_frames = num_frames
|
256 |
+
self.num_classes = num_classes
|
257 |
+
self.clip_last_layer_num = clip_last_layer_num
|
258 |
+
|
259 |
+
self.enc_layers = enc_layers
|
260 |
+
self.window_size = enc_window_size
|
261 |
+
self.sim_use_clip = sim_use_clip
|
262 |
+
self.use_sim = use_sim
|
263 |
+
self.aux_loss = aux_loss
|
264 |
+
|
265 |
+
self.av_proj = nn.Linear(128, hidden_dim)
|
266 |
+
|
267 |
+
self.enc_layers = enc_layers
|
268 |
+
if enc_layers > 0:
|
269 |
+
self.enc_self_attn = nn.ModuleList()
|
270 |
+
self.enc_ffn = nn.ModuleList()
|
271 |
+
for _ in range(self.enc_layers):
|
272 |
+
self.enc_self_attn.append(
|
273 |
+
SelfAttentionLayer(
|
274 |
+
d_model=hidden_dim,
|
275 |
+
nhead=nheads,
|
276 |
+
dropout=0.0,
|
277 |
+
normalize_before=pre_norm,
|
278 |
+
),
|
279 |
+
)
|
280 |
+
self.enc_ffn.append(
|
281 |
+
FFNLayer(
|
282 |
+
d_model=hidden_dim,
|
283 |
+
dim_feedforward=dim_feedforward,
|
284 |
+
dropout=0.0,
|
285 |
+
normalize_before=pre_norm,
|
286 |
+
)
|
287 |
+
)
|
288 |
+
|
289 |
+
if enc_layers > 0:
|
290 |
+
self.enc_av_cross_attn = nn.ModuleList()
|
291 |
+
self.enc_av_ffn = nn.ModuleList()
|
292 |
+
for _ in range(self.enc_layers):
|
293 |
+
self.enc_av_cross_attn.append(
|
294 |
+
CrossAttentionLayer(
|
295 |
+
d_model=hidden_dim,
|
296 |
+
nhead=nheads,
|
297 |
+
dropout=0.0,
|
298 |
+
normalize_before=pre_norm,
|
299 |
+
),
|
300 |
+
)
|
301 |
+
self.enc_av_ffn.append(
|
302 |
+
FFNLayer(
|
303 |
+
d_model=hidden_dim,
|
304 |
+
dim_feedforward=dim_feedforward,
|
305 |
+
dropout=0.0,
|
306 |
+
normalize_before=pre_norm,
|
307 |
+
)
|
308 |
+
)
|
309 |
+
|
310 |
+
for _ in range(self.num_layers):
|
311 |
+
self.transformer_self_attention_layers.append(
|
312 |
+
SelfAttentionLayer(
|
313 |
+
d_model=hidden_dim,
|
314 |
+
nhead=nheads,
|
315 |
+
dropout=0.0,
|
316 |
+
normalize_before=pre_norm,
|
317 |
+
)
|
318 |
+
)
|
319 |
+
|
320 |
+
self.transformer_cross_attention_layers.append(
|
321 |
+
CrossAttentionLayer(
|
322 |
+
d_model=hidden_dim,
|
323 |
+
nhead=nheads,
|
324 |
+
dropout=0.0,
|
325 |
+
normalize_before=pre_norm,
|
326 |
+
)
|
327 |
+
)
|
328 |
+
|
329 |
+
self.transformer_ffn_layers.append(
|
330 |
+
FFNLayer(
|
331 |
+
d_model=hidden_dim,
|
332 |
+
dim_feedforward=dim_feedforward,
|
333 |
+
dropout=0.0,
|
334 |
+
normalize_before=pre_norm,
|
335 |
+
)
|
336 |
+
)
|
337 |
+
|
338 |
+
self.vita_mask_features = Conv2d(
|
339 |
+
conv_dim,
|
340 |
+
mask_dim,
|
341 |
+
kernel_size=1,
|
342 |
+
stride=1,
|
343 |
+
padding=0,
|
344 |
+
)
|
345 |
+
weight_init.c2_xavier_fill(self.vita_mask_features)
|
346 |
+
|
347 |
+
self.decoder_norm = nn.LayerNorm(hidden_dim)
|
348 |
+
|
349 |
+
self.num_queries = num_queries
|
350 |
+
# learnable query features
|
351 |
+
self.query_feat = nn.Embedding(num_queries, hidden_dim)
|
352 |
+
# learnable query p.e.
|
353 |
+
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
354 |
+
|
355 |
+
self.fq_pos = nn.Embedding(num_frame_queries, hidden_dim)
|
356 |
+
|
357 |
+
if in_channels != hidden_dim or enforce_input_project:
|
358 |
+
self.input_proj_dec = nn.Linear(hidden_dim, hidden_dim)
|
359 |
+
else:
|
360 |
+
self.input_proj_dec = nn.Sequential()
|
361 |
+
self.src_embed = nn.Identity()
|
362 |
+
|
363 |
+
self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
|
364 |
+
self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
|
365 |
+
if self.use_sim:
|
366 |
+
self.sim_embed_frame = nn.Linear(hidden_dim, hidden_dim)
|
367 |
+
if self.sim_use_clip:
|
368 |
+
self.sim_embed_clip = nn.Linear(hidden_dim, hidden_dim)
|
369 |
+
|
370 |
+
@classmethod
|
371 |
+
def from_config(cls, cfg, in_channels):
|
372 |
+
ret = {}
|
373 |
+
ret["in_channels"] = in_channels
|
374 |
+
|
375 |
+
ret["hidden_dim"] = cfg.MODEL.AVISM.HIDDEN_DIM
|
376 |
+
ret["num_frame_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
|
377 |
+
ret["num_queries"] = cfg.MODEL.AVISM.NUM_OBJECT_QUERIES
|
378 |
+
# Transformer parameters:
|
379 |
+
ret["nheads"] = cfg.MODEL.AVISM.NHEADS
|
380 |
+
ret["dim_feedforward"] = cfg.MODEL.AVISM.DIM_FEEDFORWARD
|
381 |
+
|
382 |
+
assert cfg.MODEL.AVISM.DEC_LAYERS >= 1
|
383 |
+
ret["enc_layers"] = cfg.MODEL.AVISM.ENC_LAYERS
|
384 |
+
ret["dec_layers"] = cfg.MODEL.AVISM.DEC_LAYERS
|
385 |
+
ret["enc_window_size"] = cfg.MODEL.AVISM.ENC_WINDOW_SIZE
|
386 |
+
ret["pre_norm"] = cfg.MODEL.AVISM.PRE_NORM
|
387 |
+
ret["enforce_input_project"] = cfg.MODEL.AVISM.ENFORCE_INPUT_PROJ
|
388 |
+
|
389 |
+
ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
|
390 |
+
ret["num_frames"] = cfg.INPUT.SAMPLING_FRAME_NUM
|
391 |
+
ret["clip_last_layer_num"] = cfg.MODEL.AVISM.LAST_LAYER_NUM
|
392 |
+
|
393 |
+
ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
|
394 |
+
ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
|
395 |
+
ret["sim_use_clip"] = cfg.MODEL.AVISM.SIM_USE_CLIP
|
396 |
+
ret["use_sim"] = cfg.MODEL.AVISM.SIM_WEIGHT > 0.0
|
397 |
+
|
398 |
+
return ret
|
399 |
+
|
400 |
+
def forward(self, frame_query, audio_features):
|
401 |
+
"""
|
402 |
+
L: Number of Layers.
|
403 |
+
B: Batch size.
|
404 |
+
T: Temporal window size. Number of frames per video.
|
405 |
+
C: Channel size.
|
406 |
+
fQ: Number of frame-wise queries from IFC.
|
407 |
+
cQ: Number of clip-wise queries to decode Q.
|
408 |
+
"""
|
409 |
+
if not self.training:
|
410 |
+
frame_query = frame_query[[-1]]
|
411 |
+
|
412 |
+
L, BT, fQ, C = frame_query.shape
|
413 |
+
B = BT // self.num_frames if self.training else 1
|
414 |
+
T = self.num_frames if self.training else BT // B
|
415 |
+
|
416 |
+
frame_query = frame_query.reshape(L * B, T, fQ, C)
|
417 |
+
frame_query = frame_query.permute(1, 2, 0, 3).contiguous()
|
418 |
+
frame_query = self.input_proj_dec(frame_query) # T, fQ, LB, C
|
419 |
+
|
420 |
+
audio_feat = self.av_proj(audio_features) # T, C
|
421 |
+
audio_feat = audio_feat[:, None, None, :].repeat(1, fQ, L * B, 1)
|
422 |
+
|
423 |
+
if self.window_size > 0:
|
424 |
+
pad = int(ceil(T / self.window_size)) * self.window_size - T
|
425 |
+
_T = pad + T
|
426 |
+
frame_query = F.pad(frame_query, (0, 0, 0, 0, 0, 0, 0, pad)) # _T, fQ, LB, C
|
427 |
+
audio_feat = F.pad(audio_feat, (0, 0, 0, 0, 0, 0, 0, pad))
|
428 |
+
enc_mask = frame_query.new_ones(L * B, _T).bool() # LB, _T
|
429 |
+
enc_mask[:, :T] = False
|
430 |
+
else:
|
431 |
+
enc_mask = None
|
432 |
+
|
433 |
+
frame_query = self.encode_frame_query(frame_query, enc_mask)
|
434 |
+
|
435 |
+
# audio
|
436 |
+
av_feat = self.encode_av_fusion(frame_query, enc_mask, audio_feat)
|
437 |
+
|
438 |
+
frame_query = frame_query[:T].flatten(0, 1) # TfQ, LB, C
|
439 |
+
av_feat = av_feat[:T].flatten(0, 1)
|
440 |
+
frame_query = frame_query + av_feat
|
441 |
+
|
442 |
+
if self.use_sim:
|
443 |
+
pred_fq_embed = self.sim_embed_frame(frame_query) # TfQ, LB, C
|
444 |
+
pred_fq_embed = pred_fq_embed.transpose(0, 1).reshape(L, B, T, fQ, C)
|
445 |
+
else:
|
446 |
+
pred_fq_embed = None
|
447 |
+
|
448 |
+
src = self.src_embed(frame_query) # TfQ, LB, C
|
449 |
+
dec_pos = self.fq_pos.weight[None, :, None, :].repeat(T, 1, L * B, 1).flatten(0, 1) # TfQ, LB, C
|
450 |
+
|
451 |
+
# QxNxC
|
452 |
+
query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, L * B, 1) # cQ, LB, C
|
453 |
+
output = self.query_feat.weight.unsqueeze(1).repeat(1, L * B, 1) # cQ, LB, C
|
454 |
+
|
455 |
+
decoder_outputs = []
|
456 |
+
for i in range(self.num_layers):
|
457 |
+
# attention: cross-attention first
|
458 |
+
output = self.transformer_cross_attention_layers[i](
|
459 |
+
output, src,
|
460 |
+
memory_mask=None,
|
461 |
+
memory_key_padding_mask=None,
|
462 |
+
pos=dec_pos, query_pos=query_embed
|
463 |
+
)
|
464 |
+
|
465 |
+
output = self.transformer_self_attention_layers[i](
|
466 |
+
output, tgt_mask=None,
|
467 |
+
tgt_key_padding_mask=None,
|
468 |
+
query_pos=query_embed
|
469 |
+
)
|
470 |
+
|
471 |
+
# FFN
|
472 |
+
output = self.transformer_ffn_layers[i](
|
473 |
+
output
|
474 |
+
)
|
475 |
+
|
476 |
+
if (self.training and self.aux_loss) or (i == self.num_layers - 1):
|
477 |
+
dec_out = self.decoder_norm(output) # cQ, LB, C
|
478 |
+
dec_out = dec_out.transpose(0, 1) # LB, cQ, C
|
479 |
+
decoder_outputs.append(dec_out.view(L, B, self.num_queries, C))
|
480 |
+
|
481 |
+
decoder_outputs = torch.stack(decoder_outputs, dim=0) # D, L, B, cQ, C
|
482 |
+
|
483 |
+
pred_cls = self.class_embed(decoder_outputs)
|
484 |
+
pred_mask_embed = self.mask_embed(decoder_outputs)
|
485 |
+
if self.use_sim and self.sim_use_clip:
|
486 |
+
pred_cq_embed = self.sim_embed_clip(decoder_outputs)
|
487 |
+
else:
|
488 |
+
pred_cq_embed = [None] * self.num_layers
|
489 |
+
|
490 |
+
out = {
|
491 |
+
'pred_logits': pred_cls[-1],
|
492 |
+
'pred_mask_embed': pred_mask_embed[-1],
|
493 |
+
'pred_fq_embed': pred_fq_embed,
|
494 |
+
'pred_cq_embed': pred_cq_embed[-1],
|
495 |
+
'aux_outputs': self._set_aux_loss(
|
496 |
+
pred_cls, pred_mask_embed, pred_cq_embed, pred_fq_embed
|
497 |
+
)
|
498 |
+
}
|
499 |
+
return out
|
500 |
+
|
501 |
+
@torch.jit.unused
|
502 |
+
def _set_aux_loss(
|
503 |
+
self, outputs_cls, outputs_mask_embed, outputs_cq_embed, outputs_fq_embed
|
504 |
+
):
|
505 |
+
return [{"pred_logits": a, "pred_mask_embed": b, "pred_cq_embed": c, "pred_fq_embed": outputs_fq_embed}
|
506 |
+
for a, b, c in zip(outputs_cls[:-1], outputs_mask_embed[:-1], outputs_cq_embed[:-1])]
|
507 |
+
|
508 |
+
def encode_frame_query(self, frame_query, attn_mask):
|
509 |
+
"""
|
510 |
+
input shape (frame_query) : T, fQ, LB, C
|
511 |
+
output shape (frame_query) : T, fQ, LB, C
|
512 |
+
"""
|
513 |
+
|
514 |
+
# Not using window-based attention if self.window_size == 0.
|
515 |
+
if self.window_size == 0:
|
516 |
+
return_shape = frame_query.shape # T, fQ, LB, C
|
517 |
+
frame_query = frame_query.flatten(0, 1) # TfQ, LB, C
|
518 |
+
|
519 |
+
for i in range(self.enc_layers):
|
520 |
+
frame_query = self.enc_self_attn[i](frame_query)
|
521 |
+
frame_query = self.enc_ffn[i](frame_query)
|
522 |
+
|
523 |
+
frame_query = frame_query.view(return_shape)
|
524 |
+
return frame_query
|
525 |
+
# Using window-based attention if self.window_size > 0.
|
526 |
+
else:
|
527 |
+
T, fQ, LB, C = frame_query.shape
|
528 |
+
W = self.window_size
|
529 |
+
Nw = T // W
|
530 |
+
half_W = int(ceil(W / 2))
|
531 |
+
|
532 |
+
window_mask = attn_mask.view(LB * Nw, W)[..., None].repeat(1, 1, fQ).flatten(1)
|
533 |
+
|
534 |
+
_attn_mask = torch.roll(attn_mask, half_W, 1)
|
535 |
+
_attn_mask = _attn_mask.view(LB, Nw, W)[..., None].repeat(1, 1, 1, W) # LB, Nw, W, W
|
536 |
+
_attn_mask[:, 0] = _attn_mask[:, 0] | _attn_mask[:, 0].transpose(-2, -1)
|
537 |
+
_attn_mask[:, -1] = _attn_mask[:, -1] | _attn_mask[:, -1].transpose(-2, -1)
|
538 |
+
_attn_mask[:, 0, :half_W, half_W:] = True
|
539 |
+
_attn_mask[:, 0, half_W:, :half_W] = True
|
540 |
+
_attn_mask = _attn_mask.view(LB * Nw, 1, W, 1, W, 1).repeat(1, self.num_heads, 1, fQ, 1, fQ).view(
|
541 |
+
LB * Nw * self.num_heads, W * fQ, W * fQ)
|
542 |
+
shift_window_mask = _attn_mask.float() * -1000
|
543 |
+
|
544 |
+
for layer_idx in range(self.enc_layers):
|
545 |
+
if self.training or layer_idx % 2 == 0:
|
546 |
+
frame_query = self._window_attn(frame_query, window_mask, layer_idx)
|
547 |
+
else:
|
548 |
+
frame_query = self._shift_window_attn(frame_query, shift_window_mask, layer_idx)
|
549 |
+
return frame_query
|
550 |
+
|
551 |
+
def _window_attn(self, frame_query, attn_mask, layer_idx):
|
552 |
+
T, fQ, LB, C = frame_query.shape
|
553 |
+
# LBN, WTfQ = attn_mask.shape
|
554 |
+
|
555 |
+
W = self.window_size
|
556 |
+
Nw = T // W
|
557 |
+
|
558 |
+
frame_query = frame_query.view(Nw, W, fQ, LB, C)
|
559 |
+
frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
|
560 |
+
|
561 |
+
frame_query = self.enc_self_attn[layer_idx](frame_query, tgt_key_padding_mask=attn_mask)
|
562 |
+
frame_query = self.enc_ffn[layer_idx](frame_query)
|
563 |
+
frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
|
564 |
+
|
565 |
+
return frame_query
|
566 |
+
|
567 |
+
def _shift_window_attn(self, frame_query, attn_mask, layer_idx):
|
568 |
+
T, fQ, LB, C = frame_query.shape
|
569 |
+
# LBNH, WfQ, WfQ = attn_mask.shape
|
570 |
+
|
571 |
+
W = self.window_size
|
572 |
+
Nw = T // W
|
573 |
+
half_W = int(ceil(W / 2))
|
574 |
+
|
575 |
+
frame_query = torch.roll(frame_query, half_W, 0)
|
576 |
+
frame_query = frame_query.view(Nw, W, fQ, LB, C)
|
577 |
+
frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
|
578 |
+
|
579 |
+
frame_query = self.enc_self_attn[layer_idx](frame_query, tgt_mask=attn_mask)
|
580 |
+
frame_query = self.enc_ffn[layer_idx](frame_query)
|
581 |
+
frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
|
582 |
+
|
583 |
+
frame_query = torch.roll(frame_query, -half_W, 0)
|
584 |
+
|
585 |
+
return frame_query
|
586 |
+
|
587 |
+
def encode_av_fusion(self, frame_query, attn_mask, audio_feats):
|
588 |
+
"""
|
589 |
+
input shape (frame_query) : T, fQ, LB, C
|
590 |
+
output shape (frame_query) : T, fQ, LB, C
|
591 |
+
"""
|
592 |
+
|
593 |
+
# Not using window-based attention if self.window_size == 0.
|
594 |
+
if self.window_size == 0:
|
595 |
+
return_shape = frame_query.shape # T, fQ, LB, C
|
596 |
+
frame_query = frame_query.flatten(0, 1) # TfQ, LB, C
|
597 |
+
audio_feats = audio_feats.flatten(0, 1)
|
598 |
+
|
599 |
+
for i in range(self.enc_layers):
|
600 |
+
audio_feats = self.enc_av_cross_attn[i](audio_feats, frame_query)
|
601 |
+
audio_feats = self.enc_av_ffn[i](audio_feats)
|
602 |
+
|
603 |
+
audio_feats = audio_feats.view(return_shape)
|
604 |
+
return audio_feats
|
605 |
+
# Using window-based attention if self.window_size > 0.
|
606 |
+
else:
|
607 |
+
T, fQ, LB, C = frame_query.shape
|
608 |
+
W = self.window_size
|
609 |
+
Nw = T // W
|
610 |
+
half_W = int(ceil(W / 2))
|
611 |
+
|
612 |
+
window_mask = attn_mask.view(LB * Nw, W)[..., None].repeat(1, 1, fQ).flatten(1)
|
613 |
+
|
614 |
+
_attn_mask = torch.roll(attn_mask, half_W, 1)
|
615 |
+
_attn_mask = _attn_mask.view(LB, Nw, W)[..., None].repeat(1, 1, 1, W) # LB, Nw, W, W
|
616 |
+
_attn_mask[:, 0] = _attn_mask[:, 0] | _attn_mask[:, 0].transpose(-2, -1)
|
617 |
+
_attn_mask[:, -1] = _attn_mask[:, -1] | _attn_mask[:, -1].transpose(-2, -1)
|
618 |
+
_attn_mask[:, 0, :half_W, half_W:] = True
|
619 |
+
_attn_mask[:, 0, half_W:, :half_W] = True
|
620 |
+
_attn_mask = _attn_mask.view(LB * Nw, 1, W, 1, W, 1).repeat(1, self.num_heads, 1, fQ, 1, fQ).view(
|
621 |
+
LB * Nw * self.num_heads, W * fQ, W * fQ)
|
622 |
+
shift_window_mask = _attn_mask.float() * -1000
|
623 |
+
|
624 |
+
for layer_idx in range(self.enc_layers):
|
625 |
+
if layer_idx % 2 == 0:
|
626 |
+
frame_query, audio_feats = self._window_av_attn(frame_query, window_mask, layer_idx, audio_feats)
|
627 |
+
else:
|
628 |
+
frame_query, audio_feats = self._shift_window_av_attn(frame_query, shift_window_mask, layer_idx, audio_feats)
|
629 |
+
return audio_feats
|
630 |
+
|
631 |
+
def _window_av_attn(self, frame_query, attn_mask, layer_idx, audio_feats):
|
632 |
+
T, fQ, LB, C = frame_query.shape
|
633 |
+
|
634 |
+
W = self.window_size
|
635 |
+
Nw = T // W
|
636 |
+
|
637 |
+
frame_query = frame_query.view(Nw, W, fQ, LB, C)
|
638 |
+
frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
|
639 |
+
|
640 |
+
audio_feats = audio_feats.view(Nw, W, fQ, LB, C)
|
641 |
+
audio_feats = audio_feats.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
|
642 |
+
|
643 |
+
audio_feats = self.enc_av_cross_attn[layer_idx](audio_feats, frame_query, memory_key_padding_mask=attn_mask)
|
644 |
+
audio_feats = self.enc_av_ffn[layer_idx](audio_feats)
|
645 |
+
|
646 |
+
frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
|
647 |
+
audio_feats = audio_feats.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
|
648 |
+
|
649 |
+
return frame_query, audio_feats
|
650 |
+
|
651 |
+
def _shift_window_av_attn(self, frame_query, attn_mask, layer_idx, audio_feats):
|
652 |
+
T, fQ, LB, C = frame_query.shape
|
653 |
+
|
654 |
+
W = self.window_size
|
655 |
+
Nw = T // W
|
656 |
+
half_W = int(ceil(W / 2))
|
657 |
+
|
658 |
+
frame_query = torch.roll(frame_query, half_W, 0)
|
659 |
+
frame_query = frame_query.view(Nw, W, fQ, LB, C)
|
660 |
+
frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
|
661 |
+
|
662 |
+
audio_feats = torch.roll(audio_feats, half_W, 0)
|
663 |
+
audio_feats = audio_feats.view(Nw, W, fQ, LB, C)
|
664 |
+
audio_feats = audio_feats.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
|
665 |
+
|
666 |
+
audio_feats = self.enc_av_cross_attn[layer_idx](audio_feats, frame_query, memory_mask=attn_mask)
|
667 |
+
audio_feats = self.enc_av_ffn[layer_idx](audio_feats)
|
668 |
+
|
669 |
+
frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
|
670 |
+
frame_query = torch.roll(frame_query, -half_W, 0)
|
671 |
+
|
672 |
+
audio_feats = audio_feats.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
|
673 |
+
audio_feats = torch.roll(audio_feats, -half_W, 0)
|
674 |
+
|
675 |
+
return frame_query, audio_feats
|