ruohguo commited on
Commit
b80ae90
·
verified ·
1 Parent(s): 35d18ec

Upload 117 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. LICENSE +21 -0
  3. README.md +141 -3
  4. assets/teaser_figure.png +3 -0
  5. avism/__init__.py +12 -0
  6. avism/avism_model.py +460 -0
  7. avism/avism_model_coco.py +460 -0
  8. avism/config.py +56 -0
  9. avism/data/__init__.py +4 -0
  10. avism/data/augmentation.py +623 -0
  11. avism/data/avis_eval.py +203 -0
  12. avism/data/aviseval/__init__.py +5 -0
  13. avism/data/aviseval/_timing.py +65 -0
  14. avism/data/aviseval/datasets/__init__.py +1 -0
  15. avism/data/aviseval/datasets/_base_dataset.py +326 -0
  16. avism/data/aviseval/datasets/avis.py +367 -0
  17. avism/data/aviseval/eval.py +209 -0
  18. avism/data/aviseval/metrics/__init__.py +11 -0
  19. avism/data/aviseval/metrics/_base_metric.py +132 -0
  20. avism/data/aviseval/metrics/av_loc.py +191 -0
  21. avism/data/aviseval/metrics/avisa.py +190 -0
  22. avism/data/aviseval/metrics/clear.py +186 -0
  23. avism/data/aviseval/metrics/count.py +44 -0
  24. avism/data/aviseval/metrics/hota.py +202 -0
  25. avism/data/aviseval/metrics/identity.py +135 -0
  26. avism/data/aviseval/metrics/ideucl.py +135 -0
  27. avism/data/aviseval/metrics/j_and_f.py +310 -0
  28. avism/data/aviseval/metrics/track_map.py +462 -0
  29. avism/data/aviseval/metrics/vace.py +131 -0
  30. avism/data/aviseval/plotting.py +230 -0
  31. avism/data/aviseval/utils.py +146 -0
  32. avism/data/build.py +247 -0
  33. avism/data/dataset_mapper.py +272 -0
  34. avism/data/datasets/__init__.py +3 -0
  35. avism/data/datasets/avis.py +209 -0
  36. avism/data/datasets/avis_api/__init__.py +1 -0
  37. avism/data/datasets/avis_api/avos.py +277 -0
  38. avism/data/datasets/avis_api/avoseval.py +559 -0
  39. avism/data/datasets/builtin.py +29 -0
  40. avism/data/datasets/extract_audio_feat/audio_feature_extractor.py +77 -0
  41. avism/data/datasets/extract_audio_feat/mel_features.py +233 -0
  42. avism/data/datasets/extract_audio_feat/vggish_input.py +103 -0
  43. avism/data/datasets/extract_audio_feat/vggish_params.py +53 -0
  44. avism/data/datasets/extract_audio_feat/vggish_slim.py +134 -0
  45. avism/modeling/__init__.py +0 -0
  46. avism/modeling/avism_criterion.py +335 -0
  47. avism/modeling/avism_matcher.py +194 -0
  48. avism/modeling/transformer_decoder/__init__.py +1 -0
  49. avism/modeling/transformer_decoder/avism.py +675 -0
  50. avism/modeling/transformer_decoder/avism_coco.py +675 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/teaser_figure.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 GeWu-Lab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,141 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Audio-Visual Instance Segmentation
2
+
3
+
4
+ [![AVIS](https://img.shields.io/badge/Paper-AVIS-2b9348.svg?logo=arXiv)](https://arxiv.org/abs/2412.03069)
5
+ [![Project Page](https://img.shields.io/badge/Project_page-Visualizations-blue)](https://ruohaoguo.github.io/avis/)
6
+ [![Dataset](https://img.shields.io/badge/Dataset-Download-yellow)](https://1drv.ms/u/c/3c9af704fb61931d/EVOs609SGMxLsbvVzVJHAa4Bmnu4GVZGjqYHQxDz0NKTew?e=WQU2Uf)
7
+
8
+ Ruohao Guo, Xianghua Ying*, Yaru Chen, Dantong Niu, Guangyao Li, Liao Qu, Yanyu Qi, Jinxing Zhou, Bowei Xing, Wenzhen Yue, Ji Shi, Qixun Wang, Peiliang Zhang, Buwen Liang
9
+
10
+ ## 📰 News
11
+
12
+ 🔥**2025.03.01**: Codes and checkpoints are released!
13
+
14
+ 🔥**2025.02.27**: AVIS got accepted to **CVPR 2025**! 🎉🎉🎉
15
+
16
+ 🔥**2024.11.12**: Our [project page](https://ruohaoguo.github.io/avis/) is now available!
17
+
18
+ 🔥**2024.11.11**: The AVISeg dataset has been uploaded to [OneDrive](https://1drv.ms/u/c/3c9af704fb61931d/EVOs609SGMxLsbvVzVJHAa4Bmnu4GVZGjqYHQxDz0NKTew?e=WQU2Uf), welcome to download and use!
19
+
20
+
21
+ ## 🌿 Introduction
22
+
23
+ In this paper, we propose a new multi-modal task, termed audio-visual instance segmentation (AVIS), which aims to simultaneously identify, segment and track individual sounding object instances in audible videos. To facilitate this research, we introduce a high-quality benchmark named AVISeg, containing over 90K instance masks from 26 semantic categories in 926 long videos. Additionally, we propose a strong baseline model for this task. Our model first localizes sound source within each frame, and condenses object-specific contexts into concise tokens. Then it builds long-range audio-visual dependencies between these tokens using window-based attention, and tracks sounding objects among the entire video sequences.
24
+
25
+ <div align='center'>
26
+ <img src="./assets/teaser_figure.png" class="interpolation-image" alt="radar." height="50%" width="100%" />
27
+ </div>
28
+
29
+
30
+
31
+ ## ⚙️ Installation
32
+
33
+ ```bash
34
+ conda create --name avism python=3.8 -y
35
+ conda activate avism
36
+
37
+ conda install pytorch==1.9.0 torchvision==0.10.0 cudatoolkit=11.1 -c pytorch -c nvidia
38
+ pip install -U opencv-python
39
+
40
+ cd ./AVISM
41
+ git clone https://github.com/facebookresearch/detectron2
42
+ cd detectron2
43
+ pip install -e .
44
+
45
+ cd ../
46
+ pip install -r requirements.txt
47
+ cd mask2former/modeling/pixel_decoder/ops
48
+ sh make.sh
49
+ ```
50
+
51
+ ## 🤗 Setup
52
+
53
+ ### Datasets
54
+
55
+ Download and unzip datasets [OneDrive](https://1drv.ms/u/c/3c9af704fb61931d/EVOs609SGMxLsbvVzVJHAa4Bmnu4GVZGjqYHQxDz0NKTew?e=WQU2Uf) and put them in ```./datasets```.
56
+
57
+ ### Pretrained Backbones
58
+ Download and unzip pre-trained backbones [OneDrive](https://1drv.ms/u/c/3c9af704fb61931d/ETDDliQ8zZFGmYxlLVPyi3sBis_fdjX0w8mJhyQnYVSdXA?e=Wt7pUb) and put them in ```./pre_models```.
59
+
60
+ ### Checkpoints
61
+
62
+ Download the following checkpoints and put them in ```./checkpoints```.
63
+
64
+ <table>
65
+ <tr>
66
+ <th style="width: 150px;">Backbone</th>
67
+ <th>Pre-trained Datasets</th>
68
+ <th>FSLA</th>
69
+ <th>HOTA</th>
70
+ <th>mAP</th>
71
+ <th>Model Weight</th>
72
+ </tr>
73
+ <tr>
74
+ <td align="center">ResNet-50</td>
75
+ <td align="center">ImageNet</td>
76
+ <td align="center">42.78</td>
77
+ <td align="center">61.73</td>
78
+ <td align="center">40.57</td>
79
+ <td align="center"><a href="https://1drv.ms/u/c/3c9af704fb61931d/EYyAuCNpRjxDqEohJfoDLO0BYgw0lbwKqQ1lwVXe_kIPVQ?e=PeRlyx">AVISM_R50_IN.pth</a></td>
80
+ </tr>
81
+ <tr>
82
+ <td align="center">ResNet-50</td>
83
+ <td align="center">ImageNet & COCO</td>
84
+ <td align="center">44.42</td>
85
+ <td align="center">64.52</td>
86
+ <td align="center">45.04</td>
87
+ <td align="center"><a href="https://1drv.ms/u/c/3c9af704fb61931d/EX0snZsxQwdBswQFdG4sc9kBd-Bd7lw5zaTGR6FvrSxinQ?e=bdZF5G">AVISM_R50_COCO.pth</a></td>
88
+ </tr>
89
+ <tr>
90
+ <td align="center">Swin-L</td>
91
+ <td align="center">ImageNet</td>
92
+ <td align="center">49.15</td>
93
+ <td align="center">68.81</td>
94
+ <td align="center">49.06</td>
95
+ <td align="center"><a href="https://1drv.ms/u/c/3c9af704fb61931d/EV4V5Bh5AqVBhLVMM1ucdN0BuOZgHu17W3JDGjKDMLZ1bg?e=hF8umh">AVISM_SwinL_IN.pth</a></td>
96
+ </tr>
97
+ <tr>
98
+ <td align="center">Swin-L</td>
99
+ <td align="center">ImageNet & COCO</td>
100
+ <td align="center">52.49</td>
101
+ <td align="center">71.13</td>
102
+ <td align="center">53.46</td>
103
+ <td align="center"><a href="https://1drv.ms/u/c/3c9af704fb61931d/EXuM4cUxPTpEk1M7FoPqtNEBi47L7uR-ZlnqDCJscmNsiA?e=7prFiN">AVISM_SwinL_COCO.pth</a></td>
104
+ </tr>
105
+ </table>
106
+
107
+
108
+ ## 📌 Getting Started
109
+
110
+ ### Training
111
+ ```
112
+ python train_net.py --num-gpus 2 --config-file configs/avism/R50/avism_R50_IN.yaml
113
+ ```
114
+
115
+ ### Evaluation
116
+ ```
117
+ python train_net.py --config-file configs/avism/R50/avism_R50_IN.yaml --eval-only MODEL.WEIGHTS checkpoints/AVISM_R50_IN.pth
118
+ ```
119
+
120
+ ### Demo
121
+ ```
122
+ python demo_video/demo.py --config-file configs/avism/R50/avism_R50_IN.yaml --opts MODEL.WEIGHTS checkpoints/AVISM_R50_IN.pth
123
+ ```
124
+
125
+ ## Acknowledgement
126
+
127
+ We thank the great work from [Detectron2](https://github.com/facebookresearch/detectron2), [Mask2Former](https://github.com/facebookresearch/MaskFormer) and [VITA](https://github.com/sukjunhwang/VITA).
128
+
129
+
130
+ ## 📄 Citation
131
+
132
+ If our work assists your research, feel free to give us a star ⭐ or cite us using
133
+
134
+ ```
135
+ @article{guo2023audio,
136
+ title={Audio-Visual Instance Segmentation},
137
+ author={Guo, Ruohao and Ying, Xianghua and Chen, Yaru and Niu, Dantong and Li, Guangyao and Qu, Liao and Qi, Yanyu and Zhou, Jinxing and Xing, Bowei and Yue, Wenzhen and Shi, Ji and Wang, Qixun and Zhang, Peiliang and Liang, Buwen},
138
+ journal={arXiv preprint arXiv:2310.18709},
139
+ year={2023}
140
+ }
141
+ ```
assets/teaser_figure.png ADDED

Git LFS Details

  • SHA256: 12ed508f8fa304c94b3ef327e7b3c938567c3f4b72375d7edd6c7bd4e2965035
  • Pointer size: 132 Bytes
  • Size of remote file: 8.13 MB
avism/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model code
2
+ from . import modeling
3
+
4
+ # config
5
+ from .config import add_avism_config
6
+
7
+ # models
8
+ from .avism_model import AVISM
9
+ from .avism_model_coco import AVISM_COCO
10
+
11
+ # video
12
+ from .data import *
avism/avism_model.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import math
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from detectron2.config import configurable
9
+ from detectron2.data import MetadataCatalog
10
+ from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
11
+ from detectron2.modeling.backbone import Backbone
12
+ from detectron2.structures import Boxes, ImageList, Instances, BitMasks
13
+ from detectron2.utils.memory import retry_if_cuda_oom
14
+
15
+ from mask2former.modeling.criterion import SetCriterion
16
+ from mask2former.modeling.matcher import HungarianMatcher
17
+ from .modeling.avism_criterion import AvismSetCriterion
18
+ from .modeling.avism_matcher import AvismHungarianMatcher
19
+ from .modeling.transformer_decoder.avism import Avism
20
+
21
+
22
+ @META_ARCH_REGISTRY.register()
23
+ class AVISM(nn.Module):
24
+ """
25
+ Main class for mask classification semantic segmentation architectures.
26
+ """
27
+
28
+ @configurable
29
+ def __init__(
30
+ self,
31
+ *,
32
+ backbone: Backbone,
33
+ sem_seg_head: nn.Module,
34
+ criterion: nn.Module,
35
+ num_queries: int,
36
+ object_mask_threshold: float,
37
+ overlap_threshold: float,
38
+ metadata,
39
+ size_divisibility: int,
40
+ pixel_mean: Tuple[float],
41
+ pixel_std: Tuple[float],
42
+ # inference
43
+ test_topk_per_image: int,
44
+ # avism
45
+ avism_module: nn.Module,
46
+ avism_criterion: nn.Module,
47
+ num_frames: int,
48
+ num_classes: int,
49
+ is_multi_cls: bool,
50
+ apply_cls_thres: float,
51
+ freeze_detector: bool,
52
+ test_run_chunk_size: int,
53
+ test_interpolate_chunk_size: int,
54
+ is_coco: bool,
55
+ ):
56
+ """
57
+ Args:
58
+ backbone: a backbone module, must follow detectron2's backbone interface
59
+ sem_seg_head: a module that predicts semantic segmentation from backbone features
60
+ criterion: a module that defines the loss
61
+ num_queries: int, number of queries
62
+ object_mask_threshold: float, threshold to filter query based on classification score
63
+ for panoptic segmentation inference
64
+ overlap_threshold: overlap threshold used in general inference for panoptic segmentation
65
+ metadata: dataset meta, get `thing` and `stuff` category names for panoptic
66
+ segmentation inference
67
+ size_divisibility: Some backbones require the input height and width to be divisible by a
68
+ specific integer. We can use this to override such requirement.
69
+ pixel_mean, pixel_std: list or tuple with #channels element, representing
70
+ the per-channel mean and std to be used to normalize the input image
71
+ test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
72
+ """
73
+ super().__init__()
74
+ self.backbone = backbone
75
+ self.sem_seg_head = sem_seg_head
76
+ self.criterion = criterion
77
+ self.num_queries = num_queries
78
+ self.overlap_threshold = overlap_threshold
79
+ self.object_mask_threshold = object_mask_threshold
80
+ self.metadata = metadata
81
+ if size_divisibility < 0:
82
+ # use backbone size_divisibility if not set
83
+ size_divisibility = self.backbone.size_divisibility
84
+ self.size_divisibility = size_divisibility
85
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
86
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
87
+
88
+ # additional args
89
+ self.test_topk_per_image = test_topk_per_image
90
+
91
+ # avism hyper-parameters
92
+ self.num_frames = num_frames
93
+ self.num_classes = num_classes
94
+ self.avism_module = avism_module
95
+ self.avism_criterion = avism_criterion
96
+ self.is_multi_cls = is_multi_cls
97
+ self.apply_cls_thres = apply_cls_thres
98
+
99
+ if freeze_detector:
100
+ for name, p in self.named_parameters():
101
+ if not "avism_module" in name:
102
+ p.requires_grad_(False)
103
+ self.test_run_chunk_size = test_run_chunk_size
104
+ self.test_interpolate_chunk_size = test_interpolate_chunk_size
105
+
106
+ self.is_coco = is_coco
107
+
108
+ @classmethod
109
+ def from_config(cls, cfg):
110
+ backbone = build_backbone(cfg)
111
+ sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())
112
+
113
+ # Loss parameters:
114
+ deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
115
+ no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT
116
+ avism_deep_supervision = cfg.MODEL.AVISM.DEEP_SUPERVISION
117
+
118
+ # loss weights
119
+ class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT
120
+ dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
121
+ mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
122
+ sim_weight = cfg.MODEL.AVISM.SIM_WEIGHT
123
+
124
+ # building criterion
125
+ matcher = HungarianMatcher(
126
+ cost_class=class_weight,
127
+ cost_mask=mask_weight,
128
+ cost_dice=dice_weight,
129
+ num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
130
+ )
131
+
132
+ weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight}
133
+
134
+ if deep_supervision:
135
+ dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
136
+ aux_weight_dict = {}
137
+ for i in range(dec_layers - 1):
138
+ aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
139
+ weight_dict.update(aux_weight_dict)
140
+
141
+ losses = ["labels", "masks"]
142
+
143
+ criterion = SetCriterion(
144
+ sem_seg_head.num_classes,
145
+ matcher=matcher,
146
+ weight_dict=weight_dict,
147
+ eos_coef=no_object_weight,
148
+ losses=losses,
149
+ num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
150
+ oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
151
+ importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
152
+ avism_last_layer_num=cfg.MODEL.AVISM.LAST_LAYER_NUM,
153
+ )
154
+
155
+ # Avism
156
+ num_classes = sem_seg_head.num_classes
157
+ hidden_dim = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
158
+ avism_module = Avism(cfg=cfg, in_channels=hidden_dim, aux_loss=avism_deep_supervision)
159
+
160
+ # building criterion for avism inference
161
+ avism_matcher = AvismHungarianMatcher(
162
+ cost_class=class_weight,
163
+ cost_mask=mask_weight,
164
+ cost_dice=dice_weight,
165
+ num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
166
+ )
167
+ avism_weight_dict = {
168
+ "loss_avism_ce": class_weight, "loss_avism_mask": mask_weight, "loss_avism_dice": dice_weight
169
+ }
170
+ if sim_weight > 0.0:
171
+ avism_weight_dict["loss_avism_sim"] = sim_weight
172
+
173
+ if avism_deep_supervision:
174
+ avism_dec_layers = cfg.MODEL.AVISM.DEC_LAYERS
175
+ aux_weight_dict = {}
176
+ for i in range(avism_dec_layers - 1):
177
+ aux_weight_dict.update({k + f"_{i}": v for k, v in avism_weight_dict.items()})
178
+ avism_weight_dict.update(aux_weight_dict)
179
+ avism_losses = ["avism_labels", "avism_masks"]
180
+ if sim_weight > 0.0:
181
+ avism_losses.append("fg_sim")
182
+
183
+ avism_criterion = AvismSetCriterion(
184
+ num_classes,
185
+ matcher=avism_matcher,
186
+ weight_dict=avism_weight_dict,
187
+ eos_coef=cfg.MODEL.AVISM.NO_OBJECT_WEIGHT,
188
+ losses=avism_losses,
189
+ num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
190
+ oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
191
+ importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
192
+ sim_use_clip=cfg.MODEL.AVISM.SIM_USE_CLIP,
193
+ )
194
+
195
+ return {
196
+ "backbone": backbone,
197
+ "sem_seg_head": sem_seg_head,
198
+ "criterion": criterion,
199
+ "num_queries": cfg.MODEL.AVISM.NUM_OBJECT_QUERIES,
200
+ "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD,
201
+ "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD,
202
+ "metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
203
+ "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
204
+ "pixel_mean": cfg.MODEL.PIXEL_MEAN,
205
+ "pixel_std": cfg.MODEL.PIXEL_STD,
206
+ # inference
207
+ "test_topk_per_image": cfg.TEST.DETECTIONS_PER_IMAGE,
208
+ # avism
209
+ "avism_module": avism_module,
210
+ "avism_criterion": avism_criterion,
211
+ "num_frames": cfg.INPUT.SAMPLING_FRAME_NUM,
212
+ "num_classes": num_classes,
213
+ "is_multi_cls": cfg.MODEL.AVISM.MULTI_CLS_ON,
214
+ "apply_cls_thres": cfg.MODEL.AVISM.APPLY_CLS_THRES,
215
+ "freeze_detector": cfg.MODEL.AVISM.FREEZE_DETECTOR,
216
+ "test_run_chunk_size": cfg.MODEL.AVISM.TEST_RUN_CHUNK_SIZE,
217
+ "test_interpolate_chunk_size": cfg.MODEL.AVISM.TEST_INTERPOLATE_CHUNK_SIZE,
218
+ "is_coco": cfg.DATASETS.TEST[0].startswith("coco"),
219
+ }
220
+
221
+ @property
222
+ def device(self):
223
+ return self.pixel_mean.device
224
+
225
+ def forward(self, batched_inputs):
226
+ """
227
+ Args:
228
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
229
+ Each item in the list contains the inputs for one image.
230
+ For now, each item in the list is a dict that contains:
231
+ * "image": Tensor, image in (C, H, W) format.
232
+ * "instances": per-region ground truth
233
+ * Other information that's included in the original dicts, such as:
234
+ "height", "width" (int): the output resolution of the model (may be different
235
+ from input resolution), used in inference.
236
+ Returns:
237
+ list[dict]:
238
+ each dict has the results for one image. The dict contains the following keys:
239
+
240
+ * "sem_seg":
241
+ A Tensor that represents the
242
+ per-pixel segmentation prediced by the head.
243
+ The prediction has shape KxHxW that represents the logits of
244
+ each class for each pixel.
245
+ * "panoptic_seg":
246
+ A tuple that represent panoptic output
247
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
248
+ segments_info (list[dict]): Describe each segment in `panoptic_seg`.
249
+ Each dict contains keys "id", "category_id", "isthing".
250
+ """
251
+ if self.training:
252
+ return self.train_model(batched_inputs)
253
+ else:
254
+ # NOTE consider only B=1 case.
255
+ return self.inference(batched_inputs[0])
256
+
257
+ def train_model(self, batched_inputs):
258
+ images = []
259
+ audio_features = []
260
+ for video in batched_inputs:
261
+ for frame in video["image"]:
262
+ images.append(frame.to(self.device))
263
+ for audio_feat in video["audio"]:
264
+ audio_features.append(torch.tensor(audio_feat).to(self.device))
265
+
266
+ audio_features = torch.stack(audio_features)
267
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
268
+ images = ImageList.from_tensors(images, self.size_divisibility)
269
+ image_features = self.backbone(images.tensor)
270
+
271
+ BT = len(images)
272
+ T = self.num_frames if self.training else BT
273
+ B = BT // T
274
+
275
+ outputs, frame_queries, mask_features = self.sem_seg_head(image_features, audio_features)
276
+
277
+ mask_features = self.avism_module.avism_mask_features(mask_features)
278
+ mask_features = mask_features.view(B, self.num_frames, *mask_features.shape[-3:])
279
+
280
+ # mask classification target
281
+ frame_targets, clip_targets = self.prepare_targets(batched_inputs, images)
282
+
283
+ # bipartite matching-based loss
284
+ losses, fg_indices = self.criterion(outputs, frame_targets)
285
+
286
+ avism_outputs = self.avism_module(frame_queries, audio_features)
287
+ avism_outputs["pred_masks"] = torch.einsum("lbqc,btchw->lbqthw", avism_outputs["pred_mask_embed"], mask_features)
288
+ for out in avism_outputs["aux_outputs"]:
289
+ out["pred_masks"] = torch.einsum("lbqc,btchw->lbqthw", out["pred_mask_embed"], mask_features)
290
+
291
+ for k in list(losses.keys()):
292
+ if k in self.criterion.weight_dict:
293
+ losses[k] *= self.criterion.weight_dict[k]
294
+ else:
295
+ # remove this loss if not specified in `weight_dict`
296
+ losses.pop(k)
297
+ avism_loss_dict = self.avism_criterion(avism_outputs, clip_targets, frame_targets, fg_indices)
298
+ avism_weight_dict = self.avism_criterion.weight_dict
299
+
300
+ for k in avism_loss_dict.keys():
301
+ if k in avism_weight_dict:
302
+ avism_loss_dict[k] *= avism_weight_dict[k]
303
+ losses.update(avism_loss_dict)
304
+ return losses
305
+
306
+ def prepare_targets(self, targets, images):
307
+ h_pad, w_pad = images.tensor.shape[-2:]
308
+ frame_gt_instances = []
309
+ clip_gt_instances = []
310
+ for targets_per_video in targets:
311
+ _num_instance = len(targets_per_video["instances"][0])
312
+ mask_shape = [_num_instance, self.num_frames, h_pad, w_pad]
313
+ gt_masks_per_video = torch.zeros(mask_shape, dtype=torch.bool, device=self.device)
314
+
315
+ gt_classes_per_video = targets_per_video["instances"][0].gt_classes.to(self.device)
316
+ gt_ids_per_video = []
317
+ for f_i, targets_per_frame in enumerate(targets_per_video["instances"]):
318
+ targets_per_frame = targets_per_frame.to(self.device)
319
+ h, w = targets_per_frame.image_size
320
+
321
+ _update_cls = gt_classes_per_video == -1
322
+ gt_classes_per_video[_update_cls] = targets_per_frame.gt_classes[_update_cls]
323
+ gt_ids_per_video.append(targets_per_frame.gt_ids)
324
+ if isinstance(targets_per_frame.gt_masks, BitMasks):
325
+ gt_masks_per_video[:, f_i, :h, :w] = targets_per_frame.gt_masks.tensor
326
+ else: #polygon
327
+ gt_masks_per_video[:, f_i, :h, :w] = targets_per_frame.gt_masks
328
+
329
+ gt_ids_per_video = torch.stack(gt_ids_per_video, dim=1)
330
+ gt_ids_per_video[gt_masks_per_video.sum(dim=(2,3)) == 0] = -1
331
+ valid_bool_frame = (gt_ids_per_video != -1)
332
+ valid_bool_clip = valid_bool_frame.any(dim=-1)
333
+
334
+ gt_classes_per_video = gt_classes_per_video[valid_bool_clip].long() # N,
335
+ gt_ids_per_video = gt_ids_per_video[valid_bool_clip].long() # N, num_frames
336
+ gt_masks_per_video = gt_masks_per_video[valid_bool_clip].float() # N, num_frames, H, W
337
+ valid_bool_frame = valid_bool_frame[valid_bool_clip]
338
+
339
+ if len(gt_ids_per_video) > 0:
340
+ min_id = max(gt_ids_per_video[valid_bool_frame].min(), 0)
341
+ gt_ids_per_video[valid_bool_frame] -= min_id
342
+
343
+ clip_gt_instances.append(
344
+ {
345
+ "labels": gt_classes_per_video, "ids": gt_ids_per_video, "masks": gt_masks_per_video,
346
+ "video_len": targets_per_video["length"], "frame_idx": targets_per_video["frame_idx"],
347
+ }
348
+ )
349
+
350
+ for f_i in range(self.num_frames):
351
+ _cls = gt_classes_per_video.clone()
352
+ _ids = gt_ids_per_video[:, f_i].clone()
353
+ _mask = gt_masks_per_video[:, f_i].clone()
354
+
355
+ valid = _ids != -1
356
+ frame_gt_instances.append({
357
+ "labels": _cls[valid],
358
+ "ids": _ids[valid],
359
+ "masks": _mask[valid],
360
+ })
361
+
362
+ return frame_gt_instances, clip_gt_instances
363
+
364
+ def inference(self, batched_inputs):
365
+ frame_queries, mask_features = [], []
366
+ num_frames = len(batched_inputs["image"])
367
+ to_store = self.device if num_frames <= 36 else "cpu"
368
+
369
+ audio_features = torch.tensor(batched_inputs["audio"]).to(self.device)
370
+
371
+ with torch.no_grad():
372
+ for i in range(math.ceil(num_frames / self.test_run_chunk_size)):
373
+ images = batched_inputs["image"][i*self.test_run_chunk_size : (i+1)*self.test_run_chunk_size]
374
+ images = [(x.to(self.device) - self.pixel_mean) / self.pixel_std for x in images]
375
+ images = ImageList.from_tensors(images, self.size_divisibility)
376
+
377
+ audio_features_chunk = audio_features[i*self.test_run_chunk_size : (i+1)*self.test_run_chunk_size]
378
+
379
+ features = self.backbone(images.tensor)
380
+ outputs, _frame_queries, _mask_features = self.sem_seg_head(features, audio_features_chunk)
381
+
382
+ _mask_features = self.avism_module.avism_mask_features(_mask_features)
383
+
384
+ # BT is 1 as runs per frame
385
+ frame_queries.append(_frame_queries[-1]) # T', fQ, C
386
+ mask_features.append(_mask_features.to(to_store)) # T', C, H, W
387
+
388
+ interim_size = images.tensor.shape[-2:]
389
+ image_size = images.image_sizes[0] # image size without padding after data augmentation
390
+
391
+ out_height = batched_inputs.get("height", image_size[0]) # raw image size before data augmentation
392
+ out_width = batched_inputs.get("width", image_size[1])
393
+
394
+ del outputs, images, batched_inputs
395
+
396
+ frame_queries = torch.cat(frame_queries)[None] # 1, T, fQ, C
397
+ mask_features = torch.cat(mask_features) # T, C, H, W
398
+
399
+ avism_outputs = self.avism_module(frame_queries, audio_features)
400
+
401
+ mask_cls = avism_outputs["pred_logits"][-1, 0] # cQ, K+1
402
+ mask_embed = avism_outputs["pred_mask_embed"][-1, 0] # cQ, C
403
+
404
+ del avism_outputs
405
+
406
+ scores = F.softmax(mask_cls, dim=-1)[:, :-1]
407
+ labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
408
+
409
+ num_topk = self.test_topk_per_image
410
+ scores_per_video, topk_indices = scores.flatten(0, 1).topk(num_topk, sorted=False)
411
+ labels_per_video = labels[topk_indices]
412
+
413
+ topk_indices = torch.div(topk_indices, self.sem_seg_head.num_classes, rounding_mode='floor')
414
+ mask_embed = mask_embed[topk_indices]
415
+
416
+ masks_per_video = []
417
+ numerator = torch.zeros(len(mask_embed), dtype=torch.float, device=self.device)
418
+ denominator = torch.zeros(len(mask_embed), dtype=torch.float, device=self.device)
419
+ for i in range(math.ceil(len(mask_features) / self.test_interpolate_chunk_size)):
420
+ m_f = mask_features[i*self.test_interpolate_chunk_size : (i+1)*self.test_interpolate_chunk_size].to(self.device)
421
+
422
+ mask_pred = torch.einsum("qc,tchw->qthw", mask_embed, m_f)
423
+
424
+ # upsample masks
425
+ mask_pred = retry_if_cuda_oom(F.interpolate)(
426
+ mask_pred,
427
+ size=interim_size,
428
+ mode="bilinear",
429
+ align_corners=False,
430
+ ) # cQ, T, H, W
431
+
432
+ mask_pred = mask_pred[:, :, : image_size[0], : image_size[1]]
433
+
434
+ interim_mask_soft = mask_pred.sigmoid()
435
+ interim_mask_hard = interim_mask_soft > 0.5
436
+
437
+ numerator += (interim_mask_soft.flatten(1) * interim_mask_hard.flatten(1)).sum(1)
438
+ denominator += interim_mask_hard.flatten(1).sum(1)
439
+
440
+ mask_pred = F.interpolate(
441
+ mask_pred, size=(out_height, out_width), mode="bilinear", align_corners=False
442
+ ) > 0.
443
+ masks_per_video.append(mask_pred.to(to_store))
444
+ masks_per_video = torch.cat(masks_per_video, dim=1)
445
+ scores_per_video *= (numerator / (denominator + 1e-6))
446
+
447
+ confidence = 0.3
448
+ indices = torch.nonzero(scores_per_video > confidence).squeeze(-1)
449
+ scores_per_video = scores_per_video[indices]
450
+ labels_per_video = labels_per_video[indices]
451
+ masks_per_video = masks_per_video[indices]
452
+
453
+ processed_results = {
454
+ "image_size": (out_height, out_width),
455
+ "pred_scores": scores_per_video.tolist(),
456
+ "pred_labels": labels_per_video.tolist(),
457
+ "pred_masks": masks_per_video.cpu(),
458
+ }
459
+
460
+ return processed_results
avism/avism_model_coco.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import math
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from detectron2.config import configurable
9
+ from detectron2.data import MetadataCatalog
10
+ from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
11
+ from detectron2.modeling.backbone import Backbone
12
+ from detectron2.structures import Boxes, ImageList, Instances, BitMasks
13
+ from detectron2.utils.memory import retry_if_cuda_oom
14
+
15
+ from mask2former.modeling.criterion import SetCriterion
16
+ from mask2former.modeling.matcher import HungarianMatcher
17
+ from .modeling.avism_criterion import AvismSetCriterion
18
+ from .modeling.avism_matcher import AvismHungarianMatcher
19
+ from .modeling.transformer_decoder.avism_coco import Avism_COCO
20
+
21
+
22
+ @META_ARCH_REGISTRY.register()
23
+ class AVISM_COCO(nn.Module):
24
+ """
25
+ Main class for mask classification semantic segmentation architectures.
26
+ """
27
+
28
+ @configurable
29
+ def __init__(
30
+ self,
31
+ *,
32
+ backbone: Backbone,
33
+ sem_seg_head: nn.Module,
34
+ criterion: nn.Module,
35
+ num_queries: int,
36
+ object_mask_threshold: float,
37
+ overlap_threshold: float,
38
+ metadata,
39
+ size_divisibility: int,
40
+ pixel_mean: Tuple[float],
41
+ pixel_std: Tuple[float],
42
+ # inference
43
+ test_topk_per_image: int,
44
+ # avism
45
+ avism_module: nn.Module,
46
+ avism_criterion: nn.Module,
47
+ num_frames: int,
48
+ num_classes: int,
49
+ is_multi_cls: bool,
50
+ apply_cls_thres: float,
51
+ freeze_detector: bool,
52
+ test_run_chunk_size: int,
53
+ test_interpolate_chunk_size: int,
54
+ is_coco: bool,
55
+ ):
56
+ """
57
+ Args:
58
+ backbone: a backbone module, must follow detectron2's backbone interface
59
+ sem_seg_head: a module that predicts semantic segmentation from backbone features
60
+ criterion: a module that defines the loss
61
+ num_queries: int, number of queries
62
+ object_mask_threshold: float, threshold to filter query based on classification score
63
+ for panoptic segmentation inference
64
+ overlap_threshold: overlap threshold used in general inference for panoptic segmentation
65
+ metadata: dataset meta, get `thing` and `stuff` category names for panoptic
66
+ segmentation inference
67
+ size_divisibility: Some backbones require the input height and width to be divisible by a
68
+ specific integer. We can use this to override such requirement.
69
+ pixel_mean, pixel_std: list or tuple with #channels element, representing
70
+ the per-channel mean and std to be used to normalize the input image
71
+ test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
72
+ """
73
+ super().__init__()
74
+ self.backbone = backbone
75
+ self.sem_seg_head = sem_seg_head
76
+ self.criterion = criterion
77
+ self.num_queries = num_queries
78
+ self.overlap_threshold = overlap_threshold
79
+ self.object_mask_threshold = object_mask_threshold
80
+ self.metadata = metadata
81
+ if size_divisibility < 0:
82
+ # use backbone size_divisibility if not set
83
+ size_divisibility = self.backbone.size_divisibility
84
+ self.size_divisibility = size_divisibility
85
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
86
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
87
+
88
+ # additional args
89
+ self.test_topk_per_image = test_topk_per_image
90
+
91
+ # avism hyper-parameters
92
+ self.num_frames = num_frames
93
+ self.num_classes = num_classes
94
+ self.vita_module = avism_module
95
+ self.vita_criterion = avism_criterion
96
+ self.is_multi_cls = is_multi_cls
97
+ self.apply_cls_thres = apply_cls_thres
98
+
99
+ if freeze_detector:
100
+ for name, p in self.named_parameters():
101
+ if not "vita_module" in name:
102
+ p.requires_grad_(False)
103
+ self.test_run_chunk_size = test_run_chunk_size
104
+ self.test_interpolate_chunk_size = test_interpolate_chunk_size
105
+
106
+ self.is_coco = is_coco
107
+
108
+ @classmethod
109
+ def from_config(cls, cfg):
110
+ backbone = build_backbone(cfg)
111
+ sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())
112
+
113
+ # Loss parameters:
114
+ deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
115
+ no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT
116
+ avism_deep_supervision = cfg.MODEL.AVISM.DEEP_SUPERVISION
117
+
118
+ # loss weights
119
+ class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT
120
+ dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
121
+ mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
122
+ sim_weight = cfg.MODEL.AVISM.SIM_WEIGHT
123
+
124
+ # building criterion
125
+ matcher = HungarianMatcher(
126
+ cost_class=class_weight,
127
+ cost_mask=mask_weight,
128
+ cost_dice=dice_weight,
129
+ num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
130
+ )
131
+
132
+ weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight}
133
+
134
+ if deep_supervision:
135
+ dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
136
+ aux_weight_dict = {}
137
+ for i in range(dec_layers - 1):
138
+ aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
139
+ weight_dict.update(aux_weight_dict)
140
+
141
+ losses = ["labels", "masks"]
142
+
143
+ criterion = SetCriterion(
144
+ sem_seg_head.num_classes,
145
+ matcher=matcher,
146
+ weight_dict=weight_dict,
147
+ eos_coef=no_object_weight,
148
+ losses=losses,
149
+ num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
150
+ oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
151
+ importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
152
+ avism_last_layer_num=cfg.MODEL.AVISM.LAST_LAYER_NUM,
153
+ )
154
+
155
+ # Avism
156
+ num_classes = sem_seg_head.num_classes
157
+ hidden_dim = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
158
+ avism_module = Avism_COCO(cfg=cfg, in_channels=hidden_dim, aux_loss=avism_deep_supervision)
159
+
160
+ # building criterion for avism inference
161
+ avism_matcher = AvismHungarianMatcher(
162
+ cost_class=class_weight,
163
+ cost_mask=mask_weight,
164
+ cost_dice=dice_weight,
165
+ num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
166
+ )
167
+ avism_weight_dict = {
168
+ "loss_avism_ce": class_weight, "loss_avism_mask": mask_weight, "loss_avism_dice": dice_weight
169
+ }
170
+ if sim_weight > 0.0:
171
+ avism_weight_dict["loss_avism_sim"] = sim_weight
172
+
173
+ if avism_deep_supervision:
174
+ avism_dec_layers = cfg.MODEL.AVISM.DEC_LAYERS
175
+ aux_weight_dict = {}
176
+ for i in range(avism_dec_layers - 1):
177
+ aux_weight_dict.update({k + f"_{i}": v for k, v in avism_weight_dict.items()})
178
+ avism_weight_dict.update(aux_weight_dict)
179
+ avism_losses = ["avism_labels", "avism_masks"]
180
+ if sim_weight > 0.0:
181
+ avism_losses.append("fg_sim")
182
+
183
+ avism_criterion = AvismSetCriterion(
184
+ num_classes,
185
+ matcher=avism_matcher,
186
+ weight_dict=avism_weight_dict,
187
+ eos_coef=cfg.MODEL.AVISM.NO_OBJECT_WEIGHT,
188
+ losses=avism_losses,
189
+ num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
190
+ oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
191
+ importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
192
+ sim_use_clip=cfg.MODEL.AVISM.SIM_USE_CLIP,
193
+ )
194
+
195
+ return {
196
+ "backbone": backbone,
197
+ "sem_seg_head": sem_seg_head,
198
+ "criterion": criterion,
199
+ "num_queries": cfg.MODEL.AVISM.NUM_OBJECT_QUERIES,
200
+ "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD,
201
+ "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD,
202
+ "metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
203
+ "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
204
+ "pixel_mean": cfg.MODEL.PIXEL_MEAN,
205
+ "pixel_std": cfg.MODEL.PIXEL_STD,
206
+ # inference
207
+ "test_topk_per_image": cfg.TEST.DETECTIONS_PER_IMAGE,
208
+ # avism
209
+ "avism_module": avism_module,
210
+ "avism_criterion": avism_criterion,
211
+ "num_frames": cfg.INPUT.SAMPLING_FRAME_NUM,
212
+ "num_classes": num_classes,
213
+ "is_multi_cls": cfg.MODEL.AVISM.MULTI_CLS_ON,
214
+ "apply_cls_thres": cfg.MODEL.AVISM.APPLY_CLS_THRES,
215
+ "freeze_detector": cfg.MODEL.AVISM.FREEZE_DETECTOR,
216
+ "test_run_chunk_size": cfg.MODEL.AVISM.TEST_RUN_CHUNK_SIZE,
217
+ "test_interpolate_chunk_size": cfg.MODEL.AVISM.TEST_INTERPOLATE_CHUNK_SIZE,
218
+ "is_coco": cfg.DATASETS.TEST[0].startswith("coco"),
219
+ }
220
+
221
+ @property
222
+ def device(self):
223
+ return self.pixel_mean.device
224
+
225
+ def forward(self, batched_inputs):
226
+ """
227
+ Args:
228
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
229
+ Each item in the list contains the inputs for one image.
230
+ For now, each item in the list is a dict that contains:
231
+ * "image": Tensor, image in (C, H, W) format.
232
+ * "instances": per-region ground truth
233
+ * Other information that's included in the original dicts, such as:
234
+ "height", "width" (int): the output resolution of the model (may be different
235
+ from input resolution), used in inference.
236
+ Returns:
237
+ list[dict]:
238
+ each dict has the results for one image. The dict contains the following keys:
239
+
240
+ * "sem_seg":
241
+ A Tensor that represents the
242
+ per-pixel segmentation prediced by the head.
243
+ The prediction has shape KxHxW that represents the logits of
244
+ each class for each pixel.
245
+ * "panoptic_seg":
246
+ A tuple that represent panoptic output
247
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
248
+ segments_info (list[dict]): Describe each segment in `panoptic_seg`.
249
+ Each dict contains keys "id", "category_id", "isthing".
250
+ """
251
+ if self.training:
252
+ return self.train_model(batched_inputs)
253
+ else:
254
+ # NOTE consider only B=1 case.
255
+ return self.inference(batched_inputs[0])
256
+
257
+ def train_model(self, batched_inputs):
258
+ images = []
259
+ audio_features = []
260
+ for video in batched_inputs:
261
+ for frame in video["image"]:
262
+ images.append(frame.to(self.device))
263
+ for audio_feat in video["audio"]:
264
+ audio_features.append(torch.tensor(audio_feat).to(self.device))
265
+
266
+ audio_features = torch.stack(audio_features)
267
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
268
+ images = ImageList.from_tensors(images, self.size_divisibility)
269
+ image_features = self.backbone(images.tensor)
270
+
271
+ BT = len(images)
272
+ T = self.num_frames if self.training else BT
273
+ B = BT // T
274
+
275
+ outputs, frame_queries, mask_features = self.sem_seg_head(image_features, audio_features)
276
+
277
+ mask_features = self.vita_module.vita_mask_features(mask_features)
278
+ mask_features = mask_features.view(B, self.num_frames, *mask_features.shape[-3:])
279
+
280
+ # mask classification target
281
+ frame_targets, clip_targets = self.prepare_targets(batched_inputs, images)
282
+
283
+ # bipartite matching-based loss
284
+ losses, fg_indices = self.criterion(outputs, frame_targets)
285
+
286
+ avism_outputs = self.vita_module(frame_queries, audio_features)
287
+ avism_outputs["pred_masks"] = torch.einsum("lbqc,btchw->lbqthw", avism_outputs["pred_mask_embed"], mask_features)
288
+ for out in avism_outputs["aux_outputs"]:
289
+ out["pred_masks"] = torch.einsum("lbqc,btchw->lbqthw", out["pred_mask_embed"], mask_features)
290
+
291
+ for k in list(losses.keys()):
292
+ if k in self.criterion.weight_dict:
293
+ losses[k] *= self.criterion.weight_dict[k]
294
+ else:
295
+ # remove this loss if not specified in `weight_dict`
296
+ losses.pop(k)
297
+ avism_loss_dict = self.vita_criterion(avism_outputs, clip_targets, frame_targets, fg_indices)
298
+ avism_weight_dict = self.vita_criterion.weight_dict
299
+
300
+ for k in avism_loss_dict.keys():
301
+ if k in avism_weight_dict:
302
+ avism_loss_dict[k] *= avism_weight_dict[k]
303
+ losses.update(avism_loss_dict)
304
+ return losses
305
+
306
+ def prepare_targets(self, targets, images):
307
+ h_pad, w_pad = images.tensor.shape[-2:]
308
+ frame_gt_instances = []
309
+ clip_gt_instances = []
310
+ for targets_per_video in targets:
311
+ _num_instance = len(targets_per_video["instances"][0])
312
+ mask_shape = [_num_instance, self.num_frames, h_pad, w_pad]
313
+ gt_masks_per_video = torch.zeros(mask_shape, dtype=torch.bool, device=self.device)
314
+
315
+ gt_classes_per_video = targets_per_video["instances"][0].gt_classes.to(self.device)
316
+ gt_ids_per_video = []
317
+ for f_i, targets_per_frame in enumerate(targets_per_video["instances"]):
318
+ targets_per_frame = targets_per_frame.to(self.device)
319
+ h, w = targets_per_frame.image_size
320
+
321
+ _update_cls = gt_classes_per_video == -1
322
+ gt_classes_per_video[_update_cls] = targets_per_frame.gt_classes[_update_cls]
323
+ gt_ids_per_video.append(targets_per_frame.gt_ids)
324
+ if isinstance(targets_per_frame.gt_masks, BitMasks):
325
+ gt_masks_per_video[:, f_i, :h, :w] = targets_per_frame.gt_masks.tensor
326
+ else: #polygon
327
+ gt_masks_per_video[:, f_i, :h, :w] = targets_per_frame.gt_masks
328
+
329
+ gt_ids_per_video = torch.stack(gt_ids_per_video, dim=1)
330
+ gt_ids_per_video[gt_masks_per_video.sum(dim=(2,3)) == 0] = -1
331
+ valid_bool_frame = (gt_ids_per_video != -1)
332
+ valid_bool_clip = valid_bool_frame.any(dim=-1)
333
+
334
+ gt_classes_per_video = gt_classes_per_video[valid_bool_clip].long() # N,
335
+ gt_ids_per_video = gt_ids_per_video[valid_bool_clip].long() # N, num_frames
336
+ gt_masks_per_video = gt_masks_per_video[valid_bool_clip].float() # N, num_frames, H, W
337
+ valid_bool_frame = valid_bool_frame[valid_bool_clip]
338
+
339
+ if len(gt_ids_per_video) > 0:
340
+ min_id = max(gt_ids_per_video[valid_bool_frame].min(), 0)
341
+ gt_ids_per_video[valid_bool_frame] -= min_id
342
+
343
+ clip_gt_instances.append(
344
+ {
345
+ "labels": gt_classes_per_video, "ids": gt_ids_per_video, "masks": gt_masks_per_video,
346
+ "video_len": targets_per_video["length"], "frame_idx": targets_per_video["frame_idx"],
347
+ }
348
+ )
349
+
350
+ for f_i in range(self.num_frames):
351
+ _cls = gt_classes_per_video.clone()
352
+ _ids = gt_ids_per_video[:, f_i].clone()
353
+ _mask = gt_masks_per_video[:, f_i].clone()
354
+
355
+ valid = _ids != -1
356
+ frame_gt_instances.append({
357
+ "labels": _cls[valid],
358
+ "ids": _ids[valid],
359
+ "masks": _mask[valid],
360
+ })
361
+
362
+ return frame_gt_instances, clip_gt_instances
363
+
364
+ def inference(self, batched_inputs):
365
+ frame_queries, mask_features = [], []
366
+ num_frames = len(batched_inputs["image"])
367
+ to_store = self.device if num_frames <= 36 else "cpu"
368
+
369
+ audio_features = torch.tensor(batched_inputs["audio"]).to(self.device)
370
+
371
+ with torch.no_grad():
372
+ for i in range(math.ceil(num_frames / self.test_run_chunk_size)):
373
+ images = batched_inputs["image"][i*self.test_run_chunk_size : (i+1)*self.test_run_chunk_size]
374
+ images = [(x.to(self.device) - self.pixel_mean) / self.pixel_std for x in images]
375
+ images = ImageList.from_tensors(images, self.size_divisibility)
376
+
377
+ audio_features_chunk = audio_features[i*self.test_run_chunk_size : (i+1)*self.test_run_chunk_size]
378
+
379
+ features = self.backbone(images.tensor)
380
+ outputs, _frame_queries, _mask_features = self.sem_seg_head(features, audio_features_chunk)
381
+
382
+ _mask_features = self.vita_module.vita_mask_features(_mask_features)
383
+
384
+ # BT is 1 as runs per frame
385
+ frame_queries.append(_frame_queries[-1]) # T', fQ, C
386
+ mask_features.append(_mask_features.to(to_store)) # T', C, H, W
387
+
388
+ interim_size = images.tensor.shape[-2:]
389
+ image_size = images.image_sizes[0] # image size without padding after data augmentation
390
+
391
+ out_height = batched_inputs.get("height", image_size[0]) # raw image size before data augmentation
392
+ out_width = batched_inputs.get("width", image_size[1])
393
+
394
+ del outputs, images, batched_inputs
395
+
396
+ frame_queries = torch.cat(frame_queries)[None] # 1, T, fQ, C
397
+ mask_features = torch.cat(mask_features) # T, C, H, W
398
+
399
+ avism_outputs = self.vita_module(frame_queries, audio_features)
400
+
401
+ mask_cls = avism_outputs["pred_logits"][-1, 0] # cQ, K+1
402
+ mask_embed = avism_outputs["pred_mask_embed"][-1, 0] # cQ, C
403
+
404
+ del avism_outputs
405
+
406
+ scores = F.softmax(mask_cls, dim=-1)[:, :-1]
407
+ labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
408
+
409
+ num_topk = self.test_topk_per_image
410
+ scores_per_video, topk_indices = scores.flatten(0, 1).topk(num_topk, sorted=False)
411
+ labels_per_video = labels[topk_indices]
412
+
413
+ topk_indices = torch.div(topk_indices, self.sem_seg_head.num_classes, rounding_mode='floor')
414
+ mask_embed = mask_embed[topk_indices]
415
+
416
+ masks_per_video = []
417
+ numerator = torch.zeros(len(mask_embed), dtype=torch.float, device=self.device)
418
+ denominator = torch.zeros(len(mask_embed), dtype=torch.float, device=self.device)
419
+ for i in range(math.ceil(len(mask_features) / self.test_interpolate_chunk_size)):
420
+ m_f = mask_features[i*self.test_interpolate_chunk_size : (i+1)*self.test_interpolate_chunk_size].to(self.device)
421
+
422
+ mask_pred = torch.einsum("qc,tchw->qthw", mask_embed, m_f)
423
+
424
+ # upsample masks
425
+ mask_pred = retry_if_cuda_oom(F.interpolate)(
426
+ mask_pred,
427
+ size=interim_size,
428
+ mode="bilinear",
429
+ align_corners=False,
430
+ ) # cQ, T, H, W
431
+
432
+ mask_pred = mask_pred[:, :, : image_size[0], : image_size[1]]
433
+
434
+ interim_mask_soft = mask_pred.sigmoid()
435
+ interim_mask_hard = interim_mask_soft > 0.5
436
+
437
+ numerator += (interim_mask_soft.flatten(1) * interim_mask_hard.flatten(1)).sum(1)
438
+ denominator += interim_mask_hard.flatten(1).sum(1)
439
+
440
+ mask_pred = F.interpolate(
441
+ mask_pred, size=(out_height, out_width), mode="bilinear", align_corners=False
442
+ ) > 0.
443
+ masks_per_video.append(mask_pred.to(to_store))
444
+ masks_per_video = torch.cat(masks_per_video, dim=1)
445
+ scores_per_video *= (numerator / (denominator + 1e-6))
446
+
447
+ confidence = 0.3
448
+ indices = torch.nonzero(scores_per_video > confidence).squeeze(-1)
449
+ scores_per_video = scores_per_video[indices]
450
+ labels_per_video = labels_per_video[indices]
451
+ masks_per_video = masks_per_video[indices]
452
+
453
+ processed_results = {
454
+ "image_size": (out_height, out_width),
455
+ "pred_scores": scores_per_video.tolist(),
456
+ "pred_labels": labels_per_video.tolist(),
457
+ "pred_masks": masks_per_video.cpu(),
458
+ }
459
+
460
+ return processed_results
avism/config.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from detectron2.config import CfgNode as CN
3
+
4
+
5
+ def add_avism_config(cfg):
6
+ cfg.DATASETS.DATASET_RATIO = []
7
+
8
+ # DataLoader
9
+ cfg.INPUT.SAMPLING_FRAME_NUM = 2
10
+ cfg.INPUT.SAMPLING_FRAME_RANGE = 20
11
+ cfg.INPUT.SAMPLING_FRAME_SHUFFLE = False
12
+ cfg.INPUT.AUGMENTATIONS = [] # "brightness", "contrast", "saturation", "rotation"
13
+
14
+ # Pseudo Data Use
15
+ cfg.INPUT.PSEUDO = CN()
16
+ cfg.INPUT.PSEUDO.AUGMENTATIONS = ['rotation']
17
+ cfg.INPUT.PSEUDO.MIN_SIZE_TRAIN = (480, 512, 544, 576, 608, 640, 672, 704, 736, 768)
18
+ cfg.INPUT.PSEUDO.MAX_SIZE_TRAIN = 768
19
+ cfg.INPUT.PSEUDO.MIN_SIZE_TRAIN_SAMPLING = "choice_by_clip"
20
+ cfg.INPUT.PSEUDO.CROP = CN()
21
+ cfg.INPUT.PSEUDO.CROP.ENABLED = False
22
+ cfg.INPUT.PSEUDO.CROP.TYPE = "absolute_range"
23
+ cfg.INPUT.PSEUDO.CROP.SIZE = (384, 600)
24
+
25
+ # LSJ
26
+ cfg.INPUT.LSJ_AUG = CN()
27
+ cfg.INPUT.LSJ_AUG.ENABLED = False
28
+ cfg.INPUT.LSJ_AUG.IMAGE_SIZE = 1024
29
+ cfg.INPUT.LSJ_AUG.MIN_SCALE = 0.1
30
+ cfg.INPUT.LSJ_AUG.MAX_SCALE = 2.0
31
+
32
+ # AVISM
33
+ cfg.MODEL.AVISM = CN()
34
+ cfg.MODEL.AVISM.NHEADS = 8
35
+ cfg.MODEL.AVISM.DROPOUT = 0.0
36
+ cfg.MODEL.AVISM.DIM_FEEDFORWARD = 2048
37
+ cfg.MODEL.AVISM.ENC_LAYERS = 6
38
+ cfg.MODEL.AVISM.DEC_LAYERS = 3
39
+ cfg.MODEL.AVISM.ENC_WINDOW_SIZE = 0
40
+ cfg.MODEL.AVISM.PRE_NORM = False
41
+ cfg.MODEL.AVISM.HIDDEN_DIM = 256
42
+ cfg.MODEL.AVISM.NUM_OBJECT_QUERIES = 100
43
+ cfg.MODEL.AVISM.ENFORCE_INPUT_PROJ = True
44
+
45
+ cfg.MODEL.AVISM.NO_OBJECT_WEIGHT = 0.1
46
+ cfg.MODEL.AVISM.DEEP_SUPERVISION = True
47
+ cfg.MODEL.AVISM.LAST_LAYER_NUM = 3
48
+ cfg.MODEL.AVISM.MULTI_CLS_ON = True
49
+ cfg.MODEL.AVISM.APPLY_CLS_THRES = 0.01
50
+
51
+ cfg.MODEL.AVISM.SIM_USE_CLIP = True
52
+ cfg.MODEL.AVISM.SIM_WEIGHT = 0.5
53
+
54
+ cfg.MODEL.AVISM.FREEZE_DETECTOR = False
55
+ cfg.MODEL.AVISM.TEST_RUN_CHUNK_SIZE = 18
56
+ cfg.MODEL.AVISM.TEST_INTERPOLATE_CHUNK_SIZE = 5
avism/data/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .datasets import *
2
+ from .dataset_mapper import AVISDatasetMapper
3
+ from .build import *
4
+ from .avis_eval import AVISEvaluator
avism/data/augmentation.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import numpy as np
3
+ import logging
4
+ import sys
5
+ from fvcore.transforms.transform import (
6
+ HFlipTransform,
7
+ NoOpTransform,
8
+ VFlipTransform,
9
+ )
10
+ from PIL import Image
11
+ from typing import Tuple
12
+ from fvcore.transforms.transform import (
13
+ BlendTransform,
14
+ CropTransform,
15
+ HFlipTransform,
16
+ NoOpTransform,
17
+ PadTransform,
18
+ Transform,
19
+ TransformList,
20
+ VFlipTransform,
21
+ )
22
+
23
+ from detectron2.data import transforms as T
24
+
25
+
26
+ class RandomApplyClip(T.Augmentation):
27
+ """
28
+ Randomly apply an augmentation with a given probability.
29
+ """
30
+
31
+ def __init__(self, tfm_or_aug, prob=0.5, clip_frame_cnt=1):
32
+ """
33
+ Args:
34
+ tfm_or_aug (Transform, Augmentation): the transform or augmentation
35
+ to be applied. It can either be a `Transform` or `Augmentation`
36
+ instance.
37
+ prob (float): probability between 0.0 and 1.0 that
38
+ the wrapper transformation is applied
39
+ """
40
+ super().__init__()
41
+ self.aug = T.augmentation._transform_to_aug(tfm_or_aug)
42
+ assert 0.0 <= prob <= 1.0, f"Probablity must be between 0.0 and 1.0 (given: {prob})"
43
+ self.prob = prob
44
+ self._cnt = 0
45
+ self.clip_frame_cnt = clip_frame_cnt
46
+
47
+ def get_transform(self, *args):
48
+ if self._cnt % self.clip_frame_cnt == 0:
49
+ self.do = self._rand_range() < self.prob
50
+ self._cnt = 0 # avoiding overflow
51
+ self._cnt += 1
52
+
53
+ if self.do:
54
+ return self.aug.get_transform(*args)
55
+ else:
56
+ return NoOpTransform()
57
+
58
+ def __call__(self, aug_input):
59
+ if self._cnt % self.clip_frame_cnt == 0:
60
+ self.do = self._rand_range() < self.prob
61
+ self._cnt = 0 # avoiding overflow
62
+ self._cnt += 1
63
+
64
+ if self.do:
65
+ return self.aug(aug_input)
66
+ else:
67
+ return NoOpTransform()
68
+
69
+
70
+ class RandomRotationClip(T.Augmentation):
71
+ """
72
+ This method returns a copy of this image, rotated the given
73
+ number of degrees counter clockwise around the given center.
74
+ """
75
+
76
+ def __init__(self, angle, prob=0.5, expand=True, center=None, interp=None, clip_frame_cnt=1):
77
+ """
78
+ Args:
79
+ angle (list[float]): If ``sample_style=="range"``,
80
+ a [min, max] interval from which to sample the angle (in degrees).
81
+ If ``sample_style=="choice"``, a list of angles to sample from
82
+ expand (bool): choose if the image should be resized to fit the whole
83
+ rotated image (default), or simply cropped
84
+ center (list[[float, float]]): If ``sample_style=="range"``,
85
+ a [[minx, miny], [maxx, maxy]] relative interval from which to sample the center,
86
+ [0, 0] being the top left of the image and [1, 1] the bottom right.
87
+ If ``sample_style=="choice"``, a list of centers to sample from
88
+ Default: None, which means that the center of rotation is the center of the image
89
+ center has no effect if expand=True because it only affects shifting
90
+ """
91
+ super().__init__()
92
+ if isinstance(angle, (float, int)):
93
+ angle = (angle, angle)
94
+ if center is not None and isinstance(center[0], (float, int)):
95
+ center = (center, center)
96
+ self.angle_save = None
97
+ self.center_save = None
98
+ self._cnt = 0
99
+ self._init(locals())
100
+
101
+ def get_transform(self, image):
102
+ h, w = image.shape[:2]
103
+ if self._cnt % self.clip_frame_cnt == 0:
104
+ center = None
105
+ angle = np.random.uniform(self.angle[0], self.angle[1], size=self.clip_frame_cnt)
106
+ if self.center is not None:
107
+ center = (
108
+ np.random.uniform(self.center[0][0], self.center[1][0]),
109
+ np.random.uniform(self.center[0][1], self.center[1][1]),
110
+ )
111
+ angle = np.sort(angle)
112
+ if self._rand_range() < self.prob:
113
+ angle = angle[::-1]
114
+ self.angle_save = angle
115
+ self.center_save = center
116
+
117
+ self._cnt = 0 # avoiding overflow
118
+
119
+ angle = self.angle_save[self._cnt]
120
+ center = self.center_save
121
+
122
+ self._cnt += 1
123
+
124
+ if center is not None:
125
+ center = (w * center[0], h * center[1]) # Convert to absolute coordinates
126
+
127
+ if angle % 360 == 0:
128
+ return NoOpTransform()
129
+
130
+ return T.RotationTransform(h, w, angle, expand=self.expand, center=center, interp=self.interp)
131
+
132
+
133
+ class ResizeScaleClip(T.Augmentation):
134
+ """
135
+ Takes target size as input and randomly scales the given target size between `min_scale`
136
+ and `max_scale`. It then scales the input image such that it fits inside the scaled target
137
+ box, keeping the aspect ratio constant.
138
+ This implements the resize part of the Google's 'resize_and_crop' data augmentation:
139
+ https://github.com/tensorflow/tpu/blob/master/models/official/detection/utils/input_utils.py#L127
140
+ """
141
+
142
+ def __init__(
143
+ self,
144
+ min_scale: float,
145
+ max_scale: float,
146
+ target_height: int,
147
+ target_width: int,
148
+ interp: int = Image.BILINEAR,
149
+ clip_frame_cnt=1,
150
+ ):
151
+ """
152
+ Args:
153
+ min_scale: minimum image scale range.
154
+ max_scale: maximum image scale range.
155
+ target_height: target image height.
156
+ target_width: target image width.
157
+ interp: image interpolation method.
158
+ """
159
+ super().__init__()
160
+ self._init(locals())
161
+ self._cnt = 0
162
+
163
+ def _get_resize(self, image: np.ndarray, scale: float):
164
+ input_size = image.shape[:2]
165
+
166
+ # Compute new target size given a scale.
167
+ target_size = (self.target_height, self.target_width)
168
+ target_scale_size = np.multiply(target_size, scale)
169
+
170
+ # Compute actual rescaling applied to input image and output size.
171
+ output_scale = np.minimum(
172
+ target_scale_size[0] / input_size[0], target_scale_size[1] / input_size[1]
173
+ )
174
+ output_size = np.round(np.multiply(input_size, output_scale)).astype(int)
175
+
176
+ return T.ResizeTransform(
177
+ input_size[0], input_size[1], output_size[0], output_size[1], self.interp
178
+ )
179
+
180
+ def get_transform(self, image: np.ndarray):
181
+ if self._cnt % self.clip_frame_cnt == 0:
182
+ random_scale = np.random.uniform(self.min_scale, self.max_scale)
183
+ self.random_scale_save = random_scale
184
+
185
+ self._cnt = 0 # avoiding overflow
186
+ self._cnt += 1
187
+ random_scale = self.random_scale_save
188
+
189
+ return self._get_resize(image, random_scale)
190
+
191
+
192
+ class RandomCropClip(T.Augmentation):
193
+ """
194
+ Randomly crop a rectangle region out of an image.
195
+ """
196
+
197
+ def __init__(self, crop_type: str, crop_size, clip_frame_cnt=1):
198
+ """
199
+ Args:
200
+ crop_type (str): one of "relative_range", "relative", "absolute", "absolute_range".
201
+ crop_size (tuple[float, float]): two floats, explained below.
202
+ - "relative": crop a (H * crop_size[0], W * crop_size[1]) region from an input image of
203
+ size (H, W). crop size should be in (0, 1]
204
+ - "relative_range": uniformly sample two values from [crop_size[0], 1]
205
+ and [crop_size[1]], 1], and use them as in "relative" crop type.
206
+ - "absolute" crop a (crop_size[0], crop_size[1]) region from input image.
207
+ crop_size must be smaller than the input image size.
208
+ - "absolute_range", for an input of size (H, W), uniformly sample H_crop in
209
+ [crop_size[0], min(H, crop_size[1])] and W_crop in [crop_size[0], min(W, crop_size[1])].
210
+ Then crop a region (H_crop, W_crop).
211
+ """
212
+ # TODO style of relative_range and absolute_range are not consistent:
213
+ # one takes (h, w) but another takes (min, max)
214
+ super().__init__()
215
+ assert crop_type in ["relative_range", "relative", "absolute", "absolute_range"]
216
+ self._init(locals())
217
+ self._cnt = 0
218
+
219
+ def get_transform(self, image):
220
+ h, w = image.shape[:2] # 667, 500
221
+ if self._cnt % self.clip_frame_cnt == 0:
222
+ croph, cropw = self.get_crop_size((h, w))
223
+ assert h >= croph and w >= cropw, "Shape computation in {} has bugs.".format(self)
224
+
225
+ h0 = np.random.randint(h - croph + 1) # rand(124) -> 5
226
+ w0 = np.random.randint(w - cropw + 1) # rand(111) -> 634
227
+
228
+ h1 = np.random.randint(h0, h - croph + 1)
229
+ w1 = np.random.randint(w0, w - cropw + 1)
230
+
231
+ x = np.sort(np.random.rand(self.clip_frame_cnt))
232
+
233
+ h = h0 * x + h1 * (1-x)
234
+ w = w0 * x + w1 * (1-x)
235
+ h = np.round_(h).astype(int)
236
+ w = np.round_(w).astype(int)
237
+
238
+ if self._rand_range() < 0.5:
239
+ h = h[::-1]
240
+ w = w[::-1]
241
+
242
+ self.hw_save = (h, w)
243
+ self.crop_h_save, self.crop_w_save = croph, cropw
244
+ self._cnt = 0 # avoiding overflow
245
+ _h, _w = self.hw_save[0][self._cnt], self.hw_save[1][self._cnt]
246
+ self._cnt += 1
247
+
248
+ return T.CropTransform(_w, _h, self.crop_w_save, self.crop_h_save)
249
+
250
+ def get_crop_size(self, image_size):
251
+ """
252
+ Args:
253
+ image_size (tuple): height, width
254
+ Returns:
255
+ crop_size (tuple): height, width in absolute pixels
256
+ """
257
+ h, w = image_size
258
+ if self.crop_type == "relative":
259
+ ch, cw = self.crop_size
260
+ return int(h * ch + 0.5), int(w * cw + 0.5)
261
+ elif self.crop_type == "relative_range":
262
+ crop_size = np.asarray(self.crop_size, dtype=float)
263
+ ch, cw = crop_size + np.random.rand(2) * (1 - crop_size)
264
+ return int(h * ch + 0.5), int(w * cw + 0.5)
265
+ elif self.crop_type == "absolute":
266
+ return (min(self.crop_size[0], h), min(self.crop_size[1], w))
267
+ elif self.crop_type == "absolute_range":
268
+ assert self.crop_size[0] <= self.crop_size[1]
269
+ ch = np.random.randint(min(h, self.crop_size[0]), min(h, self.crop_size[1]) + 1)
270
+ cw = np.random.randint(min(w, self.crop_size[0]), min(w, self.crop_size[1]) + 1)
271
+ return ch, cw
272
+ else:
273
+ raise NotImplementedError("Unknown crop type {}".format(self.crop_type))
274
+
275
+
276
+ class FixedSizeCropClip(T.Augmentation):
277
+ """
278
+ If `crop_size` is smaller than the input image size, then it uses a random crop of
279
+ the crop size. If `crop_size` is larger than the input image size, then it pads
280
+ the right and the bottom of the image to the crop size if `pad` is True, otherwise
281
+ it returns the smaller image.
282
+ """
283
+
284
+ def __init__(self, crop_size: Tuple[int], pad: bool = True, pad_value: float = 128.0, clip_frame_cnt=1):
285
+ """
286
+ Args:
287
+ crop_size: target image (height, width).
288
+ pad: if True, will pad images smaller than `crop_size` up to `crop_size`
289
+ pad_value: the padding value.
290
+ """
291
+ super().__init__()
292
+ self._init(locals())
293
+ self._cnt = 0
294
+
295
+ def _get_crop(self, image: np.ndarray):
296
+ # Compute the image scale and scaled size.
297
+ input_size = image.shape[:2]
298
+ output_size = self.crop_size
299
+
300
+ # Add random crop if the image is scaled up.
301
+ max_offset = np.subtract(input_size, output_size)
302
+ max_offset = np.maximum(max_offset, 0)
303
+
304
+ if self._cnt % self.clip_frame_cnt == 0:
305
+ offset = np.multiply(max_offset, np.random.uniform(0.0, 1.0))
306
+ offset = np.round(offset).astype(int)
307
+ self.offset_save = offset
308
+ self._cnt = 0 # avoiding overflow
309
+ self._cnt += 1
310
+ offset = self.offset_save
311
+ return CropTransform(
312
+ offset[1], offset[0], output_size[1], output_size[0], input_size[1], input_size[0]
313
+ )
314
+
315
+ def _get_pad(self, image: np.ndarray):
316
+ # Compute the image scale and scaled size.
317
+ input_size = image.shape[:2]
318
+ output_size = self.crop_size
319
+
320
+ # Add padding if the image is scaled down.
321
+ pad_size = np.subtract(output_size, input_size)
322
+ pad_size = np.maximum(pad_size, 0)
323
+ original_size = np.minimum(input_size, output_size)
324
+ return PadTransform(
325
+ 0, 0, pad_size[1], pad_size[0], original_size[1], original_size[0], self.pad_value
326
+ )
327
+
328
+ def get_transform(self, image: np.ndarray):
329
+ transforms = [self._get_crop(image)]
330
+ if self.pad:
331
+ transforms.append(self._get_pad(image))
332
+ return TransformList(transforms)
333
+
334
+
335
+ class ResizeShortestEdgeClip(T.Augmentation):
336
+ """
337
+ Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge.
338
+ If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
339
+ """
340
+
341
+ def __init__(
342
+ self, short_edge_length, max_size=sys.maxsize, sample_style="range", interp=Image.BILINEAR, clip_frame_cnt=1
343
+ ):
344
+ """
345
+ Args:
346
+ short_edge_length (list[int]): If ``sample_style=="range"``,
347
+ a [min, max] interval from which to sample the shortest edge length.
348
+ If ``sample_style=="choice"``, a list of shortest edge lengths to sample from.
349
+ max_size (int): maximum allowed longest edge length.
350
+ sample_style (str): either "range" or "choice".
351
+ """
352
+ super().__init__()
353
+ assert sample_style in ["range", "choice", "range_by_clip", "choice_by_clip"], sample_style
354
+
355
+ self.is_range = ("range" in sample_style)
356
+ if isinstance(short_edge_length, int):
357
+ short_edge_length = (short_edge_length, short_edge_length)
358
+ if self.is_range:
359
+ assert len(short_edge_length) == 2, (
360
+ "short_edge_length must be two values using 'range' sample style."
361
+ f" Got {short_edge_length}!"
362
+ )
363
+ self._cnt = 0
364
+ self._init(locals())
365
+
366
+ def get_transform(self, image):
367
+ if self._cnt % self.clip_frame_cnt == 0:
368
+ if self.is_range:
369
+ self.size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1)
370
+ else:
371
+ self.size = np.random.choice(self.short_edge_length)
372
+ self._cnt = 0 # avoiding overflow
373
+
374
+ if self.size == 0:
375
+ return NoOpTransform()
376
+ self._cnt += 1
377
+
378
+ h, w = image.shape[:2]
379
+
380
+ scale = self.size * 1.0 / min(h, w)
381
+ if h < w:
382
+ newh, neww = self.size, scale * w
383
+ else:
384
+ newh, neww = scale * h, self.size
385
+ if max(newh, neww) > self.max_size:
386
+ scale = self.max_size * 1.0 / max(newh, neww)
387
+ newh = newh * scale
388
+ neww = neww * scale
389
+ neww = int(neww + 0.5)
390
+ newh = int(newh + 0.5)
391
+ return T.ResizeTransform(h, w, newh, neww, self.interp)
392
+
393
+
394
+ class RandomFlipClip(T.Augmentation):
395
+ """
396
+ Flip the image horizontally or vertically with the given probability.
397
+ """
398
+
399
+ def __init__(self, prob=0.5, *, horizontal=True, vertical=False, clip_frame_cnt=1):
400
+ """
401
+ Args:
402
+ prob (float): probability of flip.
403
+ horizontal (boolean): whether to apply horizontal flipping
404
+ vertical (boolean): whether to apply vertical flipping
405
+ """
406
+ super().__init__()
407
+
408
+ if horizontal and vertical:
409
+ raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.")
410
+ if not horizontal and not vertical:
411
+ raise ValueError("At least one of horiz or vert has to be True!")
412
+ self._cnt = 0
413
+
414
+ self._init(locals())
415
+
416
+ def get_transform(self, image):
417
+ if self._cnt % self.clip_frame_cnt == 0:
418
+ self.do = self._rand_range() < self.prob
419
+ self._cnt = 0 # avoiding overflow
420
+ self._cnt += 1
421
+
422
+ h, w = image.shape[:2]
423
+
424
+ if self.do:
425
+ if self.horizontal:
426
+ return HFlipTransform(w)
427
+ elif self.vertical:
428
+ return VFlipTransform(h)
429
+ else:
430
+ return NoOpTransform()
431
+
432
+
433
+ def build_augmentation(cfg, is_train):
434
+ logger = logging.getLogger(__name__)
435
+ aug_list = []
436
+ if is_train:
437
+ use_lsj = cfg.INPUT.LSJ_AUG.ENABLED
438
+ if use_lsj:
439
+ image_size = cfg.INPUT.LSJ_AUG.IMAGE_SIZE
440
+ min_scale = cfg.INPUT.LSJ_AUG.MIN_SCALE
441
+ max_scale = cfg.INPUT.LSJ_AUG.MAX_SCALE
442
+
443
+ if cfg.INPUT.RANDOM_FLIP != "none":
444
+ if cfg.INPUT.RANDOM_FLIP == "flip_by_clip":
445
+ flip_clip_frame_cnt = cfg.INPUT.SAMPLING_FRAME_NUM
446
+ else:
447
+ flip_clip_frame_cnt = 1
448
+
449
+ aug_list.append(
450
+ # NOTE using RandomFlip modified for the support of flip maintenance
451
+ RandomFlipClip(
452
+ horizontal=(cfg.INPUT.RANDOM_FLIP == "horizontal") or (cfg.INPUT.RANDOM_FLIP == "flip_by_clip"),
453
+ vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
454
+ clip_frame_cnt=flip_clip_frame_cnt,
455
+ )
456
+ )
457
+
458
+ aug_list.extend([
459
+ T.ResizeScale(
460
+ min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size
461
+ ),
462
+ T.FixedSizeCrop(crop_size=(image_size, image_size)),
463
+ ])
464
+
465
+ else:
466
+ min_size = cfg.INPUT.MIN_SIZE_TRAIN
467
+ max_size = cfg.INPUT.MAX_SIZE_TRAIN
468
+ sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
469
+ clip_frame_cnt = cfg.INPUT.SAMPLING_FRAME_NUM if "by_clip" in cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING else 1
470
+
471
+ # Crop
472
+ if cfg.INPUT.CROP.ENABLED:
473
+ crop_aug = RandomApplyClip(
474
+ T.AugmentationList([
475
+ ResizeShortestEdgeClip([400, 500, 600], 1333, sample_style, clip_frame_cnt=clip_frame_cnt),
476
+ RandomCropClip(cfg.INPUT.PSEUDO.CROP.TYPE, cfg.INPUT.PSEUDO.CROP.SIZE, clip_frame_cnt=clip_frame_cnt)
477
+ ]),
478
+ clip_frame_cnt=clip_frame_cnt
479
+ )
480
+ aug_list.append(crop_aug)
481
+
482
+ # Resize
483
+ aug_list.append(ResizeShortestEdgeClip(min_size, max_size, sample_style, clip_frame_cnt=clip_frame_cnt))
484
+
485
+ # Flip
486
+ if cfg.INPUT.RANDOM_FLIP != "none":
487
+ if cfg.INPUT.RANDOM_FLIP == "flip_by_clip":
488
+ flip_clip_frame_cnt = cfg.INPUT.SAMPLING_FRAME_NUM
489
+ else:
490
+ flip_clip_frame_cnt = 1
491
+
492
+ aug_list.append(
493
+ # NOTE using RandomFlip modified for the support of flip maintenance
494
+ RandomFlipClip(
495
+ horizontal=(cfg.INPUT.RANDOM_FLIP == "horizontal") or (cfg.INPUT.RANDOM_FLIP == "flip_by_clip"),
496
+ vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
497
+ clip_frame_cnt=flip_clip_frame_cnt,
498
+ )
499
+ )
500
+
501
+ # Additional augmentations : brightness, contrast, saturation, rotation
502
+ augmentations = cfg.INPUT.AUGMENTATIONS
503
+ if "brightness" in augmentations:
504
+ aug_list.append(T.RandomBrightness(0.9, 1.1))
505
+ if "contrast" in augmentations:
506
+ aug_list.append(T.RandomContrast(0.9, 1.1))
507
+ if "saturation" in augmentations:
508
+ aug_list.append(T.RandomSaturation(0.9, 1.1))
509
+ if "rotation" in augmentations:
510
+ aug_list.append(
511
+ T.RandomRotation(
512
+ [-15, 15], expand=False, center=[(0.4, 0.4), (0.6, 0.6)], sample_style="range"
513
+ )
514
+ )
515
+ else:
516
+ # Resize
517
+ min_size = cfg.INPUT.MIN_SIZE_TEST
518
+ max_size = cfg.INPUT.MAX_SIZE_TEST
519
+ sample_style = "choice"
520
+ aug_list.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
521
+
522
+ return aug_list
523
+
524
+
525
+ def build_pseudo_augmentation(cfg, is_train):
526
+ logger = logging.getLogger(__name__)
527
+ aug_list = []
528
+ if is_train:
529
+ use_lsj = cfg.INPUT.LSJ_AUG.ENABLED
530
+ if use_lsj:
531
+ image_size = cfg.INPUT.LSJ_AUG.IMAGE_SIZE
532
+ min_scale = cfg.INPUT.LSJ_AUG.MIN_SCALE
533
+ max_scale = cfg.INPUT.LSJ_AUG.MAX_SCALE
534
+
535
+ if cfg.INPUT.RANDOM_FLIP != "none":
536
+ if cfg.INPUT.RANDOM_FLIP == "flip_by_clip":
537
+ clip_frame_cnt = cfg.INPUT.SAMPLING_FRAME_NUM
538
+ else:
539
+ clip_frame_cnt = 1
540
+
541
+ aug_list.append(
542
+ # NOTE using RandomFlip modified for the support of flip maintenance
543
+ RandomFlipClip(
544
+ horizontal=(cfg.INPUT.RANDOM_FLIP == "horizontal") or (cfg.INPUT.RANDOM_FLIP == "flip_by_clip"),
545
+ vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
546
+ clip_frame_cnt=clip_frame_cnt,
547
+ )
548
+ )
549
+
550
+ # Additional augmentations : brightness, contrast, saturation, rotation
551
+ augmentations = cfg.INPUT.PSEUDO.AUGMENTATIONS
552
+ if "brightness" in augmentations:
553
+ aug_list.append(T.RandomBrightness(0.9, 1.1))
554
+ if "contrast" in augmentations:
555
+ aug_list.append(T.RandomContrast(0.9, 1.1))
556
+ if "saturation" in augmentations:
557
+ aug_list.append(T.RandomSaturation(0.9, 1.1))
558
+ if "rotation" in augmentations:
559
+ aug_list.append(
560
+ RandomRotationClip(
561
+ [-15, 15], expand=False, center=[(0.4, 0.4), (0.6, 0.6)], clip_frame_cnt=clip_frame_cnt,
562
+ )
563
+ )
564
+
565
+ aug_list.extend([
566
+ ResizeScaleClip(
567
+ min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size,
568
+ clip_frame_cnt=clip_frame_cnt,
569
+ ),
570
+ FixedSizeCropClip(crop_size=(image_size, image_size), clip_frame_cnt=clip_frame_cnt),
571
+ ])
572
+ else:
573
+ min_size = cfg.INPUT.PSEUDO.MIN_SIZE_TRAIN
574
+ max_size = cfg.INPUT.PSEUDO.MAX_SIZE_TRAIN
575
+ sample_style = cfg.INPUT.PSEUDO.MIN_SIZE_TRAIN_SAMPLING
576
+ clip_frame_cnt = cfg.INPUT.SAMPLING_FRAME_NUM
577
+
578
+ # Crop
579
+ if cfg.INPUT.PSEUDO.CROP.ENABLED:
580
+ crop_aug = RandomApplyClip(
581
+ T.AugmentationList([
582
+ ResizeShortestEdgeClip([400, 500, 600], 1333, sample_style, clip_frame_cnt=clip_frame_cnt),
583
+ RandomCropClip(cfg.INPUT.PSEUDO.CROP.TYPE, cfg.INPUT.PSEUDO.CROP.SIZE, clip_frame_cnt=clip_frame_cnt)
584
+ ]),
585
+ clip_frame_cnt=clip_frame_cnt
586
+ )
587
+ aug_list.append(crop_aug)
588
+
589
+ # Resize
590
+ aug_list.append(ResizeShortestEdgeClip(min_size, max_size, sample_style, clip_frame_cnt=clip_frame_cnt))
591
+
592
+ # Flip
593
+ aug_list.append(
594
+ # NOTE using RandomFlip modified for the support of flip maintenance
595
+ RandomFlipClip(
596
+ horizontal=(cfg.INPUT.RANDOM_FLIP == "horizontal") or (cfg.INPUT.RANDOM_FLIP == "flip_by_clip"),
597
+ vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
598
+ clip_frame_cnt=clip_frame_cnt,
599
+ )
600
+ )
601
+
602
+ # Additional augmentations : brightness, contrast, saturation, rotation
603
+ augmentations = cfg.INPUT.PSEUDO.AUGMENTATIONS
604
+ if "brightness" in augmentations:
605
+ aug_list.append(T.RandomBrightness(0.9, 1.1))
606
+ if "contrast" in augmentations:
607
+ aug_list.append(T.RandomContrast(0.9, 1.1))
608
+ if "saturation" in augmentations:
609
+ aug_list.append(T.RandomSaturation(0.9, 1.1))
610
+ if "rotation" in augmentations:
611
+ aug_list.append(
612
+ RandomRotationClip(
613
+ [-15, 15], expand=False, center=[(0.4, 0.4), (0.6, 0.6)], clip_frame_cnt=clip_frame_cnt,
614
+ )
615
+ )
616
+ else:
617
+ # Resize
618
+ min_size = cfg.INPUT.MIN_SIZE_TEST
619
+ max_size = cfg.INPUT.MAX_SIZE_TEST
620
+ sample_style = "choice"
621
+ aug_list.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
622
+
623
+ return aug_list
avism/data/avis_eval.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import copy
4
+ import json
5
+ import logging
6
+ import contextlib
7
+ from collections import OrderedDict
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ import pycocotools.mask as mask_util
13
+ from multiprocessing import freeze_support
14
+ from fvcore.common.file_io import PathManager
15
+ from detectron2.data import MetadataCatalog
16
+ from detectron2.utils.file_io import PathManager
17
+ from detectron2.evaluation import DatasetEvaluator
18
+
19
+ from .datasets.avis_api.avos import AVOS
20
+
21
+ import sys
22
+ sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
23
+ import aviseval
24
+
25
+
26
+ def eval_track(out_dir, gt_file):
27
+ freeze_support()
28
+
29
+ # Command line interface:
30
+ default_eval_config = aviseval.Evaluator.get_default_eval_config()
31
+ default_dataset_config = aviseval.datasets.AVIS.get_default_dataset_config()
32
+ default_dataset_config['TRACKERS_FOLDER'] = out_dir
33
+ default_dataset_config['GT_File'] = gt_file
34
+ default_metrics_config = {'METRICS': ['TrackMAP', 'HOTA']} # 'CLEAR', 'Identity'
35
+ config = {**default_eval_config, **default_dataset_config, **default_metrics_config} # Merge default configs
36
+ eval_config = {k: v for k, v in config.items() if k in default_eval_config.keys()}
37
+ dataset_config = {k: v for k, v in config.items() if k in default_dataset_config.keys()}
38
+ metrics_config = {k: v for k, v in config.items() if k in default_metrics_config.keys()}
39
+
40
+ # Run code
41
+ evaluator = aviseval.Evaluator(eval_config)
42
+ dataset_list = [aviseval.datasets.AVIS(dataset_config)]
43
+ metrics_list = []
44
+ for metric in [aviseval.metrics.TrackMAP, aviseval.metrics.HOTA]:
45
+ if metric.get_name() in metrics_config['METRICS']:
46
+ if metric == aviseval.metrics.TrackMAP:
47
+ default_track_map_config = metric.get_default_metric_config()
48
+ default_track_map_config['USE_TIME_RANGES'] = False
49
+ default_track_map_config['AREA_RANGES'] = [[0 ** 2, 128 ** 2],
50
+ [128 ** 2, 256 ** 2],
51
+ [256 ** 2, 1e5 ** 2]]
52
+ metrics_list.append(metric(default_track_map_config))
53
+ else:
54
+ metrics_list.append(metric())
55
+ if len(metrics_list) == 0:
56
+ raise Exception('No metrics selected for evaluation')
57
+
58
+ output_res, output_msg = evaluator.evaluate(dataset_list, metrics_list)
59
+
60
+ return output_res
61
+
62
+
63
+ def instances_to_coco_json_video(inputs, outputs):
64
+ """
65
+ Dump an "Instances" object to a COCO-format json that's used for evaluation.
66
+
67
+ Args:
68
+ instances (Instances):
69
+ video_id (int): the image id
70
+
71
+ Returns:
72
+ list[dict]: list of json annotations in COCO format.
73
+ """
74
+ assert len(inputs) == 1, "More than one inputs are loaded for inference!"
75
+
76
+ video_id = inputs[0]["video_id"]
77
+ video_length = inputs[0]["length"]
78
+
79
+ scores = outputs["pred_scores"]
80
+ labels = outputs["pred_labels"]
81
+ masks = outputs["pred_masks"]
82
+
83
+ avis_results = []
84
+ for instance_id, (s, l, m) in enumerate(zip(scores, labels, masks)):
85
+ segms = [
86
+ mask_util.encode(np.array(_mask[:, :, None], order="F", dtype="uint8"))[0]
87
+ for _mask in m
88
+ ]
89
+ for rle in segms:
90
+ rle["counts"] = rle["counts"].decode("utf-8")
91
+
92
+ res = {
93
+ "video_id": video_id,
94
+ "score": s,
95
+ "category_id": l,
96
+ "segmentations": segms,
97
+ }
98
+ avis_results.append(res)
99
+
100
+ return avis_results
101
+
102
+
103
+
104
+ class AVISEvaluator(DatasetEvaluator):
105
+ def __init__(
106
+ self,
107
+ dataset_name,
108
+ tasks=None,
109
+ distributed=True,
110
+ output_dir=None,
111
+ *,
112
+ use_fast_impl=True,
113
+ ):
114
+ self._logger = logging.getLogger(__name__)
115
+ self._distributed = distributed
116
+ self._output_dir = output_dir
117
+ self._use_fast_impl = use_fast_impl
118
+
119
+ self._cpu_device = torch.device("cpu")
120
+
121
+ self.dataset_name = dataset_name
122
+ self._metadata = MetadataCatalog.get(dataset_name)
123
+
124
+ json_file = PathManager.get_local_path(self._metadata.json_file)
125
+ with contextlib.redirect_stdout(io.StringIO()):
126
+ self._avis_api = AVOS(json_file)
127
+
128
+ self._do_evaluation = "annotations" in self._avis_api.dataset
129
+
130
+
131
+ def reset(self):
132
+ self._predictions = []
133
+
134
+
135
+ def process(self, inputs, outputs):
136
+ """
137
+ Args:
138
+ inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
139
+ It is a list of dict. Each dict corresponds to an image and
140
+ contains keys like "height", "width", "file_name", "image_id".
141
+ outputs: the outputs of a COCO model. It is a list of dicts with key
142
+ "instances" that contains :class:`Instances`.
143
+ """
144
+ prediction = instances_to_coco_json_video(inputs, outputs)
145
+ self._predictions.extend(prediction)
146
+
147
+
148
+ def evaluate(self):
149
+ """
150
+ Args:
151
+ img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset
152
+ """
153
+
154
+ predictions = self._predictions
155
+
156
+ self._results = OrderedDict()
157
+ self._eval_predictions(predictions)
158
+ # Copy so the caller can do whatever with results
159
+ return copy.deepcopy(self._results)
160
+
161
+
162
+ def _eval_predictions(self, predictions):
163
+ """
164
+ Evaluate predictions. Fill self._results with the metrics of the tasks.
165
+ """
166
+ self._logger.info("Preparing results for AVIS format ...")
167
+
168
+ # unmap the category ids for COCO
169
+ if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
170
+ dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id
171
+
172
+ all_contiguous_ids = list(dataset_id_to_contiguous_id.values())
173
+ num_classes = len(all_contiguous_ids)
174
+ assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1
175
+
176
+ reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()}
177
+ for result in predictions:
178
+ category_id = result["category_id"]
179
+ assert category_id < num_classes, (
180
+ f"A prediction has class={category_id}, "
181
+ f"but the dataset only has {num_classes} classes and "
182
+ f"predicted class id should be in [0, {num_classes - 1}]."
183
+ )
184
+ result["category_id"] = reverse_id_mapping[category_id]
185
+
186
+ o_d = None
187
+ if self._output_dir:
188
+ o_d = os.path.join(self._output_dir, "results")
189
+ os.makedirs(os.path.join(o_d, "model_final"), exist_ok=True)
190
+ file_path = os.path.join(o_d, "model_final", "results.json")
191
+
192
+ self._logger.info("Saving results to {}".format(file_path))
193
+ with PathManager.open(file_path, "w") as f:
194
+ f.write(json.dumps(predictions))
195
+ f.flush()
196
+
197
+ if not self._do_evaluation:
198
+ self._logger.info("Annotations are not available for evaluation.")
199
+ return
200
+
201
+ assert o_d != None
202
+ output_res = eval_track(o_d, "test.json")
203
+ self._results["segm"] = output_res['AVIS']['model_final']
avism/data/aviseval/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .eval import Evaluator
2
+ from . import datasets
3
+ from . import metrics
4
+ from . import plotting
5
+ from . import utils
avism/data/aviseval/_timing.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ from time import perf_counter
3
+ import inspect
4
+
5
+ DO_TIMING = False
6
+ DISPLAY_LESS_PROGRESS = False
7
+ timer_dict = {}
8
+ counter = 0
9
+
10
+
11
+ def time(f):
12
+ @wraps(f)
13
+ def wrap(*args, **kw):
14
+ if DO_TIMING:
15
+ # Run function with timing
16
+ ts = perf_counter()
17
+ result = f(*args, **kw)
18
+ te = perf_counter()
19
+ tt = te-ts
20
+
21
+ # Get function name
22
+ arg_names = inspect.getfullargspec(f)[0]
23
+ if arg_names[0] == 'self' and DISPLAY_LESS_PROGRESS:
24
+ return result
25
+ elif arg_names[0] == 'self':
26
+ method_name = type(args[0]).__name__ + '.' + f.__name__
27
+ else:
28
+ method_name = f.__name__
29
+
30
+ # Record accumulative time in each function for analysis
31
+ if method_name in timer_dict.keys():
32
+ timer_dict[method_name] += tt
33
+ else:
34
+ timer_dict[method_name] = tt
35
+
36
+ # If code is finished, display timing summary
37
+ if method_name == "Evaluator.evaluate":
38
+ print("")
39
+ print("Timing analysis:")
40
+ for key, value in timer_dict.items():
41
+ print('%-70s %2.4f sec' % (key, value))
42
+ else:
43
+ # Get function argument values for printing special arguments of interest
44
+ arg_titles = ['tracker', 'seq', 'cls']
45
+ arg_vals = []
46
+ for i, a in enumerate(arg_names):
47
+ if a in arg_titles:
48
+ arg_vals.append(args[i])
49
+ arg_text = '(' + ', '.join(arg_vals) + ')'
50
+
51
+ # Display methods and functions with different indentation.
52
+ if arg_names[0] == 'self':
53
+ print('%-74s %2.4f sec' % (' '*4 + method_name + arg_text, tt))
54
+ elif arg_names[0] == 'test':
55
+ pass
56
+ else:
57
+ global counter
58
+ counter += 1
59
+ print('%i %-70s %2.4f sec' % (counter, method_name + arg_text, tt))
60
+
61
+ return result
62
+ else:
63
+ # If config["TIME_PROGRESS"] is false, or config["USE_PARALLEL"] is true, run functions normally without timing.
64
+ return f(*args, **kw)
65
+ return wrap
avism/data/aviseval/datasets/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .avis import AVIS
avism/data/aviseval/datasets/_base_dataset.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import io
3
+ import zipfile
4
+ import os
5
+ import traceback
6
+ import numpy as np
7
+ from copy import deepcopy
8
+ from abc import ABC, abstractmethod
9
+ from .. import _timing
10
+ from ..utils import TrackEvalException
11
+
12
+
13
+ class _BaseDataset(ABC):
14
+ @abstractmethod
15
+ def __init__(self):
16
+ self.tracker_list = None
17
+ self.seq_list = None
18
+ self.class_list = None
19
+ self.output_fol = None
20
+ self.output_sub_fol = None
21
+ self.should_classes_combine = True
22
+ self.use_super_categories = False
23
+
24
+ # Functions to implement:
25
+
26
+ @staticmethod
27
+ @abstractmethod
28
+ def get_default_dataset_config():
29
+ ...
30
+
31
+ @abstractmethod
32
+ def _load_raw_file(self, tracker, seq, is_gt):
33
+ ...
34
+
35
+ @_timing.time
36
+ @abstractmethod
37
+ def get_preprocessed_seq_data(self, raw_data, cls):
38
+ ...
39
+
40
+ @abstractmethod
41
+ def _calculate_similarities(self, gt_dets_t, tracker_dets_t):
42
+ ...
43
+
44
+ # Helper functions for all datasets:
45
+
46
+ @classmethod
47
+ def get_class_name(cls):
48
+ return cls.__name__
49
+
50
+ def get_name(self):
51
+ return self.get_class_name()
52
+
53
+ def get_output_fol(self, tracker):
54
+ return os.path.join(self.output_fol, tracker, self.output_sub_fol)
55
+
56
+ def get_display_name(self, tracker):
57
+ """ Can be overwritten if the trackers name (in files) is different to how it should be displayed.
58
+ By default this method just returns the trackers name as is.
59
+ """
60
+ return tracker
61
+
62
+ def get_eval_info(self):
63
+ """Return info about the dataset needed for the Evaluator"""
64
+ return self.tracker_list, self.seq_list, self.class_list
65
+
66
+ @_timing.time
67
+ def get_raw_seq_data(self, tracker, seq):
68
+ """ Loads raw data (tracker and ground-truth) for a single tracker on a single sequence.
69
+ Raw data includes all of the information needed for both preprocessing and evaluation, for all classes.
70
+ A later function (get_processed_seq_data) will perform such preprocessing and extract relevant information for
71
+ the evaluation of each class.
72
+
73
+ This returns a dict which contains the fields:
74
+ [num_timesteps]: integer
75
+ [gt_ids, tracker_ids, gt_classes, tracker_classes, tracker_confidences]:
76
+ list (for each timestep) of 1D NDArrays (for each det).
77
+ [gt_dets, tracker_dets, gt_crowd_ignore_regions]: list (for each timestep) of lists of detections.
78
+ [similarity_scores]: list (for each timestep) of 2D NDArrays.
79
+ [gt_extras]: dict (for each extra) of lists (for each timestep) of 1D NDArrays (for each det).
80
+
81
+ gt_extras contains dataset specific information used for preprocessing such as occlusion and truncation levels.
82
+
83
+ Note that similarities are extracted as part of the dataset and not the metric, because almost all metrics are
84
+ independent of the exact method of calculating the similarity. However datasets are not (e.g. segmentation
85
+ masks vs 2D boxes vs 3D boxes).
86
+ We calculate the similarity before preprocessing because often both preprocessing and evaluation require it and
87
+ we don't wish to calculate this twice.
88
+ We calculate similarity between all gt and tracker classes (not just each class individually) to allow for
89
+ calculation of metrics such as class confusion matrices. Typically the impact of this on performance is low.
90
+ """
91
+ # Load raw data.
92
+ raw_gt_data = self._load_raw_file(tracker, seq, is_gt=True)
93
+ raw_tracker_data = self._load_raw_file(tracker, seq, is_gt=False)
94
+ raw_data = {**raw_tracker_data, **raw_gt_data} # Merges dictionaries
95
+
96
+ # Calculate similarities for each timestep.
97
+ similarity_scores = []
98
+ for t, (gt_dets_t, tracker_dets_t) in enumerate(zip(raw_data['gt_dets'], raw_data['tracker_dets'])):
99
+ ious = self._calculate_similarities(gt_dets_t, tracker_dets_t)
100
+ similarity_scores.append(ious)
101
+ raw_data['similarity_scores'] = similarity_scores
102
+ return raw_data
103
+
104
+ @staticmethod
105
+ def _load_simple_text_file(file, time_col=0, id_col=None, remove_negative_ids=False, valid_filter=None,
106
+ crowd_ignore_filter=None, convert_filter=None, is_zipped=False, zip_file=None,
107
+ force_delimiters=None):
108
+ """ Function that loads data which is in a commonly used text file format.
109
+ Assumes each det is given by one row of a text file.
110
+ There is no limit to the number or meaning of each column,
111
+ however one column needs to give the timestep of each det (time_col) which is default col 0.
112
+
113
+ The file dialect (deliminator, num cols, etc) is determined automatically.
114
+ This function automatically separates dets by timestep,
115
+ and is much faster than alternatives such as np.loadtext or pandas.
116
+
117
+ If remove_negative_ids is True and id_col is not None, dets with negative values in id_col are excluded.
118
+ These are not excluded from ignore data.
119
+
120
+ valid_filter can be used to only include certain classes.
121
+ It is a dict with ints as keys, and lists as values,
122
+ such that a row is included if "row[key].lower() is in value" for all key/value pairs in the dict.
123
+ If None, all classes are included.
124
+
125
+ crowd_ignore_filter can be used to read crowd_ignore regions separately. It has the same format as valid filter.
126
+
127
+ convert_filter can be used to convert value read to another format.
128
+ This is used most commonly to convert classes given as string to a class id.
129
+ This is a dict such that the key is the column to convert, and the value is another dict giving the mapping.
130
+
131
+ Optionally, input files could be a zip of multiple text files for storage efficiency.
132
+
133
+ Returns read_data and ignore_data.
134
+ Each is a dict (with keys as timesteps as strings) of lists (over dets) of lists (over column values).
135
+ Note that all data is returned as strings, and must be converted to float/int later if needed.
136
+ Note that timesteps will not be present in the returned dict keys if there are no dets for them
137
+ """
138
+
139
+ if remove_negative_ids and id_col is None:
140
+ raise TrackEvalException('remove_negative_ids is True, but id_col is not given.')
141
+ if crowd_ignore_filter is None:
142
+ crowd_ignore_filter = {}
143
+ if convert_filter is None:
144
+ convert_filter = {}
145
+ try:
146
+ if is_zipped: # Either open file directly or within a zip.
147
+ if zip_file is None:
148
+ raise TrackEvalException('is_zipped set to True, but no zip_file is given.')
149
+ archive = zipfile.ZipFile(os.path.join(zip_file), 'r')
150
+ fp = io.TextIOWrapper(archive.open(file, 'r'))
151
+ else:
152
+ fp = open(file)
153
+ read_data = {}
154
+ crowd_ignore_data = {}
155
+ fp.seek(0, os.SEEK_END)
156
+ # check if file is empty
157
+ if fp.tell():
158
+ fp.seek(0)
159
+ dialect = csv.Sniffer().sniff(fp.readline(), delimiters=force_delimiters) # Auto determine structure.
160
+ dialect.skipinitialspace = True # Deal with extra spaces between columns
161
+ fp.seek(0)
162
+ reader = csv.reader(fp, dialect)
163
+ for row in reader:
164
+ try:
165
+ # Deal with extra trailing spaces at the end of rows
166
+ if row[-1] in '':
167
+ row = row[:-1]
168
+ timestep = str(int(float(row[time_col])))
169
+ # Read ignore regions separately.
170
+ is_ignored = False
171
+ for ignore_key, ignore_value in crowd_ignore_filter.items():
172
+ if row[ignore_key].lower() in ignore_value:
173
+ # Convert values in one column (e.g. string to id)
174
+ for convert_key, convert_value in convert_filter.items():
175
+ row[convert_key] = convert_value[row[convert_key].lower()]
176
+ # Save data separated by timestep.
177
+ if timestep in crowd_ignore_data.keys():
178
+ crowd_ignore_data[timestep].append(row)
179
+ else:
180
+ crowd_ignore_data[timestep] = [row]
181
+ is_ignored = True
182
+ if is_ignored: # if det is an ignore region, it cannot be a normal det.
183
+ continue
184
+ # Exclude some dets if not valid.
185
+ if valid_filter is not None:
186
+ for key, value in valid_filter.items():
187
+ if row[key].lower() not in value:
188
+ continue
189
+ if remove_negative_ids:
190
+ if int(float(row[id_col])) < 0:
191
+ continue
192
+ # Convert values in one column (e.g. string to id)
193
+ for convert_key, convert_value in convert_filter.items():
194
+ row[convert_key] = convert_value[row[convert_key].lower()]
195
+ # Save data separated by timestep.
196
+ if timestep in read_data.keys():
197
+ read_data[timestep].append(row)
198
+ else:
199
+ read_data[timestep] = [row]
200
+ except Exception:
201
+ exc_str_init = 'In file %s the following line cannot be read correctly: \n' % os.path.basename(
202
+ file)
203
+ exc_str = ' '.join([exc_str_init]+row)
204
+ raise TrackEvalException(exc_str)
205
+ fp.close()
206
+ except Exception:
207
+ print('Error loading file: %s, printing traceback.' % file)
208
+ traceback.print_exc()
209
+ raise TrackEvalException(
210
+ 'File %s cannot be read because it is either not present or invalidly formatted' % os.path.basename(
211
+ file))
212
+ return read_data, crowd_ignore_data
213
+
214
+ @staticmethod
215
+ def _calculate_mask_ious(masks1, masks2, is_encoded=False, do_ioa=False):
216
+ """ Calculates the IOU (intersection over union) between two arrays of segmentation masks.
217
+ If is_encoded a run length encoding with pycocotools is assumed as input format, otherwise an input of numpy
218
+ arrays of the shape (num_masks, height, width) is assumed and the encoding is performed.
219
+ If do_ioa (intersection over area) , then calculates the intersection over the area of masks1 - this is commonly
220
+ used to determine if detections are within crowd ignore region.
221
+ :param masks1: first set of masks (numpy array of shape (num_masks, height, width) if not encoded,
222
+ else pycocotools rle encoded format)
223
+ :param masks2: second set of masks (numpy array of shape (num_masks, height, width) if not encoded,
224
+ else pycocotools rle encoded format)
225
+ :param is_encoded: whether the input is in pycocotools rle encoded format
226
+ :param do_ioa: whether to perform IoA computation
227
+ :return: the IoU/IoA scores
228
+ """
229
+
230
+ # Only loaded when run to reduce minimum requirements
231
+ from pycocotools import mask as mask_utils
232
+
233
+ # use pycocotools for run length encoding of masks
234
+ if not is_encoded:
235
+ masks1 = mask_utils.encode(np.array(np.transpose(masks1, (1, 2, 0)), order='F'))
236
+ masks2 = mask_utils.encode(np.array(np.transpose(masks2, (1, 2, 0)), order='F'))
237
+
238
+ # use pycocotools for iou computation of rle encoded masks
239
+ ious = mask_utils.iou(masks1, masks2, [do_ioa]*len(masks2))
240
+ if len(masks1) == 0 or len(masks2) == 0:
241
+ ious = np.asarray(ious).reshape(len(masks1), len(masks2))
242
+ assert (ious >= 0 - np.finfo('float').eps).all()
243
+ assert (ious <= 1 + np.finfo('float').eps).all()
244
+
245
+ return ious
246
+
247
+ @staticmethod
248
+ def _calculate_box_ious(bboxes1, bboxes2, box_format='xywh', do_ioa=False):
249
+ """ Calculates the IOU (intersection over union) between two arrays of boxes.
250
+ Allows variable box formats ('xywh' and 'x0y0x1y1').
251
+ If do_ioa (intersection over area) , then calculates the intersection over the area of boxes1 - this is commonly
252
+ used to determine if detections are within crowd ignore region.
253
+ """
254
+ if box_format in 'xywh':
255
+ # layout: (x0, y0, w, h)
256
+ bboxes1 = deepcopy(bboxes1)
257
+ bboxes2 = deepcopy(bboxes2)
258
+
259
+ bboxes1[:, 2] = bboxes1[:, 0] + bboxes1[:, 2]
260
+ bboxes1[:, 3] = bboxes1[:, 1] + bboxes1[:, 3]
261
+ bboxes2[:, 2] = bboxes2[:, 0] + bboxes2[:, 2]
262
+ bboxes2[:, 3] = bboxes2[:, 1] + bboxes2[:, 3]
263
+ elif box_format not in 'x0y0x1y1':
264
+ raise (TrackEvalException('box_format %s is not implemented' % box_format))
265
+
266
+ # layout: (x0, y0, x1, y1)
267
+ min_ = np.minimum(bboxes1[:, np.newaxis, :], bboxes2[np.newaxis, :, :])
268
+ max_ = np.maximum(bboxes1[:, np.newaxis, :], bboxes2[np.newaxis, :, :])
269
+ intersection = np.maximum(min_[..., 2] - max_[..., 0], 0) * np.maximum(min_[..., 3] - max_[..., 1], 0)
270
+ area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
271
+
272
+ if do_ioa:
273
+ ioas = np.zeros_like(intersection)
274
+ valid_mask = area1 > 0 + np.finfo('float').eps
275
+ ioas[valid_mask, :] = intersection[valid_mask, :] / area1[valid_mask][:, np.newaxis]
276
+
277
+ return ioas
278
+ else:
279
+ area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1])
280
+ union = area1[:, np.newaxis] + area2[np.newaxis, :] - intersection
281
+ intersection[area1 <= 0 + np.finfo('float').eps, :] = 0
282
+ intersection[:, area2 <= 0 + np.finfo('float').eps] = 0
283
+ intersection[union <= 0 + np.finfo('float').eps] = 0
284
+ union[union <= 0 + np.finfo('float').eps] = 1
285
+ ious = intersection / union
286
+ return ious
287
+
288
+ @staticmethod
289
+ def _calculate_euclidean_similarity(dets1, dets2, zero_distance=2.0):
290
+ """ Calculates the euclidean distance between two sets of detections, and then converts this into a similarity
291
+ measure with values between 0 and 1 using the following formula: sim = max(0, 1 - dist/zero_distance).
292
+ The default zero_distance of 2.0, corresponds to the default used in MOT15_3D, such that a 0.5 similarity
293
+ threshold corresponds to a 1m distance threshold for TPs.
294
+ """
295
+ dist = np.linalg.norm(dets1[:, np.newaxis]-dets2[np.newaxis, :], axis=2)
296
+ sim = np.maximum(0, 1 - dist/zero_distance)
297
+ return sim
298
+
299
+ @staticmethod
300
+ def _check_unique_ids(data, after_preproc=False):
301
+ """Check the requirement that the tracker_ids and gt_ids are unique per timestep"""
302
+ gt_ids = data['gt_ids']
303
+ tracker_ids = data['tracker_ids']
304
+ for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(gt_ids, tracker_ids)):
305
+ if len(tracker_ids_t) > 0:
306
+ unique_ids, counts = np.unique(tracker_ids_t, return_counts=True)
307
+ if np.max(counts) != 1:
308
+ duplicate_ids = unique_ids[counts > 1]
309
+ exc_str_init = 'Tracker predicts the same ID more than once in a single timestep ' \
310
+ '(seq: %s, frame: %i, ids:' % (data['seq'], t+1)
311
+ exc_str = ' '.join([exc_str_init] + [str(d) for d in duplicate_ids]) + ')'
312
+ if after_preproc:
313
+ exc_str_init += '\n Note that this error occurred after preprocessing (but not before), ' \
314
+ 'so ids may not be as in file, and something seems wrong with preproc.'
315
+ raise TrackEvalException(exc_str)
316
+ if len(gt_ids_t) > 0:
317
+ unique_ids, counts = np.unique(gt_ids_t, return_counts=True)
318
+ if np.max(counts) != 1:
319
+ duplicate_ids = unique_ids[counts > 1]
320
+ exc_str_init = 'Ground-truth has the same ID more than once in a single timestep ' \
321
+ '(seq: %s, frame: %i, ids:' % (data['seq'], t+1)
322
+ exc_str = ' '.join([exc_str_init] + [str(d) for d in duplicate_ids]) + ')'
323
+ if after_preproc:
324
+ exc_str_init += '\n Note that this error occurred after preprocessing (but not before), ' \
325
+ 'so ids may not be as in file, and something seems wrong with preproc.'
326
+ raise TrackEvalException(exc_str)
avism/data/aviseval/datasets/avis.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import json
4
+ from ._base_dataset import _BaseDataset
5
+ from ..utils import TrackEvalException
6
+ from .. import utils
7
+ from .. import _timing
8
+
9
+
10
+ class AVIS(_BaseDataset):
11
+ """Dataset class for AVIS tracking"""
12
+
13
+ @staticmethod
14
+ def get_default_dataset_config():
15
+ """Default class config values"""
16
+ default_config = {
17
+ 'GT_FOLDER': "./datasets/", # Location of GT data
18
+ 'TRACKERS_FOLDER': "./outputs/avism_R50_IN/inference/",
19
+ 'GT_File': "test.json",
20
+ # Trackers location
21
+ 'OUTPUT_FOLDER': None, # Where to save eval results (if None, same as TRACKERS_FOLDER)
22
+ 'TRACKERS_TO_EVAL': None, # Filenames of trackers to eval (if None, all in folder)
23
+ 'CLASSES_TO_EVAL': None, # Classes to eval (if None, all classes)
24
+ 'SPLIT_TO_EVAL': None, # Valid: 'train', 'val', 'train_sub_split'
25
+ 'PRINT_CONFIG': False, # Whether to print current config
26
+ 'OUTPUT_SUB_FOLDER': '', # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER
27
+ 'TRACKER_DISPLAY_NAMES': None, # Names of trackers to display, if None: TRACKERS_TO_EVAL
28
+ }
29
+ return default_config
30
+
31
+ def __init__(self, config=None):
32
+ """Initialise dataset, checking that all required files are present"""
33
+ super().__init__()
34
+ # Fill non-given config values with defaults
35
+ self.config = utils.init_config(config, self.get_default_dataset_config(), self.get_name())
36
+ self.gt_fol = self.config['GT_FOLDER']
37
+ self.tracker_fol = self.config['TRACKERS_FOLDER']
38
+ self.use_super_categories = False
39
+ self.should_classes_combine = True
40
+
41
+ self.output_fol = self.config['OUTPUT_FOLDER']
42
+ if self.output_fol is None:
43
+ self.output_fol = self.tracker_fol
44
+ self.output_sub_fol = self.config['OUTPUT_SUB_FOLDER']
45
+
46
+ if not os.path.exists(self.gt_fol):
47
+ print("GT folder not found: " + self.gt_fol)
48
+ raise TrackEvalException("GT folder not found: " + os.path.basename(self.gt_fol))
49
+ gt_dir_files = [self.config['GT_File']]
50
+ if len(gt_dir_files) != 1:
51
+ raise TrackEvalException(self.gt_fol + ' does not contain exactly one json file.')
52
+
53
+ with open(os.path.join(self.gt_fol, gt_dir_files[0])) as f:
54
+ self.gt_data = json.load(f)
55
+
56
+ # Get classes to eval
57
+ self.valid_classes = [cls['name'] for cls in self.gt_data['categories']]
58
+ cls_name_to_cls_id_map = {cls['name']: cls['id'] for cls in self.gt_data['categories']}
59
+
60
+ if self.config['CLASSES_TO_EVAL']:
61
+ self.class_list = [cls.lower() if cls.lower() in self.valid_classes else None
62
+ for cls in self.config['CLASSES_TO_EVAL']]
63
+ if not all(self.class_list):
64
+ raise TrackEvalException('Attempted to evaluate an invalid class. Only classes ' +
65
+ ', '.join(self.valid_classes) + ' are valid.')
66
+ else:
67
+ self.class_list = [cls['name'] for cls in self.gt_data['categories']]
68
+ self.class_name_to_class_id = {k: v for k, v in cls_name_to_cls_id_map.items() if k in self.class_list}
69
+
70
+ # Get sequences to eval and check gt files exist
71
+ self.seq_list = [vid['file_names'][0].split('/')[0] for vid in self.gt_data['videos']]
72
+ self.seq_name_to_seq_id = {vid['file_names'][0].split('/')[0]: vid['id'] for vid in self.gt_data['videos']}
73
+ self.seq_lengths = {vid['id']: len(vid['file_names']) for vid in self.gt_data['videos']}
74
+
75
+ # encode masks and compute track areas
76
+ self._prepare_gt_annotations()
77
+
78
+ # Get trackers to eval
79
+ if self.config['TRACKERS_TO_EVAL'] is None:
80
+ self.tracker_list = os.listdir(self.tracker_fol)
81
+ else:
82
+ self.tracker_list = self.config['TRACKERS_TO_EVAL']
83
+
84
+ if self.config['TRACKER_DISPLAY_NAMES'] is None:
85
+ self.tracker_to_disp = dict(zip(self.tracker_list, self.tracker_list))
86
+ elif (self.config['TRACKERS_TO_EVAL'] is not None) and (
87
+ len(self.config['TRACKER_DISPLAY_NAMES']) == len(self.tracker_list)):
88
+ self.tracker_to_disp = dict(zip(self.tracker_list, self.config['TRACKER_DISPLAY_NAMES']))
89
+ else:
90
+ raise TrackEvalException('List of tracker files and tracker display names do not match.')
91
+
92
+ # counter for globally unique track IDs
93
+ self.global_tid_counter = 0
94
+
95
+ self.tracker_data = dict()
96
+ for tracker in self.tracker_list:
97
+ tracker_dir_path = os.path.join(self.tracker_fol, tracker)
98
+ tr_dir_files = [file for file in os.listdir(tracker_dir_path) if file.endswith('.json')]
99
+ if len(tr_dir_files) != 1:
100
+ raise TrackEvalException(tracker_dir_path + ' does not contain exactly one json file.')
101
+
102
+ with open(os.path.join(tracker_dir_path, tr_dir_files[0])) as f:
103
+ curr_data = json.load(f)
104
+
105
+ self.tracker_data[tracker] = curr_data
106
+
107
+ def get_display_name(self, tracker):
108
+ return self.tracker_to_disp[tracker]
109
+
110
+ def _load_raw_file(self, tracker, seq, is_gt):
111
+ """Load a file (gt or tracker) in the YouTubeVIS format
112
+ If is_gt, this returns a dict which contains the fields:
113
+ [gt_ids, gt_classes] : list (for each timestep) of 1D NDArrays (for each det).
114
+ [gt_dets]: list (for each timestep) of lists of detections.
115
+ [classes_to_gt_tracks]: dictionary with class values as keys and list of dictionaries (with frame indices as
116
+ keys and corresponding segmentations as values) for each track
117
+ [classes_to_gt_track_ids, classes_to_gt_track_areas, classes_to_gt_track_iscrowd]: dictionary with class values
118
+ as keys and lists (for each track) as values
119
+
120
+ if not is_gt, this returns a dict which contains the fields:
121
+ [tracker_ids, tracker_classes, tracker_confidences] : list (for each timestep) of 1D NDArrays (for each det).
122
+ [tracker_dets]: list (for each timestep) of lists of detections.
123
+ [classes_to_dt_tracks]: dictionary with class values as keys and list of dictionaries (with frame indices as
124
+ keys and corresponding segmentations as values) for each track
125
+ [classes_to_dt_track_ids, classes_to_dt_track_areas]: dictionary with class values as keys and lists as values
126
+ [classes_to_dt_track_scores]: dictionary with class values as keys and 1D numpy arrays as values
127
+ """
128
+ # select sequence tracks
129
+ seq_id = self.seq_name_to_seq_id[seq]
130
+ if is_gt:
131
+ tracks = [ann for ann in self.gt_data['annotations'] if ann['video_id'] == seq_id]
132
+ else:
133
+ tracks = self._get_tracker_seq_tracks(tracker, seq_id)
134
+
135
+ # Convert data to required format
136
+ num_timesteps = self.seq_lengths[seq_id]
137
+ data_keys = ['ids', 'classes', 'dets']
138
+ if not is_gt:
139
+ data_keys += ['tracker_confidences']
140
+ raw_data = {key: [None] * num_timesteps for key in data_keys}
141
+ raw_data['raw_dets'] = [None] * num_timesteps
142
+ raw_data['raw_classes'] = [None] * num_timesteps
143
+
144
+ for t in range(num_timesteps):
145
+ raw_data['raw_dets'][t] = [track['segmentations'][t] for track in tracks]
146
+ raw_data['raw_classes'][t] = np.atleast_1d([track['category_id'] for track in tracks]).astype(int)
147
+
148
+ raw_data['dets'][t] = [track['segmentations'][t] for track in tracks if track['segmentations'][t]]
149
+ raw_data['ids'][t] = np.atleast_1d([track['id'] for track in tracks if track['segmentations'][t]]).astype(int)
150
+ raw_data['classes'][t] = np.atleast_1d([track['category_id'] for track in tracks if track['segmentations'][t]]).astype(int)
151
+ if not is_gt:
152
+ raw_data['tracker_confidences'][t] = np.atleast_1d([track['score'] for track in tracks if track['segmentations'][t]]).astype(float)
153
+
154
+ if is_gt:
155
+ key_map = {'ids': 'gt_ids',
156
+ 'classes': 'gt_classes',
157
+ 'dets': 'gt_dets'}
158
+ else:
159
+ key_map = {'ids': 'tracker_ids',
160
+ 'classes': 'tracker_classes',
161
+ 'dets': 'tracker_dets'}
162
+ for k, v in key_map.items():
163
+ raw_data[v] = raw_data.pop(k)
164
+
165
+ all_cls_ids = {self.class_name_to_class_id[cls] for cls in self.class_list}
166
+ classes_to_tracks = {cls: [track for track in tracks if track['category_id'] == cls] for cls in all_cls_ids}
167
+
168
+ # mapping from classes to track representations and track information
169
+ raw_data['classes_to_tracks'] = {cls: [{i: track['segmentations'][i]
170
+ for i in range(len(track['segmentations']))} for track in tracks]
171
+ for cls, tracks in classes_to_tracks.items()}
172
+ raw_data['classes_to_track_ids'] = {cls: [track['id'] for track in tracks]
173
+ for cls, tracks in classes_to_tracks.items()}
174
+ raw_data['classes_to_track_areas'] = {cls: [track['area'] for track in tracks]
175
+ for cls, tracks in classes_to_tracks.items()}
176
+
177
+ if is_gt:
178
+ raw_data['classes_to_gt_track_iscrowd'] = {cls: [track['iscrowd'] for track in tracks]
179
+ for cls, tracks in classes_to_tracks.items()}
180
+ else:
181
+ raw_data['classes_to_dt_track_scores'] = {cls: np.array([track['score'] for track in tracks])
182
+ for cls, tracks in classes_to_tracks.items()}
183
+
184
+ if is_gt:
185
+ key_map = {'classes_to_tracks': 'classes_to_gt_tracks',
186
+ 'classes_to_track_ids': 'classes_to_gt_track_ids',
187
+ 'classes_to_track_areas': 'classes_to_gt_track_areas'}
188
+ else:
189
+ key_map = {'classes_to_tracks': 'classes_to_dt_tracks',
190
+ 'classes_to_track_ids': 'classes_to_dt_track_ids',
191
+ 'classes_to_track_areas': 'classes_to_dt_track_areas'}
192
+ for k, v in key_map.items():
193
+ raw_data[v] = raw_data.pop(k)
194
+
195
+ raw_data['num_timesteps'] = num_timesteps
196
+ raw_data['seq'] = seq
197
+ return raw_data
198
+
199
+ @_timing.time
200
+ def get_preprocessed_seq_data(self, raw_data, cls):
201
+ """ Preprocess data for a single sequence for a single class ready for evaluation.
202
+ Inputs:
203
+ - raw_data is a dict containing the data for the sequence already read in by get_raw_seq_data().
204
+ - cls is the class to be evaluated.
205
+ Outputs:
206
+ - data is a dict containing all of the information that metrics need to perform evaluation.
207
+ It contains the following fields:
208
+ [num_timesteps, num_gt_ids, num_tracker_ids, num_gt_dets, num_tracker_dets] : integers.
209
+ [gt_ids, tracker_ids, tracker_confidences]: list (for each timestep) of 1D NDArrays (for each det).
210
+ [gt_dets, tracker_dets]: list (for each timestep) of lists of detections.
211
+ [similarity_scores]: list (for each timestep) of 2D NDArrays.
212
+ Notes:
213
+ General preprocessing (preproc) occurs in 4 steps. Some datasets may not use all of these steps.
214
+ 1) Extract only detections relevant for the class to be evaluated (including distractor detections).
215
+ 2) Match gt dets and tracker dets. Remove tracker dets that are matched to a gt det that is of a
216
+ distractor class, or otherwise marked as to be removed.
217
+ 3) Remove unmatched tracker dets if they fall within a crowd ignore region or don't meet a certain
218
+ other criteria (e.g. are too small).
219
+ 4) Remove gt dets that were only useful for preprocessing and not for actual evaluation.
220
+ After the above preprocessing steps, this function also calculates the number of gt and tracker detections
221
+ and unique track ids. It also relabels gt and tracker ids to be contiguous and checks that ids are
222
+ unique within each timestep.
223
+ YouTubeVIS:
224
+ In YouTubeVIS, the 4 preproc steps are as follow:
225
+ 1) There are 40 classes which are evaluated separately.
226
+ 2) No matched tracker dets are removed.
227
+ 3) No unmatched tracker dets are removed.
228
+ 4) No gt dets are removed.
229
+ Further, for TrackMAP computation track representations for the given class are accessed from a dictionary
230
+ and the tracks from the tracker data are sorted according to the tracker confidence.
231
+ """
232
+ cls_id = self.class_name_to_class_id[cls]
233
+
234
+ data_keys = ['gt_ids', 'tracker_ids', 'gt_dets', 'tracker_dets', 'similarity_scores']
235
+ data = {key: [None] * raw_data['num_timesteps'] for key in data_keys}
236
+ unique_gt_ids = []
237
+ unique_tracker_ids = []
238
+ num_gt_dets = 0
239
+ num_tracker_dets = 0
240
+
241
+ for t in range(raw_data['num_timesteps']):
242
+
243
+ # Only extract relevant dets for this class for eval (cls)
244
+ gt_class_mask = np.atleast_1d(raw_data['gt_classes'][t] == cls_id)
245
+ gt_class_mask = gt_class_mask.astype(bool)
246
+ gt_ids = raw_data['gt_ids'][t][gt_class_mask]
247
+ gt_dets = [raw_data['gt_dets'][t][ind] for ind in range(len(gt_class_mask)) if gt_class_mask[ind]]
248
+
249
+ tracker_class_mask = np.atleast_1d(raw_data['tracker_classes'][t] == cls_id)
250
+ tracker_class_mask = tracker_class_mask.astype(bool)
251
+ tracker_ids = raw_data['tracker_ids'][t][tracker_class_mask]
252
+ tracker_dets = [raw_data['tracker_dets'][t][ind] for ind in range(len(tracker_class_mask)) if
253
+ tracker_class_mask[ind]]
254
+ similarity_scores = raw_data['similarity_scores'][t][gt_class_mask, :][:, tracker_class_mask]
255
+
256
+ data['tracker_ids'][t] = tracker_ids
257
+ data['tracker_dets'][t] = tracker_dets
258
+ data['gt_ids'][t] = gt_ids
259
+ data['gt_dets'][t] = gt_dets
260
+ data['similarity_scores'][t] = similarity_scores
261
+
262
+ unique_gt_ids += list(np.unique(data['gt_ids'][t]))
263
+ unique_tracker_ids += list(np.unique(data['tracker_ids'][t]))
264
+ num_tracker_dets += len(data['tracker_ids'][t])
265
+ num_gt_dets += len(data['gt_ids'][t])
266
+
267
+ # Re-label IDs such that there are no empty IDs
268
+ if len(unique_gt_ids) > 0:
269
+ unique_gt_ids = np.unique(unique_gt_ids)
270
+ gt_id_map = np.nan * np.ones((np.max(unique_gt_ids) + 1))
271
+ gt_id_map[unique_gt_ids] = np.arange(len(unique_gt_ids))
272
+ for t in range(raw_data['num_timesteps']):
273
+ if len(data['gt_ids'][t]) > 0:
274
+ data['gt_ids'][t] = gt_id_map[data['gt_ids'][t]].astype(int)
275
+ if len(unique_tracker_ids) > 0:
276
+ unique_tracker_ids = np.unique(unique_tracker_ids)
277
+ tracker_id_map = np.nan * np.ones((np.max(unique_tracker_ids) + 1))
278
+ tracker_id_map[unique_tracker_ids] = np.arange(len(unique_tracker_ids))
279
+ for t in range(raw_data['num_timesteps']):
280
+ if len(data['tracker_ids'][t]) > 0:
281
+ data['tracker_ids'][t] = tracker_id_map[data['tracker_ids'][t]].astype(int)
282
+
283
+ # Ensure that ids are unique per timestep.
284
+ self._check_unique_ids(data)
285
+
286
+ # Record overview statistics.
287
+ data['num_tracker_dets'] = num_tracker_dets
288
+ data['num_gt_dets'] = num_gt_dets
289
+ data['num_tracker_ids'] = len(unique_tracker_ids)
290
+ data['num_gt_ids'] = len(unique_gt_ids)
291
+ data['num_timesteps'] = raw_data['num_timesteps']
292
+ data['seq'] = raw_data['seq']
293
+
294
+ # get track representations
295
+ data['gt_tracks'] = raw_data['classes_to_gt_tracks'][cls_id]
296
+ data['gt_track_ids'] = raw_data['classes_to_gt_track_ids'][cls_id]
297
+ data['gt_track_areas'] = raw_data['classes_to_gt_track_areas'][cls_id]
298
+ data['gt_track_iscrowd'] = raw_data['classes_to_gt_track_iscrowd'][cls_id]
299
+ data['dt_tracks'] = raw_data['classes_to_dt_tracks'][cls_id]
300
+ data['dt_track_ids'] = raw_data['classes_to_dt_track_ids'][cls_id]
301
+ data['dt_track_areas'] = raw_data['classes_to_dt_track_areas'][cls_id]
302
+ data['dt_track_scores'] = raw_data['classes_to_dt_track_scores'][cls_id]
303
+ data['iou_type'] = 'mask'
304
+
305
+ # sort tracker data tracks by tracker confidence scores
306
+ if data['dt_tracks']:
307
+ idx = np.argsort([-score for score in data['dt_track_scores']], kind="mergesort")
308
+ data['dt_track_scores'] = [data['dt_track_scores'][i] for i in idx]
309
+ data['dt_tracks'] = [data['dt_tracks'][i] for i in idx]
310
+ data['dt_track_ids'] = [data['dt_track_ids'][i] for i in idx]
311
+ data['dt_track_areas'] = [data['dt_track_areas'][i] for i in idx]
312
+
313
+ return data
314
+
315
+ def _calculate_similarities(self, gt_dets_t, tracker_dets_t):
316
+ similarity_scores = self._calculate_mask_ious(gt_dets_t, tracker_dets_t, is_encoded=True, do_ioa=False)
317
+ return similarity_scores
318
+
319
+ def _prepare_gt_annotations(self):
320
+ """
321
+ Prepares GT data by rle encoding segmentations and computing the average track area.
322
+ :return: None
323
+ """
324
+ # only loaded when needed to reduce minimum requirements
325
+ from pycocotools import mask as mask_utils
326
+
327
+ for track in self.gt_data['annotations']:
328
+ h = track['height']
329
+ w = track['width']
330
+ for i, seg in enumerate(track['segmentations']):
331
+ if seg:
332
+ masks = mask_utils.frPyObjects(seg, h, w)
333
+ track['segmentations'][i] = mask_utils.merge(masks)
334
+ # track['segmentations'][i] = mask_utils.frPyObjects(seg, h, w)
335
+ areas = [a for a in track['areas'] if a]
336
+ if len(areas) == 0:
337
+ track['area'] = 0
338
+ else:
339
+ track['area'] = np.array(areas).mean()
340
+
341
+ def _get_tracker_seq_tracks(self, tracker, seq_id):
342
+ """
343
+ Prepares tracker data for a given sequence. Extracts all annotations for given sequence ID, computes
344
+ average track area and assigns a track ID.
345
+ :param tracker: the given tracker
346
+ :param seq_id: the sequence ID
347
+ :return: the extracted tracks
348
+ """
349
+ # only loaded when needed to reduce minimum requirements
350
+ from pycocotools import mask as mask_utils
351
+
352
+ tracks = [ann for ann in self.tracker_data[tracker] if ann['video_id'] == seq_id]
353
+ for track in tracks:
354
+ track['areas'] = []
355
+ for seg in track['segmentations']:
356
+ if seg:
357
+ track['areas'].append(mask_utils.area(seg))
358
+ else:
359
+ track['areas'].append(None)
360
+ areas = [a for a in track['areas'] if a]
361
+ if len(areas) == 0:
362
+ track['area'] = 0
363
+ else:
364
+ track['area'] = np.array(areas).mean()
365
+ track['id'] = self.global_tid_counter
366
+ self.global_tid_counter += 1
367
+ return tracks
avism/data/aviseval/eval.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tqdm
3
+ import traceback
4
+ import numpy as np
5
+
6
+ from . import utils
7
+ from . import _timing
8
+ from .metrics import Count
9
+ from .utils import TrackEvalException
10
+ from .metrics import compute_av_loc, combine_av_loc_sequences
11
+
12
+
13
+ class Evaluator:
14
+ """Evaluator class for evaluating different metrics for different datasets"""
15
+
16
+ @staticmethod
17
+ def get_default_eval_config():
18
+ """Returns the default config values for evaluation"""
19
+ code_path = utils.get_code_path()
20
+ default_config = {
21
+ 'USE_PARALLEL': False,
22
+ 'NUM_PARALLEL_CORES': 8,
23
+ 'BREAK_ON_ERROR': True, # Raises exception and exits with error
24
+ 'RETURN_ON_ERROR': False, # if not BREAK_ON_ERROR, then returns from function on error
25
+ 'LOG_ON_ERROR': os.path.join(code_path, 'error_log.txt'), # if not None, save any errors into a log file.
26
+
27
+ 'PRINT_RESULTS': False,
28
+ 'PRINT_ONLY_COMBINED': False,
29
+ 'PRINT_CONFIG': False,
30
+ 'TIME_PROGRESS': False,
31
+ 'DISPLAY_LESS_PROGRESS': True,
32
+
33
+ 'OUTPUT_SUMMARY': False,
34
+ 'OUTPUT_EMPTY_CLASSES': False,
35
+ 'OUTPUT_DETAILED': False,
36
+ 'PLOT_CURVES': False,
37
+ }
38
+ return default_config
39
+
40
+ def __init__(self, config=None):
41
+ """Initialise the evaluator with a config file"""
42
+ self.config = utils.init_config(config, self.get_default_eval_config(), 'Eval')
43
+ # Only run timing analysis if not run in parallel.
44
+ if self.config['TIME_PROGRESS'] and not self.config['USE_PARALLEL']:
45
+ _timing.DO_TIMING = True
46
+ if self.config['DISPLAY_LESS_PROGRESS']:
47
+ _timing.DISPLAY_LESS_PROGRESS = True
48
+
49
+ @_timing.time
50
+ def evaluate(self, dataset_list, metrics_list):
51
+ """Evaluate a set of metrics on a set of datasets"""
52
+ config = self.config
53
+ metrics_list = metrics_list + [Count()] # Count metrics are always run
54
+ metric_names = utils.validate_metrics_list(metrics_list)
55
+ dataset_names = [dataset.get_name() for dataset in dataset_list]
56
+ output_res = {}
57
+ output_msg = {}
58
+
59
+ for dataset, dataset_name in zip(dataset_list, dataset_names):
60
+ # Get dataset info about what to evaluate
61
+ output_res[dataset_name] = {}
62
+ output_msg[dataset_name] = {}
63
+ tracker_list, seq_list, class_list = dataset.get_eval_info()
64
+
65
+ # Evaluate each tracker
66
+ for tracker in tracker_list:
67
+ # if not config['BREAK_ON_ERROR'] then go to next tracker without breaking
68
+ try:
69
+ print('\nEvaluating model ...... \n')
70
+ res = {}
71
+ res_av_loc = {}
72
+
73
+ seq_list_sorted = sorted(seq_list)
74
+ for curr_seq in tqdm.tqdm(seq_list_sorted):
75
+ res[curr_seq] = eval_sequence(curr_seq, dataset, tracker, class_list, metrics_list, metric_names)
76
+ res_av_loc[curr_seq] = eval_av_loc_sequence(curr_seq, dataset, tracker)
77
+
78
+ # Combine results over all sequences and then over all classes
79
+ res_av_loc_all = combine_av_loc_sequences(res_av_loc)
80
+
81
+ # collecting combined cls keys (cls averaged, det averaged, super classes)
82
+ combined_cls_keys = []
83
+ res['COMBINED_SEQ'] = {}
84
+ # combine sequences for each class
85
+ for c_cls in class_list:
86
+ res['COMBINED_SEQ'][c_cls] = {}
87
+ for metric, metric_name in zip(metrics_list, metric_names):
88
+ curr_res = {seq_key: seq_value[c_cls][metric_name] for seq_key, seq_value in res.items() if
89
+ seq_key != 'COMBINED_SEQ'}
90
+ res['COMBINED_SEQ'][c_cls][metric_name] = metric.combine_sequences(curr_res)
91
+ # combine classes
92
+ if dataset.should_classes_combine:
93
+ combined_cls_keys += ['cls_comb_cls_av', 'cls_comb_det_av', 'all']
94
+ res['COMBINED_SEQ']['cls_comb_cls_av'] = {}
95
+ res['COMBINED_SEQ']['cls_comb_det_av'] = {}
96
+ for metric, metric_name in zip(metrics_list, metric_names):
97
+ cls_res = {cls_key: cls_value[metric_name] for cls_key, cls_value in
98
+ res['COMBINED_SEQ'].items() if cls_key not in combined_cls_keys}
99
+ res['COMBINED_SEQ']['cls_comb_cls_av'][metric_name] = \
100
+ metric.combine_classes_class_averaged(cls_res)
101
+ res['COMBINED_SEQ']['cls_comb_det_av'][metric_name] = \
102
+ metric.combine_classes_det_averaged(cls_res)
103
+ # combine classes to super classes
104
+ if dataset.use_super_categories:
105
+ for cat, sub_cats in dataset.super_categories.items():
106
+ combined_cls_keys.append(cat)
107
+ res['COMBINED_SEQ'][cat] = {}
108
+ for metric, metric_name in zip(metrics_list, metric_names):
109
+ cat_res = {cls_key: cls_value[metric_name] for cls_key, cls_value in
110
+ res['COMBINED_SEQ'].items() if cls_key in sub_cats}
111
+ res['COMBINED_SEQ'][cat][metric_name] = metric.combine_classes_det_averaged(cat_res)
112
+
113
+ # Print and output results in various formats
114
+ output_fol = dataset.get_output_fol(tracker)
115
+ tracker_display_name = dataset.get_display_name(tracker)
116
+ for c_cls in res['COMBINED_SEQ'].keys(): # class_list + combined classes if calculated
117
+ summaries = []
118
+ details = []
119
+ num_dets = res['COMBINED_SEQ'][c_cls]['Count']['Dets']
120
+ if config['OUTPUT_EMPTY_CLASSES'] or num_dets > 0:
121
+ for metric, metric_name in zip(metrics_list, metric_names):
122
+ # for combined classes there is no per sequence evaluation
123
+ if c_cls in combined_cls_keys:
124
+ table_res = {'COMBINED_SEQ': res['COMBINED_SEQ'][c_cls][metric_name]}
125
+ else:
126
+ table_res = {seq_key: seq_value[c_cls][metric_name] for seq_key, seq_value in res.items()}
127
+ if config['PLOT_CURVES']:
128
+ metric.plot_single_tracker_results(table_res, tracker_display_name, c_cls, output_fol)
129
+ if config['OUTPUT_SUMMARY']:
130
+ utils.write_summary_results(summaries, c_cls, output_fol)
131
+ if config['OUTPUT_DETAILED']:
132
+ utils.write_detailed_results(details, c_cls, output_fol)
133
+
134
+ # Output for returning from function
135
+ res_output = {}
136
+
137
+ res_output["AP_all"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['TrackMAP']['AP_all']), 2)
138
+ res_output["AP_s"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['TrackMAP']['AP_area_s']), 2)
139
+ res_output["AP_m"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['TrackMAP']['AP_area_m']), 2)
140
+ res_output["AP_l"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['TrackMAP']['AP_area_l']), 2)
141
+ res_output["AR_all"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['TrackMAP']['AR_all']), 2)
142
+
143
+ res_output["HOTA"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['HOTA']['HOTA']), 2)
144
+ res_output["DetA"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['HOTA']['DetA']), 2)
145
+ res_output["DetRe"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['HOTA']['DetRe']), 2)
146
+ res_output["DetPr"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['HOTA']['DetPr']), 2)
147
+ res_output["AssA"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['HOTA']['AssA']), 2)
148
+ res_output["AssRe"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['HOTA']['AssRe']), 2)
149
+ res_output["AssPr"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['HOTA']['AssPr']), 2)
150
+ res_output["LocA"] = round(100 * np.mean(res['COMBINED_SEQ']['cls_comb_cls_av']['HOTA']['LocA']), 2)
151
+
152
+ res_output["FA"] = round(100 * np.mean(res_av_loc_all['FA']), 2)
153
+ res_output["FAn"] = round(100 * np.mean(res_av_loc_all['FAn']), 2)
154
+ res_output['FAn_count'] = int(np.mean(res_av_loc_all['FAn_count']))
155
+ res_output['FAn_all'] = int(np.mean(res_av_loc_all['FAn_all']))
156
+ res_output["FAs"] = round(100 * np.mean(res_av_loc_all['FAs']), 2)
157
+ res_output['FAs_count'] = int(np.mean(res_av_loc_all['FAs_count']))
158
+ res_output['FAs_all'] = int(np.mean(res_av_loc_all['FAs_all']))
159
+ res_output["FAm"] = round(100 * np.mean(res_av_loc_all['FAm']), 2)
160
+ res_output['FAm_count'] = int(np.mean(res_av_loc_all['FAm_count']))
161
+ res_output['FAm_all'] = int(np.mean(res_av_loc_all['FAm_all']))
162
+
163
+ output_res[dataset_name][tracker] = res_output
164
+ output_msg[dataset_name][tracker] = 'Success'
165
+
166
+ except Exception as err:
167
+ output_res[dataset_name][tracker] = None
168
+ if type(err) == TrackEvalException:
169
+ output_msg[dataset_name][tracker] = str(err)
170
+ else:
171
+ output_msg[dataset_name][tracker] = 'Unknown error occurred.'
172
+ print('Tracker %s was unable to be evaluated.' % tracker)
173
+ print(err)
174
+ traceback.print_exc()
175
+ if config['LOG_ON_ERROR'] is not None:
176
+ with open(config['LOG_ON_ERROR'], 'a') as f:
177
+ print(dataset_name, file=f)
178
+ print(tracker, file=f)
179
+ print(traceback.format_exc(), file=f)
180
+ print('\n\n\n', file=f)
181
+ if config['BREAK_ON_ERROR']:
182
+ raise err
183
+ elif config['RETURN_ON_ERROR']:
184
+ return output_res, output_msg
185
+
186
+ return output_res, output_msg
187
+
188
+
189
+ @_timing.time
190
+ def eval_sequence(seq, dataset, tracker, class_list, metrics_list, metric_names):
191
+ """Function for evaluating a single sequence"""
192
+
193
+ raw_data = dataset.get_raw_seq_data(tracker, seq)
194
+ seq_res = {}
195
+ for cls in class_list:
196
+ seq_res[cls] = {}
197
+ data = dataset.get_preprocessed_seq_data(raw_data, cls)
198
+ for metric, met_name in zip(metrics_list, metric_names):
199
+ seq_res[cls][met_name] = metric.eval_sequence(data)
200
+ return seq_res
201
+
202
+
203
+ def eval_av_loc_sequence(seq, dataset, tracker):
204
+ """Function for evaluating a single sequence"""
205
+
206
+ raw_data = dataset.get_raw_seq_data(tracker, seq)
207
+ av_loc_res = compute_av_loc(raw_data)
208
+
209
+ return av_loc_res
avism/data/aviseval/metrics/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .hota import HOTA
2
+ from .clear import CLEAR
3
+ from .identity import Identity
4
+ from .count import Count
5
+ from .j_and_f import JAndF
6
+ from .track_map import TrackMAP
7
+ from .vace import VACE
8
+ from .ideucl import IDEucl
9
+
10
+ from .avisa import avisA
11
+ from .av_loc import compute_av_loc, combine_av_loc_sequences
avism/data/aviseval/metrics/_base_metric.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from abc import ABC, abstractmethod
3
+ from .. import _timing
4
+ from ..utils import TrackEvalException
5
+
6
+
7
+ class _BaseMetric(ABC):
8
+ @abstractmethod
9
+ def __init__(self):
10
+ self.plottable = False
11
+ self.integer_fields = []
12
+ self.float_fields = []
13
+ self.array_labels = []
14
+ self.integer_array_fields = []
15
+ self.float_array_fields = []
16
+ self.fields = []
17
+ self.summary_fields = []
18
+ self.registered = False
19
+
20
+ #####################################################################
21
+ # Abstract functions for subclasses to implement
22
+
23
+ @_timing.time
24
+ @abstractmethod
25
+ def eval_sequence(self, data):
26
+ ...
27
+
28
+ @abstractmethod
29
+ def combine_sequences(self, all_res):
30
+ ...
31
+
32
+ @abstractmethod
33
+ def combine_classes_class_averaged(self, all_res, ignore_empty_classes=False):
34
+ ...
35
+
36
+ @ abstractmethod
37
+ def combine_classes_det_averaged(self, all_res):
38
+ ...
39
+
40
+ def plot_single_tracker_results(self, all_res, tracker, output_folder, cls):
41
+ """Plot results of metrics, only valid for metrics with self.plottable"""
42
+ if self.plottable:
43
+ raise NotImplementedError('plot_results is not implemented for metric %s' % self.get_name())
44
+ else:
45
+ pass
46
+
47
+ #####################################################################
48
+ # Helper functions which are useful for all metrics:
49
+
50
+ @classmethod
51
+ def get_name(cls):
52
+ return cls.__name__
53
+
54
+ @staticmethod
55
+ def _combine_sum(all_res, field):
56
+ """Combine sequence results via sum"""
57
+ return sum([all_res[k][field] for k in all_res.keys()])
58
+
59
+ @staticmethod
60
+ def _combine_weighted_av(all_res, field, comb_res, weight_field):
61
+ """Combine sequence results via weighted average"""
62
+ return sum([all_res[k][field] * all_res[k][weight_field] for k in all_res.keys()]) / np.maximum(1.0, comb_res[
63
+ weight_field])
64
+
65
+ def print_table(self, table_res, tracker, cls):
66
+ """Prints table of results for all sequences"""
67
+ print('')
68
+ metric_name = self.get_name()
69
+ self._row_print([metric_name + ': ' + tracker + '-' + cls] + self.summary_fields)
70
+ for seq, results in sorted(table_res.items()):
71
+ if seq == 'COMBINED_SEQ':
72
+ continue
73
+ summary_res = self._summary_row(results)
74
+ self._row_print([seq] + summary_res)
75
+ summary_res = self._summary_row(table_res['COMBINED_SEQ'])
76
+ self._row_print(['COMBINED'] + summary_res)
77
+
78
+ def _summary_row(self, results_):
79
+ vals = []
80
+ for h in self.summary_fields:
81
+ if h in self.float_array_fields:
82
+ vals.append("{0:1.5g}".format(100 * np.mean(results_[h])))
83
+ elif h in self.float_fields:
84
+ vals.append("{0:1.5g}".format(100 * float(results_[h])))
85
+ elif h in self.integer_fields:
86
+ vals.append("{0:d}".format(int(results_[h])))
87
+ else:
88
+ raise NotImplementedError("Summary function not implemented for this field type.")
89
+ return vals
90
+
91
+ @staticmethod
92
+ def _row_print(*argv):
93
+ """Prints results in an evenly spaced rows, with more space in first row"""
94
+ if len(argv) == 1:
95
+ argv = argv[0]
96
+ to_print = '%-35s' % argv[0]
97
+ for v in argv[1:]:
98
+ to_print += '%-10s' % str(v)
99
+ print(to_print)
100
+
101
+ def summary_results(self, table_res):
102
+ """Returns a simple summary of final results for a tracker"""
103
+ return dict(zip(self.summary_fields, self._summary_row(table_res['COMBINED_SEQ'])))
104
+
105
+ def detailed_results(self, table_res):
106
+ """Returns detailed final results for a tracker"""
107
+ # Get detailed field information
108
+ detailed_fields = self.float_fields + self.integer_fields
109
+ for h in self.float_array_fields + self.integer_array_fields:
110
+ for alpha in [int(100*x) for x in self.array_labels]:
111
+ detailed_fields.append(h + '___' + str(alpha))
112
+ detailed_fields.append(h + '___AUC')
113
+
114
+ # Get detailed results
115
+ detailed_results = {}
116
+ for seq, res in table_res.items():
117
+ detailed_row = self._detailed_row(res)
118
+ if len(detailed_row) != len(detailed_fields):
119
+ raise TrackEvalException(
120
+ 'Field names and data have different sizes (%i and %i)' % (len(detailed_row), len(detailed_fields)))
121
+ detailed_results[seq] = dict(zip(detailed_fields, detailed_row))
122
+ return detailed_results
123
+
124
+ def _detailed_row(self, res):
125
+ detailed_row = []
126
+ for h in self.float_fields + self.integer_fields:
127
+ detailed_row.append(res[h])
128
+ for h in self.float_array_fields + self.integer_array_fields:
129
+ for i, alpha in enumerate([int(100 * x) for x in self.array_labels]):
130
+ detailed_row.append(res[h][i])
131
+ detailed_row.append(np.mean(res[h]))
132
+ return detailed_row
avism/data/aviseval/metrics/av_loc.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from pycocotools import mask as mask_utils
3
+ from scipy.optimize import linear_sum_assignment
4
+
5
+
6
+ def compute_av_loc(data):
7
+ alphas = np.arange(0.05, 0.99, 0.05)
8
+ res = {}
9
+ res['FA'] = np.zeros((len(alphas)), dtype=float)
10
+ res['FAn'] = np.array([None] * len(alphas)) # frame accuracy in no sound source
11
+ res['FAs'] = np.array([None] * len(alphas)) # frame accuracy in single sound source
12
+ res['FAm'] = np.array([None] * len(alphas)) # frame accuracy in multi sound source
13
+
14
+ res['frame_num_n_all'] = np.zeros((len(alphas)), dtype=int)
15
+ res['frame_num_n_tp'] = np.zeros((len(alphas)), dtype=int)
16
+ res['frame_num_s_all'] = np.zeros((len(alphas)), dtype=int)
17
+ res['frame_num_s_tp'] = np.zeros((len(alphas)), dtype=int)
18
+ res['frame_num_m_all'] = np.zeros((len(alphas)), dtype=int)
19
+ res['frame_num_m_tp'] = np.zeros((len(alphas)), dtype=int)
20
+
21
+ frame_num_all = data['num_timesteps']
22
+ gt_classes = data['gt_classes']
23
+ gt_dets = data['gt_dets']
24
+ raw_classes = data['raw_classes']
25
+ raw_dets = data['raw_dets']
26
+ pred_classes = data['tracker_classes']
27
+ pred_dets = data['tracker_dets']
28
+
29
+ # 1. Find the best trajectory between gt and pred
30
+ unique_gt_ids = []
31
+ unique_tracker_ids = []
32
+ for t in range(data['num_timesteps']):
33
+ unique_gt_ids += list(np.unique(data['gt_ids'][t]))
34
+ unique_tracker_ids += list(np.unique(data['tracker_ids'][t]))
35
+ # Re-label IDs such that there are no empty IDs
36
+ if len(unique_gt_ids) > 0:
37
+ unique_gt_ids = np.unique(unique_gt_ids)
38
+ gt_id_map = np.nan * np.ones((np.max(unique_gt_ids) + 1))
39
+ gt_id_map[unique_gt_ids] = np.arange(len(unique_gt_ids))
40
+ for t in range(data['num_timesteps']):
41
+ if len(data['gt_ids'][t]) > 0:
42
+ data['gt_ids'][t] = gt_id_map[data['gt_ids'][t]].astype(int)
43
+ if len(unique_tracker_ids) > 0:
44
+ unique_tracker_ids = np.unique(unique_tracker_ids)
45
+ tracker_id_map = np.nan * np.ones((np.max(unique_tracker_ids) + 1))
46
+ tracker_id_map[unique_tracker_ids] = np.arange(len(unique_tracker_ids))
47
+ for t in range(data['num_timesteps']):
48
+ if len(data['tracker_ids'][t]) > 0:
49
+ data['tracker_ids'][t] = tracker_id_map[data['tracker_ids'][t]].astype(int)
50
+ data['num_tracker_ids'] = len(unique_tracker_ids)
51
+ data['num_gt_ids'] = len(unique_gt_ids)
52
+ # Variables counting global association
53
+ potential_matches_count = np.zeros((data['num_gt_ids'], data['num_tracker_ids']))
54
+ gt_id_count = np.zeros((data['num_gt_ids'], 1))
55
+ tracker_id_count = np.zeros((1, data['num_tracker_ids']))
56
+ # First loop through each timestep and accumulate global track information.
57
+ for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(data['gt_ids'], data['tracker_ids'])):
58
+ # Count the potential matches between ids in each timestep
59
+ # These are normalised, weighted by the match similarity.
60
+ similarity = data['similarity_scores'][t]
61
+ sim_iou_denom = similarity.sum(0)[np.newaxis, :] + similarity.sum(1)[:, np.newaxis] - similarity
62
+ sim_iou = np.zeros_like(similarity)
63
+ sim_iou_mask = sim_iou_denom > 0 + np.finfo('float').eps
64
+ sim_iou[sim_iou_mask] = similarity[sim_iou_mask] / sim_iou_denom[sim_iou_mask]
65
+ potential_matches_count[gt_ids_t[:, np.newaxis], tracker_ids_t[np.newaxis, :]] += sim_iou
66
+ # Calculate the total number of dets for each gt_id and tracker_id.
67
+ gt_id_count[gt_ids_t] += 1
68
+ tracker_id_count[0, tracker_ids_t] += 1
69
+ # Calculate overall jaccard alignment score (before unique matching) between IDs
70
+ global_alignment_score = potential_matches_count / (gt_id_count + tracker_id_count - potential_matches_count)
71
+ # Hungarian algorithm to find best matches
72
+ match_rows, match_cols = linear_sum_assignment(-global_alignment_score)
73
+
74
+ # 2. Compute FSLA
75
+ for a, alpha in enumerate(alphas):
76
+ frame_num_n_all = 0 # total frames in no sound source
77
+ frame_num_s_all = 0 # total frames in single sound source
78
+ frame_num_m_all = 0 # total frames in multi sound source
79
+ frame_num_n_tp = 0 # true positive frames in no sound source
80
+ frame_num_s_tp = 0 # true positive frames in single sound source
81
+ frame_num_m_tp = 0 # true positive frames in multi sound source
82
+
83
+ for frame_id in range(frame_num_all):
84
+ # classes
85
+ gt_classes_per_frame = gt_classes[frame_id]
86
+ raw_classes_per_frame = raw_classes[frame_id]
87
+ pred_classes_per_frame = pred_classes[frame_id]
88
+ # masks
89
+ gt_dets_per_frame = gt_dets[frame_id]
90
+ raw_dets_per_frame = raw_dets[frame_id]
91
+ pred_dets_per_frame = pred_dets[frame_id]
92
+
93
+ if len(pred_dets_per_frame) > 0:
94
+ pred_dets_per_frame_f = [di for di in pred_dets_per_frame if di['counts'] != 'PPTl0']
95
+ else:
96
+ pred_dets_per_frame_f = pred_dets_per_frame
97
+
98
+ # Masks must have the same class and number
99
+ if (set(gt_classes_per_frame) == set(pred_classes_per_frame)) and (len(gt_dets_per_frame) == len(pred_dets_per_frame_f)):
100
+ # 1) no sound source
101
+ if len(gt_dets_per_frame) == 0:
102
+ frame_num_n_all += 1
103
+ frame_num_n_tp += 1
104
+ # 2) single sound source
105
+ elif len(gt_dets_per_frame) == 1:
106
+ frame_num_s_all += 1
107
+ index_gt = [index for index, value in enumerate(raw_dets_per_frame) if value is not None][0]
108
+ index_pred = [index for index, element in enumerate(match_cols) if element == index_gt]
109
+ if index_pred != []:
110
+ ious = mask_utils.iou(gt_dets_per_frame, [pred_dets_per_frame[index_pred[0]]], [False])
111
+ if np.all(ious > alpha):
112
+ frame_num_s_tp += 1
113
+ # 3) multi sound source
114
+ else:
115
+ frame_num_m_all += 1
116
+ flags = [0] * len(match_rows)
117
+ for tr in range(len(match_rows)):
118
+ if (raw_classes_per_frame[match_rows[tr]] == pred_classes_per_frame[match_cols[tr]]):
119
+ if raw_dets_per_frame[match_rows[tr]] == None:
120
+ if pred_dets_per_frame[match_cols[tr]]['counts'] == 'PPTl0':
121
+ flags[tr] = 1
122
+ else:
123
+ iou = mask_utils.iou([raw_dets_per_frame[match_rows[tr]]],
124
+ [pred_dets_per_frame[match_cols[tr]]], [False])
125
+ if np.all(iou > alpha):
126
+ flags[tr] = 1
127
+ if all(ff == 1 for ff in flags):
128
+ frame_num_m_tp += 1
129
+ else:
130
+ if len(gt_dets_per_frame) == 0:
131
+ frame_num_n_all += 1
132
+ elif len(gt_dets_per_frame) == 1:
133
+ frame_num_s_all += 1
134
+ else:
135
+ frame_num_m_all += 1
136
+
137
+ assert frame_num_all == (frame_num_n_all + frame_num_s_all + frame_num_m_all)
138
+
139
+ if frame_num_n_all > 0:
140
+ res['FAn'][a] = frame_num_n_tp / frame_num_n_all
141
+ res['frame_num_n_all'][a] = frame_num_n_all
142
+ res['frame_num_n_tp'][a] = frame_num_n_tp
143
+ else:
144
+ res['FAn'][a] = None
145
+ res['frame_num_n_all'][a] = 0
146
+ res['frame_num_n_tp'][a] = 0
147
+
148
+ if frame_num_s_all > 0:
149
+ res['FAs'][a] = frame_num_s_tp / frame_num_s_all
150
+ res['frame_num_s_all'][a] = frame_num_s_all
151
+ res['frame_num_s_tp'][a] = frame_num_s_tp
152
+ else:
153
+ res['FAs'][a] = None
154
+ res['frame_num_s_all'][a] = 0
155
+ res['frame_num_s_tp'][a] = 0
156
+
157
+ if frame_num_m_all > 0:
158
+ res['FAm'][a] = frame_num_m_tp / frame_num_m_all
159
+ res['frame_num_m_all'][a] = frame_num_m_all
160
+ res['frame_num_m_tp'][a] = frame_num_m_tp
161
+ else:
162
+ res['FAm'][a] = None
163
+ res['frame_num_m_all'][a] = 0
164
+ res['frame_num_m_tp'][a] = 0
165
+
166
+ res['FA'][a] = (frame_num_n_tp + frame_num_s_tp + frame_num_m_tp) / frame_num_all
167
+
168
+ return res
169
+
170
+
171
+ def combine_av_loc_sequences(all_res):
172
+ """Combines metrics across all sequences"""
173
+ res = {}
174
+ fields_num = ['frame_num_n_all', 'frame_num_s_all', 'frame_num_m_all', 'frame_num_n_tp', 'frame_num_s_tp', 'frame_num_m_tp']
175
+ for field in fields_num:
176
+ res[field] = sum([all_res[k][field] for k in all_res.keys()])
177
+
178
+ res_final = {}
179
+
180
+ res_final['FAn'] = res['frame_num_n_tp'] / res['frame_num_n_all']
181
+ res_final['FAn_count'] = res['frame_num_n_tp']
182
+ res_final['FAn_all'] = res['frame_num_n_all']
183
+ res_final['FAs'] = res['frame_num_s_tp'] / res['frame_num_s_all']
184
+ res_final['FAs_count'] = res['frame_num_s_tp']
185
+ res_final['FAs_all'] = res['frame_num_s_all']
186
+ res_final['FAm'] = res['frame_num_m_tp'] / res['frame_num_m_all']
187
+ res_final['FAm_count'] = res['frame_num_m_tp']
188
+ res_final['FAm_all'] = res['frame_num_m_all']
189
+ res_final['FA'] = (res['frame_num_n_tp'] + res['frame_num_s_tp'] + res['frame_num_m_tp']) / (res['frame_num_n_all'] + res['frame_num_s_all'] + res['frame_num_m_all'])
190
+
191
+ return res_final
avism/data/aviseval/metrics/avisa.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from scipy.optimize import linear_sum_assignment
4
+ from ._base_metric import _BaseMetric
5
+ from .. import _timing
6
+
7
+
8
+ class avisA(_BaseMetric):
9
+ def __init__(self, config=None):
10
+ super().__init__()
11
+ self.plottable = True
12
+ self.array_labels = np.arange(0.05, 0.99, 0.05)
13
+ self.integer_array_fields = ['HOTA_TP', 'HOTA_FN', 'HOTA_FP']
14
+ self.float_array_fields = ['AssA', 'AssRe', 'AssPr', 'SegA']
15
+ self.float_fields = ['SegA(0)']
16
+ self.fields = self.float_array_fields + self.integer_array_fields + self.float_fields
17
+ self.summary_fields = self.float_array_fields + self.float_fields
18
+
19
+ @_timing.time
20
+ def eval_sequence(self, data):
21
+ """Calculates the AssA and SegA metrics for one sequence"""
22
+
23
+ # Initialise results
24
+ res = {}
25
+ for field in self.float_array_fields + self.integer_array_fields:
26
+ res[field] = np.zeros((len(self.array_labels)), dtype=float)
27
+ for field in self.float_fields:
28
+ res[field] = 0
29
+
30
+ # Return result quickly if tracker or gt sequence is empty
31
+ if data['num_tracker_dets'] == 0:
32
+ res['HOTA_FN'] = data['num_gt_dets'] * np.ones((len(self.array_labels)), dtype=float)
33
+ res['SegA'] = np.ones((len(self.array_labels)), dtype=float)
34
+ res['SegA(0)'] = 1.0
35
+ return res
36
+ if data['num_gt_dets'] == 0:
37
+ res['HOTA_FP'] = data['num_tracker_dets'] * np.ones((len(self.array_labels)), dtype=float)
38
+ res['SegA'] = np.ones((len(self.array_labels)), dtype=float)
39
+ res['SegA(0)'] = 1.0
40
+ return res
41
+
42
+ # Variables counting global association
43
+ potential_matches_count = np.zeros((data['num_gt_ids'], data['num_tracker_ids']))
44
+ gt_id_count = np.zeros((data['num_gt_ids'], 1))
45
+ tracker_id_count = np.zeros((1, data['num_tracker_ids']))
46
+
47
+ # First loop through each timestep and accumulate global track information.
48
+ for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(data['gt_ids'], data['tracker_ids'])):
49
+ # Count the potential matches between ids in each timestep
50
+ # These are normalised, weighted by the match similarity.
51
+ similarity = data['similarity_scores'][t]
52
+ sim_iou_denom = similarity.sum(0)[np.newaxis, :] + similarity.sum(1)[:, np.newaxis] - similarity
53
+ sim_iou = np.zeros_like(similarity)
54
+ sim_iou_mask = sim_iou_denom > 0 + np.finfo('float').eps
55
+ sim_iou[sim_iou_mask] = similarity[sim_iou_mask] / sim_iou_denom[sim_iou_mask]
56
+ potential_matches_count[gt_ids_t[:, np.newaxis], tracker_ids_t[np.newaxis, :]] += sim_iou
57
+
58
+ # Calculate the total number of dets for each gt_id and tracker_id.
59
+ gt_id_count[gt_ids_t] += 1
60
+ tracker_id_count[0, tracker_ids_t] += 1
61
+
62
+ # Calculate overall jaccard alignment score (before unique matching) between IDs
63
+ global_alignment_score = potential_matches_count / (gt_id_count + tracker_id_count - potential_matches_count)
64
+ matches_counts = [np.zeros_like(potential_matches_count) for _ in self.array_labels]
65
+
66
+ # Calculate scores for each timestep
67
+ for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(data['gt_ids'], data['tracker_ids'])):
68
+ # Deal with the case that there are no gt_det/tracker_det in a timestep.
69
+ if len(gt_ids_t) == 0:
70
+ for a, alpha in enumerate(self.array_labels):
71
+ res['HOTA_FP'][a] += len(tracker_ids_t)
72
+ continue
73
+ if len(tracker_ids_t) == 0:
74
+ for a, alpha in enumerate(self.array_labels):
75
+ res['HOTA_FN'][a] += len(gt_ids_t)
76
+ continue
77
+
78
+ # Get matching scores between pairs of dets for optimizing HOTA
79
+ similarity = data['similarity_scores'][t]
80
+ score_mat = global_alignment_score[gt_ids_t[:, np.newaxis], tracker_ids_t[np.newaxis, :]] * similarity
81
+
82
+ # Hungarian algorithm to find best matches: 找出最优跟踪路线
83
+ match_rows, match_cols = linear_sum_assignment(-score_mat)
84
+
85
+ # Calculate and accumulate basic statistics
86
+ for a, alpha in enumerate(self.array_labels):
87
+ actually_matched_mask = similarity[match_rows, match_cols] >= alpha - np.finfo('float').eps
88
+ alpha_match_rows = match_rows[actually_matched_mask]
89
+ alpha_match_cols = match_cols[actually_matched_mask]
90
+ num_matches = len(alpha_match_rows)
91
+ res['HOTA_TP'][a] += num_matches
92
+ res['HOTA_FN'][a] += len(gt_ids_t) - num_matches
93
+ res['HOTA_FP'][a] += len(tracker_ids_t) - num_matches
94
+ if num_matches > 0:
95
+ res['SegA'][a] += sum(similarity[alpha_match_rows, alpha_match_cols])
96
+ matches_counts[a][gt_ids_t[alpha_match_rows], tracker_ids_t[alpha_match_cols]] += 1
97
+
98
+ # Calculate association scores (AssA, AssRe, AssPr) for the alpha value.
99
+ # First calculate scores per gt_id/tracker_id combo and then average over the number of detections.
100
+ for a, alpha in enumerate(self.array_labels):
101
+ matches_count = matches_counts[a]
102
+ ass_a = matches_count / np.maximum(1, gt_id_count + tracker_id_count - matches_count)
103
+ res['AssA'][a] = np.sum(matches_count * ass_a) / np.maximum(1, res['HOTA_TP'][a])
104
+ ass_re = matches_count / np.maximum(1, gt_id_count)
105
+ res['AssRe'][a] = np.sum(matches_count * ass_re) / np.maximum(1, res['HOTA_TP'][a])
106
+ ass_pr = matches_count / np.maximum(1, tracker_id_count)
107
+ res['AssPr'][a] = np.sum(matches_count * ass_pr) / np.maximum(1, res['HOTA_TP'][a])
108
+
109
+ # Calculate final scores
110
+ res['SegA'] = np.maximum(1e-10, res['SegA']) / np.maximum(1e-10, res['HOTA_TP'])
111
+ res = self._compute_final_fields(res)
112
+ return res
113
+
114
+ def combine_sequences(self, all_res):
115
+ """Combines metrics across all sequences"""
116
+ res = {}
117
+ for field in self.integer_array_fields:
118
+ res[field] = self._combine_sum(all_res, field)
119
+ for field in ['AssRe', 'AssPr', 'AssA']:
120
+ res[field] = self._combine_weighted_av(all_res, field, res, weight_field='HOTA_TP')
121
+ sega_weighted_sum = sum([all_res[k]['SegA'] * all_res[k]['HOTA_TP'] for k in all_res.keys()])
122
+ res['SegA'] = np.maximum(1e-10, sega_weighted_sum) / np.maximum(1e-10, res['HOTA_TP'])
123
+ res = self._compute_final_fields(res)
124
+ return res
125
+
126
+ def combine_classes_class_averaged(self, all_res, ignore_empty_classes=False):
127
+ """Combines metrics across all classes by averaging over the class values.
128
+ If 'ignore_empty_classes' is True, then it only sums over classes with at least one gt or predicted detection.
129
+ """
130
+ res = {}
131
+ for field in self.integer_array_fields:
132
+ if ignore_empty_classes:
133
+ res[field] = self._combine_sum(
134
+ {k: v for k, v in all_res.items()
135
+ if (v['HOTA_TP'] + v['HOTA_FN'] + v['HOTA_FP'] > 0 + np.finfo('float').eps).any()}, field)
136
+ else:
137
+ res[field] = self._combine_sum({k: v for k, v in all_res.items()}, field)
138
+
139
+ for field in self.float_fields + self.float_array_fields:
140
+ if ignore_empty_classes:
141
+ res[field] = np.mean([v[field] for v in all_res.values() if
142
+ (v['HOTA_TP'] + v['HOTA_FN'] + v['HOTA_FP'] > 0 + np.finfo('float').eps).any()],
143
+ axis=0)
144
+ else:
145
+ res[field] = np.mean([v[field] for v in all_res.values()], axis=0)
146
+ return res
147
+
148
+ def combine_classes_det_averaged(self, all_res):
149
+ """Combines metrics across all classes by averaging over the detection values"""
150
+ res = {}
151
+ for field in self.integer_array_fields:
152
+ res[field] = self._combine_sum(all_res, field)
153
+ for field in ['AssRe', 'AssPr', 'AssA']:
154
+ res[field] = self._combine_weighted_av(all_res, field, res, weight_field='HOTA_TP')
155
+ sega_weighted_sum = sum([all_res[k]['SegA'] * all_res[k]['HOTA_TP'] for k in all_res.keys()])
156
+ res['SegA'] = np.maximum(1e-10, sega_weighted_sum) / np.maximum(1e-10, res['HOTA_TP'])
157
+ res = self._compute_final_fields(res)
158
+ return res
159
+
160
+ @staticmethod
161
+ def _compute_final_fields(res):
162
+ """Calculate sub-metric ('field') values which only depend on other sub-metric values.
163
+ This function is used both for both per-sequence calculation, and in combining values across sequences.
164
+ """
165
+ res['SegA(0)'] = res['SegA'][0]
166
+ return res
167
+
168
+ def plot_single_tracker_results(self, table_res, tracker, cls, output_folder):
169
+ """Create plot of results"""
170
+
171
+ # Only loaded when run to reduce minimum requirements
172
+ from matplotlib import pyplot as plt
173
+
174
+ res = table_res['COMBINED_SEQ']
175
+ styles_to_plot = ['r', 'b', 'g', 'b--', 'b:', 'g--', 'g:', 'm']
176
+ for name, style in zip(self.float_array_fields, styles_to_plot):
177
+ plt.plot(self.array_labels, res[name], style)
178
+ plt.xlabel('alpha')
179
+ plt.ylabel('score')
180
+ plt.title(tracker + ' - ' + cls)
181
+ plt.axis([0, 1, 0, 1])
182
+ legend = []
183
+ for name in self.float_array_fields:
184
+ legend += [name + ' (' + str(np.round(np.mean(res[name]), 2)) + ')']
185
+ plt.legend(legend, loc='lower left')
186
+ out_file = os.path.join(output_folder, cls + '_plot.pdf')
187
+ os.makedirs(os.path.dirname(out_file), exist_ok=True)
188
+ plt.savefig(out_file)
189
+ plt.savefig(out_file.replace('.pdf', '.png'))
190
+ plt.clf()
avism/data/aviseval/metrics/clear.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ from scipy.optimize import linear_sum_assignment
4
+ from ._base_metric import _BaseMetric
5
+ from .. import _timing
6
+ from .. import utils
7
+
8
+ class CLEAR(_BaseMetric):
9
+ """Class which implements the CLEAR metrics"""
10
+
11
+ @staticmethod
12
+ def get_default_config():
13
+ """Default class config values"""
14
+ default_config = {
15
+ 'THRESHOLD': 0.5, # Similarity score threshold required for a TP match. Default 0.5.
16
+ 'PRINT_CONFIG': True, # Whether to print the config information on init. Default: False.
17
+ }
18
+ return default_config
19
+
20
+ def __init__(self, config=None):
21
+ super().__init__()
22
+ main_integer_fields = ['CLR_TP', 'CLR_FN', 'CLR_FP', 'IDSW', 'MT', 'PT', 'ML', 'Frag']
23
+ extra_integer_fields = ['CLR_Frames']
24
+ self.integer_fields = main_integer_fields + extra_integer_fields
25
+ main_float_fields = ['MOTA', 'MOTP', 'MODA', 'CLR_Re', 'CLR_Pr', 'MTR', 'PTR', 'MLR', 'sMOTA']
26
+ extra_float_fields = ['CLR_F1', 'FP_per_frame', 'MOTAL', 'MOTP_sum']
27
+ self.float_fields = main_float_fields + extra_float_fields
28
+ self.fields = self.float_fields + self.integer_fields
29
+ self.summed_fields = self.integer_fields + ['MOTP_sum']
30
+ self.summary_fields = main_float_fields + main_integer_fields
31
+
32
+ # Configuration options:
33
+ self.config = utils.init_config(config, self.get_default_config(), self.get_name())
34
+ self.threshold = float(self.config['THRESHOLD'])
35
+
36
+
37
+ @_timing.time
38
+ def eval_sequence(self, data):
39
+ """Calculates CLEAR metrics for one sequence"""
40
+ # Initialise results
41
+ res = {}
42
+ for field in self.fields:
43
+ res[field] = 0
44
+
45
+ # Return result quickly if tracker or gt sequence is empty
46
+ if data['num_tracker_dets'] == 0:
47
+ res['CLR_FN'] = data['num_gt_dets']
48
+ res['ML'] = data['num_gt_ids']
49
+ res['MLR'] = 1.0
50
+ return res
51
+ if data['num_gt_dets'] == 0:
52
+ res['CLR_FP'] = data['num_tracker_dets']
53
+ res['MLR'] = 1.0
54
+ return res
55
+
56
+ # Variables counting global association
57
+ num_gt_ids = data['num_gt_ids']
58
+ gt_id_count = np.zeros(num_gt_ids) # For MT/ML/PT
59
+ gt_matched_count = np.zeros(num_gt_ids) # For MT/ML/PT
60
+ gt_frag_count = np.zeros(num_gt_ids) # For Frag
61
+
62
+ # Note that IDSWs are counted based on the last time each gt_id was present (any number of frames previously),
63
+ # but are only used in matching to continue current tracks based on the gt_id in the single previous timestep.
64
+ prev_tracker_id = np.nan * np.zeros(num_gt_ids) # For scoring IDSW
65
+ prev_timestep_tracker_id = np.nan * np.zeros(num_gt_ids) # For matching IDSW
66
+
67
+ # Calculate scores for each timestep
68
+ for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(data['gt_ids'], data['tracker_ids'])):
69
+ # Deal with the case that there are no gt_det/tracker_det in a timestep.
70
+ if len(gt_ids_t) == 0:
71
+ res['CLR_FP'] += len(tracker_ids_t)
72
+ continue
73
+ if len(tracker_ids_t) == 0:
74
+ res['CLR_FN'] += len(gt_ids_t)
75
+ gt_id_count[gt_ids_t] += 1
76
+ continue
77
+
78
+ # Calc score matrix to first minimise IDSWs from previous frame, and then maximise MOTP secondarily
79
+ similarity = data['similarity_scores'][t]
80
+ score_mat = (tracker_ids_t[np.newaxis, :] == prev_timestep_tracker_id[gt_ids_t[:, np.newaxis]])
81
+ score_mat = 1000 * score_mat + similarity
82
+ score_mat[similarity < self.threshold - np.finfo('float').eps] = 0
83
+
84
+ # Hungarian algorithm to find best matches
85
+ match_rows, match_cols = linear_sum_assignment(-score_mat)
86
+ actually_matched_mask = score_mat[match_rows, match_cols] > 0 + np.finfo('float').eps
87
+ match_rows = match_rows[actually_matched_mask]
88
+ match_cols = match_cols[actually_matched_mask]
89
+
90
+ matched_gt_ids = gt_ids_t[match_rows]
91
+ matched_tracker_ids = tracker_ids_t[match_cols]
92
+
93
+ # Calc IDSW for MOTA
94
+ prev_matched_tracker_ids = prev_tracker_id[matched_gt_ids]
95
+ is_idsw = (np.logical_not(np.isnan(prev_matched_tracker_ids))) & (
96
+ np.not_equal(matched_tracker_ids, prev_matched_tracker_ids))
97
+ res['IDSW'] += np.sum(is_idsw)
98
+
99
+ # Update counters for MT/ML/PT/Frag and record for IDSW/Frag for next timestep
100
+ gt_id_count[gt_ids_t] += 1
101
+ gt_matched_count[matched_gt_ids] += 1
102
+ not_previously_tracked = np.isnan(prev_timestep_tracker_id)
103
+ prev_tracker_id[matched_gt_ids] = matched_tracker_ids
104
+ prev_timestep_tracker_id[:] = np.nan
105
+ prev_timestep_tracker_id[matched_gt_ids] = matched_tracker_ids
106
+ currently_tracked = np.logical_not(np.isnan(prev_timestep_tracker_id))
107
+ gt_frag_count += np.logical_and(not_previously_tracked, currently_tracked)
108
+
109
+ # Calculate and accumulate basic statistics
110
+ num_matches = len(matched_gt_ids)
111
+ res['CLR_TP'] += num_matches
112
+ res['CLR_FN'] += len(gt_ids_t) - num_matches
113
+ res['CLR_FP'] += len(tracker_ids_t) - num_matches
114
+ if num_matches > 0:
115
+ res['MOTP_sum'] += sum(similarity[match_rows, match_cols])
116
+
117
+ # Calculate MT/ML/PT/Frag/MOTP
118
+ tracked_ratio = gt_matched_count[gt_id_count > 0] / gt_id_count[gt_id_count > 0]
119
+ res['MT'] = np.sum(np.greater(tracked_ratio, 0.8))
120
+ res['PT'] = np.sum(np.greater_equal(tracked_ratio, 0.2)) - res['MT']
121
+ res['ML'] = num_gt_ids - res['MT'] - res['PT']
122
+ res['Frag'] = np.sum(np.subtract(gt_frag_count[gt_frag_count > 0], 1))
123
+ res['MOTP'] = res['MOTP_sum'] / np.maximum(1.0, res['CLR_TP'])
124
+
125
+ res['CLR_Frames'] = data['num_timesteps']
126
+
127
+ # Calculate final CLEAR scores
128
+ res = self._compute_final_fields(res)
129
+ return res
130
+
131
+ def combine_sequences(self, all_res):
132
+ """Combines metrics across all sequences"""
133
+ res = {}
134
+ for field in self.summed_fields:
135
+ res[field] = self._combine_sum(all_res, field)
136
+ res = self._compute_final_fields(res)
137
+ return res
138
+
139
+ def combine_classes_det_averaged(self, all_res):
140
+ """Combines metrics across all classes by averaging over the detection values"""
141
+ res = {}
142
+ for field in self.summed_fields:
143
+ res[field] = self._combine_sum(all_res, field)
144
+ res = self._compute_final_fields(res)
145
+ return res
146
+
147
+ def combine_classes_class_averaged(self, all_res, ignore_empty_classes=False):
148
+ """Combines metrics across all classes by averaging over the class values.
149
+ If 'ignore_empty_classes' is True, then it only sums over classes with at least one gt or predicted detection.
150
+ """
151
+ res = {}
152
+ for field in self.integer_fields:
153
+ if ignore_empty_classes:
154
+ res[field] = self._combine_sum(
155
+ {k: v for k, v in all_res.items() if v['CLR_TP'] + v['CLR_FN'] + v['CLR_FP'] > 0}, field)
156
+ else:
157
+ res[field] = self._combine_sum({k: v for k, v in all_res.items()}, field)
158
+ for field in self.float_fields:
159
+ if ignore_empty_classes:
160
+ res[field] = np.mean(
161
+ [v[field] for v in all_res.values() if v['CLR_TP'] + v['CLR_FN'] + v['CLR_FP'] > 0], axis=0)
162
+ else:
163
+ res[field] = np.mean([v[field] for v in all_res.values()], axis=0)
164
+ return res
165
+
166
+ @staticmethod
167
+ def _compute_final_fields(res):
168
+ """Calculate sub-metric ('field') values which only depend on other sub-metric values.
169
+ This function is used both for both per-sequence calculation, and in combining values across sequences.
170
+ """
171
+ num_gt_ids = res['MT'] + res['ML'] + res['PT']
172
+ res['MTR'] = res['MT'] / np.maximum(1.0, num_gt_ids)
173
+ res['MLR'] = res['ML'] / np.maximum(1.0, num_gt_ids)
174
+ res['PTR'] = res['PT'] / np.maximum(1.0, num_gt_ids)
175
+ res['CLR_Re'] = res['CLR_TP'] / np.maximum(1.0, res['CLR_TP'] + res['CLR_FN'])
176
+ res['CLR_Pr'] = res['CLR_TP'] / np.maximum(1.0, res['CLR_TP'] + res['CLR_FP'])
177
+ res['MODA'] = (res['CLR_TP'] - res['CLR_FP']) / np.maximum(1.0, res['CLR_TP'] + res['CLR_FN'])
178
+ res['MOTA'] = (res['CLR_TP'] - res['CLR_FP'] - res['IDSW']) / np.maximum(1.0, res['CLR_TP'] + res['CLR_FN'])
179
+ res['MOTP'] = res['MOTP_sum'] / np.maximum(1.0, res['CLR_TP'])
180
+ res['sMOTA'] = (res['MOTP_sum'] - res['CLR_FP'] - res['IDSW']) / np.maximum(1.0, res['CLR_TP'] + res['CLR_FN'])
181
+
182
+ res['CLR_F1'] = res['CLR_TP'] / np.maximum(1.0, res['CLR_TP'] + 0.5*res['CLR_FN'] + 0.5*res['CLR_FP'])
183
+ res['FP_per_frame'] = res['CLR_FP'] / np.maximum(1.0, res['CLR_Frames'])
184
+ safe_log_idsw = np.log10(res['IDSW']) if res['IDSW'] > 0 else res['IDSW']
185
+ res['MOTAL'] = (res['CLR_TP'] - res['CLR_FP'] - safe_log_idsw) / np.maximum(1.0, res['CLR_TP'] + res['CLR_FN'])
186
+ return res
avism/data/aviseval/metrics/count.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from ._base_metric import _BaseMetric
3
+ from .. import _timing
4
+
5
+
6
+ class Count(_BaseMetric):
7
+ """Class which simply counts the number of tracker and gt detections and ids."""
8
+ def __init__(self, config=None):
9
+ super().__init__()
10
+ self.integer_fields = ['Dets', 'GT_Dets', 'IDs', 'GT_IDs']
11
+ self.fields = self.integer_fields
12
+ self.summary_fields = self.fields
13
+
14
+ @_timing.time
15
+ def eval_sequence(self, data):
16
+ """Returns counts for one sequence"""
17
+ # Get results
18
+ res = {'Dets': data['num_tracker_dets'],
19
+ 'GT_Dets': data['num_gt_dets'],
20
+ 'IDs': data['num_tracker_ids'],
21
+ 'GT_IDs': data['num_gt_ids'],
22
+ 'Frames': data['num_timesteps']}
23
+ return res
24
+
25
+ def combine_sequences(self, all_res):
26
+ """Combines metrics across all sequences"""
27
+ res = {}
28
+ for field in self.integer_fields:
29
+ res[field] = self._combine_sum(all_res, field)
30
+ return res
31
+
32
+ def combine_classes_class_averaged(self, all_res, ignore_empty_classes=None):
33
+ """Combines metrics across all classes by averaging over the class values"""
34
+ res = {}
35
+ for field in self.integer_fields:
36
+ res[field] = self._combine_sum(all_res, field)
37
+ return res
38
+
39
+ def combine_classes_det_averaged(self, all_res):
40
+ """Combines metrics across all classes by averaging over the detection values"""
41
+ res = {}
42
+ for field in self.integer_fields:
43
+ res[field] = self._combine_sum(all_res, field)
44
+ return res
avism/data/aviseval/metrics/hota.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from scipy.optimize import linear_sum_assignment
4
+ from ._base_metric import _BaseMetric
5
+ from .. import _timing
6
+
7
+
8
+ class HOTA(_BaseMetric):
9
+ """Class which implements the HOTA metrics.
10
+ See: https://link.springer.com/article/10.1007/s11263-020-01375-2
11
+ """
12
+
13
+ def __init__(self, config=None):
14
+ super().__init__()
15
+ self.plottable = True
16
+ self.array_labels = np.arange(0.05, 0.99, 0.05)
17
+ self.integer_array_fields = ['HOTA_TP', 'HOTA_FN', 'HOTA_FP']
18
+ self.float_array_fields = ['HOTA', 'DetA', 'AssA', 'DetRe', 'DetPr', 'AssRe', 'AssPr', 'LocA', 'OWTA']
19
+ self.float_fields = ['HOTA(0)', 'LocA(0)', 'HOTALocA(0)']
20
+ self.fields = self.float_array_fields + self.integer_array_fields + self.float_fields
21
+ self.summary_fields = self.float_array_fields + self.float_fields
22
+
23
+ @_timing.time
24
+ def eval_sequence(self, data):
25
+ """Calculates the HOTA metrics for one sequence"""
26
+
27
+ # Initialise results
28
+ res = {}
29
+ for field in self.float_array_fields + self.integer_array_fields:
30
+ res[field] = np.zeros((len(self.array_labels)), dtype=float)
31
+ for field in self.float_fields:
32
+ res[field] = 0
33
+
34
+ # Return result quickly if tracker or gt sequence is empty
35
+ if data['num_tracker_dets'] == 0:
36
+ res['HOTA_FN'] = data['num_gt_dets'] * np.ones((len(self.array_labels)), dtype=float)
37
+ res['LocA'] = np.ones((len(self.array_labels)), dtype=float)
38
+ res['LocA(0)'] = 1.0
39
+ return res
40
+ if data['num_gt_dets'] == 0:
41
+ res['HOTA_FP'] = data['num_tracker_dets'] * np.ones((len(self.array_labels)), dtype=float)
42
+ res['LocA'] = np.ones((len(self.array_labels)), dtype=float)
43
+ res['LocA(0)'] = 1.0
44
+ return res
45
+
46
+ # Variables counting global association
47
+ potential_matches_count = np.zeros((data['num_gt_ids'], data['num_tracker_ids']))
48
+ gt_id_count = np.zeros((data['num_gt_ids'], 1))
49
+ tracker_id_count = np.zeros((1, data['num_tracker_ids']))
50
+
51
+ # First loop through each timestep and accumulate global track information.
52
+ for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(data['gt_ids'], data['tracker_ids'])):
53
+ # Count the potential matches between ids in each timestep
54
+ # These are normalised, weighted by the match similarity.
55
+ similarity = data['similarity_scores'][t]
56
+ sim_iou_denom = similarity.sum(0)[np.newaxis, :] + similarity.sum(1)[:, np.newaxis] - similarity
57
+ sim_iou = np.zeros_like(similarity)
58
+ sim_iou_mask = sim_iou_denom > 0 + np.finfo('float').eps
59
+ sim_iou[sim_iou_mask] = similarity[sim_iou_mask] / sim_iou_denom[sim_iou_mask]
60
+ potential_matches_count[gt_ids_t[:, np.newaxis], tracker_ids_t[np.newaxis, :]] += sim_iou
61
+
62
+ # Calculate the total number of dets for each gt_id and tracker_id.
63
+ gt_id_count[gt_ids_t] += 1
64
+ tracker_id_count[0, tracker_ids_t] += 1
65
+
66
+ # Calculate overall jaccard alignment score (before unique matching) between IDs
67
+ global_alignment_score = potential_matches_count / (gt_id_count + tracker_id_count - potential_matches_count)
68
+ matches_counts = [np.zeros_like(potential_matches_count) for _ in self.array_labels]
69
+
70
+ # Calculate scores for each timestep
71
+ for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(data['gt_ids'], data['tracker_ids'])):
72
+ # Deal with the case that there are no gt_det/tracker_det in a timestep.
73
+ if len(gt_ids_t) == 0:
74
+ for a, alpha in enumerate(self.array_labels):
75
+ res['HOTA_FP'][a] += len(tracker_ids_t)
76
+ continue
77
+ if len(tracker_ids_t) == 0:
78
+ for a, alpha in enumerate(self.array_labels):
79
+ res['HOTA_FN'][a] += len(gt_ids_t)
80
+ continue
81
+
82
+ # Get matching scores between pairs of dets for optimizing HOTA
83
+ similarity = data['similarity_scores'][t]
84
+ score_mat = global_alignment_score[gt_ids_t[:, np.newaxis], tracker_ids_t[np.newaxis, :]] * similarity
85
+
86
+ # Hungarian algorithm to find best matches: 找出最优跟踪路线
87
+ match_rows, match_cols = linear_sum_assignment(-score_mat)
88
+
89
+ # Calculate and accumulate basic statistics
90
+ for a, alpha in enumerate(self.array_labels):
91
+ actually_matched_mask = similarity[match_rows, match_cols] >= alpha - np.finfo('float').eps
92
+ alpha_match_rows = match_rows[actually_matched_mask]
93
+ alpha_match_cols = match_cols[actually_matched_mask]
94
+ num_matches = len(alpha_match_rows)
95
+ res['HOTA_TP'][a] += num_matches
96
+ res['HOTA_FN'][a] += len(gt_ids_t) - num_matches
97
+ res['HOTA_FP'][a] += len(tracker_ids_t) - num_matches
98
+ if num_matches > 0:
99
+ res['LocA'][a] += sum(similarity[alpha_match_rows, alpha_match_cols])
100
+ matches_counts[a][gt_ids_t[alpha_match_rows], tracker_ids_t[alpha_match_cols]] += 1
101
+
102
+ # Calculate association scores (AssA, AssRe, AssPr) for the alpha value.
103
+ # First calculate scores per gt_id/tracker_id combo and then average over the number of detections.
104
+ for a, alpha in enumerate(self.array_labels):
105
+ matches_count = matches_counts[a]
106
+ ass_a = matches_count / np.maximum(1, gt_id_count + tracker_id_count - matches_count)
107
+ res['AssA'][a] = np.sum(matches_count * ass_a) / np.maximum(1, res['HOTA_TP'][a])
108
+ ass_re = matches_count / np.maximum(1, gt_id_count)
109
+ res['AssRe'][a] = np.sum(matches_count * ass_re) / np.maximum(1, res['HOTA_TP'][a])
110
+ ass_pr = matches_count / np.maximum(1, tracker_id_count)
111
+ res['AssPr'][a] = np.sum(matches_count * ass_pr) / np.maximum(1, res['HOTA_TP'][a])
112
+
113
+ # Calculate final scores
114
+ res['LocA'] = np.maximum(1e-10, res['LocA']) / np.maximum(1e-10, res['HOTA_TP'])
115
+ res = self._compute_final_fields(res)
116
+ return res
117
+
118
+ def combine_sequences(self, all_res):
119
+ """Combines metrics across all sequences"""
120
+ res = {}
121
+ for field in self.integer_array_fields:
122
+ res[field] = self._combine_sum(all_res, field)
123
+ for field in ['AssRe', 'AssPr', 'AssA']:
124
+ res[field] = self._combine_weighted_av(all_res, field, res, weight_field='HOTA_TP')
125
+ loca_weighted_sum = sum([all_res[k]['LocA'] * all_res[k]['HOTA_TP'] for k in all_res.keys()])
126
+ res['LocA'] = np.maximum(1e-10, loca_weighted_sum) / np.maximum(1e-10, res['HOTA_TP'])
127
+ res = self._compute_final_fields(res)
128
+ return res
129
+
130
+ def combine_classes_class_averaged(self, all_res, ignore_empty_classes=False):
131
+ """Combines metrics across all classes by averaging over the class values.
132
+ If 'ignore_empty_classes' is True, then it only sums over classes with at least one gt or predicted detection.
133
+ """
134
+ res = {}
135
+ for field in self.integer_array_fields:
136
+ if ignore_empty_classes:
137
+ res[field] = self._combine_sum(
138
+ {k: v for k, v in all_res.items()
139
+ if (v['HOTA_TP'] + v['HOTA_FN'] + v['HOTA_FP'] > 0 + np.finfo('float').eps).any()}, field)
140
+ else:
141
+ res[field] = self._combine_sum({k: v for k, v in all_res.items()}, field)
142
+
143
+ for field in self.float_fields + self.float_array_fields:
144
+ if ignore_empty_classes:
145
+ res[field] = np.mean([v[field] for v in all_res.values() if
146
+ (v['HOTA_TP'] + v['HOTA_FN'] + v['HOTA_FP'] > 0 + np.finfo('float').eps).any()],
147
+ axis=0)
148
+ else:
149
+ res[field] = np.mean([v[field] for v in all_res.values()], axis=0)
150
+ return res
151
+
152
+ def combine_classes_det_averaged(self, all_res):
153
+ """Combines metrics across all classes by averaging over the detection values"""
154
+ res = {}
155
+ for field in self.integer_array_fields:
156
+ res[field] = self._combine_sum(all_res, field)
157
+ for field in ['AssRe', 'AssPr', 'AssA']:
158
+ res[field] = self._combine_weighted_av(all_res, field, res, weight_field='HOTA_TP')
159
+ loca_weighted_sum = sum([all_res[k]['LocA'] * all_res[k]['HOTA_TP'] for k in all_res.keys()])
160
+ res['LocA'] = np.maximum(1e-10, loca_weighted_sum) / np.maximum(1e-10, res['HOTA_TP'])
161
+ res = self._compute_final_fields(res)
162
+ return res
163
+
164
+ @staticmethod
165
+ def _compute_final_fields(res):
166
+ """Calculate sub-metric ('field') values which only depend on other sub-metric values.
167
+ This function is used both for both per-sequence calculation, and in combining values across sequences.
168
+ """
169
+ res['DetRe'] = res['HOTA_TP'] / np.maximum(1, res['HOTA_TP'] + res['HOTA_FN'])
170
+ res['DetPr'] = res['HOTA_TP'] / np.maximum(1, res['HOTA_TP'] + res['HOTA_FP'])
171
+ res['DetA'] = res['HOTA_TP'] / np.maximum(1, res['HOTA_TP'] + res['HOTA_FN'] + res['HOTA_FP'])
172
+ res['HOTA'] = np.sqrt(res['DetA'] * res['AssA'])
173
+ res['OWTA'] = np.sqrt(res['DetRe'] * res['AssA'])
174
+
175
+ res['HOTA(0)'] = res['HOTA'][0]
176
+ res['LocA(0)'] = res['LocA'][0]
177
+ res['HOTALocA(0)'] = res['HOTA(0)']*res['LocA(0)']
178
+ return res
179
+
180
+ def plot_single_tracker_results(self, table_res, tracker, cls, output_folder):
181
+ """Create plot of results"""
182
+
183
+ # Only loaded when run to reduce minimum requirements
184
+ from matplotlib import pyplot as plt
185
+
186
+ res = table_res['COMBINED_SEQ']
187
+ styles_to_plot = ['r', 'b', 'g', 'b--', 'b:', 'g--', 'g:', 'm']
188
+ for name, style in zip(self.float_array_fields, styles_to_plot):
189
+ plt.plot(self.array_labels, res[name], style)
190
+ plt.xlabel('alpha')
191
+ plt.ylabel('score')
192
+ plt.title(tracker + ' - ' + cls)
193
+ plt.axis([0, 1, 0, 1])
194
+ legend = []
195
+ for name in self.float_array_fields:
196
+ legend += [name + ' (' + str(np.round(np.mean(res[name]), 2)) + ')']
197
+ plt.legend(legend, loc='lower left')
198
+ out_file = os.path.join(output_folder, cls + '_plot.pdf')
199
+ os.makedirs(os.path.dirname(out_file), exist_ok=True)
200
+ plt.savefig(out_file)
201
+ plt.savefig(out_file.replace('.pdf', '.png'))
202
+ plt.clf()
avism/data/aviseval/metrics/identity.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.optimize import linear_sum_assignment
3
+ from ._base_metric import _BaseMetric
4
+ from .. import _timing
5
+ from .. import utils
6
+
7
+
8
+ class Identity(_BaseMetric):
9
+ """Class which implements the ID metrics"""
10
+
11
+ @staticmethod
12
+ def get_default_config():
13
+ """Default class config values"""
14
+ default_config = {
15
+ 'THRESHOLD': 0.5, # Similarity score threshold required for a IDTP match. Default 0.5.
16
+ 'PRINT_CONFIG': True, # Whether to print the config information on init. Default: False.
17
+ }
18
+ return default_config
19
+
20
+ def __init__(self, config=None):
21
+ super().__init__()
22
+ self.integer_fields = ['IDTP', 'IDFN', 'IDFP']
23
+ self.float_fields = ['IDF1', 'IDR', 'IDP']
24
+ self.fields = self.float_fields + self.integer_fields
25
+ self.summary_fields = self.fields
26
+
27
+ # Configuration options:
28
+ self.config = utils.init_config(config, self.get_default_config(), self.get_name())
29
+ self.threshold = float(self.config['THRESHOLD'])
30
+
31
+ @_timing.time
32
+ def eval_sequence(self, data):
33
+ """Calculates ID metrics for one sequence"""
34
+ # Initialise results
35
+ res = {}
36
+ for field in self.fields:
37
+ res[field] = 0
38
+
39
+ # Return result quickly if tracker or gt sequence is empty
40
+ if data['num_tracker_dets'] == 0:
41
+ res['IDFN'] = data['num_gt_dets']
42
+ return res
43
+ if data['num_gt_dets'] == 0:
44
+ res['IDFP'] = data['num_tracker_dets']
45
+ return res
46
+
47
+ # Variables counting global association
48
+ potential_matches_count = np.zeros((data['num_gt_ids'], data['num_tracker_ids']))
49
+ gt_id_count = np.zeros(data['num_gt_ids'])
50
+ tracker_id_count = np.zeros(data['num_tracker_ids'])
51
+
52
+ # First loop through each timestep and accumulate global track information.
53
+ for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(data['gt_ids'], data['tracker_ids'])):
54
+ # Count the potential matches between ids in each timestep
55
+ matches_mask = np.greater_equal(data['similarity_scores'][t], self.threshold)
56
+ match_idx_gt, match_idx_tracker = np.nonzero(matches_mask)
57
+ potential_matches_count[gt_ids_t[match_idx_gt], tracker_ids_t[match_idx_tracker]] += 1
58
+
59
+ # Calculate the total number of dets for each gt_id and tracker_id.
60
+ gt_id_count[gt_ids_t] += 1
61
+ tracker_id_count[tracker_ids_t] += 1
62
+
63
+ # Calculate optimal assignment cost matrix for ID metrics
64
+ num_gt_ids = data['num_gt_ids']
65
+ num_tracker_ids = data['num_tracker_ids']
66
+ fp_mat = np.zeros((num_gt_ids + num_tracker_ids, num_gt_ids + num_tracker_ids))
67
+ fn_mat = np.zeros((num_gt_ids + num_tracker_ids, num_gt_ids + num_tracker_ids))
68
+ fp_mat[num_gt_ids:, :num_tracker_ids] = 1e10
69
+ fn_mat[:num_gt_ids, num_tracker_ids:] = 1e10
70
+ for gt_id in range(num_gt_ids):
71
+ fn_mat[gt_id, :num_tracker_ids] = gt_id_count[gt_id]
72
+ fn_mat[gt_id, num_tracker_ids + gt_id] = gt_id_count[gt_id]
73
+ for tracker_id in range(num_tracker_ids):
74
+ fp_mat[:num_gt_ids, tracker_id] = tracker_id_count[tracker_id]
75
+ fp_mat[tracker_id + num_gt_ids, tracker_id] = tracker_id_count[tracker_id]
76
+ fn_mat[:num_gt_ids, :num_tracker_ids] -= potential_matches_count
77
+ fp_mat[:num_gt_ids, :num_tracker_ids] -= potential_matches_count
78
+
79
+ # Hungarian algorithm
80
+ match_rows, match_cols = linear_sum_assignment(fn_mat + fp_mat)
81
+
82
+ # Accumulate basic statistics
83
+ res['IDFN'] = fn_mat[match_rows, match_cols].sum().astype(int)
84
+ res['IDFP'] = fp_mat[match_rows, match_cols].sum().astype(int)
85
+ res['IDTP'] = (gt_id_count.sum() - res['IDFN']).astype(int)
86
+
87
+ # Calculate final ID scores
88
+ res = self._compute_final_fields(res)
89
+ return res
90
+
91
+ def combine_classes_class_averaged(self, all_res, ignore_empty_classes=False):
92
+ """Combines metrics across all classes by averaging over the class values.
93
+ If 'ignore_empty_classes' is True, then it only sums over classes with at least one gt or predicted detection.
94
+ """
95
+ res = {}
96
+ for field in self.integer_fields:
97
+ if ignore_empty_classes:
98
+ res[field] = self._combine_sum({k: v for k, v in all_res.items()
99
+ if v['IDTP'] + v['IDFN'] + v['IDFP'] > 0 + np.finfo('float').eps},
100
+ field)
101
+ else:
102
+ res[field] = self._combine_sum({k: v for k, v in all_res.items()}, field)
103
+ for field in self.float_fields:
104
+ if ignore_empty_classes:
105
+ res[field] = np.mean([v[field] for v in all_res.values()
106
+ if v['IDTP'] + v['IDFN'] + v['IDFP'] > 0 + np.finfo('float').eps], axis=0)
107
+ else:
108
+ res[field] = np.mean([v[field] for v in all_res.values()], axis=0)
109
+ return res
110
+
111
+ def combine_classes_det_averaged(self, all_res):
112
+ """Combines metrics across all classes by averaging over the detection values"""
113
+ res = {}
114
+ for field in self.integer_fields:
115
+ res[field] = self._combine_sum(all_res, field)
116
+ res = self._compute_final_fields(res)
117
+ return res
118
+
119
+ def combine_sequences(self, all_res):
120
+ """Combines metrics across all sequences"""
121
+ res = {}
122
+ for field in self.integer_fields:
123
+ res[field] = self._combine_sum(all_res, field)
124
+ res = self._compute_final_fields(res)
125
+ return res
126
+
127
+ @staticmethod
128
+ def _compute_final_fields(res):
129
+ """Calculate sub-metric ('field') values which only depend on other sub-metric values.
130
+ This function is used both for both per-sequence calculation, and in combining values across sequences.
131
+ """
132
+ res['IDR'] = res['IDTP'] / np.maximum(1.0, res['IDTP'] + res['IDFN'])
133
+ res['IDP'] = res['IDTP'] / np.maximum(1.0, res['IDTP'] + res['IDFP'])
134
+ res['IDF1'] = res['IDTP'] / np.maximum(1.0, res['IDTP'] + 0.5 * res['IDFP'] + 0.5 * res['IDFN'])
135
+ return res
avism/data/aviseval/metrics/ideucl.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.optimize import linear_sum_assignment
3
+ from ._base_metric import _BaseMetric
4
+ from .. import _timing
5
+ from collections import defaultdict
6
+ from .. import utils
7
+
8
+
9
+ class IDEucl(_BaseMetric):
10
+ """Class which implements the ID metrics"""
11
+
12
+ @staticmethod
13
+ def get_default_config():
14
+ """Default class config values"""
15
+ default_config = {
16
+ 'THRESHOLD': 0.4, # Similarity score threshold required for a IDTP match. 0.4 for IDEucl.
17
+ 'PRINT_CONFIG': True, # Whether to print the config information on init. Default: False.
18
+ }
19
+ return default_config
20
+
21
+ def __init__(self, config=None):
22
+ super().__init__()
23
+ self.fields = ['IDEucl']
24
+ self.float_fields = self.fields
25
+ self.summary_fields = self.fields
26
+
27
+ # Configuration options:
28
+ self.config = utils.init_config(config, self.get_default_config(), self.get_name())
29
+ self.threshold = float(self.config['THRESHOLD'])
30
+
31
+
32
+ @_timing.time
33
+ def eval_sequence(self, data):
34
+ """Calculates IDEucl metrics for all frames"""
35
+ # Initialise results
36
+ res = {'IDEucl' : 0}
37
+
38
+ # Return result quickly if tracker or gt sequence is empty
39
+ if data['num_tracker_dets'] == 0 or data['num_gt_dets'] == 0.:
40
+ return res
41
+
42
+ data['centroid'] = []
43
+ for t, gt_det in enumerate(data['gt_dets']):
44
+ # import pdb;pdb.set_trace()
45
+ data['centroid'].append(self._compute_centroid(gt_det))
46
+
47
+ oid_hid_cent = defaultdict(list)
48
+ oid_cent = defaultdict(list)
49
+ for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(data['gt_ids'], data['tracker_ids'])):
50
+ matches_mask = np.greater_equal(data['similarity_scores'][t], self.threshold)
51
+
52
+ # I hope the orders of ids and boxes are maintained in `data`
53
+ for ind, gid in enumerate(gt_ids_t):
54
+ oid_cent[gid].append(data['centroid'][t][ind])
55
+
56
+ match_idx_gt, match_idx_tracker = np.nonzero(matches_mask)
57
+ for m_gid, m_tid in zip(match_idx_gt, match_idx_tracker):
58
+ oid_hid_cent[gt_ids_t[m_gid], tracker_ids_t[m_tid]].append(data['centroid'][t][m_gid])
59
+
60
+ oid_hid_dist = {k : np.sum(np.linalg.norm(np.diff(np.array(v), axis=0), axis=1)) for k, v in oid_hid_cent.items()}
61
+ oid_dist = {int(k) : np.sum(np.linalg.norm(np.diff(np.array(v), axis=0), axis=1)) for k, v in oid_cent.items()}
62
+
63
+ unique_oid = np.unique([i[0] for i in oid_hid_dist.keys()]).tolist()
64
+ unique_hid = np.unique([i[1] for i in oid_hid_dist.keys()]).tolist()
65
+ o_len = len(unique_oid)
66
+ h_len = len(unique_hid)
67
+ dist_matrix = np.zeros((o_len, h_len))
68
+ for ((oid, hid), dist) in oid_hid_dist.items():
69
+ oid_ind = unique_oid.index(oid)
70
+ hid_ind = unique_hid.index(hid)
71
+ dist_matrix[oid_ind, hid_ind] = dist
72
+
73
+ # opt_hyp_dist contains GT ID : max dist covered by track
74
+ opt_hyp_dist = dict.fromkeys(oid_dist.keys(), 0.)
75
+ cost_matrix = np.max(dist_matrix) - dist_matrix
76
+ rows, cols = linear_sum_assignment(cost_matrix)
77
+ for (row, col) in zip(rows, cols):
78
+ value = dist_matrix[row, col]
79
+ opt_hyp_dist[int(unique_oid[row])] = value
80
+
81
+ assert len(opt_hyp_dist.keys()) == len(oid_dist.keys())
82
+ hyp_length = np.sum(list(opt_hyp_dist.values()))
83
+ gt_length = np.sum(list(oid_dist.values()))
84
+ id_eucl =np.mean([np.divide(a, b, out=np.zeros_like(a), where=b!=0) for a, b in zip(opt_hyp_dist.values(), oid_dist.values())])
85
+ res['IDEucl'] = np.divide(hyp_length, gt_length, out=np.zeros_like(hyp_length), where=gt_length!=0)
86
+ return res
87
+
88
+ def combine_classes_class_averaged(self, all_res, ignore_empty_classes=False):
89
+ """Combines metrics across all classes by averaging over the class values.
90
+ If 'ignore_empty_classes' is True, then it only sums over classes with at least one gt or predicted detection.
91
+ """
92
+ res = {}
93
+
94
+ for field in self.float_fields:
95
+ if ignore_empty_classes:
96
+ res[field] = np.mean([v[field] for v in all_res.values()
97
+ if v['IDEucl'] > 0 + np.finfo('float').eps], axis=0)
98
+ else:
99
+ res[field] = np.mean([v[field] for v in all_res.values()], axis=0)
100
+ return res
101
+
102
+ def combine_classes_det_averaged(self, all_res):
103
+ """Combines metrics across all classes by averaging over the detection values"""
104
+ res = {}
105
+ for field in self.float_fields:
106
+ res[field] = self._combine_sum(all_res, field)
107
+ res = self._compute_final_fields(res, len(all_res))
108
+ return res
109
+
110
+ def combine_sequences(self, all_res):
111
+ """Combines metrics across all sequences"""
112
+ res = {}
113
+ for field in self.float_fields:
114
+ res[field] = self._combine_sum(all_res, field)
115
+ res = self._compute_final_fields(res, len(all_res))
116
+ return res
117
+
118
+
119
+ @staticmethod
120
+ def _compute_centroid(box):
121
+ box = np.array(box)
122
+ if len(box.shape) == 1:
123
+ centroid = (box[0:2] + box[2:4])/2
124
+ else:
125
+ centroid = (box[:, 0:2] + box[:, 2:4])/2
126
+ return np.flip(centroid, axis=1)
127
+
128
+
129
+ @staticmethod
130
+ def _compute_final_fields(res, res_len):
131
+ """
132
+ Exists only to match signature with the original Identiy class.
133
+
134
+ """
135
+ return {k:v/res_len for k,v in res.items()}
avism/data/aviseval/metrics/j_and_f.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import math
4
+ from scipy.optimize import linear_sum_assignment
5
+ from ..utils import TrackEvalException
6
+ from ._base_metric import _BaseMetric
7
+ from .. import _timing
8
+
9
+
10
+ class JAndF(_BaseMetric):
11
+ """Class which implements the J&F metrics"""
12
+ def __init__(self, config=None):
13
+ super().__init__()
14
+ self.integer_fields = ['num_gt_tracks']
15
+ self.float_fields = ['J-Mean', 'J-Recall', 'J-Decay', 'F-Mean', 'F-Recall', 'F-Decay', 'J&F']
16
+ self.fields = self.float_fields + self.integer_fields
17
+ self.summary_fields = self.float_fields
18
+ self.optim_type = 'J' # possible values J, J&F
19
+
20
+ @_timing.time
21
+ def eval_sequence(self, data):
22
+ """Returns J&F metrics for one sequence"""
23
+
24
+ # Only loaded when run to reduce minimum requirements
25
+ from pycocotools import mask as mask_utils
26
+
27
+ num_timesteps = data['num_timesteps']
28
+ num_tracker_ids = data['num_tracker_ids']
29
+ num_gt_ids = data['num_gt_ids']
30
+ gt_dets = data['gt_dets']
31
+ tracker_dets = data['tracker_dets']
32
+ gt_ids = data['gt_ids']
33
+ tracker_ids = data['tracker_ids']
34
+
35
+ # get shape of frames
36
+ frame_shape = None
37
+ if num_gt_ids > 0:
38
+ for t in range(num_timesteps):
39
+ if len(gt_ids[t]) > 0:
40
+ frame_shape = gt_dets[t][0]['size']
41
+ break
42
+ elif num_tracker_ids > 0:
43
+ for t in range(num_timesteps):
44
+ if len(tracker_ids[t]) > 0:
45
+ frame_shape = tracker_dets[t][0]['size']
46
+ break
47
+
48
+ if frame_shape:
49
+ # append all zero masks for timesteps in which tracks do not have a detection
50
+ zero_padding = np.zeros((frame_shape), order= 'F').astype(np.uint8)
51
+ padding_mask = mask_utils.encode(zero_padding)
52
+ for t in range(num_timesteps):
53
+ gt_id_det_mapping = {gt_ids[t][i]: gt_dets[t][i] for i in range(len(gt_ids[t]))}
54
+ gt_dets[t] = [gt_id_det_mapping[index] if index in gt_ids[t] else padding_mask for index
55
+ in range(num_gt_ids)]
56
+ tracker_id_det_mapping = {tracker_ids[t][i]: tracker_dets[t][i] for i in range(len(tracker_ids[t]))}
57
+ tracker_dets[t] = [tracker_id_det_mapping[index] if index in tracker_ids[t] else padding_mask for index
58
+ in range(num_tracker_ids)]
59
+ # also perform zero padding if number of tracker IDs < number of ground truth IDs
60
+ if num_tracker_ids < num_gt_ids:
61
+ diff = num_gt_ids - num_tracker_ids
62
+ for t in range(num_timesteps):
63
+ tracker_dets[t] = tracker_dets[t] + [padding_mask for _ in range(diff)]
64
+ num_tracker_ids += diff
65
+
66
+ j = self._compute_j(gt_dets, tracker_dets, num_gt_ids, num_tracker_ids, num_timesteps)
67
+
68
+ # boundary threshold for F computation
69
+ bound_th = 0.008
70
+
71
+ # perform matching
72
+ if self.optim_type == 'J&F':
73
+ f = np.zeros_like(j)
74
+ for k in range(num_tracker_ids):
75
+ for i in range(num_gt_ids):
76
+ f[k, i, :] = self._compute_f(gt_dets, tracker_dets, k, i, bound_th)
77
+ optim_metrics = (np.mean(j, axis=2) + np.mean(f, axis=2)) / 2
78
+ row_ind, col_ind = linear_sum_assignment(- optim_metrics)
79
+ j_m = j[row_ind, col_ind, :]
80
+ f_m = f[row_ind, col_ind, :]
81
+ elif self.optim_type == 'J':
82
+ optim_metrics = np.mean(j, axis=2)
83
+ row_ind, col_ind = linear_sum_assignment(- optim_metrics)
84
+ j_m = j[row_ind, col_ind, :]
85
+ f_m = np.zeros_like(j_m)
86
+ for i, (tr_ind, gt_ind) in enumerate(zip(row_ind, col_ind)):
87
+ f_m[i] = self._compute_f(gt_dets, tracker_dets, tr_ind, gt_ind, bound_th)
88
+ else:
89
+ raise TrackEvalException('Unsupported optimization type %s for J&F metric.' % self.optim_type)
90
+
91
+ # append zeros for false negatives
92
+ if j_m.shape[0] < data['num_gt_ids']:
93
+ diff = data['num_gt_ids'] - j_m.shape[0]
94
+ j_m = np.concatenate((j_m, np.zeros((diff, j_m.shape[1]))), axis=0)
95
+ f_m = np.concatenate((f_m, np.zeros((diff, f_m.shape[1]))), axis=0)
96
+
97
+ # compute the metrics for each ground truth track
98
+ res = {
99
+ 'J-Mean': [np.nanmean(j_m[i, :]) for i in range(j_m.shape[0])],
100
+ 'J-Recall': [np.nanmean(j_m[i, :] > 0.5 + np.finfo('float').eps) for i in range(j_m.shape[0])],
101
+ 'F-Mean': [np.nanmean(f_m[i, :]) for i in range(f_m.shape[0])],
102
+ 'F-Recall': [np.nanmean(f_m[i, :] > 0.5 + np.finfo('float').eps) for i in range(f_m.shape[0])],
103
+ 'J-Decay': [],
104
+ 'F-Decay': []
105
+ }
106
+ n_bins = 4
107
+ ids = np.round(np.linspace(1, data['num_timesteps'], n_bins + 1) + 1e-10) - 1
108
+ ids = ids.astype(np.uint8)
109
+
110
+ for k in range(j_m.shape[0]):
111
+ d_bins_j = [j_m[k][ids[i]:ids[i + 1] + 1] for i in range(0, n_bins)]
112
+ res['J-Decay'].append(np.nanmean(d_bins_j[0]) - np.nanmean(d_bins_j[3]))
113
+ for k in range(f_m.shape[0]):
114
+ d_bins_f = [f_m[k][ids[i]:ids[i + 1] + 1] for i in range(0, n_bins)]
115
+ res['F-Decay'].append(np.nanmean(d_bins_f[0]) - np.nanmean(d_bins_f[3]))
116
+
117
+ # count number of tracks for weighting of the result
118
+ res['num_gt_tracks'] = len(res['J-Mean'])
119
+ for field in ['J-Mean', 'J-Recall', 'J-Decay', 'F-Mean', 'F-Recall', 'F-Decay']:
120
+ res[field] = np.mean(res[field])
121
+ res['J&F'] = (res['J-Mean'] + res['F-Mean']) / 2
122
+ return res
123
+
124
+ def combine_sequences(self, all_res):
125
+ """Combines metrics across all sequences"""
126
+ res = {'num_gt_tracks': self._combine_sum(all_res, 'num_gt_tracks')}
127
+ for field in self.summary_fields:
128
+ res[field] = self._combine_weighted_av(all_res, field, res, weight_field='num_gt_tracks')
129
+ return res
130
+
131
+ def combine_classes_class_averaged(self, all_res, ignore_empty_classes=False):
132
+ """Combines metrics across all classes by averaging over the class values
133
+ 'ignore empty classes' is not yet implemented here.
134
+ """
135
+ res = {'num_gt_tracks': self._combine_sum(all_res, 'num_gt_tracks')}
136
+ for field in self.float_fields:
137
+ res[field] = np.mean([v[field] for v in all_res.values()])
138
+ return res
139
+
140
+ def combine_classes_det_averaged(self, all_res):
141
+ """Combines metrics across all classes by averaging over the detection values"""
142
+ res = {'num_gt_tracks': self._combine_sum(all_res, 'num_gt_tracks')}
143
+ for field in self.float_fields:
144
+ res[field] = np.mean([v[field] for v in all_res.values()])
145
+ return res
146
+
147
+ @staticmethod
148
+ def _seg2bmap(seg, width=None, height=None):
149
+ """
150
+ From a segmentation, compute a binary boundary map with 1 pixel wide
151
+ boundaries. The boundary pixels are offset by 1/2 pixel towards the
152
+ origin from the actual segment boundary.
153
+ Arguments:
154
+ seg : Segments labeled from 1..k.
155
+ width : Width of desired bmap <= seg.shape[1]
156
+ height : Height of desired bmap <= seg.shape[0]
157
+ Returns:
158
+ bmap (ndarray): Binary boundary map.
159
+ David Martin <[email protected]>
160
+ January 2003
161
+ """
162
+
163
+ seg = seg.astype(bool)
164
+ seg[seg > 0] = 1
165
+
166
+ assert np.atleast_3d(seg).shape[2] == 1
167
+
168
+ width = seg.shape[1] if width is None else width
169
+ height = seg.shape[0] if height is None else height
170
+
171
+ h, w = seg.shape[:2]
172
+
173
+ ar1 = float(width) / float(height)
174
+ ar2 = float(w) / float(h)
175
+
176
+ assert not (
177
+ width > w | height > h | abs(ar1 - ar2) > 0.01
178
+ ), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height)
179
+
180
+ e = np.zeros_like(seg)
181
+ s = np.zeros_like(seg)
182
+ se = np.zeros_like(seg)
183
+
184
+ e[:, :-1] = seg[:, 1:]
185
+ s[:-1, :] = seg[1:, :]
186
+ se[:-1, :-1] = seg[1:, 1:]
187
+
188
+ b = seg ^ e | seg ^ s | seg ^ se
189
+ b[-1, :] = seg[-1, :] ^ e[-1, :]
190
+ b[:, -1] = seg[:, -1] ^ s[:, -1]
191
+ b[-1, -1] = 0
192
+
193
+ if w == width and h == height:
194
+ bmap = b
195
+ else:
196
+ bmap = np.zeros((height, width))
197
+ for x in range(w):
198
+ for y in range(h):
199
+ if b[y, x]:
200
+ j = 1 + math.floor((y - 1) + height / h)
201
+ i = 1 + math.floor((x - 1) + width / h)
202
+ bmap[j, i] = 1
203
+
204
+ return bmap
205
+
206
+ @staticmethod
207
+ def _compute_f(gt_data, tracker_data, tracker_data_id, gt_id, bound_th):
208
+ """
209
+ Perform F computation for a given gt and a given tracker ID. Adapted from
210
+ https://github.com/davisvideochallenge/davis2017-evaluation
211
+ :param gt_data: the encoded gt masks
212
+ :param tracker_data: the encoded tracker masks
213
+ :param tracker_data_id: the tracker ID
214
+ :param gt_id: the ground truth ID
215
+ :param bound_th: boundary threshold parameter
216
+ :return: the F value for the given tracker and gt ID
217
+ """
218
+
219
+ # Only loaded when run to reduce minimum requirements
220
+ from pycocotools import mask as mask_utils
221
+ from skimage.morphology import disk
222
+ import cv2
223
+
224
+ f = np.zeros(len(gt_data))
225
+
226
+ for t, (gt_masks, tracker_masks) in enumerate(zip(gt_data, tracker_data)):
227
+ curr_tracker_mask = mask_utils.decode(tracker_masks[tracker_data_id])
228
+ curr_gt_mask = mask_utils.decode(gt_masks[gt_id])
229
+
230
+ bound_pix = bound_th if bound_th >= 1 - np.finfo('float').eps else \
231
+ np.ceil(bound_th * np.linalg.norm(curr_tracker_mask.shape))
232
+
233
+ # Get the pixel boundaries of both masks
234
+ fg_boundary = JAndF._seg2bmap(curr_tracker_mask)
235
+ gt_boundary = JAndF._seg2bmap(curr_gt_mask)
236
+
237
+ # fg_dil = binary_dilation(fg_boundary, disk(bound_pix))
238
+ fg_dil = cv2.dilate(fg_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8))
239
+ # gt_dil = binary_dilation(gt_boundary, disk(bound_pix))
240
+ gt_dil = cv2.dilate(gt_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8))
241
+
242
+ # Get the intersection
243
+ gt_match = gt_boundary * fg_dil
244
+ fg_match = fg_boundary * gt_dil
245
+
246
+ # Area of the intersection
247
+ n_fg = np.sum(fg_boundary)
248
+ n_gt = np.sum(gt_boundary)
249
+
250
+ # % Compute precision and recall
251
+ if n_fg == 0 and n_gt > 0:
252
+ precision = 1
253
+ recall = 0
254
+ elif n_fg > 0 and n_gt == 0:
255
+ precision = 0
256
+ recall = 1
257
+ elif n_fg == 0 and n_gt == 0:
258
+ precision = 1
259
+ recall = 1
260
+ else:
261
+ precision = np.sum(fg_match) / float(n_fg)
262
+ recall = np.sum(gt_match) / float(n_gt)
263
+
264
+ # Compute F measure
265
+ if precision + recall == 0:
266
+ f_val = 0
267
+ else:
268
+ f_val = 2 * precision * recall / (precision + recall)
269
+
270
+ f[t] = f_val
271
+
272
+ return f
273
+
274
+ @staticmethod
275
+ def _compute_j(gt_data, tracker_data, num_gt_ids, num_tracker_ids, num_timesteps):
276
+ """
277
+ Computation of J value for all ground truth IDs and all tracker IDs in the given sequence. Adapted from
278
+ https://github.com/davisvideochallenge/davis2017-evaluation
279
+ :param gt_data: the ground truth masks
280
+ :param tracker_data: the tracker masks
281
+ :param num_gt_ids: the number of ground truth IDs
282
+ :param num_tracker_ids: the number of tracker IDs
283
+ :param num_timesteps: the number of timesteps
284
+ :return: the J values
285
+ """
286
+
287
+ # Only loaded when run to reduce minimum requirements
288
+ from pycocotools import mask as mask_utils
289
+
290
+ j = np.zeros((num_tracker_ids, num_gt_ids, num_timesteps))
291
+
292
+ for t, (time_gt, time_data) in enumerate(zip(gt_data, tracker_data)):
293
+ # run length encoded masks with pycocotools
294
+ area_gt = mask_utils.area(time_gt)
295
+ time_data = list(time_data)
296
+ area_tr = mask_utils.area(time_data)
297
+
298
+ area_tr = np.repeat(area_tr[:, np.newaxis], len(area_gt), axis=1)
299
+ area_gt = np.repeat(area_gt[np.newaxis, :], len(area_tr), axis=0)
300
+
301
+ # mask iou computation with pycocotools
302
+ ious = np.atleast_2d(mask_utils.iou(time_data, time_gt, [0]*len(time_gt)))
303
+ # set iou to 1 if both masks are close to 0 (no ground truth and no predicted mask in timestep)
304
+ ious[np.isclose(area_tr, 0) & np.isclose(area_gt, 0)] = 1
305
+ assert (ious >= 0 - np.finfo('float').eps).all()
306
+ assert (ious <= 1 + np.finfo('float').eps).all()
307
+
308
+ j[..., t] = ious
309
+
310
+ return j
avism/data/aviseval/metrics/track_map.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from ._base_metric import _BaseMetric
3
+ from .. import _timing
4
+ from functools import partial
5
+ from .. import utils
6
+ from ..utils import TrackEvalException
7
+
8
+
9
+ class TrackMAP(_BaseMetric):
10
+ """Class which implements the TrackMAP metrics"""
11
+
12
+ @staticmethod
13
+ def get_default_metric_config():
14
+ """Default class config values"""
15
+ default_config = {
16
+ 'USE_AREA_RANGES': True, # whether to evaluate for certain area ranges
17
+ 'AREA_RANGES': [[0 ** 2, 32 ** 2], # additional area range sets for which TrackMAP is evaluated
18
+ [32 ** 2, 96 ** 2], # (all area range always included), default values for TAO
19
+ [96 ** 2, 1e5 ** 2]], # evaluation
20
+ 'AREA_RANGE_LABELS': ["area_s", "area_m", "area_l"], # the labels for the area ranges
21
+ 'USE_TIME_RANGES': True, # whether to evaluate for certain time ranges (length of tracks)
22
+ 'TIME_RANGES': [[0, 3], [3, 10], [10, 1e5]], # additional time range sets for which TrackMAP is evaluated
23
+ # (all time range always included) , default values for TAO evaluation
24
+ 'TIME_RANGE_LABELS': ["time_s", "time_m", "time_l"], # the labels for the time ranges
25
+ 'IOU_THRESHOLDS': np.arange(0.5, 0.96, 0.05), # the IoU thresholds
26
+ 'RECALL_THRESHOLDS': np.linspace(0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01) + 1), endpoint=True),
27
+ # recall thresholds at which precision is evaluated
28
+ 'MAX_DETECTIONS': 0, # limit the maximum number of considered tracks per sequence (0 for unlimited)
29
+ 'PRINT_CONFIG': False
30
+ }
31
+ return default_config
32
+
33
+ def __init__(self, config=None):
34
+ super().__init__()
35
+ self.config = utils.init_config(config, self.get_default_metric_config(), self.get_name())
36
+
37
+ self.num_ig_masks = 1
38
+ self.lbls = ['all']
39
+ self.use_area_rngs = self.config['USE_AREA_RANGES']
40
+ if self.use_area_rngs:
41
+ self.area_rngs = self.config['AREA_RANGES']
42
+ self.area_rng_lbls = self.config['AREA_RANGE_LABELS']
43
+ self.num_ig_masks += len(self.area_rng_lbls)
44
+ self.lbls += self.area_rng_lbls
45
+
46
+ self.use_time_rngs = self.config['USE_TIME_RANGES']
47
+ if self.use_time_rngs:
48
+ self.time_rngs = self.config['TIME_RANGES']
49
+ self.time_rng_lbls = self.config['TIME_RANGE_LABELS']
50
+ self.num_ig_masks += len(self.time_rng_lbls)
51
+ self.lbls += self.time_rng_lbls
52
+
53
+ self.array_labels = self.config['IOU_THRESHOLDS']
54
+ self.rec_thrs = self.config['RECALL_THRESHOLDS']
55
+
56
+ self.maxDet = self.config['MAX_DETECTIONS']
57
+ self.float_array_fields = ['AP_' + lbl for lbl in self.lbls] + ['AR_' + lbl for lbl in self.lbls]
58
+ self.fields = self.float_array_fields
59
+ self.summary_fields = self.float_array_fields
60
+
61
+ @_timing.time
62
+ def eval_sequence(self, data):
63
+ """Calculates GT and Tracker matches for one sequence for TrackMAP metrics. Adapted from
64
+ https://github.com/TAO-Dataset/"""
65
+
66
+ # Initialise results to zero for each sequence as the fields are only defined over the set of all sequences
67
+ res = {}
68
+ for field in self.fields:
69
+ res[field] = [0 for _ in self.array_labels]
70
+
71
+ gt_ids, dt_ids = data['gt_track_ids'], data['dt_track_ids']
72
+
73
+ if len(gt_ids) == 0 and len(dt_ids) == 0:
74
+ for idx in range(self.num_ig_masks):
75
+ res[idx] = None
76
+ return res
77
+
78
+ # get track data
79
+ gt_tr_areas = data.get('gt_track_areas', None) if self.use_area_rngs else None
80
+ gt_tr_lengths = data.get('gt_track_lengths', None) if self.use_time_rngs else None
81
+ gt_tr_iscrowd = data.get('gt_track_iscrowd', None)
82
+ dt_tr_areas = data.get('dt_track_areas', None) if self.use_area_rngs else None
83
+ dt_tr_lengths = data.get('dt_track_lengths', None) if self.use_time_rngs else None
84
+ is_nel = data.get('not_exhaustively_labeled', False)
85
+
86
+ # compute ignore masks for different track sets to eval
87
+ gt_ig_masks = self._compute_track_ig_masks(len(gt_ids), track_lengths=gt_tr_lengths, track_areas=gt_tr_areas,
88
+ iscrowd=gt_tr_iscrowd)
89
+ dt_ig_masks = self._compute_track_ig_masks(len(dt_ids), track_lengths=dt_tr_lengths, track_areas=dt_tr_areas,
90
+ is_not_exhaustively_labeled=is_nel, is_gt=False)
91
+
92
+ boxformat = data.get('boxformat', 'xywh')
93
+ ious = self._compute_track_ious(data['dt_tracks'], data['gt_tracks'], iou_function=data['iou_type'],
94
+ boxformat=boxformat)
95
+
96
+ for mask_idx in range(self.num_ig_masks):
97
+ gt_ig_mask = gt_ig_masks[mask_idx]
98
+
99
+ # Sort gt ignore last
100
+ gt_idx = np.argsort([g for g in gt_ig_mask], kind="mergesort")
101
+ gt_ids = [gt_ids[i] for i in gt_idx]
102
+
103
+ ious_sorted = ious[:, gt_idx] if len(ious) > 0 else ious
104
+
105
+ num_thrs = len(self.array_labels)
106
+ num_gt = len(gt_ids)
107
+ num_dt = len(dt_ids)
108
+
109
+ # Array to store the "id" of the matched dt/gt
110
+ gt_m = np.zeros((num_thrs, num_gt)) - 1
111
+ dt_m = np.zeros((num_thrs, num_dt)) - 1
112
+
113
+ gt_ig = np.array([gt_ig_mask[idx] for idx in gt_idx])
114
+ dt_ig = np.zeros((num_thrs, num_dt))
115
+
116
+ for iou_thr_idx, iou_thr in enumerate(self.array_labels):
117
+ if len(ious_sorted) == 0:
118
+ break
119
+
120
+ for dt_idx, _dt in enumerate(dt_ids):
121
+ iou = min([iou_thr, 1 - 1e-10])
122
+ # information about best match so far (m=-1 -> unmatched)
123
+ # store the gt_idx which matched for _dt
124
+ m = -1
125
+ for gt_idx, _ in enumerate(gt_ids):
126
+ # if this gt already matched continue
127
+ if gt_m[iou_thr_idx, gt_idx] > 0:
128
+ continue
129
+ # if _dt matched to reg gt, and on ignore gt, stop
130
+ if m > -1 and gt_ig[m] == 0 and gt_ig[gt_idx] == 1:
131
+ break
132
+ # continue to next gt unless better match made
133
+ if ious_sorted[dt_idx, gt_idx] < iou - np.finfo('float').eps:
134
+ continue
135
+ # if match successful and best so far, store appropriately
136
+ iou = ious_sorted[dt_idx, gt_idx]
137
+ m = gt_idx
138
+
139
+ # No match found for _dt, go to next _dt
140
+ if m == -1:
141
+ continue
142
+
143
+ # if gt to ignore for some reason update dt_ig.
144
+ # Should not be used in evaluation.
145
+ dt_ig[iou_thr_idx, dt_idx] = gt_ig[m]
146
+ # _dt match found, update gt_m, and dt_m with "id"
147
+ dt_m[iou_thr_idx, dt_idx] = gt_ids[m]
148
+ gt_m[iou_thr_idx, m] = _dt
149
+
150
+ dt_ig_mask = dt_ig_masks[mask_idx]
151
+
152
+ dt_ig_mask = np.array(dt_ig_mask).reshape((1, num_dt)) # 1 X num_dt
153
+ dt_ig_mask = np.repeat(dt_ig_mask, num_thrs, 0) # num_thrs X num_dt
154
+
155
+ # Based on dt_ig_mask ignore any unmatched detection by updating dt_ig
156
+ dt_ig = np.logical_or(dt_ig, np.logical_and(dt_m == -1, dt_ig_mask))
157
+ # store results for given video and category
158
+ res[mask_idx] = {
159
+ "dt_ids": dt_ids,
160
+ "gt_ids": gt_ids,
161
+ "dt_matches": dt_m,
162
+ "gt_matches": gt_m,
163
+ "dt_scores": data['dt_track_scores'],
164
+ "gt_ignore": gt_ig,
165
+ "dt_ignore": dt_ig,
166
+ }
167
+
168
+ return res
169
+
170
+ def combine_sequences(self, all_res):
171
+ """Combines metrics across all sequences. Computes precision and recall values based on track matches.
172
+ Adapted from https://github.com/TAO-Dataset/
173
+ """
174
+ num_thrs = len(self.array_labels)
175
+ num_recalls = len(self.rec_thrs)
176
+
177
+ # -1 for absent categories
178
+ precision = -np.ones(
179
+ (num_thrs, num_recalls, self.num_ig_masks)
180
+ )
181
+ recall = -np.ones((num_thrs, self.num_ig_masks))
182
+
183
+ for ig_idx in range(self.num_ig_masks):
184
+ ig_idx_results = [res[ig_idx] for res in all_res.values() if res[ig_idx] is not None]
185
+
186
+ # Remove elements which are None
187
+ if len(ig_idx_results) == 0:
188
+ continue
189
+
190
+ # Append all scores: shape (N,)
191
+ # limit considered tracks for each sequence if maxDet > 0
192
+ if self.maxDet == 0:
193
+ dt_scores = np.concatenate([res["dt_scores"] for res in ig_idx_results], axis=0)
194
+
195
+ dt_idx = np.argsort(-dt_scores, kind="mergesort")
196
+
197
+ dt_m = np.concatenate([e["dt_matches"] for e in ig_idx_results],
198
+ axis=1)[:, dt_idx]
199
+ dt_ig = np.concatenate([e["dt_ignore"] for e in ig_idx_results],
200
+ axis=1)[:, dt_idx]
201
+ elif self.maxDet > 0:
202
+ dt_scores = np.concatenate([res["dt_scores"][0:self.maxDet] for res in ig_idx_results], axis=0)
203
+
204
+ dt_idx = np.argsort(-dt_scores, kind="mergesort")
205
+
206
+ dt_m = np.concatenate([e["dt_matches"][:, 0:self.maxDet] for e in ig_idx_results],
207
+ axis=1)[:, dt_idx]
208
+ dt_ig = np.concatenate([e["dt_ignore"][:, 0:self.maxDet] for e in ig_idx_results],
209
+ axis=1)[:, dt_idx]
210
+ else:
211
+ raise Exception("Number of maximum detections must be >= 0, but is set to %i" % self.maxDet)
212
+
213
+ gt_ig = np.concatenate([res["gt_ignore"] for res in ig_idx_results])
214
+ # num gt anns to consider
215
+ num_gt = np.count_nonzero(gt_ig == 0)
216
+
217
+ if num_gt == 0:
218
+ continue
219
+
220
+ tps = np.logical_and(dt_m != -1, np.logical_not(dt_ig))
221
+ fps = np.logical_and(dt_m == -1, np.logical_not(dt_ig))
222
+
223
+ tp_sum = np.cumsum(tps, axis=1).astype(dtype=float)
224
+ fp_sum = np.cumsum(fps, axis=1).astype(dtype=float)
225
+
226
+ for iou_thr_idx, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
227
+ tp = np.array(tp)
228
+ fp = np.array(fp)
229
+ num_tp = len(tp)
230
+ rc = tp / num_gt
231
+ if num_tp:
232
+ recall[iou_thr_idx, ig_idx] = rc[-1]
233
+ else:
234
+ recall[iou_thr_idx, ig_idx] = 0
235
+
236
+ # np.spacing(1) ~= eps
237
+ pr = tp / (fp + tp + np.spacing(1))
238
+ pr = pr.tolist()
239
+
240
+ # Ensure precision values are monotonically decreasing
241
+ for i in range(num_tp - 1, 0, -1):
242
+ if pr[i] > pr[i - 1]:
243
+ pr[i - 1] = pr[i]
244
+
245
+ # find indices at the predefined recall values
246
+ rec_thrs_insert_idx = np.searchsorted(rc, self.rec_thrs, side="left")
247
+
248
+ pr_at_recall = [0.0] * num_recalls
249
+
250
+ try:
251
+ for _idx, pr_idx in enumerate(rec_thrs_insert_idx):
252
+ pr_at_recall[_idx] = pr[pr_idx]
253
+ except IndexError:
254
+ pass
255
+
256
+ precision[iou_thr_idx, :, ig_idx] = (np.array(pr_at_recall))
257
+
258
+ res = {'precision': precision, 'recall': recall}
259
+
260
+ # compute the precision and recall averages for the respective alpha thresholds and ignore masks
261
+ for lbl in self.lbls:
262
+ res['AP_' + lbl] = np.zeros((len(self.array_labels)), dtype=float)
263
+ res['AR_' + lbl] = np.zeros((len(self.array_labels)), dtype=float)
264
+
265
+ for a_id, alpha in enumerate(self.array_labels):
266
+ for lbl_idx, lbl in enumerate(self.lbls):
267
+ p = precision[a_id, :, lbl_idx]
268
+ if len(p[p > -1]) == 0:
269
+ mean_p = -1
270
+ else:
271
+ mean_p = np.mean(p[p > -1])
272
+ res['AP_' + lbl][a_id] = mean_p
273
+ res['AR_' + lbl][a_id] = recall[a_id, lbl_idx]
274
+
275
+ return res
276
+
277
+ def combine_classes_class_averaged(self, all_res, ignore_empty_classes=True):
278
+ """Combines metrics across all classes by averaging over the class values
279
+ Note mAP is not well defined for 'empty classes' so 'ignore empty classes' is always true here.
280
+ """
281
+ res = {}
282
+ for field in self.fields:
283
+ res[field] = np.zeros((len(self.array_labels)), dtype=float)
284
+ field_stacked = np.array([res[field] for res in all_res.values()])
285
+
286
+ for a_id, alpha in enumerate(self.array_labels):
287
+ values = field_stacked[:, a_id]
288
+ if len(values[values > -1]) == 0:
289
+ mean = -1
290
+ else:
291
+ mean = np.mean(values[values > -1])
292
+ res[field][a_id] = mean
293
+ return res
294
+
295
+ def combine_classes_det_averaged(self, all_res):
296
+ """Combines metrics across all classes by averaging over the detection values"""
297
+
298
+ res = {}
299
+ for field in self.fields:
300
+ res[field] = np.zeros((len(self.array_labels)), dtype=float)
301
+ field_stacked = np.array([res[field] for res in all_res.values()])
302
+
303
+ for a_id, alpha in enumerate(self.array_labels):
304
+ values = field_stacked[:, a_id]
305
+ if len(values[values > -1]) == 0:
306
+ mean = -1
307
+ else:
308
+ mean = np.mean(values[values > -1])
309
+ res[field][a_id] = mean
310
+ return res
311
+
312
+ def _compute_track_ig_masks(self, num_ids, track_lengths=None, track_areas=None, iscrowd=None,
313
+ is_not_exhaustively_labeled=False, is_gt=True):
314
+ """
315
+ Computes ignore masks for different track sets to evaluate
316
+ :param num_ids: the number of track IDs
317
+ :param track_lengths: the lengths of the tracks (number of timesteps)
318
+ :param track_areas: the average area of a track
319
+ :param iscrowd: whether a track is marked as crowd
320
+ :param is_not_exhaustively_labeled: whether the track category is not exhaustively labeled
321
+ :param is_gt: whether it is gt
322
+ :return: the track ignore masks
323
+ """
324
+ # for TAO tracks for classes which are not exhaustively labeled are not evaluated
325
+ if not is_gt and is_not_exhaustively_labeled:
326
+ track_ig_masks = [[1 for _ in range(num_ids)] for i in range(self.num_ig_masks)]
327
+ else:
328
+ # consider all tracks
329
+ track_ig_masks = [[0 for _ in range(num_ids)]]
330
+
331
+ # consider tracks with certain area
332
+ if self.use_area_rngs:
333
+ for rng in self.area_rngs:
334
+ track_ig_masks.append([0 if rng[0] - np.finfo('float').eps <= area <= rng[1] + np.finfo('float').eps
335
+ else 1 for area in track_areas])
336
+
337
+ # consider tracks with certain duration
338
+ if self.use_time_rngs:
339
+ for rng in self.time_rngs:
340
+ track_ig_masks.append([0 if rng[0] - np.finfo('float').eps <= length
341
+ <= rng[1] + np.finfo('float').eps else 1 for length in track_lengths])
342
+
343
+ # for YouTubeVIS evaluation tracks with crowd tag are not evaluated
344
+ if is_gt and iscrowd:
345
+ track_ig_masks = [np.logical_or(mask, iscrowd) for mask in track_ig_masks]
346
+
347
+ return track_ig_masks
348
+
349
+ @staticmethod
350
+ def _compute_bb_track_iou(dt_track, gt_track, boxformat='xywh'):
351
+ """
352
+ Calculates the track IoU for one detected track and one ground truth track for bounding boxes
353
+ :param dt_track: the detected track (format: dictionary with frame index as keys and
354
+ numpy arrays as values)
355
+ :param gt_track: the ground truth track (format: dictionary with frame index as keys and
356
+ numpy array as values)
357
+ :param boxformat: the format of the boxes
358
+ :return: the track IoU
359
+ """
360
+ intersect = 0
361
+ union = 0
362
+ image_ids = set(gt_track.keys()) | set(dt_track.keys())
363
+ for image in image_ids:
364
+ g = gt_track.get(image, None)
365
+ d = dt_track.get(image, None)
366
+ if boxformat == 'xywh':
367
+ if d is not None and g is not None:
368
+ dx, dy, dw, dh = d
369
+ gx, gy, gw, gh = g
370
+ w = max(min(dx + dw, gx + gw) - max(dx, gx), 0)
371
+ h = max(min(dy + dh, gy + gh) - max(dy, gy), 0)
372
+ i = w * h
373
+ u = dw * dh + gw * gh - i
374
+ intersect += i
375
+ union += u
376
+ elif d is None and g is not None:
377
+ union += g[2] * g[3]
378
+ elif d is not None and g is None:
379
+ union += d[2] * d[3]
380
+ elif boxformat == 'x0y0x1y1':
381
+ if d is not None and g is not None:
382
+ dx0, dy0, dx1, dy1 = d
383
+ gx0, gy0, gx1, gy1 = g
384
+ w = max(min(dx1, gx1) - max(dx0, gx0), 0)
385
+ h = max(min(dy1, gy1) - max(dy0, gy0), 0)
386
+ i = w * h
387
+ u = (dx1 - dx0) * (dy1 - dy0) + (gx1 - gx0) * (gy1 - gy0) - i
388
+ intersect += i
389
+ union += u
390
+ elif d is None and g is not None:
391
+ union += (g[2] - g[0]) * (g[3] - g[1])
392
+ elif d is not None and g is None:
393
+ union += (d[2] - d[0]) * (d[3] - d[1])
394
+ else:
395
+ raise TrackEvalException('BoxFormat not implemented')
396
+ if intersect > union:
397
+ raise TrackEvalException("Intersection value > union value. Are the box values corrupted?")
398
+ return intersect / union if union > 0 else 0
399
+
400
+ @staticmethod
401
+ def _compute_mask_track_iou(dt_track, gt_track):
402
+ """
403
+ Calculates the track IoU for one detected track and one ground truth track for segmentation masks
404
+ :param dt_track: the detected track (format: dictionary with frame index as keys and
405
+ pycocotools rle encoded masks as values)
406
+ :param gt_track: the ground truth track (format: dictionary with frame index as keys and
407
+ pycocotools rle encoded masks as values)
408
+ :return: the track IoU
409
+ """
410
+ # only loaded when needed to reduce minimum requirements
411
+ from pycocotools import mask as mask_utils
412
+
413
+ intersect = .0
414
+ union = .0
415
+ image_ids = set(gt_track.keys()) | set(dt_track.keys())
416
+ for image in image_ids:
417
+ g = gt_track.get(image, None)
418
+ d = dt_track.get(image, None)
419
+ if d and g:
420
+ intersect += mask_utils.area(mask_utils.merge([d, g], True))
421
+ union += mask_utils.area(mask_utils.merge([d, g], False))
422
+ elif not d and g:
423
+ union += mask_utils.area(g)
424
+ elif d and not g:
425
+ union += mask_utils.area(d)
426
+ if union < 0.0 - np.finfo('float').eps:
427
+ raise TrackEvalException("Union value < 0. Are the segmentaions corrupted?")
428
+ if intersect > union:
429
+ raise TrackEvalException("Intersection value > union value. Are the segmentations corrupted?")
430
+ iou = intersect / union if union > 0.0 + np.finfo('float').eps else 0.0
431
+ return iou
432
+
433
+ @staticmethod
434
+ def _compute_track_ious(dt, gt, iou_function='bbox', boxformat='xywh'):
435
+ """
436
+ Calculate track IoUs for a set of ground truth tracks and a set of detected tracks
437
+ """
438
+
439
+ if len(gt) == 0 and len(dt) == 0:
440
+ return []
441
+
442
+ if iou_function == 'bbox':
443
+ track_iou_function = partial(TrackMAP._compute_bb_track_iou, boxformat=boxformat)
444
+ elif iou_function == 'mask':
445
+ track_iou_function = partial(TrackMAP._compute_mask_track_iou)
446
+ else:
447
+ raise Exception('IoU function not implemented')
448
+
449
+ ious = np.zeros([len(dt), len(gt)])
450
+ for i, j in np.ndindex(ious.shape):
451
+ ious[i, j] = track_iou_function(dt[i], gt[j])
452
+ return ious
453
+
454
+ @staticmethod
455
+ def _row_print(*argv):
456
+ """Prints results in an evenly spaced rows, with more space in first row"""
457
+ if len(argv) == 1:
458
+ argv = argv[0]
459
+ to_print = '%-40s' % argv[0]
460
+ for v in argv[1:]:
461
+ to_print += '%-12s' % str(v)
462
+ print(to_print)
avism/data/aviseval/metrics/vace.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.optimize import linear_sum_assignment
3
+ from ._base_metric import _BaseMetric
4
+ from .. import _timing
5
+
6
+
7
+ class VACE(_BaseMetric):
8
+ """Class which implements the VACE metrics.
9
+
10
+ The metrics are described in:
11
+ Manohar et al. (2006) "Performance Evaluation of Object Detection and Tracking in Video"
12
+ https://link.springer.com/chapter/10.1007/11612704_16
13
+
14
+ This implementation uses the "relaxed" variant of the metrics,
15
+ where an overlap threshold is applied in each frame.
16
+ """
17
+
18
+ def __init__(self, config=None):
19
+ super().__init__()
20
+ self.integer_fields = ['VACE_IDs', 'VACE_GT_IDs', 'num_non_empty_timesteps']
21
+ self.float_fields = ['STDA', 'ATA', 'FDA', 'SFDA']
22
+ self.fields = self.integer_fields + self.float_fields
23
+ self.summary_fields = ['SFDA', 'ATA']
24
+
25
+ # Fields that are accumulated over multiple videos.
26
+ self._additive_fields = self.integer_fields + ['STDA', 'FDA']
27
+
28
+ self.threshold = 0.5
29
+
30
+ @_timing.time
31
+ def eval_sequence(self, data):
32
+ """Calculates VACE metrics for one sequence.
33
+
34
+ Depends on the fields:
35
+ data['num_gt_ids']
36
+ data['num_tracker_ids']
37
+ data['gt_ids']
38
+ data['tracker_ids']
39
+ data['similarity_scores']
40
+ """
41
+ res = {}
42
+
43
+ # Obtain Average Tracking Accuracy (ATA) using track correspondence.
44
+ # Obtain counts necessary to compute temporal IOU.
45
+ # Assume that integer counts can be represented exactly as floats.
46
+ potential_matches_count = np.zeros((data['num_gt_ids'], data['num_tracker_ids']))
47
+ gt_id_count = np.zeros(data['num_gt_ids'])
48
+ tracker_id_count = np.zeros(data['num_tracker_ids'])
49
+ both_present_count = np.zeros((data['num_gt_ids'], data['num_tracker_ids']))
50
+ for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(data['gt_ids'], data['tracker_ids'])):
51
+ # Count the number of frames in which two tracks satisfy the overlap criterion.
52
+ matches_mask = np.greater_equal(data['similarity_scores'][t], self.threshold)
53
+ match_idx_gt, match_idx_tracker = np.nonzero(matches_mask)
54
+ potential_matches_count[gt_ids_t[match_idx_gt], tracker_ids_t[match_idx_tracker]] += 1
55
+ # Count the number of frames in which the tracks are present.
56
+ gt_id_count[gt_ids_t] += 1
57
+ tracker_id_count[tracker_ids_t] += 1
58
+ both_present_count[gt_ids_t[:, np.newaxis], tracker_ids_t[np.newaxis, :]] += 1
59
+ # Number of frames in which either track is present (union of the two sets of frames).
60
+ union_count = (gt_id_count[:, np.newaxis]
61
+ + tracker_id_count[np.newaxis, :]
62
+ - both_present_count)
63
+ # The denominator should always be non-zero if all tracks are non-empty.
64
+ with np.errstate(divide='raise', invalid='raise'):
65
+ temporal_iou = potential_matches_count / union_count
66
+ # Find assignment that maximizes temporal IOU.
67
+ match_rows, match_cols = linear_sum_assignment(-temporal_iou)
68
+ res['STDA'] = temporal_iou[match_rows, match_cols].sum()
69
+ res['VACE_IDs'] = data['num_tracker_ids']
70
+ res['VACE_GT_IDs'] = data['num_gt_ids']
71
+
72
+ # Obtain Frame Detection Accuracy (FDA) using per-frame correspondence.
73
+ non_empty_count = 0
74
+ fda = 0
75
+ for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(data['gt_ids'], data['tracker_ids'])):
76
+ n_g = len(gt_ids_t)
77
+ n_d = len(tracker_ids_t)
78
+ if not (n_g or n_d):
79
+ continue
80
+ # n_g > 0 or n_d > 0
81
+ non_empty_count += 1
82
+ if not (n_g and n_d):
83
+ continue
84
+ # n_g > 0 and n_d > 0
85
+ spatial_overlap = data['similarity_scores'][t]
86
+ match_rows, match_cols = linear_sum_assignment(-spatial_overlap)
87
+ overlap_ratio = spatial_overlap[match_rows, match_cols].sum()
88
+ fda += overlap_ratio / (0.5 * (n_g + n_d))
89
+ res['FDA'] = fda
90
+ res['num_non_empty_timesteps'] = non_empty_count
91
+
92
+ res.update(self._compute_final_fields(res))
93
+ return res
94
+
95
+ def combine_classes_class_averaged(self, all_res, ignore_empty_classes=True):
96
+ """Combines metrics across all classes by averaging over the class values.
97
+ If 'ignore_empty_classes' is True, then it only sums over classes with at least one gt or predicted detection.
98
+ """
99
+ res = {}
100
+ for field in self.fields:
101
+ if ignore_empty_classes:
102
+ res[field] = np.mean([v[field] for v in all_res.values()
103
+ if v['VACE_GT_IDs'] > 0 or v['VACE_IDs'] > 0], axis=0)
104
+ else:
105
+ res[field] = np.mean([v[field] for v in all_res.values()], axis=0)
106
+ return res
107
+
108
+ def combine_classes_det_averaged(self, all_res):
109
+ """Combines metrics across all classes by averaging over the detection values"""
110
+ res = {}
111
+ for field in self._additive_fields:
112
+ res[field] = _BaseMetric._combine_sum(all_res, field)
113
+ res = self._compute_final_fields(res)
114
+ return res
115
+
116
+ def combine_sequences(self, all_res):
117
+ """Combines metrics across all sequences"""
118
+ res = {}
119
+ for header in self._additive_fields:
120
+ res[header] = _BaseMetric._combine_sum(all_res, header)
121
+ res.update(self._compute_final_fields(res))
122
+ return res
123
+
124
+ @staticmethod
125
+ def _compute_final_fields(additive):
126
+ final = {}
127
+ with np.errstate(invalid='ignore'): # Permit nan results.
128
+ final['ATA'] = (additive['STDA'] /
129
+ (0.5 * (additive['VACE_IDs'] + additive['VACE_GT_IDs'])))
130
+ final['SFDA'] = additive['FDA'] / additive['num_non_empty_timesteps']
131
+ return final
avism/data/aviseval/plotting.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import numpy as np
4
+ from .utils import TrackEvalException
5
+
6
+
7
+ def plot_compare_trackers(tracker_folder, tracker_list, cls, output_folder, plots_list=None):
8
+ """Create plots which compare metrics across different trackers."""
9
+ # Define what to plot
10
+ if plots_list is None:
11
+ plots_list = get_default_plots_list()
12
+
13
+ # Load data
14
+ data = load_multiple_tracker_summaries(tracker_folder, tracker_list, cls)
15
+ out_loc = os.path.join(output_folder, cls)
16
+
17
+ # Plot
18
+ for args in plots_list:
19
+ create_comparison_plot(data, out_loc, *args)
20
+
21
+
22
+ def get_default_plots_list():
23
+ # y_label, x_label, sort_label, bg_label, bg_function
24
+ plots_list = [
25
+ ['AssA', 'DetA', 'HOTA', 'HOTA', 'geometric_mean'],
26
+ ['AssPr', 'AssRe', 'HOTA', 'AssA', 'jaccard'],
27
+ ['DetPr', 'DetRe', 'HOTA', 'DetA', 'jaccard'],
28
+ ['HOTA(0)', 'LocA(0)', 'HOTA', 'HOTALocA(0)', 'multiplication'],
29
+ ['HOTA', 'LocA', 'HOTA', None, None],
30
+
31
+ ['HOTA', 'MOTA', 'HOTA', None, None],
32
+ ['HOTA', 'IDF1', 'HOTA', None, None],
33
+ ['IDF1', 'MOTA', 'HOTA', None, None],
34
+ ]
35
+ return plots_list
36
+
37
+
38
+ def load_multiple_tracker_summaries(tracker_folder, tracker_list, cls):
39
+ """Loads summary data for multiple trackers."""
40
+ data = {}
41
+ for tracker in tracker_list:
42
+ with open(os.path.join(tracker_folder, tracker, cls + '_summary.txt')) as f:
43
+ keys = next(f).split(' ')
44
+ done = False
45
+ while not done:
46
+ values = next(f).split(' ')
47
+ if len(values) == len(keys):
48
+ done = True
49
+ data[tracker] = dict(zip(keys, map(float, values)))
50
+ return data
51
+
52
+
53
+ def create_comparison_plot(data, out_loc, y_label, x_label, sort_label, bg_label=None, bg_function=None, settings=None):
54
+ """ Creates a scatter plot comparing multiple trackers between two metric fields, with one on the x-axis and the
55
+ other on the y axis. Adds pareto optical lines and (optionally) a background contour.
56
+
57
+ Inputs:
58
+ data: dict of dicts such that data[tracker_name][metric_field_name] = float
59
+ y_label: the metric_field_name to be plotted on the y-axis
60
+ x_label: the metric_field_name to be plotted on the x-axis
61
+ sort_label: the metric_field_name by which trackers are ordered and ranked
62
+ bg_label: the metric_field_name by which (optional) background contours are plotted
63
+ bg_function: the (optional) function bg_function(x,y) which converts the x_label / y_label values into bg_label.
64
+ settings: dict of plot settings with keys:
65
+ 'gap_val': gap between axis ticks and bg curves.
66
+ 'num_to_plot': maximum number of trackers to plot
67
+ """
68
+
69
+ # Only loaded when run to reduce minimum requirements
70
+ from matplotlib import pyplot as plt
71
+
72
+ # Get plot settings
73
+ if settings is None:
74
+ gap_val = 2
75
+ num_to_plot = 20
76
+ else:
77
+ gap_val = settings['gap_val']
78
+ num_to_plot = settings['num_to_plot']
79
+
80
+ if (bg_label is None) != (bg_function is None):
81
+ raise TrackEvalException('bg_function and bg_label must either be both given or neither given.')
82
+
83
+ # Extract data
84
+ tracker_names = np.array(list(data.keys()))
85
+ sort_index = np.array([data[t][sort_label] for t in tracker_names]).argsort()[::-1]
86
+ x_values = np.array([data[t][x_label] for t in tracker_names])[sort_index][:num_to_plot]
87
+ y_values = np.array([data[t][y_label] for t in tracker_names])[sort_index][:num_to_plot]
88
+
89
+ # Print info on what is being plotted
90
+ tracker_names = tracker_names[sort_index][:num_to_plot]
91
+ print('\nPlotting %s vs %s, for the following (ordered) trackers:' % (y_label, x_label))
92
+ for i, name in enumerate(tracker_names):
93
+ print('%i: %s' % (i+1, name))
94
+
95
+ # Find best fitting boundaries for data
96
+ boundaries = _get_boundaries(x_values, y_values, round_val=gap_val/2)
97
+
98
+ fig = plt.figure()
99
+
100
+ # Plot background contour
101
+ if bg_function is not None:
102
+ _plot_bg_contour(bg_function, boundaries, gap_val)
103
+
104
+ # Plot pareto optimal lines
105
+ _plot_pareto_optimal_lines(x_values, y_values)
106
+
107
+ # Plot data points with number labels
108
+ labels = np.arange(len(y_values)) + 1
109
+ plt.plot(x_values, y_values, 'b.', markersize=15)
110
+ for xx, yy, l in zip(x_values, y_values, labels):
111
+ plt.text(xx, yy, str(l), color="red", fontsize=15)
112
+
113
+ # Add extra explanatory text to plots
114
+ plt.text(0, -0.11, 'label order:\nHOTA', horizontalalignment='left', verticalalignment='center',
115
+ transform=fig.axes[0].transAxes, color="red", fontsize=12)
116
+ if bg_label is not None:
117
+ plt.text(1, -0.11, 'curve values:\n' + bg_label, horizontalalignment='right', verticalalignment='center',
118
+ transform=fig.axes[0].transAxes, color="grey", fontsize=12)
119
+
120
+ plt.xlabel(x_label, fontsize=15)
121
+ plt.ylabel(y_label, fontsize=15)
122
+ title = y_label + ' vs ' + x_label
123
+ if bg_label is not None:
124
+ title += ' (' + bg_label + ')'
125
+ plt.title(title, fontsize=17)
126
+ plt.xticks(np.arange(0, 100, gap_val))
127
+ plt.yticks(np.arange(0, 100, gap_val))
128
+ min_x, max_x, min_y, max_y = boundaries
129
+ plt.xlim(min_x, max_x)
130
+ plt.ylim(min_y, max_y)
131
+ plt.gca().set_aspect('equal', adjustable='box')
132
+ plt.tight_layout()
133
+
134
+ os.makedirs(out_loc, exist_ok=True)
135
+ filename = os.path.join(out_loc, title.replace(' ', '_'))
136
+ plt.savefig(filename + '.pdf', bbox_inches='tight', pad_inches=0.05)
137
+ plt.savefig(filename + '.png', bbox_inches='tight', pad_inches=0.05)
138
+
139
+
140
+ def _get_boundaries(x_values, y_values, round_val):
141
+ x1 = np.min(np.floor((x_values - 0.5) / round_val) * round_val)
142
+ x2 = np.max(np.ceil((x_values + 0.5) / round_val) * round_val)
143
+ y1 = np.min(np.floor((y_values - 0.5) / round_val) * round_val)
144
+ y2 = np.max(np.ceil((y_values + 0.5) / round_val) * round_val)
145
+ x_range = x2 - x1
146
+ y_range = y2 - y1
147
+ max_range = max(x_range, y_range)
148
+ x_center = (x1 + x2) / 2
149
+ y_center = (y1 + y2) / 2
150
+ min_x = max(x_center - max_range / 2, 0)
151
+ max_x = min(x_center + max_range / 2, 100)
152
+ min_y = max(y_center - max_range / 2, 0)
153
+ max_y = min(y_center + max_range / 2, 100)
154
+ return min_x, max_x, min_y, max_y
155
+
156
+
157
+ def geometric_mean(x, y):
158
+ return np.sqrt(x * y)
159
+
160
+
161
+ def jaccard(x, y):
162
+ x = x / 100
163
+ y = y / 100
164
+ return 100 * (x * y) / (x + y - x * y)
165
+
166
+
167
+ def multiplication(x, y):
168
+ return x * y / 100
169
+
170
+
171
+ bg_function_dict = {
172
+ "geometric_mean": geometric_mean,
173
+ "jaccard": jaccard,
174
+ "multiplication": multiplication,
175
+ }
176
+
177
+
178
+ def _plot_bg_contour(bg_function, plot_boundaries, gap_val):
179
+ """ Plot background contour. """
180
+
181
+ # Only loaded when run to reduce minimum requirements
182
+ from matplotlib import pyplot as plt
183
+
184
+ # Plot background contour
185
+ min_x, max_x, min_y, max_y = plot_boundaries
186
+ x = np.arange(min_x, max_x, 0.1)
187
+ y = np.arange(min_y, max_y, 0.1)
188
+ x_grid, y_grid = np.meshgrid(x, y)
189
+ if bg_function in bg_function_dict.keys():
190
+ z_grid = bg_function_dict[bg_function](x_grid, y_grid)
191
+ else:
192
+ raise TrackEvalException("background plotting function '%s' is not defined." % bg_function)
193
+ levels = np.arange(0, 100, gap_val)
194
+ con = plt.contour(x_grid, y_grid, z_grid, levels, colors='grey')
195
+
196
+ def bg_format(val):
197
+ s = '{:1f}'.format(val)
198
+ return '{:.0f}'.format(val) if s[-1] == '0' else s
199
+
200
+ con.levels = [bg_format(val) for val in con.levels]
201
+ plt.clabel(con, con.levels, inline=True, fmt='%r', fontsize=8)
202
+
203
+
204
+ def _plot_pareto_optimal_lines(x_values, y_values):
205
+ """ Plot pareto optimal lines """
206
+
207
+ # Only loaded when run to reduce minimum requirements
208
+ from matplotlib import pyplot as plt
209
+
210
+ # Plot pareto optimal lines
211
+ cxs = x_values
212
+ cys = y_values
213
+ best_y = np.argmax(cys)
214
+ x_pareto = [0, cxs[best_y]]
215
+ y_pareto = [cys[best_y], cys[best_y]]
216
+ t = 2
217
+ remaining = cxs > x_pareto[t - 1]
218
+ cys = cys[remaining]
219
+ cxs = cxs[remaining]
220
+ while len(cxs) > 0 and len(cys) > 0:
221
+ best_y = np.argmax(cys)
222
+ x_pareto += [x_pareto[t - 1], cxs[best_y]]
223
+ y_pareto += [cys[best_y], cys[best_y]]
224
+ t += 2
225
+ remaining = cxs > x_pareto[t - 1]
226
+ cys = cys[remaining]
227
+ cxs = cxs[remaining]
228
+ x_pareto.append(x_pareto[t - 1])
229
+ y_pareto.append(0)
230
+ plt.plot(np.array(x_pareto), np.array(y_pareto), '--r')
avism/data/aviseval/utils.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import csv
4
+ import argparse
5
+ from collections import OrderedDict
6
+
7
+
8
+ def init_config(config, default_config, name=None):
9
+ """Initialise non-given config values with defaults"""
10
+ if config is None:
11
+ config = default_config
12
+ else:
13
+ for k in default_config.keys():
14
+ if k not in config.keys():
15
+ config[k] = default_config[k]
16
+ if name and config['PRINT_CONFIG']:
17
+ print('\n%s Config:' % name)
18
+ for c in config.keys():
19
+ print('%-20s : %-30s' % (c, config[c]))
20
+ return config
21
+
22
+
23
+ def update_config(config):
24
+ """
25
+ Parse the arguments of a script and updates the config values for a given value if specified in the arguments.
26
+ :param config: the config to update
27
+ :return: the updated config
28
+ """
29
+ parser = argparse.ArgumentParser()
30
+ for setting in config.keys():
31
+ if type(config[setting]) == list or type(config[setting]) == type(None):
32
+ parser.add_argument("--" + setting, nargs='+')
33
+ else:
34
+ parser.add_argument("--" + setting)
35
+ args = parser.parse_args().__dict__
36
+ for setting in args.keys():
37
+ if args[setting] is not None:
38
+ if type(config[setting]) == type(True):
39
+ if args[setting] == 'True':
40
+ x = True
41
+ elif args[setting] == 'False':
42
+ x = False
43
+ else:
44
+ raise Exception('Command line parameter ' + setting + 'must be True or False')
45
+ elif type(config[setting]) == type(1):
46
+ x = int(args[setting])
47
+ elif type(args[setting]) == type(None):
48
+ x = None
49
+ else:
50
+ x = args[setting]
51
+ config[setting] = x
52
+ return config
53
+
54
+
55
+ def get_code_path():
56
+ """Get base path where code is"""
57
+ return os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
58
+
59
+
60
+ def validate_metrics_list(metrics_list):
61
+ """Get names of metric class and ensures they are unique, further checks that the fields within each metric class
62
+ do not have overlapping names.
63
+ """
64
+ metric_names = [metric.get_name() for metric in metrics_list]
65
+ # check metric names are unique
66
+ if len(metric_names) != len(set(metric_names)):
67
+ raise TrackEvalException('Code being run with multiple metrics of the same name')
68
+ fields = []
69
+ for m in metrics_list:
70
+ fields += m.fields
71
+ # check metric fields are unique
72
+ if len(fields) != len(set(fields)):
73
+ raise TrackEvalException('Code being run with multiple metrics with fields of the same name')
74
+ return metric_names
75
+
76
+
77
+ def write_summary_results(summaries, cls, output_folder):
78
+ """Write summary results to file"""
79
+
80
+ fields = sum([list(s.keys()) for s in summaries], [])
81
+ values = sum([list(s.values()) for s in summaries], [])
82
+
83
+ # In order to remain consistent upon new fields being adding, for each of the following fields if they are present
84
+ # they will be output in the summary first in the order below. Any further fields will be output in the order each
85
+ # metric family is called, and within each family either in the order they were added to the dict (python >= 3.6) or
86
+ # randomly (python < 3.6).
87
+ default_order = ['HOTA', 'DetA', 'AssA', 'DetRe', 'DetPr', 'AssRe', 'AssPr', 'LocA', 'OWTA', 'HOTA(0)', 'LocA(0)',
88
+ 'HOTALocA(0)', 'MOTA', 'MOTP', 'MODA', 'CLR_Re', 'CLR_Pr', 'MTR', 'PTR', 'MLR', 'CLR_TP', 'CLR_FN',
89
+ 'CLR_FP', 'IDSW', 'MT', 'PT', 'ML', 'Frag', 'sMOTA', 'IDF1', 'IDR', 'IDP', 'IDTP', 'IDFN', 'IDFP',
90
+ 'Dets', 'GT_Dets', 'IDs', 'GT_IDs']
91
+ default_ordered_dict = OrderedDict(zip(default_order, [None for _ in default_order]))
92
+ for f, v in zip(fields, values):
93
+ default_ordered_dict[f] = v
94
+ for df in default_order:
95
+ if default_ordered_dict[df] is None:
96
+ del default_ordered_dict[df]
97
+ fields = list(default_ordered_dict.keys())
98
+ values = list(default_ordered_dict.values())
99
+
100
+ out_file = os.path.join(output_folder, cls + '_summary.txt')
101
+ os.makedirs(os.path.dirname(out_file), exist_ok=True)
102
+ with open(out_file, 'w', newline='') as f:
103
+ writer = csv.writer(f, delimiter=' ')
104
+ writer.writerow(fields)
105
+ writer.writerow(values)
106
+
107
+
108
+ def write_detailed_results(details, cls, output_folder):
109
+ """Write detailed results to file"""
110
+ sequences = details[0].keys()
111
+ fields = ['seq'] + sum([list(s['COMBINED_SEQ'].keys()) for s in details], [])
112
+ out_file = os.path.join(output_folder, cls + '_detailed.csv')
113
+ os.makedirs(os.path.dirname(out_file), exist_ok=True)
114
+ with open(out_file, 'w', newline='') as f:
115
+ writer = csv.writer(f)
116
+ writer.writerow(fields)
117
+ for seq in sorted(sequences):
118
+ if seq == 'COMBINED_SEQ':
119
+ continue
120
+ writer.writerow([seq] + sum([list(s[seq].values()) for s in details], []))
121
+ writer.writerow(['COMBINED'] + sum([list(s['COMBINED_SEQ'].values()) for s in details], []))
122
+
123
+
124
+ def load_detail(file):
125
+ """Loads detailed data for a tracker."""
126
+ data = {}
127
+ with open(file) as f:
128
+ for i, row_text in enumerate(f):
129
+ row = row_text.replace('\r', '').replace('\n', '').split(',')
130
+ if i == 0:
131
+ keys = row[1:]
132
+ continue
133
+ current_values = row[1:]
134
+ seq = row[0]
135
+ if seq == 'COMBINED':
136
+ seq = 'COMBINED_SEQ'
137
+ if (len(current_values) == len(keys)) and seq != '':
138
+ data[seq] = {}
139
+ for key, value in zip(keys, current_values):
140
+ data[seq][key] = float(value)
141
+ return data
142
+
143
+
144
+ class TrackEvalException(Exception):
145
+ """Custom exception for catching expected errors."""
146
+ ...
avism/data/build.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import logging
3
+ import torch.utils.data
4
+
5
+ from detectron2.config import CfgNode, configurable
6
+ from detectron2.data.build import (
7
+ build_batch_data_loader,
8
+ load_proposals_into_dataset,
9
+ trivial_batch_collator,
10
+ )
11
+ from detectron2.data.catalog import DatasetCatalog
12
+ from detectron2.data.common import DatasetFromList, MapDataset
13
+ from detectron2.data.dataset_mapper import DatasetMapper
14
+ from detectron2.data.samplers import InferenceSampler, TrainingSampler
15
+ from detectron2.utils.comm import get_world_size
16
+
17
+
18
+ def _compute_num_images_per_worker(cfg: CfgNode):
19
+ num_workers = get_world_size()
20
+ images_per_batch = cfg.SOLVER.IMS_PER_BATCH
21
+ assert (
22
+ images_per_batch % num_workers == 0
23
+ ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format(
24
+ images_per_batch, num_workers
25
+ )
26
+ assert (
27
+ images_per_batch >= num_workers
28
+ ), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format(
29
+ images_per_batch, num_workers
30
+ )
31
+ images_per_worker = images_per_batch // num_workers
32
+ return images_per_worker
33
+
34
+
35
+ def filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names):
36
+ """
37
+ Filter out images with none annotations or only crowd annotations
38
+ (i.e., images without non-crowd annotations).
39
+ A common training-time preprocessing on COCO dataset.
40
+
41
+ Args:
42
+ dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
43
+
44
+ Returns:
45
+ list[dict]: the same format, but filtered.
46
+ """
47
+ num_before = len(dataset_dicts)
48
+
49
+ def valid(anns):
50
+ for ann in anns:
51
+ if isinstance(ann, list):
52
+ for instance in ann:
53
+ if instance.get("iscrowd", 0) == 0:
54
+ return True
55
+ else:
56
+ if ann.get("iscrowd", 0) == 0:
57
+ return True
58
+ return False
59
+
60
+ dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])]
61
+ num_after = len(dataset_dicts)
62
+ logger = logging.getLogger(__name__)
63
+ logger.info(
64
+ "Removed {} images with no usable annotations. {} images left.".format(
65
+ num_before - num_after, num_after
66
+ )
67
+ )
68
+ return dataset_dicts
69
+
70
+
71
+ def get_detection_dataset_dicts(
72
+ dataset_names, filter_empty=True, proposal_files=None
73
+ ):
74
+ """
75
+ Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
76
+
77
+ Args:
78
+ dataset_names (str or list[str]): a dataset name or a list of dataset names
79
+ filter_empty (bool): whether to filter out images without instance annotations
80
+ proposal_files (list[str]): if given, a list of object proposal files
81
+ that match each dataset in `dataset_names`.
82
+
83
+ Returns:
84
+ list[dict]: a list of dicts following the standard dataset dict format.
85
+ """
86
+ if isinstance(dataset_names, str):
87
+ dataset_names = [dataset_names]
88
+ assert len(dataset_names)
89
+ dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]
90
+ for dataset_name, dicts in zip(dataset_names, dataset_dicts):
91
+ assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
92
+
93
+ if proposal_files is not None:
94
+ assert len(dataset_names) == len(proposal_files)
95
+ # load precomputed proposals from proposal files
96
+ dataset_dicts = [
97
+ load_proposals_into_dataset(dataset_i_dicts, proposal_file)
98
+ for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
99
+ ]
100
+
101
+ dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
102
+
103
+ has_instances = "annotations" in dataset_dicts[0]
104
+ if filter_empty and has_instances:
105
+ dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names)
106
+
107
+ assert len(dataset_dicts), "No valid data found in {}.".format(",".join(dataset_names))
108
+ return dataset_dicts
109
+
110
+
111
+ def _train_loader_from_config(cfg, mapper, dataset_name=None, *, dataset=None, sampler=None):
112
+ if dataset is None:
113
+ dataset = get_detection_dataset_dicts(
114
+ dataset_name,
115
+ filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
116
+ proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
117
+ )
118
+
119
+ if mapper is None:
120
+ mapper = DatasetMapper(cfg, True)
121
+
122
+ if sampler is None:
123
+ sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
124
+ logger = logging.getLogger(__name__)
125
+ logger.info("Using training sampler {}".format(sampler_name))
126
+ sampler = TrainingSampler(len(dataset))
127
+
128
+ return {
129
+ "dataset": dataset,
130
+ "sampler": sampler,
131
+ "mapper": mapper,
132
+ "total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
133
+ "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
134
+ "num_workers": cfg.DATALOADER.NUM_WORKERS,
135
+ }
136
+
137
+
138
+ # TODO can allow dataset as an iterable or IterableDataset to make this function more general
139
+ @configurable(from_config=_train_loader_from_config)
140
+ def build_detection_train_loader(
141
+ dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0
142
+ ):
143
+ """
144
+ Build a dataloader for object detection with some default features.
145
+ This interface is experimental.
146
+
147
+ Args:
148
+ dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
149
+ or a map-style pytorch dataset. They can be obtained by using
150
+ :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
151
+ mapper (callable): a callable which takes a sample (dict) from dataset and
152
+ returns the format to be consumed by the model.
153
+ When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
154
+ sampler (torch.utils.data.sampler.Sampler or None): a sampler that
155
+ produces indices to be applied on ``dataset``.
156
+ Default to :class:`TrainingSampler`, which coordinates a random shuffle
157
+ sequence across all workers.
158
+ total_batch_size (int): total batch size across all workers. Batching
159
+ simply puts data into a list.
160
+ aspect_ratio_grouping (bool): whether to group images with similar
161
+ aspect ratio for efficiency. When enabled, it requires each
162
+ element in dataset be a dict with keys "width" and "height".
163
+ num_workers (int): number of parallel data loading workers
164
+
165
+ Returns:
166
+ torch.utils.data.DataLoader: a dataloader. Each output from it is a
167
+ ``list[mapped_element]`` of length ``total_batch_size / num_workers``,
168
+ where ``mapped_element`` is produced by the ``mapper``.
169
+ """
170
+ if isinstance(dataset, list):
171
+ dataset = DatasetFromList(dataset, copy=False)
172
+ if mapper is not None:
173
+ dataset = MapDataset(dataset, mapper)
174
+ if sampler is None:
175
+ sampler = TrainingSampler(len(dataset))
176
+ assert isinstance(sampler, torch.utils.data.sampler.Sampler)
177
+ return build_batch_data_loader(
178
+ dataset,
179
+ sampler,
180
+ total_batch_size,
181
+ aspect_ratio_grouping=aspect_ratio_grouping,
182
+ num_workers=num_workers,
183
+ )
184
+
185
+
186
+ def _test_loader_from_config(cfg, dataset_name, mapper=None):
187
+ """
188
+ Uses the given `dataset_name` argument (instead of the names in cfg), because the
189
+ standard practice is to evaluate each test set individually (not combining them).
190
+ """
191
+ dataset = get_detection_dataset_dicts(
192
+ [dataset_name],
193
+ filter_empty=False,
194
+ proposal_files=[
195
+ cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)]
196
+ ]
197
+ if cfg.MODEL.LOAD_PROPOSALS
198
+ else None,
199
+ )
200
+ if mapper is None:
201
+ mapper = DatasetMapper(cfg, False)
202
+ return {"dataset": dataset, "mapper": mapper, "num_workers": cfg.DATALOADER.NUM_WORKERS}
203
+
204
+
205
+ @configurable(from_config=_test_loader_from_config)
206
+ def build_detection_test_loader(dataset, *, mapper, num_workers=0):
207
+ """
208
+ Similar to `build_detection_train_loader`, but uses a batch size of 1.
209
+ This interface is experimental.
210
+
211
+ Args:
212
+ dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
213
+ or a map-style pytorch dataset. They can be obtained by using
214
+ :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
215
+ mapper (callable): a callable which takes a sample (dict) from dataset
216
+ and returns the format to be consumed by the model.
217
+ When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
218
+ num_workers (int): number of parallel data loading workers
219
+
220
+ Returns:
221
+ DataLoader: a torch DataLoader, that loads the given detection
222
+ dataset, with test-time transformation and batching.
223
+
224
+ Examples:
225
+ ::
226
+ data_loader = build_detection_test_loader(
227
+ DatasetRegistry.get("my_test"),
228
+ mapper=DatasetMapper(...))
229
+
230
+ # or, instantiate with a CfgNode:
231
+ data_loader = build_detection_test_loader(cfg, "my_test")
232
+ """
233
+ if isinstance(dataset, list):
234
+ dataset = DatasetFromList(dataset, copy=False)
235
+ if mapper is not None:
236
+ dataset = MapDataset(dataset, mapper)
237
+ sampler = InferenceSampler(len(dataset))
238
+ # Always use 1 image per worker during inference since this is the
239
+ # standard when reporting inference time in papers.
240
+ batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, 1, drop_last=False)
241
+ data_loader = torch.utils.data.DataLoader(
242
+ dataset,
243
+ num_workers=num_workers,
244
+ batch_sampler=batch_sampler,
245
+ collate_fn=trivial_batch_collator,
246
+ )
247
+ return data_loader
avism/data/dataset_mapper.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import random
4
+ import numpy as np
5
+ from typing import List, Union
6
+ import torch
7
+
8
+ from detectron2.config import configurable
9
+ from detectron2.structures import (
10
+ BitMasks,
11
+ Boxes,
12
+ BoxMode,
13
+ Instances,
14
+ )
15
+
16
+ from detectron2.data import detection_utils as utils
17
+ from detectron2.data import transforms as T
18
+
19
+ from .augmentation import build_augmentation
20
+
21
+ __all__ = ["AVISDatasetMapper"]
22
+
23
+
24
+ def filter_empty_instances(instances, by_box=True, by_mask=True, box_threshold=1e-5):
25
+ """
26
+ Filter out empty instances in an `Instances` object.
27
+
28
+ Args:
29
+ instances (Instances):
30
+ by_box (bool): whether to filter out instances with empty boxes
31
+ by_mask (bool): whether to filter out instances with empty masks
32
+ box_threshold (float): minimum width and height to be considered non-empty
33
+
34
+ Returns:
35
+ Instances: the filtered instances.
36
+ """
37
+ assert by_box or by_mask
38
+ r = []
39
+ if by_box:
40
+ r.append(instances.gt_boxes.nonempty(threshold=box_threshold))
41
+ if instances.has("gt_masks") and by_mask:
42
+ r.append(instances.gt_masks.nonempty())
43
+ r.append(instances.gt_classes != -1)
44
+
45
+ if not r:
46
+ return instances
47
+ m = r[0]
48
+ for x in r[1:]:
49
+ m = m & x
50
+
51
+ instances.gt_ids[~m] = -1
52
+ return instances
53
+
54
+
55
+ def _get_dummy_anno(num_classes):
56
+ return {
57
+ "iscrowd": 0,
58
+ "category_id": num_classes,
59
+ "id": -1,
60
+ "bbox": np.array([0, 0, 0, 0]),
61
+ "bbox_mode": BoxMode.XYXY_ABS,
62
+ "segmentation": [np.array([0.0] * 6)]
63
+ }
64
+
65
+
66
+ def avis_annotations_to_instances(annos, image_size):
67
+ """
68
+ Create an :class:`Instances` object used by the models,
69
+ from instance annotations in the dataset dict.
70
+
71
+ Args:
72
+ annos (list[dict]): a list of instance annotations in one image, each
73
+ element for one instance.
74
+ image_size (tuple): height, width
75
+
76
+ Returns:
77
+ Instances:
78
+ It will contain fields "gt_boxes", "gt_classes", "gt_ids",
79
+ "gt_masks", if they can be obtained from `annos`.
80
+ This is the format that builtin models expect.
81
+ """
82
+ boxes = [BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS) for obj in annos]
83
+ target = Instances(image_size)
84
+ target.gt_boxes = Boxes(boxes)
85
+
86
+ classes = [int(obj["category_id"]) for obj in annos]
87
+ classes = torch.tensor(classes, dtype=torch.int64)
88
+ target.gt_classes = classes
89
+
90
+ ids = [int(obj["id"]) for obj in annos]
91
+ ids = torch.tensor(ids, dtype=torch.int64)
92
+ target.gt_ids = ids
93
+
94
+ if len(annos) and "segmentation" in annos[0]:
95
+ segms = [obj["segmentation"] for obj in annos]
96
+ masks = []
97
+ for segm in segms:
98
+ assert segm.ndim == 2, "Expect segmentation of 2 dimensions, got {}.".format(
99
+ segm.ndim
100
+ )
101
+ # mask array
102
+ masks.append(segm)
103
+ # torch.from_numpy does not support array with negative stride.
104
+ masks = BitMasks(
105
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
106
+ )
107
+ target.gt_masks = masks
108
+
109
+ return target
110
+
111
+
112
+ class AVISDatasetMapper:
113
+ """
114
+ A callable which takes a dataset dict in AVIS Dataset format,
115
+ and map it into a format used by the model.
116
+ """
117
+
118
+ @configurable
119
+ def __init__(
120
+ self,
121
+ is_train: bool,
122
+ *,
123
+ augmentations: List[Union[T.Augmentation, T.Transform]],
124
+ image_format: str,
125
+ use_instance_mask: bool = False,
126
+ sampling_frame_num: int = 2,
127
+ sampling_frame_range: int = 5,
128
+ sampling_frame_shuffle: bool = False,
129
+ num_classes: int = 26,
130
+ ):
131
+ """
132
+ NOTE: this interface is experimental.
133
+ Args:
134
+ is_train: whether it's used in training or inference
135
+ augmentations: a list of augmentations or deterministic transforms to apply
136
+ image_format: an image format supported by :func:`detection_utils.read_image`.
137
+ use_instance_mask: whether to process instance segmentation annotations, if available
138
+ """
139
+ # fmt: off
140
+ self.is_train = is_train
141
+ self.augmentations = T.AugmentationList(augmentations)
142
+ self.image_format = image_format
143
+ self.use_instance_mask = use_instance_mask
144
+ self.sampling_frame_num = sampling_frame_num
145
+ self.sampling_frame_range = sampling_frame_range
146
+ self.sampling_frame_shuffle = sampling_frame_shuffle
147
+ self.num_classes = num_classes
148
+ # fmt: on
149
+ logger = logging.getLogger(__name__)
150
+ mode = "training" if is_train else "inference"
151
+ logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}")
152
+
153
+ @classmethod
154
+ def from_config(cls, cfg, is_train: bool = True):
155
+ augs = build_augmentation(cfg, is_train)
156
+
157
+ sampling_frame_num = cfg.INPUT.SAMPLING_FRAME_NUM
158
+ sampling_frame_range = cfg.INPUT.SAMPLING_FRAME_RANGE
159
+ sampling_frame_shuffle = cfg.INPUT.SAMPLING_FRAME_SHUFFLE
160
+
161
+ ret = {
162
+ "is_train": is_train,
163
+ "augmentations": augs,
164
+ "image_format": cfg.INPUT.FORMAT,
165
+ "use_instance_mask": cfg.MODEL.MASK_ON,
166
+ "sampling_frame_num": sampling_frame_num,
167
+ "sampling_frame_range": sampling_frame_range,
168
+ "sampling_frame_shuffle": sampling_frame_shuffle,
169
+ "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
170
+ }
171
+
172
+ return ret
173
+
174
+ def __call__(self, dataset_dict):
175
+ """
176
+ Args:
177
+ dataset_dict (dict): Metadata of one video, in YTVIS Dataset format.
178
+
179
+ Returns:
180
+ dict: a format that builtin models in detectron2 accept
181
+ """
182
+ # TODO consider examining below deepcopy as it costs huge amount of computations.
183
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
184
+
185
+ video_length = dataset_dict["length"]
186
+ if self.is_train:
187
+ ref_frame = random.randrange(video_length)
188
+
189
+ start_idx = max(0, ref_frame-self.sampling_frame_range)
190
+ end_idx = min(video_length, ref_frame+self.sampling_frame_range + 1)
191
+
192
+ selected_idx = np.random.choice(
193
+ np.array(list(range(start_idx, ref_frame)) + list(range(ref_frame+1, end_idx))),
194
+ self.sampling_frame_num - 1,
195
+ )
196
+ selected_idx = selected_idx.tolist() + [ref_frame]
197
+ selected_idx = sorted(selected_idx)
198
+ if self.sampling_frame_shuffle:
199
+ random.shuffle(selected_idx)
200
+ else:
201
+ selected_idx = range(video_length)
202
+
203
+ video_annos = dataset_dict.pop("annotations", None)
204
+ file_names = dataset_dict.pop("file_names", None)
205
+ audio_feats = dataset_dict.pop("audio", None)
206
+
207
+ if self.is_train:
208
+ _ids = set()
209
+ for frame_idx in selected_idx:
210
+ _ids.update([anno["id"] for anno in video_annos[frame_idx]])
211
+ ids = dict()
212
+ for i, _id in enumerate(_ids):
213
+ ids[_id] = i
214
+
215
+ dataset_dict["image"] = []
216
+ dataset_dict["instances"] = []
217
+ dataset_dict["file_names"] = []
218
+ dataset_dict["audio"] = []
219
+ dataset_dict["frame_idx"] = list(selected_idx)
220
+ for frame_idx in selected_idx:
221
+ dataset_dict["file_names"].append(file_names[frame_idx])
222
+ dataset_dict["audio"].append(audio_feats[frame_idx])
223
+
224
+ # Read image
225
+ image = utils.read_image(file_names[frame_idx], format=self.image_format)
226
+ utils.check_image_size(dataset_dict, image)
227
+
228
+ aug_input = T.AugInput(image)
229
+ transforms = self.augmentations(aug_input)
230
+ image = aug_input.image
231
+
232
+ image_shape = image.shape[:2] # h, w
233
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
234
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
235
+ # Therefore it's important to use torch.Tensor.
236
+ dataset_dict["image"].append(torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))))
237
+
238
+ if (video_annos is None) or (not self.is_train):
239
+ continue
240
+
241
+ # NOTE copy() is to prevent annotations getting changed from applying augmentations
242
+ _frame_annos = []
243
+ for anno in video_annos[frame_idx]:
244
+ _anno = {}
245
+ for k, v in anno.items():
246
+ _anno[k] = copy.deepcopy(v)
247
+ _frame_annos.append(_anno)
248
+
249
+ # USER: Implement additional transformations if you have other types of data
250
+ annos = [
251
+ utils.transform_instance_annotations(obj, transforms, image_shape)
252
+ for obj in _frame_annos
253
+ if obj.get("iscrowd", 0) == 0
254
+ ]
255
+ sorted_annos = [_get_dummy_anno(self.num_classes) for _ in range(len(ids))]
256
+
257
+ for _anno in annos:
258
+ idx = ids[_anno["id"]]
259
+ sorted_annos[idx] = _anno
260
+ _gt_ids = [_anno["id"] for _anno in sorted_annos]
261
+
262
+ instances = utils.annotations_to_instances(sorted_annos, image_shape, mask_format="bitmask")
263
+ instances.gt_ids = torch.tensor(_gt_ids)
264
+ if instances.has("gt_masks"):
265
+ instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
266
+ instances = filter_empty_instances(instances)
267
+ else:
268
+ instances.gt_masks = BitMasks(torch.empty((0, *image_shape)))
269
+ dataset_dict["instances"].append(instances)
270
+
271
+ return dataset_dict
272
+
avism/data/datasets/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import builtin # ensure the builtin datasets are registered
2
+
3
+ __all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
avism/data/datasets/avis.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import io
3
+ import logging
4
+ import numpy as np
5
+ import os
6
+ import pycocotools.mask as mask_util
7
+ from fvcore.common.file_io import PathManager
8
+ from fvcore.common.timer import Timer
9
+
10
+ from detectron2.structures import Boxes, BoxMode, PolygonMasks
11
+ from detectron2.data import DatasetCatalog, MetadataCatalog
12
+
13
+ from .avis_api.avos import AVOS
14
+
15
+
16
+ """
17
+ This file contains functions to parse AVIS dataset of
18
+ COCO-format annotations into dicts in "Detectron2 format".
19
+ """
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ __all__ = ["load_avis_json", "register_avis_instances"]
24
+
25
+
26
+ AVIS_CATEGORIES = [
27
+ {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
28
+ {"color": [0, 82, 0], "isthing": 1, "id": 2, "name": "violin"},
29
+ {"color": [119, 11, 32], "isthing": 1, "id": 3, "name": "guitar"},
30
+ {"color": [165, 42, 42], "isthing": 1, "id": 4, "name": "cello"},
31
+ {"color": [134, 134, 103], "isthing": 1, "id": 5, "name": "flute"},
32
+ {"color": [0, 0, 142], "isthing": 1, "id": 6, "name": "piano"},
33
+ {"color": [255, 109, 65], "isthing": 1, "id": 7, "name": "ukulele"},
34
+ {"color": [0, 226, 252], "isthing": 1, "id": 8, "name": "accordion"},
35
+ {"color": [5, 121, 0], "isthing": 1, "id": 9, "name": "guzheng"},
36
+ {"color": [0, 60, 100], "isthing": 1, "id": 10, "name": "clarinet"},
37
+ {"color": [250, 170, 30], "isthing": 1, "id": 11, "name": "cat"},
38
+ {"color": [100, 170, 30], "isthing": 1, "id": 12, "name": "car"},
39
+ {"color": [179, 0, 194], "isthing": 1, "id": 13, "name": "saxophone"},
40
+ {"color": [255, 77, 255], "isthing": 1, "id": 14, "name": "dog"},
41
+ {"color": [120, 166, 157], "isthing": 1, "id": 15, "name": "lawn_mover"},
42
+ {"color": [73, 77, 174], "isthing": 1, "id": 16, "name": "tuba"},
43
+ {"color": [0, 80, 100], "isthing": 1, "id": 17, "name": "banjo"},
44
+ {"color": [182, 182, 255], "isthing": 1, "id": 18, "name": "pipa"},
45
+ {"color": [0, 143, 149], "isthing": 1, "id": 19, "name": "bassoon"},
46
+ {"color": [174, 57, 255], "isthing": 1, "id": 20, "name": "airplane"},
47
+ {"color": [0, 0, 230], "isthing": 1, "id": 21, "name": "tree_harvester"},
48
+ {"color": [72, 0, 118], "isthing": 1, "id": 22, "name": "trumpet"},
49
+ {"color": [255, 179, 240], "isthing": 1, "id": 23, "name": "lion"},
50
+ {"color": [0, 125, 92], "isthing": 1, "id": 24, "name": "bass"},
51
+ {"color": [209, 0, 151], "isthing": 1, "id": 25, "name": "erhu"},
52
+ {"color": [188, 208, 182], "isthing": 1, "id": 26, "name": "horse"}]
53
+
54
+
55
+ def _get_avis_instances_meta():
56
+ thing_ids = [k["id"] for k in AVIS_CATEGORIES if k["isthing"] == 1]
57
+ thing_colors = [k["color"] for k in AVIS_CATEGORIES if k["isthing"] == 1]
58
+ assert len(thing_ids) == 26, len(thing_ids)
59
+ # Mapping from the incontiguous AVIS category id to an id in [0, 25]
60
+ thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
61
+ thing_classes = [k["name"] for k in AVIS_CATEGORIES if k["isthing"] == 1]
62
+ ret = {
63
+ "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
64
+ "thing_classes": thing_classes,
65
+ "thing_colors": thing_colors,
66
+ }
67
+ return ret
68
+
69
+
70
+ def load_avis_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None):
71
+
72
+ timer = Timer()
73
+ json_file = PathManager.get_local_path(json_file)
74
+ with contextlib.redirect_stdout(io.StringIO()):
75
+ avis_api = AVOS(json_file)
76
+ if timer.seconds() > 1:
77
+ logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
78
+
79
+ id_map = None
80
+ if dataset_name is not None:
81
+ meta = MetadataCatalog.get(dataset_name)
82
+ cat_ids = sorted(avis_api.getCatIds())
83
+ cats = avis_api.loadCats(cat_ids)
84
+ # The categories in a custom json file may not be sorted.
85
+ thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])]
86
+ meta.thing_classes = thing_classes
87
+
88
+ # It works by looking at the "categories" field in the json, therefore
89
+ # if users' own json also have incontiguous ids, we'll
90
+ # apply this mapping as well but print a warning.
91
+ if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)):
92
+ if "coco" not in dataset_name:
93
+ logger.warning(
94
+ """
95
+ Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.
96
+ """
97
+ )
98
+ id_map = {v: i for i, v in enumerate(cat_ids)}
99
+ meta.thing_dataset_id_to_contiguous_id = id_map
100
+
101
+ # sort indices for reproducible results
102
+ vid_ids = sorted(avis_api.vids.keys())
103
+ vids = avis_api.loadVids(vid_ids)
104
+
105
+ anns = [avis_api.vidToAnns[vid_id] for vid_id in vid_ids]
106
+ total_num_valid_anns = sum([len(x) for x in anns])
107
+ total_num_anns = len(avis_api.anns)
108
+ if total_num_valid_anns < total_num_anns:
109
+ logger.warning(
110
+ f"{json_file} contains {total_num_anns} annotations, but only "
111
+ f"{total_num_valid_anns} of them match to images in the file."
112
+ )
113
+
114
+ vids_anns = list(zip(vids, anns))
115
+ logger.info("Loaded {} videos in AVIS format from {}".format(len(vids_anns), json_file))
116
+
117
+ dataset_dicts = []
118
+
119
+ ann_keys = ["iscrowd", "category_id", "id"] + (extra_annotation_keys or [])
120
+
121
+ num_instances_without_valid_segmentation = 0
122
+
123
+ for (vid_dict, anno_dict_list) in vids_anns:
124
+ record = {}
125
+ record["file_names"] = [os.path.join(image_root, vid_dict["file_names"][i]) for i in range(vid_dict["length"])]
126
+ record["height"] = vid_dict["height"]
127
+ record["width"] = vid_dict["width"]
128
+ record["length"] = vid_dict["length"]
129
+ video_id = record["video_id"] = vid_dict["id"]
130
+
131
+ video_objs = []
132
+ for frame_idx in range(record["length"]):
133
+ frame_objs = []
134
+ for anno in anno_dict_list:
135
+ assert anno["video_id"] == video_id
136
+
137
+ obj = {key: anno[key] for key in ann_keys if key in anno}
138
+
139
+ _bboxes = anno.get("bboxes", None)
140
+ _segm = anno.get("segmentations", None)
141
+
142
+ if not (_bboxes and _segm and _bboxes[frame_idx] and _segm[frame_idx]):
143
+ continue
144
+
145
+ bbox = _bboxes[frame_idx]
146
+ segm = _segm[frame_idx]
147
+
148
+ obj["bbox"] = bbox
149
+ obj["bbox_mode"] = BoxMode.XYWH_ABS
150
+
151
+ if isinstance(segm, dict):
152
+ if isinstance(segm["counts"], list):
153
+ # convert to compressed RLE
154
+ segm = mask_util.frPyObjects(segm, *segm["size"])
155
+ elif segm:
156
+ # filter out invalid polygons (< 3 points)
157
+ segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
158
+ if len(segm) == 0:
159
+ num_instances_without_valid_segmentation += 1
160
+ continue # ignore this instance
161
+ obj["segmentation"] = segm
162
+
163
+ if id_map:
164
+ obj["category_id"] = id_map[obj["category_id"]]
165
+ frame_objs.append(obj)
166
+ video_objs.append(frame_objs)
167
+ record["annotations"] = video_objs
168
+
169
+ # audio:
170
+ audio_feats_pth = os.path.join(image_root[:-10], "FEATAudios", vid_dict['file_names'][0].split("/")[0] + '.npy')
171
+ record["audio"] = np.load(audio_feats_pth)
172
+
173
+ dataset_dicts.append(record)
174
+
175
+ if num_instances_without_valid_segmentation > 0:
176
+ logger.warning(
177
+ "Filtered out {} instances without valid segmentation. ".format(
178
+ num_instances_without_valid_segmentation
179
+ )
180
+ + "There might be issues in your dataset generation process. "
181
+ "A valid polygon should be a list[float] with even length >= 6."
182
+ )
183
+ return dataset_dicts
184
+
185
+
186
+ def register_avis_instances(name, metadata, json_file, image_root):
187
+ """
188
+ Register a dataset in AVIS's json annotation format for
189
+ instance tracking.
190
+
191
+ Args:
192
+ name (str): the name that identifies a dataset, e.g. "avis_train".
193
+ metadata (dict): extra metadata associated with this dataset. You can
194
+ leave it as an empty dict.
195
+ json_file (str): path to the json instance annotation file.
196
+ image_root (str or path-like): directory which contains all the images.
197
+ """
198
+ assert isinstance(name, str), name
199
+ assert isinstance(json_file, (str, os.PathLike)), json_file
200
+ assert isinstance(image_root, (str, os.PathLike)), image_root
201
+ # 1. register a function which returns dicts
202
+ DatasetCatalog.register(name, lambda: load_avis_json(json_file, image_root, name))
203
+
204
+ # 2. Optionally, add metadata about this dataset,
205
+ # since they might be useful in evaluation, visualization or logging
206
+ MetadataCatalog.get(name).set(
207
+ json_file=json_file, image_root=image_root, evaluator_type="avis", **metadata
208
+ )
209
+
avism/data/datasets/avis_api/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
avism/data/datasets/avis_api/avos.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # The following API functions are defined:
4
+ # AVOS - AVOS api class that loads AVIS annotation file and prepare data structures.
5
+ # decodeMask - Decode binary mask M encoded via run-length encoding.
6
+ # encodeMask - Encode binary mask M using run-length encoding.
7
+ # getAnnIds - Get ann ids that satisfy given filter conditions.
8
+ # getCatIds - Get cat ids that satisfy given filter conditions.
9
+ # getImgIds - Get img ids that satisfy given filter conditions.
10
+ # loadAnns - Load anns with the specified ids.
11
+ # loadCats - Load cats with the specified ids.
12
+ # loadImgs - Load imgs with the specified ids.
13
+ # annToMask - Convert segmentation in an annotation to binary mask.
14
+ # loadRes - Load algorithm results and create API for accessing them.
15
+
16
+ import json
17
+ import time
18
+ import numpy as np
19
+ import copy
20
+ import itertools
21
+ from pycocotools import mask as maskUtils
22
+ from collections import defaultdict
23
+ import sys
24
+ PYTHON_VERSION = sys.version_info[0]
25
+ if PYTHON_VERSION == 2:
26
+ from urllib import urlretrieve
27
+ elif PYTHON_VERSION == 3:
28
+ from urllib.request import urlretrieve
29
+
30
+
31
+ def _isArrayLike(obj):
32
+ return hasattr(obj, '__iter__') and hasattr(obj, '__len__')
33
+
34
+
35
+ class AVOS:
36
+ def __init__(self, annotation_file=None):
37
+ """
38
+ Constructor of Microsoft COCO helper class for reading and visualizing annotations.
39
+ :param annotation_file (str): location of annotation file
40
+ :param image_folder (str): location to the folder that hosts images.
41
+ :return:
42
+ """
43
+ # load dataset
44
+ self.dataset,self.anns,self.cats,self.vids = dict(),dict(),dict(),dict()
45
+ self.vidToAnns, self.catToVids = defaultdict(list), defaultdict(list)
46
+ if not annotation_file == None:
47
+ print('loading annotations into memory...')
48
+ tic = time.time()
49
+ dataset = json.load(open(annotation_file, 'r'))
50
+ assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
51
+ print('Done (t={:0.2f}s)'.format(time.time()- tic))
52
+ self.dataset = dataset
53
+ self.createIndex()
54
+
55
+ def createIndex(self):
56
+ # create index
57
+ print('creating index...')
58
+ anns, cats, vids = {}, {}, {}
59
+ vidToAnns,catToVids = defaultdict(list),defaultdict(list)
60
+ if 'annotations' in self.dataset:
61
+ for ann in self.dataset['annotations']:
62
+ vidToAnns[ann['video_id']].append(ann)
63
+ anns[ann['id']] = ann
64
+
65
+ if 'videos' in self.dataset:
66
+ for vid in self.dataset['videos']:
67
+ vids[vid['id']] = vid
68
+
69
+ if 'categories' in self.dataset:
70
+ for cat in self.dataset['categories']:
71
+ cats[cat['id']] = cat
72
+
73
+ if 'annotations' in self.dataset and 'categories' in self.dataset:
74
+ for ann in self.dataset['annotations']:
75
+ catToVids[ann['category_id']].append(ann['video_id'])
76
+
77
+ print('index created!')
78
+
79
+ # create class members
80
+ self.anns = anns
81
+ self.vidToAnns = vidToAnns
82
+ self.catToVids = catToVids
83
+ self.vids = vids
84
+ self.cats = cats
85
+
86
+ def info(self):
87
+ """
88
+ Print information about the annotation file.
89
+ :return:
90
+ """
91
+ for key, value in self.dataset['info'].items():
92
+ print('{}: {}'.format(key, value))
93
+
94
+ def getAnnIds(self, vidIds=[], catIds=[], areaRng=[], iscrowd=None):
95
+ """
96
+ Get ann ids that satisfy given filter conditions. default skips that filter
97
+ :param vidIds (int array) : get anns for given vids
98
+ catIds (int array) : get anns for given cats
99
+ areaRng (float array) : get anns for given area range (e.g. [0 inf])
100
+ iscrowd (boolean) : get anns for given crowd label (False or True)
101
+ :return: ids (int array) : integer array of ann ids
102
+ """
103
+ vidIds = vidIds if _isArrayLike(vidIds) else [vidIds]
104
+ catIds = catIds if _isArrayLike(catIds) else [catIds]
105
+
106
+ if len(vidIds) == len(catIds) == len(areaRng) == 0:
107
+ anns = self.dataset['annotations']
108
+ else:
109
+ if not len(vidIds) == 0:
110
+ lists = [self.vidToAnns[vidId] for vidId in vidIds if vidId in self.vidToAnns]
111
+ anns = list(itertools.chain.from_iterable(lists))
112
+ else:
113
+ anns = self.dataset['annotations']
114
+ anns = anns if len(catIds) == 0 else [ann for ann in anns if ann['category_id'] in catIds]
115
+ anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['avg_area'] > areaRng[0] and ann['avg_area'] < areaRng[1]]
116
+ if not iscrowd == None:
117
+ ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]
118
+ else:
119
+ ids = [ann['id'] for ann in anns]
120
+ return ids
121
+
122
+ def getCatIds(self, catNms=[], supNms=[], catIds=[]):
123
+ """
124
+ filtering parameters. default skips that filter.
125
+ :param catNms (str array) : get cats for given cat names
126
+ :param supNms (str array) : get cats for given supercategory names
127
+ :param catIds (int array) : get cats for given cat ids
128
+ :return: ids (int array) : integer array of cat ids
129
+ """
130
+ catNms = catNms if _isArrayLike(catNms) else [catNms]
131
+ supNms = supNms if _isArrayLike(supNms) else [supNms]
132
+ catIds = catIds if _isArrayLike(catIds) else [catIds]
133
+
134
+ if len(catNms) == len(supNms) == len(catIds) == 0:
135
+ cats = self.dataset['categories']
136
+ else:
137
+ cats = self.dataset['categories']
138
+ cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name'] in catNms]
139
+ cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms]
140
+ cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id'] in catIds]
141
+ ids = [cat['id'] for cat in cats]
142
+ return ids
143
+
144
+ def getVidIds(self, vidIds=[], catIds=[]):
145
+ '''
146
+ Get vid ids that satisfy given filter conditions.
147
+ :param vidIds (int array) : get vids for given ids
148
+ :param catIds (int array) : get vids with all given cats
149
+ :return: ids (int array) : integer array of vid ids
150
+ '''
151
+ vidIds = vidIds if _isArrayLike(vidIds) else [vidIds]
152
+ catIds = catIds if _isArrayLike(catIds) else [catIds]
153
+
154
+ if len(vidIds) == len(catIds) == 0:
155
+ ids = self.vids.keys()
156
+ else:
157
+ ids = set(vidIds)
158
+ for i, catId in enumerate(catIds):
159
+ if i == 0 and len(ids) == 0:
160
+ ids = set(self.catToVids[catId])
161
+ else:
162
+ ids &= set(self.catToVids[catId])
163
+ return list(ids)
164
+
165
+ def loadAnns(self, ids=[]):
166
+ """
167
+ Load anns with the specified ids.
168
+ :param ids (int array) : integer ids specifying anns
169
+ :return: anns (object array) : loaded ann objects
170
+ """
171
+ if _isArrayLike(ids):
172
+ return [self.anns[id] for id in ids]
173
+ elif type(ids) == int:
174
+ return [self.anns[ids]]
175
+
176
+ def loadCats(self, ids=[]):
177
+ """
178
+ Load cats with the specified ids.
179
+ :param ids (int array) : integer ids specifying cats
180
+ :return: cats (object array) : loaded cat objects
181
+ """
182
+ if _isArrayLike(ids):
183
+ return [self.cats[id] for id in ids]
184
+ elif type(ids) == int:
185
+ return [self.cats[ids]]
186
+
187
+ def loadVids(self, ids=[]):
188
+ """
189
+ Load anns with the specified ids.
190
+ :param ids (int array) : integer ids specifying vid
191
+ :return: vids (object array) : loaded vid objects
192
+ """
193
+ if _isArrayLike(ids):
194
+ return [self.vids[id] for id in ids]
195
+ elif type(ids) == int:
196
+ return [self.vids[ids]]
197
+
198
+
199
+ def loadRes(self, resFile):
200
+ """
201
+ Load result file and return a result api object.
202
+ :param resFile (str) : file name of result file
203
+ :return: res (obj) : result api object
204
+ """
205
+ res = AVOS()
206
+ res.dataset['videos'] = [img for img in self.dataset['videos']]
207
+
208
+ print('Loading and preparing results...')
209
+ tic = time.time()
210
+ if type(resFile) == str or (PYTHON_VERSION == 2 and type(resFile) == unicode):
211
+ anns = json.load(open(resFile))
212
+ elif type(resFile) == np.ndarray:
213
+ anns = self.loadNumpyAnnotations(resFile)
214
+ else:
215
+ anns = resFile
216
+ assert type(anns) == list, 'results in not an array of objects'
217
+ annsVidIds = [ann['video_id'] for ann in anns]
218
+ assert set(annsVidIds) == (set(annsVidIds) & set(self.getVidIds())), \
219
+ 'Results do not correspond to current coco set'
220
+ if 'segmentations' in anns[0]:
221
+ res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
222
+ for id, ann in enumerate(anns):
223
+ ann['areas'] = []
224
+ if not 'bboxes' in ann:
225
+ ann['bboxes'] = []
226
+ for seg in ann['segmentations']:
227
+ # now only support compressed RLE format as segmentation results
228
+ if seg:
229
+ ann['areas'].append(maskUtils.area(seg))
230
+ if len(ann['bboxes']) < len(ann['areas']):
231
+ ann['bboxes'].append(maskUtils.toBbox(seg))
232
+ else:
233
+ ann['areas'].append(None)
234
+ if len(ann['bboxes']) < len(ann['areas']):
235
+ ann['bboxes'].append(None)
236
+ ann['id'] = id+1
237
+ l = [a for a in ann['areas'] if a]
238
+ if len(l)==0:
239
+ ann['avg_area'] = 0
240
+ else:
241
+ ann['avg_area'] = np.array(l).mean()
242
+ ann['iscrowd'] = 0
243
+ print('DONE (t={:0.2f}s)'.format(time.time()- tic))
244
+
245
+ res.dataset['annotations'] = anns
246
+ res.createIndex()
247
+ return res
248
+
249
+ def annToRLE(self, ann, frameId):
250
+ """
251
+ Convert annotation which can be polygons, uncompressed RLE to RLE.
252
+ :return: binary mask (numpy 2D array)
253
+ """
254
+ t = self.vids[ann['video_id']]
255
+ h, w = t['height'], t['width']
256
+ segm = ann['segmentations'][frameId]
257
+ if type(segm) == list:
258
+ # polygon -- a single object might consist of multiple parts
259
+ # we merge all parts into one mask rle code
260
+ rles = maskUtils.frPyObjects(segm, h, w)
261
+ rle = maskUtils.merge(rles)
262
+ elif type(segm['counts']) == list:
263
+ # uncompressed RLE
264
+ rle = maskUtils.frPyObjects(segm, h, w)
265
+ else:
266
+ # rle
267
+ rle = segm
268
+ return rle
269
+
270
+ def annToMask(self, ann, frameId):
271
+ """
272
+ Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask.
273
+ :return: binary mask (numpy 2D array)
274
+ """
275
+ rle = self.annToRLE(ann, frameId)
276
+ m = maskUtils.decode(rle)
277
+ return m
avism/data/datasets/avis_api/avoseval.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ import numpy as np
4
+ import datetime
5
+ import time
6
+ from collections import defaultdict
7
+ from pycocotools import mask as maskUtils
8
+ import copy
9
+
10
+ class AVOSeval:
11
+ # Interface for evaluating video instance segmentation on the AVIS dataset.
12
+ #
13
+ # The usage for AVOSeval is as follows:
14
+ # cocoGt=..., cocoDt=... # load dataset and results
15
+ # E = AVOSeval(cocoGt,cocoDt); # initialize AVOSeval object
16
+ # E.params.recThrs = ...; # set parameters as desired
17
+ # E.evaluate(); # run per image evaluation
18
+ # E.accumulate(); # accumulate per image results
19
+ # E.summarize(); # display summary metrics of results
20
+ #
21
+ # The evaluation parameters are as follows (defaults in brackets):
22
+ # imgIds - [all] N img ids to use for evaluation
23
+ # catIds - [all] K cat ids to use for evaluation
24
+ # iouThrs - [.5:.05:.95] T=10 IoU thresholds for evaluation
25
+ # recThrs - [0:.01:1] R=101 recall thresholds for evaluation
26
+ # areaRng - [...] A=4 object area ranges for evaluation
27
+ # maxDets - [1 10 100] M=3 thresholds on max detections per image
28
+ # iouType - ['segm'] set iouType to 'segm', 'bbox' or 'keypoints'
29
+ # iouType replaced the now DEPRECATED useSegm parameter.
30
+ # useCats - [1] if true use category labels for evaluation
31
+ # Note: if useCats=0 category labels are ignored as in proposal scoring.
32
+ # Note: multiple areaRngs [Ax2] and maxDets [Mx1] can be specified.
33
+ #
34
+ # evaluate(): evaluates detections on every image and every category and
35
+ # concats the results into the "evalImgs" with fields:
36
+ # dtIds - [1xD] id for each of the D detections (dt)
37
+ # gtIds - [1xG] id for each of the G ground truths (gt)
38
+ # dtMatches - [TxD] matching gt id at each IoU or 0
39
+ # gtMatches - [TxG] matching dt id at each IoU or 0
40
+ # dtScores - [1xD] confidence of each dt
41
+ # gtIgnore - [1xG] ignore flag for each gt
42
+ # dtIgnore - [TxD] ignore flag for each dt at each IoU
43
+ #
44
+ # accumulate(): accumulates the per-image, per-category evaluation
45
+ # results in "evalImgs" into the dictionary "eval" with fields:
46
+ # params - parameters used for evaluation
47
+ # date - date evaluation was performed
48
+ # counts - [T,R,K,A,M] parameter dimensions (see above)
49
+ # precision - [TxRxKxAxM] precision for every evaluation setting
50
+ # recall - [TxKxAxM] max recall for every evaluation setting
51
+ # Note: precision and recall==-1 for settings with no gt objects.
52
+ #
53
+ # See also coco, mask, pycocoDemo, pycocoEvalDemo
54
+ def __init__(self, cocoGt=None, cocoDt=None, iouType='segm'):
55
+ '''
56
+ Initialize CocoEval using coco APIs for gt and dt
57
+ :param cocoGt: coco object with ground truth annotations
58
+ :param cocoDt: coco object with detection results
59
+ :return: None
60
+ '''
61
+ if not iouType:
62
+ print('iouType not specified. use default iouType segm')
63
+ self.cocoGt = cocoGt # ground truth COCO API
64
+ self.cocoDt = cocoDt # detections COCO API
65
+ self.params = {} # evaluation parameters
66
+ self.evalVids = defaultdict(list) # per-image per-category evaluation results [KxAxI] elements
67
+ self.eval = {} # accumulated evaluation results
68
+ self._gts = defaultdict(list) # gt for evaluation
69
+ self._dts = defaultdict(list) # dt for evaluation
70
+ self.params = Params(iouType=iouType) # parameters
71
+ self._paramsEval = {} # parameters for evaluation
72
+ self.stats = [] # result summarization
73
+ self.ious = {} # ious between all gts and dts
74
+ if not cocoGt is None:
75
+ self.params.vidIds = sorted(cocoGt.getVidIds())
76
+ self.params.catIds = sorted(cocoGt.getCatIds())
77
+
78
+
79
+ def _prepare(self):
80
+ '''
81
+ Prepare ._gts and ._dts for evaluation based on params
82
+ :return: None
83
+ '''
84
+ def _toMask(anns, coco):
85
+ # modify ann['segmentation'] by reference
86
+ for ann in anns:
87
+ for i, a in enumerate(ann['segmentations']):
88
+ if a:
89
+ rle = coco.annToRLE(ann, i)
90
+ ann['segmentations'][i] = rle
91
+ l = [a for a in ann['areas'] if a]
92
+ if len(l)==0:
93
+ ann['avg_area'] = 0
94
+ else:
95
+ ann['avg_area'] = np.array(l).mean()
96
+ p = self.params
97
+ if p.useCats:
98
+ gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(vidIds=p.vidIds, catIds=p.catIds))
99
+ dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(vidIds=p.vidIds, catIds=p.catIds))
100
+ else:
101
+ gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(vidIds=p.vidIds))
102
+ dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(vidIds=p.vidIds))
103
+
104
+ # convert ground truth to mask if iouType == 'segm'
105
+ if p.iouType == 'segm':
106
+ _toMask(gts, self.cocoGt)
107
+ _toMask(dts, self.cocoDt)
108
+ # set ignore flag
109
+ for gt in gts:
110
+ gt['ignore'] = gt['ignore'] if 'ignore' in gt else 0
111
+ gt['ignore'] = 'iscrowd' in gt and gt['iscrowd']
112
+ if p.iouType == 'keypoints':
113
+ gt['ignore'] = (gt['num_keypoints'] == 0) or gt['ignore']
114
+
115
+ self._gts = defaultdict(list) # gt for evaluation
116
+ self._dts = defaultdict(list) # dt for evaluation
117
+ for gt in gts:
118
+ self._gts[gt['video_id'], gt['category_id']].append(gt)
119
+ for dt in dts:
120
+ self._dts[dt['video_id'], dt['category_id']].append(dt)
121
+ self.evalVids = defaultdict(list) # per-image per-category evaluation results
122
+ self.eval = {} # accumulated evaluation results
123
+
124
+ def evaluate(self):
125
+ '''
126
+ Run per image evaluation on given images and store results (a list of dict) in self.evalVids
127
+ :return: None
128
+ '''
129
+ tic = time.time()
130
+ print('Running per image evaluation...')
131
+ p = self.params
132
+ # add backward compatibility if useSegm is specified in params
133
+ if not p.useSegm is None:
134
+ p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
135
+ print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
136
+ print('Evaluate annotation type *{}*'.format(p.iouType))
137
+ p.vidIds = list(np.unique(p.vidIds))
138
+ if p.useCats:
139
+ p.catIds = list(np.unique(p.catIds))
140
+ p.maxDets = sorted(p.maxDets)
141
+ self.params=p
142
+
143
+ self._prepare()
144
+ # loop through images, area range, max detection number
145
+ catIds = p.catIds if p.useCats else [-1]
146
+
147
+ if p.iouType == 'segm' or p.iouType == 'bbox':
148
+ computeIoU = self.computeIoU
149
+ elif p.iouType == 'keypoints':
150
+ computeIoU = self.computeOks
151
+ self.ious = {(vidId, catId): computeIoU(vidId, catId) \
152
+ for vidId in p.vidIds
153
+ for catId in catIds}
154
+
155
+ evaluateVid = self.evaluateVid
156
+ maxDet = p.maxDets[-1]
157
+
158
+
159
+ self.evalImgs = [evaluateVid(vidId, catId, areaRng, maxDet)
160
+ for catId in catIds
161
+ for areaRng in p.areaRng
162
+ for vidId in p.vidIds
163
+ ]
164
+ self._paramsEval = copy.deepcopy(self.params)
165
+ toc = time.time()
166
+ print('DONE (t={:0.2f}s).'.format(toc-tic))
167
+
168
+ def computeIoU(self, vidId, catId):
169
+ p = self.params
170
+ if p.useCats:
171
+ gt = self._gts[vidId,catId]
172
+ dt = self._dts[vidId,catId]
173
+ else:
174
+ gt = [_ for cId in p.catIds for _ in self._gts[vidId,cId]]
175
+ dt = [_ for cId in p.catIds for _ in self._dts[vidId,cId]]
176
+ if len(gt) == 0 and len(dt) ==0:
177
+ return []
178
+ inds = np.argsort([-d['score'] for d in dt], kind='mergesort')
179
+ dt = [dt[i] for i in inds]
180
+ if len(dt) > p.maxDets[-1]:
181
+ dt=dt[0:p.maxDets[-1]]
182
+
183
+ if p.iouType == 'segm':
184
+ g = [g['segmentations'] for g in gt]
185
+ d = [d['segmentations'] for d in dt]
186
+ elif p.iouType == 'bbox':
187
+ g = [g['bboxes'] for g in gt]
188
+ d = [d['bboxes'] for d in dt]
189
+ else:
190
+ raise Exception('unknown iouType for iou computation')
191
+
192
+ # compute iou between each dt and gt region
193
+ iscrowd = [int(o['iscrowd']) for o in gt]
194
+ #ious = maskUtils.iou(d,g,iscrowd)
195
+ def iou_seq(d_seq, g_seq):
196
+ i = .0
197
+ u = .0
198
+ for d, g in zip(d_seq, g_seq):
199
+ if d and g:
200
+ i += maskUtils.area(maskUtils.merge([d, g], True))
201
+ u += maskUtils.area(maskUtils.merge([d, g], False))
202
+ elif not d and g:
203
+ u += maskUtils.area(g)
204
+ elif d and not g:
205
+ u += maskUtils.area(d)
206
+ if not u > .0:
207
+ print("Mask sizes in video {} and category {} may not match!".format(vidId, catId))
208
+ iou = i / u if u > .0 else .0
209
+ return iou
210
+ ious = np.zeros([len(d), len(g)])
211
+ for i, j in np.ndindex(ious.shape):
212
+ ious[i, j] = iou_seq(d[i], g[j])
213
+ #print(vidId, catId, ious.shape, ious)
214
+ return ious
215
+
216
+ def computeOks(self, imgId, catId):
217
+ p = self.params
218
+ # dimention here should be Nxm
219
+ gts = self._gts[imgId, catId]
220
+ dts = self._dts[imgId, catId]
221
+ inds = np.argsort([-d['score'] for d in dts], kind='mergesort')
222
+ dts = [dts[i] for i in inds]
223
+ if len(dts) > p.maxDets[-1]:
224
+ dts = dts[0:p.maxDets[-1]]
225
+ # if len(gts) == 0 and len(dts) == 0:
226
+ if len(gts) == 0 or len(dts) == 0:
227
+ return []
228
+ ious = np.zeros((len(dts), len(gts)))
229
+ sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62,.62, 1.07, 1.07, .87, .87, .89, .89])/10.0
230
+ vars = (sigmas * 2)**2
231
+ k = len(sigmas)
232
+ # compute oks between each detection and ground truth object
233
+ for j, gt in enumerate(gts):
234
+ # create bounds for ignore regions(double the gt bbox)
235
+ g = np.array(gt['keypoints'])
236
+ xg = g[0::3]; yg = g[1::3]; vg = g[2::3]
237
+ k1 = np.count_nonzero(vg > 0)
238
+ bb = gt['bbox']
239
+ x0 = bb[0] - bb[2]; x1 = bb[0] + bb[2] * 2
240
+ y0 = bb[1] - bb[3]; y1 = bb[1] + bb[3] * 2
241
+ for i, dt in enumerate(dts):
242
+ d = np.array(dt['keypoints'])
243
+ xd = d[0::3]; yd = d[1::3]
244
+ if k1>0:
245
+ # measure the per-keypoint distance if keypoints visible
246
+ dx = xd - xg
247
+ dy = yd - yg
248
+ else:
249
+ # measure minimum distance to keypoints in (x0,y0) & (x1,y1)
250
+ z = np.zeros((k))
251
+ dx = np.max((z, x0-xd),axis=0)+np.max((z, xd-x1),axis=0)
252
+ dy = np.max((z, y0-yd),axis=0)+np.max((z, yd-y1),axis=0)
253
+ e = (dx**2 + dy**2) / vars / (gt['avg_area']+np.spacing(1)) / 2
254
+ if k1 > 0:
255
+ e=e[vg > 0]
256
+ ious[i, j] = np.sum(np.exp(-e)) / e.shape[0]
257
+ return ious
258
+
259
+ def evaluateVid(self, vidId, catId, aRng, maxDet):
260
+ '''
261
+ perform evaluation for single category and image
262
+ :return: dict (single image results)
263
+ '''
264
+ p = self.params
265
+ if p.useCats:
266
+ gt = self._gts[vidId,catId]
267
+ dt = self._dts[vidId,catId]
268
+ else:
269
+ gt = [_ for cId in p.catIds for _ in self._gts[vidId,cId]]
270
+ dt = [_ for cId in p.catIds for _ in self._dts[vidId,cId]]
271
+ if len(gt) == 0 and len(dt) ==0:
272
+ return None
273
+
274
+ for g in gt:
275
+ if g['ignore'] or (g['avg_area']<aRng[0] or g['avg_area']>aRng[1]):
276
+ g['_ignore'] = 1
277
+ else:
278
+ g['_ignore'] = 0
279
+
280
+ # sort dt highest score first, sort gt ignore last
281
+ gtind = np.argsort([g['_ignore'] for g in gt], kind='mergesort')
282
+ gt = [gt[i] for i in gtind]
283
+ dtind = np.argsort([-d['score'] for d in dt], kind='mergesort')
284
+ dt = [dt[i] for i in dtind[0:maxDet]]
285
+ iscrowd = [int(o['iscrowd']) for o in gt]
286
+ # load computed ious
287
+ ious = self.ious[vidId, catId][:, gtind] if len(self.ious[vidId, catId]) > 0 else self.ious[vidId, catId]
288
+
289
+ T = len(p.iouThrs)
290
+ G = len(gt)
291
+ D = len(dt)
292
+ gtm = np.zeros((T,G))
293
+ dtm = np.zeros((T,D))
294
+ gtIg = np.array([g['_ignore'] for g in gt])
295
+ dtIg = np.zeros((T,D))
296
+ if not len(ious)==0:
297
+ for tind, t in enumerate(p.iouThrs):
298
+ for dind, d in enumerate(dt):
299
+ # information about best match so far (m=-1 -> unmatched)
300
+ iou = min([t,1-1e-10])
301
+ m = -1
302
+ for gind, g in enumerate(gt):
303
+ # if this gt already matched, and not a crowd, continue
304
+ if gtm[tind,gind]>0 and not iscrowd[gind]:
305
+ continue
306
+ # if dt matched to reg gt, and on ignore gt, stop
307
+ if m>-1 and gtIg[m]==0 and gtIg[gind]==1:
308
+ break
309
+ # continue to next gt unless better match made
310
+ if ious[dind,gind] < iou:
311
+ continue
312
+ # if match successful and best so far, store appropriately
313
+ iou=ious[dind,gind]
314
+ m=gind
315
+ # if match made store id of match for both dt and gt
316
+ if m ==-1:
317
+ continue
318
+ dtIg[tind,dind] = gtIg[m]
319
+ dtm[tind,dind] = gt[m]['id']
320
+ gtm[tind,m] = d['id']
321
+ # set unmatched detections outside of area range to ignore
322
+ a = np.array([d['avg_area']<aRng[0] or d['avg_area']>aRng[1] for d in dt]).reshape((1, len(dt)))
323
+ dtIg = np.logical_or(dtIg, np.logical_and(dtm==0, np.repeat(a,T,0)))
324
+ # store results for given image and category
325
+ return {
326
+ 'video_id': vidId,
327
+ 'category_id': catId,
328
+ 'aRng': aRng,
329
+ 'maxDet': maxDet,
330
+ 'dtIds': [d['id'] for d in dt],
331
+ 'gtIds': [g['id'] for g in gt],
332
+ 'dtMatches': dtm,
333
+ 'gtMatches': gtm,
334
+ 'dtScores': [d['score'] for d in dt],
335
+ 'gtIgnore': gtIg,
336
+ 'dtIgnore': dtIg,
337
+ }
338
+
339
+ def accumulate(self, p = None):
340
+ '''
341
+ Accumulate per image evaluation results and store the result in self.eval
342
+ :param p: input params for evaluation
343
+ :return: None
344
+ '''
345
+ print('Accumulating evaluation results...')
346
+ tic = time.time()
347
+ if not self.evalImgs:
348
+ print('Please run evaluate() first')
349
+ # allows input customized parameters
350
+ if p is None:
351
+ p = self.params
352
+ p.catIds = p.catIds if p.useCats == 1 else [-1]
353
+ T = len(p.iouThrs)
354
+ R = len(p.recThrs)
355
+ K = len(p.catIds) if p.useCats else 1
356
+ A = len(p.areaRng)
357
+ M = len(p.maxDets)
358
+ precision = -np.ones((T,R,K,A,M)) # -1 for the precision of absent categories
359
+ recall = -np.ones((T,K,A,M))
360
+ scores = -np.ones((T,R,K,A,M))
361
+
362
+ # create dictionary for future indexing
363
+ _pe = self._paramsEval
364
+ catIds = _pe.catIds if _pe.useCats else [-1]
365
+ setK = set(catIds)
366
+ setA = set(map(tuple, _pe.areaRng))
367
+ setM = set(_pe.maxDets)
368
+ setI = set(_pe.vidIds)
369
+ # get inds to evaluate
370
+ k_list = [n for n, k in enumerate(p.catIds) if k in setK]
371
+ m_list = [m for n, m in enumerate(p.maxDets) if m in setM]
372
+ a_list = [n for n, a in enumerate(map(lambda x: tuple(x), p.areaRng)) if a in setA]
373
+ i_list = [n for n, i in enumerate(p.vidIds) if i in setI]
374
+ I0 = len(_pe.vidIds)
375
+ A0 = len(_pe.areaRng)
376
+ # retrieve E at each category, area range, and max number of detections
377
+ for k, k0 in enumerate(k_list):
378
+ Nk = k0*A0*I0
379
+ for a, a0 in enumerate(a_list):
380
+ Na = a0*I0
381
+ for m, maxDet in enumerate(m_list):
382
+ E = [self.evalImgs[Nk + Na + i] for i in i_list]
383
+ E = [e for e in E if not e is None]
384
+ if len(E) == 0:
385
+ continue
386
+ dtScores = np.concatenate([e['dtScores'][0:maxDet] for e in E])
387
+
388
+ # different sorting method generates slightly different results.
389
+ # mergesort is used to be consistent as Matlab implementation.
390
+ inds = np.argsort(-dtScores, kind='mergesort')
391
+ dtScoresSorted = dtScores[inds]
392
+
393
+ dtm = np.concatenate([e['dtMatches'][:,0:maxDet] for e in E], axis=1)[:,inds]
394
+ dtIg = np.concatenate([e['dtIgnore'][:,0:maxDet] for e in E], axis=1)[:,inds]
395
+ gtIg = np.concatenate([e['gtIgnore'] for e in E])
396
+ npig = np.count_nonzero(gtIg==0 )
397
+ if npig == 0:
398
+ continue
399
+ tps = np.logical_and( dtm, np.logical_not(dtIg) )
400
+ fps = np.logical_and(np.logical_not(dtm), np.logical_not(dtIg) )
401
+
402
+ tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float64)
403
+ fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float64)
404
+ for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
405
+ tp = np.array(tp)
406
+ fp = np.array(fp)
407
+ nd = len(tp)
408
+ rc = tp / npig
409
+ pr = tp / (fp+tp+np.spacing(1))
410
+ q = np.zeros((R,))
411
+ ss = np.zeros((R,))
412
+
413
+ if nd:
414
+ recall[t,k,a,m] = rc[-1]
415
+ else:
416
+ recall[t,k,a,m] = 0
417
+
418
+ # numpy is slow without cython optimization for accessing elements
419
+ # use python array gets significant speed improvement
420
+ pr = pr.tolist(); q = q.tolist()
421
+
422
+ for i in range(nd-1, 0, -1):
423
+ if pr[i] > pr[i-1]:
424
+ pr[i-1] = pr[i]
425
+
426
+ inds = np.searchsorted(rc, p.recThrs, side='left')
427
+ try:
428
+ for ri, pi in enumerate(inds):
429
+ q[ri] = pr[pi]
430
+ ss[ri] = dtScoresSorted[pi]
431
+ except:
432
+ pass
433
+ precision[t,:,k,a,m] = np.array(q)
434
+ scores[t,:,k,a,m] = np.array(ss)
435
+ self.eval = {
436
+ 'params': p,
437
+ 'counts': [T, R, K, A, M],
438
+ 'date': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
439
+ 'precision': precision,
440
+ 'recall': recall,
441
+ 'scores': scores,
442
+ }
443
+ toc = time.time()
444
+ print('DONE (t={:0.2f}s).'.format( toc-tic))
445
+
446
+ def summarize(self):
447
+ '''
448
+ Compute and display summary metrics for evaluation results.
449
+ Note this functin can *only* be applied on the default parameter setting
450
+ '''
451
+ def _summarize( ap=1, iouThr=None, areaRng='all', maxDets=100 ):
452
+ p = self.params
453
+ iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}'
454
+ titleStr = 'Average Precision' if ap == 1 else 'Average Recall'
455
+ typeStr = '(AP)' if ap==1 else '(AR)'
456
+ iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
457
+ if iouThr is None else '{:0.2f}'.format(iouThr)
458
+
459
+ aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
460
+ mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
461
+ if ap == 1:
462
+ # dimension of precision: [TxRxKxAxM]
463
+ s = self.eval['precision']
464
+ # IoU
465
+ if iouThr is not None:
466
+ t = np.where(iouThr == p.iouThrs)[0]
467
+ s = s[t]
468
+ s = s[:,:,:,aind,mind]
469
+ else:
470
+ # dimension of recall: [TxKxAxM]
471
+ s = self.eval['recall']
472
+ if iouThr is not None:
473
+ t = np.where(iouThr == p.iouThrs)[0]
474
+ s = s[t]
475
+ s = s[:,:,aind,mind]
476
+ if len(s[s>-1])==0:
477
+ mean_s = -1
478
+ else:
479
+ mean_s = np.mean(s[s>-1])
480
+ print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s))
481
+ return mean_s
482
+ def _summarizeDets():
483
+ stats = np.zeros((12,))
484
+ stats[0] = _summarize(1)
485
+ stats[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2])
486
+ stats[2] = _summarize(1, iouThr=.75, maxDets=self.params.maxDets[2])
487
+ stats[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2])
488
+ stats[4] = _summarize(1, areaRng='medium', maxDets=self.params.maxDets[2])
489
+ stats[5] = _summarize(1, areaRng='large', maxDets=self.params.maxDets[2])
490
+ stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
491
+ stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
492
+ stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
493
+ stats[9] = _summarize(0, areaRng='small', maxDets=self.params.maxDets[2])
494
+ stats[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2])
495
+ stats[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2])
496
+ return stats
497
+ def _summarizeKps():
498
+ stats = np.zeros((10,))
499
+ stats[0] = _summarize(1, maxDets=20)
500
+ stats[1] = _summarize(1, maxDets=20, iouThr=.5)
501
+ stats[2] = _summarize(1, maxDets=20, iouThr=.75)
502
+ stats[3] = _summarize(1, maxDets=20, areaRng='medium')
503
+ stats[4] = _summarize(1, maxDets=20, areaRng='large')
504
+ stats[5] = _summarize(0, maxDets=20)
505
+ stats[6] = _summarize(0, maxDets=20, iouThr=.5)
506
+ stats[7] = _summarize(0, maxDets=20, iouThr=.75)
507
+ stats[8] = _summarize(0, maxDets=20, areaRng='medium')
508
+ stats[9] = _summarize(0, maxDets=20, areaRng='large')
509
+ return stats
510
+ if not self.eval:
511
+ raise Exception('Please run accumulate() first')
512
+ iouType = self.params.iouType
513
+ if iouType == 'segm' or iouType == 'bbox':
514
+ summarize = _summarizeDets
515
+ elif iouType == 'keypoints':
516
+ summarize = _summarizeKps
517
+ self.stats = summarize()
518
+
519
+ def __str__(self):
520
+ self.summarize()
521
+
522
+ class Params:
523
+ '''
524
+ Params for coco evaluation api
525
+ '''
526
+ def setDetParams(self):
527
+ self.vidIds = []
528
+ self.catIds = []
529
+ # np.arange causes trouble. the data point on arange is slightly larger than the true value
530
+ #self.iouThrs = np.linspace(.5, 0.95, np.round((0.95 - .5) / .05) + 1, endpoint=True)
531
+ #self.recThrs = np.linspace(.0, 1.00, np.round((1.00 - .0) / .01) + 1, endpoint=True)
532
+ self.iouThrs = np.linspace(.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
533
+ self.recThrs = np.linspace(.0, 1.00, int(np.round((1.00 - .0) / .01)) + 1, endpoint=True)
534
+ self.maxDets = [1, 10, 100]
535
+ self.areaRng = [[0 ** 2, 1e5 ** 2], [0 ** 2, 128 ** 2], [ 128 ** 2, 256 ** 2], [256 ** 2, 1e5 ** 2]]
536
+ self.areaRngLbl = ['all', 'small', 'medium', 'large']
537
+ self.useCats = 1
538
+
539
+ def setKpParams(self):
540
+ self.vidIds = []
541
+ self.catIds = []
542
+ # np.arange causes trouble. the data point on arange is slightly larger than the true value
543
+ self.iouThrs = np.linspace(.5, 0.95, np.round((0.95 - .5) / .05) + 1, endpoint=True)
544
+ self.recThrs = np.linspace(.0, 1.00, np.round((1.00 - .0) / .01) + 1, endpoint=True)
545
+ self.maxDets = [20]
546
+ self.areaRng = [[0 ** 2, 1e5 ** 2], [32 ** 2, 96 ** 2], [96 ** 2, 1e5 ** 2]]
547
+ self.areaRngLbl = ['all', 'medium', 'large']
548
+ self.useCats = 1
549
+
550
+ def __init__(self, iouType='segm'):
551
+ if iouType == 'segm' or iouType == 'bbox':
552
+ self.setDetParams()
553
+ elif iouType == 'keypoints':
554
+ self.setKpParams()
555
+ else:
556
+ raise Exception('iouType not supported')
557
+ self.iouType = iouType
558
+ # useSegm is deprecated
559
+ self.useSegm = None
avism/data/datasets/builtin.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from .avis import (
4
+ register_avis_instances,
5
+ _get_avis_instances_meta,
6
+ )
7
+
8
+
9
+ # ==== Predefined splits for AVIS ===========
10
+ _PREDEFINED_SPLITS_AVIS = {
11
+ "avis_train": ("train/JPEGImages", "train.json"),
12
+ "avis_val": ("val/JPEGImages", "val.json"),
13
+ "avis_test": ("test/JPEGImages", "test.json"),
14
+ }
15
+
16
+ def register_all_avis(root):
17
+ for key, (image_root, json_file) in _PREDEFINED_SPLITS_AVIS.items():
18
+ # Assume pre-defined datasets live in `./datasets`.
19
+ register_avis_instances(
20
+ key,
21
+ _get_avis_instances_meta(),
22
+ os.path.join(root, json_file) if "://" not in json_file else json_file,
23
+ os.path.join(root, image_root),
24
+ )
25
+
26
+ if __name__.endswith(".builtin"):
27
+ # Assume pre-defined datasets live in `./datasets`.
28
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
29
+ register_all_avis(_root)
avism/data/datasets/extract_audio_feat/audio_feature_extractor.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
3
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1" # set gpu number
4
+ import numpy as np
5
+ import tensorflow as tf
6
+
7
+ import vggish_input
8
+ import vggish_params
9
+ import vggish_slim
10
+ import contextlib
11
+ import wave
12
+
13
+
14
+ # get audio length
15
+ def get_audio_len(audio_file):
16
+ with contextlib.closing(wave.open(audio_file, 'r')) as f:
17
+ frames = f.getnframes()
18
+ rate = f.getframerate()
19
+ wav_length = int(frames / float(rate))
20
+ return wav_length
21
+
22
+ # Paths to downloaded VGGish files.
23
+ checkpoint_path = './vggish_model.ckpt'
24
+ pca_params_path = './vggish_pca_params.npz'
25
+ freq = 1000
26
+ sr = 44100
27
+
28
+
29
+ audio_root = "./datasets/"
30
+ for subset in ["train", "val", "test"]:
31
+ print("{} ----------> ".format(subset))
32
+
33
+ audio_dir = os.path.join(audio_root, subset, "WAVAudios")
34
+ save_dir = os.path.join(audio_root, subset, "FEATAudios")
35
+ if not os.path.exists(save_dir):
36
+ os.makedirs(save_dir)
37
+
38
+ lis = sorted(os.listdir(audio_dir))
39
+ len_data = len(lis)
40
+ print(len_data)
41
+
42
+ i = 0
43
+ for n in range(len_data):
44
+ i += 1
45
+ # save file
46
+ outfile = os.path.join(save_dir, lis[n][:-4] + '.npy')
47
+ if os.path.exists(outfile):
48
+ print("\nProcessing: ", i, " / ", len_data, " ----> ", lis[n][:-4] + '.npy', " is already exist! ")
49
+ continue
50
+
51
+ '''feature learning by VGG-net trained by audioset'''
52
+ audio_index = os.path.join(audio_dir, lis[n]) # path of your audio files
53
+ num_secs = len(os.listdir(os.path.join(audio_root, subset, "JPEGImages", lis[n][:-4])))
54
+ # num_secs_real = get_audio_len(audio_index)
55
+ # print("\nProcessing: ", i, " / ", len_data, " --------> video: ", lis[n], " ---> sec: ", num_secs_real)
56
+
57
+ input_batch = vggish_input.wavfile_to_examples(audio_index, num_secs)
58
+ np.testing.assert_equal(
59
+ input_batch.shape,
60
+ [num_secs, vggish_params.NUM_FRAMES, vggish_params.NUM_BANDS])
61
+
62
+ # Define VGGish, load the checkpoint, and run the batch through the model to produce embeddings.
63
+ with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
64
+ vggish_slim.define_vggish_slim()
65
+ vggish_slim.load_vggish_slim_checkpoint(sess, checkpoint_path)
66
+
67
+ features_tensor = sess.graph.get_tensor_by_name(vggish_params.INPUT_TENSOR_NAME)
68
+ embedding_tensor = sess.graph.get_tensor_by_name(vggish_params.OUTPUT_TENSOR_NAME)
69
+ [embedding_batch] = sess.run([embedding_tensor], feed_dict={features_tensor: input_batch})
70
+ np.save(outfile, embedding_batch)
71
+ print(" save info: ", lis[n][:-4] + '.npy', " ---> ", embedding_batch.shape)
72
+
73
+ i += 1
74
+
75
+ print("\n---------------------------------- end ----------------------------------\n")
76
+
77
+
avism/data/datasets/extract_audio_feat/mel_features.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2017 The TensorFlow Authors All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Defines routines to compute mel spectrogram features from audio waveform."""
17
+
18
+ import numpy as np
19
+
20
+
21
+ def frame(data, window_length, hop_length):
22
+ """Convert array into a sequence of successive possibly overlapping frames.
23
+
24
+ An n-dimensional array of shape (num_samples, ...) is converted into an
25
+ (n+1)-D array of shape (num_frames, window_length, ...), where each frame
26
+ starts hop_length points after the preceding one.
27
+
28
+ This is accomplished using stride_tricks, so the original data is not
29
+ copied. However, there is no zero-padding, so any incomplete frames at the
30
+ end are not included.
31
+
32
+ Args:
33
+ data: np.array of dimension N >= 1.
34
+ window_length: Number of samples in each frame.
35
+ hop_length: Advance (in samples) between each window.
36
+
37
+ Returns:
38
+ (N+1)-D np.array with as many rows as there are complete frames that can be
39
+ extracted.
40
+ """
41
+
42
+ # print("data: ", data.shape)
43
+ num_samples = data.shape[0]
44
+ # print("num_samples: ", num_samples)
45
+ num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length))
46
+ # print("num_frames: ", num_frames)
47
+ shape = (num_frames, window_length) + data.shape[1:]
48
+ # print("shape: ", shape)
49
+ strides = (data.strides[0] * hop_length,) + data.strides
50
+ # print("strides: ", strides)
51
+
52
+
53
+
54
+ # 按shape进行分块
55
+ return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
56
+
57
+
58
+ def periodic_hann(window_length):
59
+ """Calculate a "periodic" Hann window.
60
+
61
+ The classic Hann window is defined as a raised cosine that starts and
62
+ ends on zero, and where every value appears twice, except the middle
63
+ point for an odd-length window. Matlab calls this a "symmetric" window
64
+ and np.hanning() returns it. However, for Fourier analysis, this
65
+ actually represents just over one cycle of a period N-1 cosine, and
66
+ thus is not compactly expressed on a length-N Fourier basis. Instead,
67
+ it's better to use a raised cosine that ends just before the final
68
+ zero value - i.e. a complete cycle of a period-N cosine. Matlab
69
+ calls this a "periodic" window. This routine calculates it.
70
+
71
+ Args:
72
+ window_length: The number of points in the returned window.
73
+
74
+ Returns:
75
+ A 1D np.array containing the periodic hann window.
76
+ """
77
+ return 0.5 - (0.5 * np.cos(2 * np.pi / window_length *
78
+ np.arange(window_length)))
79
+
80
+
81
+ def stft_magnitude(signal, fft_length,
82
+ hop_length=None,
83
+ window_length=None):
84
+ """Calculate the short-time Fourier transform magnitude.
85
+
86
+ Args:
87
+ signal: 1D np.array of the input time-domain signal.
88
+ fft_length: Size of the FFT to apply.
89
+ hop_length: Advance (in samples) between each frame passed to FFT.
90
+ window_length: Length of each block of samples to pass to FFT.
91
+
92
+ Returns:
93
+ 2D np.array where each row contains the magnitudes of the fft_length/2+1
94
+ unique values of the FFT for the corresponding frame of input samples.
95
+ """
96
+ frames = frame(signal, window_length, hop_length)
97
+ # Apply frame window to each frame. We use a periodic Hann (cosine of period
98
+ # window_length) instead of the symmetric Hann of np.hanning (period
99
+ # window_length-1).
100
+ window = periodic_hann(window_length)
101
+ windowed_frames = frames * window
102
+ return np.abs(np.fft.rfft(windowed_frames, int(fft_length)))
103
+
104
+
105
+ # Mel spectrum constants and functions.
106
+ _MEL_BREAK_FREQUENCY_HERTZ = 700.0
107
+ _MEL_HIGH_FREQUENCY_Q = 1127.0
108
+
109
+
110
+ def hertz_to_mel(frequencies_hertz):
111
+ """Convert frequencies to mel scale using HTK formula.
112
+
113
+ Args:
114
+ frequencies_hertz: Scalar or np.array of frequencies in hertz.
115
+
116
+ Returns:
117
+ Object of same size as frequencies_hertz containing corresponding values
118
+ on the mel scale.
119
+ """
120
+ return _MEL_HIGH_FREQUENCY_Q * np.log(
121
+ 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
122
+
123
+
124
+ def spectrogram_to_mel_matrix(num_mel_bins=20,
125
+ num_spectrogram_bins=129,
126
+ audio_sample_rate=8000,
127
+ lower_edge_hertz=125.0,
128
+ upper_edge_hertz=3800.0):
129
+ """Return a matrix that can post-multiply spectrogram rows to make mel.
130
+
131
+ Returns a np.array matrix A that can be used to post-multiply a matrix S of
132
+ spectrogram values (STFT magnitudes) arranged as frames x bins to generate a
133
+ "mel spectrogram" M of frames x num_mel_bins. M = S A.
134
+
135
+ The classic HTK algorithm exploits the complementarity of adjacent mel bands
136
+ to multiply each FFT bin by only one mel weight, then add it, with positive
137
+ and negative signs, to the two adjacent mel bands to which that bin
138
+ contributes. Here, by expressing this operation as a matrix multiply, we go
139
+ from num_fft multiplies per frame (plus around 2*num_fft adds) to around
140
+ num_fft^2 multiplies and adds. However, because these are all presumably
141
+ accomplished in a single call to np.dot(), it's not clear which approach is
142
+ faster in Python. The matrix multiplication has the attraction of being more
143
+ general and flexible, and much easier to read.
144
+
145
+ Args:
146
+ num_mel_bins: How many bands in the resulting mel spectrum. This is
147
+ the number of columns in the output matrix.
148
+ num_spectrogram_bins: How many bins there are in the source spectrogram
149
+ data, which is understood to be fft_size/2 + 1, i.e. the spectrogram
150
+ only contains the nonredundant FFT bins.
151
+ audio_sample_rate: Samples per second of the audio at the input to the
152
+ spectrogram. We need this to figure out the actual frequencies for
153
+ each spectrogram bin, which dictates how they are mapped into mel.
154
+ lower_edge_hertz: Lower bound on the frequencies to be included in the mel
155
+ spectrum. This corresponds to the lower edge of the lowest triangular
156
+ band.
157
+ upper_edge_hertz: The desired top edge of the highest frequency band.
158
+
159
+ Returns:
160
+ An np.array with shape (num_spectrogram_bins, num_mel_bins).
161
+
162
+ Raises:
163
+ ValueError: if frequency edges are incorrectly ordered or out of range.
164
+ """
165
+ nyquist_hertz = audio_sample_rate / 2.
166
+ if lower_edge_hertz < 0.0:
167
+ raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz)
168
+ if lower_edge_hertz >= upper_edge_hertz:
169
+ raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
170
+ (lower_edge_hertz, upper_edge_hertz))
171
+ if upper_edge_hertz > nyquist_hertz:
172
+ raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" %
173
+ (upper_edge_hertz, nyquist_hertz))
174
+ spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
175
+ spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz)
176
+ # The i'th mel band (starting from i=1) has center frequency
177
+ # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
178
+ # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in
179
+ # the band_edges_mel arrays.
180
+ band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz),
181
+ hertz_to_mel(upper_edge_hertz), num_mel_bins + 2)
182
+ # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
183
+ # of spectrogram values.
184
+ mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins))
185
+ for i in range(num_mel_bins):
186
+ lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3]
187
+ # Calculate lower and upper slopes for every spectrogram bin.
188
+ # Line segments are linear in the *mel* domain, not hertz.
189
+ lower_slope = ((spectrogram_bins_mel - lower_edge_mel) /
190
+ (center_mel - lower_edge_mel))
191
+ upper_slope = ((upper_edge_mel - spectrogram_bins_mel) /
192
+ (upper_edge_mel - center_mel))
193
+ # .. then intersect them with each other and zero.
194
+ mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope,
195
+ upper_slope))
196
+ # HTK excludes the spectrogram DC bin; make sure it always gets a zero
197
+ # coefficient.
198
+ mel_weights_matrix[0, :] = 0.0
199
+ return mel_weights_matrix
200
+
201
+
202
+ def log_mel_spectrogram(data,
203
+ audio_sample_rate=8000,
204
+ log_offset=0.0,
205
+ window_length_secs=0.025,
206
+ hop_length_secs=0.010,
207
+ **kwargs):
208
+ """Convert waveform to a log magnitude mel-frequency spectrogram.
209
+
210
+ Args:
211
+ data: 1D np.array of waveform data.
212
+ audio_sample_rate: The sampling rate of data.
213
+ log_offset: Add this to values when taking log to avoid -Infs.
214
+ window_length_secs: Duration of each window to analyze.
215
+ hop_length_secs: Advance between successive analysis windows.
216
+ **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix.
217
+
218
+ Returns:
219
+ 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank
220
+ magnitudes for successive frames.
221
+ """
222
+ window_length_samples = int(round(audio_sample_rate * window_length_secs))
223
+ hop_length_samples = int(round(audio_sample_rate * hop_length_secs))
224
+ fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
225
+ spectrogram = stft_magnitude(
226
+ data,
227
+ fft_length=fft_length,
228
+ hop_length=hop_length_samples,
229
+ window_length=window_length_samples)
230
+ mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix(
231
+ num_spectrogram_bins=spectrogram.shape[1],
232
+ audio_sample_rate=audio_sample_rate, **kwargs))
233
+ return np.log(mel_spectrogram + log_offset)
avism/data/datasets/extract_audio_feat/vggish_input.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2017 The TensorFlow Authors All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Compute input examples for VGGish from audio waveform."""
17
+
18
+ import numpy as np
19
+ import resampy
20
+ from scipy.io import wavfile
21
+
22
+ import mel_features
23
+ import vggish_params
24
+
25
+
26
+ def waveform_to_examples(data, sample_rate):
27
+ """Converts audio waveform into an array of examples for VGGish.
28
+
29
+ Args:
30
+ data: np.array of either one dimension (mono) or two dimensions
31
+ (multi-channel, with the outer dimension representing channels).
32
+ Each sample is generally expected to lie in the range [-1.0, +1.0],
33
+ although this is not required.
34
+ sample_rate: Sample rate of data.
35
+
36
+ Returns:
37
+ 3-D np.array of shape [num_examples, num_frames, num_bands] which represents
38
+ a sequence of examples, each of which contains a patch of log mel
39
+ spectrogram, covering num_frames frames of audio and num_bands mel frequency
40
+ bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS.
41
+ """
42
+ # Convert to mono.
43
+ if len(data.shape) > 1:
44
+ data = np.mean(data, axis=1)
45
+ # Resample to the rate assumed by VGGish.
46
+ if sample_rate != vggish_params.SAMPLE_RATE:
47
+ data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)
48
+
49
+ # Compute log mel spectrogram features.
50
+ log_mel = mel_features.log_mel_spectrogram(
51
+ data,
52
+ audio_sample_rate=vggish_params.SAMPLE_RATE,
53
+ log_offset=vggish_params.LOG_OFFSET,
54
+ window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS,
55
+ hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS,
56
+ num_mel_bins=vggish_params.NUM_MEL_BINS,
57
+ lower_edge_hertz=vggish_params.MEL_MIN_HZ,
58
+ upper_edge_hertz=vggish_params.MEL_MAX_HZ)
59
+
60
+ # Frame features into examples.
61
+ features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS
62
+ example_window_length = int(round(
63
+ vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate))
64
+ example_hop_length = int(round(
65
+ vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate))
66
+ log_mel_examples = mel_features.frame(
67
+ log_mel,
68
+ window_length=example_window_length,
69
+ hop_length=example_hop_length)
70
+ return log_mel_examples
71
+
72
+
73
+ def wavfile_to_examples(wav_file, num_secs):
74
+ """Convenience wrapper around waveform_to_examples() for a common WAV format.
75
+ Args:
76
+ wav_file: String path to a file, or a file-like object. The file
77
+ is assumed to contain WAV audio data with signed 16-bit PCM samples.
78
+
79
+ Returns:
80
+ See waveform_to_examples.
81
+ """
82
+ sr, snd = wavfile.read(wav_file)
83
+ L = sr * num_secs
84
+ wav_data = snd[:L, :]
85
+ wav_data = wav_data / 32768.0 # Convert to [-1.0, +1.0]
86
+ T = num_secs
87
+ log_mel = np.zeros([T, 96, 64])
88
+
89
+ for i in range(T):
90
+ s = i * sr
91
+ e = (i + 1) * sr
92
+ if len(wav_data.shape) > 1:
93
+ data = wav_data[s:e, :]
94
+ else:
95
+ data = wav_data[s:e]
96
+
97
+ wave_data_array = waveform_to_examples(data, sr)
98
+ if len(wave_data_array) != 0:
99
+ log_mel[i, :, :] = wave_data_array
100
+ else:
101
+ log_mel[i, :, :] = np.zeros((1, 96, 64), dtype=float)
102
+
103
+ return log_mel
avism/data/datasets/extract_audio_feat/vggish_params.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2017 The TensorFlow Authors All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Global parameters for the VGGish model.
17
+
18
+ See vggish_slim.py for more information.
19
+ """
20
+
21
+ # Architectural constants.
22
+ NUM_FRAMES = 96 # Frames in input mel-spectrogram patch.
23
+ NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch.
24
+ EMBEDDING_SIZE = 128 # Size of embedding layer.
25
+
26
+ # Hyperparameters used in feature and example generation.
27
+ SAMPLE_RATE = 16000
28
+ STFT_WINDOW_LENGTH_SECONDS = 0.025
29
+ STFT_HOP_LENGTH_SECONDS = 0.010
30
+ NUM_MEL_BINS = NUM_BANDS
31
+ MEL_MIN_HZ = 125
32
+ MEL_MAX_HZ = 7500
33
+ LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram.
34
+ EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
35
+ EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
36
+
37
+ # Parameters used for embedding postprocessing.
38
+ PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors'
39
+ PCA_MEANS_NAME = 'pca_means'
40
+ QUANTIZE_MIN_VAL = -2.0
41
+ QUANTIZE_MAX_VAL = +2.0
42
+
43
+ # Hyperparameters used in training.
44
+ INIT_STDDEV = 0.01 # Standard deviation used to initialize weights.
45
+ LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer.
46
+ ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer.
47
+
48
+ # Names of ops, tensors, and features.
49
+ INPUT_OP_NAME = 'vggish/input_features'
50
+ INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0'
51
+ OUTPUT_OP_NAME = 'vggish/embedding'
52
+ OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0'
53
+ AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding'
avism/data/datasets/extract_audio_feat/vggish_slim.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2017 The TensorFlow Authors All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Defines the 'VGGish' model used to generate AudioSet embedding features.
17
+
18
+ The public AudioSet release (https://research.google.com/audioset/download.html)
19
+ includes 128-D features extracted from the embedding layer of a VGG-like model
20
+ that was trained on a large Google-internal YouTube dataset. Here we provide
21
+ a TF-Slim definition of the same model, without any dependences on libraries
22
+ internal to Google. We call it 'VGGish'.
23
+
24
+ Note that we only define the model up to the embedding layer, which is the
25
+ penultimate layer before the final classifier layer. We also provide various
26
+ hyperparameter values (in vggish_params.py) that were used to train this model
27
+ internally.
28
+
29
+ For comparison, here is TF-Slim's VGG definition:
30
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/vgg.py
31
+ """
32
+
33
+ import tensorflow._api.v2.compat.v1 as tf
34
+ tf.disable_v2_behavior()
35
+ import tf_slim as slim
36
+
37
+ import vggish_params as params
38
+
39
+
40
+ def define_vggish_slim(training=False):
41
+ """Defines the VGGish TensorFlow model.
42
+
43
+ All ops are created in the current default graph, under the scope 'vggish/'.
44
+
45
+ The input is a placeholder named 'vggish/input_features' of type float32 and
46
+ shape [batch_size, num_frames, num_bands] where batch_size is variable and
47
+ num_frames and num_bands are constants, and [num_frames, num_bands] represents
48
+ a log-mel-scale spectrogram patch covering num_bands frequency bands and
49
+ num_frames time frames (where each frame step is usually 10ms). This is
50
+ produced by computing the stabilized log(mel-spectrogram + params.LOG_OFFSET).
51
+ The output is an op named 'vggish/embedding' which produces the activations of
52
+ a 128-D embedding layer, which is usually the penultimate layer when used as
53
+ part of a full model with a final classifier layer.
54
+
55
+ Args:
56
+ training: If true, all parameters are marked trainable.
57
+
58
+ Returns:
59
+ The op 'vggish/embeddings'.
60
+ """
61
+ # Defaults:
62
+ # - All weights are initialized to N(0, INIT_STDDEV).
63
+ # - All biases are initialized to 0.
64
+ # - All activations are ReLU.
65
+ # - All convolutions are 3x3 with stride 1 and SAME padding.
66
+ # - All max-pools are 2x2 with stride 2 and SAME padding.
67
+ with slim.arg_scope([slim.conv2d, slim.fully_connected],
68
+ weights_initializer=tf.truncated_normal_initializer(
69
+ stddev=params.INIT_STDDEV),
70
+ biases_initializer=tf.zeros_initializer(),
71
+ activation_fn=tf.nn.relu,
72
+ trainable=training), \
73
+ slim.arg_scope([slim.conv2d],
74
+ kernel_size=[3, 3], stride=1, padding='SAME'), \
75
+ slim.arg_scope([slim.max_pool2d],
76
+ kernel_size=[2, 2], stride=2, padding='SAME'), \
77
+ tf.compat.v1.variable_scope('vggish'):
78
+ # tf.variable_scope('vggish'):
79
+ # Input: a batch of 2-D log-mel-spectrogram patches.
80
+ # features = tf.placeholder(
81
+ features = tf.compat.v1.placeholder(
82
+ tf.float32, shape=(None, params.NUM_FRAMES, params.NUM_BANDS),
83
+ name='input_features')
84
+ # Reshape to 4-D so that we can convolve a batch with conv2d().
85
+ net = tf.reshape(features, [-1, params.NUM_FRAMES, params.NUM_BANDS, 1])
86
+
87
+ # The VGG stack of alternating convolutions and max-pools.
88
+ net = slim.conv2d(net, 64, scope='conv1')
89
+ net = slim.max_pool2d(net, scope='pool1')
90
+ net = slim.conv2d(net, 128, scope='conv2')
91
+ net = slim.max_pool2d(net, scope='pool2')
92
+ net = slim.repeat(net, 2, slim.conv2d, 256, scope='conv3')
93
+ net = slim.max_pool2d(net, scope='pool3')
94
+ net = slim.repeat(net, 2, slim.conv2d, 512, scope='conv4')
95
+ net = slim.max_pool2d(net, scope='pool4')
96
+
97
+ # Flatten before entering fully-connected layers
98
+ net = slim.flatten(net)
99
+ net = slim.repeat(net, 2, slim.fully_connected, 4096, scope='fc1')
100
+ # The embedding layer.
101
+ net = slim.fully_connected(net, params.EMBEDDING_SIZE, scope='fc2')
102
+ return tf.identity(net, name='embedding')
103
+
104
+
105
+ def load_vggish_slim_checkpoint(session, checkpoint_path):
106
+ """Loads a pre-trained VGGish-compatible checkpoint.
107
+
108
+ This function can be used as an initialization function (referred to as
109
+ init_fn in TensorFlow documentation) which is called in a Session after
110
+ initializating all variables. When used as an init_fn, this will load
111
+ a pre-trained checkpoint that is compatible with the VGGish model
112
+ definition. Only variables defined by VGGish will be loaded.
113
+
114
+ Args:
115
+ session: an active TensorFlow session.
116
+ checkpoint_path: path to a file containing a checkpoint that is
117
+ compatible with the VGGish model definition.
118
+ """
119
+ # Get the list of names of all VGGish variables that exist in
120
+ # the checkpoint (i.e., all inference-mode VGGish variables).
121
+ with tf.Graph().as_default():
122
+ define_vggish_slim(training=False)
123
+ # vggish_var_names = [v.name for v in tf.global_variables()]
124
+ vggish_var_names = [v.name for v in tf.compat.v1.global_variables()]
125
+
126
+ # Get the list of all currently existing variables that match
127
+ # the list of variable names we just computed.
128
+ # vggish_vars = [v for v in tf.global_variables() if v.name in vggish_var_names]
129
+ vggish_vars = [v for v in tf.compat.v1.global_variables() if v.name in vggish_var_names]
130
+
131
+ # Use a Saver to restore just the variables selected above.
132
+ # saver = tf.train.Saver(vggish_vars, name='vggish_load_pretrained')
133
+ saver = tf.compat.v1.train.Saver(vggish_vars, name='vggish_load_pretrained')
134
+ saver.restore(session, checkpoint_path)
avism/modeling/__init__.py ADDED
File without changes
avism/modeling/avism_criterion.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ from detectron2.utils.comm import get_world_size
6
+ from detectron2.projects.point_rend.point_features import (
7
+ get_uncertain_point_coords_with_randomness,
8
+ point_sample,
9
+ )
10
+
11
+ from ..utils.misc import is_dist_avail_and_initialized
12
+
13
+
14
+ def dice_loss(
15
+ inputs: torch.Tensor,
16
+ targets: torch.Tensor,
17
+ num_masks: float,
18
+ ):
19
+ """
20
+ Compute the DICE loss, similar to generalized IOU for masks
21
+ Args:
22
+ inputs: A float tensor of arbitrary shape.
23
+ The predictions for each example.
24
+ targets: A float tensor with the same shape as inputs. Stores the binary
25
+ classification label for each element in inputs
26
+ (0 for the negative class and 1 for the positive class).
27
+ """
28
+ inputs = inputs.sigmoid()
29
+ inputs = inputs.flatten(1)
30
+ numerator = 2 * (inputs * targets).sum(-1)
31
+ denominator = inputs.sum(-1) + targets.sum(-1)
32
+ loss = 1 - (numerator + 1) / (denominator + 1)
33
+ return loss.sum() / num_masks
34
+
35
+
36
+ dice_loss_jit = torch.jit.script(
37
+ dice_loss
38
+ ) # type: torch.jit.ScriptModule
39
+
40
+
41
+ def sigmoid_ce_loss(
42
+ inputs: torch.Tensor,
43
+ targets: torch.Tensor,
44
+ num_masks: float,
45
+ ):
46
+ """
47
+ Args:
48
+ inputs: A float tensor of arbitrary shape.
49
+ The predictions for each example.
50
+ targets: A float tensor with the same shape as inputs. Stores the binary
51
+ classification label for each element in inputs
52
+ (0 for the negative class and 1 for the positive class).
53
+ Returns:
54
+ Loss tensor
55
+ """
56
+ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
57
+
58
+ return loss.mean(1).sum() / num_masks
59
+
60
+
61
+ sigmoid_ce_loss_jit = torch.jit.script(
62
+ sigmoid_ce_loss
63
+ ) # type: torch.jit.ScriptModule
64
+
65
+
66
+ def calculate_uncertainty(logits):
67
+ """
68
+ We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
69
+ foreground class in `classes`.
70
+ Args:
71
+ logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or
72
+ class-agnostic, where R is the total number of predicted masks in all images and C is
73
+ the number of foreground classes. The values are logits.
74
+ Returns:
75
+ scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
76
+ the most uncertain locations having the highest uncertainty score.
77
+ """
78
+ assert logits.shape[1] == 1
79
+ gt_class_logits = logits.clone()
80
+ return -(torch.abs(gt_class_logits))
81
+
82
+
83
+ class AvismSetCriterion(nn.Module):
84
+ """This class computes the loss for DETR.
85
+ The process happens in two steps:
86
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
87
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
88
+ """
89
+
90
+ def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses,
91
+ num_points, oversample_ratio, importance_sample_ratio, sim_use_clip):
92
+ """Create the criterion.
93
+ Parameters:
94
+ num_classes: number of object categories, omitting the special no-object category
95
+ matcher: module able to compute a matching between targets and proposals
96
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
97
+ eos_coef: relative classification weight applied to the no-object category
98
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
99
+ """
100
+ super().__init__()
101
+ self.num_classes = num_classes
102
+ self.matcher = matcher
103
+ self.weight_dict = weight_dict
104
+ self.eos_coef = eos_coef
105
+ self.losses = losses
106
+ empty_weight = torch.ones(self.num_classes + 1)
107
+ empty_weight[-1] = self.eos_coef
108
+ self.register_buffer("empty_weight", empty_weight)
109
+
110
+ # pointwise mask loss parameters
111
+ self.num_points = num_points
112
+ self.oversample_ratio = oversample_ratio
113
+ self.importance_sample_ratio = importance_sample_ratio
114
+ self.sim_use_clip = sim_use_clip
115
+
116
+ def loss_labels(self, outputs, targets, indices, num_masks):
117
+ """Classification loss (NLL)
118
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
119
+ """
120
+ assert "pred_logits" in outputs
121
+ src_logits = outputs['pred_logits']
122
+ L, B, cQ, _ = src_logits.shape
123
+ src_logits = src_logits.reshape(L*B, cQ, self.num_classes+1)
124
+
125
+ idx = self._get_src_permutation_idx(indices)
126
+ target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets * L, indices)])
127
+ target_classes = torch.full(
128
+ src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
129
+ )
130
+ target_classes[idx] = target_classes_o
131
+
132
+ loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
133
+ losses = {'loss_avism_ce': loss_ce}
134
+
135
+ return losses
136
+
137
+ def loss_masks(self, outputs, targets, indices, num_masks):
138
+ """Compute the losses related to the masks: the focal loss and the dice loss.
139
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
140
+ """
141
+ assert "pred_masks" in outputs
142
+
143
+ idx = self._get_src_permutation_idx(indices)
144
+ src_masks = outputs["pred_masks"]
145
+ L, B, cQ, T, H, W = src_masks.shape
146
+ src_masks = src_masks.reshape(L*B, cQ, T, H, W)
147
+
148
+ src_masks = src_masks[idx] # Nt x T x Hp x Wp
149
+ target_masks = torch.cat([t['masks'][i] for t, (_, i) in zip(targets * L, indices)]).to(src_masks)
150
+ # Nt x T x Ht x Wt
151
+ src_masks = src_masks.flatten(0, 1)[:, None]
152
+ target_masks = target_masks.flatten(0, 1)[:, None]
153
+
154
+ with torch.no_grad():
155
+ # sample point_coords
156
+ point_coords = get_uncertain_point_coords_with_randomness(
157
+ src_masks,
158
+ lambda logits: calculate_uncertainty(logits),
159
+ self.num_points,
160
+ self.oversample_ratio,
161
+ self.importance_sample_ratio,
162
+ )
163
+ # get gt labels
164
+ point_labels = point_sample(
165
+ target_masks,
166
+ point_coords,
167
+ align_corners=False,
168
+ ).squeeze(1)
169
+
170
+ point_logits = point_sample(
171
+ src_masks,
172
+ point_coords,
173
+ align_corners=False,
174
+ ).squeeze(1)
175
+
176
+ # Nt*T, randN -> Nt, T*randN
177
+ point_logits = point_logits.view(len(idx[0]), T * self.num_points)
178
+ point_labels = point_labels.view(len(idx[0]), T * self.num_points)
179
+
180
+ losses = {
181
+ "loss_avism_mask": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks),
182
+ "loss_avism_dice": dice_loss_jit(point_logits, point_labels, num_masks),
183
+ }
184
+
185
+ del src_masks
186
+ del target_masks
187
+ return losses
188
+
189
+ def loss_fg_sim(
190
+ self, outputs, clip_targets, frame_targets,
191
+ clip_indices, frame_indices, num_masks, MULTIPLIER=1000
192
+ ):
193
+ total_src_q, total_tgt_ids, total_batch_idx = [], [], []
194
+
195
+ # Frame
196
+ src_fq = outputs["pred_fq_embed"] # L, B, T, fQ, C
197
+ # L = number of frame_decoder layers
198
+ L, B, T, fQ, C = src_fq.shape
199
+ src_fq = src_fq.flatten(0, 2) # LBT, fQ, C
200
+
201
+ frame_indices = sum(frame_indices, [])
202
+ frame_src_idx = self._get_src_permutation_idx(frame_indices) # len = LBT
203
+ src_fq = src_fq[frame_src_idx] # Nf, C
204
+ target_frame_ids = torch.cat(
205
+ [t["ids"][J] for t, (_, J) in zip(frame_targets * L, frame_indices)]
206
+ )
207
+ frame_batch_idx = torch.div(frame_src_idx[0].to(device=src_fq.device), T, rounding_mode="floor")
208
+ is_frame_valid = target_frame_ids != -1
209
+ target_frame_ids += frame_batch_idx * MULTIPLIER
210
+
211
+ total_src_q.append(src_fq[is_frame_valid])
212
+ total_tgt_ids.append(target_frame_ids[is_frame_valid])
213
+ total_batch_idx.append(frame_batch_idx[is_frame_valid])
214
+
215
+ # Clip
216
+ if self.sim_use_clip:
217
+ src_cq = outputs["pred_cq_embed"] # L, B, cQ, C
218
+ src_cq = src_cq.flatten(0, 1) # LB , cQ, C
219
+
220
+ clip_src_idx = self._get_src_permutation_idx(clip_indices) # len = LB
221
+ src_cq = src_cq[clip_src_idx] # Nc, C
222
+ target_clip_ids = torch.cat( # clip_ids' shape = (N, num_frames) -> (N,)
223
+ [t["ids"][J] for t, (_, J) in zip(clip_targets * L, clip_indices)]
224
+ ).amax(dim=1)
225
+ clip_batch_idx = clip_src_idx[0].to(device=src_fq.device)
226
+ is_clip_valid = target_clip_ids != -1
227
+ target_clip_ids += clip_batch_idx * MULTIPLIER
228
+
229
+ total_src_q.append(src_cq[is_clip_valid])
230
+ total_tgt_ids.append(target_clip_ids[is_clip_valid])
231
+ total_batch_idx.append(clip_batch_idx[is_clip_valid])
232
+
233
+ # Clip + Frame
234
+ total_src_q = torch.cat(total_src_q) # Nc+Nf, C
235
+ total_tgt_ids = torch.cat(total_tgt_ids) # Nc+Nf
236
+ total_batch_idx = torch.cat(total_batch_idx) # Nc+Nf
237
+
238
+ sim_pred_logits = torch.matmul(total_src_q, total_src_q.T) # Nc+Nf, Nc+Nf
239
+ sim_tgt = (total_tgt_ids[:, None] == total_tgt_ids[None]).float() # Nc+Nf, Nc+Nf
240
+
241
+ same_clip = (total_batch_idx[:, None] == total_batch_idx[None]).float()
242
+ loss = F.binary_cross_entropy_with_logits(sim_pred_logits, sim_tgt, reduction='none')
243
+
244
+ loss = loss * same_clip
245
+ loss_clip_sim = loss.sum() / (same_clip.sum() + 1e-6)
246
+
247
+ return {"loss_clip_sim": loss_clip_sim}
248
+
249
+ def _get_src_permutation_idx(self, indices):
250
+ # permute predictions following indices
251
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
252
+ src_idx = torch.cat([src for (src, _) in indices])
253
+ return batch_idx, src_idx
254
+
255
+ def _get_tgt_permutation_idx(self, indices):
256
+ # permute targets following indices
257
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
258
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
259
+ return batch_idx, tgt_idx
260
+
261
+ def get_loss(
262
+ self, loss, outputs, clip_targets, frame_targets, clip_indices, frame_indices, num_masks
263
+ ):
264
+ loss_map = {
265
+ 'avism_labels': self.loss_labels,
266
+ 'avism_masks': self.loss_masks,
267
+ 'fg_sim': self.loss_fg_sim,
268
+ }
269
+ assert loss in loss_map, f"do you really want to compute {loss} loss?"
270
+ if loss == 'fg_sim':
271
+ return loss_map[loss](
272
+ outputs, clip_targets, frame_targets, clip_indices, frame_indices, num_masks
273
+ )
274
+ return loss_map[loss](outputs, clip_targets, clip_indices, num_masks)
275
+
276
+ def forward(self, outputs, clip_targets, frame_targets, frame_indices=None):
277
+ """This performs the loss computation.
278
+ Parameters:
279
+ outputs: dict of tensors, see the output specification of the model for the format
280
+ targets: list of dicts, such that len(targets) == batch_size.
281
+ The expected keys in each dict depends on the losses applied, see each loss' doc
282
+ """
283
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
284
+
285
+ # Retrieve the matching between the outputs of the last layer and the targets
286
+ clip_indices = self.matcher(outputs_without_aux, clip_targets)
287
+
288
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
289
+ num_masks = sum(len(t["labels"]) for t in clip_targets) * len(outputs_without_aux["pred_masks"])
290
+ num_masks = torch.as_tensor(
291
+ [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
292
+ )
293
+ if is_dist_avail_and_initialized():
294
+ torch.distributed.all_reduce(num_masks)
295
+ num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
296
+
297
+ # Compute all the requested losses
298
+ losses = {}
299
+ for loss in self.losses:
300
+ losses.update(
301
+ self.get_loss(
302
+ loss, outputs, clip_targets, frame_targets, clip_indices, frame_indices, num_masks
303
+ )
304
+ )
305
+
306
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
307
+ if "aux_outputs" in outputs:
308
+ for i, aux_outputs in enumerate(outputs["aux_outputs"]):
309
+ clip_indices = self.matcher(aux_outputs, clip_targets)
310
+ for loss in self.losses:
311
+ if loss == "fg_sim":
312
+ continue
313
+ l_dict = self.get_loss(
314
+ loss, aux_outputs, clip_targets, frame_targets, clip_indices, frame_indices, num_masks
315
+ )
316
+ l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
317
+ losses.update(l_dict)
318
+
319
+ return losses
320
+
321
+ def __repr__(self):
322
+ head = "Criterion " + self.__class__.__name__
323
+ body = [
324
+ "matcher: {}".format(self.matcher.__repr__(_repr_indent=8)),
325
+ "losses: {}".format(self.losses),
326
+ "weight_dict: {}".format(self.weight_dict),
327
+ "num_classes: {}".format(self.num_classes),
328
+ "eos_coef: {}".format(self.eos_coef),
329
+ "num_points: {}".format(self.num_points),
330
+ "oversample_ratio: {}".format(self.oversample_ratio),
331
+ "importance_sample_ratio: {}".format(self.importance_sample_ratio),
332
+ ]
333
+ _repr_indent = 4
334
+ lines = [head] + [" " * _repr_indent + line for line in body]
335
+ return "\n".join(lines)
avism/modeling/avism_matcher.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modules to compute the matching cost and solve the corresponding LSAP.
3
+ """
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from scipy.optimize import linear_sum_assignment
7
+ from torch import nn
8
+ from torch.cuda.amp import autocast
9
+
10
+ from detectron2.projects.point_rend.point_features import point_sample
11
+
12
+
13
+ def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor):
14
+ """
15
+ Compute the DICE loss, similar to generalized IOU for masks
16
+ Args:
17
+ inputs: A float tensor of arbitrary shape.
18
+ The predictions for each example.
19
+ targets: A float tensor with the same shape as inputs. Stores the binary
20
+ classification label for each element in inputs
21
+ (0 for the negative class and 1 for the positive class).
22
+ """
23
+ inputs = inputs.sigmoid()
24
+ inputs = inputs.flatten(1)
25
+ numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets)
26
+ denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]
27
+ loss = 1 - (numerator + 1) / (denominator + 1)
28
+ return loss
29
+
30
+
31
+ batch_dice_loss_jit = torch.jit.script(
32
+ batch_dice_loss
33
+ ) # type: torch.jit.ScriptModule
34
+
35
+
36
+ def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor):
37
+ """
38
+ Args:
39
+ inputs: A float tensor of arbitrary shape.
40
+ The predictions for each example.
41
+ targets: A float tensor with the same shape as inputs. Stores the binary
42
+ classification label for each element in inputs
43
+ (0 for the negative class and 1 for the positive class).
44
+ Returns:
45
+ Loss tensor
46
+ """
47
+ hw = inputs.shape[1]
48
+
49
+ pos = F.binary_cross_entropy_with_logits(
50
+ inputs, torch.ones_like(inputs), reduction="none"
51
+ )
52
+ neg = F.binary_cross_entropy_with_logits(
53
+ inputs, torch.zeros_like(inputs), reduction="none"
54
+ )
55
+
56
+ loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum(
57
+ "nc,mc->nm", neg, (1 - targets)
58
+ )
59
+
60
+ return loss / hw
61
+
62
+
63
+ batch_sigmoid_ce_loss_jit = torch.jit.script(
64
+ batch_sigmoid_ce_loss
65
+ ) # type: torch.jit.ScriptModule
66
+
67
+
68
+ class AvismHungarianMatcher(nn.Module):
69
+ """This class computes an assignment between the targets and the predictions of the network
70
+
71
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
72
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
73
+ while the others are un-matched (and thus treated as non-objects).
74
+ """
75
+
76
+ def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, num_points: int = 0):
77
+ """Creates the matcher
78
+
79
+ Params:
80
+ cost_class: This is the relative weight of the classification error in the matching cost
81
+ cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost
82
+ cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost
83
+ """
84
+ super().__init__()
85
+ self.cost_class = cost_class
86
+ self.cost_mask = cost_mask
87
+ self.cost_dice = cost_dice
88
+
89
+ assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0"
90
+
91
+ self.num_points = num_points
92
+
93
+ @torch.no_grad()
94
+ def memory_efficient_forward(self, outputs, targets):
95
+ # We flatten to compute the cost matrices in a batch
96
+
97
+ # Here, "L" is the number of frame-level decoder layers.
98
+ out_prob = outputs["pred_logits"].softmax(-1) # L, B, cQ, K+1
99
+ out_mask = outputs["pred_masks"] # L, B, cQ, T, H, W
100
+
101
+ L, B, cQ, T, s_h, s_w = out_mask.shape
102
+
103
+ out_prob = out_prob.reshape(L*B, cQ, -1)
104
+ out_mask = out_mask.reshape(L*B, cQ, T, s_h, s_w)
105
+
106
+ # If target is [vid1, vid2, vid3],
107
+ # it now becomes [vid1, vid2, vid3, vid1, vid2, vid3, ...].
108
+ targets = targets * L
109
+
110
+ indices = []
111
+ for b in range(L*B):
112
+ b_out_prob = out_prob[b]
113
+ tgt_ids = targets[b]["labels"]
114
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
115
+ # but approximate it in 1 - proba[target class].
116
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
117
+ cost_class = -b_out_prob[:, tgt_ids]
118
+
119
+ b_out_mask = out_mask[b] # cQ x T x H_pred x W_pred
120
+ # gt masks are already padded when preparing target
121
+ tgt_mask = targets[b]["masks"].to(b_out_mask) # Nins x T x H_tgt x W_tgt
122
+
123
+ # out_mask = out_mask[:, None]
124
+ # tgt_mask = tgt_mask[:, None]
125
+ # all masks share the same set of points for efficient matching!
126
+ point_coords = torch.rand(1, self.num_points, 2, device=b_out_mask.device)
127
+ # get gt labels
128
+ tgt_mask = point_sample(
129
+ tgt_mask,
130
+ point_coords.repeat(tgt_mask.shape[0], 1, 1),
131
+ align_corners=False,
132
+ ).flatten(1)
133
+
134
+ b_out_mask = point_sample(
135
+ b_out_mask,
136
+ point_coords.repeat(b_out_mask.shape[0], 1, 1),
137
+ align_corners=False,
138
+ ).flatten(1)
139
+
140
+ with autocast(enabled=False):
141
+ b_out_mask = b_out_mask.float()
142
+ tgt_mask = tgt_mask.float()
143
+ # Compute the focal loss between masks
144
+ cost_mask = batch_sigmoid_ce_loss_jit(b_out_mask, tgt_mask)
145
+ # Compute the dice loss betwen masks
146
+ cost_dice = batch_dice_loss(b_out_mask, tgt_mask)
147
+
148
+ # Final cost matrix
149
+ C = (
150
+ self.cost_mask * cost_mask
151
+ + self.cost_class * cost_class
152
+ + self.cost_dice * cost_dice
153
+ )
154
+ C = C.reshape(cQ, -1).cpu()
155
+
156
+ indices.append(linear_sum_assignment(C))
157
+
158
+ return [
159
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
160
+ for i, j in indices
161
+ ]
162
+
163
+ @torch.no_grad()
164
+ def forward(self, outputs, targets):
165
+ """Performs the matching
166
+
167
+ Params:
168
+ outputs: This is a dict that contains at least these entries:
169
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
170
+ "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks
171
+
172
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
173
+ "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
174
+ objects in the target) containing the class labels
175
+ "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks
176
+
177
+ Returns:
178
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
179
+ - index_i is the indices of the selected predictions (in order)
180
+ - index_j is the indices of the corresponding selected targets (in order)
181
+ For each batch element, it holds:
182
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
183
+ """
184
+ return self.memory_efficient_forward(outputs, targets)
185
+
186
+ def __repr__(self, _repr_indent=4):
187
+ head = "Matcher " + self.__class__.__name__
188
+ body = [
189
+ "cost_class: {}".format(self.cost_class),
190
+ "cost_mask: {}".format(self.cost_mask),
191
+ "cost_dice: {}".format(self.cost_dice),
192
+ ]
193
+ lines = [head] + [" " * _repr_indent + line for line in body]
194
+ return "\n".join(lines)
avism/modeling/transformer_decoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .avism_transformer_decoder import AVISMMultiScaleMaskedTransformerDecoder
avism/modeling/transformer_decoder/avism.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import ceil
2
+ import fvcore.nn.weight_init as weight_init
3
+ from typing import Optional
4
+ import torch
5
+ from torch import nn, Tensor
6
+ from torch.nn import functional as F
7
+ import copy
8
+
9
+ from detectron2.config import configurable
10
+ from detectron2.layers import Conv2d
11
+
12
+
13
+ class SelfAttentionLayer(nn.Module):
14
+
15
+ def __init__(self, d_model, nhead, dropout=0.0,
16
+ activation="relu", normalize_before=False):
17
+ super().__init__()
18
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
19
+
20
+ self.norm = nn.LayerNorm(d_model)
21
+ self.dropout = nn.Dropout(dropout)
22
+
23
+ self.activation = _get_activation_fn(activation)
24
+ self.normalize_before = normalize_before
25
+
26
+ self._reset_parameters()
27
+
28
+ def _reset_parameters(self):
29
+ for p in self.parameters():
30
+ if p.dim() > 1:
31
+ nn.init.xavier_uniform_(p)
32
+
33
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
34
+ return tensor if pos is None else tensor + pos
35
+
36
+ def forward_post(self, tgt,
37
+ tgt_mask: Optional[Tensor] = None,
38
+ tgt_key_padding_mask: Optional[Tensor] = None,
39
+ query_pos: Optional[Tensor] = None):
40
+ q = k = self.with_pos_embed(tgt, query_pos)
41
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
42
+ key_padding_mask=tgt_key_padding_mask)[0]
43
+ tgt = tgt + self.dropout(tgt2)
44
+ tgt = self.norm(tgt)
45
+
46
+ return tgt
47
+
48
+ def forward_pre(self, tgt,
49
+ tgt_mask: Optional[Tensor] = None,
50
+ tgt_key_padding_mask: Optional[Tensor] = None,
51
+ query_pos: Optional[Tensor] = None):
52
+ tgt2 = self.norm(tgt)
53
+ q = k = self.with_pos_embed(tgt2, query_pos)
54
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
55
+ key_padding_mask=tgt_key_padding_mask)[0]
56
+ tgt = tgt + self.dropout(tgt2)
57
+
58
+ return tgt
59
+
60
+ def forward(self, tgt,
61
+ tgt_mask: Optional[Tensor] = None,
62
+ tgt_key_padding_mask: Optional[Tensor] = None,
63
+ query_pos: Optional[Tensor] = None):
64
+ if self.normalize_before:
65
+ return self.forward_pre(tgt, tgt_mask,
66
+ tgt_key_padding_mask, query_pos)
67
+ return self.forward_post(tgt, tgt_mask,
68
+ tgt_key_padding_mask, query_pos)
69
+
70
+
71
+ class CrossAttentionLayer(nn.Module):
72
+
73
+ def __init__(self, d_model, nhead, dropout=0.0,
74
+ activation="relu", normalize_before=False):
75
+ super().__init__()
76
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
77
+
78
+ self.norm = nn.LayerNorm(d_model)
79
+ self.dropout = nn.Dropout(dropout)
80
+
81
+ self.activation = _get_activation_fn(activation)
82
+ self.normalize_before = normalize_before
83
+
84
+ self._reset_parameters()
85
+
86
+ def _reset_parameters(self):
87
+ for p in self.parameters():
88
+ if p.dim() > 1:
89
+ nn.init.xavier_uniform_(p)
90
+
91
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
92
+ return tensor if pos is None else tensor + pos
93
+
94
+ def forward_post(self, tgt, memory,
95
+ memory_mask: Optional[Tensor] = None,
96
+ memory_key_padding_mask: Optional[Tensor] = None,
97
+ pos: Optional[Tensor] = None,
98
+ query_pos: Optional[Tensor] = None):
99
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
100
+ key=self.with_pos_embed(memory, pos),
101
+ value=memory, attn_mask=memory_mask,
102
+ key_padding_mask=memory_key_padding_mask)[0]
103
+ tgt = tgt + self.dropout(tgt2)
104
+ tgt = self.norm(tgt)
105
+
106
+ return tgt
107
+
108
+ def forward_pre(self, tgt, memory,
109
+ memory_mask: Optional[Tensor] = None,
110
+ memory_key_padding_mask: Optional[Tensor] = None,
111
+ pos: Optional[Tensor] = None,
112
+ query_pos: Optional[Tensor] = None):
113
+ tgt2 = self.norm(tgt)
114
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
115
+ key=self.with_pos_embed(memory, pos),
116
+ value=memory, attn_mask=memory_mask,
117
+ key_padding_mask=memory_key_padding_mask)[0]
118
+ tgt = tgt + self.dropout(tgt2)
119
+
120
+ return tgt
121
+
122
+ def forward(self, tgt, memory,
123
+ memory_mask: Optional[Tensor] = None,
124
+ memory_key_padding_mask: Optional[Tensor] = None,
125
+ pos: Optional[Tensor] = None,
126
+ query_pos: Optional[Tensor] = None):
127
+ if self.normalize_before:
128
+ return self.forward_pre(tgt, memory, memory_mask,
129
+ memory_key_padding_mask, pos, query_pos)
130
+ return self.forward_post(tgt, memory, memory_mask,
131
+ memory_key_padding_mask, pos, query_pos)
132
+
133
+
134
+ class FFNLayer(nn.Module):
135
+
136
+ def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
137
+ activation="relu", normalize_before=False):
138
+ super().__init__()
139
+ # Implementation of Feedforward model
140
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
141
+ self.dropout = nn.Dropout(dropout)
142
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
143
+
144
+ self.norm = nn.LayerNorm(d_model)
145
+
146
+ self.activation = _get_activation_fn(activation)
147
+ self.normalize_before = normalize_before
148
+
149
+ self._reset_parameters()
150
+
151
+ def _reset_parameters(self):
152
+ for p in self.parameters():
153
+ if p.dim() > 1:
154
+ nn.init.xavier_uniform_(p)
155
+
156
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
157
+ return tensor if pos is None else tensor + pos
158
+
159
+ def forward_post(self, tgt):
160
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
161
+ tgt = tgt + self.dropout(tgt2)
162
+ tgt = self.norm(tgt)
163
+ return tgt
164
+
165
+ def forward_pre(self, tgt):
166
+ tgt2 = self.norm(tgt)
167
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
168
+ tgt = tgt + self.dropout(tgt2)
169
+ return tgt
170
+
171
+ def forward(self, tgt):
172
+ if self.normalize_before:
173
+ return self.forward_pre(tgt)
174
+ return self.forward_post(tgt)
175
+
176
+
177
+ def _get_activation_fn(activation):
178
+ """Return an activation function given a string"""
179
+ if activation == "relu":
180
+ return F.relu
181
+ if activation == "gelu":
182
+ return F.gelu
183
+ if activation == "glu":
184
+ return F.glu
185
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
186
+
187
+
188
+ class MLP(nn.Module):
189
+ """ Very simple multi-layer perceptron (also called FFN)"""
190
+
191
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
192
+ super().__init__()
193
+ self.num_layers = num_layers
194
+ h = [hidden_dim] * (num_layers - 1)
195
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
196
+
197
+ def forward(self, x):
198
+ for i, layer in enumerate(self.layers):
199
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
200
+ return x
201
+
202
+
203
+ def _get_clones(module, N):
204
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
205
+
206
+
207
+ class Avism(nn.Module):
208
+
209
+ @configurable
210
+ def __init__(
211
+ self,
212
+ in_channels,
213
+ aux_loss,
214
+ *,
215
+ hidden_dim: int,
216
+ num_frame_queries: int,
217
+ num_queries: int,
218
+ nheads: int,
219
+ dim_feedforward: int,
220
+ enc_layers: int,
221
+ dec_layers: int,
222
+ enc_window_size: int,
223
+ pre_norm: bool,
224
+ enforce_input_project: bool,
225
+ num_frames: int,
226
+ num_classes: int,
227
+ clip_last_layer_num: bool,
228
+ conv_dim: int,
229
+ mask_dim: int,
230
+ sim_use_clip: list,
231
+ use_sim: bool,
232
+ ):
233
+ """
234
+ NOTE: this interface is experimental.
235
+ Args:
236
+ in_channels: channels of the input features
237
+ hidden_dim: Transformer feature dimension
238
+ num_queries: number of queries
239
+ nheads: number of heads
240
+ dim_feedforward: feature dimension in feedforward network
241
+ enc_layers: number of Transformer encoder layers
242
+ dec_layers: number of Transformer decoder layers
243
+ pre_norm: whether to use pre-LayerNorm or not
244
+ enforce_input_project: add input project 1x1 conv even if input
245
+ channels and hidden dim is identical
246
+ """
247
+ super().__init__()
248
+
249
+ # define Transformer decoder here
250
+ self.num_heads = nheads
251
+ self.num_layers = dec_layers
252
+ self.transformer_self_attention_layers = nn.ModuleList()
253
+ self.transformer_cross_attention_layers = nn.ModuleList()
254
+ self.transformer_ffn_layers = nn.ModuleList()
255
+ self.num_frames = num_frames
256
+ self.num_classes = num_classes
257
+ self.clip_last_layer_num = clip_last_layer_num
258
+
259
+ self.enc_layers = enc_layers
260
+ self.window_size = enc_window_size
261
+ self.sim_use_clip = sim_use_clip
262
+ self.use_sim = use_sim
263
+ self.aux_loss = aux_loss
264
+
265
+ self.av_proj = nn.Linear(128, hidden_dim)
266
+
267
+ self.enc_layers = enc_layers
268
+ if enc_layers > 0:
269
+ self.enc_self_attn = nn.ModuleList()
270
+ self.enc_ffn = nn.ModuleList()
271
+ for _ in range(self.enc_layers):
272
+ self.enc_self_attn.append(
273
+ SelfAttentionLayer(
274
+ d_model=hidden_dim,
275
+ nhead=nheads,
276
+ dropout=0.0,
277
+ normalize_before=pre_norm,
278
+ ),
279
+ )
280
+ self.enc_ffn.append(
281
+ FFNLayer(
282
+ d_model=hidden_dim,
283
+ dim_feedforward=dim_feedforward,
284
+ dropout=0.0,
285
+ normalize_before=pre_norm,
286
+ )
287
+ )
288
+
289
+ if enc_layers > 0:
290
+ self.enc_av_cross_attn = nn.ModuleList()
291
+ self.enc_av_ffn = nn.ModuleList()
292
+ for _ in range(self.enc_layers):
293
+ self.enc_av_cross_attn.append(
294
+ CrossAttentionLayer(
295
+ d_model=hidden_dim,
296
+ nhead=nheads,
297
+ dropout=0.0,
298
+ normalize_before=pre_norm,
299
+ ),
300
+ )
301
+ self.enc_av_ffn.append(
302
+ FFNLayer(
303
+ d_model=hidden_dim,
304
+ dim_feedforward=dim_feedforward,
305
+ dropout=0.0,
306
+ normalize_before=pre_norm,
307
+ )
308
+ )
309
+
310
+ for _ in range(self.num_layers):
311
+ self.transformer_self_attention_layers.append(
312
+ SelfAttentionLayer(
313
+ d_model=hidden_dim,
314
+ nhead=nheads,
315
+ dropout=0.0,
316
+ normalize_before=pre_norm,
317
+ )
318
+ )
319
+
320
+ self.transformer_cross_attention_layers.append(
321
+ CrossAttentionLayer(
322
+ d_model=hidden_dim,
323
+ nhead=nheads,
324
+ dropout=0.0,
325
+ normalize_before=pre_norm,
326
+ )
327
+ )
328
+
329
+ self.transformer_ffn_layers.append(
330
+ FFNLayer(
331
+ d_model=hidden_dim,
332
+ dim_feedforward=dim_feedforward,
333
+ dropout=0.0,
334
+ normalize_before=pre_norm,
335
+ )
336
+ )
337
+
338
+ self.avism_mask_features = Conv2d(
339
+ conv_dim,
340
+ mask_dim,
341
+ kernel_size=1,
342
+ stride=1,
343
+ padding=0,
344
+ )
345
+ weight_init.c2_xavier_fill(self.avism_mask_features)
346
+
347
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
348
+
349
+ self.num_queries = num_queries
350
+ # learnable query features
351
+ self.query_feat = nn.Embedding(num_queries, hidden_dim)
352
+ # learnable query p.e.
353
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
354
+
355
+ self.fq_pos = nn.Embedding(num_frame_queries, hidden_dim)
356
+
357
+ if in_channels != hidden_dim or enforce_input_project:
358
+ self.input_proj_dec = nn.Linear(hidden_dim, hidden_dim)
359
+ else:
360
+ self.input_proj_dec = nn.Sequential()
361
+ self.src_embed = nn.Identity()
362
+
363
+ self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
364
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
365
+ if self.use_sim:
366
+ self.sim_embed_frame = nn.Linear(hidden_dim, hidden_dim)
367
+ if self.sim_use_clip:
368
+ self.sim_embed_clip = nn.Linear(hidden_dim, hidden_dim)
369
+
370
+ @classmethod
371
+ def from_config(cls, cfg, in_channels):
372
+ ret = {}
373
+ ret["in_channels"] = in_channels
374
+
375
+ ret["hidden_dim"] = cfg.MODEL.AVISM.HIDDEN_DIM
376
+ ret["num_frame_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
377
+ ret["num_queries"] = cfg.MODEL.AVISM.NUM_OBJECT_QUERIES
378
+ # Transformer parameters:
379
+ ret["nheads"] = cfg.MODEL.AVISM.NHEADS
380
+ ret["dim_feedforward"] = cfg.MODEL.AVISM.DIM_FEEDFORWARD
381
+
382
+ assert cfg.MODEL.AVISM.DEC_LAYERS >= 1
383
+ ret["enc_layers"] = cfg.MODEL.AVISM.ENC_LAYERS
384
+ ret["dec_layers"] = cfg.MODEL.AVISM.DEC_LAYERS
385
+ ret["enc_window_size"] = cfg.MODEL.AVISM.ENC_WINDOW_SIZE
386
+ ret["pre_norm"] = cfg.MODEL.AVISM.PRE_NORM
387
+ ret["enforce_input_project"] = cfg.MODEL.AVISM.ENFORCE_INPUT_PROJ
388
+
389
+ ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
390
+ ret["num_frames"] = cfg.INPUT.SAMPLING_FRAME_NUM
391
+ ret["clip_last_layer_num"] = cfg.MODEL.AVISM.LAST_LAYER_NUM
392
+
393
+ ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
394
+ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
395
+ ret["sim_use_clip"] = cfg.MODEL.AVISM.SIM_USE_CLIP
396
+ ret["use_sim"] = cfg.MODEL.AVISM.SIM_WEIGHT > 0.0
397
+
398
+ return ret
399
+
400
+ def forward(self, frame_query, audio_features):
401
+ """
402
+ L: Number of Layers.
403
+ B: Batch size.
404
+ T: Temporal window size. Number of frames per video.
405
+ C: Channel size.
406
+ fQ: Number of frame-wise queries from IFC.
407
+ cQ: Number of clip-wise queries to decode Q.
408
+ """
409
+ if not self.training:
410
+ frame_query = frame_query[[-1]]
411
+
412
+ L, BT, fQ, C = frame_query.shape
413
+ B = BT // self.num_frames if self.training else 1
414
+ T = self.num_frames if self.training else BT // B
415
+
416
+ frame_query = frame_query.reshape(L * B, T, fQ, C)
417
+ frame_query = frame_query.permute(1, 2, 0, 3).contiguous()
418
+ frame_query = self.input_proj_dec(frame_query) # T, fQ, LB, C
419
+
420
+ audio_feat = self.av_proj(audio_features) # T, C
421
+ audio_feat = audio_feat[:, None, None, :].repeat(1, fQ, L * B, 1)
422
+
423
+ if self.window_size > 0:
424
+ pad = int(ceil(T / self.window_size)) * self.window_size - T
425
+ _T = pad + T
426
+ frame_query = F.pad(frame_query, (0, 0, 0, 0, 0, 0, 0, pad)) # _T, fQ, LB, C
427
+ audio_feat = F.pad(audio_feat, (0, 0, 0, 0, 0, 0, 0, pad))
428
+ enc_mask = frame_query.new_ones(L * B, _T).bool() # LB, _T
429
+ enc_mask[:, :T] = False
430
+ else:
431
+ enc_mask = None
432
+
433
+ frame_query = self.encode_frame_query(frame_query, enc_mask)
434
+
435
+ # audio
436
+ av_feat = self.encode_av_fusion(frame_query, enc_mask, audio_feat)
437
+
438
+ frame_query = frame_query[:T].flatten(0, 1) # TfQ, LB, C
439
+ av_feat = av_feat[:T].flatten(0, 1)
440
+ frame_query = frame_query + av_feat
441
+
442
+ if self.use_sim:
443
+ pred_fq_embed = self.sim_embed_frame(frame_query) # TfQ, LB, C
444
+ pred_fq_embed = pred_fq_embed.transpose(0, 1).reshape(L, B, T, fQ, C)
445
+ else:
446
+ pred_fq_embed = None
447
+
448
+ src = self.src_embed(frame_query) # TfQ, LB, C
449
+ dec_pos = self.fq_pos.weight[None, :, None, :].repeat(T, 1, L * B, 1).flatten(0, 1) # TfQ, LB, C
450
+
451
+ # QxNxC
452
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, L * B, 1) # cQ, LB, C
453
+ output = self.query_feat.weight.unsqueeze(1).repeat(1, L * B, 1) # cQ, LB, C
454
+
455
+ decoder_outputs = []
456
+ for i in range(self.num_layers):
457
+ # attention: cross-attention first
458
+ output = self.transformer_cross_attention_layers[i](
459
+ output, src,
460
+ memory_mask=None,
461
+ memory_key_padding_mask=None,
462
+ pos=dec_pos, query_pos=query_embed
463
+ )
464
+
465
+ output = self.transformer_self_attention_layers[i](
466
+ output, tgt_mask=None,
467
+ tgt_key_padding_mask=None,
468
+ query_pos=query_embed
469
+ )
470
+
471
+ # FFN
472
+ output = self.transformer_ffn_layers[i](
473
+ output
474
+ )
475
+
476
+ if (self.training and self.aux_loss) or (i == self.num_layers - 1):
477
+ dec_out = self.decoder_norm(output) # cQ, LB, C
478
+ dec_out = dec_out.transpose(0, 1) # LB, cQ, C
479
+ decoder_outputs.append(dec_out.view(L, B, self.num_queries, C))
480
+
481
+ decoder_outputs = torch.stack(decoder_outputs, dim=0) # D, L, B, cQ, C
482
+
483
+ pred_cls = self.class_embed(decoder_outputs)
484
+ pred_mask_embed = self.mask_embed(decoder_outputs)
485
+ if self.use_sim and self.sim_use_clip:
486
+ pred_cq_embed = self.sim_embed_clip(decoder_outputs)
487
+ else:
488
+ pred_cq_embed = [None] * self.num_layers
489
+
490
+ out = {
491
+ 'pred_logits': pred_cls[-1],
492
+ 'pred_mask_embed': pred_mask_embed[-1],
493
+ 'pred_fq_embed': pred_fq_embed,
494
+ 'pred_cq_embed': pred_cq_embed[-1],
495
+ 'aux_outputs': self._set_aux_loss(
496
+ pred_cls, pred_mask_embed, pred_cq_embed, pred_fq_embed
497
+ )
498
+ }
499
+ return out
500
+
501
+ @torch.jit.unused
502
+ def _set_aux_loss(
503
+ self, outputs_cls, outputs_mask_embed, outputs_cq_embed, outputs_fq_embed
504
+ ):
505
+ return [{"pred_logits": a, "pred_mask_embed": b, "pred_cq_embed": c, "pred_fq_embed": outputs_fq_embed}
506
+ for a, b, c in zip(outputs_cls[:-1], outputs_mask_embed[:-1], outputs_cq_embed[:-1])]
507
+
508
+ def encode_frame_query(self, frame_query, attn_mask):
509
+ """
510
+ input shape (frame_query) : T, fQ, LB, C
511
+ output shape (frame_query) : T, fQ, LB, C
512
+ """
513
+
514
+ # Not using window-based attention if self.window_size == 0.
515
+ if self.window_size == 0:
516
+ return_shape = frame_query.shape # T, fQ, LB, C
517
+ frame_query = frame_query.flatten(0, 1) # TfQ, LB, C
518
+
519
+ for i in range(self.enc_layers):
520
+ frame_query = self.enc_self_attn[i](frame_query)
521
+ frame_query = self.enc_ffn[i](frame_query)
522
+
523
+ frame_query = frame_query.view(return_shape)
524
+ return frame_query
525
+ # Using window-based attention if self.window_size > 0.
526
+ else:
527
+ T, fQ, LB, C = frame_query.shape
528
+ W = self.window_size
529
+ Nw = T // W
530
+ half_W = int(ceil(W / 2))
531
+
532
+ window_mask = attn_mask.view(LB * Nw, W)[..., None].repeat(1, 1, fQ).flatten(1)
533
+
534
+ _attn_mask = torch.roll(attn_mask, half_W, 1)
535
+ _attn_mask = _attn_mask.view(LB, Nw, W)[..., None].repeat(1, 1, 1, W) # LB, Nw, W, W
536
+ _attn_mask[:, 0] = _attn_mask[:, 0] | _attn_mask[:, 0].transpose(-2, -1)
537
+ _attn_mask[:, -1] = _attn_mask[:, -1] | _attn_mask[:, -1].transpose(-2, -1)
538
+ _attn_mask[:, 0, :half_W, half_W:] = True
539
+ _attn_mask[:, 0, half_W:, :half_W] = True
540
+ _attn_mask = _attn_mask.view(LB * Nw, 1, W, 1, W, 1).repeat(1, self.num_heads, 1, fQ, 1, fQ).view(
541
+ LB * Nw * self.num_heads, W * fQ, W * fQ)
542
+ shift_window_mask = _attn_mask.float() * -1000
543
+
544
+ for layer_idx in range(self.enc_layers):
545
+ if self.training or layer_idx % 2 == 0:
546
+ frame_query = self._window_attn(frame_query, window_mask, layer_idx)
547
+ else:
548
+ frame_query = self._shift_window_attn(frame_query, shift_window_mask, layer_idx)
549
+ return frame_query
550
+
551
+ def _window_attn(self, frame_query, attn_mask, layer_idx):
552
+ T, fQ, LB, C = frame_query.shape
553
+ # LBN, WTfQ = attn_mask.shape
554
+
555
+ W = self.window_size
556
+ Nw = T // W
557
+
558
+ frame_query = frame_query.view(Nw, W, fQ, LB, C)
559
+ frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
560
+
561
+ frame_query = self.enc_self_attn[layer_idx](frame_query, tgt_key_padding_mask=attn_mask)
562
+ frame_query = self.enc_ffn[layer_idx](frame_query)
563
+ frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
564
+
565
+ return frame_query
566
+
567
+ def _shift_window_attn(self, frame_query, attn_mask, layer_idx):
568
+ T, fQ, LB, C = frame_query.shape
569
+ # LBNH, WfQ, WfQ = attn_mask.shape
570
+
571
+ W = self.window_size
572
+ Nw = T // W
573
+ half_W = int(ceil(W / 2))
574
+
575
+ frame_query = torch.roll(frame_query, half_W, 0)
576
+ frame_query = frame_query.view(Nw, W, fQ, LB, C)
577
+ frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
578
+
579
+ frame_query = self.enc_self_attn[layer_idx](frame_query, tgt_mask=attn_mask)
580
+ frame_query = self.enc_ffn[layer_idx](frame_query)
581
+ frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
582
+
583
+ frame_query = torch.roll(frame_query, -half_W, 0)
584
+
585
+ return frame_query
586
+
587
+ def encode_av_fusion(self, frame_query, attn_mask, audio_feats):
588
+ """
589
+ input shape (frame_query) : T, fQ, LB, C
590
+ output shape (frame_query) : T, fQ, LB, C
591
+ """
592
+
593
+ # Not using window-based attention if self.window_size == 0.
594
+ if self.window_size == 0:
595
+ return_shape = frame_query.shape # T, fQ, LB, C
596
+ frame_query = frame_query.flatten(0, 1) # TfQ, LB, C
597
+ audio_feats = audio_feats.flatten(0, 1)
598
+
599
+ for i in range(self.enc_layers):
600
+ audio_feats = self.enc_av_cross_attn[i](audio_feats, frame_query)
601
+ audio_feats = self.enc_av_ffn[i](audio_feats)
602
+
603
+ audio_feats = audio_feats.view(return_shape)
604
+ return audio_feats
605
+ # Using window-based attention if self.window_size > 0.
606
+ else:
607
+ T, fQ, LB, C = frame_query.shape
608
+ W = self.window_size
609
+ Nw = T // W
610
+ half_W = int(ceil(W / 2))
611
+
612
+ window_mask = attn_mask.view(LB * Nw, W)[..., None].repeat(1, 1, fQ).flatten(1)
613
+
614
+ _attn_mask = torch.roll(attn_mask, half_W, 1)
615
+ _attn_mask = _attn_mask.view(LB, Nw, W)[..., None].repeat(1, 1, 1, W) # LB, Nw, W, W
616
+ _attn_mask[:, 0] = _attn_mask[:, 0] | _attn_mask[:, 0].transpose(-2, -1)
617
+ _attn_mask[:, -1] = _attn_mask[:, -1] | _attn_mask[:, -1].transpose(-2, -1)
618
+ _attn_mask[:, 0, :half_W, half_W:] = True
619
+ _attn_mask[:, 0, half_W:, :half_W] = True
620
+ _attn_mask = _attn_mask.view(LB * Nw, 1, W, 1, W, 1).repeat(1, self.num_heads, 1, fQ, 1, fQ).view(
621
+ LB * Nw * self.num_heads, W * fQ, W * fQ)
622
+ shift_window_mask = _attn_mask.float() * -1000
623
+
624
+ for layer_idx in range(self.enc_layers):
625
+ if layer_idx % 2 == 0:
626
+ frame_query, audio_feats = self._window_av_attn(frame_query, window_mask, layer_idx, audio_feats)
627
+ else:
628
+ frame_query, audio_feats = self._shift_window_av_attn(frame_query, shift_window_mask, layer_idx, audio_feats)
629
+ return audio_feats
630
+
631
+ def _window_av_attn(self, frame_query, attn_mask, layer_idx, audio_feats):
632
+ T, fQ, LB, C = frame_query.shape
633
+
634
+ W = self.window_size
635
+ Nw = T // W
636
+
637
+ frame_query = frame_query.view(Nw, W, fQ, LB, C)
638
+ frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
639
+
640
+ audio_feats = audio_feats.view(Nw, W, fQ, LB, C)
641
+ audio_feats = audio_feats.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
642
+
643
+ audio_feats = self.enc_av_cross_attn[layer_idx](audio_feats, frame_query, memory_key_padding_mask=attn_mask)
644
+ audio_feats = self.enc_av_ffn[layer_idx](audio_feats)
645
+
646
+ frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
647
+ audio_feats = audio_feats.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
648
+
649
+ return frame_query, audio_feats
650
+
651
+ def _shift_window_av_attn(self, frame_query, attn_mask, layer_idx, audio_feats):
652
+ T, fQ, LB, C = frame_query.shape
653
+
654
+ W = self.window_size
655
+ Nw = T // W
656
+ half_W = int(ceil(W / 2))
657
+
658
+ frame_query = torch.roll(frame_query, half_W, 0)
659
+ frame_query = frame_query.view(Nw, W, fQ, LB, C)
660
+ frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
661
+
662
+ audio_feats = torch.roll(audio_feats, half_W, 0)
663
+ audio_feats = audio_feats.view(Nw, W, fQ, LB, C)
664
+ audio_feats = audio_feats.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
665
+
666
+ audio_feats = self.enc_av_cross_attn[layer_idx](audio_feats, frame_query, memory_mask=attn_mask)
667
+ audio_feats = self.enc_av_ffn[layer_idx](audio_feats)
668
+
669
+ frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
670
+ frame_query = torch.roll(frame_query, -half_W, 0)
671
+
672
+ audio_feats = audio_feats.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
673
+ audio_feats = torch.roll(audio_feats, -half_W, 0)
674
+
675
+ return frame_query, audio_feats
avism/modeling/transformer_decoder/avism_coco.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import ceil
2
+ import fvcore.nn.weight_init as weight_init
3
+ from typing import Optional
4
+ import torch
5
+ from torch import nn, Tensor
6
+ from torch.nn import functional as F
7
+ import copy
8
+
9
+ from detectron2.config import configurable
10
+ from detectron2.layers import Conv2d
11
+
12
+
13
+ class SelfAttentionLayer(nn.Module):
14
+
15
+ def __init__(self, d_model, nhead, dropout=0.0,
16
+ activation="relu", normalize_before=False):
17
+ super().__init__()
18
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
19
+
20
+ self.norm = nn.LayerNorm(d_model)
21
+ self.dropout = nn.Dropout(dropout)
22
+
23
+ self.activation = _get_activation_fn(activation)
24
+ self.normalize_before = normalize_before
25
+
26
+ self._reset_parameters()
27
+
28
+ def _reset_parameters(self):
29
+ for p in self.parameters():
30
+ if p.dim() > 1:
31
+ nn.init.xavier_uniform_(p)
32
+
33
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
34
+ return tensor if pos is None else tensor + pos
35
+
36
+ def forward_post(self, tgt,
37
+ tgt_mask: Optional[Tensor] = None,
38
+ tgt_key_padding_mask: Optional[Tensor] = None,
39
+ query_pos: Optional[Tensor] = None):
40
+ q = k = self.with_pos_embed(tgt, query_pos)
41
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
42
+ key_padding_mask=tgt_key_padding_mask)[0]
43
+ tgt = tgt + self.dropout(tgt2)
44
+ tgt = self.norm(tgt)
45
+
46
+ return tgt
47
+
48
+ def forward_pre(self, tgt,
49
+ tgt_mask: Optional[Tensor] = None,
50
+ tgt_key_padding_mask: Optional[Tensor] = None,
51
+ query_pos: Optional[Tensor] = None):
52
+ tgt2 = self.norm(tgt)
53
+ q = k = self.with_pos_embed(tgt2, query_pos)
54
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
55
+ key_padding_mask=tgt_key_padding_mask)[0]
56
+ tgt = tgt + self.dropout(tgt2)
57
+
58
+ return tgt
59
+
60
+ def forward(self, tgt,
61
+ tgt_mask: Optional[Tensor] = None,
62
+ tgt_key_padding_mask: Optional[Tensor] = None,
63
+ query_pos: Optional[Tensor] = None):
64
+ if self.normalize_before:
65
+ return self.forward_pre(tgt, tgt_mask,
66
+ tgt_key_padding_mask, query_pos)
67
+ return self.forward_post(tgt, tgt_mask,
68
+ tgt_key_padding_mask, query_pos)
69
+
70
+
71
+ class CrossAttentionLayer(nn.Module):
72
+
73
+ def __init__(self, d_model, nhead, dropout=0.0,
74
+ activation="relu", normalize_before=False):
75
+ super().__init__()
76
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
77
+
78
+ self.norm = nn.LayerNorm(d_model)
79
+ self.dropout = nn.Dropout(dropout)
80
+
81
+ self.activation = _get_activation_fn(activation)
82
+ self.normalize_before = normalize_before
83
+
84
+ self._reset_parameters()
85
+
86
+ def _reset_parameters(self):
87
+ for p in self.parameters():
88
+ if p.dim() > 1:
89
+ nn.init.xavier_uniform_(p)
90
+
91
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
92
+ return tensor if pos is None else tensor + pos
93
+
94
+ def forward_post(self, tgt, memory,
95
+ memory_mask: Optional[Tensor] = None,
96
+ memory_key_padding_mask: Optional[Tensor] = None,
97
+ pos: Optional[Tensor] = None,
98
+ query_pos: Optional[Tensor] = None):
99
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
100
+ key=self.with_pos_embed(memory, pos),
101
+ value=memory, attn_mask=memory_mask,
102
+ key_padding_mask=memory_key_padding_mask)[0]
103
+ tgt = tgt + self.dropout(tgt2)
104
+ tgt = self.norm(tgt)
105
+
106
+ return tgt
107
+
108
+ def forward_pre(self, tgt, memory,
109
+ memory_mask: Optional[Tensor] = None,
110
+ memory_key_padding_mask: Optional[Tensor] = None,
111
+ pos: Optional[Tensor] = None,
112
+ query_pos: Optional[Tensor] = None):
113
+ tgt2 = self.norm(tgt)
114
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
115
+ key=self.with_pos_embed(memory, pos),
116
+ value=memory, attn_mask=memory_mask,
117
+ key_padding_mask=memory_key_padding_mask)[0]
118
+ tgt = tgt + self.dropout(tgt2)
119
+
120
+ return tgt
121
+
122
+ def forward(self, tgt, memory,
123
+ memory_mask: Optional[Tensor] = None,
124
+ memory_key_padding_mask: Optional[Tensor] = None,
125
+ pos: Optional[Tensor] = None,
126
+ query_pos: Optional[Tensor] = None):
127
+ if self.normalize_before:
128
+ return self.forward_pre(tgt, memory, memory_mask,
129
+ memory_key_padding_mask, pos, query_pos)
130
+ return self.forward_post(tgt, memory, memory_mask,
131
+ memory_key_padding_mask, pos, query_pos)
132
+
133
+
134
+ class FFNLayer(nn.Module):
135
+
136
+ def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
137
+ activation="relu", normalize_before=False):
138
+ super().__init__()
139
+ # Implementation of Feedforward model
140
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
141
+ self.dropout = nn.Dropout(dropout)
142
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
143
+
144
+ self.norm = nn.LayerNorm(d_model)
145
+
146
+ self.activation = _get_activation_fn(activation)
147
+ self.normalize_before = normalize_before
148
+
149
+ self._reset_parameters()
150
+
151
+ def _reset_parameters(self):
152
+ for p in self.parameters():
153
+ if p.dim() > 1:
154
+ nn.init.xavier_uniform_(p)
155
+
156
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
157
+ return tensor if pos is None else tensor + pos
158
+
159
+ def forward_post(self, tgt):
160
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
161
+ tgt = tgt + self.dropout(tgt2)
162
+ tgt = self.norm(tgt)
163
+ return tgt
164
+
165
+ def forward_pre(self, tgt):
166
+ tgt2 = self.norm(tgt)
167
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
168
+ tgt = tgt + self.dropout(tgt2)
169
+ return tgt
170
+
171
+ def forward(self, tgt):
172
+ if self.normalize_before:
173
+ return self.forward_pre(tgt)
174
+ return self.forward_post(tgt)
175
+
176
+
177
+ def _get_activation_fn(activation):
178
+ """Return an activation function given a string"""
179
+ if activation == "relu":
180
+ return F.relu
181
+ if activation == "gelu":
182
+ return F.gelu
183
+ if activation == "glu":
184
+ return F.glu
185
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
186
+
187
+
188
+ class MLP(nn.Module):
189
+ """ Very simple multi-layer perceptron (also called FFN)"""
190
+
191
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
192
+ super().__init__()
193
+ self.num_layers = num_layers
194
+ h = [hidden_dim] * (num_layers - 1)
195
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
196
+
197
+ def forward(self, x):
198
+ for i, layer in enumerate(self.layers):
199
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
200
+ return x
201
+
202
+
203
+ def _get_clones(module, N):
204
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
205
+
206
+
207
+ class Avism_COCO(nn.Module):
208
+
209
+ @configurable
210
+ def __init__(
211
+ self,
212
+ in_channels,
213
+ aux_loss,
214
+ *,
215
+ hidden_dim: int,
216
+ num_frame_queries: int,
217
+ num_queries: int,
218
+ nheads: int,
219
+ dim_feedforward: int,
220
+ enc_layers: int,
221
+ dec_layers: int,
222
+ enc_window_size: int,
223
+ pre_norm: bool,
224
+ enforce_input_project: bool,
225
+ num_frames: int,
226
+ num_classes: int,
227
+ clip_last_layer_num: bool,
228
+ conv_dim: int,
229
+ mask_dim: int,
230
+ sim_use_clip: list,
231
+ use_sim: bool,
232
+ ):
233
+ """
234
+ NOTE: this interface is experimental.
235
+ Args:
236
+ in_channels: channels of the input features
237
+ hidden_dim: Transformer feature dimension
238
+ num_queries: number of queries
239
+ nheads: number of heads
240
+ dim_feedforward: feature dimension in feedforward network
241
+ enc_layers: number of Transformer encoder layers
242
+ dec_layers: number of Transformer decoder layers
243
+ pre_norm: whether to use pre-LayerNorm or not
244
+ enforce_input_project: add input project 1x1 conv even if input
245
+ channels and hidden dim is identical
246
+ """
247
+ super().__init__()
248
+
249
+ # define Transformer decoder here
250
+ self.num_heads = nheads
251
+ self.num_layers = dec_layers
252
+ self.transformer_self_attention_layers = nn.ModuleList()
253
+ self.transformer_cross_attention_layers = nn.ModuleList()
254
+ self.transformer_ffn_layers = nn.ModuleList()
255
+ self.num_frames = num_frames
256
+ self.num_classes = num_classes
257
+ self.clip_last_layer_num = clip_last_layer_num
258
+
259
+ self.enc_layers = enc_layers
260
+ self.window_size = enc_window_size
261
+ self.sim_use_clip = sim_use_clip
262
+ self.use_sim = use_sim
263
+ self.aux_loss = aux_loss
264
+
265
+ self.av_proj = nn.Linear(128, hidden_dim)
266
+
267
+ self.enc_layers = enc_layers
268
+ if enc_layers > 0:
269
+ self.enc_self_attn = nn.ModuleList()
270
+ self.enc_ffn = nn.ModuleList()
271
+ for _ in range(self.enc_layers):
272
+ self.enc_self_attn.append(
273
+ SelfAttentionLayer(
274
+ d_model=hidden_dim,
275
+ nhead=nheads,
276
+ dropout=0.0,
277
+ normalize_before=pre_norm,
278
+ ),
279
+ )
280
+ self.enc_ffn.append(
281
+ FFNLayer(
282
+ d_model=hidden_dim,
283
+ dim_feedforward=dim_feedforward,
284
+ dropout=0.0,
285
+ normalize_before=pre_norm,
286
+ )
287
+ )
288
+
289
+ if enc_layers > 0:
290
+ self.enc_av_cross_attn = nn.ModuleList()
291
+ self.enc_av_ffn = nn.ModuleList()
292
+ for _ in range(self.enc_layers):
293
+ self.enc_av_cross_attn.append(
294
+ CrossAttentionLayer(
295
+ d_model=hidden_dim,
296
+ nhead=nheads,
297
+ dropout=0.0,
298
+ normalize_before=pre_norm,
299
+ ),
300
+ )
301
+ self.enc_av_ffn.append(
302
+ FFNLayer(
303
+ d_model=hidden_dim,
304
+ dim_feedforward=dim_feedforward,
305
+ dropout=0.0,
306
+ normalize_before=pre_norm,
307
+ )
308
+ )
309
+
310
+ for _ in range(self.num_layers):
311
+ self.transformer_self_attention_layers.append(
312
+ SelfAttentionLayer(
313
+ d_model=hidden_dim,
314
+ nhead=nheads,
315
+ dropout=0.0,
316
+ normalize_before=pre_norm,
317
+ )
318
+ )
319
+
320
+ self.transformer_cross_attention_layers.append(
321
+ CrossAttentionLayer(
322
+ d_model=hidden_dim,
323
+ nhead=nheads,
324
+ dropout=0.0,
325
+ normalize_before=pre_norm,
326
+ )
327
+ )
328
+
329
+ self.transformer_ffn_layers.append(
330
+ FFNLayer(
331
+ d_model=hidden_dim,
332
+ dim_feedforward=dim_feedforward,
333
+ dropout=0.0,
334
+ normalize_before=pre_norm,
335
+ )
336
+ )
337
+
338
+ self.vita_mask_features = Conv2d(
339
+ conv_dim,
340
+ mask_dim,
341
+ kernel_size=1,
342
+ stride=1,
343
+ padding=0,
344
+ )
345
+ weight_init.c2_xavier_fill(self.vita_mask_features)
346
+
347
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
348
+
349
+ self.num_queries = num_queries
350
+ # learnable query features
351
+ self.query_feat = nn.Embedding(num_queries, hidden_dim)
352
+ # learnable query p.e.
353
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
354
+
355
+ self.fq_pos = nn.Embedding(num_frame_queries, hidden_dim)
356
+
357
+ if in_channels != hidden_dim or enforce_input_project:
358
+ self.input_proj_dec = nn.Linear(hidden_dim, hidden_dim)
359
+ else:
360
+ self.input_proj_dec = nn.Sequential()
361
+ self.src_embed = nn.Identity()
362
+
363
+ self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
364
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
365
+ if self.use_sim:
366
+ self.sim_embed_frame = nn.Linear(hidden_dim, hidden_dim)
367
+ if self.sim_use_clip:
368
+ self.sim_embed_clip = nn.Linear(hidden_dim, hidden_dim)
369
+
370
+ @classmethod
371
+ def from_config(cls, cfg, in_channels):
372
+ ret = {}
373
+ ret["in_channels"] = in_channels
374
+
375
+ ret["hidden_dim"] = cfg.MODEL.AVISM.HIDDEN_DIM
376
+ ret["num_frame_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
377
+ ret["num_queries"] = cfg.MODEL.AVISM.NUM_OBJECT_QUERIES
378
+ # Transformer parameters:
379
+ ret["nheads"] = cfg.MODEL.AVISM.NHEADS
380
+ ret["dim_feedforward"] = cfg.MODEL.AVISM.DIM_FEEDFORWARD
381
+
382
+ assert cfg.MODEL.AVISM.DEC_LAYERS >= 1
383
+ ret["enc_layers"] = cfg.MODEL.AVISM.ENC_LAYERS
384
+ ret["dec_layers"] = cfg.MODEL.AVISM.DEC_LAYERS
385
+ ret["enc_window_size"] = cfg.MODEL.AVISM.ENC_WINDOW_SIZE
386
+ ret["pre_norm"] = cfg.MODEL.AVISM.PRE_NORM
387
+ ret["enforce_input_project"] = cfg.MODEL.AVISM.ENFORCE_INPUT_PROJ
388
+
389
+ ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
390
+ ret["num_frames"] = cfg.INPUT.SAMPLING_FRAME_NUM
391
+ ret["clip_last_layer_num"] = cfg.MODEL.AVISM.LAST_LAYER_NUM
392
+
393
+ ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
394
+ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
395
+ ret["sim_use_clip"] = cfg.MODEL.AVISM.SIM_USE_CLIP
396
+ ret["use_sim"] = cfg.MODEL.AVISM.SIM_WEIGHT > 0.0
397
+
398
+ return ret
399
+
400
+ def forward(self, frame_query, audio_features):
401
+ """
402
+ L: Number of Layers.
403
+ B: Batch size.
404
+ T: Temporal window size. Number of frames per video.
405
+ C: Channel size.
406
+ fQ: Number of frame-wise queries from IFC.
407
+ cQ: Number of clip-wise queries to decode Q.
408
+ """
409
+ if not self.training:
410
+ frame_query = frame_query[[-1]]
411
+
412
+ L, BT, fQ, C = frame_query.shape
413
+ B = BT // self.num_frames if self.training else 1
414
+ T = self.num_frames if self.training else BT // B
415
+
416
+ frame_query = frame_query.reshape(L * B, T, fQ, C)
417
+ frame_query = frame_query.permute(1, 2, 0, 3).contiguous()
418
+ frame_query = self.input_proj_dec(frame_query) # T, fQ, LB, C
419
+
420
+ audio_feat = self.av_proj(audio_features) # T, C
421
+ audio_feat = audio_feat[:, None, None, :].repeat(1, fQ, L * B, 1)
422
+
423
+ if self.window_size > 0:
424
+ pad = int(ceil(T / self.window_size)) * self.window_size - T
425
+ _T = pad + T
426
+ frame_query = F.pad(frame_query, (0, 0, 0, 0, 0, 0, 0, pad)) # _T, fQ, LB, C
427
+ audio_feat = F.pad(audio_feat, (0, 0, 0, 0, 0, 0, 0, pad))
428
+ enc_mask = frame_query.new_ones(L * B, _T).bool() # LB, _T
429
+ enc_mask[:, :T] = False
430
+ else:
431
+ enc_mask = None
432
+
433
+ frame_query = self.encode_frame_query(frame_query, enc_mask)
434
+
435
+ # audio
436
+ av_feat = self.encode_av_fusion(frame_query, enc_mask, audio_feat)
437
+
438
+ frame_query = frame_query[:T].flatten(0, 1) # TfQ, LB, C
439
+ av_feat = av_feat[:T].flatten(0, 1)
440
+ frame_query = frame_query + av_feat
441
+
442
+ if self.use_sim:
443
+ pred_fq_embed = self.sim_embed_frame(frame_query) # TfQ, LB, C
444
+ pred_fq_embed = pred_fq_embed.transpose(0, 1).reshape(L, B, T, fQ, C)
445
+ else:
446
+ pred_fq_embed = None
447
+
448
+ src = self.src_embed(frame_query) # TfQ, LB, C
449
+ dec_pos = self.fq_pos.weight[None, :, None, :].repeat(T, 1, L * B, 1).flatten(0, 1) # TfQ, LB, C
450
+
451
+ # QxNxC
452
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, L * B, 1) # cQ, LB, C
453
+ output = self.query_feat.weight.unsqueeze(1).repeat(1, L * B, 1) # cQ, LB, C
454
+
455
+ decoder_outputs = []
456
+ for i in range(self.num_layers):
457
+ # attention: cross-attention first
458
+ output = self.transformer_cross_attention_layers[i](
459
+ output, src,
460
+ memory_mask=None,
461
+ memory_key_padding_mask=None,
462
+ pos=dec_pos, query_pos=query_embed
463
+ )
464
+
465
+ output = self.transformer_self_attention_layers[i](
466
+ output, tgt_mask=None,
467
+ tgt_key_padding_mask=None,
468
+ query_pos=query_embed
469
+ )
470
+
471
+ # FFN
472
+ output = self.transformer_ffn_layers[i](
473
+ output
474
+ )
475
+
476
+ if (self.training and self.aux_loss) or (i == self.num_layers - 1):
477
+ dec_out = self.decoder_norm(output) # cQ, LB, C
478
+ dec_out = dec_out.transpose(0, 1) # LB, cQ, C
479
+ decoder_outputs.append(dec_out.view(L, B, self.num_queries, C))
480
+
481
+ decoder_outputs = torch.stack(decoder_outputs, dim=0) # D, L, B, cQ, C
482
+
483
+ pred_cls = self.class_embed(decoder_outputs)
484
+ pred_mask_embed = self.mask_embed(decoder_outputs)
485
+ if self.use_sim and self.sim_use_clip:
486
+ pred_cq_embed = self.sim_embed_clip(decoder_outputs)
487
+ else:
488
+ pred_cq_embed = [None] * self.num_layers
489
+
490
+ out = {
491
+ 'pred_logits': pred_cls[-1],
492
+ 'pred_mask_embed': pred_mask_embed[-1],
493
+ 'pred_fq_embed': pred_fq_embed,
494
+ 'pred_cq_embed': pred_cq_embed[-1],
495
+ 'aux_outputs': self._set_aux_loss(
496
+ pred_cls, pred_mask_embed, pred_cq_embed, pred_fq_embed
497
+ )
498
+ }
499
+ return out
500
+
501
+ @torch.jit.unused
502
+ def _set_aux_loss(
503
+ self, outputs_cls, outputs_mask_embed, outputs_cq_embed, outputs_fq_embed
504
+ ):
505
+ return [{"pred_logits": a, "pred_mask_embed": b, "pred_cq_embed": c, "pred_fq_embed": outputs_fq_embed}
506
+ for a, b, c in zip(outputs_cls[:-1], outputs_mask_embed[:-1], outputs_cq_embed[:-1])]
507
+
508
+ def encode_frame_query(self, frame_query, attn_mask):
509
+ """
510
+ input shape (frame_query) : T, fQ, LB, C
511
+ output shape (frame_query) : T, fQ, LB, C
512
+ """
513
+
514
+ # Not using window-based attention if self.window_size == 0.
515
+ if self.window_size == 0:
516
+ return_shape = frame_query.shape # T, fQ, LB, C
517
+ frame_query = frame_query.flatten(0, 1) # TfQ, LB, C
518
+
519
+ for i in range(self.enc_layers):
520
+ frame_query = self.enc_self_attn[i](frame_query)
521
+ frame_query = self.enc_ffn[i](frame_query)
522
+
523
+ frame_query = frame_query.view(return_shape)
524
+ return frame_query
525
+ # Using window-based attention if self.window_size > 0.
526
+ else:
527
+ T, fQ, LB, C = frame_query.shape
528
+ W = self.window_size
529
+ Nw = T // W
530
+ half_W = int(ceil(W / 2))
531
+
532
+ window_mask = attn_mask.view(LB * Nw, W)[..., None].repeat(1, 1, fQ).flatten(1)
533
+
534
+ _attn_mask = torch.roll(attn_mask, half_W, 1)
535
+ _attn_mask = _attn_mask.view(LB, Nw, W)[..., None].repeat(1, 1, 1, W) # LB, Nw, W, W
536
+ _attn_mask[:, 0] = _attn_mask[:, 0] | _attn_mask[:, 0].transpose(-2, -1)
537
+ _attn_mask[:, -1] = _attn_mask[:, -1] | _attn_mask[:, -1].transpose(-2, -1)
538
+ _attn_mask[:, 0, :half_W, half_W:] = True
539
+ _attn_mask[:, 0, half_W:, :half_W] = True
540
+ _attn_mask = _attn_mask.view(LB * Nw, 1, W, 1, W, 1).repeat(1, self.num_heads, 1, fQ, 1, fQ).view(
541
+ LB * Nw * self.num_heads, W * fQ, W * fQ)
542
+ shift_window_mask = _attn_mask.float() * -1000
543
+
544
+ for layer_idx in range(self.enc_layers):
545
+ if self.training or layer_idx % 2 == 0:
546
+ frame_query = self._window_attn(frame_query, window_mask, layer_idx)
547
+ else:
548
+ frame_query = self._shift_window_attn(frame_query, shift_window_mask, layer_idx)
549
+ return frame_query
550
+
551
+ def _window_attn(self, frame_query, attn_mask, layer_idx):
552
+ T, fQ, LB, C = frame_query.shape
553
+ # LBN, WTfQ = attn_mask.shape
554
+
555
+ W = self.window_size
556
+ Nw = T // W
557
+
558
+ frame_query = frame_query.view(Nw, W, fQ, LB, C)
559
+ frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
560
+
561
+ frame_query = self.enc_self_attn[layer_idx](frame_query, tgt_key_padding_mask=attn_mask)
562
+ frame_query = self.enc_ffn[layer_idx](frame_query)
563
+ frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
564
+
565
+ return frame_query
566
+
567
+ def _shift_window_attn(self, frame_query, attn_mask, layer_idx):
568
+ T, fQ, LB, C = frame_query.shape
569
+ # LBNH, WfQ, WfQ = attn_mask.shape
570
+
571
+ W = self.window_size
572
+ Nw = T // W
573
+ half_W = int(ceil(W / 2))
574
+
575
+ frame_query = torch.roll(frame_query, half_W, 0)
576
+ frame_query = frame_query.view(Nw, W, fQ, LB, C)
577
+ frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
578
+
579
+ frame_query = self.enc_self_attn[layer_idx](frame_query, tgt_mask=attn_mask)
580
+ frame_query = self.enc_ffn[layer_idx](frame_query)
581
+ frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
582
+
583
+ frame_query = torch.roll(frame_query, -half_W, 0)
584
+
585
+ return frame_query
586
+
587
+ def encode_av_fusion(self, frame_query, attn_mask, audio_feats):
588
+ """
589
+ input shape (frame_query) : T, fQ, LB, C
590
+ output shape (frame_query) : T, fQ, LB, C
591
+ """
592
+
593
+ # Not using window-based attention if self.window_size == 0.
594
+ if self.window_size == 0:
595
+ return_shape = frame_query.shape # T, fQ, LB, C
596
+ frame_query = frame_query.flatten(0, 1) # TfQ, LB, C
597
+ audio_feats = audio_feats.flatten(0, 1)
598
+
599
+ for i in range(self.enc_layers):
600
+ audio_feats = self.enc_av_cross_attn[i](audio_feats, frame_query)
601
+ audio_feats = self.enc_av_ffn[i](audio_feats)
602
+
603
+ audio_feats = audio_feats.view(return_shape)
604
+ return audio_feats
605
+ # Using window-based attention if self.window_size > 0.
606
+ else:
607
+ T, fQ, LB, C = frame_query.shape
608
+ W = self.window_size
609
+ Nw = T // W
610
+ half_W = int(ceil(W / 2))
611
+
612
+ window_mask = attn_mask.view(LB * Nw, W)[..., None].repeat(1, 1, fQ).flatten(1)
613
+
614
+ _attn_mask = torch.roll(attn_mask, half_W, 1)
615
+ _attn_mask = _attn_mask.view(LB, Nw, W)[..., None].repeat(1, 1, 1, W) # LB, Nw, W, W
616
+ _attn_mask[:, 0] = _attn_mask[:, 0] | _attn_mask[:, 0].transpose(-2, -1)
617
+ _attn_mask[:, -1] = _attn_mask[:, -1] | _attn_mask[:, -1].transpose(-2, -1)
618
+ _attn_mask[:, 0, :half_W, half_W:] = True
619
+ _attn_mask[:, 0, half_W:, :half_W] = True
620
+ _attn_mask = _attn_mask.view(LB * Nw, 1, W, 1, W, 1).repeat(1, self.num_heads, 1, fQ, 1, fQ).view(
621
+ LB * Nw * self.num_heads, W * fQ, W * fQ)
622
+ shift_window_mask = _attn_mask.float() * -1000
623
+
624
+ for layer_idx in range(self.enc_layers):
625
+ if layer_idx % 2 == 0:
626
+ frame_query, audio_feats = self._window_av_attn(frame_query, window_mask, layer_idx, audio_feats)
627
+ else:
628
+ frame_query, audio_feats = self._shift_window_av_attn(frame_query, shift_window_mask, layer_idx, audio_feats)
629
+ return audio_feats
630
+
631
+ def _window_av_attn(self, frame_query, attn_mask, layer_idx, audio_feats):
632
+ T, fQ, LB, C = frame_query.shape
633
+
634
+ W = self.window_size
635
+ Nw = T // W
636
+
637
+ frame_query = frame_query.view(Nw, W, fQ, LB, C)
638
+ frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
639
+
640
+ audio_feats = audio_feats.view(Nw, W, fQ, LB, C)
641
+ audio_feats = audio_feats.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
642
+
643
+ audio_feats = self.enc_av_cross_attn[layer_idx](audio_feats, frame_query, memory_key_padding_mask=attn_mask)
644
+ audio_feats = self.enc_av_ffn[layer_idx](audio_feats)
645
+
646
+ frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
647
+ audio_feats = audio_feats.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
648
+
649
+ return frame_query, audio_feats
650
+
651
+ def _shift_window_av_attn(self, frame_query, attn_mask, layer_idx, audio_feats):
652
+ T, fQ, LB, C = frame_query.shape
653
+
654
+ W = self.window_size
655
+ Nw = T // W
656
+ half_W = int(ceil(W / 2))
657
+
658
+ frame_query = torch.roll(frame_query, half_W, 0)
659
+ frame_query = frame_query.view(Nw, W, fQ, LB, C)
660
+ frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
661
+
662
+ audio_feats = torch.roll(audio_feats, half_W, 0)
663
+ audio_feats = audio_feats.view(Nw, W, fQ, LB, C)
664
+ audio_feats = audio_feats.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
665
+
666
+ audio_feats = self.enc_av_cross_attn[layer_idx](audio_feats, frame_query, memory_mask=attn_mask)
667
+ audio_feats = self.enc_av_ffn[layer_idx](audio_feats)
668
+
669
+ frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
670
+ frame_query = torch.roll(frame_query, -half_W, 0)
671
+
672
+ audio_feats = audio_feats.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
673
+ audio_feats = torch.roll(audio_feats, -half_W, 0)
674
+
675
+ return frame_query, audio_feats