Spaces:
Running
on
Zero
Running
on
Zero
derektan
commited on
Commit
·
4f09ecf
1
Parent(s):
66c5745
Init new app to handle planning. Fresh import from 27fe831777c12b25e504dd14e5b661742bdecce6 from VLM-Search
Browse files- .gitignore +146 -1
- Maps/flair_real_maps/envs_val_trained_clss/MSK_000310_great_horned_owl.png +3 -0
- Maps/flair_real_maps/masks_val_trained_clss/MSK_000310_great_horned_owl.png +3 -0
- Maps/flair_real_maps_lisa_pred/finetuned_LISA_v3_original_losses/flair_lisa_soft_masks_trained_clss_v3/MSK_000310_great_horned_owl.png +3 -0
- Taxabind/Taxabind/SatBind/clip_seg_tta.py +933 -0
- Taxabind/Taxabind/SatBind/config.py +61 -0
- Taxabind/Taxabind/SatBind/dataset.py +396 -0
- Taxabind/Taxabind/SatBind/kmeans_clustering.py +620 -0
- Taxabind/Taxabind/SatBind/model.py +226 -0
- Taxabind/Taxabind/SatBind/watershed_segmentation.py +1094 -0
- Taxabind/Taxabind/SoundBind/config.py +30 -0
- Taxabind/Taxabind/SoundBind/dataloader.py +57 -0
- Taxabind/Taxabind/SoundBind/model.py +141 -0
- Taxabind/Taxabind/SoundBind/sound_encoder.py +45 -0
- app.py +119 -260
- app_BACKUP.py +375 -0
- env.py +811 -0
- graph.py +179 -0
- graph_generator.py +330 -0
- inference/model/STAGE1_vlm_search_24x24_040425_no_tgt_rewards_iNAT_DS_16k.pth +3 -0
- model.py +319 -0
- node.py +102 -0
- robot.py +60 -0
- sensor.py +128 -0
- test_multi_robot_worker.py +727 -0
- test_parameter.py +183 -0
.gitignore
CHANGED
@@ -1 +1,146 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Additional Stuff
|
2 |
+
!train/
|
3 |
+
train/*
|
4 |
+
!train/saved
|
5 |
+
|
6 |
+
!model/
|
7 |
+
model/*
|
8 |
+
!model/saved
|
9 |
+
|
10 |
+
!inference/
|
11 |
+
inference/*
|
12 |
+
!inference/saved
|
13 |
+
|
14 |
+
!gifs/
|
15 |
+
gifs/*
|
16 |
+
!gifs/saved
|
17 |
+
|
18 |
+
# Byte-compiled / optimized / DLL files
|
19 |
+
__pycache__/
|
20 |
+
*.py[cod]
|
21 |
+
*$py.class
|
22 |
+
|
23 |
+
# C extensions
|
24 |
+
*.so
|
25 |
+
|
26 |
+
# Distribution / packaging
|
27 |
+
.Python
|
28 |
+
build/
|
29 |
+
develop-eggs/
|
30 |
+
dist/
|
31 |
+
downloads/
|
32 |
+
eggs/
|
33 |
+
.eggs/
|
34 |
+
lib/
|
35 |
+
lib64/
|
36 |
+
parts/
|
37 |
+
sdist/
|
38 |
+
var/
|
39 |
+
wheels/
|
40 |
+
pip-wheel-metadata/
|
41 |
+
share/python-wheels/
|
42 |
+
*.egg-info/
|
43 |
+
.installed.cfg
|
44 |
+
*.egg
|
45 |
+
MANIFEST
|
46 |
+
|
47 |
+
# PyInstaller
|
48 |
+
# Usually these files are written by a python script from a template
|
49 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
50 |
+
*.manifest
|
51 |
+
*.spec
|
52 |
+
|
53 |
+
# Installer logs
|
54 |
+
pip-log.txt
|
55 |
+
pip-delete-this-directory.txt
|
56 |
+
|
57 |
+
# Unit test / coverage reports
|
58 |
+
htmlcov/
|
59 |
+
.tox/
|
60 |
+
.nox/
|
61 |
+
.coverage
|
62 |
+
.coverage.*
|
63 |
+
.cache
|
64 |
+
nosetests.xml
|
65 |
+
coverage.xml
|
66 |
+
*.cover
|
67 |
+
*.py,cover
|
68 |
+
.hypothesis/
|
69 |
+
.pytest_cache/
|
70 |
+
|
71 |
+
# Translations
|
72 |
+
*.mo
|
73 |
+
*.pot
|
74 |
+
|
75 |
+
# Django stuff:
|
76 |
+
*.log
|
77 |
+
local_settings.py
|
78 |
+
db.sqlite3
|
79 |
+
db.sqlite3-journal
|
80 |
+
|
81 |
+
# Flask stuff:
|
82 |
+
instance/
|
83 |
+
.webassets-cache
|
84 |
+
|
85 |
+
# Scrapy stuff:
|
86 |
+
.scrapy
|
87 |
+
|
88 |
+
# Sphinx documentation
|
89 |
+
docs/_build/
|
90 |
+
|
91 |
+
# PyBuilder
|
92 |
+
target/
|
93 |
+
|
94 |
+
# Jupyter Notebook
|
95 |
+
.ipynb_checkpoints
|
96 |
+
|
97 |
+
# IPython
|
98 |
+
profile_default/
|
99 |
+
ipython_config.py
|
100 |
+
|
101 |
+
# pyenv
|
102 |
+
.python-version
|
103 |
+
|
104 |
+
# pipenv
|
105 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
106 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
107 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
108 |
+
# install all needed dependencies.
|
109 |
+
#Pipfile.lock
|
110 |
+
|
111 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
112 |
+
__pypackages__/
|
113 |
+
|
114 |
+
# Celery stuff
|
115 |
+
celerybeat-schedule
|
116 |
+
celerybeat.pid
|
117 |
+
|
118 |
+
# SageMath parsed files
|
119 |
+
*.sage.py
|
120 |
+
|
121 |
+
# Environments
|
122 |
+
.env
|
123 |
+
.venv
|
124 |
+
env/
|
125 |
+
venv/
|
126 |
+
ENV/
|
127 |
+
env.bak/
|
128 |
+
venv.bak/
|
129 |
+
|
130 |
+
# Spyder project settings
|
131 |
+
.spyderproject
|
132 |
+
.spyproject
|
133 |
+
|
134 |
+
# Rope project settings
|
135 |
+
.ropeproject
|
136 |
+
|
137 |
+
# mkdocs documentation
|
138 |
+
/site
|
139 |
+
|
140 |
+
# mypy
|
141 |
+
.mypy_cache/
|
142 |
+
.dmypy.json
|
143 |
+
dmypy.json
|
144 |
+
|
145 |
+
# Pyre type checker
|
146 |
+
.pyre/
|
Maps/flair_real_maps/envs_val_trained_clss/MSK_000310_great_horned_owl.png
ADDED
![]() |
Git LFS Details
|
Maps/flair_real_maps/masks_val_trained_clss/MSK_000310_great_horned_owl.png
ADDED
![]() |
Git LFS Details
|
Maps/flair_real_maps_lisa_pred/finetuned_LISA_v3_original_losses/flair_lisa_soft_masks_trained_clss_v3/MSK_000310_great_horned_owl.png
ADDED
![]() |
Git LFS Details
|
Taxabind/Taxabind/SatBind/clip_seg_tta.py
ADDED
@@ -0,0 +1,933 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from itertools import cycle
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
# Ensure the correct directory is at the front of sys.path
|
8 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
9 |
+
# Remove cached 'model' module to force reloading from the correct path
|
10 |
+
if "model" in sys.modules:
|
11 |
+
del sys.modules["model"]
|
12 |
+
|
13 |
+
import json
|
14 |
+
import glob
|
15 |
+
from datetime import datetime
|
16 |
+
import time
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import numpy as np
|
20 |
+
from PIL import Image
|
21 |
+
from tqdm import tqdm
|
22 |
+
from matplotlib import pyplot as plt
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
from torchvision.transforms import v2
|
25 |
+
|
26 |
+
import open_clip
|
27 |
+
from dataset import SatNatDataset
|
28 |
+
from model import SatBind
|
29 |
+
from transformers import CLIPVisionModelWithProjection
|
30 |
+
from kmeans_clustering import CombinedSilhouetteInertiaClusterer
|
31 |
+
|
32 |
+
# Just to import SoundBind (1 dir back)
|
33 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
34 |
+
from SoundBind.model import AudioBind
|
35 |
+
|
36 |
+
# import matplotlib
|
37 |
+
# matplotlib.use("Agg") # <-- key line to avoid tkinter dependency
|
38 |
+
|
39 |
+
class ClipSegTTA:
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
img_dir: str,
|
43 |
+
imo_dir: str,
|
44 |
+
json_path: str,
|
45 |
+
sat_to_img_ids_json_path: str,
|
46 |
+
patch_size: int,
|
47 |
+
sat_checkpoint_path: str,
|
48 |
+
sample_index: int = 0,
|
49 |
+
blur_kernel = (5,5), # (0,0) for no gaussian blur
|
50 |
+
batch_size: int = 1,
|
51 |
+
num_workers: int = 1,
|
52 |
+
device: str = "cuda",
|
53 |
+
sat_to_img_ids_json_is_train_dict: bool = True,
|
54 |
+
tax_to_filter_val: str = "",
|
55 |
+
load_model: bool = True,
|
56 |
+
initial_modality: str = "image",
|
57 |
+
sound_data_path: str = None,
|
58 |
+
sound_checkpoint_path: str = None,
|
59 |
+
):
|
60 |
+
"""
|
61 |
+
Initialize the ClipSegTTA class with the required parameters.
|
62 |
+
|
63 |
+
:param img_dir: Path to the ground images directory.
|
64 |
+
:param imo_dir: Path to the satellite images directory.
|
65 |
+
:param json_path: Path to the iNat JSON file (full dataset).
|
66 |
+
:param sat_filtered_json_path: Path to the filtered pixel-CLIP JSON (satellite).
|
67 |
+
:param sat_to_img_ids_json_path: Path to the mapping from satellite image IDs to ground image IDs.
|
68 |
+
:param patch_size: Size of each satellite patch (e.g. 14).
|
69 |
+
:param sat_checkpoint_path: Path to the SatBind checkpoint (CKPT file).
|
70 |
+
:param batch_size: Batch size for loading data (default=1).
|
71 |
+
:param num_workers: Number of workers for data loading (default=1).
|
72 |
+
:param device: Device to use, e.g., 'cuda' or 'cpu' (default='cuda').
|
73 |
+
"""
|
74 |
+
|
75 |
+
self.img_dir = img_dir
|
76 |
+
self.imo_dir = imo_dir
|
77 |
+
self.json_path = json_path
|
78 |
+
# self.sat_filtered_json_path = sat_filtered_json_path
|
79 |
+
self.sat_to_img_ids_json_path = sat_to_img_ids_json_path
|
80 |
+
self.patch_size = patch_size
|
81 |
+
self.sat_checkpoint_path = sat_checkpoint_path
|
82 |
+
self.sample_index = sample_index # Can be overriden by reset function
|
83 |
+
self.blur_kernel = blur_kernel
|
84 |
+
self.batch_size = batch_size
|
85 |
+
self.num_workers = num_workers
|
86 |
+
self.device = device
|
87 |
+
self.sat_to_img_ids_json_is_train_dict = sat_to_img_ids_json_is_train_dict
|
88 |
+
self.tax_to_filter_val = tax_to_filter_val
|
89 |
+
self.load_model = load_model
|
90 |
+
self.initial_modality = initial_modality
|
91 |
+
self.sound_data_path = sound_data_path
|
92 |
+
self.sound_checkpoint_path = sound_checkpoint_path
|
93 |
+
|
94 |
+
# Prepare the dataset
|
95 |
+
start_time = time.time()
|
96 |
+
self.load_data()
|
97 |
+
print(f"Dataset loaded in {(time.time()-start_time):.2f}s.")
|
98 |
+
|
99 |
+
# Load the global model (original/frozen checkpoint)
|
100 |
+
if self.load_model:
|
101 |
+
start_time = time.time()
|
102 |
+
self.load_global_model()
|
103 |
+
self.tokenizer = open_clip.get_tokenizer("hf-hub:imageomics/bioclip")
|
104 |
+
print(f"Global model loaded in {(time.time()-start_time):.2f}s.")
|
105 |
+
|
106 |
+
# Create the local model that will be adapted for TTA
|
107 |
+
if self.load_model:
|
108 |
+
self.model_local = SatBind(train_dataset=None, val_dataset=None)
|
109 |
+
self.model_local.to(self.device)
|
110 |
+
self.model_local.eval()
|
111 |
+
|
112 |
+
# Load sound model if provided
|
113 |
+
if self.sound_checkpoint_path:
|
114 |
+
self.sound_model = AudioBind.load_from_checkpoint(sound_checkpoint_path, train_dataset=None, val_dataset=None)
|
115 |
+
self.sound_model.to(self.device)
|
116 |
+
self.sound_model.eval()
|
117 |
+
|
118 |
+
# Params
|
119 |
+
self.reset(sample_idx=self.sample_index)
|
120 |
+
|
121 |
+
# for idx, related_img_path in enumerate(self.related_imgs_paths):
|
122 |
+
# self.visualize_heatmap(
|
123 |
+
# step=0,
|
124 |
+
# img_path_viz=related_img_path,
|
125 |
+
# imo_path_viz=self.imo_path,
|
126 |
+
# patch_idx_viz=[self.patch_idx],
|
127 |
+
# patch_is_pos=[True],
|
128 |
+
# species_name=self.species_name
|
129 |
+
# )
|
130 |
+
|
131 |
+
self.clip_inference_time = 0.0
|
132 |
+
self.tta_time = 0.0
|
133 |
+
|
134 |
+
|
135 |
+
def load_data(self):
|
136 |
+
"""Load or initialize the dataset."""
|
137 |
+
self.dataset = SatNatDataset(
|
138 |
+
img_dir=self.img_dir,
|
139 |
+
imo_dir=self.imo_dir,
|
140 |
+
json_path=self.json_path,
|
141 |
+
# sat_filtered_json_path=self.sat_filtered_json_path,
|
142 |
+
sat_to_img_ids_json_path=self.sat_to_img_ids_json_path,
|
143 |
+
sound_data_path=self.sound_data_path,
|
144 |
+
patch_size=self.patch_size,
|
145 |
+
mode="val",
|
146 |
+
get_img_path=True,
|
147 |
+
sat_to_img_ids_json_is_train_dict=self.sat_to_img_ids_json_is_train_dict,
|
148 |
+
tax_to_filter_val=self.tax_to_filter_val
|
149 |
+
)
|
150 |
+
|
151 |
+
def reset(self, sample_idx):
|
152 |
+
"""Reset the parameters & local model for the current sample."""
|
153 |
+
if self.load_model:
|
154 |
+
self.reset_local_model() # Reset to global weights as init
|
155 |
+
|
156 |
+
self.img_paths, self.imo_path, self.imgs, self.imo, self.sounds, self.sound_ids, self.species_name, self.target_positions, self.gt_mask_name = self.dataset.get_search_ds_data(sample_idx)
|
157 |
+
self.imgs = self.imgs.to(self.device)
|
158 |
+
# self.img_path = self.img_paths[0] # Select 1st img as query
|
159 |
+
# self.img = self.imgs[0] # Select 1st img as query
|
160 |
+
# self.img = self.img.to(self.device)
|
161 |
+
# self.imo = self.imo.to(self.device)
|
162 |
+
self.tgts_gt_score = None
|
163 |
+
if self.load_model:
|
164 |
+
self.heatmap, self.heatmap_unnormalized, self.heatmap_unnormalized_initial, self.patch_embeds = None, None, None, None
|
165 |
+
img = self.imgs[0].unsqueeze(0).to(self.device)
|
166 |
+
imo = self.imo.unsqueeze(0).to(self.device)
|
167 |
+
sound = self.sounds[0].to(self.device) if self.sounds != [] else None
|
168 |
+
txt = [self.species_name]
|
169 |
+
self.generate_heatmap(img, imo, txt, sound=sound, modality=self.initial_modality)
|
170 |
+
|
171 |
+
# Find avg heatmap score for target positions (index target into heatmap)
|
172 |
+
scores = []
|
173 |
+
imo_orig = Image.open(self.imo_path)
|
174 |
+
for pos in self.target_positions:
|
175 |
+
row_trans = int(pos[0] * self.heatmap.shape[0] / imo_orig.size[0])
|
176 |
+
col_trans = int(pos[1] * self.heatmap.shape[1] / imo_orig.size[1])
|
177 |
+
# print("row_trans, col_trans: ", row_trans, col_trans)
|
178 |
+
scores.append(self.heatmap[row_trans, col_trans])
|
179 |
+
self.tgts_gt_score = np.mean(scores)
|
180 |
+
# print("scores: ", scores)
|
181 |
+
# print("self.tgts_gt_score: ", self.tgts_gt_score)
|
182 |
+
|
183 |
+
|
184 |
+
# print("Sample reset for TTA: ", self.species_name)
|
185 |
+
|
186 |
+
def load_global_model(self):
|
187 |
+
"""Load the global SatBind model from checkpoint, move to device, and eval."""
|
188 |
+
self.model_global = SatBind.load_from_checkpoint(
|
189 |
+
self.sat_checkpoint_path, train_dataset=None, val_dataset=None
|
190 |
+
)
|
191 |
+
self.model_global = self.model_global.to(self.device)
|
192 |
+
self.model_global.eval()
|
193 |
+
|
194 |
+
def reset_local_model(self):
|
195 |
+
"""
|
196 |
+
Reset the local model to match the global model's parameters
|
197 |
+
and freeze/unfreeze layers for TTA.
|
198 |
+
"""
|
199 |
+
start_time = time.time()
|
200 |
+
with torch.no_grad():
|
201 |
+
for param_global, param_local in zip(
|
202 |
+
self.model_global.parameters(), self.model_local.parameters()
|
203 |
+
):
|
204 |
+
param_local.data.copy_(param_global.data)
|
205 |
+
|
206 |
+
# Freeze everything except the satellite encoder & custom projection
|
207 |
+
for name, param in self.model_local.named_parameters():
|
208 |
+
if "imo_encoder" in name or "visual_projection_custom" in name:
|
209 |
+
param.requires_grad = True
|
210 |
+
else:
|
211 |
+
param.requires_grad = False
|
212 |
+
|
213 |
+
self.model_local.eval()
|
214 |
+
# print(f"Local model reset in {(time.time()-start_time):.2f}s.")
|
215 |
+
|
216 |
+
|
217 |
+
def execute_tta(
|
218 |
+
self,
|
219 |
+
patch_indices: list,
|
220 |
+
patch_is_pos: list,
|
221 |
+
pos_sample_weight: list,
|
222 |
+
neg_sample_weight: list,
|
223 |
+
tta_steps: int = 10,
|
224 |
+
lr: float = 2e-6,
|
225 |
+
modality: str = "image", # image, text, or combined
|
226 |
+
query_variety: bool = False, # Whether to use different query images
|
227 |
+
target_found_idxs: list = [],
|
228 |
+
reset_weights: bool = True,
|
229 |
+
num_viz_steps: int = 1,
|
230 |
+
viz_heatmap: bool = False,
|
231 |
+
):
|
232 |
+
"""
|
233 |
+
Run test-time adaptation using the local model. The local model is first
|
234 |
+
reset to the global weights. After TTA, the global model remains
|
235 |
+
unchanged; only the local model is updated.
|
236 |
+
|
237 |
+
:param sample_index: Index for selecting sample(s) from the validation set.
|
238 |
+
:param tta_steps: Number of test-time adaptation steps.
|
239 |
+
:param num_viz_steps: Visualize heatmap after every 'num_viz_steps' steps.
|
240 |
+
:param viz_heatmap: If True, perform visualization. If False, skip plotting.
|
241 |
+
"""
|
242 |
+
|
243 |
+
### Option 1: SAMPLE FROM DATASET
|
244 |
+
# 1) Reset the local model to global weights
|
245 |
+
if reset_weights:
|
246 |
+
self.reset_local_model()
|
247 |
+
|
248 |
+
# 2) Prepare the sample(s) for TTA
|
249 |
+
# print("target_found_idxs: ", target_found_idxs)
|
250 |
+
if query_variety:
|
251 |
+
indices = torch.tensor(target_found_idxs).to(dtype=torch.long).to(self.device)
|
252 |
+
img = self.imgs[indices].to(self.device)
|
253 |
+
# print("~variety")
|
254 |
+
elif (self.initial_modality == "text" or self.initial_modality == "sound") and modality == "combined" and len(target_found_idxs) == 0:
|
255 |
+
indices = torch.tensor(target_found_idxs).to(dtype=torch.long).to(self.device)
|
256 |
+
img = self.imgs[indices].to(self.device)
|
257 |
+
# print("~empty")
|
258 |
+
else:
|
259 |
+
img = self.imgs[0].unsqueeze(0).to(self.device)
|
260 |
+
# print("~single")
|
261 |
+
imo = self.imo.unsqueeze(0).to(self.device) # vectorize in shared_step to make imo uniform (faster)
|
262 |
+
txt = [self.species_name]
|
263 |
+
sound = self.sounds[0].to(self.device) if self.sounds != [] else None
|
264 |
+
patch_indices = [idx+1 for idx in patch_indices] # BUGFIX: Consider the [CLS] token offset
|
265 |
+
patch_idx = torch.tensor(patch_indices).to(self.device)
|
266 |
+
|
267 |
+
# print("img.shape: ", img.shape)
|
268 |
+
# print("imo.shape: ", imo.shape)
|
269 |
+
# print("patch_idx.shape: ", patch_idx.shape)
|
270 |
+
# print("species_txt: ", txt)
|
271 |
+
# ---------------------------------------------------------------------
|
272 |
+
|
273 |
+
|
274 |
+
# 5) Set up optimizer (only for sat branch) & AMP scaler
|
275 |
+
optimizer = torch.optim.Adam(
|
276 |
+
[p for p in self.model_local.parameters() if p.requires_grad], lr=lr
|
277 |
+
)
|
278 |
+
# scaler = torch.cuda.amp.GradScaler(init_scale=1000)
|
279 |
+
|
280 |
+
# print(
|
281 |
+
# f"Starting TTA (patch_idx={patch_idx[0].item()}) for {tta_steps} steps. "
|
282 |
+
# f"Visualization: {'ON' if viz_heatmap else 'OFF'} every {num_viz_steps} step(s)."
|
283 |
+
# )
|
284 |
+
|
285 |
+
start_time = time.time()
|
286 |
+
|
287 |
+
# 6) TTA loop
|
288 |
+
for step in range(tta_steps):
|
289 |
+
# # with torch.cuda.amp.autocast():
|
290 |
+
batch_size = imo.shape[0]
|
291 |
+
|
292 |
+
# Query embeds
|
293 |
+
query_embeds = self.generate_query_embeds(img, imo, txt, sound=sound, modality=modality)
|
294 |
+
|
295 |
+
# Sat Embeds
|
296 |
+
imo_embeds = self.model_local.imo_encoder(imo).last_hidden_state # (batch, Patches, hidden_dim)
|
297 |
+
imo_embeds = imo_embeds[torch.arange(batch_size), patch_idx] # (batch, hidden_dim)
|
298 |
+
imo_embeds = self.model_local.visual_projection_custom(imo_embeds) # (batch_size, proj_dim)
|
299 |
+
imo_embeds = torch.nn.functional.normalize(imo_embeds, dim=-1)
|
300 |
+
|
301 |
+
# Compute Similarity Loss
|
302 |
+
logit_scale = self.model_local.logit_scale.exp()
|
303 |
+
similarity = imo_embeds @ query_embeds.t() * logit_scale
|
304 |
+
# target = torch.tensor(patch_is_pos, dtype=torch.float32, device=similarity.device)
|
305 |
+
# per_sample_weight = torch.tensor(per_sample_weight, dtype=torch.float32, device=similarity.device)
|
306 |
+
# criterion = torch.nn.BCEWithLogitsLoss(weight=per_sample_weight) # Includes sigmoid activation internally
|
307 |
+
# loss = criterion(similarity.squeeze(), target)
|
308 |
+
# # NOTE: Equivalent loss scaling
|
309 |
+
# # criterion = torch.nn.BCEWithLogitsLoss(reduction="none")
|
310 |
+
# # loss = criterion(similarity.squeeze(), target)
|
311 |
+
# # loss = loss * per_sample_weight
|
312 |
+
# # loss = loss.mean()
|
313 |
+
|
314 |
+
# loss = 0
|
315 |
+
# patch_probs = similarity.squeeze()
|
316 |
+
# for probs, counts in zip(patch_probs, patch_is_pos):
|
317 |
+
# print("patch_probs.shape: ", patch_probs.shape)
|
318 |
+
# delta_loss = (probs - counts * torch.log(probs + 1e-6))
|
319 |
+
# print("delta_loss.shape: ", delta_loss.shape)
|
320 |
+
# loss += delta_loss
|
321 |
+
|
322 |
+
# Negative Log Likelihood loss for spatial poisson point process
|
323 |
+
patch_probs = similarity.squeeze().sigmoid()
|
324 |
+
counts = torch.tensor(patch_is_pos, dtype=torch.float32, device=similarity.device)
|
325 |
+
pos_weights = torch.tensor(pos_sample_weight, dtype=torch.float32, device=similarity.device)
|
326 |
+
neg_weights = torch.tensor(neg_sample_weight, dtype=torch.float32, device=similarity.device)
|
327 |
+
# loss = (neg_weight * patch_probs - pos_weight * counts * torch.log(patch_probs + 1e-6))
|
328 |
+
loss = (neg_weights * patch_probs - pos_weights * counts * torch.log(patch_probs + 1e-6))
|
329 |
+
# print("patch_probs: ", patch_probs)
|
330 |
+
# print("counts: ", counts)
|
331 |
+
# print("pos_weights: ", pos_weights)
|
332 |
+
# print("neg_weights: ", neg_weights)
|
333 |
+
# print("loss: ", loss)
|
334 |
+
loss = loss.sum()
|
335 |
+
|
336 |
+
# # Print loss info
|
337 |
+
# print(f"Step {step:03d}: CLIP Loss = {loss.item():.4f}")
|
338 |
+
|
339 |
+
# # Backprop and update
|
340 |
+
# NOTE: No need for scaler since not using mixed precision (i.e. autocast)
|
341 |
+
# optimizer.zero_grad()
|
342 |
+
# scaler.scale(loss).backward()
|
343 |
+
# scaler.step(optimizer)
|
344 |
+
# scaler.update()
|
345 |
+
optimizer.zero_grad()
|
346 |
+
loss.backward()
|
347 |
+
optimizer.step()
|
348 |
+
|
349 |
+
self.tta_time = time.time() - start_time
|
350 |
+
# print("self.tta_time: ", self.tta_time)
|
351 |
+
|
352 |
+
# Visualization every 'num_viz_steps' steps (if enabled)
|
353 |
+
if (step + 1) % num_viz_steps == 0 and viz_heatmap:
|
354 |
+
# Visualize only the first sample in the batch
|
355 |
+
# self.generate_heatmap(img[0], imo[0])
|
356 |
+
self.generate_heatmap(img, imo, txt, sound=sound, modality=modality)
|
357 |
+
self.visualize_heatmap(
|
358 |
+
step=step,
|
359 |
+
img_path_viz=self.img_paths[0], # Viz 1st image
|
360 |
+
imo_path_viz=self.imo_path,
|
361 |
+
patch_idx_viz=patch_idx,
|
362 |
+
patch_is_pos=patch_is_pos,
|
363 |
+
species_name=self.species_name
|
364 |
+
)
|
365 |
+
|
366 |
+
# Save final heatmap after TTA steps
|
367 |
+
self.generate_heatmap(img, imo, txt, sound=sound, modality=modality)
|
368 |
+
|
369 |
+
|
370 |
+
def generate_query_embeds(self, img, imo, txt, sound=None, modality="image"):
|
371 |
+
|
372 |
+
# Query Embeds
|
373 |
+
if modality == "image":
|
374 |
+
# print("~Image modality")
|
375 |
+
query_embeds, *_ = self.model_local.bio_model(img) # (batch_size, proj_dim)
|
376 |
+
if query_embeds.shape[0] > 1:
|
377 |
+
query_embeds = query_embeds.mean(dim=0, keepdim=True) # (1, proj_dim)
|
378 |
+
elif modality == "text" or (modality == "combined" and img.shape[0] == 0):
|
379 |
+
# print("~Text modality")
|
380 |
+
txt_tokenized = self.tokenizer(txt).to(imo.device)
|
381 |
+
_, query_embeds, _ = self.model_local.bio_model(text=txt_tokenized)
|
382 |
+
elif modality == "sound" or (modality == "combined" and img.shape[0] == 0):
|
383 |
+
# print("~Sound modality")
|
384 |
+
if sound == None:
|
385 |
+
print("!!!! Sound modality requires sound input !!!")
|
386 |
+
exit(1)
|
387 |
+
unnormalized_audio_embeds = self.sound_model.audio_encoder(sound)
|
388 |
+
query_embeds = torch.nn.functional.normalize(unnormalized_audio_embeds, dim=-1)
|
389 |
+
elif modality == "combined":
|
390 |
+
# print("~Combined modality")
|
391 |
+
img_embeds, *_ = self.model_local.bio_model(img)
|
392 |
+
if img_embeds.shape[0] > 1:
|
393 |
+
img_embeds = img_embeds.mean(dim=0, keepdim=True) # (1, proj_dim)
|
394 |
+
txt_tokenized = self.tokenizer(txt).to(imo.device)
|
395 |
+
_, txt_embeds, _ = self.model_local.bio_model(text=txt_tokenized)
|
396 |
+
query_embeds = (img_embeds + txt_embeds) / 2
|
397 |
+
else:
|
398 |
+
raise ValueError("Invalid modality")
|
399 |
+
|
400 |
+
return query_embeds
|
401 |
+
|
402 |
+
|
403 |
+
def generate_heatmap(self, img, imo, txt, sound=None, modality="image"):
|
404 |
+
|
405 |
+
start_time = time.time()
|
406 |
+
|
407 |
+
# Satellite encoder outputs
|
408 |
+
imo_embeds = self.model_local.imo_encoder(imo).last_hidden_state
|
409 |
+
# Apply custom projection
|
410 |
+
imo_embeds = self.model_local.visual_projection_custom(imo_embeds)
|
411 |
+
imo_embeds = torch.nn.functional.normalize(imo_embeds, dim=-1)
|
412 |
+
# Remove batch dimension -> (num_tokens, proj_dim)
|
413 |
+
imo_embeds = imo_embeds.squeeze(0)
|
414 |
+
self.patch_embeds = imo_embeds.clone()[1:].cpu().detach().numpy()
|
415 |
+
|
416 |
+
# # Ground image embedding (bio CLIP model)
|
417 |
+
# img_embeds, *_ = self.model_local.bio_model(img)
|
418 |
+
query_embeds = self.generate_query_embeds(img, imo, txt, sound=sound, modality=modality)
|
419 |
+
|
420 |
+
# Same logit scale as in SatBind
|
421 |
+
logit_scale = self.model_local.logit_scale.exp()
|
422 |
+
sim = query_embeds @ imo_embeds.t() * logit_scale
|
423 |
+
# Sigmoid to get similarity scores
|
424 |
+
scores = sim.t().sigmoid() # (num_tokens, 1)
|
425 |
+
|
426 |
+
# Exclude [CLS] token at index 0
|
427 |
+
score_no_cls = scores[1:].squeeze() # shape: (num_tokens-1,)
|
428 |
+
num_tokens = score_no_cls.shape[0]
|
429 |
+
side_dim = int(num_tokens**0.5)
|
430 |
+
sim_scores = score_no_cls.reshape(side_dim, side_dim).clone()
|
431 |
+
sim_scores = sim_scores.cpu().detach().numpy()
|
432 |
+
|
433 |
+
self.clip_inference_time = time.time() - start_time
|
434 |
+
# print("self.clip_inference_time: ", self.clip_inference_time)
|
435 |
+
|
436 |
+
# Gausian Smoothing
|
437 |
+
if self.blur_kernel != (0,0):
|
438 |
+
sim_scores = cv2.GaussianBlur(sim_scores, self.blur_kernel, 0)
|
439 |
+
|
440 |
+
# Normalize to expectation
|
441 |
+
self.heatmap_unnormalized = sim_scores
|
442 |
+
scale = len(self.target_positions) / (self.heatmap_unnormalized.sum())
|
443 |
+
self.heatmap_unnormalized *= scale
|
444 |
+
if self.heatmap_unnormalized_initial is None:
|
445 |
+
self.heatmap_unnormalized_initial = self.heatmap_unnormalized.copy()
|
446 |
+
# print("self.heatmap_unnormalized.sum(): ", self.heatmap_unnormalized.sum())
|
447 |
+
|
448 |
+
# Standard normalization to (0,1)
|
449 |
+
# self.heatmap = score_no_cls.reshape(side_dim, side_dim).clone()
|
450 |
+
# self.heatmap = self.heatmap.cpu().detach().numpy()
|
451 |
+
self.heatmap = sim_scores.copy()
|
452 |
+
self.heatmap = (self.heatmap - self.heatmap.min()) / (self.heatmap.max() - self.heatmap.min()) # normalize heatmap
|
453 |
+
|
454 |
+
|
455 |
+
def visualize_heatmap(
|
456 |
+
self,
|
457 |
+
step: int,
|
458 |
+
img_path_viz: str,
|
459 |
+
imo_path_viz: str,
|
460 |
+
patch_idx_viz: torch.Tensor,
|
461 |
+
patch_is_pos: list,
|
462 |
+
species_name: str
|
463 |
+
):
|
464 |
+
"""
|
465 |
+
Visualization function that plots the ground image, satellite image with
|
466 |
+
highlighted patch, and the learned heatmap.
|
467 |
+
|
468 |
+
:param step: Current TTA step (for labeling the plots).
|
469 |
+
:param img_path_viz: File path to the ground image.
|
470 |
+
:param imo_path_viz: File path to the satellite image.
|
471 |
+
:param patch_idx_viz: The patch index (tensor) being highlighted.
|
472 |
+
:param species_text: List of species text labels (from dataset).
|
473 |
+
:param species_idx_viz: Index of the species to visualize text for.
|
474 |
+
:param img_viz: The ground image tensor (3, H, W).
|
475 |
+
:param imo_viz: The satellite image tensor (3, H, W).
|
476 |
+
"""
|
477 |
+
|
478 |
+
# Switch off gradients for visualization
|
479 |
+
with torch.no_grad():
|
480 |
+
side_dim = self.heatmap.shape[0]
|
481 |
+
|
482 |
+
# -----------------------------------------------------------------
|
483 |
+
# Highlight the patch in the satellite image
|
484 |
+
sat_img_orig = Image.open(imo_path_viz)
|
485 |
+
sat_highlight = np.array(
|
486 |
+
self.dataset.debug_imo_viz_transform(sat_img_orig.copy())
|
487 |
+
)
|
488 |
+
|
489 |
+
for idx, patch_idx in enumerate(patch_idx_viz):
|
490 |
+
|
491 |
+
# Because patch_idx includes the [CLS] offset, subtract 1
|
492 |
+
patch_idx_actual = patch_idx - 1
|
493 |
+
|
494 |
+
# Get dimensions (H x W)
|
495 |
+
H, W = sat_highlight.shape[0], sat_highlight.shape[1]
|
496 |
+
|
497 |
+
# Number of patches in each dimension
|
498 |
+
patches_per_col = W // self.patch_size
|
499 |
+
patches_per_row = H // self.patch_size
|
500 |
+
|
501 |
+
# Determine row/col in the patch grid
|
502 |
+
patch_row = patch_idx_actual // patches_per_col
|
503 |
+
patch_col = patch_idx_actual % patches_per_row
|
504 |
+
|
505 |
+
# Pixel boundaries
|
506 |
+
x_start = patch_col * self.patch_size
|
507 |
+
x_end = (patch_col + 1) * self.patch_size
|
508 |
+
y_start = patch_row * self.patch_size
|
509 |
+
y_end = (patch_row + 1) * self.patch_size
|
510 |
+
|
511 |
+
# Fill patch area with a grey color
|
512 |
+
# if sat_highlight.dtype == np.uint8:
|
513 |
+
# grey_value = 128
|
514 |
+
# else:
|
515 |
+
# grey_value = 0.5
|
516 |
+
# sat_highlight[y_start:y_end, x_start:x_end, :] = grey_value
|
517 |
+
# Blue color for positive patches (transparent)
|
518 |
+
if patch_is_pos[idx]:
|
519 |
+
sat_highlight[y_start:y_end, x_start:x_end, 0] = 0
|
520 |
+
sat_highlight[y_start:y_end, x_start:x_end, 1] = 0
|
521 |
+
sat_highlight[y_start:y_end, x_start:x_end, 2] = 255
|
522 |
+
# Red color for negative patches (transparent)
|
523 |
+
else:
|
524 |
+
sat_highlight[y_start:y_end, x_start:x_end, 0] = 255
|
525 |
+
sat_highlight[y_start:y_end, x_start:x_end, 1] = 0
|
526 |
+
sat_highlight[y_start:y_end, x_start:x_end, 2] = 0
|
527 |
+
|
528 |
+
|
529 |
+
# -----------------------------------------------------------------
|
530 |
+
# Plot results
|
531 |
+
fig, axes = plt.subplots(1, 3, figsize=(12, 6))
|
532 |
+
fig.suptitle(f"Query: {species_name}")
|
533 |
+
|
534 |
+
# Ground image
|
535 |
+
img_orig = Image.open(img_path_viz)
|
536 |
+
axes[0].imshow(img_orig)
|
537 |
+
axes[0].set_title("Ground Image")
|
538 |
+
axes[0].axis("off")
|
539 |
+
|
540 |
+
# Satellite image
|
541 |
+
axes[1].imshow(sat_highlight)
|
542 |
+
axes[1].set_title("Sat Image")
|
543 |
+
axes[1].axis("off")
|
544 |
+
|
545 |
+
# Heatmap
|
546 |
+
heatmap_np = self.heatmap_unnormalized
|
547 |
+
im = axes[2].imshow(heatmap_np, cmap="viridis")
|
548 |
+
axes[2].set_title(
|
549 |
+
f"Heatmap at TTA Step {step:03d} ({side_dim}x{side_dim})"
|
550 |
+
)
|
551 |
+
axes[2].axis("off")
|
552 |
+
fig.colorbar(im, ax=axes[2], fraction=0.046, pad=0.04)
|
553 |
+
|
554 |
+
plt.tight_layout()
|
555 |
+
plt.show()
|
556 |
+
|
557 |
+
|
558 |
+
if __name__ == "__main__":
|
559 |
+
|
560 |
+
# # (CHANGE ME!) PARAMS: VAL
|
561 |
+
# clip_seg_tta = ClipSegTTA(
|
562 |
+
# img_dir="/mnt/hdd/inat2021_ds/inat21",
|
563 |
+
# imo_dir="/mnt/hdd/inat2021_ds/sat_test_jpg_512px",
|
564 |
+
# json_path="/mnt/hdd/inat2021_ds/inat21_val.json",
|
565 |
+
# sat_filtered_json_path="/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_val.json",
|
566 |
+
# sat_to_img_ids_json_path="/mnt/hdd/inat2021_ds/2_OTHERS/filtered_mapping_sat_to_img_ids_val.json",
|
567 |
+
# patch_size=14,
|
568 |
+
# sat_checkpoint_path="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt", # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_160225_NO_DATASET_SHUFFLE/satbind-epoch=02-val_loss=2.26-BACKUP.ckpt
|
569 |
+
# sample_index = 45,
|
570 |
+
# batch_size=1,
|
571 |
+
# num_workers=1,
|
572 |
+
# device="cuda",
|
573 |
+
# )
|
574 |
+
|
575 |
+
# Run TTA on sample index 45, for 10 steps,
|
576 |
+
# visualizing every 2 steps, with heatmap enabled:
|
577 |
+
|
578 |
+
# # Step 1
|
579 |
+
# print("Executing step 1...")
|
580 |
+
# patch_indices = [422, 204, 45]
|
581 |
+
# patch_is_pos = [True, False, False]
|
582 |
+
# clip_seg_tta.execute_tta(patch_indices, patch_is_pos, tta_steps=10, neg_sample_weight_scale=1.0, num_viz_steps=2, viz_heatmap=True)
|
583 |
+
|
584 |
+
# # Step 2
|
585 |
+
# print("Executing step 2...")
|
586 |
+
# # patch_indices = [185, 89, 30] # for 256px
|
587 |
+
# patch_indices = [422, 204, 45, 423] # for 512px
|
588 |
+
# patch_is_pos = [True, False, False, True]
|
589 |
+
# clip_seg_tta.execute_tta(patch_indices, patch_is_pos, tta_steps=10, neg_sample_weight_scale=1.0, num_viz_steps=2, viz_heatmap=True)
|
590 |
+
|
591 |
+
|
592 |
+
# # (CHANGE ME!) PARAMS: TRAIN w/ TARGETS
|
593 |
+
# clip_seg_tta = ClipSegTTA(
|
594 |
+
# img_dir="/mnt/hdd/inat2021_ds/inat21",
|
595 |
+
# imo_dir="/mnt/hdd/inat2021_ds/sat_train_jpg_512px",
|
596 |
+
# json_path="/mnt/hdd/inat2021_ds/inat21_train.json",
|
597 |
+
# sat_filtered_json_path="/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json",
|
598 |
+
# # sat_to_img_ids_json_path="/mnt/hdd/inat2021_ds/2_OTHERS/filtered_mapping_sat_to_img_ids_train.json",
|
599 |
+
# sat_to_img_ids_json_path="/mnt/hdd/inat2021_ds/target_search_ds/OLD/taxon_sat_target_search_100x_per_10-20counts.json",
|
600 |
+
# patch_size=14,
|
601 |
+
# sat_checkpoint_path="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt", # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_160225_NO_DATASET_SHUFFLE/satbind-epoch=02-val_loss=2.26-BACKUP.ckpt
|
602 |
+
# sample_index = 99,
|
603 |
+
# batch_size=1,
|
604 |
+
# num_workers=1,
|
605 |
+
# device="cuda",
|
606 |
+
# )
|
607 |
+
|
608 |
+
# # (CHANGE ME!) PARAMS: TRAIN w/ TARGETS
|
609 |
+
# clip_seg_tta = ClipSegTTA(
|
610 |
+
# img_dir="/mnt/hdd/inat2021_ds/inat21",
|
611 |
+
# imo_dir="/mnt/hdd/inat2021_ds/sat_train_jpg_512px",
|
612 |
+
# json_path="/mnt/hdd/inat2021_ds/inat21_train.json",
|
613 |
+
# # sat_filtered_json_path="/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json",
|
614 |
+
# sat_to_img_ids_json_path="/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/search_val_in.json",
|
615 |
+
# sound_data_path='/mnt/hdd/inat2021_ds/2_OTHERS/sound_train',
|
616 |
+
# patch_size=14,
|
617 |
+
# sat_checkpoint_path="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px_search_ds_filtered/pixel_clip_512px_search_ds_180325_CLIP-L-336/satbind-epoch=02-val_loss=2.46-BACKUP.ckpt", # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt
|
618 |
+
# sound_checkpoint_path = "/home/user/Taxabind/TaxaBind/SoundBind/checkpoints/BUGFIX_CLIP_TRAIN_CORRECT_NOT_VAL_OUT_TAX_FILTER_TGT_ONLY_v2_130625/soundbind-epoch=19-val_loss=3.96_BACKUP.ckpt",
|
619 |
+
# sample_index = 8, # 5,6,8
|
620 |
+
# batch_size=1,
|
621 |
+
# num_workers=1,
|
622 |
+
# device="cuda",
|
623 |
+
# sat_to_img_ids_json_is_train_dict=False,
|
624 |
+
# )
|
625 |
+
|
626 |
+
# (CHANGE ME!) PARAMS: TRAIN w/ TARGETS
|
627 |
+
clip_seg_tta = ClipSegTTA(
|
628 |
+
img_dir="/mnt/hdd/inat2021_ds/inat21",
|
629 |
+
imo_dir="/mnt/hdd/inat2021_ds/sat_train_jpg_512px",
|
630 |
+
json_path="/mnt/hdd/inat2021_ds/inat21_train.json",
|
631 |
+
# sat_filtered_json_path="/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json",
|
632 |
+
sat_to_img_ids_json_path="/mnt/hdd/inat2021_ds/2_OTHERS/sound_train/sound_image_pairs_filtered_v3_with_in_domain_taxs_150625/val_in_with_sound_ids_v3.json",
|
633 |
+
sound_data_path='/mnt/hdd/inat2021_ds/2_OTHERS/sound_train',
|
634 |
+
patch_size=14,
|
635 |
+
sat_checkpoint_path="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px_search_ds_filtered/pixel_clip_512px_search_ds_180325_CLIP-L-336/satbind-epoch=02-val_loss=2.46-BACKUP.ckpt", # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt
|
636 |
+
sound_checkpoint_path = "/home/user/Taxabind/TaxaBind/SoundBind/checkpoints/BUGFIX_CLIP_TRAIN_CORRECT_NOT_VAL_OUT_TAX_FILTER_TGT_ONLY_v2_130625/soundbind-epoch=19-val_loss=3.96_BACKUP.ckpt",
|
637 |
+
sample_index = 8, # 5,6,8
|
638 |
+
batch_size=1,
|
639 |
+
num_workers=1,
|
640 |
+
device="cuda",
|
641 |
+
sat_to_img_ids_json_is_train_dict=False,
|
642 |
+
initial_modality="sound"
|
643 |
+
)
|
644 |
+
|
645 |
+
|
646 |
+
###########################################
|
647 |
+
# Scaling with list
|
648 |
+
###########################################
|
649 |
+
|
650 |
+
# # print("Executing step 4b...")
|
651 |
+
# patch_indices = [422, 45]
|
652 |
+
# patch_is_pos = [True, False]
|
653 |
+
# # per_sample_weight = [1.0, 1.0]
|
654 |
+
# pos_sample_weight = 1.0
|
655 |
+
# neg_sample_weight = 1.0
|
656 |
+
# clip_seg_tta.execute_tta(
|
657 |
+
# patch_indices,
|
658 |
+
# patch_is_pos,
|
659 |
+
# pos_sample_weight,
|
660 |
+
# neg_sample_weight,
|
661 |
+
# tta_steps=10,
|
662 |
+
# num_viz_steps=2,
|
663 |
+
# viz_heatmap=True)
|
664 |
+
|
665 |
+
# print("Executing step 4b...")
|
666 |
+
patch_indices = [422, 32]
|
667 |
+
patch_is_pos = [True, False]
|
668 |
+
# per_sample_weight = [1.0, 1.0]
|
669 |
+
pos_sample_weight = 1.0
|
670 |
+
neg_sample_weight = 1.0
|
671 |
+
clip_seg_tta.execute_tta(
|
672 |
+
patch_indices,
|
673 |
+
patch_is_pos,
|
674 |
+
pos_sample_weight,
|
675 |
+
neg_sample_weight,
|
676 |
+
tta_steps=30,
|
677 |
+
num_viz_steps=2,
|
678 |
+
viz_heatmap=True,
|
679 |
+
modality="sound")
|
680 |
+
|
681 |
+
# # print("Executing step 4b...")
|
682 |
+
# patch_indices = [45, 422]
|
683 |
+
# patch_is_pos = [True, False]
|
684 |
+
# per_sample_weight = [1.0, 1.0]
|
685 |
+
# clip_seg_tta.execute_tta(
|
686 |
+
# patch_indices,
|
687 |
+
# patch_is_pos,
|
688 |
+
# tta_steps=20,
|
689 |
+
# per_sample_weight=per_sample_weight,
|
690 |
+
# num_viz_steps=2,
|
691 |
+
# viz_heatmap=True)
|
692 |
+
|
693 |
+
# # print("Executing step 4b...")
|
694 |
+
# patch_indices = [45, 422]
|
695 |
+
# patch_is_pos = [2, False]
|
696 |
+
# per_sample_weight = [0.1, 1.0]
|
697 |
+
# clip_seg_tta.execute_tta(
|
698 |
+
# patch_indices,
|
699 |
+
# patch_is_pos,
|
700 |
+
# tta_steps=20,
|
701 |
+
# per_sample_weight=per_sample_weight,
|
702 |
+
# num_viz_steps=2,
|
703 |
+
# viz_heatmap=True)
|
704 |
+
|
705 |
+
# # print("Executing step 4b...")
|
706 |
+
# patch_indices = [45, 422]
|
707 |
+
# patch_is_pos = [True, False]
|
708 |
+
# per_sample_weight = [1.0, 0.1]
|
709 |
+
# clip_seg_tta.execute_tta(
|
710 |
+
# patch_indices,
|
711 |
+
# patch_is_pos,
|
712 |
+
# tta_steps=10,
|
713 |
+
# per_sample_weight=per_sample_weight,
|
714 |
+
# num_viz_steps=2,
|
715 |
+
# viz_heatmap=True)
|
716 |
+
|
717 |
+
# # print("Executing step 4b...")
|
718 |
+
# patch_indices = [422, 424, 398, 400, 45, 47, 41, 43]
|
719 |
+
# patch_is_pos = [True, True, True, True, False, False, False, False]
|
720 |
+
# per_sample_weight = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
721 |
+
# clip_seg_tta.execute_tta(
|
722 |
+
# patch_indices,
|
723 |
+
# patch_is_pos,
|
724 |
+
# tta_steps=10,
|
725 |
+
# per_sample_weight=per_sample_weight,
|
726 |
+
# num_viz_steps=2,
|
727 |
+
# viz_heatmap=True)
|
728 |
+
|
729 |
+
# # print("Executing step 4b...")
|
730 |
+
# patch_indices = [45, 47, 41, 43, 422, 424, 398, 400]
|
731 |
+
# patch_is_pos = [True, True, True, True, False, False, False, False]
|
732 |
+
# per_sample_weight = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
733 |
+
# clip_seg_tta.execute_tta(
|
734 |
+
# patch_indices,
|
735 |
+
# patch_is_pos,
|
736 |
+
# tta_steps=10,
|
737 |
+
# per_sample_weight=per_sample_weight,
|
738 |
+
# num_viz_steps=2,
|
739 |
+
# viz_heatmap=True)
|
740 |
+
|
741 |
+
# # print("Executing step 4b...")
|
742 |
+
# patch_indices = [45, 47, 41, 43, 422, 424, 398, 400]
|
743 |
+
# patch_is_pos = [True, True, True, True, False, False, False, False]
|
744 |
+
# per_sample_weight = [1.0, 1.0, 1.0, 1.0, 0.1, 0.1, 0.1, 0.1]
|
745 |
+
# clip_seg_tta.execute_tta(
|
746 |
+
# patch_indices,
|
747 |
+
# patch_is_pos,
|
748 |
+
# tta_steps=10,
|
749 |
+
# per_sample_weight=per_sample_weight,
|
750 |
+
# num_viz_steps=2,
|
751 |
+
# viz_heatmap=True)
|
752 |
+
|
753 |
+
# # print("Executing step 4b...")
|
754 |
+
# patch_indices = [45, 47, 41, 43, 422, 424, 398, 400]
|
755 |
+
# patch_is_pos = [True, True, True, True, False, False, False, False]
|
756 |
+
# per_sample_weight = [0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 1.0, 1.0]
|
757 |
+
# clip_seg_tta.execute_tta(
|
758 |
+
# patch_indices,
|
759 |
+
# patch_is_pos,
|
760 |
+
# tta_steps=10,
|
761 |
+
# per_sample_weight=per_sample_weight,
|
762 |
+
# num_viz_steps=2,
|
763 |
+
# viz_heatmap=True)
|
764 |
+
|
765 |
+
# # print("Executing step 4b...")
|
766 |
+
# patch_indices = [45, 47, 41, 43, 42, 44, 46, 40]
|
767 |
+
# patch_is_pos = [True, True, True, True, True, True, True, True]
|
768 |
+
# per_sample_weight = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
769 |
+
# clip_seg_tta.execute_tta(
|
770 |
+
# patch_indices,
|
771 |
+
# patch_is_pos,
|
772 |
+
# tta_steps=10,
|
773 |
+
# per_sample_weight=per_sample_weight,
|
774 |
+
# num_viz_steps=2,
|
775 |
+
# viz_heatmap=True)
|
776 |
+
|
777 |
+
# # print("Executing step 4b...")
|
778 |
+
# patch_indices = [45, 47, 41, 43, 422, 424, 398, 400]
|
779 |
+
# patch_is_pos = [True, True, True, True, False, False, False, False]
|
780 |
+
# per_sample_weight = [1.0, 1.0, 1.0, 1.0, 0.1, 0.1, 0.1, 0.1]
|
781 |
+
# clip_seg_tta.execute_tta(
|
782 |
+
# patch_indices,
|
783 |
+
# patch_is_pos,
|
784 |
+
# tta_steps=10,
|
785 |
+
# per_sample_weight=per_sample_weight,
|
786 |
+
# num_viz_steps=2,
|
787 |
+
# viz_heatmap=True)
|
788 |
+
|
789 |
+
# # print("Executing step 4b...")
|
790 |
+
# patch_indices = [554, 45, 422]
|
791 |
+
# patch_is_pos = [True, True, False]
|
792 |
+
# per_sample_weight = [1.0, 1.0, 1.0]
|
793 |
+
# clip_seg_tta.execute_tta(
|
794 |
+
# patch_indices,
|
795 |
+
# patch_is_pos,
|
796 |
+
# tta_steps=10,
|
797 |
+
# per_sample_weight=per_sample_weight,
|
798 |
+
# num_viz_steps=2,
|
799 |
+
# viz_heatmap=True)
|
800 |
+
|
801 |
+
# # print("Executing step 4b...")
|
802 |
+
# patch_indices = [554, 45, 422]
|
803 |
+
# patch_is_pos = [True, True, False]
|
804 |
+
# per_sample_weight = [1.0, 1.0, 0.1]
|
805 |
+
# clip_seg_tta.execute_tta(
|
806 |
+
# patch_indices,
|
807 |
+
# patch_is_pos,
|
808 |
+
# tta_steps=10,
|
809 |
+
# per_sample_weight=per_sample_weight,
|
810 |
+
# num_viz_steps=2,
|
811 |
+
# viz_heatmap=True)
|
812 |
+
|
813 |
+
# # print("Executing step 4b...")
|
814 |
+
# patch_indices = [554, 45, 422]
|
815 |
+
# patch_is_pos = [False, True, False]
|
816 |
+
# per_sample_weight = [1.0, 1.0, 1.0]
|
817 |
+
# clip_seg_tta.execute_tta(
|
818 |
+
# patch_indices,
|
819 |
+
# patch_is_pos,
|
820 |
+
# tta_steps=10,
|
821 |
+
# per_sample_weight=per_sample_weight,
|
822 |
+
# num_viz_steps=2,
|
823 |
+
# viz_heatmap=True)
|
824 |
+
|
825 |
+
# # print("Executing step 4b...")
|
826 |
+
# patch_indices = [554, 45, 422]
|
827 |
+
# patch_is_pos = [False, True, False]
|
828 |
+
# per_sample_weight = [0.0, 1.0, 0.8]
|
829 |
+
# clip_seg_tta.execute_tta(
|
830 |
+
# patch_indices,
|
831 |
+
# patch_is_pos,
|
832 |
+
# tta_steps=10,
|
833 |
+
# per_sample_weight=per_sample_weight,
|
834 |
+
# num_viz_steps=2,
|
835 |
+
# viz_heatmap=True)
|
836 |
+
|
837 |
+
|
838 |
+
###########################################
|
839 |
+
# Text modality
|
840 |
+
# NOTE: Less aggressive vs image modality
|
841 |
+
###########################################
|
842 |
+
|
843 |
+
# # print("Executing step 4b...")
|
844 |
+
# patch_indices = [45, 422]
|
845 |
+
# patch_is_pos = [True, False]
|
846 |
+
# per_sample_weight = [1.0, 1.0]
|
847 |
+
# clip_seg_tta.execute_tta(
|
848 |
+
# patch_indices,
|
849 |
+
# patch_is_pos,
|
850 |
+
# tta_steps=10,
|
851 |
+
# per_sample_weight=per_sample_weight,
|
852 |
+
# modality="text",
|
853 |
+
# num_viz_steps=2,
|
854 |
+
# viz_heatmap=True)
|
855 |
+
|
856 |
+
# # print("Executing step 4b...")
|
857 |
+
# patch_indices = [45, 422]
|
858 |
+
# patch_is_pos = [True, False]
|
859 |
+
# per_sample_weight = [0.1, 1.0]
|
860 |
+
# clip_seg_tta.execute_tta(
|
861 |
+
# patch_indices,
|
862 |
+
# patch_is_pos,
|
863 |
+
# tta_steps=10,
|
864 |
+
# per_sample_weight=per_sample_weight,
|
865 |
+
# modality="text",
|
866 |
+
# num_viz_steps=2,
|
867 |
+
# viz_heatmap=True)
|
868 |
+
|
869 |
+
# # print("Executing step 4b...")
|
870 |
+
# patch_indices = [45, 422]
|
871 |
+
# patch_is_pos = [True, False]
|
872 |
+
# per_sample_weight = [1.0, 0.1]
|
873 |
+
# clip_seg_tta.execute_tta(
|
874 |
+
# patch_indices,
|
875 |
+
# patch_is_pos,
|
876 |
+
# tta_steps=10,
|
877 |
+
# per_sample_weight=per_sample_weight,
|
878 |
+
# modality="text",
|
879 |
+
# num_viz_steps=2,
|
880 |
+
# viz_heatmap=True)
|
881 |
+
|
882 |
+
|
883 |
+
###########################################
|
884 |
+
# Combined modality
|
885 |
+
# NOTE: Best of both worlds
|
886 |
+
###########################################
|
887 |
+
|
888 |
+
# # print("Executing step 4b...")
|
889 |
+
# patch_indices = [45, 422]
|
890 |
+
# patch_is_pos = [True, False]
|
891 |
+
# per_sample_weight = [1.0, 1.0]
|
892 |
+
# clip_seg_tta.execute_tta(
|
893 |
+
# patch_indices,
|
894 |
+
# patch_is_pos,
|
895 |
+
# tta_steps=10,
|
896 |
+
# per_sample_weight=per_sample_weight,
|
897 |
+
# modality="combined",
|
898 |
+
# num_viz_steps=2,
|
899 |
+
# viz_heatmap=True)
|
900 |
+
|
901 |
+
# # print("Executing step 4b...")
|
902 |
+
# patch_indices = [45, 422]
|
903 |
+
# patch_is_pos = [True, False]
|
904 |
+
# per_sample_weight = [0.1, 1.0]
|
905 |
+
# clip_seg_tta.execute_tta(
|
906 |
+
# patch_indices,
|
907 |
+
# patch_is_pos,
|
908 |
+
# tta_steps=10,
|
909 |
+
# per_sample_weight=per_sample_weight,
|
910 |
+
# modality="combined",
|
911 |
+
# num_viz_steps=2,
|
912 |
+
# viz_heatmap=True)
|
913 |
+
|
914 |
+
# # print("Executing step 4b...")
|
915 |
+
# patch_indices = [45, 422]
|
916 |
+
# patch_is_pos = [True, False]
|
917 |
+
# per_sample_weight = [1.0, 0.1]
|
918 |
+
# clip_seg_tta.execute_tta(
|
919 |
+
# patch_indices,
|
920 |
+
# patch_is_pos,
|
921 |
+
# tta_steps=10,
|
922 |
+
# per_sample_weight=per_sample_weight,
|
923 |
+
# modality="combined",
|
924 |
+
# num_viz_steps=2,
|
925 |
+
# viz_heatmap=True)
|
926 |
+
|
927 |
+
###########################################
|
928 |
+
# Real examples
|
929 |
+
###########################################
|
930 |
+
|
931 |
+
|
932 |
+
|
933 |
+
|
Taxabind/Taxabind/SatBind/config.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
config = edict()
|
4 |
+
|
5 |
+
# Pixel level CLIP training
|
6 |
+
config.img_dir = '/mnt/hdd/inat2021_ds/inat21'
|
7 |
+
config.imo_dir = '/mnt/hdd/inat2021_ds/sat_train_jpg_512px' # sat_train_jpg_256px, sat_train_jpg_512px
|
8 |
+
config.imo_dir_val = '/mnt/hdd/inat2021_ds/sat_test_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px
|
9 |
+
config.train_json_path = '/mnt/hdd/inat2021_ds/inat21_train.json'
|
10 |
+
config.val_json_path = '/mnt/hdd/inat2021_ds/inat21_val.json' # no filter needed
|
11 |
+
# config.sat_to_img_ids_train_json_path = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_train.json'
|
12 |
+
config.sat_to_img_ids_train_json_path = '/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/SEARCH_DS_filtered_mapping_sat_to_img_ids_train_3-20counts.json'
|
13 |
+
config.sat_to_img_ids_val_json_path = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json'
|
14 |
+
# config.filtered_train_json_path = '/mnt/hdd/inat2021_ds/inat21_filtered_pixel_clip_v2_train.json'
|
15 |
+
# config.filtered_val_json_path = '/mnt/hdd/inat2021_ds/inat21_filtered_pixel_clip_val.json' # no filter needed
|
16 |
+
|
17 |
+
# # Pixel level CLIP training (Toy example)
|
18 |
+
# config.img_dir = '../scripts/preprocess/expt/fake_img_dataset'
|
19 |
+
# config.imo_dir = '../scripts/preprocess/expt/sat_train_dataset'
|
20 |
+
# config.imo_dir_val = '../scripts/preprocess/expt/sat_val_dataset'
|
21 |
+
# config.train_json_path = '../scripts/preprocess/expt/inat2021_train_micro.json'
|
22 |
+
# config.val_json_path = '../scripts/preprocess/expt/inat2021_val_micro.json'
|
23 |
+
# config.sat_to_img_ids_train_json_path = '../scripts/preprocess/expt/sat_filtered/filtered_mapping_sat_to_img_ids_train_micro.json'
|
24 |
+
# config.sat_to_img_ids_val_json_path = '../scripts/preprocess/expt/sat_filtered/filtered_mapping_sat_to_img_ids_val_micro.json'
|
25 |
+
# # config.filtered_train_json_path = '../scripts/preprocess/expt/inat2021_train_micro.json'
|
26 |
+
# # config.filtered_val_json_path = '../scripts/preprocess/expt/inat2021_val_micro.json' # no filter needed
|
27 |
+
|
28 |
+
# # Image level CLIP training
|
29 |
+
# config.img_dir = '/mnt/hdd/inat2021_ds/inat21'
|
30 |
+
# config.imo_dir = '/mnt/hdd/inat2021_ds/sat_train_jpg'
|
31 |
+
# config.imo_dir_val = '/mnt/hdd/inat2021_ds/sat_test_jpg'
|
32 |
+
# config.train_json_path = '/mnt/hdd/inat2021_ds/inat21_filtered_img_clip_train.json'
|
33 |
+
# config.val_json_path = '/mnt/hdd/inat2021_ds/inat21_val.json' # no filter needed
|
34 |
+
|
35 |
+
# # Image level CLIP training (Toy example)
|
36 |
+
# config.img_dir = '../scripts/preprocess/expt/fake_img_dataset'
|
37 |
+
# config.imo_dir = '../scripts/preprocess/expt/sat_train_dataset'
|
38 |
+
# config.imo_dir_val = '../scripts/preprocess/expt/sat_val_dataset'
|
39 |
+
# config.train_json_path = '../scripts/preprocess/expt/inat2021_train_micro.json'
|
40 |
+
# config.val_json_path = '../scripts/preprocess/expt/inat2021_val_micro.json'
|
41 |
+
|
42 |
+
# batch_size * accumulate_grad_batches * devices = MUST BE CONSTANT (i.e. 256 * 8 * 2 = 4096)
|
43 |
+
config.batch_size = 32 # 256
|
44 |
+
config.lr = 1e-4 # 1e-4
|
45 |
+
config.accumulate_grad_batches = 64
|
46 |
+
config.max_epochs = 20
|
47 |
+
config.num_workers = 16
|
48 |
+
config.devices = 2 # 2
|
49 |
+
config.val_check_interval = 0.5
|
50 |
+
config.sat_encoder = 'openai/clip-vit-large-patch14-336' # openai/clip-vit-base-patch16, openai/clip-vit-large-patch14-336
|
51 |
+
config.patch_size = 14
|
52 |
+
|
53 |
+
config.save_dir = 'checkpoints'
|
54 |
+
config.filename = 'satbind-{epoch:02d}-{val_loss:.2f}'
|
55 |
+
|
56 |
+
config.locked_tuning = True
|
57 |
+
|
58 |
+
config.resume_from_checkpoint = False
|
59 |
+
config.resume_checkpoint_name = 'satbind-resume'
|
60 |
+
|
61 |
+
print("config: \n", config)
|
Taxabind/Taxabind/SatBind/dataset.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from matplotlib import pyplot as plt
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torchvision import transforms
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
from PIL import Image
|
10 |
+
from datetime import datetime
|
11 |
+
from torchvision.transforms import v2
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
import glob
|
15 |
+
|
16 |
+
# Just to import SoundBind (1 dir back)
|
17 |
+
import sys
|
18 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
19 |
+
from SoundBind.sound_encoder import get_audio_clap
|
20 |
+
|
21 |
+
class SatNatDataset(Dataset):
|
22 |
+
def __init__(self, img_dir, imo_dir, json_path, sat_to_img_ids_json_path, patch_size, mode='train', get_img_path=False, sat_to_img_ids_json_is_train_dict=True, tax_to_filter_val="", sound_data_path=None):
|
23 |
+
self.img_dir = img_dir
|
24 |
+
self.imo_dir = imo_dir
|
25 |
+
self.patch_size = patch_size
|
26 |
+
self.get_img_path = get_img_path
|
27 |
+
self.mode = mode
|
28 |
+
self.sat_to_img_ids_json_is_train_dict = sat_to_img_ids_json_is_train_dict
|
29 |
+
self.tax_to_filter_val = tax_to_filter_val
|
30 |
+
self.sound_data_path = sound_data_path
|
31 |
+
|
32 |
+
# ADDED
|
33 |
+
self.current_epoch = 0
|
34 |
+
|
35 |
+
self.json = json.load(open(json_path, 'r'))
|
36 |
+
# self.sat_filt_json = json.load(open(sat_filtered_json_path, 'r'))
|
37 |
+
self.sat_to_img_ids_json_path = json.load(open(sat_to_img_ids_json_path, 'r'))
|
38 |
+
self.images = self.json['images']
|
39 |
+
self.annot = self.json['annotations']
|
40 |
+
for i in range(len(self.images)):
|
41 |
+
assert self.images[i]['id'] == self.annot[i]['id']
|
42 |
+
self.images[i]['label'] = self.annot[i]['category_id']
|
43 |
+
self.filtered_json = [d for d in self.images if d['latitude'] is not None and d['longitude'] is not None]
|
44 |
+
self.species_text = list(set([" ".join(d['file_name'].split("/")[1].split("_")[1:]) for d in self.filtered_json]))
|
45 |
+
|
46 |
+
# self.sat_filtered_json = [d for d in self.sat_filt_json['images'] if d['latitude'] is not None and d['longitude'] is not None]
|
47 |
+
# self.sat_paths = {d['id']: str(d['id'])+'_'+str(d['latitude'])+'_'+str(d['longitude'])+'.jpg' for d in self.sat_filtered_json}
|
48 |
+
self.inat_json_dict = {
|
49 |
+
"images": {img["id"]: img for img in self.images},
|
50 |
+
"annotations": {ann["id"]: ann for ann in self.annot},
|
51 |
+
}
|
52 |
+
|
53 |
+
# Expand dict
|
54 |
+
self.sat_to_img_ids_tuples = []
|
55 |
+
if self.sat_to_img_ids_json_is_train_dict:
|
56 |
+
# for i in range(len(self.sat_filtered_json)):
|
57 |
+
# id = self.sat_filtered_json[i]['id']
|
58 |
+
# sat_id = Path(self.sat_paths[id]).stem # remove file extension
|
59 |
+
# img_ids = self.sat_to_img_ids_json_path[sat_id]["img_ids"]
|
60 |
+
for sat_key, sat_sample in self.sat_to_img_ids_json_path.items():
|
61 |
+
id = sat_sample["id"] # int(sat_key.split("_")[0]) # sat_sample['id']
|
62 |
+
sat_path = sat_sample["sat_path"]
|
63 |
+
img_ids = sat_sample["img_ids"]
|
64 |
+
# img_locs = sat_sample["target_positions"] # NOTE: Not accurate after resize
|
65 |
+
for img_id in img_ids:
|
66 |
+
self.sat_to_img_ids_tuples.append((id, sat_path, img_id))
|
67 |
+
print("len(self.sat_to_img_ids_json_path): ", len(self.sat_to_img_ids_json_path))
|
68 |
+
print("len(self.sat_to_img_ids_tuples): ", len(self.sat_to_img_ids_tuples))
|
69 |
+
else:
|
70 |
+
self.filtered_val_ds_by_tax = [d for d in self.sat_to_img_ids_json_path if self.tax_to_filter_val in d['taxonomy']]
|
71 |
+
|
72 |
+
if mode == 'train':
|
73 |
+
self.img_transform = transforms.Compose([
|
74 |
+
transforms.Resize((256, 256)),
|
75 |
+
transforms.RandomCrop((224, 224)),
|
76 |
+
transforms.RandomHorizontalFlip(0.5),
|
77 |
+
transforms.GaussianBlur(5, (0.01, 1.0)),
|
78 |
+
transforms.ToTensor(),
|
79 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
80 |
+
std=[0.229, 0.224, 0.225])
|
81 |
+
])
|
82 |
+
self.imo_transform = transforms.Compose([
|
83 |
+
# transforms.Resize((256,256)),
|
84 |
+
# transforms.RandomCrop((224, 224)),
|
85 |
+
# transforms.RandomHorizontalFlip(0.5),
|
86 |
+
# transforms.Resize((224,224)),
|
87 |
+
transforms.Resize((336,336)),
|
88 |
+
transforms.GaussianBlur(5, (0.01, 1.0)),
|
89 |
+
transforms.ToTensor(),
|
90 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
91 |
+
std=[0.229, 0.224, 0.225])
|
92 |
+
])
|
93 |
+
else:
|
94 |
+
self.img_transform = transforms.Compose([
|
95 |
+
transforms.Resize((256, 256)),
|
96 |
+
transforms.CenterCrop((224, 224)),
|
97 |
+
transforms.ToTensor(),
|
98 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
99 |
+
std=[0.229, 0.224, 0.225])
|
100 |
+
])
|
101 |
+
self.imo_transform = transforms.Compose([
|
102 |
+
# transforms.Resize((256,256)),
|
103 |
+
# transforms.CenterCrop((224, 224)),
|
104 |
+
# transforms.Resize((224,224)),
|
105 |
+
transforms.Resize((336,336)),
|
106 |
+
transforms.ToTensor(),
|
107 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
108 |
+
std=[0.229, 0.224, 0.225])
|
109 |
+
])
|
110 |
+
self.debug_img_viz_transform = transforms.Compose([
|
111 |
+
transforms.Resize((256, 256)),
|
112 |
+
transforms.CenterCrop((224, 224))
|
113 |
+
])
|
114 |
+
self.debug_imo_viz_transform = transforms.Compose([
|
115 |
+
# transforms.Resize((224,224))
|
116 |
+
transforms.Resize((336,336))
|
117 |
+
])
|
118 |
+
|
119 |
+
def __len__(self):
|
120 |
+
# return len(self.filtered_json)
|
121 |
+
# return len(self.sat_filtered_json)
|
122 |
+
return len(self.sat_to_img_ids_tuples)
|
123 |
+
|
124 |
+
def __getitem__(self, idx):
|
125 |
+
|
126 |
+
# # DEBUG: Visualize original image and IMO pair
|
127 |
+
# self.debug_viz_img_imo_orig_pair(idx)
|
128 |
+
if not self.sat_to_img_ids_json_is_train_dict:
|
129 |
+
print("Json is not dict. Please reformat for training!")
|
130 |
+
exit()
|
131 |
+
|
132 |
+
## Pixel-level CLIP
|
133 |
+
# Get sat img
|
134 |
+
# id = self.sat_filtered_json[idx]['id']
|
135 |
+
id, sat_path, img_id = self.sat_to_img_ids_tuples[idx]
|
136 |
+
imo_path = os.path.join(self.imo_dir, sat_path)
|
137 |
+
imo = self.imo_transform(Image.open(imo_path))
|
138 |
+
sat_id = Path(sat_path).stem # remove file extension
|
139 |
+
|
140 |
+
img_path = os.path.join(self.img_dir, self.inat_json_dict["images"][img_id]["file_name"])
|
141 |
+
img = self.img_transform(Image.open(img_path))
|
142 |
+
|
143 |
+
# # Map lat-lon to pixel in sat img
|
144 |
+
sat_min_lon = self.sat_to_img_ids_json_path[sat_id]["sat_bounds"]["min_lon"]
|
145 |
+
sat_min_lat = self.sat_to_img_ids_json_path[sat_id]["sat_bounds"]["min_lat"]
|
146 |
+
sat_max_lon = self.sat_to_img_ids_json_path[sat_id]["sat_bounds"]["max_lon"]
|
147 |
+
sat_max_lat = self.sat_to_img_ids_json_path[sat_id]["sat_bounds"]["max_lat"]
|
148 |
+
|
149 |
+
img_lon = self.inat_json_dict["images"][img_id]["longitude"]
|
150 |
+
img_lat = self.inat_json_dict["images"][img_id]["latitude"]
|
151 |
+
row, col = self.latlon_to_pixel(img_lat, img_lon, sat_min_lat, sat_max_lat, sat_min_lon, sat_max_lon, imo.shape[2], imo.shape[1])
|
152 |
+
|
153 |
+
patch_idx = self.pixel_to_patch_idx(row, col, self.patch_size, imo.shape[2], imo.shape[1])
|
154 |
+
patch_idx += 1 # account for [CLS] token at the start of ViT input sequence
|
155 |
+
|
156 |
+
species_text = " ".join(self.inat_json_dict["images"][img_id]['file_name'].split("/")[1].split("_")[1:])
|
157 |
+
|
158 |
+
# # # DEBUG: Visualize patch
|
159 |
+
# img_debug = np.array(self.debug_img_viz_transform(Image.open(img_path)))
|
160 |
+
# imo_debug = np.array(self.debug_imo_viz_transform(Image.open(imo_path)))
|
161 |
+
# self.debug_viz_patch(img_debug, imo_debug, patch_idx, self.patch_size, species_text)
|
162 |
+
|
163 |
+
if self.get_img_path:
|
164 |
+
return img_path, imo_path, img, imo, self.inat_json_dict["annotations"][img_id]['category_id'], patch_idx, species_text, self.species_text.index(species_text)
|
165 |
+
else:
|
166 |
+
return img, imo, self.inat_json_dict["annotations"][img_id]['category_id'], patch_idx, species_text, self.species_text.index(species_text)
|
167 |
+
|
168 |
+
|
169 |
+
# NOTE: Image-level CLIP
|
170 |
+
# img_path = os.path.join(self.img_dir, self.filtered_json[idx]['file_name'])
|
171 |
+
# imo_path = os.path.join(self.imo_dir, self.sat_paths[self.filtered_json[idx]['id']])
|
172 |
+
# img = self.img_transform(Image.open(img_path))
|
173 |
+
# imo = self.imo_transform(Image.open(imo_path))
|
174 |
+
# species_text = " ".join(self.filtered_json[idx]['file_name'].split("/")[1].split("_")[1:])
|
175 |
+
# return img, imo, self.filtered_json[idx]['label'], species_text, self.species_text.index(species_text)
|
176 |
+
|
177 |
+
def latlon_to_pixel(self, lat, lon, lat_min, lat_max, lon_min, lon_max, img_width, img_height):
|
178 |
+
lat_res = (lat_max - lat_min) / img_height
|
179 |
+
lon_res = (lon_max - lon_min) / img_width
|
180 |
+
|
181 |
+
col = int(math.floor((lon - lon_min) / lon_res))
|
182 |
+
row = int(math.floor((lat_max - lat) / lat_res))
|
183 |
+
|
184 |
+
return row, col
|
185 |
+
|
186 |
+
def pixel_to_patch_idx(self, row, col, patch_size, img_width, img_height):
|
187 |
+
patch_size_width = patch_size
|
188 |
+
patch_size_height = patch_size
|
189 |
+
patch_row = row // patch_size_height
|
190 |
+
patch_col = col // patch_size_width
|
191 |
+
patch_idx = patch_row * (img_width // patch_size) + patch_col
|
192 |
+
|
193 |
+
return patch_idx
|
194 |
+
|
195 |
+
def set_epoch(self, epoch):
|
196 |
+
self.current_epoch = epoch
|
197 |
+
|
198 |
+
###########################################################
|
199 |
+
|
200 |
+
def debug_viz_img_imo_orig_pair(self, idx):
|
201 |
+
|
202 |
+
img_path = os.path.join(self.img_dir, self.filtered_json[idx]['file_name'])
|
203 |
+
imo_path = os.path.join(self.imo_dir, self.sat_paths[self.filtered_json[idx]['id']])
|
204 |
+
# img = self.img_transform(Image.open(img_path))
|
205 |
+
# imo = self.imo_transform(Image.open(imo_path))
|
206 |
+
# species_text = " ".join(self.filtered_json[idx]['file_name'].split("/")[1].split("_")[1:])
|
207 |
+
|
208 |
+
# Create a side-by-side plot
|
209 |
+
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
|
210 |
+
|
211 |
+
img_np = Image.open(img_path)
|
212 |
+
imo_np = Image.open(imo_path)
|
213 |
+
|
214 |
+
axes[0].imshow(img_np)
|
215 |
+
axes[0].set_title('Image')
|
216 |
+
axes[0].axis('off')
|
217 |
+
|
218 |
+
axes[1].imshow(imo_np)
|
219 |
+
axes[1].set_title('IMO')
|
220 |
+
axes[1].axis('off')
|
221 |
+
|
222 |
+
plt.tight_layout()
|
223 |
+
plt.show()
|
224 |
+
|
225 |
+
|
226 |
+
def debug_viz_patch(self, img_np, imo_np, patch_idx, patch_size, species_text):
|
227 |
+
|
228 |
+
# ----- Compute the patch boundaries -----
|
229 |
+
# Since your patch_idx has been incremented by 1 (for [CLS]), subtract 1 to get the zero-based index.
|
230 |
+
patch_idx_actual = patch_idx - 1
|
231 |
+
|
232 |
+
# Get the image dimensions from the satellite image (imo).
|
233 |
+
# Note: In your code, imo.shape[2] is width and imo.shape[1] is height.
|
234 |
+
H, W = imo_np.shape[0], imo_np.shape[1]
|
235 |
+
|
236 |
+
# Compute the number of patches per row.
|
237 |
+
patches_per_col = W // patch_size
|
238 |
+
patches_per_row = H // patch_size
|
239 |
+
|
240 |
+
# Determine the patch's row and column in the grid.
|
241 |
+
patch_row = patch_idx_actual // patches_per_col
|
242 |
+
patch_col = patch_idx_actual % patches_per_row
|
243 |
+
|
244 |
+
# Calculate pixel boundaries for the patch.
|
245 |
+
x_start = patch_col * patch_size
|
246 |
+
x_end = (patch_col + 1) * patch_size
|
247 |
+
y_start = patch_row * patch_size
|
248 |
+
y_end = (patch_row + 1) * patch_size
|
249 |
+
|
250 |
+
# ----- Extract the patch from the satellite image -----
|
251 |
+
imo_patch = imo_np[y_start:y_end, x_start:x_end, :]
|
252 |
+
|
253 |
+
# ----- Create a highlighted version of the satellite image -----
|
254 |
+
# We will grey out the patch region.
|
255 |
+
imo_highlight = imo_np.copy()
|
256 |
+
|
257 |
+
# Define a grey value based on the image's data type.
|
258 |
+
if imo_highlight.dtype == np.uint8:
|
259 |
+
grey_value = 128 # for 0-255 images
|
260 |
+
else:
|
261 |
+
grey_value = 0.5 # for images normalized to [0, 1]
|
262 |
+
|
263 |
+
# Replace the patch region with the grey color.
|
264 |
+
imo_highlight[y_start:y_end, x_start:x_end, :] = grey_value
|
265 |
+
|
266 |
+
# ----- Plotting the three images side by side -----
|
267 |
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
268 |
+
fig.suptitle(f"Species: {species_text}")
|
269 |
+
|
270 |
+
# Left: Satellite image with the patch greyed out.
|
271 |
+
axes[0].imshow(imo_highlight)
|
272 |
+
axes[0].set_title("IMO with Patch Greyed Out")
|
273 |
+
axes[0].axis("off")
|
274 |
+
|
275 |
+
# Middle: Cropped patch from the satellite image.
|
276 |
+
axes[1].imshow(imo_patch)
|
277 |
+
axes[1].set_title("IMO Patch")
|
278 |
+
axes[1].axis("off")
|
279 |
+
|
280 |
+
# Right: Ground image.
|
281 |
+
axes[2].imshow(img_np)
|
282 |
+
axes[2].set_title("Ground Image")
|
283 |
+
axes[2].axis("off")
|
284 |
+
|
285 |
+
plt.tight_layout()
|
286 |
+
plt.show()
|
287 |
+
|
288 |
+
###########################################################
|
289 |
+
|
290 |
+
def get_search_ds_data(self, idx):
|
291 |
+
|
292 |
+
if self.sat_to_img_ids_json_is_train_dict:
|
293 |
+
print("Json is dict. Please reformat for target search!")
|
294 |
+
exit()
|
295 |
+
|
296 |
+
|
297 |
+
# sat_sample = self.sat_to_img_ids_json_path[idx] if self.sat_to_img_ids_json_is_train_dict else self.filtered_val_ds_by_tax[idx]
|
298 |
+
bounded_idx = idx % len(self.filtered_val_ds_by_tax)
|
299 |
+
if idx >= len(self.filtered_val_ds_by_tax):
|
300 |
+
print("??? bounded_idx: ", bounded_idx)
|
301 |
+
sat_sample = self.filtered_val_ds_by_tax[bounded_idx]
|
302 |
+
target_positions = sat_sample["target_positions"]
|
303 |
+
imo_path = os.path.join(self.imo_dir, sat_sample["sat_path"])
|
304 |
+
imo = self.imo_transform(Image.open(imo_path))
|
305 |
+
|
306 |
+
img_paths = []
|
307 |
+
imgs = []
|
308 |
+
species_texts = []
|
309 |
+
for img_id in sat_sample["img_ids"]:
|
310 |
+
img_path = os.path.join(self.img_dir, self.inat_json_dict["images"][img_id]["file_name"])
|
311 |
+
img = self.img_transform(Image.open(img_path))
|
312 |
+
img_paths.append(img_path)
|
313 |
+
imgs.append(img)
|
314 |
+
|
315 |
+
species_text = " ".join(self.inat_json_dict["images"][img_id]['file_name'].split("/")[1].split("_")[1:])
|
316 |
+
species_texts.append(species_text)
|
317 |
+
imgs = torch.stack(imgs) # Stack all images into a single tensor
|
318 |
+
|
319 |
+
if len(set(species_texts)) > 1:
|
320 |
+
print("Species mismatch in search dataset!")
|
321 |
+
exit()
|
322 |
+
else:
|
323 |
+
species_name = species_texts[0]
|
324 |
+
|
325 |
+
gt_mask_name = str(sat_sample["id"]) + "_" + sat_sample["taxonomy"] + ".png" # Saved with 2x '_' by accident
|
326 |
+
gt_mask_name = gt_mask_name.replace(" ", "_")
|
327 |
+
|
328 |
+
# Consider sound if valid
|
329 |
+
sounds = []
|
330 |
+
sound_ids = []
|
331 |
+
# if self.sound_data_path is not None and "sound_ids" in sat_sample:
|
332 |
+
# for sound_id in sat_sample["sound_ids"]:
|
333 |
+
# sound_path = os.path.join(self.sound_data_path,"sounds_mp3",str(sound_id)+"."+'mp3')
|
334 |
+
# sound = get_audio_clap(sound_path) # , sound_format)
|
335 |
+
# for k in sound.keys():
|
336 |
+
# sound[k] = sound[k].squeeze(0)
|
337 |
+
# sounds.append(sound)
|
338 |
+
# sound_ids.append(sound_id)
|
339 |
+
# sounds = torch.stack(sounds)
|
340 |
+
|
341 |
+
# NOTE: Cannot stack cos sound format is different (Use first index)
|
342 |
+
if self.sound_data_path is not None and "sound_ids" in sat_sample:
|
343 |
+
sound_id = sat_sample["sound_ids"][0]
|
344 |
+
sound_path = os.path.join(self.sound_data_path,"sounds_mp3",str(sound_id)+"."+'mp3')
|
345 |
+
sound = get_audio_clap(sound_path) # , sound_format)
|
346 |
+
# NOTE: no need to squeeze if not sampled by Pytorch Lightning
|
347 |
+
# for k in sound.keys():
|
348 |
+
# sound[k] = sound[k].squeeze(0)
|
349 |
+
sounds = [sound]
|
350 |
+
sound_ids = [sound_id]
|
351 |
+
|
352 |
+
###################################
|
353 |
+
|
354 |
+
########################################################################################################
|
355 |
+
# # TEMP OVERRIDE:
|
356 |
+
# imo_path = "/home/user/search-tta-demo/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg"
|
357 |
+
# imo = self.imo_transform(Image.open(imo_path).convert("RGB"))
|
358 |
+
# img_paths = ["/home/user/search-tta-demo/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg"]
|
359 |
+
# imgs = [self.img_transform(Image.open(img_paths[0]))]
|
360 |
+
# imgs = torch.stack(imgs)
|
361 |
+
# species_name = "Animalia Chordata Aves Charadriiformes Laridae Larus marinus"
|
362 |
+
# gt_mask_name = "Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus"
|
363 |
+
|
364 |
+
# sounds = []
|
365 |
+
# sound_ids = []
|
366 |
+
# sound_id = 89758229
|
367 |
+
# sound_path = "/home/user/search-tta-demo/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3"
|
368 |
+
# sound = get_audio_clap(sound_path) # , sound_format)
|
369 |
+
# # NOTE: no need to squeeze if not sampled by Pytorch Lightning
|
370 |
+
# # for k in sound.keys():
|
371 |
+
# # sound[k] = sound[k].squeeze(0)
|
372 |
+
# sounds = [sound]
|
373 |
+
# sound_ids = [sound_id]
|
374 |
+
|
375 |
+
# # 2020 and beyond
|
376 |
+
# target_positions = [
|
377 |
+
# # (84, 57), # 2015
|
378 |
+
# # (59, 323), # Shouldn't be in water
|
379 |
+
# (50, 347), # ADDED: Shouldn't be in water
|
380 |
+
# # (54, 332), # In water
|
381 |
+
# # (19, 507), # 2014
|
382 |
+
# (68, 444), # shifted down
|
383 |
+
# (136, 505),
|
384 |
+
# (142, 465),
|
385 |
+
# (345, 343),
|
386 |
+
# (130, 241),
|
387 |
+
# (367, 281), # Should be more right actually
|
388 |
+
# # (463, 351),
|
389 |
+
# # (513, 381), # Out of bound
|
390 |
+
# # (405, 162), # 2019
|
391 |
+
# (315, 80) # Shifted down-left
|
392 |
+
# ]
|
393 |
+
|
394 |
+
########################################################################################################
|
395 |
+
|
396 |
+
return img_paths, imo_path, imgs, imo, sounds, sound_ids, species_name, target_positions, gt_mask_name
|
Taxabind/Taxabind/SatBind/kmeans_clustering.py
ADDED
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import matplotlib.patches as mpatches
|
7 |
+
import matplotlib.colors as colors
|
8 |
+
from sklearn.cluster import KMeans
|
9 |
+
from sklearn.metrics import silhouette_score
|
10 |
+
from kneed import KneeLocator
|
11 |
+
import scipy.ndimage as ndimage
|
12 |
+
from collections import Counter
|
13 |
+
|
14 |
+
# import matplotlib
|
15 |
+
# matplotlib.use("Agg") # <-- key line to avoid tkinter dependency
|
16 |
+
|
17 |
+
class CombinedSilhouetteInertiaClusterer:
|
18 |
+
"""
|
19 |
+
A class that:
|
20 |
+
1) Runs combined silhouette & WGSS (inertia) analysis to find the best k.
|
21 |
+
2) Performs a 2D reshape of patch labels.
|
22 |
+
3) Optionally smooths the cluster map.
|
23 |
+
4) Plots the final clustering (before & after smoothing).
|
24 |
+
5) Returns the *raw* (unsmoothed) KMeans labels as a 1D array.
|
25 |
+
"""
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
k_min=1,
|
29 |
+
k_max=8,
|
30 |
+
k_avg_max=4,
|
31 |
+
silhouette_threshold=0.15,
|
32 |
+
relative_threshold=0.15,
|
33 |
+
random_state=0,
|
34 |
+
min_patch_size=5,
|
35 |
+
n_smooth_iter=2,
|
36 |
+
ignore_label=-1,
|
37 |
+
plot=False,
|
38 |
+
gifs_dir = "./"
|
39 |
+
):
|
40 |
+
"""
|
41 |
+
Parameters
|
42 |
+
----------
|
43 |
+
k_min : int
|
44 |
+
Minimum number of clusters for KMeans.
|
45 |
+
k_max : int
|
46 |
+
Maximum number of clusters for KMeans.
|
47 |
+
k_avg_max : int
|
48 |
+
Upper bound on k after combining elbow & silhouette if they disagree.
|
49 |
+
silhouette_threshold : float
|
50 |
+
Minimum silhouette score at k=2 to justify splitting.
|
51 |
+
relative_threshold : float
|
52 |
+
Minimum % improvement in inertia from k=1→k=2 to justify splitting.
|
53 |
+
random_state : int
|
54 |
+
RNG seed for KMeans.
|
55 |
+
min_patch_size : int
|
56 |
+
Patches smaller than this threshold are smoothed.
|
57 |
+
n_smooth_iter : int
|
58 |
+
Number of smoothing iterations.
|
59 |
+
ignore_label : int
|
60 |
+
Label to ignore in smoothing step.
|
61 |
+
"""
|
62 |
+
self.k_min = k_min
|
63 |
+
self.k_max = k_max
|
64 |
+
self.k_avg_max = k_avg_max
|
65 |
+
self.silhouette_threshold = silhouette_threshold
|
66 |
+
self.relative_threshold = relative_threshold
|
67 |
+
self.random_state = random_state
|
68 |
+
|
69 |
+
self.min_patch_size = min_patch_size
|
70 |
+
self.n_smooth_iter = n_smooth_iter
|
71 |
+
self.ignore_label = ignore_label
|
72 |
+
self.plot = plot
|
73 |
+
self.gifs_dir = gifs_dir
|
74 |
+
|
75 |
+
self.final_k = None
|
76 |
+
self.final_labels_1d = None # 1D array of cluster labels (unsmoothed)
|
77 |
+
self.smoothed_labels_2d = None # 2D array of labels (smoothed)
|
78 |
+
self.kmeans_frame_files = []
|
79 |
+
|
80 |
+
##############################
|
81 |
+
# Helper functions
|
82 |
+
##############################
|
83 |
+
def find_largest_connected_component(self, cluster_map, label):
|
84 |
+
"""
|
85 |
+
Identifies the largest connected component (patch) of a given label in the cluster map.
|
86 |
+
Uses **4-connectivity** (N, S, E, W) to define connected regions.
|
87 |
+
|
88 |
+
Parameters:
|
89 |
+
-----------
|
90 |
+
cluster_map : ndarray (H, W)
|
91 |
+
The 2D label map.
|
92 |
+
label : int
|
93 |
+
The label whose largest connected patch we want to find.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
--------
|
97 |
+
largest_component_mask : ndarray (H, W)
|
98 |
+
A binary mask where `True` indicates pixels belonging to the largest connected component.
|
99 |
+
"""
|
100 |
+
mask = (cluster_map == label)
|
101 |
+
|
102 |
+
# Label all connected components using **4-connectivity**
|
103 |
+
labeled_components, num_components = ndimage.label(
|
104 |
+
mask,
|
105 |
+
structure=np.array([[0,1,0], [1,1,1], [0,1,0]])
|
106 |
+
)
|
107 |
+
|
108 |
+
if num_components == 0:
|
109 |
+
return np.zeros_like(cluster_map, dtype=bool) # No components found
|
110 |
+
|
111 |
+
# Compute sizes of each connected component
|
112 |
+
component_sizes = np.bincount(labeled_components.ravel())[1:] # Ignore background (0)
|
113 |
+
largest_component_idx = np.argmax(component_sizes) + 1 # +1 to skip background
|
114 |
+
|
115 |
+
# Create a mask for the largest component
|
116 |
+
largest_component_mask = (labeled_components == largest_component_idx)
|
117 |
+
return largest_component_mask
|
118 |
+
|
119 |
+
|
120 |
+
def smooth_small_patches(self, labels_2d, min_patch_size=5, n_iter=1, ignore_label=-1):
|
121 |
+
"""
|
122 |
+
Smooths **only small disconnected patches** (relative to the largest patch) within each cluster.
|
123 |
+
Uses **majority voting** in a 4-connected neighborhood (N, S, E, W).
|
124 |
+
Ignores neighboring pixels that have the `ignore_label` when computing the majority vote.
|
125 |
+
Implements a **tie-breaker**: If there is a tie between the original class and another class,
|
126 |
+
the original class is retained.
|
127 |
+
|
128 |
+
Parameters:
|
129 |
+
-----------
|
130 |
+
labels_2d : ndarray of shape (H, W)
|
131 |
+
Integer labels (e.g., cluster IDs), including an ignore label (-1).
|
132 |
+
min_patch_size : int
|
133 |
+
Patches smaller than this threshold are considered **isolated** and will be smoothed.
|
134 |
+
n_iter : int
|
135 |
+
Number of smoothing iterations.
|
136 |
+
ignore_label : int
|
137 |
+
Label value to be ignored when counting neighbors.
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
--------
|
141 |
+
smoothed_labels : ndarray of shape (H, W)
|
142 |
+
The label map after selective smoothing.
|
143 |
+
"""
|
144 |
+
H, W = labels_2d.shape
|
145 |
+
smoothed = labels_2d.copy()
|
146 |
+
|
147 |
+
for _ in range(n_iter):
|
148 |
+
new_label = smoothed.copy() # Clone where updates will be applied
|
149 |
+
|
150 |
+
unique_labels = np.unique(smoothed)
|
151 |
+
np.random.shuffle(unique_labels) # Process clusters in random order to avoid bias
|
152 |
+
|
153 |
+
for label in unique_labels:
|
154 |
+
if label == ignore_label:
|
155 |
+
continue
|
156 |
+
|
157 |
+
# Find the largest connected component for this label
|
158 |
+
largest_component_mask = self.find_largest_connected_component(smoothed, label)
|
159 |
+
|
160 |
+
# Identify small disconnected patches
|
161 |
+
small_patches_mask = (smoothed == label) & (~largest_component_mask)
|
162 |
+
|
163 |
+
# Label these small patches separately
|
164 |
+
labeled_small_patches, num_patches = ndimage.label(
|
165 |
+
small_patches_mask,
|
166 |
+
structure=np.array([[0,1,0], [1,1,1], [0,1,0]])
|
167 |
+
)
|
168 |
+
|
169 |
+
# Process each small patch individually
|
170 |
+
for patch_id in range(1, num_patches + 1):
|
171 |
+
patch_mask = (labeled_small_patches == patch_id)
|
172 |
+
patch_size = np.sum(patch_mask)
|
173 |
+
|
174 |
+
if patch_size < min_patch_size:
|
175 |
+
# Only smooth small isolated patches
|
176 |
+
for i, j in zip(*np.where(patch_mask)):
|
177 |
+
# Collect 4-connected neighborhood (N, S, E, W) **excluding ignore_label**
|
178 |
+
neighbors = []
|
179 |
+
if i > 0 and smoothed[i-1, j] != ignore_label: # North
|
180 |
+
neighbors.append(smoothed[i-1, j])
|
181 |
+
if i < H - 1 and smoothed[i+1, j] != ignore_label: # South
|
182 |
+
neighbors.append(smoothed[i+1, j])
|
183 |
+
if j > 0 and smoothed[i, j-1] != ignore_label: # West
|
184 |
+
neighbors.append(smoothed[i, j-1])
|
185 |
+
if j < W - 1 and smoothed[i, j+1] != ignore_label: # East
|
186 |
+
neighbors.append(smoothed[i, j+1])
|
187 |
+
|
188 |
+
if neighbors: # Avoid empty neighbor list
|
189 |
+
neighbor_counts = Counter(neighbors)
|
190 |
+
most_common_label, count = neighbor_counts.most_common(1)[0]
|
191 |
+
|
192 |
+
# Tie-breaker condition
|
193 |
+
tied_labels = [lbl for lbl, cnt in neighbor_counts.items() if cnt == count]
|
194 |
+
if len(tied_labels) > 1 and smoothed[i, j] in tied_labels:
|
195 |
+
# If there's a tie and the original class is involved, keep the original class
|
196 |
+
new_label[i, j] = smoothed[i, j]
|
197 |
+
else:
|
198 |
+
# Otherwise, assign the most common label
|
199 |
+
new_label[i, j] = most_common_label
|
200 |
+
|
201 |
+
smoothed = new_label.copy() # Apply modifications
|
202 |
+
|
203 |
+
return smoothed
|
204 |
+
|
205 |
+
|
206 |
+
def combined_silhouette_inertia_clustering(
|
207 |
+
self,
|
208 |
+
X,
|
209 |
+
k_min=1,
|
210 |
+
k_max=8,
|
211 |
+
k_avg_max=4,
|
212 |
+
silhouette_threshold=0.2,
|
213 |
+
relative_threshold=0.05,
|
214 |
+
random_state=0
|
215 |
+
):
|
216 |
+
"""
|
217 |
+
Runs KMeans for k in [k_min..k_max] exactly once each,
|
218 |
+
collects silhouette scores & inertias, and returns:
|
219 |
+
- best_k
|
220 |
+
- labels for best_k
|
221 |
+
- silhouette_scores (list or None for k=1)
|
222 |
+
- inertias (list)
|
223 |
+
|
224 |
+
The final k is chosen by both silhouette and elbow (WGSS):
|
225 |
+
1) If #points < 2 -> return k=1 and all-zero labels.
|
226 |
+
2) Compare k=1 vs. k=2 improvement in inertia and silhouette(2).
|
227 |
+
3) If both pass threshold, run k=2..k_max, pick best_k_sil (highest silhouette)
|
228 |
+
and best_k_elbow (via KneeLocator). If they differ, take their mean (rounded).
|
229 |
+
"""
|
230 |
+
n_samples = len(X)
|
231 |
+
if n_samples < 2:
|
232 |
+
return 1, np.zeros(n_samples, dtype=int), [None], [None]
|
233 |
+
|
234 |
+
# --- Fit once for k=1 ---
|
235 |
+
km1 = KMeans(n_clusters=1, random_state=random_state).fit(X)
|
236 |
+
inertia_k1 = km1.inertia_ / n_samples
|
237 |
+
silhouette_k1 = None # undefined for k=1
|
238 |
+
# print(f"k=1, inertia={inertia_k1:.4f}, silhouette=NA")
|
239 |
+
|
240 |
+
if k_max < 2:
|
241 |
+
# If k_max=1, no reason to check further
|
242 |
+
return 1, km1.labels_, [silhouette_k1], [inertia_k1]
|
243 |
+
|
244 |
+
# --- Fit once for k=2 ---
|
245 |
+
km2 = KMeans(n_clusters=2, random_state=random_state).fit(X)
|
246 |
+
inertia_k2 = km2.inertia_ / n_samples
|
247 |
+
sil_k2 = silhouette_score(X, km2.labels_)
|
248 |
+
# print(f"k=2, inertia={inertia_k2:.4f}, silhouette={sil_k2:.4f}")
|
249 |
+
|
250 |
+
relative_improvement = (inertia_k1 - inertia_k2) / inertia_k1
|
251 |
+
# print(f"[k=1 → k=2] inertia improvement: {relative_improvement:.4%}, sil(k=2)={sil_k2:.4f}")
|
252 |
+
|
253 |
+
# If improvement is too small or silhouette is too low => remain at k=1
|
254 |
+
if (relative_improvement < relative_threshold) or (sil_k2 < silhouette_threshold):
|
255 |
+
print("Choosing k=1 (failed thresholds).")
|
256 |
+
return 1, km1.labels_, [silhouette_k1, sil_k2], [inertia_k1, inertia_k2]
|
257 |
+
|
258 |
+
# --- Otherwise fit k=2..k_max and gather inertias & silhouettes ---
|
259 |
+
all_k = range(2, k_max + 1)
|
260 |
+
kmeans_models = {}
|
261 |
+
inertias = []
|
262 |
+
silhouettes = []
|
263 |
+
|
264 |
+
# We already have k=2
|
265 |
+
kmeans_models[2] = km2
|
266 |
+
inertias.append(inertia_k2)
|
267 |
+
silhouettes.append(sil_k2)
|
268 |
+
|
269 |
+
for k in range(3, k_max + 1):
|
270 |
+
km = KMeans(n_clusters=k, random_state=random_state).fit(X)
|
271 |
+
kmeans_models[k] = km
|
272 |
+
|
273 |
+
norm_inertia = km.inertia_ / n_samples
|
274 |
+
inertias.append(norm_inertia)
|
275 |
+
|
276 |
+
# If k>n_samples, silhouette_score is meaningless, but in normal usage k<<n_samples
|
277 |
+
sil_val = silhouette_score(X, km.labels_) if k <= n_samples else -1
|
278 |
+
silhouettes.append(sil_val)
|
279 |
+
|
280 |
+
# print(f"k={k}, inertia={norm_inertia:.4f}, silhouette={sil_val:.4f}")
|
281 |
+
|
282 |
+
# (a) Silhouette-based best_k_sil
|
283 |
+
best_idx_sil = np.argmax(silhouettes)
|
284 |
+
best_k_sil = best_idx_sil + 2 # offset since silhouettes[0] is k=2
|
285 |
+
|
286 |
+
# (b) Inertia-based best_k_elbow
|
287 |
+
k_candidates = np.arange(2, k_max + 1)
|
288 |
+
if len(k_candidates) == 1:
|
289 |
+
best_k_elbow = 2
|
290 |
+
else:
|
291 |
+
kn = KneeLocator(k_candidates, inertias, curve="convex", direction="decreasing")
|
292 |
+
best_k_elbow = kn.elbow
|
293 |
+
if best_k_elbow is None:
|
294 |
+
print("No elbow found => default to k=1.")
|
295 |
+
best_k_elbow = 1 # fallback
|
296 |
+
|
297 |
+
print(f"Silhouette-based best_k={best_k_sil}, elbow-based best_k={best_k_elbow}")
|
298 |
+
|
299 |
+
# Combine if there's disagreement
|
300 |
+
if best_k_sil == best_k_elbow:
|
301 |
+
final_k = max(1, min(best_k_sil, k_avg_max)) # best_k_sil
|
302 |
+
else:
|
303 |
+
avg_k = 0.5 * (best_k_sil + best_k_elbow)
|
304 |
+
final_k = int(math.ceil(avg_k))
|
305 |
+
final_k = max(1, min(final_k, k_avg_max))
|
306 |
+
# print(f"Disagreement => taking final_k=ceil((sil={best_k_sil} + elbow={best_k_elbow})/2) = {final_k}, capped by k_avg_max={k_avg_max}")
|
307 |
+
assert (final_k <= k_avg_max), f"Final k={final_k} is greater than k_avg_max={k_avg_max}"
|
308 |
+
|
309 |
+
# Get final labels from the chosen KMeans model
|
310 |
+
if final_k == 1:
|
311 |
+
final_labels = km1.labels_
|
312 |
+
else:
|
313 |
+
final_labels = kmeans_models[final_k].labels_
|
314 |
+
|
315 |
+
return final_k, final_labels, [silhouette_k1] + silhouettes, [inertia_k1] + inertias
|
316 |
+
|
317 |
+
|
318 |
+
def compute_region_statistics(self, label_map, heatmap, visited_indices, episode_num=0, step_num=0):
|
319 |
+
"""
|
320 |
+
Computes region statistics for the current smoothed label map.
|
321 |
+
|
322 |
+
Parameters
|
323 |
+
----------
|
324 |
+
heatmap : ndarray, shape (H, W)
|
325 |
+
A 2D array of probabilities (or intensities) for each patch.
|
326 |
+
Must match the shape of `self.smoothed_labels_2d`.
|
327 |
+
visited_indices : list or array of int
|
328 |
+
A list of flat indices (0-based) indicating which patches
|
329 |
+
the robot has visited. The flat index is equivalent to
|
330 |
+
r * W + c for patch at row r, column c.
|
331 |
+
|
332 |
+
Returns
|
333 |
+
-------
|
334 |
+
region_dict : dict
|
335 |
+
A dictionary keyed by cluster label ID, with values:
|
336 |
+
{
|
337 |
+
'num_patches': int,
|
338 |
+
'patches_visited': int,
|
339 |
+
'expectation': float
|
340 |
+
}
|
341 |
+
"""
|
342 |
+
# Flatten the cluster map and the heatmap to handle indexing uniformly
|
343 |
+
# label_map_2d = label_map # self.smoothed_labels_2d
|
344 |
+
# H, W = label_map_2d.shape
|
345 |
+
label_map_2d = self.smoothed_labels_2d
|
346 |
+
label_map_1d = self.smoothed_labels_2d.ravel()
|
347 |
+
heatmap_1d = heatmap.ravel()
|
348 |
+
|
349 |
+
# Identify unique labels (excluding ignore_label if present)
|
350 |
+
unique_labels = np.unique(label_map_1d)
|
351 |
+
region_dict = {}
|
352 |
+
for lbl in unique_labels:
|
353 |
+
if lbl == self.ignore_label: # skip ignore_label
|
354 |
+
continue
|
355 |
+
region_dict[lbl] = {
|
356 |
+
'num_patches': 0,
|
357 |
+
'patches_visited': 0,
|
358 |
+
'expectation': 0.0
|
359 |
+
}
|
360 |
+
|
361 |
+
# Accumulate totals for all patches
|
362 |
+
total_patches = len(label_map_1d)
|
363 |
+
for i in range(total_patches):
|
364 |
+
lbl = label_map_1d[i]
|
365 |
+
if lbl == self.ignore_label:
|
366 |
+
continue
|
367 |
+
region_dict[lbl]['num_patches'] += 1
|
368 |
+
region_dict[lbl]['expectation'] += float(heatmap_1d[i])
|
369 |
+
|
370 |
+
# # Exponential distribution (waiting time) = num_patches / expected_num_tgts
|
371 |
+
for lbl in region_dict:
|
372 |
+
region_dict[lbl]['expectation'] = region_dict[lbl]['num_patches'] / region_dict[lbl]['expectation']
|
373 |
+
|
374 |
+
# Count only unique visited patches by converting to a set.
|
375 |
+
unique_visited = set(visited_indices)
|
376 |
+
for vi in unique_visited:
|
377 |
+
if vi < 0 or vi >= total_patches:
|
378 |
+
continue # skip out-of-bounds indices
|
379 |
+
lbl = label_map_1d[vi]
|
380 |
+
# lbl = self.get_label_id(vi)
|
381 |
+
if lbl == self.ignore_label:
|
382 |
+
continue
|
383 |
+
region_dict[lbl]['patches_visited'] += 1
|
384 |
+
|
385 |
+
if self.plot:
|
386 |
+
self.plot_cluster_map(label_map_2d, heatmap, visited_indices, region_dict, episode_num, step_num)
|
387 |
+
|
388 |
+
return region_dict
|
389 |
+
|
390 |
+
def plot_cluster_map(self, cluster_map, heatmap, path_taken, region_stats_dict, episode_num, step_num, cmap='tab20'):
|
391 |
+
|
392 |
+
# 4) Plot (side-by-side) if requested
|
393 |
+
fig, axes = plt.subplots(1, 3, figsize=(12, 6))
|
394 |
+
|
395 |
+
axes[0].imshow(cluster_map, cmap='tab20')
|
396 |
+
axes[0].set_title(f"Raw KMeans Clusters")
|
397 |
+
axes[0].axis('off')
|
398 |
+
|
399 |
+
axes[1].imshow(heatmap, cmap="viridis")
|
400 |
+
axes[1].set_title("Heatmap")
|
401 |
+
axes[1].axis('off')
|
402 |
+
|
403 |
+
axes[2].imshow(cluster_map, cmap='tab20')
|
404 |
+
axes[2].set_title("Raw KMeans Clusters")
|
405 |
+
axes[2].axis('off')
|
406 |
+
|
407 |
+
path_rows, path_cols = [], []
|
408 |
+
for i, idx in enumerate(path_taken):
|
409 |
+
rr = idx // cluster_map.shape[1]
|
410 |
+
cc = idx % cluster_map.shape[1]
|
411 |
+
path_rows.append(rr)
|
412 |
+
path_cols.append(cc)
|
413 |
+
axes[2].plot(path_cols, path_rows, c="r", linewidth=2)
|
414 |
+
axes[2].plot(path_cols[-1], path_rows[-1], markersize=12, zorder=99, marker="^", ls="-", c="r", mec="black")
|
415 |
+
axes[2].plot(path_cols[0], path_rows[0], 'co', c="r", markersize=8, zorder=5)
|
416 |
+
|
417 |
+
# Create legend patches for each region.
|
418 |
+
# We assume region labels are non-negative integers.
|
419 |
+
unique_labels = sorted(region_stats_dict.keys())
|
420 |
+
# For normalization, use max label value (or 1 if max==0)
|
421 |
+
max_label = max(unique_labels) if unique_labels else 1
|
422 |
+
cm = plt.get_cmap(cmap)
|
423 |
+
legend_patches = []
|
424 |
+
for lbl in unique_labels:
|
425 |
+
# Normalize the label to [0,1] for colormap lookup.
|
426 |
+
norm_value = lbl / max_label if max_label > 0 else 0.5
|
427 |
+
color = cm(norm_value)
|
428 |
+
patch = mpatches.Patch(color=color, label=f"R{lbl}")
|
429 |
+
legend_patches.append(patch)
|
430 |
+
|
431 |
+
# Add legends to both subplots.
|
432 |
+
axes[0].legend(handles=legend_patches, title="Regions", loc='upper right')
|
433 |
+
axes[2].legend(handles=legend_patches, title="Regions", loc='upper right')
|
434 |
+
|
435 |
+
# Build the legend text for each region using the provided format:
|
436 |
+
# "R{label}: patches={num_patches}, E={expectation:.3f}, visited={num_visited}"
|
437 |
+
legend_lines = []
|
438 |
+
for label, stats in region_stats_dict.items():
|
439 |
+
line = f"R{label}: patches={stats['num_patches']}, E={stats['expectation']:.3f}, visited={stats['patches_visited']}"
|
440 |
+
legend_lines.append(line)
|
441 |
+
legend_text = "\n".join(legend_lines)
|
442 |
+
|
443 |
+
# Add the legend text as a subtitle at the bottom of the figure
|
444 |
+
# Using fig.text to place the text at the center bottom with a background box
|
445 |
+
fig.text(0.5, 0.05, legend_text, ha='center', va='bottom', fontsize=10,
|
446 |
+
bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'))
|
447 |
+
|
448 |
+
# Adjust layout to reserve space for the legend text
|
449 |
+
plt.tight_layout()
|
450 |
+
plt.subplots_adjust(bottom=0.1) # Prevent overlap
|
451 |
+
# plt.tight_layout(rect=[0, 0.05, 1, 1])
|
452 |
+
# plt.show()
|
453 |
+
|
454 |
+
|
455 |
+
# gifs_folder = '/home/user/VLM-Search/inference/test_results/gifs'
|
456 |
+
if not os.path.exists(self.gifs_dir):
|
457 |
+
os.makedirs(self.gifs_dir)
|
458 |
+
|
459 |
+
plt.savefig(f'{self.gifs_dir}/kmeans_{episode_num}_{step_num}.png'.format(dpi=150))
|
460 |
+
self.kmeans_frame_files.append(f'{self.gifs_dir}/kmeans_{episode_num}_{step_num}.png')
|
461 |
+
plt.close()
|
462 |
+
|
463 |
+
|
464 |
+
# ------ OTHER HELPER FUNCTIONS -------#
|
465 |
+
def get_label_id(self, patch_idx):
|
466 |
+
"""
|
467 |
+
Given a flattened index (row-major order) within the 24×24 grid,
|
468 |
+
return the average probability of the corresponding merged region.
|
469 |
+
|
470 |
+
:param patch_idx: Flattened index (0 to H×W-1).
|
471 |
+
:return: The region-average probability, or 0.0 if the patch
|
472 |
+
is out of bounds or not in a labeled region.
|
473 |
+
"""
|
474 |
+
# Guard: check we have a merged label map
|
475 |
+
if self.smoothed_labels_2d is None:
|
476 |
+
raise ValueError("Merged markers do not exist. "
|
477 |
+
"Please call `incorporate_kmeans_boundaries(...)` first.")
|
478 |
+
|
479 |
+
return self.smoothed_labels_2d.ravel()[patch_idx]
|
480 |
+
|
481 |
+
|
482 |
+
# ------ OTHER HELPER FUNCTIONS -------#
|
483 |
+
def get_probs(self, patch_idx, heatmap):
|
484 |
+
"""
|
485 |
+
Given a flattened index (row-major order) within the 24×24 grid,
|
486 |
+
return the average probability of the corresponding merged region.
|
487 |
+
|
488 |
+
:param patch_idx: Flattened index (0 to H×W-1).
|
489 |
+
:return: The region-average probability, or 0.0 if the patch
|
490 |
+
is out of bounds or not in a labeled region.
|
491 |
+
"""
|
492 |
+
|
493 |
+
return heatmap.ravel()[patch_idx]
|
494 |
+
|
495 |
+
##############################
|
496 |
+
# Main functions
|
497 |
+
##############################
|
498 |
+
def fit_predict(self, patch_embeds, map_shape):
|
499 |
+
"""
|
500 |
+
Main entry point.
|
501 |
+
|
502 |
+
Parameters
|
503 |
+
----------
|
504 |
+
patch_embeds : ndarray, shape (N, D)
|
505 |
+
The satellite patch embeddings to be clustered.
|
506 |
+
map_shape : tuple of (H, W)
|
507 |
+
The 2D layout of patches. Must satisfy H * W = N.
|
508 |
+
do_plot : bool
|
509 |
+
Whether to display the clustering maps before and after smoothing.
|
510 |
+
|
511 |
+
Returns
|
512 |
+
-------
|
513 |
+
labels_1d : ndarray, shape (N,)
|
514 |
+
The final cluster assignments from KMeans (UNSMOOTHED).
|
515 |
+
"""
|
516 |
+
# 1) Run combined silhouette & inertia
|
517 |
+
best_k, final_labels, silhouettes, inertias = self.combined_silhouette_inertia_clustering(
|
518 |
+
X=patch_embeds,
|
519 |
+
k_min=self.k_min,
|
520 |
+
k_max=self.k_max,
|
521 |
+
k_avg_max=self.k_avg_max,
|
522 |
+
silhouette_threshold=self.silhouette_threshold,
|
523 |
+
relative_threshold=self.relative_threshold,
|
524 |
+
random_state=self.random_state
|
525 |
+
)
|
526 |
+
self.final_k = best_k
|
527 |
+
self.final_labels_1d = final_labels.copy() # store raw labels
|
528 |
+
|
529 |
+
# 2) Reshape for display
|
530 |
+
H, W = map_shape
|
531 |
+
cluster_map = final_labels.reshape(H, W)
|
532 |
+
|
533 |
+
# 3) Apply smoothing
|
534 |
+
# cluster_map_smoothed = self.smooth_small_patches(
|
535 |
+
# labels_2d=cluster_map,
|
536 |
+
# min_patch_size=self.min_patch_size,
|
537 |
+
# n_iter=self.n_smooth_iter,
|
538 |
+
# ignore_label=self.ignore_label
|
539 |
+
# )
|
540 |
+
cluster_map_smoothed = cluster_map
|
541 |
+
self.smoothed_labels_2d = cluster_map_smoothed.copy()
|
542 |
+
|
543 |
+
# 5) Return the smoothed 2D labels
|
544 |
+
return self.smoothed_labels_2d
|
545 |
+
|
546 |
+
|
547 |
+
if __name__ == "__main__":
|
548 |
+
|
549 |
+
file_path="./expt/patch_embeds_99.npy"
|
550 |
+
patch_embeds = np.load(file_path)
|
551 |
+
map_shape = (int(np.sqrt(patch_embeds.shape[0])), int(np.sqrt(patch_embeds.shape[0])))
|
552 |
+
|
553 |
+
# 1) Load a 24x24 heatmap
|
554 |
+
file_path="./expt/seg_mask_step0_v1.npy"
|
555 |
+
heatmap = np.load(file_path)
|
556 |
+
heatmap = np.clip(heatmap,0,100)
|
557 |
+
|
558 |
+
# Normalize heatmap
|
559 |
+
num_targets = 19
|
560 |
+
scale = num_targets / (heatmap.sum())
|
561 |
+
heatmap *= scale
|
562 |
+
|
563 |
+
# Instantiate our clusterer
|
564 |
+
clusterer = CombinedSilhouetteInertiaClusterer(
|
565 |
+
k_min=1,
|
566 |
+
k_max=8,
|
567 |
+
k_avg_max=4,
|
568 |
+
silhouette_threshold=0.15,
|
569 |
+
relative_threshold=0.15,
|
570 |
+
random_state=0,
|
571 |
+
min_patch_size=5, # smoothing parameter
|
572 |
+
n_smooth_iter=2, # smoothing parameter
|
573 |
+
ignore_label=-1
|
574 |
+
)
|
575 |
+
|
576 |
+
# Fit & predict (this will also plot the clusters before & after smoothing)
|
577 |
+
smoothed_labels_2d = clusterer.fit_predict(
|
578 |
+
patch_embeds=patch_embeds,
|
579 |
+
map_shape=map_shape,
|
580 |
+
)
|
581 |
+
|
582 |
+
# 2) Simulate a random walk
|
583 |
+
max_r, max_c=heatmap.shape[0]-1, heatmap.shape[1]-1
|
584 |
+
steps=100
|
585 |
+
current_pos=(0,0)
|
586 |
+
moves_8=[(-1,-1),(-1,0),(-1,1),
|
587 |
+
( 0,-1), ( 0,1),
|
588 |
+
( 1,-1),( 1,0), ( 1,1)]
|
589 |
+
path_coords=[current_pos]
|
590 |
+
# num_found_list=[0]
|
591 |
+
for _ in range(steps):
|
592 |
+
r,c=current_pos
|
593 |
+
dr,dc=random.choice(moves_8)
|
594 |
+
nr, nc=r+dr,c+dc
|
595 |
+
if 0<=nr<=max_r and 0<=nc<=max_c:
|
596 |
+
current_pos=(nr,nc)
|
597 |
+
path_coords.append(current_pos)
|
598 |
+
# 10% chance for new target
|
599 |
+
# num_found_list.append(random.random()<0.1)
|
600 |
+
# num_found_list.append(0)
|
601 |
+
# num_found_list[0]=1
|
602 |
+
# num_found_list[1]=2
|
603 |
+
# num_found_list[2]=3
|
604 |
+
|
605 |
+
# Flatten
|
606 |
+
width=heatmap.shape[1]
|
607 |
+
flattened_visits=[]
|
608 |
+
for (rr,cc) in path_coords:
|
609 |
+
idx=rr*width+cc
|
610 |
+
flattened_visits.append(idx)
|
611 |
+
|
612 |
+
|
613 |
+
# Input heatmap (for region statistics)
|
614 |
+
region_dict = clusterer.compute_region_statistics(smoothed_labels_2d, heatmap, flattened_visits, plot=True)
|
615 |
+
|
616 |
+
# # save the smoother labels as .npy
|
617 |
+
# np.save('smoothed_labels_2d.npy', smoothed_labels_2d)
|
618 |
+
|
619 |
+
# print("Chosen k:", clusterer.final_k)
|
620 |
+
# print("Smoothed labels shape:", smoothed_labels_2d.shape)
|
Taxabind/Taxabind/SatBind/model.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
# Ensure the correct directory is at the front of sys.path
|
4 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
5 |
+
|
6 |
+
import open_clip
|
7 |
+
import pytorch_lightning as pl
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import numpy as np
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from transformers import CLIPVisionModelWithProjection
|
13 |
+
from dataset import SatNatDataset
|
14 |
+
|
15 |
+
from pytorch_lightning.loggers import WandbLogger
|
16 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
17 |
+
from config import config
|
18 |
+
|
19 |
+
######################################
|
20 |
+
class UpdateDatasetEpochCallback(pl.Callback):
|
21 |
+
def on_train_epoch_start(self, trainer, pl_module):
|
22 |
+
"""Called at the start of each training epoch."""
|
23 |
+
print("!!!trainer.current_epoch: ", trainer.current_epoch)
|
24 |
+
current_epoch = trainer.current_epoch
|
25 |
+
|
26 |
+
# Access the dataset and set the epoch
|
27 |
+
train_loader = trainer.train_dataloader
|
28 |
+
# trainer.train_dataloader can be a list in some configs, handle that if needed
|
29 |
+
dataset = train_loader.dataset
|
30 |
+
|
31 |
+
# Now set the epoch:
|
32 |
+
if hasattr(dataset, 'set_epoch'):
|
33 |
+
dataset.set_epoch(current_epoch)
|
34 |
+
######################################
|
35 |
+
|
36 |
+
|
37 |
+
def create_pairwise_mask(labels):
|
38 |
+
labels = labels.reshape(-1)
|
39 |
+
num_samples = len(labels)
|
40 |
+
pairwise_mask = torch.zeros(num_samples, num_samples).to(labels.device)
|
41 |
+
|
42 |
+
for i in range(num_samples):
|
43 |
+
pairwise_mask[i, :] = (labels == labels[i])
|
44 |
+
|
45 |
+
return pairwise_mask
|
46 |
+
|
47 |
+
def clip_loss(similarity: torch.Tensor, label) -> torch.Tensor:
|
48 |
+
overhead_img_loss = contrastive_loss(similarity, label)
|
49 |
+
ground_img_loss = contrastive_loss(similarity.t(), label.t())
|
50 |
+
return 0.5*torch.mean(torch.sum(overhead_img_loss, dim=-1)) + 0.5*torch.mean(torch.sum(ground_img_loss, dim=-1))
|
51 |
+
|
52 |
+
def contrastive_loss(logits: torch.Tensor, label) -> torch.Tensor:
|
53 |
+
gt = create_pairwise_mask(label)
|
54 |
+
return -gt*torch.log(logits.softmax(-1)+1e-6)
|
55 |
+
|
56 |
+
class SatBind(pl.LightningModule):
|
57 |
+
def __init__(self, train_dataset, val_dataset, **kwargs):
|
58 |
+
super().__init__()
|
59 |
+
self.train_dataset = train_dataset
|
60 |
+
self.val_dataset = val_dataset
|
61 |
+
|
62 |
+
#initialize bio CLIP with frozen weights
|
63 |
+
self.bio_model, *_ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip')
|
64 |
+
if config.locked_tuning:
|
65 |
+
for param in self.bio_model.parameters():
|
66 |
+
param.requires_grad = False
|
67 |
+
|
68 |
+
#initialize CLIP with trainable weights
|
69 |
+
self.imo_encoder = CLIPVisionModelWithProjection.from_pretrained(config.sat_encoder).train()
|
70 |
+
for layer in self.imo_encoder.children():
|
71 |
+
if hasattr(layer, 'reset_parameters'):
|
72 |
+
layer.reset_parameters()
|
73 |
+
|
74 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
75 |
+
self.batch_size = kwargs.get('batch_size', config.batch_size)
|
76 |
+
self.lr = kwargs.get('lr', config.lr)
|
77 |
+
|
78 |
+
# # ADDED
|
79 |
+
clip_cfg = self.imo_encoder.config
|
80 |
+
self.visual_projection_custom = nn.Linear(clip_cfg.hidden_size, 512, bias=False) # clip_cfg.projection_dim)
|
81 |
+
# print("clip_cfg.hidden_size: ", clip_cfg.hidden_size)
|
82 |
+
# print("clip_cfg.projection_dim: ", clip_cfg.projection_dim)
|
83 |
+
|
84 |
+
# # Since CLIP's vision projection is not used (prevent pytorch lightning issues)
|
85 |
+
# for param in self.imo_encoder.visual_projection.parameters():
|
86 |
+
# param.requires_grad = False
|
87 |
+
# for param in self.imo_encoder.transformer.parameters(): # If the transformer is unused
|
88 |
+
# param.requires_grad = False
|
89 |
+
|
90 |
+
def forward(self, batch):
|
91 |
+
img, imo, label, patch_idx, *_ = batch
|
92 |
+
batch_size = img.shape[0]
|
93 |
+
|
94 |
+
#compute bioclip embeddings
|
95 |
+
img_embeds, *_ = self.bio_model(img) # (batch_size, proj_dim)
|
96 |
+
|
97 |
+
# ## ORIGINAL: compute overhead embeddings
|
98 |
+
# imo_embeds = self.imo_encoder(imo).image_embeds # (batch, proj_dim)
|
99 |
+
# print("[ORIG] imo_embeds.shape: ", imo_embeds.shape)
|
100 |
+
# print("[ORIG] torch.min(imo_embeds): ", torch.min(imo_embeds))
|
101 |
+
# print("[ORIG] torch.max(imo_embeds): ", torch.max(imo_embeds))
|
102 |
+
|
103 |
+
# NEW:
|
104 |
+
imo_embeds = self.imo_encoder(imo).last_hidden_state # (batch, Patches, hidden_dim)
|
105 |
+
imo_embeds = imo_embeds[torch.arange(batch_size), patch_idx] # (batch, hidden_dim)
|
106 |
+
# imo_embeds = self.imo_encoder.vision_model.post_layernorm(imo_embeds) # (batch, hidden_dim)
|
107 |
+
# imo_embeds = self.imo_encoder.visual_projection(imo_embeds) # (batch_size, proj_dim)
|
108 |
+
imo_embeds = self.visual_projection_custom(imo_embeds) # (batch_size, proj_dim)
|
109 |
+
|
110 |
+
return img_embeds, imo_embeds, label
|
111 |
+
|
112 |
+
|
113 |
+
def shared_step(self, batch, return_sim_matrix=False):
|
114 |
+
|
115 |
+
img_embeds, imo_embeds, label, *_ = self(batch)
|
116 |
+
#normalize embeddings
|
117 |
+
#img embeds is already normalized
|
118 |
+
img_embeds = img_embeds
|
119 |
+
imo_embeds = torch.nn.functional.normalize(imo_embeds, dim=-1)
|
120 |
+
# print("[NORM] torch.min(imo_embeds): ", torch.min(imo_embeds))
|
121 |
+
# print("[NORM] torch.max(imo_embeds): ", torch.max(imo_embeds))
|
122 |
+
|
123 |
+
#exponentiate the log of temperrature
|
124 |
+
logit_scale = self.logit_scale.exp()
|
125 |
+
|
126 |
+
#compute similarity
|
127 |
+
img_to_imo_sim = img_embeds @ imo_embeds.t() * logit_scale
|
128 |
+
|
129 |
+
if return_sim_matrix:
|
130 |
+
img_to_imo_sim_copy = img_to_imo_sim.clone().detach()
|
131 |
+
|
132 |
+
loss = clip_loss(img_to_imo_sim, label)
|
133 |
+
|
134 |
+
if return_sim_matrix:
|
135 |
+
return loss, img_to_imo_sim_copy
|
136 |
+
else:
|
137 |
+
return loss
|
138 |
+
|
139 |
+
|
140 |
+
def training_step(self, batch, batch_idx):
|
141 |
+
loss = self.shared_step(batch)
|
142 |
+
self.log('train_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
|
143 |
+
self.log('temperature', self.logit_scale.data, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
|
144 |
+
return loss
|
145 |
+
|
146 |
+
def validation_step(self, batch, batch_idx):
|
147 |
+
loss = self.shared_step(batch)
|
148 |
+
self.log('val_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
|
149 |
+
return loss
|
150 |
+
|
151 |
+
def train_dataloader(self):
|
152 |
+
return DataLoader(self.train_dataset,
|
153 |
+
batch_size=self.batch_size,
|
154 |
+
num_workers=config.num_workers,
|
155 |
+
shuffle=True, # True
|
156 |
+
persistent_workers=False)
|
157 |
+
|
158 |
+
def val_dataloader(self):
|
159 |
+
return DataLoader(self.val_dataset,
|
160 |
+
batch_size=self.batch_size,
|
161 |
+
num_workers=config.num_workers,
|
162 |
+
shuffle=False,
|
163 |
+
persistent_workers=False)
|
164 |
+
|
165 |
+
def configure_optimizers(self):
|
166 |
+
params = self.parameters()
|
167 |
+
self.optim = torch.optim.AdamW(params,
|
168 |
+
lr=self.lr,
|
169 |
+
betas=(0.9,0.98),
|
170 |
+
eps=1e-6
|
171 |
+
)
|
172 |
+
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
173 |
+
optimizer=self.optim,
|
174 |
+
T_0=20,
|
175 |
+
eta_min=1e-6
|
176 |
+
)
|
177 |
+
return [self.optim], [self.scheduler]
|
178 |
+
|
179 |
+
|
180 |
+
if __name__ == '__main__':
|
181 |
+
img_dir = config.img_dir
|
182 |
+
imo_dir = config.imo_dir
|
183 |
+
imo_dir_val = config.imo_dir_val
|
184 |
+
train_json_path = config.train_json_path
|
185 |
+
val_json_path = config.val_json_path
|
186 |
+
# filtered_train_json_path = config.filtered_train_json_path
|
187 |
+
# filtered_val_json_path = config.filtered_val_json_path
|
188 |
+
sat_to_img_ids_train_json_path = config.sat_to_img_ids_train_json_path
|
189 |
+
sat_to_img_ids_val_json_path = config.sat_to_img_ids_val_json_path
|
190 |
+
patch_size = config.patch_size
|
191 |
+
|
192 |
+
#define dataset
|
193 |
+
train_dataset = SatNatDataset(img_dir, imo_dir, train_json_path, sat_to_img_ids_train_json_path, patch_size)
|
194 |
+
val_dataset = SatNatDataset(img_dir, imo_dir_val, val_json_path, sat_to_img_ids_val_json_path, patch_size, mode='val')
|
195 |
+
|
196 |
+
#define model
|
197 |
+
model = SatBind(train_dataset=train_dataset, val_dataset=val_dataset)
|
198 |
+
torch.cuda.empty_cache()
|
199 |
+
|
200 |
+
checkpoint = ModelCheckpoint(
|
201 |
+
monitor='val_loss',
|
202 |
+
dirpath=config.save_dir,
|
203 |
+
filename=config.filename,
|
204 |
+
mode='min',
|
205 |
+
save_top_k=1, # save only the best checkpoint (according to val_loss)
|
206 |
+
save_last=True # also save the last checkpoint every epoch
|
207 |
+
# last_filename=config.filename + "-LAST"
|
208 |
+
)
|
209 |
+
checkpoint.CHECKPOINT_NAME_LAST = config.filename + "-LAST" # Rename last checkpoint name
|
210 |
+
|
211 |
+
trainer = pl.Trainer(
|
212 |
+
accelerator='gpu',
|
213 |
+
strategy='ddp_find_unused_parameters_true', # ddp (orig), ddp_find_unused_parameters_true (supress pl issues with unused trainable params)
|
214 |
+
devices=config.devices,
|
215 |
+
max_epochs=config.max_epochs,
|
216 |
+
num_nodes=1,
|
217 |
+
callbacks=[checkpoint, UpdateDatasetEpochCallback()],
|
218 |
+
accumulate_grad_batches=config.accumulate_grad_batches,
|
219 |
+
log_every_n_steps=1,
|
220 |
+
val_check_interval=config.val_check_interval,
|
221 |
+
)
|
222 |
+
|
223 |
+
if config.resume_from_checkpoint:
|
224 |
+
trainer.fit(model, ckpt_path=f"{config.save_dir}/{config.resume_checkpoint_name}.ckpt")
|
225 |
+
else:
|
226 |
+
trainer.fit(model)
|
Taxabind/Taxabind/SatBind/watershed_segmentation.py
ADDED
@@ -0,0 +1,1094 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import cv2
|
5 |
+
import random
|
6 |
+
import math
|
7 |
+
import numpy as np
|
8 |
+
from collections import Counter
|
9 |
+
from scipy import ndimage
|
10 |
+
|
11 |
+
# import matplotlib
|
12 |
+
# matplotlib.use("Agg") # <-- key line to avoid tkinter dependency
|
13 |
+
|
14 |
+
class WatershedBinomial:
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
N=5.0, # total expected # of targets across entire map
|
18 |
+
blur_kernel=(5, 5),
|
19 |
+
sure_fg_factor=0.7,
|
20 |
+
dilation_kernel_size=(3, 3),
|
21 |
+
dilation_iters=3,
|
22 |
+
region_size_thresh=10,
|
23 |
+
plot_img=False,
|
24 |
+
gifs_dir="./gifs",
|
25 |
+
colormap=plt.cm.tab10,
|
26 |
+
title="Watershed + Binomial Steps"
|
27 |
+
):
|
28 |
+
"""
|
29 |
+
A binomial viewpoint:
|
30 |
+
1) We run watershed on a 24x24 heatmap.
|
31 |
+
2) We scale each region's probability p_i so sum of expected targets across
|
32 |
+
all regions = N. Then expected # in region i is E_i = size_i * p_i.
|
33 |
+
3) We track how many targets each region has found so far (#found).
|
34 |
+
4) The average steps to find the next target is size_i/E_i = 1/p_i.
|
35 |
+
5) We store these as class variables so they can be accessed later:
|
36 |
+
- region_prob[lab]
|
37 |
+
- region_visited_str[lab]
|
38 |
+
- region_expected[lab]
|
39 |
+
- region_found[lab]
|
40 |
+
- region_steps_next[lab]
|
41 |
+
6) The final legend includes all this info for each region.
|
42 |
+
"""
|
43 |
+
|
44 |
+
self.N = N
|
45 |
+
self.blur_kernel = blur_kernel
|
46 |
+
self.sure_fg_factor = sure_fg_factor
|
47 |
+
self.dilation_kernel_size = dilation_kernel_size
|
48 |
+
self.dilation_iters = dilation_iters
|
49 |
+
self.region_size_thresh = region_size_thresh
|
50 |
+
self.plot_img = plot_img
|
51 |
+
self.gifs_dir = gifs_dir
|
52 |
+
self.colormap = colormap
|
53 |
+
self.title = title
|
54 |
+
|
55 |
+
self.raw_heatmap = None
|
56 |
+
self.flattened_visits = None
|
57 |
+
self.num_found_list = None
|
58 |
+
|
59 |
+
# --------------------------
|
60 |
+
# Intermediate images
|
61 |
+
# --------------------------
|
62 |
+
self.input_heatmap_24x24 = None
|
63 |
+
self.padded_heatmap_26x26 = None
|
64 |
+
self.blurred_heatmap_26x26 = None
|
65 |
+
self.binary_mask_26x26 = None
|
66 |
+
self.dist_transform_26x26 = None
|
67 |
+
self.sure_fg_mask_26x26 = None
|
68 |
+
self.watershed_markers_26x26 = None
|
69 |
+
self.color_watershed_raw_26x26 = None
|
70 |
+
|
71 |
+
# --------------------------
|
72 |
+
# Final label map + overlays
|
73 |
+
# --------------------------
|
74 |
+
self.final_markers_24x24 = None
|
75 |
+
self.final_output_no_vis_24x24 = None
|
76 |
+
self.final_output_with_vis_24x24 = None
|
77 |
+
|
78 |
+
# --------------------------
|
79 |
+
# Class variables of interest
|
80 |
+
# --------------------------
|
81 |
+
# region_prob[lab] -> scaled probability p_i
|
82 |
+
# region_visited_str[lab]-> visited fraction as "visited_count/size_i"
|
83 |
+
# region_expected[lab] -> E_i = size_i * p_i
|
84 |
+
# region_found[lab] -> number of targets found so far
|
85 |
+
# region_steps_next[lab] -> size_i/E_i (1/p_i)
|
86 |
+
self.region_prob = {}
|
87 |
+
self.visited_num = {}
|
88 |
+
self.region_visited_str = {}
|
89 |
+
self.region_centroids = {}
|
90 |
+
self.region_expected = {}
|
91 |
+
self.region_found = {}
|
92 |
+
self.region_steps_next = {}
|
93 |
+
|
94 |
+
# --------------------------
|
95 |
+
# For MERGED segmentation with k-means boundaries
|
96 |
+
# We'll store them in separate variables
|
97 |
+
# --------------------------
|
98 |
+
self.merged_markers_24x24 = None
|
99 |
+
self.merged_output_no_vis = None
|
100 |
+
self.merged_output_with_vis = None
|
101 |
+
|
102 |
+
self.merged_region_prob = {}
|
103 |
+
self.merged_visited_num = {}
|
104 |
+
self.merged_region_visited_str = {}
|
105 |
+
self.merged_region_expected = {}
|
106 |
+
self.merged_region_found = {}
|
107 |
+
self.merged_region_steps_next = {}
|
108 |
+
self.merged_region_centroids = {}
|
109 |
+
|
110 |
+
# OTHERS
|
111 |
+
self.segmentor_frame_files = []
|
112 |
+
self.kmeans_merged_frame_files = []
|
113 |
+
|
114 |
+
# Persistent label-color mapping (Preassign colors for labels)
|
115 |
+
self.label_to_color = {}
|
116 |
+
for i in range(1, 101):
|
117 |
+
color = np.array(self.colormap(i % 10)[:3]) * 255
|
118 |
+
self.label_to_color[i] = color.astype(np.uint8)
|
119 |
+
|
120 |
+
# --------------------------
|
121 |
+
# Public Entry
|
122 |
+
# --------------------------
|
123 |
+
|
124 |
+
def run(self, raw_heatmap, flattened_visits, num_found_list, step=0):
|
125 |
+
"""
|
126 |
+
1) Watershed pipeline
|
127 |
+
2) Scale probabilities => sum_i #expected = N
|
128 |
+
3) Build final overlays & store region stats
|
129 |
+
4) Plot subplots
|
130 |
+
"""
|
131 |
+
self.raw_heatmap = raw_heatmap
|
132 |
+
self.flattened_visits = flattened_visits
|
133 |
+
self.num_found_list = num_found_list
|
134 |
+
self.input_heatmap_24x24 = self.raw_heatmap.copy()
|
135 |
+
|
136 |
+
# Step A) Watershed + merges
|
137 |
+
self._watershed_on_padded()
|
138 |
+
self.final_markers_24x24 = self._postprocess_cropped_markers()
|
139 |
+
|
140 |
+
# Step B) Build final binomial-based stats
|
141 |
+
self._build_final_outputs()
|
142 |
+
|
143 |
+
# Step C) Plot
|
144 |
+
if self.plot_img:
|
145 |
+
self._plot_all_steps(step)
|
146 |
+
return self.final_markers_24x24
|
147 |
+
|
148 |
+
# --------------------------
|
149 |
+
# Watershed pipeline
|
150 |
+
# --------------------------
|
151 |
+
|
152 |
+
def _pad_image(self, image, pad_size=1):
|
153 |
+
return np.pad(image, pad_size, mode='edge')
|
154 |
+
|
155 |
+
def _crop_image(self, image, pad_size=1):
|
156 |
+
return image[pad_size:-pad_size, pad_size:-pad_size]
|
157 |
+
|
158 |
+
def _watershed_on_padded(self):
|
159 |
+
padded = self._pad_image(self.raw_heatmap, pad_size=1)
|
160 |
+
self.padded_heatmap_26x26 = padded
|
161 |
+
|
162 |
+
# Normalize
|
163 |
+
norm_padded = (
|
164 |
+
(padded - padded.min()) / (padded.max() - padded.min() + 1e-9)
|
165 |
+
* 255
|
166 |
+
).astype(np.uint8)
|
167 |
+
|
168 |
+
# (2) Blur
|
169 |
+
if self.blur_kernel != (0,0):
|
170 |
+
self.blurred_heatmap_26x26 = cv2.GaussianBlur(norm_padded, self.blur_kernel, 0)
|
171 |
+
else:
|
172 |
+
self.blurred_heatmap_26x26 = norm_padded.copy()
|
173 |
+
|
174 |
+
# (3) Binary
|
175 |
+
_, bin_mask = cv2.threshold(
|
176 |
+
self.blurred_heatmap_26x26,
|
177 |
+
0, 255,
|
178 |
+
cv2.THRESH_BINARY + cv2.THRESH_OTSU
|
179 |
+
)
|
180 |
+
self.binary_mask_26x26 = bin_mask
|
181 |
+
|
182 |
+
# (4) Distance transform
|
183 |
+
dist_transform = cv2.distanceTransform(bin_mask, cv2.DIST_L2, 5)
|
184 |
+
self.dist_transform_26x26 = dist_transform
|
185 |
+
|
186 |
+
# (5) Sure FG
|
187 |
+
_, sure_fg = cv2.threshold(
|
188 |
+
dist_transform,
|
189 |
+
self.sure_fg_factor * dist_transform.max(),
|
190 |
+
255, 0
|
191 |
+
)
|
192 |
+
sure_fg = np.uint8(sure_fg)
|
193 |
+
self.sure_fg_mask_26x26 = sure_fg
|
194 |
+
|
195 |
+
# sure BG
|
196 |
+
kernel = np.ones(self.dilation_kernel_size, np.uint8)
|
197 |
+
sure_bg = cv2.dilate(bin_mask, kernel, iterations=self.dilation_iters)
|
198 |
+
unknown = cv2.subtract(sure_bg, sure_fg)
|
199 |
+
|
200 |
+
# Markers
|
201 |
+
_, markers = cv2.connectedComponents(sure_fg)
|
202 |
+
markers += 1
|
203 |
+
markers[unknown==255] = 0
|
204 |
+
|
205 |
+
# (6) Watershed
|
206 |
+
color_img = cv2.applyColorMap(self.blurred_heatmap_26x26, cv2.COLORMAP_VIRIDIS)
|
207 |
+
markers = cv2.watershed(color_img, markers)
|
208 |
+
self.watershed_markers_26x26 = markers
|
209 |
+
|
210 |
+
self.color_watershed_raw_26x26 = self._color_label_map(markers)
|
211 |
+
|
212 |
+
def _postprocess_cropped_markers(self):
|
213 |
+
cropped = self._crop_image(self.watershed_markers_26x26, 1)
|
214 |
+
cropped = self._fix_boundary_labels(cropped)
|
215 |
+
cropped = self._separate_subregions(cropped)
|
216 |
+
cropped = self._merge_small_regions(cropped, self.region_size_thresh)
|
217 |
+
return cropped
|
218 |
+
|
219 |
+
def _fix_boundary_labels(self, markers):
|
220 |
+
b_idx = np.argwhere(markers==-1)
|
221 |
+
neighbors_8 = [
|
222 |
+
(-1,-1), (-1,0), (-1,1),
|
223 |
+
( 0,-1), ( 0,1),
|
224 |
+
( 1,-1), ( 1,0), ( 1,1)
|
225 |
+
]
|
226 |
+
for (r,c) in b_idx:
|
227 |
+
neigh_labs=[]
|
228 |
+
for dr,dc in neighbors_8:
|
229 |
+
rr,cc=r+dr,c+dc
|
230 |
+
if 0<=rr<markers.shape[0] and 0<=cc<markers.shape[1]:
|
231 |
+
lb=markers[rr,cc]
|
232 |
+
if lb>0:
|
233 |
+
neigh_labs.append(lb)
|
234 |
+
if neigh_labs:
|
235 |
+
markers[r,c]=np.bincount(neigh_labs).argmax()
|
236 |
+
else:
|
237 |
+
markers[r,c]=0
|
238 |
+
return markers
|
239 |
+
|
240 |
+
def _separate_subregions(self, markers):
|
241 |
+
refined = np.zeros_like(markers, dtype=np.int32)
|
242 |
+
refined[markers==-1] = -1
|
243 |
+
refined[markers==0] = 0
|
244 |
+
|
245 |
+
next_label=1
|
246 |
+
labs = np.unique(markers)
|
247 |
+
for lb in labs:
|
248 |
+
if lb<=0:
|
249 |
+
continue
|
250 |
+
region_mask=(markers==lb).astype(np.uint8)
|
251 |
+
n_sub, sub_markers = cv2.connectedComponents(region_mask)
|
252 |
+
for sub_id in range(1,n_sub):
|
253 |
+
subregion_mask=(sub_markers==sub_id)
|
254 |
+
refined[subregion_mask]=next_label
|
255 |
+
next_label+=1
|
256 |
+
return refined
|
257 |
+
|
258 |
+
|
259 |
+
def _merge_small_regions(self, markers, min_patch_size, n_iter=1, ignore_label=-99):
|
260 |
+
"""
|
261 |
+
Combined approach using:
|
262 |
+
(A) Pixel-wise smoothing of small patches (iterative),
|
263 |
+
but now using 8-connected neighbors.
|
264 |
+
(B) After that, remove any connected component whose
|
265 |
+
size < self.region_size_thresh by merging it into
|
266 |
+
the majority label among its 8-connected boundary.
|
267 |
+
|
268 |
+
1) We do not change the pixel-wise smoothing logic, except
|
269 |
+
that we consider 8 neighbors instead of 4.
|
270 |
+
2) Then we do a region-based check for all subregions whose
|
271 |
+
size < region_size_thresh. Each such region is assigned
|
272 |
+
to the 8-neighbor majority among that region's boundary
|
273 |
+
pixels. If a tie occurs, we pick the largest label ID
|
274 |
+
or keep the original? (You can define your tie logic
|
275 |
+
as you wish, shown here picking the top or random.)
|
276 |
+
"""
|
277 |
+
|
278 |
+
H, W = markers.shape
|
279 |
+
|
280 |
+
# -----------------------------------------------------------
|
281 |
+
# (A) Pixel-wise smoothing with 8-connected neighbors
|
282 |
+
# for small sub-patches of each label
|
283 |
+
# -----------------------------------------------------------
|
284 |
+
|
285 |
+
markers = ndimage.median_filter(markers, size=3)
|
286 |
+
|
287 |
+
# -----------------------------------------------------------
|
288 |
+
# (B) Now remove entire subregions < min_patch_size by
|
289 |
+
# reassigning them to the 8-neighbor majority of boundary
|
290 |
+
# -----------------------------------------------------------
|
291 |
+
|
292 |
+
# smoothed = markers.copy()
|
293 |
+
structure_8 = np.array([[1,1,1],
|
294 |
+
[1,1,1],
|
295 |
+
[1,1,1]], dtype=np.uint8)
|
296 |
+
|
297 |
+
# identify connected components of each label
|
298 |
+
# step: we'll do a multi-pass until no small region is left
|
299 |
+
changed = True
|
300 |
+
while changed:
|
301 |
+
changed = False
|
302 |
+
label_map = markers
|
303 |
+
labels = np.unique(label_map[label_map>0])
|
304 |
+
for label in labels:
|
305 |
+
# If label patches is > threshold
|
306 |
+
if label < 0 or np.sum(label_map == label) >= min_patch_size:
|
307 |
+
continue
|
308 |
+
|
309 |
+
# find connected components of 'label'
|
310 |
+
mask = (label_map == label)
|
311 |
+
labeled_cc, num_cc = ndimage.label(mask, structure=structure_8)
|
312 |
+
if num_cc < 1:
|
313 |
+
continue
|
314 |
+
|
315 |
+
for cc_id in range(1, num_cc+1):
|
316 |
+
cc_mask = (labeled_cc == cc_id)
|
317 |
+
# size_cc = np.sum(cc_mask)
|
318 |
+
# if 0<size_cc<min_patch_size:
|
319 |
+
# compute the 8-neighbor majority from boundary
|
320 |
+
boundary_pixels = []
|
321 |
+
coords = np.argwhere(cc_mask)
|
322 |
+
for (r,c) in coords:
|
323 |
+
# check if (r,c) is on boundary => or we can check neighbors
|
324 |
+
for dr in (-1,0,1):
|
325 |
+
for dc in (-1,0,1):
|
326 |
+
if dr==0 and dc==0:
|
327 |
+
continue
|
328 |
+
rr,cc2 = r+dr, c+dc
|
329 |
+
if 0<=rr<H and 0<=cc2<W:
|
330 |
+
# If neighbor label != label, it's a boundary
|
331 |
+
neighbor_label = label_map[rr, cc2]
|
332 |
+
if neighbor_label != label and neighbor_label > 0:
|
333 |
+
boundary_pixels.append(neighbor_label)
|
334 |
+
|
335 |
+
if boundary_pixels:
|
336 |
+
counter = Counter(boundary_pixels)
|
337 |
+
# remove any '0' or negative labels
|
338 |
+
# for k in list(counter.keys()):
|
339 |
+
# if k<=0:
|
340 |
+
# del counter[k]
|
341 |
+
if len(counter)>0:
|
342 |
+
# pick the label with the largest count
|
343 |
+
(new_lab, _cnt) = counter.most_common(1)[0]
|
344 |
+
label_map[cc_mask] = new_lab
|
345 |
+
changed = True
|
346 |
+
else:
|
347 |
+
# no valid neighbor => keep the label, or reassign to 0
|
348 |
+
label_map[cc_mask] = 0
|
349 |
+
changed = True
|
350 |
+
|
351 |
+
return label_map
|
352 |
+
|
353 |
+
# def largest_component_mask(self, label_map, label):
|
354 |
+
# """Find largest connected component for 'label' using 8-neighbor structure."""
|
355 |
+
# mask = (label_map == label)
|
356 |
+
# structure_8 = np.array([[1,1,1],
|
357 |
+
# [1,1,1],
|
358 |
+
# [1,1,1]], dtype=np.uint8)
|
359 |
+
# labeled, num = ndimage.label(mask, structure=structure_8)
|
360 |
+
# if num < 1:
|
361 |
+
# return np.zeros_like(mask, dtype=bool)
|
362 |
+
# # measure region sizes
|
363 |
+
# sizes = ndimage.sum(mask, labeled, range(1,num+1))
|
364 |
+
# largest_id = (np.argmax(sizes) + 1)
|
365 |
+
# return (labeled == largest_id)
|
366 |
+
|
367 |
+
# NOT IN USE
|
368 |
+
def _merge_small_regions_smoothing(self, markers):
|
369 |
+
label_sizes={}
|
370 |
+
labs=np.unique(markers)
|
371 |
+
for lb in labs:
|
372 |
+
if lb>0:
|
373 |
+
label_sizes[lb]=np.sum(markers==lb)
|
374 |
+
sorted_labels=sorted(label_sizes, key=lambda x: label_sizes[x])
|
375 |
+
|
376 |
+
neighbors_8=[
|
377 |
+
(-1,-1), (-1,0), (-1,1),
|
378 |
+
( 0,-1), ( 0,1),
|
379 |
+
( 1,-1), ( 1,0), ( 1,1)
|
380 |
+
]
|
381 |
+
for lb in sorted_labels:
|
382 |
+
size=label_sizes[lb]
|
383 |
+
if 0<size<self.region_size_thresh:
|
384 |
+
coords=np.argwhere(markers==lb)
|
385 |
+
nbr_set=set()
|
386 |
+
for (r,c) in coords:
|
387 |
+
for dr,dc in neighbors_8:
|
388 |
+
rr,cc=r+dr,c+dc
|
389 |
+
if 0<=rr<markers.shape[0] and 0<=cc<markers.shape[1]:
|
390 |
+
nl=markers[rr,cc]
|
391 |
+
if nl>0 and nl!=lb:
|
392 |
+
nbr_set.add(nl)
|
393 |
+
if not nbr_set:
|
394 |
+
continue
|
395 |
+
max_lb=None
|
396 |
+
max_sz=-1
|
397 |
+
for nlb in nbr_set:
|
398 |
+
s_nlb=label_sizes.get(nlb,0)
|
399 |
+
if s_nlb>max_sz:
|
400 |
+
max_sz=s_nlb
|
401 |
+
max_lb=nlb
|
402 |
+
markers[markers==lb]=max_lb
|
403 |
+
label_sizes[max_lb]+=size
|
404 |
+
label_sizes[lb]=0
|
405 |
+
return markers
|
406 |
+
|
407 |
+
# --------------------------
|
408 |
+
# Build binomial-based stats
|
409 |
+
# --------------------------
|
410 |
+
|
411 |
+
def _build_final_outputs(self):
|
412 |
+
"""
|
413 |
+
1) color-labeled final overlay
|
414 |
+
2) visited overlay
|
415 |
+
3) scale region probabilities => sum of E_i = N
|
416 |
+
4) store visited fraction, expected #targets, #found, steps next
|
417 |
+
"""
|
418 |
+
markers=self.final_markers_24x24
|
419 |
+
valid_labels=np.unique(markers[markers>0])
|
420 |
+
|
421 |
+
# color-labeled overlay
|
422 |
+
color_map=self._color_label_map(markers)
|
423 |
+
self.final_output_no_vis_24x24=self._overlay_heatmap_and_labels(self.raw_heatmap, color_map)
|
424 |
+
|
425 |
+
# visited
|
426 |
+
visited_mask=np.zeros(markers.shape,dtype=bool)
|
427 |
+
h,w=markers.shape
|
428 |
+
for idx in self.flattened_visits:
|
429 |
+
rr=idx//w
|
430 |
+
cc=idx%w
|
431 |
+
if 0<=rr<h and 0<=cc<w:
|
432 |
+
visited_mask[rr,cc]=True
|
433 |
+
|
434 |
+
with_vis=self.final_output_no_vis_24x24.copy()
|
435 |
+
red_overlay=np.zeros_like(with_vis)
|
436 |
+
alpha=0.3
|
437 |
+
for idx in self.flattened_visits:
|
438 |
+
rr=idx//w
|
439 |
+
cc=idx%w
|
440 |
+
if 0<=rr<h and 0<=cc<w:
|
441 |
+
red_overlay[rr,cc]=[255,0,0]
|
442 |
+
self.final_output_with_vis_24x24=cv2.addWeighted(with_vis,1.0,red_overlay,alpha,0.0)
|
443 |
+
|
444 |
+
# gather region sums
|
445 |
+
label_sizes={}
|
446 |
+
label_sums={}
|
447 |
+
total_mass=0.0
|
448 |
+
for lb in valid_labels:
|
449 |
+
mask=(markers==lb)
|
450 |
+
size_i=np.sum(mask)
|
451 |
+
label_sizes[lb]=size_i
|
452 |
+
s=np.sum(self.raw_heatmap[mask])
|
453 |
+
label_sums[lb]=s
|
454 |
+
total_mass+=s
|
455 |
+
|
456 |
+
# scale factor
|
457 |
+
scale=self.N/(total_mass+1e-9)
|
458 |
+
|
459 |
+
# how many visited patches in each region
|
460 |
+
region_visited={}
|
461 |
+
for lb in valid_labels:
|
462 |
+
region_visited[lb]=np.sum((markers==lb) & visited_mask)
|
463 |
+
|
464 |
+
# how many targets found in each region
|
465 |
+
region_found_count={lb:0 for lb in valid_labels}
|
466 |
+
for idx, num_found in zip(self.flattened_visits,self.num_found_list):
|
467 |
+
rr=idx//w
|
468 |
+
cc=idx%w
|
469 |
+
lb=markers[rr,cc]
|
470 |
+
if lb>0 and num_found>0:
|
471 |
+
region_found_count[lb]+=num_found
|
472 |
+
|
473 |
+
# store stats
|
474 |
+
for lb in valid_labels:
|
475 |
+
size_i=label_sizes[lb]
|
476 |
+
raw_sum=label_sums[lb]
|
477 |
+
if size_i>0:
|
478 |
+
p_i=scale*(raw_sum/size_i)
|
479 |
+
else:
|
480 |
+
p_i=0.0
|
481 |
+
self.region_prob[lb]=p_i
|
482 |
+
|
483 |
+
visited_num=region_visited[lb]
|
484 |
+
# visited fraction as a string, e.g. "23/100"
|
485 |
+
self.visited_num[lb]=visited_num
|
486 |
+
self.region_visited_str[lb]=f"{visited_num}/{size_i}"
|
487 |
+
|
488 |
+
E_i=size_i*p_i
|
489 |
+
self.region_expected[lb]=E_i
|
490 |
+
|
491 |
+
fcount=region_found_count[lb]
|
492 |
+
self.region_found[lb]=fcount
|
493 |
+
|
494 |
+
self.region_centroids[lb] = self._find_centroid_in_region(markers, lb)
|
495 |
+
|
496 |
+
# steps for the (found+1)-th target => size_i/E_i = 1/p_i
|
497 |
+
if E_i>0:
|
498 |
+
next_steps=(fcount+1)*size_i/E_i
|
499 |
+
else:
|
500 |
+
next_steps=float('inf')
|
501 |
+
self.region_steps_next[lb]=next_steps
|
502 |
+
|
503 |
+
|
504 |
+
def _find_centroid_in_region(self, markers, label_id):
|
505 |
+
coords = np.argwhere(markers == label_id)
|
506 |
+
if coords.size == 0:
|
507 |
+
return None
|
508 |
+
r_mean = np.mean(coords[:, 0])
|
509 |
+
c_mean = np.mean(coords[:, 1])
|
510 |
+
dists = (coords[:, 0] - r_mean)**2 + (coords[:, 1] - c_mean)**2
|
511 |
+
idx = np.argmin(dists)
|
512 |
+
return (coords[idx][0], coords[idx][1])
|
513 |
+
|
514 |
+
|
515 |
+
# --------------------------
|
516 |
+
# Visualization
|
517 |
+
# --------------------------
|
518 |
+
|
519 |
+
def _color_label_map(self, markers):
|
520 |
+
out_img=np.zeros((*markers.shape,3),dtype=np.uint8)
|
521 |
+
labs=np.unique(markers[markers>0])
|
522 |
+
# col_array=self.colormap(np.linspace(0,1,len(labs)))
|
523 |
+
|
524 |
+
for r in range(markers.shape[0]):
|
525 |
+
for c in range(markers.shape[1]):
|
526 |
+
lb = markers[r, c]
|
527 |
+
if lb in self.label_to_color:
|
528 |
+
out_img[r, c] = self.label_to_color[lb]
|
529 |
+
|
530 |
+
# OLD: Color keeps jumping
|
531 |
+
# lab2col={}
|
532 |
+
# for i,lb in enumerate(labs):
|
533 |
+
# lab2col[lb]=(col_array[i][:3]*255).astype(np.uint8)
|
534 |
+
|
535 |
+
# for r in range(markers.shape[0]):
|
536 |
+
# for c in range(markers.shape[1]):
|
537 |
+
# lb=markers[r,c]
|
538 |
+
# if lb in lab2col:
|
539 |
+
# out_img[r,c]=lab2col[lb]
|
540 |
+
return out_img
|
541 |
+
|
542 |
+
def _overlay_heatmap_and_labels(self, heatmap_24x24, color_map):
|
543 |
+
heatmap_rgb=plt.cm.gray((heatmap_24x24/100.0))[:,:,:3]
|
544 |
+
heatmap_rgb=(heatmap_rgb*255).astype(np.uint8)
|
545 |
+
alpha=0.5
|
546 |
+
blended=cv2.addWeighted(heatmap_rgb,1.0,color_map,alpha,0.0)
|
547 |
+
return blended
|
548 |
+
|
549 |
+
def _plot_all_steps(self, step):
|
550 |
+
fig, axs = plt.subplots(2,4, figsize=(20,10))
|
551 |
+
fig.suptitle(self.title, fontsize=16)
|
552 |
+
fig.text(0.5, 0.05, self._get_subtitles(), ha="center", va="center", fontsize=12)
|
553 |
+
plt.tight_layout()
|
554 |
+
plt.subplots_adjust(bottom=0.1) # Prevent overlap
|
555 |
+
|
556 |
+
# (1) Input
|
557 |
+
axs[0,0].imshow(self.input_heatmap_24x24, cmap='viridis')
|
558 |
+
axs[0,0].set_title("1) Input Heatmap")
|
559 |
+
axs[0,0].axis("off")
|
560 |
+
|
561 |
+
# (2) After Blur
|
562 |
+
if self.blurred_heatmap_26x26 is not None:
|
563 |
+
axs[0,1].imshow(self.blurred_heatmap_26x26, cmap='gray')
|
564 |
+
else:
|
565 |
+
axs[0,1].text(0.2,0.5,"No blur", fontsize=12)
|
566 |
+
axs[0,1].set_title("2) After Gaussian Blur")
|
567 |
+
axs[0,1].axis("off")
|
568 |
+
|
569 |
+
# (3) Binary Mask
|
570 |
+
if self.binary_mask_26x26 is not None:
|
571 |
+
axs[0,2].imshow(self.binary_mask_26x26, cmap='gray')
|
572 |
+
else:
|
573 |
+
axs[0,2].text(0.2,0.5,"No binary", fontsize=12)
|
574 |
+
axs[0,2].set_title("3) Binary Mask")
|
575 |
+
axs[0,2].axis("off")
|
576 |
+
|
577 |
+
# (4) Distance Transform
|
578 |
+
if self.dist_transform_26x26 is not None:
|
579 |
+
axs[0,3].imshow(self.dist_transform_26x26, cmap='jet')
|
580 |
+
else:
|
581 |
+
axs[0,3].text(0.2,0.5,"No dist transform", fontsize=12)
|
582 |
+
axs[0,3].set_title("4) Distance Transform")
|
583 |
+
axs[0,3].axis("off")
|
584 |
+
|
585 |
+
# (5) Sure Foreground
|
586 |
+
if self.sure_fg_mask_26x26 is not None:
|
587 |
+
axs[1,0].imshow(self.sure_fg_mask_26x26, cmap='gray')
|
588 |
+
else:
|
589 |
+
axs[1,0].text(0.2,0.5,"No sure-FG", fontsize=12)
|
590 |
+
axs[1,0].set_title("5) Sure Foreground Mask")
|
591 |
+
axs[1,0].axis("off")
|
592 |
+
|
593 |
+
# (6) Raw Watershed
|
594 |
+
if self.color_watershed_raw_26x26 is not None:
|
595 |
+
axs[1,1].imshow(self.color_watershed_raw_26x26)
|
596 |
+
else:
|
597 |
+
axs[1,1].text(0.2,0.5,"No watershed", fontsize=12)
|
598 |
+
axs[1,1].set_title("6) Raw Watershed Output")
|
599 |
+
axs[1,1].axis("off")
|
600 |
+
|
601 |
+
# (7) Final (No Vis)
|
602 |
+
if self.final_output_no_vis_24x24 is not None:
|
603 |
+
axs[1,2].imshow(self.final_output_no_vis_24x24)
|
604 |
+
self._plot_centroids(axs[1,2])
|
605 |
+
self._add_legend(axs[1,2])
|
606 |
+
else:
|
607 |
+
axs[1,2].text(0.2,0.5,"No final", fontsize=12)
|
608 |
+
axs[1,2].set_title("7) Final Output (No Vis)")
|
609 |
+
axs[1,2].axis("off")
|
610 |
+
|
611 |
+
# (8) Final (With Vis)
|
612 |
+
if self.final_output_no_vis_24x24 is not None:
|
613 |
+
axs[1,3].imshow(self.final_output_no_vis_24x24) # final_output_with_vis_24x24
|
614 |
+
# Red line for path
|
615 |
+
path_rows, path_cols = [], []
|
616 |
+
targets = []
|
617 |
+
for i, idx in enumerate(self.flattened_visits):
|
618 |
+
rr = idx // self.final_output_no_vis_24x24.shape[1]
|
619 |
+
cc = idx % self.final_output_no_vis_24x24.shape[1]
|
620 |
+
path_rows.append(rr)
|
621 |
+
path_cols.append(cc)
|
622 |
+
if self.num_found_list[i] > 0:
|
623 |
+
targets.append((rr, cc))
|
624 |
+
axs[1,3].plot(path_cols, path_rows, c="r", linewidth=2)
|
625 |
+
axs[1,3].plot(path_cols[-1], path_rows[-1], markersize=12, zorder=99, marker="^", ls="-", c="r", mec="black")
|
626 |
+
axs[1,3].plot(path_cols[0], path_rows[0], 'co', c="r", markersize=8, zorder=5)
|
627 |
+
for target in targets:
|
628 |
+
axs[1,3].plot(target[1], target[0], color='g', marker='x', linestyle='-', mec="black", markersize=12, markeredgewidth=4, zorder=999)
|
629 |
+
self._plot_centroids(axs[1,3])
|
630 |
+
self._add_legend(axs[1,3])
|
631 |
+
else:
|
632 |
+
axs[1,3].text(0.2,0.5,"No visited out", fontsize=12)
|
633 |
+
axs[1,3].set_title("8) Final Output (With Vis)")
|
634 |
+
axs[1,3].axis("off")
|
635 |
+
|
636 |
+
# plt.show()
|
637 |
+
|
638 |
+
if not os.path.exists(self.gifs_dir):
|
639 |
+
os.makedirs(self.gifs_dir)
|
640 |
+
|
641 |
+
plt.savefig(f'{self.gifs_dir}/watershed_{step}.png'.format(dpi=150))
|
642 |
+
self.segmentor_frame_files.append(f'{self.gifs_dir}/watershed_{step}.png')
|
643 |
+
plt.close()
|
644 |
+
|
645 |
+
|
646 |
+
# def _plot_centroids(self, ax):
|
647 |
+
# markers=self.final_markers_24x24
|
648 |
+
# labs=np.unique(markers[markers>0])
|
649 |
+
# for lb in labs:
|
650 |
+
# coords=np.argwhere(markers==lb)
|
651 |
+
# if coords.size>0:
|
652 |
+
# r_mean=np.mean(coords[:,0])
|
653 |
+
# c_mean=np.mean(coords[:,1])
|
654 |
+
# ax.plot(c_mean, r_mean, 'x', color='white', markersize=7)
|
655 |
+
def _plot_centroids(self, ax):
|
656 |
+
markers = self.final_markers_24x24
|
657 |
+
labs = np.unique(markers[markers > 0])
|
658 |
+
for lb in labs:
|
659 |
+
# Use the stored centroid to ensure it lies within the region.
|
660 |
+
cent = self.region_centroids.get(lb, None)
|
661 |
+
if cent is not None:
|
662 |
+
ax.plot(cent[1], cent[0], 'x', color='white', markersize=7)
|
663 |
+
|
664 |
+
def _add_legend(self, ax):
|
665 |
+
markers=self.final_markers_24x24
|
666 |
+
labs=np.unique(markers[markers>0])
|
667 |
+
# col_array=self.colormap(np.linspace(0,1,len(labs)))
|
668 |
+
|
669 |
+
handles=[]
|
670 |
+
for i,lab in enumerate(labs):
|
671 |
+
color=self.label_to_color[lab] / 255.0
|
672 |
+
label_str = f"R{lab}"
|
673 |
+
h=plt.Line2D(
|
674 |
+
[0],[0],
|
675 |
+
marker='o',
|
676 |
+
color='w',
|
677 |
+
markerfacecolor=color, # col_array[i][:3],
|
678 |
+
markersize=10,
|
679 |
+
label=label_str
|
680 |
+
)
|
681 |
+
handles.append(h)
|
682 |
+
|
683 |
+
ax.legend(handles=handles, loc='upper right', fontsize=8, frameon=True)
|
684 |
+
|
685 |
+
|
686 |
+
def _get_subtitles(self):
|
687 |
+
markers=self.final_markers_24x24
|
688 |
+
labs=np.unique(markers[markers>0])
|
689 |
+
|
690 |
+
suptitle_str=""
|
691 |
+
for i,lab in enumerate(labs):
|
692 |
+
prob_val=self.region_prob.get(lab,0.0)
|
693 |
+
visit_str=self.region_visited_str.get(lab,"0/0")
|
694 |
+
E_i=self.region_expected.get(lab,0.0)
|
695 |
+
num_found=self.region_found.get(lab,0)
|
696 |
+
steps_next=self.region_steps_next.get(lab,float('inf'))
|
697 |
+
|
698 |
+
if math.isinf(steps_next):
|
699 |
+
steps_str="∞"
|
700 |
+
else:
|
701 |
+
steps_str=f"{steps_next:.2f}"
|
702 |
+
|
703 |
+
label_str = (
|
704 |
+
f"R{lab}: p={prob_val:.3f}, Vis={visit_str}, "
|
705 |
+
f"E_tgts={E_i:.2f}, #found={num_found}, "
|
706 |
+
f"E_steps({num_found+1}th)={steps_str}"
|
707 |
+
)
|
708 |
+
suptitle_str+=label_str+"\n"
|
709 |
+
|
710 |
+
return suptitle_str
|
711 |
+
|
712 |
+
|
713 |
+
def incorporate_kmeans_boundaries(self, kmeans_label_map, step=0):
|
714 |
+
"""
|
715 |
+
Merges the existing final watershed map (self.final_markers_24x24) with a
|
716 |
+
k-means label map of the same shape. Each watershed region is subdivided by
|
717 |
+
the boundaries from the k-means map to produce a new, finer segmentation.
|
718 |
+
|
719 |
+
Then we recompute region stats (probabilities, centroids, etc.) and plot a
|
720 |
+
2×2 figure:
|
721 |
+
1) Original Heatmap
|
722 |
+
2) Old Watershed Segmentation
|
723 |
+
3) k-Means Label Map
|
724 |
+
4) Merged Segmentation
|
725 |
+
"""
|
726 |
+
|
727 |
+
# ---------------------------
|
728 |
+
# 1) Ensure shapes match
|
729 |
+
# ---------------------------
|
730 |
+
if kmeans_label_map.shape != self.final_markers_24x24.shape:
|
731 |
+
raise ValueError("kmeans_label_map must match the shape of the final watershed map.")
|
732 |
+
|
733 |
+
# We produce a new merged label map (24×24) subdividing each watershed region
|
734 |
+
# by the kmeans labels. For example, if watershed label L intersects multiple
|
735 |
+
# kmeans labels, that region is split into multiple subregions.
|
736 |
+
|
737 |
+
# We'll store it in self.merged_markers_24x24
|
738 |
+
self.merged_markers_24x24 = np.zeros_like(self.final_markers_24x24, dtype=np.int32)
|
739 |
+
|
740 |
+
watershed_labels = np.unique(self.final_markers_24x24[self.final_markers_24x24 > 0])
|
741 |
+
next_label = 1
|
742 |
+
|
743 |
+
for wlab in watershed_labels:
|
744 |
+
# Mask for watershed label wlab
|
745 |
+
wmask = (self.final_markers_24x24 == wlab)
|
746 |
+
|
747 |
+
# Among those pixels, we see which kmeans labels appear
|
748 |
+
kmeans_vals = np.unique(kmeans_label_map[wmask])
|
749 |
+
for klabel in kmeans_vals:
|
750 |
+
# Intersection => subregion
|
751 |
+
subregion_mask = wmask & (kmeans_label_map == klabel)
|
752 |
+
if not np.any(subregion_mask):
|
753 |
+
continue
|
754 |
+
# Assign a new label
|
755 |
+
self.merged_markers_24x24[subregion_mask] = next_label
|
756 |
+
next_label += 1
|
757 |
+
|
758 |
+
# Optional: If you want to merge small subregions again:
|
759 |
+
self.merged_markers_24x24 = self._merge_small_regions(self.merged_markers_24x24, self.region_size_thresh, n_iter=1)
|
760 |
+
|
761 |
+
# Recompute all stats for the merged segmentation:
|
762 |
+
self._build_merged_outputs()
|
763 |
+
|
764 |
+
# Finally, plot the 2×2 figure:
|
765 |
+
if self.plot_img:
|
766 |
+
self._plot_kmeans_merge(kmeans_label_map, step)
|
767 |
+
|
768 |
+
return self.merged_markers_24x24
|
769 |
+
|
770 |
+
|
771 |
+
def _build_merged_outputs(self):
|
772 |
+
"""
|
773 |
+
Re-run binomial logic on self.merged_markers_24x24:
|
774 |
+
1) build color overlay
|
775 |
+
2) scale probabilities
|
776 |
+
3) compute visited fraction, found targets, centroids, etc.
|
777 |
+
We'll store them in new variables (prefixed 'merged_') so we don't overwrite
|
778 |
+
the original watershed data.
|
779 |
+
"""
|
780 |
+
|
781 |
+
markers = self.merged_markers_24x24
|
782 |
+
valid_labels = np.unique(markers[markers > 0])
|
783 |
+
|
784 |
+
# 1) color-labeled overlay
|
785 |
+
self.merged_output_color = self._color_label_map(markers)
|
786 |
+
self.merged_output_no_vis = self._overlay_heatmap_and_labels(self.raw_heatmap, self.merged_output_color)
|
787 |
+
|
788 |
+
# 2) visited overlay
|
789 |
+
h, w = markers.shape
|
790 |
+
visited_mask = np.zeros(markers.shape, dtype=bool)
|
791 |
+
for idx in self.flattened_visits:
|
792 |
+
rr = idx // w
|
793 |
+
cc = idx % w
|
794 |
+
if 0 <= rr < h and 0 <= cc < w:
|
795 |
+
visited_mask[rr, cc] = True
|
796 |
+
|
797 |
+
merged_with_vis = self.merged_output_no_vis.copy()
|
798 |
+
red_overlay = np.zeros_like(merged_with_vis)
|
799 |
+
alpha = 0.3
|
800 |
+
for idx in self.flattened_visits:
|
801 |
+
rr = idx // w
|
802 |
+
cc = idx % w
|
803 |
+
if 0 <= rr < h and 0 <= cc < w:
|
804 |
+
red_overlay[rr, cc] = [255, 0, 0]
|
805 |
+
self.merged_output_with_vis = cv2.addWeighted(merged_with_vis, 1.0, red_overlay, alpha, 0.0)
|
806 |
+
|
807 |
+
# 3) compute scale factor
|
808 |
+
label_sizes = {}
|
809 |
+
label_sums = {}
|
810 |
+
total_mass = 0.0
|
811 |
+
for lb in valid_labels:
|
812 |
+
mask = (markers == lb)
|
813 |
+
size_i = np.sum(mask)
|
814 |
+
label_sizes[lb] = size_i
|
815 |
+
s = np.sum(self.raw_heatmap[mask])
|
816 |
+
label_sums[lb] = s
|
817 |
+
total_mass += s
|
818 |
+
|
819 |
+
scale = self.N/(total_mass + 1e-9)
|
820 |
+
|
821 |
+
# 4) visited fraction, found count, etc.
|
822 |
+
merged_visited = {}
|
823 |
+
for lb in valid_labels:
|
824 |
+
merged_visited[lb] = np.sum((markers == lb) & visited_mask)
|
825 |
+
|
826 |
+
# found_count
|
827 |
+
merged_found_count = {lb:0 for lb in valid_labels}
|
828 |
+
for idx, num_found in zip(self.flattened_visits, self.num_found_list):
|
829 |
+
rr = idx // w
|
830 |
+
cc = idx % w
|
831 |
+
lb = markers[rr, cc]
|
832 |
+
if lb > 0 and num_found > 0:
|
833 |
+
merged_found_count[lb] += num_found
|
834 |
+
|
835 |
+
# store them
|
836 |
+
self.merged_region_prob = {}
|
837 |
+
self.merged_region_visited_str = {}
|
838 |
+
self.merged_region_expected = {}
|
839 |
+
self.merged_region_found = {}
|
840 |
+
self.merged_region_steps_next = {}
|
841 |
+
self.merged_region_centroids = {}
|
842 |
+
|
843 |
+
for lb in valid_labels:
|
844 |
+
size_i = label_sizes[lb]
|
845 |
+
raw_sum = label_sums[lb]
|
846 |
+
if size_i > 0:
|
847 |
+
p_i = scale*(raw_sum/size_i)
|
848 |
+
else:
|
849 |
+
p_i = 0.0
|
850 |
+
self.merged_region_prob[lb] = p_i
|
851 |
+
|
852 |
+
visited_num = merged_visited[lb]
|
853 |
+
self.merged_visited_num[lb]=visited_num
|
854 |
+
self.merged_region_visited_str[lb] = f"{visited_num}/{size_i}"
|
855 |
+
|
856 |
+
E_i = size_i*p_i
|
857 |
+
self.merged_region_expected[lb] = E_i
|
858 |
+
|
859 |
+
fcount = merged_found_count[lb]
|
860 |
+
self.merged_region_found[lb] = fcount
|
861 |
+
|
862 |
+
# recompute centroid for merged label
|
863 |
+
cent = self._find_centroid_in_region(markers, lb)
|
864 |
+
self.merged_region_centroids[lb] = cent
|
865 |
+
|
866 |
+
if E_i > 0:
|
867 |
+
next_steps = (fcount+1)*size_i/E_i
|
868 |
+
else:
|
869 |
+
next_steps = float('inf')
|
870 |
+
self.merged_region_steps_next[lb] = next_steps
|
871 |
+
|
872 |
+
|
873 |
+
|
874 |
+
def _plot_kmeans_merge(self, kmeans_label_map, step):
|
875 |
+
"""
|
876 |
+
Plots a 2×2 figure:
|
877 |
+
(1) Original Heatmap
|
878 |
+
(2) Old Final Markers (Watershed)
|
879 |
+
(3) k-means Label Map
|
880 |
+
(4) Merged Markers
|
881 |
+
"""
|
882 |
+
|
883 |
+
|
884 |
+
fig, axs = plt.subplots(2, 2, figsize=(12, 12))
|
885 |
+
fig.suptitle("Merging K-Means Boundaries", fontsize=16)
|
886 |
+
fig.text(0.5, 0.05, self._get_merged_subtitles(), ha="center", va="center", fontsize=12)
|
887 |
+
plt.tight_layout()
|
888 |
+
plt.subplots_adjust(bottom=0.1) # Prevent overlap
|
889 |
+
|
890 |
+
# (1) Original heatmap
|
891 |
+
axs[0,0].imshow(self.input_heatmap_24x24, cmap='viridis')
|
892 |
+
axs[0,0].set_title("1) Original Heatmap")
|
893 |
+
axs[0,0].axis("off")
|
894 |
+
|
895 |
+
# (2) Watershed Segmentation
|
896 |
+
old_color = self._color_label_map(self.final_markers_24x24)
|
897 |
+
old_overlay = self._overlay_heatmap_and_labels(self.raw_heatmap, old_color)
|
898 |
+
axs[0,1].imshow(old_overlay)
|
899 |
+
axs[0,1].set_title("2) Watershed Segmentation")
|
900 |
+
axs[0,1].axis("off")
|
901 |
+
# -------------- NEW: Add the merged legend to the watershed subplot --------------
|
902 |
+
self._add_merged_legend(axs[0,1])
|
903 |
+
# -------------------------------------------------------------------------------
|
904 |
+
|
905 |
+
# (3) k-Means Label Map
|
906 |
+
kmeans_color = self._color_label_map(kmeans_label_map)
|
907 |
+
kmeans_overlay = self._overlay_heatmap_and_labels(self.raw_heatmap, kmeans_color)
|
908 |
+
axs[1,0].imshow(kmeans_overlay)
|
909 |
+
axs[1,0].set_title("3) K-Means Label Map")
|
910 |
+
axs[1,0].axis("off")
|
911 |
+
|
912 |
+
# (4) Merged Markers
|
913 |
+
# We already have self.merged_output_with_vis, but we'll now add a pink overlay for the path
|
914 |
+
# merged_with_vis = self.merged_output_with_vis.copy()
|
915 |
+
|
916 |
+
# Red line for path
|
917 |
+
axs[1,1].imshow(self.merged_output_no_vis)
|
918 |
+
path_rows, path_cols = [], []
|
919 |
+
targets = []
|
920 |
+
for i, idx in enumerate(self.flattened_visits):
|
921 |
+
rr = idx // self.merged_markers_24x24.shape[1]
|
922 |
+
cc = idx % self.merged_markers_24x24.shape[1]
|
923 |
+
path_rows.append(rr)
|
924 |
+
path_cols.append(cc)
|
925 |
+
if self.num_found_list[i] > 0:
|
926 |
+
targets.append((rr, cc))
|
927 |
+
axs[1,1].plot(path_cols, path_rows, c="r", linewidth=2)
|
928 |
+
axs[1,1].plot(path_cols[-1], path_rows[-1], markersize=12, zorder=99, marker="^", ls="-", c="r", mec="black")
|
929 |
+
axs[1,1].plot(path_cols[0], path_rows[0], 'co', c="r", markersize=8, zorder=5)
|
930 |
+
for target in targets:
|
931 |
+
axs[1,1].plot(target[1], target[0], color='g', marker='x', linestyle='-', mec="black", markersize=12, markeredgewidth=4, zorder=999)
|
932 |
+
|
933 |
+
axs[1,1].set_title("4) Merged Map")
|
934 |
+
axs[1,1].axis("off")
|
935 |
+
|
936 |
+
# Mark centroids for merged subregions
|
937 |
+
merged_labs = np.unique(self.merged_markers_24x24[self.merged_markers_24x24 > 0])
|
938 |
+
for lab in merged_labs:
|
939 |
+
cent = self.merged_region_centroids.get(lab, None)
|
940 |
+
if cent is not None:
|
941 |
+
axs[1,1].plot(cent[1], cent[0], 'x', color='white', markersize=7)
|
942 |
+
|
943 |
+
# Add the merged legend to the final subplot as well
|
944 |
+
self._add_merged_legend(axs[1,1])
|
945 |
+
|
946 |
+
# Save figure (same naming logic as your code)
|
947 |
+
if not os.path.exists(self.gifs_dir):
|
948 |
+
os.makedirs(self.gifs_dir)
|
949 |
+
|
950 |
+
plt.savefig(f'{self.gifs_dir}/kmeans_merged_{step}.png', dpi=150)
|
951 |
+
self.kmeans_merged_frame_files.append(f'{self.gifs_dir}/kmeans_merged_{step}.png')
|
952 |
+
plt.close()
|
953 |
+
|
954 |
+
|
955 |
+
def _add_merged_legend(self, ax):
|
956 |
+
"""
|
957 |
+
Similar to _add_legend, but pulling data from self.merged_region_* variables.
|
958 |
+
"""
|
959 |
+
markers = self.merged_markers_24x24
|
960 |
+
labs = np.unique(markers[markers>0])
|
961 |
+
# col_array = self.colormap(np.linspace(0,1,len(labs)))
|
962 |
+
|
963 |
+
handles = []
|
964 |
+
for i,lab in enumerate(labs):
|
965 |
+
color = self.label_to_color[lab] / 255.0
|
966 |
+
label_str = f"R{lab}"
|
967 |
+
h=plt.Line2D(
|
968 |
+
[0],[0],
|
969 |
+
marker='o',
|
970 |
+
color='w',
|
971 |
+
markerfacecolor=color, #col_array[i][:3],
|
972 |
+
markersize=10,
|
973 |
+
label=label_str
|
974 |
+
)
|
975 |
+
handles.append(h)
|
976 |
+
|
977 |
+
ax.legend(handles=handles, loc='upper right', fontsize=8, frameon=True)
|
978 |
+
|
979 |
+
# TODO: Merge with _get_subtitles
|
980 |
+
def _get_merged_subtitles(self):
|
981 |
+
markers=self.merged_markers_24x24
|
982 |
+
labs=np.unique(markers[markers>0])
|
983 |
+
|
984 |
+
suptitle_str=""
|
985 |
+
for i,lab in enumerate(labs):
|
986 |
+
prob_val=self.merged_region_prob.get(lab,0.0)
|
987 |
+
visit_str=self.merged_region_visited_str.get(lab,"0/0")
|
988 |
+
E_i=self.merged_region_expected.get(lab,0.0)
|
989 |
+
num_found=self.merged_region_found.get(lab,0)
|
990 |
+
steps_next=self.merged_region_steps_next.get(lab,float('inf'))
|
991 |
+
|
992 |
+
if math.isinf(steps_next):
|
993 |
+
steps_str="∞"
|
994 |
+
else:
|
995 |
+
steps_str=f"{steps_next:.2f}"
|
996 |
+
|
997 |
+
label_str = (
|
998 |
+
f"R{lab}: p={prob_val:.3f}, Vis={visit_str}, "
|
999 |
+
f"E_tgts={E_i:.2f}, #found={num_found}, "
|
1000 |
+
f"E_steps({num_found+1}th)={steps_str}"
|
1001 |
+
)
|
1002 |
+
suptitle_str+=label_str+"\n"
|
1003 |
+
|
1004 |
+
return suptitle_str
|
1005 |
+
|
1006 |
+
|
1007 |
+
# ------ OTHER HELPER FUNCTIONS -------#
|
1008 |
+
def get_label_id(self, patch_idx):
|
1009 |
+
"""
|
1010 |
+
Given a flattened index (row-major order) within the 24×24 grid,
|
1011 |
+
return the average probability of the corresponding merged region.
|
1012 |
+
|
1013 |
+
:param patch_idx: Flattened index (0 to H×W-1).
|
1014 |
+
:return: The region-average probability, or 0.0 if the patch
|
1015 |
+
is out of bounds or not in a labeled region.
|
1016 |
+
"""
|
1017 |
+
# Guard: check we have a merged label map
|
1018 |
+
if self.merged_markers_24x24 is None:
|
1019 |
+
raise ValueError("Merged markers do not exist. "
|
1020 |
+
"Please call `incorporate_kmeans_boundaries(...)` first.")
|
1021 |
+
|
1022 |
+
h, w = self.merged_markers_24x24.shape
|
1023 |
+
# Derive row, col
|
1024 |
+
row = patch_idx // w
|
1025 |
+
col = patch_idx % w
|
1026 |
+
|
1027 |
+
if row < 0 or row >= h or col < 0 or col >= w:
|
1028 |
+
return 0.0
|
1029 |
+
|
1030 |
+
# Find which merged label region that pixel belongs to
|
1031 |
+
label_id = self.merged_markers_24x24[row, col]
|
1032 |
+
|
1033 |
+
return label_id
|
1034 |
+
|
1035 |
+
# ----------------------------
|
1036 |
+
# Example MAIN usage
|
1037 |
+
# ----------------------------
|
1038 |
+
if __name__=="__main__":
|
1039 |
+
# 1) Load a 24x24 heatmap
|
1040 |
+
file_path="./expt/seg_mask_step0_v1.npy"
|
1041 |
+
heatmap_24x24=np.load(file_path)
|
1042 |
+
heatmap_24x24=np.clip(heatmap_24x24,0,100)
|
1043 |
+
|
1044 |
+
# 2) Simulate a random walk
|
1045 |
+
max_r, max_c=heatmap_24x24.shape[0]-1, heatmap_24x24.shape[1]-1
|
1046 |
+
steps=100
|
1047 |
+
current_pos=(0,0)
|
1048 |
+
moves_8=[(-1,-1),(-1,0),(-1,1),
|
1049 |
+
( 0,-1), ( 0,1),
|
1050 |
+
( 1,-1),( 1,0), ( 1,1)]
|
1051 |
+
path_coords=[current_pos]
|
1052 |
+
num_found_list=[0]
|
1053 |
+
for _ in range(steps):
|
1054 |
+
r,c=current_pos
|
1055 |
+
dr,dc=random.choice(moves_8)
|
1056 |
+
nr, nc=r+dr,c+dc
|
1057 |
+
if 0<=nr<=max_r and 0<=nc<=max_c:
|
1058 |
+
current_pos=(nr,nc)
|
1059 |
+
path_coords.append(current_pos)
|
1060 |
+
# 10% chance for new target
|
1061 |
+
# num_found_list.append(random.random()<0.1)
|
1062 |
+
num_found_list.append(0)
|
1063 |
+
num_found_list[0]=1
|
1064 |
+
num_found_list[1]=2
|
1065 |
+
num_found_list[2]=3
|
1066 |
+
|
1067 |
+
# Flatten
|
1068 |
+
width=heatmap_24x24.shape[1]
|
1069 |
+
flattened_visits=[]
|
1070 |
+
for (rr,cc) in path_coords:
|
1071 |
+
idx=rr*width+cc
|
1072 |
+
flattened_visits.append(idx)
|
1073 |
+
|
1074 |
+
# 3) Instantiate & run
|
1075 |
+
segmenter=WatershedBinomial(
|
1076 |
+
N=10.0, # total expected # across map
|
1077 |
+
blur_kernel=(5,5),
|
1078 |
+
sure_fg_factor=0.5,
|
1079 |
+
dilation_kernel_size=(2,2),
|
1080 |
+
dilation_iters=3,
|
1081 |
+
region_size_thresh=9,
|
1082 |
+
plot_img=True,
|
1083 |
+
gifs_dir="/home/user/VLM-Search/inference/test_results/gifs",
|
1084 |
+
colormap=plt.cm.tab20,
|
1085 |
+
title="Watershed + Binomial Steps"
|
1086 |
+
)
|
1087 |
+
final=segmenter.run(heatmap_24x24, flattened_visits, num_found_list, step=0)
|
1088 |
+
|
1089 |
+
# Suppose you load a kmeans_label_map in the same shape (24×24)
|
1090 |
+
file_path="./expt/smoothed_labels_2d.npy"
|
1091 |
+
kmeans_label_map = np.load(file_path)
|
1092 |
+
|
1093 |
+
# Now incorporate it into the final segmentation:
|
1094 |
+
segmenter.incorporate_kmeans_boundaries(kmeans_label_map, step=0)
|
Taxabind/Taxabind/SoundBind/config.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
config = edict()
|
4 |
+
config.train_df = '/mnt/hdd/inat2021_ds/sound_train/sound_image_pairs_filtered.csv' # 'train_df.json'
|
5 |
+
config.val_df= '/mnt/hdd/inat2021_ds/sound_val/sound_image_pairs_filtered.csv' # 'val_df.csv'
|
6 |
+
config.data_path = '/mnt/hdd/inat2021_ds/sound_train'
|
7 |
+
# HOW ABOUT VAL?
|
8 |
+
|
9 |
+
# # Toy Example
|
10 |
+
# config.train_df = '../scripts/preprocess/expt/sound_train/sound_image_pairs_filtered.csv' # 'train_df.json'
|
11 |
+
# config.val_df= '../scripts/preprocess/expt/sound_train/sound_image_pairs_filtered.csv' # 'val_df.csv'
|
12 |
+
# config.data_path = '../scripts/preprocess/expt/sound_train'
|
13 |
+
|
14 |
+
config.batch_size = 256
|
15 |
+
config.lr = 1e-4
|
16 |
+
config.accumulate_grad_batches = 8
|
17 |
+
config.max_epochs = 20
|
18 |
+
config.num_workers = 16
|
19 |
+
config.devices = 2
|
20 |
+
config.val_check_interval = 0.5
|
21 |
+
|
22 |
+
|
23 |
+
config.save_dir = 'checkpoints'
|
24 |
+
config.filename = 'soundbind-{epoch:02d}-{val_loss:.2f}'
|
25 |
+
|
26 |
+
config.locked_tuning = True
|
27 |
+
|
28 |
+
# TEMP
|
29 |
+
config.sat_encoder = 'openai/clip-vit-large-patch14-336' # openai/clip-vit-base-patch16, openai/clip-vit-large-patch14-336
|
30 |
+
config.patch_size = 14
|
Taxabind/Taxabind/SoundBind/dataloader.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from torchvision import transforms
|
3 |
+
import os
|
4 |
+
from PIL import Image
|
5 |
+
import pandas as pd
|
6 |
+
from config import config
|
7 |
+
from sound_encoder import get_audio_clap
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
#img, sound, year, month, day
|
12 |
+
class INatDataset(Dataset):
|
13 |
+
def __init__(self,
|
14 |
+
data_file,
|
15 |
+
mode='train'):
|
16 |
+
self.data_file = pd.read_csv(data_file)
|
17 |
+
if mode=='train':
|
18 |
+
self.transform = transforms.Compose([
|
19 |
+
transforms.Resize((256, 256)),
|
20 |
+
transforms.RandomCrop((224, 224)),
|
21 |
+
transforms.RandomHorizontalFlip(0.5),
|
22 |
+
transforms.GaussianBlur(5, (0.01, 1.0)),
|
23 |
+
transforms.ToTensor(),
|
24 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
25 |
+
std=[0.229, 0.224, 0.225])
|
26 |
+
])
|
27 |
+
else:
|
28 |
+
self.transform = transforms.Compose([
|
29 |
+
transforms.Resize((256, 256)),
|
30 |
+
transforms.CenterCrop((224, 224)),
|
31 |
+
transforms.ToTensor(),
|
32 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
33 |
+
std=[0.229, 0.224, 0.225])
|
34 |
+
])
|
35 |
+
self.species_text = self.data_file['scientific_name'].tolist()
|
36 |
+
self.species_classes = list(set(self.species_text))
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
return len(self.data_file)
|
40 |
+
|
41 |
+
def get_sample(self,idx):
|
42 |
+
sample = self.data_file.iloc[idx]
|
43 |
+
id = sample.id
|
44 |
+
sound_format = sample.sound_format
|
45 |
+
image_path = os.path.join(config['data_path'],"images",str(id)+".jpg")
|
46 |
+
sound_path = os.path.join(config['data_path'],"sounds_mp3",str(id)+"."+'mp3')
|
47 |
+
sound = get_audio_clap(sound_path) # , sound_format)
|
48 |
+
|
49 |
+
for k in sound.keys():
|
50 |
+
sound[k] = sound[k].squeeze(0)
|
51 |
+
image = self.transform(Image.open(image_path))
|
52 |
+
|
53 |
+
return image, sound
|
54 |
+
|
55 |
+
def __getitem__(self, idx):
|
56 |
+
image, sound = self.get_sample(idx)
|
57 |
+
return image, sound, self.species_classes.index(self.data_file.iloc[idx]['scientific_name'])
|
Taxabind/Taxabind/SoundBind/model.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
# Ensure the correct directory is at the front of sys.path
|
4 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
5 |
+
|
6 |
+
import open_clip
|
7 |
+
import pytorch_lightning as pl
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from sound_encoder import CLAP_audiomodel_withProjection as AudioEncoder
|
11 |
+
import numpy as np
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
from config import config
|
14 |
+
import random
|
15 |
+
from dataloader import INatDataset
|
16 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
17 |
+
|
18 |
+
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
|
19 |
+
audio_loss = contrastive_loss(similarity)
|
20 |
+
ground_img_loss = contrastive_loss(similarity.t())
|
21 |
+
return 0.5*audio_loss + 0.5*ground_img_loss
|
22 |
+
|
23 |
+
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
|
24 |
+
|
25 |
+
return nn.functional.cross_entropy(logits[:logits.shape[1]], torch.arange(logits.shape[1], device=logits.device))
|
26 |
+
|
27 |
+
class AudioBind(pl.LightningModule):
|
28 |
+
def __init__(self, train_dataset, val_dataset, **kwargs):
|
29 |
+
super().__init__()
|
30 |
+
self.train_dataset = train_dataset
|
31 |
+
self.val_dataset = val_dataset
|
32 |
+
self.model, *_ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip')
|
33 |
+
if config.locked_tuning:
|
34 |
+
for param in self.model.parameters():
|
35 |
+
param.requires_grad = False
|
36 |
+
self.audio_encoder = AudioEncoder(freeze=False)
|
37 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
38 |
+
self.batch_size = kwargs.get('batch_size')
|
39 |
+
self.num_workers = kwargs.get('num_workers')
|
40 |
+
self.lr = kwargs.get('lr', 1e-4)
|
41 |
+
|
42 |
+
def forward(self, image, audio):
|
43 |
+
with torch.no_grad():
|
44 |
+
image_embeds, *_ = self.model(image)
|
45 |
+
unnormalized_audio_embeds = self.audio_encoder(audio)
|
46 |
+
audio_embeds = torch.nn.functional.normalize(unnormalized_audio_embeds, dim=-1)
|
47 |
+
return image_embeds, audio_embeds
|
48 |
+
|
49 |
+
def shared_step(self, batch):
|
50 |
+
image, audio, *_ = batch
|
51 |
+
image_embeds, audio_embeds = self(image, audio)
|
52 |
+
logit_scale = self.logit_scale.exp()
|
53 |
+
logits_per_img = torch.matmul(image_embeds,audio_embeds.t())*logit_scale
|
54 |
+
cross_contrastive_loss = clip_loss(logits_per_img)
|
55 |
+
return cross_contrastive_loss
|
56 |
+
|
57 |
+
def training_step(self, batch, batch_idx):
|
58 |
+
loss = self.shared_step(batch)
|
59 |
+
self.log('train_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
|
60 |
+
self.log('temperature', self.logit_scale.data, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
|
61 |
+
return loss
|
62 |
+
|
63 |
+
def on_train_batch_end(self,outputs,batch, batch_idx):
|
64 |
+
if self.logit_scale.data > np.log(100):
|
65 |
+
self.logit_scale.data = torch.clamp(self.logit_scale.data, 0, np.log(100))
|
66 |
+
|
67 |
+
def validation_step(self, batch, batch_idx):
|
68 |
+
loss = self.shared_step(batch)
|
69 |
+
self.log('val_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
|
70 |
+
return loss
|
71 |
+
|
72 |
+
def train_dataloader(self):
|
73 |
+
return DataLoader(self.train_dataset,
|
74 |
+
batch_size=self.batch_size,
|
75 |
+
num_workers=self.num_workers,
|
76 |
+
shuffle=True,
|
77 |
+
persistent_workers=False)
|
78 |
+
|
79 |
+
def val_dataloader(self):
|
80 |
+
return DataLoader(self.val_dataset,
|
81 |
+
batch_size=self.batch_size,
|
82 |
+
num_workers=self.num_workers,
|
83 |
+
shuffle=False,
|
84 |
+
persistent_workers=False)
|
85 |
+
|
86 |
+
def configure_optimizers(self):
|
87 |
+
params = self.parameters()
|
88 |
+
self.optim = torch.optim.AdamW(params,
|
89 |
+
lr=self.lr,
|
90 |
+
betas=(0.9,0.98),
|
91 |
+
eps=1e-6
|
92 |
+
)
|
93 |
+
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
94 |
+
optimizer=self.optim,
|
95 |
+
T_0=20
|
96 |
+
)
|
97 |
+
return [self.optim], [self.scheduler]
|
98 |
+
|
99 |
+
def seed_everything(seed=42):
|
100 |
+
"""
|
101 |
+
seed: int
|
102 |
+
"""
|
103 |
+
torch.manual_seed(seed)
|
104 |
+
torch.cuda.manual_seed_all(seed)
|
105 |
+
np.random.seed(seed)
|
106 |
+
random.seed(seed)
|
107 |
+
torch.backends.cudnn.deterministic = True
|
108 |
+
torch.backends.cudnn.benchmark = False
|
109 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
110 |
+
|
111 |
+
if __name__=='__main__':
|
112 |
+
import warnings
|
113 |
+
warnings.filterwarnings("ignore")
|
114 |
+
torch.set_warn_always(False)
|
115 |
+
|
116 |
+
seed_everything()
|
117 |
+
train_dataset = INatDataset(data_file=config.train_df, mode='train')
|
118 |
+
val_dataset = INatDataset(data_file=config.val_df, mode='val')
|
119 |
+
kwargs = {'batch_size':config.batch_size, 'num_workers': config.num_workers}
|
120 |
+
|
121 |
+
model = AudioBind(train_dataset, val_dataset, **kwargs)
|
122 |
+
torch.cuda.empty_cache()
|
123 |
+
|
124 |
+
checkpoint = ModelCheckpoint(
|
125 |
+
monitor='val_loss',
|
126 |
+
dirpath=config.save_dir,
|
127 |
+
filename=config.filename,
|
128 |
+
mode='min',
|
129 |
+
save_top_k=3
|
130 |
+
)
|
131 |
+
trainer = pl.Trainer(
|
132 |
+
accelerator='gpu',
|
133 |
+
devices=config.devices,
|
134 |
+
strategy='ddp',
|
135 |
+
max_epochs=config.max_epochs,
|
136 |
+
num_nodes=1,
|
137 |
+
callbacks=[checkpoint],
|
138 |
+
accumulate_grad_batches=config.accumulate_grad_batches,
|
139 |
+
log_every_n_steps=1
|
140 |
+
)
|
141 |
+
trainer.fit(model)
|
Taxabind/Taxabind/SoundBind/sound_encoder.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#Hugging face way of loading AudioCLAP model
|
2 |
+
from transformers import ClapProcessor
|
3 |
+
from transformers import ClapAudioModelWithProjection
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import torchaudio
|
9 |
+
import os
|
10 |
+
|
11 |
+
processor = ClapProcessor.from_pretrained("laion/clap-htsat-fused")
|
12 |
+
SAMPLE_RATE = 48000
|
13 |
+
|
14 |
+
def get_audio_clap(path_to_audio,format="mp3",padding="repeatpad",truncation="fusion"):
|
15 |
+
track, sr = torchaudio.load(path_to_audio, format=format) # torchaudio.load(path_to_audio)
|
16 |
+
track = track.mean(axis=0)
|
17 |
+
track = torchaudio.functional.resample(track, orig_freq=sr, new_freq=SAMPLE_RATE)
|
18 |
+
output = processor(audios=track, sampling_rate=SAMPLE_RATE, max_length_s=10, return_tensors="pt",padding=padding,truncation=truncation)
|
19 |
+
return output
|
20 |
+
|
21 |
+
|
22 |
+
class CLAP_audiomodel_withProjection(pl.LightningModule):
|
23 |
+
def __init__(self,freeze=False):
|
24 |
+
super().__init__()
|
25 |
+
if freeze:
|
26 |
+
self.model = ClapAudioModelWithProjection.from_pretrained("laion/clap-htsat-fused").eval()
|
27 |
+
for params in self.model.parameters():
|
28 |
+
params.requires_grad=False
|
29 |
+
else:
|
30 |
+
self.model = ClapAudioModelWithProjection.from_pretrained("laion/clap-htsat-fused").train()
|
31 |
+
def forward(self,audio):
|
32 |
+
batch_embeddings_audio = self.model(**audio)['audio_embeds']
|
33 |
+
return batch_embeddings_audio
|
34 |
+
|
35 |
+
if __name__ == '__main__':
|
36 |
+
path_to_audio ="/mnt/hdd/inat2021_ds/sound_train/sounds_mp3/165878447.mp3"
|
37 |
+
sample = get_audio_clap(path_to_audio)
|
38 |
+
print(sample.keys())
|
39 |
+
|
40 |
+
sample['input_features'] = torch.concat([sample['input_features'],sample['input_features']],axis=0)
|
41 |
+
sample['is_longer'] = torch.concat([sample['is_longer'],sample['is_longer']],axis=0)
|
42 |
+
print(sample['input_features'].shape,sample['is_longer'].shape) #torch.Size([2, 4, 1001, 64]), torch.Size([2, 1])
|
43 |
+
model = CLAP_audiomodel_withProjection(freeze=False)
|
44 |
+
audio_feat = model(sample)
|
45 |
+
print(audio_feat.shape) #torch.Size([2, 512])
|
app.py
CHANGED
@@ -1,222 +1,123 @@
|
|
1 |
"""
|
2 |
-
Search-TTA
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
"""
|
4 |
|
5 |
# ────────────────────────── imports ───────────────────────────────────
|
6 |
-
import
|
|
|
7 |
import gradio as gr
|
8 |
import torch
|
9 |
-
import numpy as np
|
10 |
from PIL import Image
|
11 |
-
import matplotlib.pyplot as plt
|
12 |
-
import io
|
13 |
-
import torchaudio
|
14 |
-
import spaces # integration with ZeroGPU on hf
|
15 |
-
|
16 |
-
from torchvision import transforms
|
17 |
-
import open_clip
|
18 |
-
from clip_vision_per_patch_model import CLIPVisionPerPatchModel
|
19 |
-
from transformers import ClapAudioModelWithProjection
|
20 |
-
from transformers import ClapProcessor
|
21 |
-
|
22 |
-
# ────────────────────────── global config & models ────────────────────
|
23 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
24 |
-
|
25 |
-
# BioCLIP (ground-image & text encoder)
|
26 |
-
bio_model, _, _ = open_clip.create_model_and_transforms("hf-hub:imageomics/bioclip")
|
27 |
-
bio_model = bio_model.to(device).eval()
|
28 |
-
bio_tokenizer = open_clip.get_tokenizer("hf-hub:imageomics/bioclip")
|
29 |
-
|
30 |
-
# Satellite patch encoder CLIP-L-336 per-patch)
|
31 |
-
sat_model: CLIPVisionPerPatchModel = (
|
32 |
-
CLIPVisionPerPatchModel.from_pretrained("derektan95/search-tta-sat")
|
33 |
-
.to(device)
|
34 |
-
.eval()
|
35 |
-
)
|
36 |
-
|
37 |
-
# Sound CLAP model
|
38 |
-
sound_model: ClapAudioModelWithProjection = (
|
39 |
-
ClapAudioModelWithProjection.from_pretrained("derektan95/search-tta-sound")
|
40 |
-
.to(device)
|
41 |
-
.eval()
|
42 |
-
)
|
43 |
-
sound_processor: ClapProcessor = ClapProcessor.from_pretrained("derektan95/search-tta-sound")
|
44 |
-
SAMPLE_RATE = 48000
|
45 |
-
|
46 |
-
logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
47 |
-
logit_scale = logit_scale.exp()
|
48 |
-
blur_kernel = (5,5)
|
49 |
-
|
50 |
-
# ────────────────────────── transforms (exact spec) ───────────────────
|
51 |
-
img_transform = transforms.Compose(
|
52 |
-
[
|
53 |
-
transforms.Resize((256, 256)),
|
54 |
-
transforms.CenterCrop((224, 224)),
|
55 |
-
transforms.ToTensor(),
|
56 |
-
transforms.Normalize(
|
57 |
-
mean=[0.485, 0.456, 0.406],
|
58 |
-
std=[0.229, 0.224, 0.225],
|
59 |
-
),
|
60 |
-
]
|
61 |
-
)
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
)
|
73 |
|
74 |
-
|
75 |
-
track, sr = torchaudio.load(path_to_audio, format=format) # torchaudio.load(path_to_audio)
|
76 |
-
track = track.mean(axis=0)
|
77 |
-
track = torchaudio.functional.resample(track, orig_freq=sr, new_freq=SAMPLE_RATE)
|
78 |
-
output = sound_processor(audios=track, sampling_rate=SAMPLE_RATE, max_length_s=10, return_tensors="pt",padding=padding,truncation=truncation)
|
79 |
-
return output
|
80 |
-
|
81 |
-
# ────────────────────────── helpers ───────────────────────────────────
|
82 |
-
|
83 |
-
@torch.no_grad()
|
84 |
-
def _encode_ground(img_pil: Image.Image) -> torch.Tensor:
|
85 |
-
img = img_transform(img_pil).unsqueeze(0).to(device)
|
86 |
-
img_embeds, *_ = bio_model(img)
|
87 |
-
return img_embeds
|
88 |
-
|
89 |
-
|
90 |
-
@torch.no_grad()
|
91 |
-
def _encode_text(text: str) -> torch.Tensor:
|
92 |
-
toks = bio_tokenizer(text).to(device)
|
93 |
-
_, txt_embeds, _ = bio_model(text=toks)
|
94 |
-
return txt_embeds
|
95 |
-
|
96 |
-
|
97 |
-
@torch.no_grad()
|
98 |
-
def _encode_sat(img_pil: Image.Image) -> torch.Tensor:
|
99 |
-
imo = imo_transform(img_pil).unsqueeze(0).to(device)
|
100 |
-
imo_embeds = sat_model(imo)
|
101 |
-
return imo_embeds
|
102 |
-
|
103 |
-
|
104 |
-
@torch.no_grad()
|
105 |
-
def _encode_sound(sound) -> torch.Tensor:
|
106 |
-
processed_sound = get_audio_clap(sound)
|
107 |
-
for k in processed_sound.keys():
|
108 |
-
processed_sound[k] = processed_sound[k].to(device)
|
109 |
-
unnormalized_audio_embeds = sound_model(**processed_sound).audio_embeds
|
110 |
-
sound_embeds = torch.nn.functional.normalize(unnormalized_audio_embeds, dim=-1)
|
111 |
-
return sound_embeds
|
112 |
-
|
113 |
-
|
114 |
-
def _similarity_heatmap(query: torch.Tensor, patches: torch.Tensor) -> np.ndarray:
|
115 |
-
sims = torch.matmul(query, patches.t()) * logit_scale
|
116 |
-
sims = sims.t().sigmoid()
|
117 |
-
sims = sims[1:].squeeze() # drop CLS token
|
118 |
-
side = int(np.sqrt(len(sims)))
|
119 |
-
sims = sims.reshape(side, side)
|
120 |
-
return sims.cpu().detach().numpy()
|
121 |
-
|
122 |
-
|
123 |
-
def _array_to_pil(arr: np.ndarray) -> Image.Image:
|
124 |
-
"""
|
125 |
-
Render arr with viridis, automatically stretching its own min→max to 0→1
|
126 |
-
so that the most-similar patches appear yellow.
|
127 |
-
"""
|
128 |
-
|
129 |
-
# Gausian Smoothing
|
130 |
-
if blur_kernel != (0,0):
|
131 |
-
arr = cv2.GaussianBlur(arr, blur_kernel, 0)
|
132 |
|
133 |
-
# --- contrast-stretch to local 0-1 range --------------------------
|
134 |
-
arr_min, arr_max = float(arr.min()), float(arr.max())
|
135 |
-
if arr_max - arr_min < 1e-6: # avoid /0 when the heat-map is flat
|
136 |
-
arr_scaled = np.zeros_like(arr)
|
137 |
-
else:
|
138 |
-
arr_scaled = (arr - arr_min) / (arr_max - arr_min)
|
139 |
-
# ------------------------------------------------------------------
|
140 |
-
fig, ax = plt.subplots(figsize=(2.6, 2.6), dpi=96)
|
141 |
-
ax.imshow(arr_scaled, cmap="viridis", vmin=0.0, vmax=1.0)
|
142 |
-
ax.axis("off")
|
143 |
-
buf = io.BytesIO()
|
144 |
-
plt.tight_layout(pad=0)
|
145 |
-
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
146 |
-
plt.close(fig)
|
147 |
-
buf.seek(0)
|
148 |
-
return Image.open(buf)
|
149 |
-
|
150 |
-
# ────────────────────────── main inference ────────────────────────────
|
151 |
-
# integration with ZeroGPU on hf
|
152 |
-
@spaces.GPU
|
153 |
def process(
|
154 |
-
sat_img: Image.Image,
|
155 |
-
taxonomy: str,
|
156 |
ground_img: Image.Image | None,
|
157 |
-
|
158 |
):
|
159 |
-
|
160 |
-
return None, None
|
161 |
-
|
162 |
-
patches = _encode_sat(sat_img)
|
163 |
-
|
164 |
-
heat_ground, heat_text, heat_sound = None, None, None
|
165 |
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
|
174 |
-
|
175 |
-
q_sound = _encode_sound(sound)
|
176 |
-
heat_sound = _array_to_pil(_similarity_heatmap(q_sound, patches))
|
177 |
|
178 |
-
|
|
|
179 |
|
180 |
|
181 |
# ────────────────────────── Gradio UI ─────────────────────────────────
|
182 |
-
with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
|
183 |
|
184 |
-
|
185 |
-
gr.Markdown(
|
186 |
"""
|
187 |
-
|
188 |
-
|
189 |
-
<h1>Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild</h1>
|
190 |
-
<span></span>
|
191 |
-
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
|
192 |
-
<a href="https://search-tta.github.io">Project Website</a>
|
193 |
-
</h2>
|
194 |
-
<span></span>
|
195 |
-
<h2 style='font-weight: 450; font-size: 0.5rem; margin: 0rem'>[Work in Progress]</h2>
|
196 |
-
</div>
|
197 |
-
</div>
|
198 |
"""
|
199 |
-
|
200 |
-
|
201 |
-
# <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
|
202 |
-
# <a href="https://derektan95.github.io">Derek M. S. Tan</a>,
|
203 |
-
# <a href="https://chinchinati.github.io/">Shailesh</a>,
|
204 |
-
# <a href="https://www.linkedin.com/in/boyang-liu-nus">Boyang Liu</a>,
|
205 |
-
# <a href="https://www.linkedin.com/in/loki-silvres">Alok Raj</a>,
|
206 |
-
# <a href="https://www.linkedin.com/in/ang-qi-xuan-714347142">Qi Xuan Ang</a>,
|
207 |
-
# <a href="https://weihengdai.top">Weiheng Dai</a>,
|
208 |
-
# <a href="https://www.linkedin.com/in/tanishqduhan">Tanishq Duhan</a>,
|
209 |
-
# <a href="https://www.linkedin.com/in/jimmychiun">Jimmy Chiun</a>,
|
210 |
-
# <a href="https://www.yuhongcao.online/">Yuhong Cao</a>,
|
211 |
-
# <a href="https://www.cs.toronto.edu/~florian/">Florian Shkurti</a>,
|
212 |
-
# <a href="https://www.marmotlab.org/bio.html">Guillaume Sartoretti</a>
|
213 |
-
# </h2>
|
214 |
-
# <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>National University of Singapore, University of Toronto, IIT-Dhanbad, Singapore Technologies Engineering</h2>
|
215 |
-
)
|
216 |
|
217 |
with gr.Row(variant="panel"):
|
218 |
-
|
219 |
-
# LEFT COLUMN (satellite, taxonomy, run)
|
220 |
with gr.Column():
|
221 |
sat_input = gr.Image(
|
222 |
label="Satellite Image",
|
@@ -228,54 +129,21 @@ with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
|
|
228 |
label="Full Taxonomy Name (optional)",
|
229 |
placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
|
230 |
)
|
231 |
-
|
232 |
-
# ─── NEW: sound input ───────────────────────────
|
233 |
-
sound_input = gr.Audio(
|
234 |
-
label="Sound Input (optional)",
|
235 |
-
sources=["upload"], # or "microphone" / "url" as you prefer
|
236 |
-
type="filepath", # or "numpy" if you want raw arrays
|
237 |
-
)
|
238 |
-
run_btn = gr.Button("Run", variant="primary")
|
239 |
-
|
240 |
-
# RIGHT COLUMN (ground image + two heat-maps)
|
241 |
-
with gr.Column():
|
242 |
ground_input = gr.Image(
|
243 |
label="Ground-level Image (optional)",
|
244 |
sources=["upload"],
|
245 |
type="pil",
|
246 |
height=320,
|
247 |
)
|
248 |
-
gr.
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
height=160,
|
256 |
-
# width=160,
|
257 |
-
)
|
258 |
-
with gr.Column(scale=1, min_width=100):
|
259 |
-
gr.Markdown("**Text Query**", elem_id="label-text")
|
260 |
-
heat_text_out = gr.Image(
|
261 |
-
show_label=False,
|
262 |
-
height=160,
|
263 |
-
# width=160,
|
264 |
-
)
|
265 |
-
with gr.Column(scale=1, min_width=100):
|
266 |
-
gr.Markdown("**Sound Query**", elem_id="label-sound")
|
267 |
-
heat_sound_out = gr.Image(
|
268 |
-
show_label=False,
|
269 |
-
height=160,
|
270 |
-
# width=160,
|
271 |
-
)
|
272 |
-
# ─── NEW: sound output ─────────────────────────
|
273 |
-
# sound_output = gr.Audio(
|
274 |
-
# label="Playback",
|
275 |
-
# )
|
276 |
-
|
277 |
|
278 |
-
# EXAMPLES
|
279 |
with gr.Row():
|
280 |
gr.Markdown("### In-Domain Taxonomy")
|
281 |
with gr.Row():
|
@@ -285,40 +153,34 @@ with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
|
|
285 |
"examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg",
|
286 |
"examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg",
|
287 |
"Animalia Chordata Aves Charadriiformes Laridae Larus marinus",
|
288 |
-
"examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3"
|
289 |
],
|
290 |
[
|
291 |
"examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg",
|
292 |
"examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg",
|
293 |
"Animalia Chordata Mammalia Rodentia Caviidae Hydrochoerus hydrochaeris",
|
294 |
-
"examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/166631961.mp3"
|
295 |
],
|
296 |
[
|
297 |
"examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg",
|
298 |
"examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg",
|
299 |
"Animalia Arthropoda Malacostraca Decapoda Ocypodidae Ocypode quadrata",
|
300 |
-
"examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/12372063.mp3"
|
301 |
],
|
302 |
[
|
303 |
"examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg",
|
304 |
"examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg",
|
305 |
"Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
|
306 |
-
"examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/59677071.mp3"
|
307 |
],
|
308 |
[
|
309 |
"examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg",
|
310 |
"examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg",
|
311 |
"Animalia Chordata Reptilia Squamata Varanidae Varanus salvator",
|
312 |
-
None
|
313 |
],
|
314 |
],
|
315 |
-
inputs=[sat_input, ground_input, taxonomy_input
|
316 |
-
outputs=[
|
317 |
-
fn=process,
|
318 |
cache_examples=False,
|
319 |
)
|
320 |
|
321 |
-
# EXAMPLES
|
322 |
with gr.Row():
|
323 |
gr.Markdown("### Out-Domain Taxonomy")
|
324 |
with gr.Row():
|
@@ -328,48 +190,45 @@ with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
|
|
328 |
"examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg",
|
329 |
"examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg",
|
330 |
"Animalia Chordata Mammalia Carnivora Phocidae Mirounga angustirostris",
|
331 |
-
"examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3123948.mp3"
|
332 |
],
|
333 |
[
|
334 |
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg",
|
335 |
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg",
|
336 |
"Animalia Chordata Mammalia Carnivora Canidae Canis aureus",
|
337 |
-
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/189318716.mp3"
|
338 |
],
|
339 |
[
|
340 |
"examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus/yosemite_v3_resized.png",
|
341 |
"examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus/248820933.jpeg",
|
342 |
"Animalia Chordata Mammalia Carnivora Ursidae Ursus americanus",
|
343 |
-
None
|
344 |
],
|
345 |
[
|
346 |
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg",
|
347 |
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg",
|
348 |
"Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis",
|
349 |
-
None
|
350 |
],
|
351 |
],
|
352 |
-
inputs=[sat_input, ground_input, taxonomy_input
|
353 |
-
outputs=[
|
354 |
-
fn=process,
|
355 |
cache_examples=False,
|
356 |
)
|
357 |
|
358 |
-
|
359 |
run_btn.click(
|
360 |
fn=process,
|
361 |
-
inputs=[sat_input,
|
362 |
-
outputs=
|
363 |
)
|
364 |
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
)
|
371 |
|
372 |
-
#
|
373 |
if __name__ == "__main__":
|
|
|
|
|
374 |
demo.queue(max_size=15)
|
375 |
demo.launch(share=True)
|
|
|
1 |
"""
|
2 |
+
Simplified Gradio demo for Search-TTA evaluation.
|
3 |
+
This version mirrors the layout of `app_BACKUP.py` but:
|
4 |
+
1. Loads no OpenCLIP / CLAP / Satellite encoders at import-time.
|
5 |
+
2. Keeps only the Satellite and Ground-level image inputs.
|
6 |
+
3. Exposes the high-level wrapper classes `ClipSegTTA` and
|
7 |
+
`TestWorker` and calls `TestWorker.run_episode` inside the
|
8 |
+
`process` callback.
|
9 |
"""
|
10 |
|
11 |
# ────────────────────────── imports ───────────────────────────────────
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
import gradio as gr
|
15 |
import torch
|
|
|
16 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
# Import configuration & RL / TTA utilities -------------------------------------------------
|
19 |
+
# NOTE: we import * so that the global names (e.g. USE_GPU, MODEL_NAME, etc.)
|
20 |
+
# are available exactly as referenced later in the unchanged snippet.
|
21 |
+
from test_parameter import * # noqa: F403, F401 (wild-import is intentional here)
|
22 |
+
|
23 |
+
from model import PolicyNet # noqa: E402 – after wild import on purpose
|
24 |
+
from test_multi_robot_worker import TestWorker # noqa: E402
|
25 |
+
# from Taxabind.TaxaBind.SatBind.clip_seg_tta import ClipSegTTA # noqa: E402
|
26 |
+
from Taxabind.Taxabind.SatBind.clip_seg_tta import ClipSegTTA
|
27 |
+
|
28 |
+
|
29 |
+
# CHANGE ME!
|
30 |
+
currEpisode = 0
|
31 |
+
|
32 |
+
# Prepare the model
|
33 |
+
# device = torch.device('cpu') #if USE_GPU_TRAINING else torch.device('cpu')
|
34 |
+
device = torch.device('cuda') if USE_GPU else torch.device('cpu')
|
35 |
+
policy_net = PolicyNet(INPUT_DIM, EMBEDDING_DIM).to(device)
|
36 |
+
# script_dir = os.path.dirname(os.path.abspath(__file__))
|
37 |
+
script_dir = Path(__file__).resolve().parent
|
38 |
+
print("real_script_dir: ", script_dir)
|
39 |
+
# checkpoint = torch.load(f'{script_dir}/modules/vlm_search/{model_path}/{MODEL_NAME}')
|
40 |
+
checkpoint = torch.load(f'{model_path}/{MODEL_NAME}')
|
41 |
+
policy_net.load_state_dict(checkpoint['policy_model'])
|
42 |
+
print('Model loaded!')
|
43 |
+
# print(next(policy_net.parameters()).device)
|
44 |
+
|
45 |
+
# Init Taxabind here (only need to init once)
|
46 |
+
if TAXABIND_TTA:
|
47 |
+
# self.clip_seg_tta = None
|
48 |
+
clip_seg_tta = ClipSegTTA(
|
49 |
+
img_dir=TAXABIND_IMG_DIR,
|
50 |
+
imo_dir=TAXABIND_IMO_DIR,
|
51 |
+
json_path=TAXABIND_INAT_JSON_PATH,
|
52 |
+
sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH,
|
53 |
+
patch_size=TAXABIND_PATCH_SIZE,
|
54 |
+
sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH,
|
55 |
+
sample_index = 0, # Set using 'reset' in worker
|
56 |
+
blur_kernel = TAXABIND_GAUSSIAN_BLUR_KERNEL,
|
57 |
+
device=device,
|
58 |
+
sat_to_img_ids_json_is_train_dict=False, # for search ds val
|
59 |
+
tax_to_filter_val=QUERY_TAX,
|
60 |
+
load_model=USE_CLIP_PREDS,
|
61 |
+
initial_modality=INITIAL_MODALITY,
|
62 |
+
sound_data_path = TAXABIND_SOUND_DATA_PATH,
|
63 |
+
sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH,
|
64 |
+
# sat_filtered_json_path=TAXABIND_FILTERED_INAT_JSON_PATH,
|
65 |
+
)
|
66 |
+
print("ClipSegTTA Loaded!")
|
67 |
+
else:
|
68 |
+
clip_seg_tta = None
|
69 |
+
|
70 |
+
# Define TestWorker
|
71 |
+
planner = TestWorker(
|
72 |
+
meta_agent_id=0,
|
73 |
+
n_agent=1,
|
74 |
+
policy_net=policy_net,
|
75 |
+
global_step=3,
|
76 |
+
device='cuda',
|
77 |
+
greedy=True,
|
78 |
+
save_image=SAVE_GIFS,
|
79 |
+
clip_seg_tta=clip_seg_tta
|
80 |
)
|
81 |
|
82 |
+
# ────────────────────────── Gradio process fn ─────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
def process(
|
85 |
+
sat_img: Image.Image | None,
|
|
|
86 |
ground_img: Image.Image | None,
|
87 |
+
taxonomy: str | None = None,
|
88 |
):
|
89 |
+
"""Callback executed when the user presses **Run** in the UI.
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
+
At test-time we simply trigger the RL search episode via
|
92 |
+
``planner.run_episode`` and return its performance metrics.
|
93 |
+
The image inputs are currently *not* used directly here but are
|
94 |
+
retained to conform to the requested interface.
|
95 |
+
"""
|
96 |
+
# If no satellite image is provided we bail out early.
|
97 |
+
if sat_img is None:
|
98 |
+
return {"error": "Please provide a satellite image."}
|
99 |
|
100 |
+
# Optionally you may want to reset episode index or make it configurable.
|
101 |
+
# For now we hard-code episode 0, mirroring the snippet.
|
102 |
+
planner.run_episode(currEpisode)
|
103 |
|
104 |
+
print("planner.perf_metrics: ", planner.perf_metrics)
|
|
|
|
|
105 |
|
106 |
+
# Return the collected performance metrics so they can be inspected.
|
107 |
+
return planner.perf_metrics
|
108 |
|
109 |
|
110 |
# ────────────────────────── Gradio UI ─────────────────────────────────
|
111 |
+
with gr.Blocks(title="Search-TTA (Simplified)", theme=gr.themes.Base()) as demo:
|
112 |
|
113 |
+
gr.Markdown(
|
|
|
114 |
"""
|
115 |
+
# Search-TTA – Simplified Demo
|
116 |
+
**Satellite ↔ Ground-level Visual Search** via RL Test-Time Adaptation.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
"""
|
118 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
with gr.Row(variant="panel"):
|
|
|
|
|
121 |
with gr.Column():
|
122 |
sat_input = gr.Image(
|
123 |
label="Satellite Image",
|
|
|
129 |
label="Full Taxonomy Name (optional)",
|
130 |
placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
|
131 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
ground_input = gr.Image(
|
133 |
label="Ground-level Image (optional)",
|
134 |
sources=["upload"],
|
135 |
type="pil",
|
136 |
height=320,
|
137 |
)
|
138 |
+
run_btn = gr.Button("Run", variant="primary")
|
139 |
+
|
140 |
+
with gr.Column():
|
141 |
+
gr.Markdown("### Episode Metrics")
|
142 |
+
metrics_out = gr.JSON(label="Performance Metrics")
|
143 |
+
|
144 |
+
# Bind callback
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
+
# EXAMPLES – copied from original demo (satellite, ground, taxonomy only)
|
147 |
with gr.Row():
|
148 |
gr.Markdown("### In-Domain Taxonomy")
|
149 |
with gr.Row():
|
|
|
153 |
"examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg",
|
154 |
"examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg",
|
155 |
"Animalia Chordata Aves Charadriiformes Laridae Larus marinus",
|
|
|
156 |
],
|
157 |
[
|
158 |
"examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg",
|
159 |
"examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg",
|
160 |
"Animalia Chordata Mammalia Rodentia Caviidae Hydrochoerus hydrochaeris",
|
|
|
161 |
],
|
162 |
[
|
163 |
"examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg",
|
164 |
"examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg",
|
165 |
"Animalia Arthropoda Malacostraca Decapoda Ocypodidae Ocypode quadrata",
|
|
|
166 |
],
|
167 |
[
|
168 |
"examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg",
|
169 |
"examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg",
|
170 |
"Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
|
|
|
171 |
],
|
172 |
[
|
173 |
"examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg",
|
174 |
"examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg",
|
175 |
"Animalia Chordata Reptilia Squamata Varanidae Varanus salvator",
|
|
|
176 |
],
|
177 |
],
|
178 |
+
inputs=[sat_input, ground_input, taxonomy_input],
|
179 |
+
outputs=[metrics_out],
|
180 |
+
fn=lambda sat, grd, tax: process(sat, grd),
|
181 |
cache_examples=False,
|
182 |
)
|
183 |
|
|
|
184 |
with gr.Row():
|
185 |
gr.Markdown("### Out-Domain Taxonomy")
|
186 |
with gr.Row():
|
|
|
190 |
"examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg",
|
191 |
"examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg",
|
192 |
"Animalia Chordata Mammalia Carnivora Phocidae Mirounga angustirostris",
|
|
|
193 |
],
|
194 |
[
|
195 |
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg",
|
196 |
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg",
|
197 |
"Animalia Chordata Mammalia Carnivora Canidae Canis aureus",
|
|
|
198 |
],
|
199 |
[
|
200 |
"examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus/yosemite_v3_resized.png",
|
201 |
"examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus/248820933.jpeg",
|
202 |
"Animalia Chordata Mammalia Carnivora Ursidae Ursus americanus",
|
|
|
203 |
],
|
204 |
[
|
205 |
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg",
|
206 |
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg",
|
207 |
"Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis",
|
|
|
208 |
],
|
209 |
],
|
210 |
+
inputs=[sat_input, ground_input, taxonomy_input],
|
211 |
+
outputs=[metrics_out],
|
212 |
+
fn=lambda sat, grd, tax: process(sat, grd),
|
213 |
cache_examples=False,
|
214 |
)
|
215 |
|
216 |
+
|
217 |
run_btn.click(
|
218 |
fn=process,
|
219 |
+
inputs=[sat_input, ground_input, taxonomy_input],
|
220 |
+
outputs=metrics_out,
|
221 |
)
|
222 |
|
223 |
+
# ────────────────────────── unchanged worker initialisation ───────────
|
224 |
+
# NOTE: **Do NOT modify the code below.** It is copied verbatim from the
|
225 |
+
# user-provided snippet so that the exact same objects are created.
|
226 |
+
# The variables referenced here come from `test_parameter` which we
|
227 |
+
# imported with a wildcard earlier.
|
|
|
228 |
|
229 |
+
# if def main
|
230 |
if __name__ == "__main__":
|
231 |
+
|
232 |
+
# Finally launch the Gradio interface (queue for concurrency).
|
233 |
demo.queue(max_size=15)
|
234 |
demo.launch(share=True)
|
app_BACKUP.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Search-TTA demo
|
3 |
+
"""
|
4 |
+
|
5 |
+
# ────────────────────────── imports ───────────────────────────────────
|
6 |
+
import cv2
|
7 |
+
import gradio as gr
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import io
|
13 |
+
import torchaudio
|
14 |
+
import spaces # integration with ZeroGPU on hf
|
15 |
+
|
16 |
+
from torchvision import transforms
|
17 |
+
import open_clip
|
18 |
+
from clip_vision_per_patch_model import CLIPVisionPerPatchModel
|
19 |
+
from transformers import ClapAudioModelWithProjection
|
20 |
+
from transformers import ClapProcessor
|
21 |
+
|
22 |
+
# ────────────────────────── global config & models ────────────────────
|
23 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
24 |
+
|
25 |
+
# BioCLIP (ground-image & text encoder)
|
26 |
+
bio_model, _, _ = open_clip.create_model_and_transforms("hf-hub:imageomics/bioclip")
|
27 |
+
bio_model = bio_model.to(device).eval()
|
28 |
+
bio_tokenizer = open_clip.get_tokenizer("hf-hub:imageomics/bioclip")
|
29 |
+
|
30 |
+
# Satellite patch encoder CLIP-L-336 per-patch)
|
31 |
+
sat_model: CLIPVisionPerPatchModel = (
|
32 |
+
CLIPVisionPerPatchModel.from_pretrained("derektan95/search-tta-sat")
|
33 |
+
.to(device)
|
34 |
+
.eval()
|
35 |
+
)
|
36 |
+
|
37 |
+
# Sound CLAP model
|
38 |
+
sound_model: ClapAudioModelWithProjection = (
|
39 |
+
ClapAudioModelWithProjection.from_pretrained("derektan95/search-tta-sound")
|
40 |
+
.to(device)
|
41 |
+
.eval()
|
42 |
+
)
|
43 |
+
sound_processor: ClapProcessor = ClapProcessor.from_pretrained("derektan95/search-tta-sound")
|
44 |
+
SAMPLE_RATE = 48000
|
45 |
+
|
46 |
+
logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
47 |
+
logit_scale = logit_scale.exp()
|
48 |
+
blur_kernel = (5,5)
|
49 |
+
|
50 |
+
# ────────────────────────── transforms (exact spec) ───────────────────
|
51 |
+
img_transform = transforms.Compose(
|
52 |
+
[
|
53 |
+
transforms.Resize((256, 256)),
|
54 |
+
transforms.CenterCrop((224, 224)),
|
55 |
+
transforms.ToTensor(),
|
56 |
+
transforms.Normalize(
|
57 |
+
mean=[0.485, 0.456, 0.406],
|
58 |
+
std=[0.229, 0.224, 0.225],
|
59 |
+
),
|
60 |
+
]
|
61 |
+
)
|
62 |
+
|
63 |
+
imo_transform = transforms.Compose(
|
64 |
+
[
|
65 |
+
transforms.Resize((336, 336)),
|
66 |
+
transforms.ToTensor(),
|
67 |
+
transforms.Normalize(
|
68 |
+
mean=[0.485, 0.456, 0.406],
|
69 |
+
std=[0.229, 0.224, 0.225],
|
70 |
+
),
|
71 |
+
]
|
72 |
+
)
|
73 |
+
|
74 |
+
def get_audio_clap(path_to_audio,format="mp3",padding="repeatpad",truncation="fusion"):
|
75 |
+
track, sr = torchaudio.load(path_to_audio, format=format) # torchaudio.load(path_to_audio)
|
76 |
+
track = track.mean(axis=0)
|
77 |
+
track = torchaudio.functional.resample(track, orig_freq=sr, new_freq=SAMPLE_RATE)
|
78 |
+
output = sound_processor(audios=track, sampling_rate=SAMPLE_RATE, max_length_s=10, return_tensors="pt",padding=padding,truncation=truncation)
|
79 |
+
return output
|
80 |
+
|
81 |
+
# ────────────────────────── helpers ───────────────────────────────────
|
82 |
+
|
83 |
+
@torch.no_grad()
|
84 |
+
def _encode_ground(img_pil: Image.Image) -> torch.Tensor:
|
85 |
+
img = img_transform(img_pil).unsqueeze(0).to(device)
|
86 |
+
img_embeds, *_ = bio_model(img)
|
87 |
+
return img_embeds
|
88 |
+
|
89 |
+
|
90 |
+
@torch.no_grad()
|
91 |
+
def _encode_text(text: str) -> torch.Tensor:
|
92 |
+
toks = bio_tokenizer(text).to(device)
|
93 |
+
_, txt_embeds, _ = bio_model(text=toks)
|
94 |
+
return txt_embeds
|
95 |
+
|
96 |
+
|
97 |
+
@torch.no_grad()
|
98 |
+
def _encode_sat(img_pil: Image.Image) -> torch.Tensor:
|
99 |
+
imo = imo_transform(img_pil).unsqueeze(0).to(device)
|
100 |
+
imo_embeds = sat_model(imo)
|
101 |
+
return imo_embeds
|
102 |
+
|
103 |
+
|
104 |
+
@torch.no_grad()
|
105 |
+
def _encode_sound(sound) -> torch.Tensor:
|
106 |
+
processed_sound = get_audio_clap(sound)
|
107 |
+
for k in processed_sound.keys():
|
108 |
+
processed_sound[k] = processed_sound[k].to(device)
|
109 |
+
unnormalized_audio_embeds = sound_model(**processed_sound).audio_embeds
|
110 |
+
sound_embeds = torch.nn.functional.normalize(unnormalized_audio_embeds, dim=-1)
|
111 |
+
return sound_embeds
|
112 |
+
|
113 |
+
|
114 |
+
def _similarity_heatmap(query: torch.Tensor, patches: torch.Tensor) -> np.ndarray:
|
115 |
+
sims = torch.matmul(query, patches.t()) * logit_scale
|
116 |
+
sims = sims.t().sigmoid()
|
117 |
+
sims = sims[1:].squeeze() # drop CLS token
|
118 |
+
side = int(np.sqrt(len(sims)))
|
119 |
+
sims = sims.reshape(side, side)
|
120 |
+
return sims.cpu().detach().numpy()
|
121 |
+
|
122 |
+
|
123 |
+
def _array_to_pil(arr: np.ndarray) -> Image.Image:
|
124 |
+
"""
|
125 |
+
Render arr with viridis, automatically stretching its own min→max to 0→1
|
126 |
+
so that the most-similar patches appear yellow.
|
127 |
+
"""
|
128 |
+
|
129 |
+
# Gausian Smoothing
|
130 |
+
if blur_kernel != (0,0):
|
131 |
+
arr = cv2.GaussianBlur(arr, blur_kernel, 0)
|
132 |
+
|
133 |
+
# --- contrast-stretch to local 0-1 range --------------------------
|
134 |
+
arr_min, arr_max = float(arr.min()), float(arr.max())
|
135 |
+
if arr_max - arr_min < 1e-6: # avoid /0 when the heat-map is flat
|
136 |
+
arr_scaled = np.zeros_like(arr)
|
137 |
+
else:
|
138 |
+
arr_scaled = (arr - arr_min) / (arr_max - arr_min)
|
139 |
+
# ------------------------------------------------------------------
|
140 |
+
fig, ax = plt.subplots(figsize=(2.6, 2.6), dpi=96)
|
141 |
+
ax.imshow(arr_scaled, cmap="viridis", vmin=0.0, vmax=1.0)
|
142 |
+
ax.axis("off")
|
143 |
+
buf = io.BytesIO()
|
144 |
+
plt.tight_layout(pad=0)
|
145 |
+
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
146 |
+
plt.close(fig)
|
147 |
+
buf.seek(0)
|
148 |
+
return Image.open(buf)
|
149 |
+
|
150 |
+
# ────────────────────────── main inference ────────────────────────────
|
151 |
+
# integration with ZeroGPU on hf
|
152 |
+
@spaces.GPU
|
153 |
+
def process(
|
154 |
+
sat_img: Image.Image,
|
155 |
+
taxonomy: str,
|
156 |
+
ground_img: Image.Image | None,
|
157 |
+
sound: torch.Tensor | None,
|
158 |
+
):
|
159 |
+
if sat_img is None:
|
160 |
+
return None, None
|
161 |
+
|
162 |
+
patches = _encode_sat(sat_img)
|
163 |
+
|
164 |
+
heat_ground, heat_text, heat_sound = None, None, None
|
165 |
+
|
166 |
+
if ground_img is not None:
|
167 |
+
q_img = _encode_ground(ground_img)
|
168 |
+
heat_ground = _array_to_pil(_similarity_heatmap(q_img, patches))
|
169 |
+
|
170 |
+
if taxonomy.strip():
|
171 |
+
q_txt = _encode_text(taxonomy.strip())
|
172 |
+
heat_text = _array_to_pil(_similarity_heatmap(q_txt, patches))
|
173 |
+
|
174 |
+
if sound is not None:
|
175 |
+
q_sound = _encode_sound(sound)
|
176 |
+
heat_sound = _array_to_pil(_similarity_heatmap(q_sound, patches))
|
177 |
+
|
178 |
+
return heat_ground, heat_text, heat_sound
|
179 |
+
|
180 |
+
|
181 |
+
# ────────────────────────── Gradio UI ─────────────────────────────────
|
182 |
+
with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
|
183 |
+
|
184 |
+
with gr.Row():
|
185 |
+
gr.Markdown(
|
186 |
+
"""
|
187 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
188 |
+
<div>
|
189 |
+
<h1>Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild</h1>
|
190 |
+
<span></span>
|
191 |
+
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
|
192 |
+
<a href="https://search-tta.github.io">Project Website</a>
|
193 |
+
</h2>
|
194 |
+
<span></span>
|
195 |
+
<h2 style='font-weight: 450; font-size: 0.5rem; margin: 0rem'>[Work in Progress]</h2>
|
196 |
+
</div>
|
197 |
+
</div>
|
198 |
+
"""
|
199 |
+
# <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>WACV 2025</h2>
|
200 |
+
|
201 |
+
# <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
|
202 |
+
# <a href="https://derektan95.github.io">Derek M. S. Tan</a>,
|
203 |
+
# <a href="https://chinchinati.github.io/">Shailesh</a>,
|
204 |
+
# <a href="https://www.linkedin.com/in/boyang-liu-nus">Boyang Liu</a>,
|
205 |
+
# <a href="https://www.linkedin.com/in/loki-silvres">Alok Raj</a>,
|
206 |
+
# <a href="https://www.linkedin.com/in/ang-qi-xuan-714347142">Qi Xuan Ang</a>,
|
207 |
+
# <a href="https://weihengdai.top">Weiheng Dai</a>,
|
208 |
+
# <a href="https://www.linkedin.com/in/tanishqduhan">Tanishq Duhan</a>,
|
209 |
+
# <a href="https://www.linkedin.com/in/jimmychiun">Jimmy Chiun</a>,
|
210 |
+
# <a href="https://www.yuhongcao.online/">Yuhong Cao</a>,
|
211 |
+
# <a href="https://www.cs.toronto.edu/~florian/">Florian Shkurti</a>,
|
212 |
+
# <a href="https://www.marmotlab.org/bio.html">Guillaume Sartoretti</a>
|
213 |
+
# </h2>
|
214 |
+
# <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>National University of Singapore, University of Toronto, IIT-Dhanbad, Singapore Technologies Engineering</h2>
|
215 |
+
)
|
216 |
+
|
217 |
+
with gr.Row(variant="panel"):
|
218 |
+
|
219 |
+
# LEFT COLUMN (satellite, taxonomy, run)
|
220 |
+
with gr.Column():
|
221 |
+
sat_input = gr.Image(
|
222 |
+
label="Satellite Image",
|
223 |
+
sources=["upload"],
|
224 |
+
type="pil",
|
225 |
+
height=320,
|
226 |
+
)
|
227 |
+
taxonomy_input = gr.Textbox(
|
228 |
+
label="Full Taxonomy Name (optional)",
|
229 |
+
placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
|
230 |
+
)
|
231 |
+
|
232 |
+
# ─── NEW: sound input ───────────────────────────
|
233 |
+
sound_input = gr.Audio(
|
234 |
+
label="Sound Input (optional)",
|
235 |
+
sources=["upload"], # or "microphone" / "url" as you prefer
|
236 |
+
type="filepath", # or "numpy" if you want raw arrays
|
237 |
+
)
|
238 |
+
run_btn = gr.Button("Run", variant="primary")
|
239 |
+
|
240 |
+
# RIGHT COLUMN (ground image + two heat-maps)
|
241 |
+
with gr.Column():
|
242 |
+
ground_input = gr.Image(
|
243 |
+
label="Ground-level Image (optional)",
|
244 |
+
sources=["upload"],
|
245 |
+
type="pil",
|
246 |
+
height=320,
|
247 |
+
)
|
248 |
+
gr.Markdown("### Heat-map Results")
|
249 |
+
with gr.Row():
|
250 |
+
# Separate label and image to avoid overlap
|
251 |
+
with gr.Column(scale=1, min_width=100):
|
252 |
+
gr.Markdown("**Ground Image Query**", elem_id="label-ground")
|
253 |
+
heat_ground_out = gr.Image(
|
254 |
+
show_label=False,
|
255 |
+
height=160,
|
256 |
+
# width=160,
|
257 |
+
)
|
258 |
+
with gr.Column(scale=1, min_width=100):
|
259 |
+
gr.Markdown("**Text Query**", elem_id="label-text")
|
260 |
+
heat_text_out = gr.Image(
|
261 |
+
show_label=False,
|
262 |
+
height=160,
|
263 |
+
# width=160,
|
264 |
+
)
|
265 |
+
with gr.Column(scale=1, min_width=100):
|
266 |
+
gr.Markdown("**Sound Query**", elem_id="label-sound")
|
267 |
+
heat_sound_out = gr.Image(
|
268 |
+
show_label=False,
|
269 |
+
height=160,
|
270 |
+
# width=160,
|
271 |
+
)
|
272 |
+
# ─── NEW: sound output ─────────────────────────
|
273 |
+
# sound_output = gr.Audio(
|
274 |
+
# label="Playback",
|
275 |
+
# )
|
276 |
+
|
277 |
+
|
278 |
+
# EXAMPLES
|
279 |
+
with gr.Row():
|
280 |
+
gr.Markdown("### In-Domain Taxonomy")
|
281 |
+
with gr.Row():
|
282 |
+
gr.Examples(
|
283 |
+
examples=[
|
284 |
+
[
|
285 |
+
"examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg",
|
286 |
+
"examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg",
|
287 |
+
"Animalia Chordata Aves Charadriiformes Laridae Larus marinus",
|
288 |
+
"examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3"
|
289 |
+
],
|
290 |
+
[
|
291 |
+
"examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg",
|
292 |
+
"examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg",
|
293 |
+
"Animalia Chordata Mammalia Rodentia Caviidae Hydrochoerus hydrochaeris",
|
294 |
+
"examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/166631961.mp3"
|
295 |
+
],
|
296 |
+
[
|
297 |
+
"examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg",
|
298 |
+
"examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg",
|
299 |
+
"Animalia Arthropoda Malacostraca Decapoda Ocypodidae Ocypode quadrata",
|
300 |
+
"examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/12372063.mp3"
|
301 |
+
],
|
302 |
+
[
|
303 |
+
"examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg",
|
304 |
+
"examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg",
|
305 |
+
"Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
|
306 |
+
"examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/59677071.mp3"
|
307 |
+
],
|
308 |
+
[
|
309 |
+
"examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg",
|
310 |
+
"examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg",
|
311 |
+
"Animalia Chordata Reptilia Squamata Varanidae Varanus salvator",
|
312 |
+
None
|
313 |
+
],
|
314 |
+
],
|
315 |
+
inputs=[sat_input, ground_input, taxonomy_input, sound_input],
|
316 |
+
outputs=[heat_ground_out, heat_text_out, heat_sound_out],
|
317 |
+
fn=process,
|
318 |
+
cache_examples=False,
|
319 |
+
)
|
320 |
+
|
321 |
+
# EXAMPLES
|
322 |
+
with gr.Row():
|
323 |
+
gr.Markdown("### Out-Domain Taxonomy")
|
324 |
+
with gr.Row():
|
325 |
+
gr.Examples(
|
326 |
+
examples=[
|
327 |
+
[
|
328 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg",
|
329 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg",
|
330 |
+
"Animalia Chordata Mammalia Carnivora Phocidae Mirounga angustirostris",
|
331 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3123948.mp3"
|
332 |
+
],
|
333 |
+
[
|
334 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg",
|
335 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg",
|
336 |
+
"Animalia Chordata Mammalia Carnivora Canidae Canis aureus",
|
337 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/189318716.mp3"
|
338 |
+
],
|
339 |
+
[
|
340 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus/yosemite_v3_resized.png",
|
341 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus/248820933.jpeg",
|
342 |
+
"Animalia Chordata Mammalia Carnivora Ursidae Ursus americanus",
|
343 |
+
None
|
344 |
+
],
|
345 |
+
[
|
346 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg",
|
347 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg",
|
348 |
+
"Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis",
|
349 |
+
None
|
350 |
+
],
|
351 |
+
],
|
352 |
+
inputs=[sat_input, ground_input, taxonomy_input, sound_input],
|
353 |
+
outputs=[heat_ground_out, heat_text_out, heat_sound_out],
|
354 |
+
fn=process,
|
355 |
+
cache_examples=False,
|
356 |
+
)
|
357 |
+
|
358 |
+
# CALLBACK
|
359 |
+
run_btn.click(
|
360 |
+
fn=process,
|
361 |
+
inputs=[sat_input, taxonomy_input, ground_input, sound_input],
|
362 |
+
outputs=[heat_ground_out, heat_text_out, heat_sound_out],
|
363 |
+
)
|
364 |
+
|
365 |
+
# Footer to point out to model and data from app page.
|
366 |
+
gr.Markdown(
|
367 |
+
"""
|
368 |
+
The satellite image CLIP encoder is fine-tuned using [Sentinel-2 Level 2A](https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/) satellite image and taxonomy images (with GPS locations) from [iNaturalist](https://inaturalist.org/). The sound CLIP encoder is fine-tuned with a subset of the same taxonomy images and their corresponding sounds from [iNaturalist](https://inaturalist.org/). Note that while some of the examples above result in poor probability distributions, they will be improved using our test-time adaptation framework during the search process.
|
369 |
+
"""
|
370 |
+
)
|
371 |
+
|
372 |
+
# LAUNCH
|
373 |
+
if __name__ == "__main__":
|
374 |
+
demo.queue(max_size=15)
|
375 |
+
demo.launch(share=True)
|
env.py
ADDED
@@ -0,0 +1,811 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#######################################################################
|
2 |
+
# Name: env.py
|
3 |
+
#
|
4 |
+
# - Reads and processes training and test maps (E.g. DungeonMaps)
|
5 |
+
# - Processes rewards, new frontiers given action
|
6 |
+
# - Updates a graph representation of environment for input into network
|
7 |
+
#######################################################################
|
8 |
+
|
9 |
+
import sys
|
10 |
+
|
11 |
+
import cv2
|
12 |
+
from matplotlib.colors import LogNorm, PowerNorm
|
13 |
+
if sys.modules['TRAINING']:
|
14 |
+
from parameter import *
|
15 |
+
else:
|
16 |
+
from test_parameter import *
|
17 |
+
|
18 |
+
import copy
|
19 |
+
import pandas as pd
|
20 |
+
import rasterio
|
21 |
+
from skimage import io
|
22 |
+
import matplotlib.pyplot as plt
|
23 |
+
import os
|
24 |
+
from skimage.measure import block_reduce
|
25 |
+
from sensor import *
|
26 |
+
from graph_generator import *
|
27 |
+
from node import *
|
28 |
+
from scipy.ndimage import label, find_objects
|
29 |
+
import matplotlib.image as mpimg
|
30 |
+
from matplotlib.colors import Normalize
|
31 |
+
|
32 |
+
# import matplotlib
|
33 |
+
# matplotlib.use("Agg") # <-- key line to avoid tkinter dependency
|
34 |
+
|
35 |
+
|
36 |
+
class Env():
|
37 |
+
def __init__(self, map_index, n_agent, k_size=20, plot=False, test=False, mask_index=None):
|
38 |
+
self.n_agent = n_agent
|
39 |
+
self.test = test
|
40 |
+
self.map_dir = GRIDMAP_SET_DIR
|
41 |
+
|
42 |
+
# Import environment gridmap
|
43 |
+
self.map_list = os.listdir(self.map_dir)
|
44 |
+
self.map_list.sort(reverse=True)
|
45 |
+
|
46 |
+
# NEW: Import segmentation utility map
|
47 |
+
self.seg_dir = MASK_SET_DIR
|
48 |
+
self.segmentation_mask, self.target_positions, self.target_found_idxs = None, [], []
|
49 |
+
self.segmentation_mask_list = os.listdir(self.seg_dir)
|
50 |
+
self.segmentation_mask_list.sort(reverse=True)
|
51 |
+
|
52 |
+
# Import target maps (if relevant)
|
53 |
+
if TARGETS_SET_DIR != "":
|
54 |
+
self.targets_map_list = os.listdir(TARGETS_SET_DIR)
|
55 |
+
self.targets_map_list.sort(reverse=True)
|
56 |
+
|
57 |
+
# # NEW: Find common files in both directories
|
58 |
+
# if TARGETS_SET_DIR != "":
|
59 |
+
# common_files = [file for file in self.map_list if file in self.segmentation_mask_list and file in self.targets_map_list]
|
60 |
+
# else:
|
61 |
+
# common_files = [file for file in self.map_list if file in self.segmentation_mask_list]
|
62 |
+
self.map_index = map_index % len(self.map_list)
|
63 |
+
if mask_index is not None:
|
64 |
+
self.mask_index = mask_index % len(self.segmentation_mask_list)
|
65 |
+
else:
|
66 |
+
self.mask_index = map_index % len(self.segmentation_mask_list)
|
67 |
+
# self.common_map_file = common_files[self.map_index]
|
68 |
+
# print("self.common_map_file: ", self.common_map_file)
|
69 |
+
|
70 |
+
# Import ground truth and segmentation mask
|
71 |
+
self.ground_truth, self.map_start_position = self.import_ground_truth(
|
72 |
+
os.path.join(self.map_dir, self.map_list[self.map_index]))# self.common_map_file))
|
73 |
+
self.ground_truth_size = np.shape(self.ground_truth) # (480, 640)
|
74 |
+
self.robot_belief = np.ones(self.ground_truth_size) * 127 # unexplored 127
|
75 |
+
self.downsampled_belief = None
|
76 |
+
self.old_robot_belief = copy.deepcopy(self.robot_belief)
|
77 |
+
self.coverage_belief = np.ones(self.ground_truth_size) * 127 # unexplored 127
|
78 |
+
|
79 |
+
# Import segmentation mask
|
80 |
+
mask_filename = self.segmentation_mask_list[self.mask_index]
|
81 |
+
self.segmentation_mask = self.import_segmentation_mask(
|
82 |
+
os.path.join(self.seg_dir, mask_filename))# self.common_map_file))
|
83 |
+
# print("mask_filename: ", mask_filename)
|
84 |
+
|
85 |
+
# Overwrite target positions if directory specified
|
86 |
+
self.gt_segmentation_mask = None
|
87 |
+
if self.test and TARGETS_SET_DIR != "":
|
88 |
+
self.gt_segmentation_mask = self.import_segmentation_mask(
|
89 |
+
os.path.join(TARGETS_SET_DIR, self.map_list[self.map_index])) # UNUSED - self.common_map_file))
|
90 |
+
# print("target_positions: ", self.target_positions)
|
91 |
+
# print("np.unique(self.segmentation_mask): ", np.unique(self.segmentation_mask))
|
92 |
+
|
93 |
+
self.segmentation_info_mask = None
|
94 |
+
self.gt_segmentation_info_mask = None
|
95 |
+
self.segmentation_info_mask_unnormalized = None
|
96 |
+
self.filtered_seg_info_mask = None
|
97 |
+
self.num_targets_found = 0
|
98 |
+
self.num_new_targets_found = 0
|
99 |
+
|
100 |
+
# # Link score masks to raw image files
|
101 |
+
# csv_file = pd.read_csv(RAW_IMG_PATH_DICT, header=None)
|
102 |
+
# img_score_paths = csv_file.iloc[:,2].tolist()
|
103 |
+
# raw_img_paths = csv_file.iloc[:,0].tolist()
|
104 |
+
# self.score_to_img_dict = {os.path.basename(img_score_path): raw_img_path for img_score_path, raw_img_path in zip(img_score_paths, raw_img_paths)}
|
105 |
+
|
106 |
+
self.resolution = 4
|
107 |
+
self.sensor_range = SENSOR_RANGE
|
108 |
+
self.explored_rate = 0
|
109 |
+
self.targets_found_rate = 0
|
110 |
+
self.info_gain = 0
|
111 |
+
self.total_info = 0
|
112 |
+
|
113 |
+
self.graph_generator = Graph_generator(map_size=self.ground_truth_size, sensor_range=self.sensor_range, k_size=k_size, plot=plot)
|
114 |
+
self.node_coords, self.graph, self.node_utility, self.guidepost = None, None, None, None
|
115 |
+
|
116 |
+
self.frontiers = None
|
117 |
+
|
118 |
+
self.start_positions = []
|
119 |
+
self.begin(self.map_start_position)
|
120 |
+
|
121 |
+
self.plot = plot
|
122 |
+
self.frame_files = []
|
123 |
+
|
124 |
+
def find_index_from_coords(self, position):
|
125 |
+
index = np.argmin(np.linalg.norm(self.node_coords - position, axis=1))
|
126 |
+
return index
|
127 |
+
|
128 |
+
def begin(self, start_position):
|
129 |
+
# self.robot_belief = self.update_robot_belief(robot_position, self.sensor_range, self.robot_belief,
|
130 |
+
# self.ground_truth)
|
131 |
+
self.robot_belief = self.ground_truth
|
132 |
+
|
133 |
+
self.downsampled_belief = block_reduce(self.robot_belief.copy(), block_size=(self.resolution, self.resolution),
|
134 |
+
func=np.min)
|
135 |
+
self.frontiers = self.find_frontier()
|
136 |
+
self.old_robot_belief = copy.deepcopy(self.robot_belief)
|
137 |
+
|
138 |
+
self.node_coords, self.graph, self.node_utility, self.guidepost = self.graph_generator.generate_graph(
|
139 |
+
self.robot_belief, self.frontiers)
|
140 |
+
|
141 |
+
# Find non-conflicting start positions
|
142 |
+
if FIX_START_POSITION:
|
143 |
+
coords_res_row = int(self.robot_belief.shape[0]/NUM_COORDS_HEIGHT)
|
144 |
+
coords_res_col = int(self.robot_belief.shape[1]/NUM_COORDS_WIDTH)
|
145 |
+
self.start_positions = [(int(self.robot_belief.shape[1]/2)-coords_res_col/2,int(self.robot_belief.shape[0]/2)-coords_res_row/2) for _ in range(self.n_agent)] # bottom-left corner
|
146 |
+
else:
|
147 |
+
nearby_coords = self.graph_generator.get_neighbors_grid_coords(start_position)
|
148 |
+
itr = 0
|
149 |
+
for i in range(self.n_agent):
|
150 |
+
if i == 0 or len(nearby_coords) == 0:
|
151 |
+
self.start_positions.append(start_position)
|
152 |
+
else:
|
153 |
+
idx = min(itr, len(nearby_coords)-1)
|
154 |
+
self.start_positions.append(nearby_coords[idx])
|
155 |
+
itr += 1
|
156 |
+
|
157 |
+
for i in range(len(self.start_positions)):
|
158 |
+
self.start_positions[i] = self.node_coords[self.find_index_from_coords(self.start_positions[i])]
|
159 |
+
self.coverage_belief = self.update_robot_belief(self.start_positions[i], self.sensor_range, self.coverage_belief,
|
160 |
+
self.ground_truth)
|
161 |
+
|
162 |
+
for start_position in self.start_positions:
|
163 |
+
self.graph_generator.route_node.append(start_position)
|
164 |
+
|
165 |
+
# Info map from ground truth
|
166 |
+
rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
|
167 |
+
rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
|
168 |
+
self.segmentation_info_mask = np.zeros((len(self.node_coords), 1))
|
169 |
+
self.gt_segmentation_info_mask = np.zeros((len(self.node_coords), 1))
|
170 |
+
for i, node_coord in enumerate(self.node_coords):
|
171 |
+
max_x = min(node_coord[0] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
|
172 |
+
min_x = max(node_coord[0] - int(math.ceil(rng_x)), 0)
|
173 |
+
max_y = min(node_coord[1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
|
174 |
+
min_y = max(node_coord[1] - int(math.ceil(rng_y)), 0)
|
175 |
+
|
176 |
+
# if np.any(self.segmentation_mask[min_y:max_y, min_x:max_x] == 255):
|
177 |
+
# self.segmentation_info_mask[i] = 1.0
|
178 |
+
# else:
|
179 |
+
# self.segmentation_info_mask[i] = 0.0
|
180 |
+
# self.segmentation_info_mask[i] = np.mean(self.segmentation_mask[min_y:max_y, min_x:max_x])
|
181 |
+
# self.segmentation_info_mask[i] = np.max(self.segmentation_mask[min_y:max_y, min_x:max_x])
|
182 |
+
if TARGETS_SET_DIR == "": # If targets combined with segmentation mask
|
183 |
+
exclude = {208} # Exclude target positions
|
184 |
+
else:
|
185 |
+
exclude = {}
|
186 |
+
self.segmentation_info_mask[i] = max(x for x in self.segmentation_mask[min_y:max_y, min_x:max_x].flatten() if x not in exclude) / 100.0
|
187 |
+
if self.gt_segmentation_mask is not None:
|
188 |
+
self.gt_segmentation_info_mask[i] = max(x for x in self.gt_segmentation_mask[min_y:max_y, min_x:max_x].flatten() if x not in exclude) / 100.0
|
189 |
+
# print("np.unique(self.segmentation_info_mask): ", np.unique(self.segmentation_info_mask))
|
190 |
+
self.filtered_seg_info_mask = copy.deepcopy(self.segmentation_info_mask)
|
191 |
+
|
192 |
+
# In case targets found at beginning...
|
193 |
+
done, num_targets_found = self.check_done()
|
194 |
+
self.num_targets_found = num_targets_found
|
195 |
+
|
196 |
+
|
197 |
+
def multi_robot_step(self, next_position_list, dist_list, travel_dist_list):
|
198 |
+
temp_frontiers = copy.deepcopy(self.frontiers)
|
199 |
+
reward_list = []
|
200 |
+
for dist, robot_position in zip(dist_list, next_position_list):
|
201 |
+
self.graph_generator.route_node.append(robot_position)
|
202 |
+
next_node_index = self.find_index_from_coords(robot_position)
|
203 |
+
self.graph_generator.nodes_list[next_node_index].set_visited()
|
204 |
+
# self.robot_belief = self.update_robot_belief(robot_position, self.sensor_range, self.robot_belief,
|
205 |
+
# self.ground_truth)
|
206 |
+
self.coverage_belief = self.update_robot_belief(robot_position, self.sensor_range, self.coverage_belief,
|
207 |
+
self.ground_truth)
|
208 |
+
self.robot_belief = self.ground_truth
|
209 |
+
|
210 |
+
self.downsampled_belief = block_reduce(self.robot_belief.copy(),
|
211 |
+
block_size=(self.resolution, self.resolution),
|
212 |
+
func=np.min)
|
213 |
+
|
214 |
+
frontiers = self.find_frontier()
|
215 |
+
#num_observed_frontiers = self.calculate_num_observed_frontiers(temp_frontiers, frontiers)
|
216 |
+
#temp_frontiers = frontiers
|
217 |
+
|
218 |
+
num_observed_frontiers = self.node_utility[next_node_index]
|
219 |
+
|
220 |
+
# individual_reward = num_observed_frontiers / 50 - dist / 64
|
221 |
+
individual_reward = -dist / 32 # 64
|
222 |
+
|
223 |
+
info_gain_reward = 0
|
224 |
+
robot_position_idx = self.find_index_from_coords(robot_position)
|
225 |
+
# if self.segmentation_info_mask[robot_position_idx] == 1.0 and self.guidepost[robot_position_idx] == 0.0:
|
226 |
+
# # print("High Info (Unvisited)")
|
227 |
+
# info_gain_reward = (HIGH_INFO_REWARD_RATIO * 1.5)
|
228 |
+
# elif self.segmentation_info_mask[robot_position_idx] == 0.0 and self.guidepost[robot_position_idx] == 0.0:
|
229 |
+
# # print("Low Info (Unvisited)")
|
230 |
+
# info_gain_reward = ((1-HIGH_INFO_REWARD_RATIO) * 1.5)
|
231 |
+
info_gain_reward = self.filtered_seg_info_mask[robot_position_idx][0] * 1.5
|
232 |
+
if self.guidepost[robot_position_idx] == 0.0:
|
233 |
+
info_gain_reward += 0.2
|
234 |
+
# print("info_gain_reward: ", info_gain_reward)
|
235 |
+
individual_reward += info_gain_reward
|
236 |
+
|
237 |
+
# print("dist / 64: ", dist / 64)
|
238 |
+
# print("info gain reward: ", info_gain_reward)
|
239 |
+
|
240 |
+
reward_list.append(individual_reward)
|
241 |
+
|
242 |
+
self.node_coords, self.graph, self.node_utility, self.guidepost = self.graph_generator.update_graph(self.robot_belief, self.old_robot_belief, frontiers, self.frontiers)
|
243 |
+
self.old_robot_belief = copy.deepcopy(self.robot_belief)
|
244 |
+
|
245 |
+
self.filtered_seg_info_mask = [info[0] if self.guidepost[i] == 0.0 else 0.0 for i, info in enumerate(self.segmentation_info_mask)]
|
246 |
+
self.filtered_seg_info_mask = np.expand_dims(np.array(self.filtered_seg_info_mask), axis=1)
|
247 |
+
|
248 |
+
num_observed_frontiers = self.calculate_num_observed_frontiers(self.frontiers, frontiers)
|
249 |
+
self.frontiers = frontiers
|
250 |
+
self.explored_rate = self.evaluate_exploration_rate()
|
251 |
+
|
252 |
+
done, num_targets_found = self.check_done()
|
253 |
+
self.num_new_targets_found = num_targets_found - self.num_targets_found
|
254 |
+
# #team_reward = sum(reward_list) / len(reward_list)
|
255 |
+
# # team_reward = num_observed_frontiers / 50
|
256 |
+
# team_reward = self.num_new_targets_found * 5.0
|
257 |
+
team_reward = 0.0
|
258 |
+
# # print("target found reward: ", self.num_new_targets_found * 5.0)
|
259 |
+
|
260 |
+
self.num_targets_found = num_targets_found
|
261 |
+
self.targets_found_rate = self.evaluate_targets_found_rate()
|
262 |
+
self.info_gain, self.total_info = self.evaluate_info_gain()
|
263 |
+
|
264 |
+
if done:
|
265 |
+
# team_reward += np.sum(self.robot_belief == 255) / sum(travel_dist_list)
|
266 |
+
team_reward += 40 # 20
|
267 |
+
for i in range(len(reward_list)):
|
268 |
+
reward_list[i] += team_reward
|
269 |
+
|
270 |
+
return reward_list, done
|
271 |
+
|
272 |
+
|
273 |
+
def import_ground_truth(self, map_index):
|
274 |
+
# occupied 1, free 255, unexplored 127
|
275 |
+
|
276 |
+
try:
|
277 |
+
# ground_truth = (io.imread(map_index, 1) * 255).astype(int)
|
278 |
+
ground_truth = (io.imread(map_index, 1)).astype(int)
|
279 |
+
if np.all(ground_truth == 0):
|
280 |
+
ground_truth = (io.imread(map_index, 1) * 255).astype(int)
|
281 |
+
except:
|
282 |
+
new_map_index = self.map_dir + '/' + self.map_list[0]
|
283 |
+
ground_truth = (io.imread(new_map_index, 1)).astype(int)
|
284 |
+
print('could not read the map_path ({}), hence skipping it and using ({}).'.format(map_index, new_map_index))
|
285 |
+
|
286 |
+
robot_location = np.nonzero(ground_truth == 208)
|
287 |
+
|
288 |
+
# print("robot_location: ", robot_location)
|
289 |
+
# print("np.array(robot_location)[1, 127]: ", np.array(robot_location)[1, 127])
|
290 |
+
|
291 |
+
robot_location = np.array([np.array(robot_location)[1, 127], np.array(robot_location)[0, 127]])
|
292 |
+
ground_truth = (ground_truth > 150)
|
293 |
+
ground_truth = ground_truth * 254 + 1
|
294 |
+
return ground_truth, robot_location
|
295 |
+
|
296 |
+
|
297 |
+
def import_segmentation_mask(self, map_index):
|
298 |
+
# occupied 1, free 255, unexplored 127
|
299 |
+
|
300 |
+
# mask = (io.imread(map_index, 1) * 255).astype(int) # NOTE: Cannot work well with seg mask self-generated
|
301 |
+
mask = cv2.imread(map_index).astype(int)
|
302 |
+
# print("np.unique(segmentation_mask): ", np.unique(mask))
|
303 |
+
|
304 |
+
# NOTE: Could contain mutiple start positions
|
305 |
+
# target_position = np.nonzero(mask == 208)
|
306 |
+
# target_positions = self.find_target_locations(mask)
|
307 |
+
|
308 |
+
# target_position = np.array([np.array(target_position)[1, 127], np.array(target_position)[0, 127]])
|
309 |
+
return mask #, target_positions
|
310 |
+
|
311 |
+
|
312 |
+
def find_target_locations(self, image_array, grey_value=208):
|
313 |
+
# Load the image
|
314 |
+
# image = Image.open(image_path)
|
315 |
+
# image_array = np.array(image)
|
316 |
+
|
317 |
+
# Identify pixels equal to the grey value
|
318 |
+
grey_pixels = np.where(image_array == grey_value)
|
319 |
+
|
320 |
+
# Create a binary array where grey pixels are marked as True
|
321 |
+
binary_array = np.zeros_like(image_array, dtype=bool)
|
322 |
+
binary_array[grey_pixels] = True
|
323 |
+
|
324 |
+
# Label connected components
|
325 |
+
labeled_array, num_features = label(binary_array)
|
326 |
+
|
327 |
+
# Find objects returns slices for each connected component
|
328 |
+
slices = find_objects(labeled_array)
|
329 |
+
|
330 |
+
# Calculate the center of each box
|
331 |
+
centers = []
|
332 |
+
for slice in slices:
|
333 |
+
row_center = (slice[0].start + slice[0].stop - 1) // 2
|
334 |
+
col_center = (slice[1].start + slice[1].stop - 1) // 2
|
335 |
+
centers.append((col_center, row_center)) # (y,x)
|
336 |
+
|
337 |
+
return centers
|
338 |
+
|
339 |
+
def free_cells(self):
|
340 |
+
index = np.where(self.ground_truth == 255)
|
341 |
+
free = np.asarray([index[1], index[0]]).T
|
342 |
+
return free
|
343 |
+
|
344 |
+
def update_robot_belief(self, robot_position, sensor_range, robot_belief, ground_truth):
|
345 |
+
robot_belief = sensor_work(robot_position, sensor_range, robot_belief, ground_truth)
|
346 |
+
return robot_belief
|
347 |
+
|
348 |
+
|
349 |
+
def check_done(self, robot_id=0):
|
350 |
+
""" All agnets to have explored most of the env map """
|
351 |
+
done = False
|
352 |
+
# for idx in range(self.n_agent):
|
353 |
+
# if np.sum(self.ground_truth == 255) - np.sum(self.all_robot_belief[idx][idx] == 255) > 40:
|
354 |
+
# done = False
|
355 |
+
|
356 |
+
# NEW: ADDITIONAL VLM SEARCH CRITERIA
|
357 |
+
num_targets_found = 0
|
358 |
+
self.target_found_idxs = []
|
359 |
+
for i, target in enumerate(self.target_positions):
|
360 |
+
if self.coverage_belief[target[1], target[0]] == 255: # 255:
|
361 |
+
num_targets_found += 1
|
362 |
+
self.target_found_idxs.append(i)
|
363 |
+
# free_cells_mask = self.all_robot_belief[robot_id][robot_id] == 255
|
364 |
+
# filtered_segmentation_mask = np.where(free_cells_mask, self.segmentation_mask, 0)
|
365 |
+
# targets = self.find_target_locations(filtered_segmentation_mask)
|
366 |
+
# print("num_targets_found: ", num_targets_found)
|
367 |
+
|
368 |
+
if TERMINATE_ON_TGTS_FOUND and num_targets_found >= len(self.target_positions):
|
369 |
+
done = True
|
370 |
+
if not TERMINATE_ON_TGTS_FOUND and np.sum(self.coverage_belief == 255) / np.sum(self.ground_truth == 255) >= 0.99:
|
371 |
+
done = True
|
372 |
+
|
373 |
+
return done, num_targets_found
|
374 |
+
|
375 |
+
|
376 |
+
def calculate_num_observed_frontiers(self, old_frontiers, frontiers):
|
377 |
+
frontiers_to_check = frontiers[:, 0] + frontiers[:, 1] * 1j
|
378 |
+
pre_frontiers_to_check = old_frontiers[:, 0] + old_frontiers[:, 1] * 1j
|
379 |
+
frontiers_num = np.intersect1d(frontiers_to_check, pre_frontiers_to_check).shape[0]
|
380 |
+
pre_frontiers_num = pre_frontiers_to_check.shape[0]
|
381 |
+
delta_num = pre_frontiers_num - frontiers_num
|
382 |
+
|
383 |
+
return delta_num
|
384 |
+
|
385 |
+
def calculate_reward(self, dist, frontiers):
|
386 |
+
reward = 0
|
387 |
+
reward -= dist / 64
|
388 |
+
|
389 |
+
frontiers_to_check = frontiers[:, 0] + frontiers[:, 1] * 1j
|
390 |
+
pre_frontiers_to_check = self.frontiers[:, 0] + self.frontiers[:, 1] * 1j
|
391 |
+
frontiers_num = np.intersect1d(frontiers_to_check, pre_frontiers_to_check).shape[0]
|
392 |
+
pre_frontiers_num = pre_frontiers_to_check.shape[0]
|
393 |
+
delta_num = pre_frontiers_num - frontiers_num
|
394 |
+
|
395 |
+
reward += delta_num / 50
|
396 |
+
|
397 |
+
return reward
|
398 |
+
|
399 |
+
def evaluate_exploration_rate(self):
|
400 |
+
# rate = np.sum(self.robot_belief == 255) / np.sum(self.ground_truth == 255)
|
401 |
+
rate = np.sum(self.coverage_belief == 255) / np.sum(self.ground_truth == 255)
|
402 |
+
return rate
|
403 |
+
|
404 |
+
def evaluate_targets_found_rate(self):
|
405 |
+
if len(self.target_positions) == 0:
|
406 |
+
return 0
|
407 |
+
else:
|
408 |
+
rate = self.num_targets_found / len(self.target_positions)
|
409 |
+
return rate
|
410 |
+
|
411 |
+
def evaluate_info_gain(self):
|
412 |
+
# print("self.segmentation_mask.shape: ", self.segmentation_mask.shape)
|
413 |
+
# coverage_belief = (self.coverage_belief == 255)
|
414 |
+
# print("coverage_belief.shape: ", coverage_belief.shape)
|
415 |
+
# print("np.unique(coverage_belief): ", np.unique(coverage_belief))
|
416 |
+
# print("np.count_nonzero(coverage_belief): ", np.count_nonzero(coverage_belief))
|
417 |
+
# print("np.count_zero(coverage_belief): ", coverage_belief.size - np.count_nonzero(coverage_belief))
|
418 |
+
# print("self.segmentation_mask[self.coverage_belief == 255].shape: ", self.segmentation_mask[self.coverage_belief == 255].shape)
|
419 |
+
if self.test and TARGETS_SET_DIR != "":
|
420 |
+
info_gained = np.sum(self.gt_segmentation_mask[self.coverage_belief == 255]) / 100.0
|
421 |
+
total_info = np.sum(self.gt_segmentation_mask) / 100.0
|
422 |
+
else:
|
423 |
+
info_gained = np.sum(self.segmentation_mask[self.coverage_belief == 255]) / 100.0
|
424 |
+
total_info = np.sum(self.segmentation_mask) / 100.0
|
425 |
+
return info_gained, total_info
|
426 |
+
|
427 |
+
def calculate_new_free_area(self):
|
428 |
+
old_free_area = self.old_robot_belief == 255
|
429 |
+
current_free_area = self.robot_belief == 255
|
430 |
+
|
431 |
+
new_free_area = (current_free_area.astype(np.int) - old_free_area.astype(np.int)) * 255
|
432 |
+
|
433 |
+
return new_free_area, np.sum(old_free_area)
|
434 |
+
|
435 |
+
def calculate_dist_path(self, path):
|
436 |
+
dist = 0
|
437 |
+
start = path[0]
|
438 |
+
end = path[-1]
|
439 |
+
for index in path:
|
440 |
+
if index == end:
|
441 |
+
break
|
442 |
+
dist += np.linalg.norm(self.node_coords[start] - self.node_coords[index])
|
443 |
+
start = index
|
444 |
+
return dist
|
445 |
+
|
446 |
+
def find_frontier(self):
|
447 |
+
y_len = self.downsampled_belief.shape[0]
|
448 |
+
x_len = self.downsampled_belief.shape[1]
|
449 |
+
mapping = self.downsampled_belief.copy()
|
450 |
+
belief = self.downsampled_belief.copy()
|
451 |
+
# 0-1 unknown area map
|
452 |
+
mapping = (mapping == 127) * 1
|
453 |
+
mapping = np.lib.pad(mapping, ((1, 1), (1, 1)), 'constant', constant_values=0)
|
454 |
+
fro_map = mapping[2:][:, 1:x_len + 1] + mapping[:y_len][:, 1:x_len + 1] + mapping[1:y_len + 1][:, 2:] + \
|
455 |
+
mapping[1:y_len + 1][:, :x_len] + mapping[:y_len][:, 2:] + mapping[2:][:, :x_len] + mapping[2:][:,
|
456 |
+
2:] + \
|
457 |
+
mapping[:y_len][:, :x_len]
|
458 |
+
ind_free = np.where(belief.ravel(order='F') == 255)[0]
|
459 |
+
ind_fron_1 = np.where(1 < fro_map.ravel(order='F'))[0]
|
460 |
+
ind_fron_2 = np.where(fro_map.ravel(order='F') < 8)[0]
|
461 |
+
ind_fron = np.intersect1d(ind_fron_1, ind_fron_2)
|
462 |
+
ind_to = np.intersect1d(ind_free, ind_fron)
|
463 |
+
|
464 |
+
map_x = x_len
|
465 |
+
map_y = y_len
|
466 |
+
x = np.linspace(0, map_x - 1, map_x)
|
467 |
+
y = np.linspace(0, map_y - 1, map_y)
|
468 |
+
t1, t2 = np.meshgrid(x, y)
|
469 |
+
points = np.vstack([t1.T.ravel(), t2.T.ravel()]).T
|
470 |
+
|
471 |
+
f = points[ind_to]
|
472 |
+
f = f.astype(int)
|
473 |
+
|
474 |
+
f = f * self.resolution
|
475 |
+
|
476 |
+
return f
|
477 |
+
|
478 |
+
def plot_env(self, n, path, step, travel_dist, robots_route, img_path_override=None, sat_path_override=None, msk_name_override=None, sound_id_override=None, colormap_mid_val=None):
|
479 |
+
|
480 |
+
# # TEMP
|
481 |
+
# if TAXABIND_TTA:
|
482 |
+
# # Save self.segmentation_info_mask as .npy file in gifs_path
|
483 |
+
# side_dim = int(np.sqrt(self.segmentation_info_mask.shape[0]))
|
484 |
+
# mask_viz = self.segmentation_info_mask.squeeze().reshape((side_dim, side_dim)).T
|
485 |
+
# np.save(os.path.join(path, f"seg_mask_step{step}.npy"), mask_viz)
|
486 |
+
|
487 |
+
plt.switch_backend('agg')
|
488 |
+
# plt.ion()
|
489 |
+
plt.cla()
|
490 |
+
color_list = ["r", "g", "c", "m", "y", "k"]
|
491 |
+
|
492 |
+
if TARGETS_SET_DIR == "" and not TAXABIND_TTA:
|
493 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
|
494 |
+
else:
|
495 |
+
fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(2, 3, figsize=(12, 8))
|
496 |
+
|
497 |
+
### Fig1: Environment ###
|
498 |
+
msk_name = ""
|
499 |
+
if TAXABIND_TTA:
|
500 |
+
image = mpimg.imread(sat_path_override)
|
501 |
+
msk_name = msk_name_override
|
502 |
+
# else:
|
503 |
+
# plt.imshow(self.robot_belief, cmap='gray')
|
504 |
+
# ax1.imshow(self.coverage_belief, cmap='gray')
|
505 |
+
# image = mpimg.imread("Maps/real_maps/real/4259_masked_img_0.jpg")
|
506 |
+
# msk_name = self.map_list[self.map_index]
|
507 |
+
# raw_img_path = self.score_to_img_dict[msk_name]
|
508 |
+
# if "flair" in raw_img_path:
|
509 |
+
# with rasterio.open(raw_img_path) as src_img:
|
510 |
+
# image = src_img.read([1,2,3])
|
511 |
+
# image = np.transpose(image, (1, 2, 0))
|
512 |
+
# else:
|
513 |
+
# image = mpimg.imread(raw_img_path)
|
514 |
+
|
515 |
+
|
516 |
+
### Fig1: Environment ###
|
517 |
+
ax = ax1 # if TAXABIND_TTA else ax1
|
518 |
+
ax.imshow(image)
|
519 |
+
ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
|
520 |
+
ax.set_title("Image")
|
521 |
+
# if VIZ_GRAPH_EDGES:
|
522 |
+
# for i in range(len(self.graph_generator.x)):
|
523 |
+
# ax.plot(self.graph_generator.x[i], self.graph_generator.y[i], 'tan', zorder=1)
|
524 |
+
# # ax.scatter(self.node_coords[:, 0], self.node_coords[:, 1], c=self.node_utility, zorder=5)
|
525 |
+
# ax.scatter(self.node_coords[:, 0], self.node_coords[:, 1], c=self.segmentation_info_mask, zorder=5)
|
526 |
+
# ax.scatter(self.frontiers[:, 0], self.frontiers[:, 1], c='r', s=2, zorder=3)
|
527 |
+
for i, route in enumerate(robots_route):
|
528 |
+
robot_marker_color = color_list[i % len(color_list)]
|
529 |
+
xPoints = route[0]
|
530 |
+
yPoints = route[1]
|
531 |
+
ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
|
532 |
+
# ax.plot(xPoints[-1], yPoints[-1], 'mo', markersize=8, zorder=10)
|
533 |
+
ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
|
534 |
+
ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
|
535 |
+
|
536 |
+
# Sensor range
|
537 |
+
rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
|
538 |
+
rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
|
539 |
+
max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
|
540 |
+
min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
|
541 |
+
max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
|
542 |
+
min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
|
543 |
+
ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
|
544 |
+
ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
|
545 |
+
ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
|
546 |
+
ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
|
547 |
+
|
548 |
+
### Fig2: Graph ###
|
549 |
+
ax = ax4 if TAXABIND_TTA else ax1
|
550 |
+
# ax.imshow(image)
|
551 |
+
ax.imshow(self.coverage_belief, cmap='gray')
|
552 |
+
ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
|
553 |
+
ax.set_title("Information Graph")
|
554 |
+
if VIZ_GRAPH_EDGES:
|
555 |
+
for i in range(len(self.graph_generator.x)):
|
556 |
+
ax.plot(self.graph_generator.x[i], self.graph_generator.y[i], 'tan', zorder=1)
|
557 |
+
# ax.scatter(self.node_coords[:, 0], self.node_coords[:, 1], c=self.node_utility, zorder=5)
|
558 |
+
# ax.scatter(self.node_coords[:, 0], self.node_coords[:, 1], c=self.segmentation_info_mask, zorder=5)
|
559 |
+
# filtered_seg_info_mask = [info[0] if self.guidepost[i] == 0.0 else 0.0 for i, info in enumerate(self.segmentation_info_mask)]
|
560 |
+
ax.scatter(self.node_coords[:, 0], self.node_coords[:, 1], c=self.filtered_seg_info_mask, zorder=5, s=8)
|
561 |
+
# ax.scatter(self.frontiers[:, 0], self.frontiers[:, 1], c='r', s=2, zorder=3)
|
562 |
+
|
563 |
+
for i, route in enumerate(robots_route):
|
564 |
+
robot_marker_color = color_list[i % len(color_list)]
|
565 |
+
xPoints = route[0]
|
566 |
+
yPoints = route[1]
|
567 |
+
ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
|
568 |
+
# ax.plot(xPoints[-1], yPoints[-1], 'mo', markersize=8, zorder=10)
|
569 |
+
ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
|
570 |
+
ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
|
571 |
+
|
572 |
+
# Sensor range
|
573 |
+
rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
|
574 |
+
rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
|
575 |
+
max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
|
576 |
+
min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
|
577 |
+
max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
|
578 |
+
min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
|
579 |
+
ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
|
580 |
+
ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
|
581 |
+
ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
|
582 |
+
ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
|
583 |
+
|
584 |
+
# Plot target positions
|
585 |
+
for target in self.target_positions:
|
586 |
+
if self.coverage_belief[target[1], target[0]] == 255:
|
587 |
+
# ax.plot(target[0], target[1], 'go', markersize=8, zorder=99)
|
588 |
+
ax.plot(target[0], target[1], color='g', marker='x', linestyle='-', markersize=12, markeredgewidth=4, zorder=99)
|
589 |
+
else:
|
590 |
+
# ax.plot(target[0], target[1], 'ro', markersize=8, zorder=99)
|
591 |
+
ax.plot(target[0], target[1], color='r', marker='x', linestyle='-', markersize=12, markeredgewidth=4, zorder=99)
|
592 |
+
|
593 |
+
# ax.pause(0.1)
|
594 |
+
|
595 |
+
### Fig3: Segmentation Mask ###
|
596 |
+
ax = ax5 if TAXABIND_TTA else ax2
|
597 |
+
if TAXABIND_TTA and USE_CLIP_PREDS:
|
598 |
+
side_dim = int(np.sqrt(self.segmentation_info_mask.shape[0]))
|
599 |
+
mask_viz = self.segmentation_info_mask.squeeze().reshape((side_dim, side_dim)).T
|
600 |
+
scale_y = math.ceil(self.ground_truth_size[1] / side_dim)
|
601 |
+
scale_x = math.ceil(self.ground_truth_size[0] / side_dim)
|
602 |
+
upscaled_mask_viz = np.kron(mask_viz, np.ones((scale_y, scale_x))) # Integer scaling only
|
603 |
+
upscaled_mask_viz = upscaled_mask_viz[:self.ground_truth_size[1], :self.ground_truth_size[0]]
|
604 |
+
im = ax.imshow(upscaled_mask_viz, cmap="viridis")
|
605 |
+
ax.axis("off")
|
606 |
+
else:
|
607 |
+
im = ax.imshow(self.segmentation_mask.mean(axis=-1), cmap='viridis', vmin=0, vmax=100) # cmap='gray'
|
608 |
+
ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
|
609 |
+
ax.set_title(f"Predicted Mask (Normalized)")
|
610 |
+
for i, route in enumerate(robots_route):
|
611 |
+
robot_marker_color = color_list[i % len(color_list)]
|
612 |
+
xPoints = route[0]
|
613 |
+
yPoints = route[1]
|
614 |
+
ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
|
615 |
+
# ax.plot(xPoints[-1], yPoints[-1], 'mo', markersize=8, zorder=10)
|
616 |
+
ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
|
617 |
+
ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
|
618 |
+
|
619 |
+
# Sensor range
|
620 |
+
rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
|
621 |
+
rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
|
622 |
+
max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
|
623 |
+
min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
|
624 |
+
max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
|
625 |
+
min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
|
626 |
+
ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
|
627 |
+
ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
|
628 |
+
ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
|
629 |
+
ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
|
630 |
+
|
631 |
+
# Add a colorbar
|
632 |
+
cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
633 |
+
cbar.set_label("Normalized Probs")
|
634 |
+
|
635 |
+
# ax.pause(0.1)
|
636 |
+
|
637 |
+
### Fig4: Segmentation Mask ###
|
638 |
+
if TAXABIND_TTA and USE_CLIP_PREDS:
|
639 |
+
ax = ax6
|
640 |
+
side_dim = int(np.sqrt(self.segmentation_info_mask_unnormalized.shape[0]))
|
641 |
+
mask_viz = self.segmentation_info_mask_unnormalized.squeeze().reshape((side_dim, side_dim)).T
|
642 |
+
scale_y = math.ceil(self.ground_truth_size[1] / side_dim)
|
643 |
+
scale_x = math.ceil(self.ground_truth_size[0] / side_dim)
|
644 |
+
upscaled_mask_viz = np.kron(mask_viz, np.ones((scale_y, scale_x))) # Integer scaling only
|
645 |
+
upscaled_mask_viz = upscaled_mask_viz[:self.ground_truth_size[1], :self.ground_truth_size[0]]
|
646 |
+
|
647 |
+
max_val = 0.15 # TO CHANGE
|
648 |
+
mid_val = colormap_mid_val if colormap_mid_val is not None else 0.05
|
649 |
+
# mid_val = np.max(self.segmentation_info_mask_unnormalized)
|
650 |
+
norm = CustomNorm(vmin=0.0, vmax=max_val, mid=mid_val, lower_portion=0.8)
|
651 |
+
im = ax.imshow(upscaled_mask_viz, cmap="viridis", norm=norm) # norm=LogNorm(vmin=0.01, vmax=0.1))
|
652 |
+
# norm = PowerNorm(gamma=0.25, vmin=0.01, vmax=0.2)
|
653 |
+
# norm=LogNorm(vmin=0.01, vmax=0.2)
|
654 |
+
im = ax.imshow(upscaled_mask_viz, cmap="viridis", norm=norm) # norm=LogNorm(vmin=0.01, vmax=0.1))
|
655 |
+
ax.axis("off")
|
656 |
+
# else:
|
657 |
+
# im = ax.imshow(self.segmentation_mask.mean(axis=-1), cmap='viridis', vmin=0, vmax=100) # cmap='gray'
|
658 |
+
# ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
|
659 |
+
ax.set_title(f"Predicted Mask (Unnormalized)")
|
660 |
+
for i, route in enumerate(robots_route):
|
661 |
+
robot_marker_color = color_list[i % len(color_list)]
|
662 |
+
xPoints = route[0]
|
663 |
+
yPoints = route[1]
|
664 |
+
ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
|
665 |
+
# ax.plot(xPoints[-1], yPoints[-1], 'mo', markersize=8, zorder=10)
|
666 |
+
ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
|
667 |
+
ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
|
668 |
+
|
669 |
+
# Sensor range
|
670 |
+
rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
|
671 |
+
rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
|
672 |
+
max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
|
673 |
+
min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
|
674 |
+
max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
|
675 |
+
min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
|
676 |
+
ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
|
677 |
+
ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
|
678 |
+
ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
|
679 |
+
ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
|
680 |
+
|
681 |
+
# Add a colorbar
|
682 |
+
cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
683 |
+
if TAXABIND_TTA and USE_CLIP_PREDS:
|
684 |
+
cbar.set_ticks([0.0, mid_val, max_val])
|
685 |
+
cbar.set_label("Probs (Scaled by expectation)")
|
686 |
+
|
687 |
+
|
688 |
+
# Fog5: GT Mask
|
689 |
+
if TARGETS_SET_DIR != "":
|
690 |
+
ax = ax2
|
691 |
+
im = ax.imshow(self.gt_segmentation_mask.mean(axis=-1), cmap='viridis', vmin=0, vmax=100) # cmap='gray'
|
692 |
+
ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
|
693 |
+
ax.set_title(f"Ground Truth Mask")
|
694 |
+
for i, route in enumerate(robots_route):
|
695 |
+
robot_marker_color = color_list[i % len(color_list)]
|
696 |
+
xPoints = route[0]
|
697 |
+
yPoints = route[1]
|
698 |
+
ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
|
699 |
+
# ax.plot(xPoints[-1], yPoints[-1], 'mo', markersize=8, zorder=10)
|
700 |
+
ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
|
701 |
+
ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
|
702 |
+
|
703 |
+
# Sensor range
|
704 |
+
rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
|
705 |
+
rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
|
706 |
+
max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
|
707 |
+
min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
|
708 |
+
max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
|
709 |
+
min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
|
710 |
+
ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
|
711 |
+
ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
|
712 |
+
ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
|
713 |
+
ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
|
714 |
+
|
715 |
+
# Add a colorbar
|
716 |
+
cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
717 |
+
cbar.set_label("Normalized Mask Value")
|
718 |
+
|
719 |
+
# ax4.pause(0.1)
|
720 |
+
|
721 |
+
|
722 |
+
### Fig6: Segmentation Mask (GT) ###
|
723 |
+
if TAXABIND_TTA:
|
724 |
+
ax = ax3
|
725 |
+
image = mpimg.imread(img_path_override)
|
726 |
+
ax.imshow(image)
|
727 |
+
ax.set_title("Ground Image")
|
728 |
+
ax.axis("off")
|
729 |
+
|
730 |
+
|
731 |
+
sound_id = sound_id_override if sound_id_override is not None else "-1"
|
732 |
+
plt.suptitle('Targets Found: {}/{} Coverage ratio: {:.4g} Travel Dist: {:.4g} Info Gain: {:.4g}% \n ({}) \n (Sound ID: {})'.format(self.num_targets_found, \
|
733 |
+
len(self.target_positions), self.explored_rate, travel_dist, (100*self.info_gain/self.total_info), msk_name,
|
734 |
+
sound_id))
|
735 |
+
plt.tight_layout()
|
736 |
+
plt.savefig('{}/{}_{}_samples.png'.format(path, n, step, dpi=100))
|
737 |
+
# plt.show()
|
738 |
+
frame = '{}/{}_{}_samples.png'.format(path, n, step)
|
739 |
+
self.frame_files.append(frame)
|
740 |
+
plt.close()
|
741 |
+
|
742 |
+
|
743 |
+
####################
|
744 |
+
|
745 |
+
class CustomNorm(Normalize):
|
746 |
+
"""
|
747 |
+
A custom normalization that allocates a larger fraction of the colormap
|
748 |
+
to the lower data range [vmin, mid] than to [mid, vmax].
|
749 |
+
|
750 |
+
Parameters
|
751 |
+
----------
|
752 |
+
vmin : float
|
753 |
+
Minimum data value
|
754 |
+
vmax : float
|
755 |
+
Maximum data value
|
756 |
+
mid : float
|
757 |
+
Midpoint in data where we switch from 'lower' to 'upper' mapping
|
758 |
+
lower_portion : float
|
759 |
+
Fraction of the colormap to allocate for [vmin, mid].
|
760 |
+
For example, 0.8 => 80% of colors for [vmin, mid], 20% for [mid, vmax].
|
761 |
+
clip : bool
|
762 |
+
Whether to clip data outside [vmin, vmax].
|
763 |
+
"""
|
764 |
+
|
765 |
+
def __init__(self, vmin=None, vmax=None, mid=0.05, lower_portion=0.8, clip=False):
|
766 |
+
self.mid = mid
|
767 |
+
self.lower_portion = lower_portion
|
768 |
+
super().__init__(vmin, vmax, clip)
|
769 |
+
|
770 |
+
def __call__(self, value, clip=None):
|
771 |
+
"""Forward transform: data -> [0..1] color space."""
|
772 |
+
vmin, vmax, mid = self.vmin, self.vmax, self.mid
|
773 |
+
lp = self.lower_portion
|
774 |
+
|
775 |
+
value = np.asarray(value, dtype=np.float64)
|
776 |
+
|
777 |
+
# Piecewise linear mapping:
|
778 |
+
# [vmin..mid] => [0..lp]
|
779 |
+
# [mid..vmax] => [lp..1]
|
780 |
+
normed = np.where(
|
781 |
+
value <= mid,
|
782 |
+
lp * (value - vmin) / (mid - vmin),
|
783 |
+
lp + (value - mid) / (vmax - mid) * (1 - lp)
|
784 |
+
)
|
785 |
+
return np.clip(normed, 0, 1)
|
786 |
+
|
787 |
+
def inverse(self, value):
|
788 |
+
"""
|
789 |
+
Inverse transform: [0..1] color space -> data space.
|
790 |
+
Matplotlib's colorbar calls this to place ticks correctly.
|
791 |
+
"""
|
792 |
+
vmin, vmax, mid = self.vmin, self.vmax, self.mid
|
793 |
+
lp = self.lower_portion
|
794 |
+
|
795 |
+
value = np.asarray(value, dtype=np.float64)
|
796 |
+
|
797 |
+
# For color space [0..lp], invert to [vmin..mid]
|
798 |
+
# For color space [lp..1], invert to [mid..vmax]
|
799 |
+
below = (value <= lp)
|
800 |
+
above = ~below
|
801 |
+
|
802 |
+
# Allocate array for results
|
803 |
+
data = np.zeros_like(value, dtype=np.float64)
|
804 |
+
|
805 |
+
# Invert lower segment
|
806 |
+
data[below] = vmin + (value[below] / lp) * (mid - vmin)
|
807 |
+
|
808 |
+
# Invert upper segment
|
809 |
+
data[above] = mid + ((value[above] - lp) / (1 - lp)) * (vmax - mid)
|
810 |
+
|
811 |
+
return data
|
graph.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#######################################################################
|
2 |
+
# Name: env.py
|
3 |
+
#
|
4 |
+
# - Adapted from https://gist.github.com/betandr/541a1f6466b6855471de5ca30b74cb31
|
5 |
+
# - Simple graph class to perform distance calculations (E.g. A-Star, Djikstra)
|
6 |
+
#######################################################################
|
7 |
+
|
8 |
+
from decimal import Decimal
|
9 |
+
|
10 |
+
|
11 |
+
class Edge:
|
12 |
+
def __init__(self, to_node, length):
|
13 |
+
self.to_node = to_node
|
14 |
+
self.length = length
|
15 |
+
|
16 |
+
|
17 |
+
class Graph:
|
18 |
+
def __init__(self):
|
19 |
+
self.nodes = set()
|
20 |
+
self.edges = dict()
|
21 |
+
|
22 |
+
def add_node(self, node):
|
23 |
+
self.nodes.add(node)
|
24 |
+
|
25 |
+
def add_edge(self, from_node, to_node, length):
|
26 |
+
edge = Edge(to_node, length)
|
27 |
+
# edge = to_node
|
28 |
+
if from_node in self.edges:
|
29 |
+
from_node_edges = self.edges[from_node]
|
30 |
+
else:
|
31 |
+
self.edges[from_node] = dict()
|
32 |
+
from_node_edges = self.edges[from_node]
|
33 |
+
# if edge not in from_node_edges:
|
34 |
+
# from_node_edges.append(edge)
|
35 |
+
from_node_edges[to_node] = edge
|
36 |
+
|
37 |
+
def clear_edge(self, from_node):
|
38 |
+
if from_node in self.edges:
|
39 |
+
self.edges[from_node] = dict()
|
40 |
+
# else:
|
41 |
+
# print("no edge node or no this node")
|
42 |
+
|
43 |
+
def min_dist(q, dist):
|
44 |
+
"""
|
45 |
+
Returns the node with the smallest distance in q.
|
46 |
+
Implemented to keep the main algorithm clean.
|
47 |
+
"""
|
48 |
+
min_node = None
|
49 |
+
for node in q:
|
50 |
+
if min_node == None:
|
51 |
+
min_node = node
|
52 |
+
elif dist[node] < dist[min_node]:
|
53 |
+
min_node = node
|
54 |
+
|
55 |
+
return min_node
|
56 |
+
|
57 |
+
|
58 |
+
INFINITY = float('Infinity')
|
59 |
+
|
60 |
+
|
61 |
+
def dijkstra(graph, source):
|
62 |
+
q = set()
|
63 |
+
dist = {}
|
64 |
+
prev = {}
|
65 |
+
|
66 |
+
for v in graph.nodes: # initialization
|
67 |
+
dist[v] = INFINITY # unknown distance from source to v
|
68 |
+
prev[v] = INFINITY # previous node in optimal path from source
|
69 |
+
q.add(v) # all nodes initially in q (unvisited nodes)
|
70 |
+
|
71 |
+
# distance from source to source
|
72 |
+
dist[source] = 0
|
73 |
+
|
74 |
+
while q:
|
75 |
+
# node with the least distance selected first
|
76 |
+
u = min_dist(q, dist)
|
77 |
+
|
78 |
+
q.remove(u)
|
79 |
+
|
80 |
+
try:
|
81 |
+
if u in graph.edges:
|
82 |
+
for _, v in graph.edges[u].items():
|
83 |
+
alt = dist[u] + v.length
|
84 |
+
if alt < dist[v.to_node]:
|
85 |
+
# a shorter path to v has been found
|
86 |
+
dist[v.to_node] = alt
|
87 |
+
prev[v.to_node] = u
|
88 |
+
except:
|
89 |
+
pass
|
90 |
+
|
91 |
+
return dist, prev
|
92 |
+
|
93 |
+
|
94 |
+
def to_array(prev, from_node):
|
95 |
+
"""Creates an ordered list of labels as a route."""
|
96 |
+
previous_node = prev[from_node]
|
97 |
+
route = [from_node]
|
98 |
+
|
99 |
+
while previous_node != INFINITY:
|
100 |
+
route.append(previous_node)
|
101 |
+
temp = previous_node
|
102 |
+
previous_node = prev[temp]
|
103 |
+
|
104 |
+
route.reverse()
|
105 |
+
return route
|
106 |
+
|
107 |
+
|
108 |
+
def h(index, destination, node_coords):
|
109 |
+
current = node_coords[index]
|
110 |
+
end = node_coords[destination]
|
111 |
+
h = abs(end[0] - current[0]) + abs(end[1] - current[1])
|
112 |
+
# h = ((end[0]-current[0])**2 + (end[1] - current[1])**2)**(1/2)
|
113 |
+
return h
|
114 |
+
|
115 |
+
|
116 |
+
def a_star(start, destination, node_coords, graph):
|
117 |
+
if start == destination:
|
118 |
+
return [], 0
|
119 |
+
if str(destination) in graph.edges[str(start)].keys():
|
120 |
+
cost = graph.edges[str(start)][str(destination)].length
|
121 |
+
return [start, destination], cost
|
122 |
+
open_list = {start}
|
123 |
+
closed_list = set([])
|
124 |
+
|
125 |
+
g = {start: 0}
|
126 |
+
parents = {start: start}
|
127 |
+
|
128 |
+
while len(open_list) > 0:
|
129 |
+
n = None
|
130 |
+
h_n = 1e5
|
131 |
+
# print('open list', open_list)
|
132 |
+
for v in open_list:
|
133 |
+
h_v = h(v, destination, node_coords)
|
134 |
+
if n is not None:
|
135 |
+
h_n = h(n, destination, node_coords)
|
136 |
+
if n is None or g[v] + h_v < g[n] + h_n:
|
137 |
+
n = v
|
138 |
+
|
139 |
+
if n is None:
|
140 |
+
print('Path does not exist!')
|
141 |
+
return None, 1e5
|
142 |
+
|
143 |
+
if n == destination:
|
144 |
+
reconst_path = []
|
145 |
+
while parents[n] != n:
|
146 |
+
reconst_path.append(n)
|
147 |
+
n = parents[n]
|
148 |
+
reconst_path.append(start)
|
149 |
+
reconst_path.reverse()
|
150 |
+
# print('Path found: {}'.format(reconst_path))
|
151 |
+
# print(g[destination])
|
152 |
+
return reconst_path, g[destination]
|
153 |
+
|
154 |
+
for edge in graph.edges[str(n)].values():
|
155 |
+
m = int(edge.to_node)
|
156 |
+
cost = edge.length
|
157 |
+
# print(m, cost)
|
158 |
+
if m not in open_list and m not in closed_list:
|
159 |
+
open_list.add(m)
|
160 |
+
parents[m] = n
|
161 |
+
g[m] = g[n] + cost
|
162 |
+
|
163 |
+
else:
|
164 |
+
if g[m] > g[n] + cost:
|
165 |
+
g[m] = g[n] + cost
|
166 |
+
parents[m] = n
|
167 |
+
|
168 |
+
if m in closed_list:
|
169 |
+
closed_list.remove(m)
|
170 |
+
open_list.add(m)
|
171 |
+
|
172 |
+
open_list.remove(n)
|
173 |
+
closed_list.add(n)
|
174 |
+
|
175 |
+
print('Path does not exist!')
|
176 |
+
return None, 1e5
|
177 |
+
|
178 |
+
|
179 |
+
|
graph_generator.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#######################################################################
|
2 |
+
# Name: graph_generator.py
|
3 |
+
#
|
4 |
+
# - Wrapper for graph.py
|
5 |
+
# - Sends the formatted inputs into graph.py to get useful info
|
6 |
+
#######################################################################
|
7 |
+
|
8 |
+
import sys
|
9 |
+
if sys.modules['TRAINING']:
|
10 |
+
from parameter import *
|
11 |
+
else:
|
12 |
+
from test_parameter import *
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
from sklearn.neighbors import NearestNeighbors
|
16 |
+
import shapely.geometry
|
17 |
+
|
18 |
+
from node import Node
|
19 |
+
from graph import Graph, a_star
|
20 |
+
|
21 |
+
|
22 |
+
class Graph_generator:
|
23 |
+
def __init__(self, map_size, k_size, sensor_range, plot=False):
|
24 |
+
self.k_size = k_size
|
25 |
+
self.graph = Graph()
|
26 |
+
self.node_coords = None
|
27 |
+
self.plot = plot
|
28 |
+
self.x = []
|
29 |
+
self.y = []
|
30 |
+
self.map_x = map_size[1]
|
31 |
+
self.map_y = map_size[0]
|
32 |
+
self.uniform_points, self.grid_coords = self.generate_uniform_points()
|
33 |
+
|
34 |
+
self.sensor_range = sensor_range
|
35 |
+
|
36 |
+
self.route_node = []
|
37 |
+
self.nodes_list = []
|
38 |
+
self.node_utility = None
|
39 |
+
self.guidepost = None
|
40 |
+
|
41 |
+
def edge_clear_all_nodes(self):
|
42 |
+
self.graph = Graph()
|
43 |
+
self.x = []
|
44 |
+
self.y = []
|
45 |
+
|
46 |
+
def edge_clear(self, coords):
|
47 |
+
node_index = str(self.find_index_from_coords(self.node_coords, coords))
|
48 |
+
self.graph.clear_edge(node_index)
|
49 |
+
|
50 |
+
def generate_graph(self, robot_belief, frontiers):
|
51 |
+
self.edge_clear_all_nodes()
|
52 |
+
free_area = self.free_area(robot_belief)
|
53 |
+
|
54 |
+
free_area_to_check = free_area[:, 0] + free_area[:, 1] * 1j
|
55 |
+
uniform_points_to_check = self.uniform_points[:, 0] + self.uniform_points[:, 1] * 1j
|
56 |
+
_, _, candidate_indices = np.intersect1d(free_area_to_check, uniform_points_to_check, return_indices=True)
|
57 |
+
node_coords = self.uniform_points[candidate_indices]
|
58 |
+
|
59 |
+
# node_coords = np.concatenate((robot_location.reshape(1, 2), node_coords))
|
60 |
+
self.node_coords = self.unique_coords(node_coords).reshape(-1, 2)
|
61 |
+
# self.find_k_neighbor_all_nodes(self.node_coords, robot_belief)
|
62 |
+
self.find_nearest_neighbor_all_nodes(self.node_coords, robot_belief)
|
63 |
+
|
64 |
+
self.node_utility = []
|
65 |
+
for coords in self.node_coords:
|
66 |
+
node = Node(coords, frontiers, robot_belief)
|
67 |
+
self.nodes_list.append(node)
|
68 |
+
utility = node.utility
|
69 |
+
self.node_utility.append(utility)
|
70 |
+
self.node_utility = np.array(self.node_utility)
|
71 |
+
|
72 |
+
self.guidepost = np.zeros((self.node_coords.shape[0], 1))
|
73 |
+
x = self.node_coords[:,0] + self.node_coords[:,1]*1j
|
74 |
+
for node in self.route_node:
|
75 |
+
index = np.argwhere(x.reshape(-1) == node[0]+node[1]*1j)[0]
|
76 |
+
self.guidepost[index] = 1
|
77 |
+
|
78 |
+
return self.node_coords, self.graph.edges, self.node_utility, self.guidepost
|
79 |
+
|
80 |
+
def update_graph(self, robot_belief, old_robot_belief, frontiers, old_frontiers):
|
81 |
+
new_free_area = self.free_area((robot_belief - old_robot_belief > 0) * 255)
|
82 |
+
free_area_to_check = new_free_area[:, 0] + new_free_area[:, 1] * 1j
|
83 |
+
uniform_points_to_check = self.uniform_points[:, 0] + self.uniform_points[:, 1] * 1j
|
84 |
+
_, _, candidate_indices = np.intersect1d(free_area_to_check, uniform_points_to_check, return_indices=True)
|
85 |
+
new_node_coords = self.uniform_points[candidate_indices]
|
86 |
+
self.node_coords = np.concatenate((self.node_coords, new_node_coords))
|
87 |
+
|
88 |
+
old_node_to_update = []
|
89 |
+
for coords in new_node_coords:
|
90 |
+
neighbor_indices = self.find_k_neighbor(coords, self.node_coords, robot_belief)
|
91 |
+
old_node_to_update += neighbor_indices
|
92 |
+
old_node_to_update = set(old_node_to_update)
|
93 |
+
for index in old_node_to_update:
|
94 |
+
coords = self.node_coords[index]
|
95 |
+
self.edge_clear(coords)
|
96 |
+
self.find_k_neighbor(coords, self.node_coords, robot_belief)
|
97 |
+
|
98 |
+
#self.edge_clear_all_nodes()
|
99 |
+
#self.find_k_neighbor_all_nodes(self.node_coords, robot_belief)
|
100 |
+
|
101 |
+
old_frontiers_to_check = old_frontiers[:, 0] + old_frontiers[:, 1] * 1j
|
102 |
+
new_frontiers_to_check = frontiers[:, 0] + frontiers[:, 1] * 1j
|
103 |
+
observed_frontiers_index = np.where(
|
104 |
+
np.isin(old_frontiers_to_check, new_frontiers_to_check, assume_unique=True) == False)
|
105 |
+
new_frontiers_index = np.where(
|
106 |
+
np.isin(new_frontiers_to_check, old_frontiers_to_check, assume_unique=True) == False)
|
107 |
+
observed_frontiers = old_frontiers[observed_frontiers_index]
|
108 |
+
new_frontiers = frontiers[new_frontiers_index]
|
109 |
+
for node in self.nodes_list:
|
110 |
+
if node.zero_utility_node is True:
|
111 |
+
pass
|
112 |
+
else:
|
113 |
+
node.update_observable_frontiers(observed_frontiers, new_frontiers, robot_belief)
|
114 |
+
|
115 |
+
for new_coords in new_node_coords:
|
116 |
+
node = Node(new_coords, frontiers, robot_belief)
|
117 |
+
self.nodes_list.append(node)
|
118 |
+
|
119 |
+
self.node_utility = []
|
120 |
+
for i, coords in enumerate(self.node_coords):
|
121 |
+
utility = self.nodes_list[i].utility
|
122 |
+
self.node_utility.append(utility)
|
123 |
+
self.node_utility = np.array(self.node_utility)
|
124 |
+
|
125 |
+
self.guidepost = np.zeros((self.node_coords.shape[0], 1))
|
126 |
+
x = self.node_coords[:, 0] + self.node_coords[:, 1] * 1j
|
127 |
+
for node in self.route_node:
|
128 |
+
index = np.argwhere(x.reshape(-1) == node[0] + node[1] * 1j)
|
129 |
+
self.guidepost[index] = 1
|
130 |
+
|
131 |
+
return self.node_coords, self.graph.edges, self.node_utility, self.guidepost
|
132 |
+
|
133 |
+
def generate_uniform_points(self):
|
134 |
+
# x = np.linspace(0, self.map_x - 1, NUM_COORDS_WIDTH).round().astype(int)
|
135 |
+
# y = np.linspace(0, self.map_y - 1, NUM_COORDS_HEIGHT).round().astype(int)
|
136 |
+
padding_x = 0.5 * (self.map_x / NUM_COORDS_WIDTH)
|
137 |
+
padding_y = 0.5 * (self.map_y / NUM_COORDS_HEIGHT)
|
138 |
+
x = np.linspace(padding_x, self.map_x - padding_x - 1, NUM_COORDS_WIDTH).round().astype(int)
|
139 |
+
y = np.linspace(padding_y, self.map_y - padding_y - 1, NUM_COORDS_HEIGHT).round().astype(int)
|
140 |
+
|
141 |
+
t1, t2 = np.meshgrid(x, y)
|
142 |
+
points = np.vstack([t1.T.ravel(), t2.T.ravel()]).T
|
143 |
+
matrix = np.stack((t1, t2), axis=-1)
|
144 |
+
return points, matrix
|
145 |
+
|
146 |
+
def free_area(self, robot_belief):
|
147 |
+
index = np.where(robot_belief == 255)
|
148 |
+
free = np.asarray([index[1], index[0]]).T
|
149 |
+
return free
|
150 |
+
|
151 |
+
def unique_coords(self, coords):
|
152 |
+
x = coords[:, 0] + coords[:, 1] * 1j
|
153 |
+
indices = np.unique(x, return_index=True)[1]
|
154 |
+
coords = np.array([coords[idx] for idx in sorted(indices)])
|
155 |
+
return coords
|
156 |
+
|
157 |
+
def find_k_neighbor(self, coords, node_coords, robot_belief):
|
158 |
+
dist_list = np.linalg.norm((coords-node_coords), axis=-1)
|
159 |
+
sorted_index = np.argsort(dist_list)
|
160 |
+
k = 0
|
161 |
+
neighbor_index_list = []
|
162 |
+
while k < self.k_size and k< node_coords.shape[0]:
|
163 |
+
neighbor_index = sorted_index[k]
|
164 |
+
neighbor_index_list.append(neighbor_index)
|
165 |
+
dist = dist_list[k]
|
166 |
+
start = coords
|
167 |
+
end = node_coords[neighbor_index]
|
168 |
+
if not self.check_collision(start, end, robot_belief):
|
169 |
+
a = str(self.find_index_from_coords(node_coords, start))
|
170 |
+
b = str(neighbor_index)
|
171 |
+
self.graph.add_node(a)
|
172 |
+
self.graph.add_edge(a, b, dist)
|
173 |
+
|
174 |
+
if self.plot:
|
175 |
+
self.x.append([start[0], end[0]])
|
176 |
+
self.y.append([start[1], end[1]])
|
177 |
+
k += 1
|
178 |
+
return neighbor_index_list
|
179 |
+
|
180 |
+
def find_k_neighbor_all_nodes(self, node_coords, robot_belief):
|
181 |
+
X = node_coords
|
182 |
+
if len(node_coords) >= self.k_size:
|
183 |
+
knn = NearestNeighbors(n_neighbors=self.k_size)
|
184 |
+
else:
|
185 |
+
knn = NearestNeighbors(n_neighbors=len(node_coords))
|
186 |
+
knn.fit(X)
|
187 |
+
distances, indices = knn.kneighbors(X)
|
188 |
+
|
189 |
+
for i, p in enumerate(X):
|
190 |
+
for j, neighbour in enumerate(X[indices[i][:]]):
|
191 |
+
start = p
|
192 |
+
end = neighbour
|
193 |
+
if not self.check_collision(start, end, robot_belief):
|
194 |
+
a = str(self.find_index_from_coords(node_coords, p))
|
195 |
+
b = str(self.find_index_from_coords(node_coords, neighbour))
|
196 |
+
self.graph.add_node(a)
|
197 |
+
self.graph.add_edge(a, b, distances[i, j])
|
198 |
+
|
199 |
+
if self.plot:
|
200 |
+
self.x.append([p[0], neighbour[0]])
|
201 |
+
self.y.append([p[1], neighbour[1]])
|
202 |
+
|
203 |
+
def find_nearest_neighbor_all_nodes(self, node_coords, robot_belief):
|
204 |
+
for i, p in enumerate(node_coords):
|
205 |
+
filtered_coords = self.get_neighbors_grid_coords(p)
|
206 |
+
|
207 |
+
for j, neighbour in enumerate(filtered_coords):
|
208 |
+
start = p
|
209 |
+
end = neighbour
|
210 |
+
if not self.check_collision(start, end, robot_belief):
|
211 |
+
a = str(self.find_index_from_coords(node_coords, p))
|
212 |
+
b = str(self.find_index_from_coords(node_coords, neighbour))
|
213 |
+
self.graph.add_node(a)
|
214 |
+
self.graph.add_edge(a, b, np.linalg.norm(start-end))
|
215 |
+
|
216 |
+
if self.plot:
|
217 |
+
self.x.append([p[0], neighbour[0]])
|
218 |
+
self.y.append([p[1], neighbour[1]])
|
219 |
+
|
220 |
+
|
221 |
+
|
222 |
+
def find_index_from_coords(self, node_coords, p):
|
223 |
+
return np.where(np.linalg.norm(node_coords - p, axis=1) < 1e-5)[0][0]
|
224 |
+
|
225 |
+
def find_closest_index_from_coords(self, node_coords, p):
|
226 |
+
return np.argmin(np.linalg.norm(node_coords - p, axis=1))
|
227 |
+
|
228 |
+
def find_index_from_grid_coords_2d(self, p):
|
229 |
+
|
230 |
+
# Calculate the distance between the target coord and each point in grid_coords
|
231 |
+
diffs = np.linalg.norm(self.grid_coords - p, axis=2) # Compute along the last axis (x, y)
|
232 |
+
indices = np.where(diffs < 1e-5)
|
233 |
+
|
234 |
+
# Return the 2D index as a tuple (row, col)
|
235 |
+
if indices[0].size > 0:
|
236 |
+
return indices[0][0], indices[1][0]
|
237 |
+
else:
|
238 |
+
raise ValueError(f"Coordinate {p} not found in self.grid_coords.")
|
239 |
+
|
240 |
+
def find_closest_index_from_grid_coords_2d(self, p):
|
241 |
+
|
242 |
+
# Calculate the distance between the target coord and each point in grid_coords
|
243 |
+
distances = np.linalg.norm(self.grid_coords - p, axis=2) # Compute along the last axis (x, y)
|
244 |
+
flat_index = np.argmin(distances)
|
245 |
+
return np.unravel_index(flat_index, distances.shape)
|
246 |
+
|
247 |
+
|
248 |
+
def check_collision(self, start, end, robot_belief):
|
249 |
+
collision = False
|
250 |
+
line = shapely.geometry.LineString([start, end])
|
251 |
+
|
252 |
+
sortx = np.sort([start[0], end[0]])
|
253 |
+
sorty = np.sort([start[1], end[1]])
|
254 |
+
|
255 |
+
# print(robot_belief.shape)
|
256 |
+
robot_belief = robot_belief[sorty[0]:sorty[1]+1, sortx[0]:sortx[1]+1]
|
257 |
+
|
258 |
+
occupied_area_index = np.where(robot_belief == 1)
|
259 |
+
occupied_area_coords = np.asarray([occupied_area_index[1]+sortx[0], occupied_area_index[0]+sorty[0]]).T
|
260 |
+
unexplored_area_index = np.where(robot_belief == 127)
|
261 |
+
unexplored_area_coords = np.asarray([unexplored_area_index[1]+sortx[0], unexplored_area_index[0]+sorty[0]]).T
|
262 |
+
unfree_area_coords = np.concatenate((occupied_area_coords, unexplored_area_coords))
|
263 |
+
|
264 |
+
# obstacles = []
|
265 |
+
for i in range(unfree_area_coords.shape[0]):
|
266 |
+
coords = ([(unfree_area_coords[i][0], unfree_area_coords[i][1]),
|
267 |
+
(unfree_area_coords[i][0] + 1, unfree_area_coords[i][1]),
|
268 |
+
(unfree_area_coords[i][0], unfree_area_coords[i][1] + 1),
|
269 |
+
(unfree_area_coords[i][0] + 1, unfree_area_coords[i][1] + 1)])
|
270 |
+
obstacle = shapely.geometry.Polygon(coords)
|
271 |
+
#obstacles.append(obstacle)
|
272 |
+
# if obstacles != []:
|
273 |
+
#all_obstacles = shapely.geometry.MultiPolygon(obstacles)
|
274 |
+
# print(obstacle.is_valid)
|
275 |
+
collision = line.intersects(obstacle)
|
276 |
+
if collision:
|
277 |
+
break
|
278 |
+
|
279 |
+
return collision
|
280 |
+
|
281 |
+
def find_shortest_path(self, current, destination, node_coords):
|
282 |
+
start_node = str(self.find_index_from_coords(node_coords, current))
|
283 |
+
end_node = str(self.find_index_from_coords(node_coords, destination))
|
284 |
+
route, dist = a_star(int(start_node), int(end_node), self.node_coords, self.graph)
|
285 |
+
if start_node != end_node:
|
286 |
+
assert route != []
|
287 |
+
# t3 = time.time()
|
288 |
+
# print(t2-t1, t3-t2)
|
289 |
+
route = list(map(str, route))
|
290 |
+
return dist, route
|
291 |
+
|
292 |
+
|
293 |
+
# def get_neighbors_grid_coords(self, coord):
|
294 |
+
# # Return the 4 closest neighbors (N, S, E, W) of a given coordinate
|
295 |
+
|
296 |
+
# nearest_coord = self.node_coords[self.find_closest_index_from_coords(self.node_coords, coord)]
|
297 |
+
# rows, cols = self.grid_coords.shape[:2]
|
298 |
+
# neighbors = []
|
299 |
+
# i, j = self.find_index_from_grid_coords_2d(nearest_coord)
|
300 |
+
|
301 |
+
# # Define NSEW neighbor offsets
|
302 |
+
# directions = [(-1, 0), (1, 0), (0, -1), (0, 1)] # N, S, W, E
|
303 |
+
|
304 |
+
# for di, dj in directions:
|
305 |
+
# ni, nj = i + di, j + dj
|
306 |
+
# if 0 <= ni < rows and 0 <= nj < cols:
|
307 |
+
# neighbors.append(tuple(self.grid_coords[ni, nj]))
|
308 |
+
|
309 |
+
# return neighbors
|
310 |
+
|
311 |
+
|
312 |
+
def get_neighbors_grid_coords(self, coord):
|
313 |
+
# Return the 8 closest neighbors of a given coordinate
|
314 |
+
|
315 |
+
nearest_coord = self.node_coords[self.find_closest_index_from_coords(self.node_coords, coord)]
|
316 |
+
rows, cols = self.grid_coords.shape[:2]
|
317 |
+
neighbors = []
|
318 |
+
i, j = self.find_index_from_grid_coords_2d(nearest_coord)
|
319 |
+
|
320 |
+
# Create a range of indices for rows and columns
|
321 |
+
row_range = np.clip([i - 1, i, i + 1], 0, rows - 1)
|
322 |
+
col_range = np.clip([j - 1, j, j + 1], 0, cols - 1)
|
323 |
+
|
324 |
+
# Iterate over the valid indices
|
325 |
+
for ni in row_range:
|
326 |
+
for nj in col_range:
|
327 |
+
if (ni, nj) != (i, j): # Skip the center point
|
328 |
+
neighbors.append(tuple(self.grid_coords[ni, nj]))
|
329 |
+
|
330 |
+
return neighbors
|
inference/model/STAGE1_vlm_search_24x24_040425_no_tgt_rewards_iNAT_DS_16k.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:44e642df9aaa2847ba44dd4707985c67ef712f5264272ef7993aeb7805c80f5a
|
3 |
+
size 52167246
|
model.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#######################################################################
|
2 |
+
# Name: model.py
|
3 |
+
#
|
4 |
+
# - Attention-based encoders & decoders
|
5 |
+
# - Policy Net: Input = Augmented Graph, Output = Node to go to
|
6 |
+
# - Critic Net: Input = Augmented Graph + Action, Output = Q_Value
|
7 |
+
#######################################################################
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import math
|
12 |
+
|
13 |
+
|
14 |
+
class SingleHeadAttention(nn.Module):
|
15 |
+
def __init__(self, embedding_dim):
|
16 |
+
super(SingleHeadAttention, self).__init__()
|
17 |
+
self.input_dim = embedding_dim
|
18 |
+
self.embedding_dim = embedding_dim
|
19 |
+
self.value_dim = embedding_dim
|
20 |
+
self.key_dim = self.value_dim
|
21 |
+
self.tanh_clipping = 10
|
22 |
+
self.norm_factor = 1 / math.sqrt(self.key_dim)
|
23 |
+
|
24 |
+
self.w_query = nn.Parameter(torch.Tensor(self.input_dim, self.key_dim))
|
25 |
+
self.w_key = nn.Parameter(torch.Tensor(self.input_dim, self.key_dim))
|
26 |
+
|
27 |
+
self.init_parameters()
|
28 |
+
|
29 |
+
def init_parameters(self):
|
30 |
+
for param in self.parameters():
|
31 |
+
stdv = 1. / math.sqrt(param.size(-1))
|
32 |
+
param.data.uniform_(-stdv, stdv)
|
33 |
+
|
34 |
+
def forward(self, q, k, mask=None):
|
35 |
+
|
36 |
+
n_batch, n_key, n_dim = k.size()
|
37 |
+
n_query = q.size(1)
|
38 |
+
|
39 |
+
k_flat = k.reshape(-1, n_dim)
|
40 |
+
q_flat = q.reshape(-1, n_dim)
|
41 |
+
|
42 |
+
shape_k = (n_batch, n_key, -1)
|
43 |
+
shape_q = (n_batch, n_query, -1)
|
44 |
+
|
45 |
+
Q = torch.matmul(q_flat, self.w_query).view(shape_q)
|
46 |
+
K = torch.matmul(k_flat, self.w_key).view(shape_k)
|
47 |
+
|
48 |
+
U = self.norm_factor * torch.matmul(Q, K.transpose(1, 2))
|
49 |
+
U = self.tanh_clipping * torch.tanh(U)
|
50 |
+
|
51 |
+
if mask is not None:
|
52 |
+
U = U.masked_fill(mask == 1, -1e8)
|
53 |
+
attention = torch.log_softmax(U, dim=-1) # n_batch*n_query*n_key
|
54 |
+
|
55 |
+
return attention
|
56 |
+
|
57 |
+
|
58 |
+
class MultiHeadAttention(nn.Module):
|
59 |
+
def __init__(self, embedding_dim, n_heads=8):
|
60 |
+
super(MultiHeadAttention, self).__init__()
|
61 |
+
self.n_heads = n_heads
|
62 |
+
self.input_dim = embedding_dim
|
63 |
+
self.embedding_dim = embedding_dim
|
64 |
+
self.value_dim = self.embedding_dim // self.n_heads
|
65 |
+
self.key_dim = self.value_dim
|
66 |
+
self.norm_factor = 1 / math.sqrt(self.key_dim)
|
67 |
+
|
68 |
+
self.w_query = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.key_dim))
|
69 |
+
self.w_key = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.key_dim))
|
70 |
+
self.w_value = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.value_dim))
|
71 |
+
self.w_out = nn.Parameter(torch.Tensor(self.n_heads, self.value_dim, self.embedding_dim))
|
72 |
+
|
73 |
+
self.init_parameters()
|
74 |
+
|
75 |
+
def init_parameters(self):
|
76 |
+
for param in self.parameters():
|
77 |
+
stdv = 1. / math.sqrt(param.size(-1))
|
78 |
+
param.data.uniform_(-stdv, stdv)
|
79 |
+
|
80 |
+
def forward(self, q, k=None, v=None, key_padding_mask=None, attn_mask=None):
|
81 |
+
if k is None:
|
82 |
+
k = q
|
83 |
+
if v is None:
|
84 |
+
v = q
|
85 |
+
|
86 |
+
n_batch, n_key, n_dim = k.size()
|
87 |
+
n_query = q.size(1)
|
88 |
+
n_value = v.size(1)
|
89 |
+
|
90 |
+
k_flat = k.contiguous().view(-1, n_dim)
|
91 |
+
v_flat = v.contiguous().view(-1, n_dim)
|
92 |
+
q_flat = q.contiguous().view(-1, n_dim)
|
93 |
+
shape_v = (self.n_heads, n_batch, n_value, -1)
|
94 |
+
shape_k = (self.n_heads, n_batch, n_key, -1)
|
95 |
+
shape_q = (self.n_heads, n_batch, n_query, -1)
|
96 |
+
|
97 |
+
Q = torch.matmul(q_flat, self.w_query).view(shape_q) # n_heads*batch_size*n_query*key_dim
|
98 |
+
K = torch.matmul(k_flat, self.w_key).view(shape_k) # n_heads*batch_size*targets_size*key_dim
|
99 |
+
V = torch.matmul(v_flat, self.w_value).view(shape_v) # n_heads*batch_size*targets_size*value_dim
|
100 |
+
|
101 |
+
U = self.norm_factor * torch.matmul(Q, K.transpose(2, 3)) # n_heads*batch_size*n_query*targets_size
|
102 |
+
|
103 |
+
if attn_mask is not None:
|
104 |
+
attn_mask = attn_mask.view(1, n_batch, n_query, n_key).expand_as(U)
|
105 |
+
|
106 |
+
if key_padding_mask is not None:
|
107 |
+
key_padding_mask = key_padding_mask.repeat(1, n_query, 1)
|
108 |
+
key_padding_mask = key_padding_mask.view(1, n_batch, n_query, n_key).expand_as(U) # copy for n_heads times
|
109 |
+
|
110 |
+
if attn_mask is not None and key_padding_mask is not None:
|
111 |
+
mask = (attn_mask + key_padding_mask)
|
112 |
+
elif attn_mask is not None:
|
113 |
+
mask = attn_mask
|
114 |
+
elif key_padding_mask is not None:
|
115 |
+
mask = key_padding_mask
|
116 |
+
else:
|
117 |
+
mask = None
|
118 |
+
|
119 |
+
if mask is not None:
|
120 |
+
U = U.masked_fill(mask > 0, -1e8)
|
121 |
+
|
122 |
+
attention = torch.softmax(U, dim=-1) # n_heads*batch_size*n_query*targets_size
|
123 |
+
|
124 |
+
heads = torch.matmul(attention, V) # n_heads*batch_size*n_query*value_dim
|
125 |
+
|
126 |
+
# out = heads.permute(1, 2, 0, 3).reshape(n_batch, n_query, n_dim)
|
127 |
+
out = torch.mm(
|
128 |
+
heads.permute(1, 2, 0, 3).reshape(-1, self.n_heads * self.value_dim),
|
129 |
+
# batch_size*n_query*n_heads*value_dim
|
130 |
+
self.w_out.view(-1, self.embedding_dim)
|
131 |
+
# n_heads*value_dim*embedding_dim
|
132 |
+
).view(-1, n_query, self.embedding_dim)
|
133 |
+
|
134 |
+
|
135 |
+
return out, attention # batch_size*n_query*embedding_dim
|
136 |
+
|
137 |
+
|
138 |
+
class Normalization(nn.Module):
|
139 |
+
def __init__(self, embedding_dim):
|
140 |
+
super(Normalization, self).__init__()
|
141 |
+
self.normalizer = nn.LayerNorm(embedding_dim)
|
142 |
+
|
143 |
+
def forward(self, input):
|
144 |
+
return self.normalizer(input.view(-1, input.size(-1))).view(*input.size())
|
145 |
+
|
146 |
+
|
147 |
+
class EncoderLayer(nn.Module):
|
148 |
+
def __init__(self, embedding_dim, n_head):
|
149 |
+
super(EncoderLayer, self).__init__()
|
150 |
+
self.multiHeadAttention = MultiHeadAttention(embedding_dim, n_head)
|
151 |
+
self.normalization1 = Normalization(embedding_dim)
|
152 |
+
self.feedForward = nn.Sequential(nn.Linear(embedding_dim, 512), nn.ReLU(inplace=True),
|
153 |
+
nn.Linear(512, embedding_dim))
|
154 |
+
self.normalization2 = Normalization(embedding_dim)
|
155 |
+
|
156 |
+
def forward(self, src, key_padding_mask=None, attn_mask=None):
|
157 |
+
h0 = src
|
158 |
+
h = self.normalization1(src)
|
159 |
+
h, _ = self.multiHeadAttention(q=h, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
160 |
+
h = h + h0
|
161 |
+
h1 = h
|
162 |
+
h = self.normalization2(h)
|
163 |
+
h = self.feedForward(h)
|
164 |
+
h2 = h + h1
|
165 |
+
return h2
|
166 |
+
|
167 |
+
|
168 |
+
class DecoderLayer(nn.Module):
|
169 |
+
def __init__(self, embedding_dim, n_head):
|
170 |
+
super(DecoderLayer, self).__init__()
|
171 |
+
self.multiHeadAttention = MultiHeadAttention(embedding_dim, n_head)
|
172 |
+
self.normalization1 = Normalization(embedding_dim)
|
173 |
+
self.feedForward = nn.Sequential(nn.Linear(embedding_dim, 512),
|
174 |
+
nn.ReLU(inplace=True),
|
175 |
+
nn.Linear(512, embedding_dim))
|
176 |
+
self.normalization2 = Normalization(embedding_dim)
|
177 |
+
|
178 |
+
def forward(self, tgt, memory, key_padding_mask=None, attn_mask=None):
|
179 |
+
h0 = tgt
|
180 |
+
tgt = self.normalization1(tgt)
|
181 |
+
memory = self.normalization1(memory)
|
182 |
+
h, w = self.multiHeadAttention(q=tgt, k=memory, v=memory, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
183 |
+
h = h + h0
|
184 |
+
h1 = h
|
185 |
+
h = self.normalization2(h)
|
186 |
+
h = self.feedForward(h)
|
187 |
+
h2 = h + h1
|
188 |
+
return h2, w
|
189 |
+
|
190 |
+
|
191 |
+
class Encoder(nn.Module):
|
192 |
+
def __init__(self, embedding_dim=128, n_head=8, n_layer=1):
|
193 |
+
super(Encoder, self).__init__()
|
194 |
+
self.layers = nn.ModuleList(EncoderLayer(embedding_dim, n_head) for i in range(n_layer))
|
195 |
+
|
196 |
+
def forward(self, src, key_padding_mask=None, attn_mask=None):
|
197 |
+
for layer in self.layers:
|
198 |
+
src = layer(src, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
199 |
+
return src
|
200 |
+
|
201 |
+
|
202 |
+
class Decoder(nn.Module):
|
203 |
+
def __init__(self, embedding_dim=128, n_head=8, n_layer=1):
|
204 |
+
super(Decoder, self).__init__()
|
205 |
+
self.layers = nn.ModuleList([DecoderLayer(embedding_dim, n_head) for i in range(n_layer)])
|
206 |
+
|
207 |
+
def forward(self, tgt, memory, key_padding_mask=None, attn_mask=None):
|
208 |
+
for layer in self.layers:
|
209 |
+
tgt, w = layer(tgt, memory, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
210 |
+
return tgt, w
|
211 |
+
|
212 |
+
|
213 |
+
class PolicyNet(nn.Module):
|
214 |
+
def __init__(self, input_dim, embedding_dim):
|
215 |
+
super(PolicyNet, self).__init__()
|
216 |
+
self.initial_embedding = nn.Linear(input_dim, embedding_dim) # layer for non-end position
|
217 |
+
|
218 |
+
self.current_embedding = nn.Linear(embedding_dim * 2, embedding_dim)
|
219 |
+
|
220 |
+
self.encoder = Encoder(embedding_dim=embedding_dim, n_head=8, n_layer=6)
|
221 |
+
self.decoder = Decoder(embedding_dim=embedding_dim, n_head=8, n_layer=1)
|
222 |
+
self.pointer = SingleHeadAttention(embedding_dim)
|
223 |
+
|
224 |
+
def encode_graph(self, node_inputs, node_padding_mask, edge_mask):
|
225 |
+
node_feature = self.initial_embedding(node_inputs)
|
226 |
+
enhanced_node_feature = self.encoder(src=node_feature, key_padding_mask=node_padding_mask, attn_mask=edge_mask)
|
227 |
+
|
228 |
+
return enhanced_node_feature
|
229 |
+
|
230 |
+
def output_policy(self, enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask):
|
231 |
+
k_size = edge_inputs.size()[2]
|
232 |
+
current_edge = torch.gather(edge_inputs, 1, current_index.repeat(1, 1, k_size))
|
233 |
+
current_edge = current_edge.permute(0, 2, 1)
|
234 |
+
embedding_dim = enhanced_node_feature.size()[2]
|
235 |
+
|
236 |
+
neigboring_feature = torch.gather(enhanced_node_feature, 1, current_edge.repeat(1, 1, embedding_dim))
|
237 |
+
|
238 |
+
current_node_feature = torch.gather(enhanced_node_feature, 1, current_index.repeat(1, 1, embedding_dim))
|
239 |
+
|
240 |
+
if edge_padding_mask is not None:
|
241 |
+
current_mask = torch.gather(edge_padding_mask, 1, current_index.repeat(1,1,k_size)).to(enhanced_node_feature.device)
|
242 |
+
# print(current_mask)
|
243 |
+
else:
|
244 |
+
current_mask = None
|
245 |
+
current_mask[:,:,0] = 1 # don't stay at current position
|
246 |
+
|
247 |
+
# ADDED: If nowhere to go, then STAY at current position
|
248 |
+
# #assert 0 in current_mask # Will cause sim to crash
|
249 |
+
if not 0 in current_mask:
|
250 |
+
current_mask[:,:,0] = 0
|
251 |
+
|
252 |
+
enhanced_current_node_feature, _ = self.decoder(current_node_feature, enhanced_node_feature, node_padding_mask)
|
253 |
+
enhanced_current_node_feature = self.current_embedding(torch.cat((enhanced_current_node_feature, current_node_feature), dim=-1))
|
254 |
+
logp = self.pointer(enhanced_current_node_feature, neigboring_feature, current_mask)
|
255 |
+
logp= logp.squeeze(1) # batch_size*k_size
|
256 |
+
|
257 |
+
return logp
|
258 |
+
|
259 |
+
def forward(self, node_inputs, edge_inputs, current_index, node_padding_mask=None, edge_padding_mask=None, edge_mask=None):
|
260 |
+
enhanced_node_feature = self.encode_graph(node_inputs, node_padding_mask, edge_mask)
|
261 |
+
logp = self.output_policy(enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask)
|
262 |
+
return logp
|
263 |
+
|
264 |
+
|
265 |
+
class QNet(nn.Module):
|
266 |
+
def __init__(self, input_dim, embedding_dim):
|
267 |
+
super(QNet, self).__init__()
|
268 |
+
self.initial_embedding = nn.Linear(input_dim, embedding_dim) # layer for non-end position
|
269 |
+
self.action_embedding = nn.Linear(embedding_dim*3, embedding_dim)
|
270 |
+
|
271 |
+
self.encoder = Encoder(embedding_dim=embedding_dim, n_head=8, n_layer=6)
|
272 |
+
self.decoder = Decoder(embedding_dim=embedding_dim, n_head=8, n_layer=1)
|
273 |
+
|
274 |
+
self.q_values_layer = nn.Linear(embedding_dim, 1)
|
275 |
+
|
276 |
+
def encode_graph(self, node_inputs, node_padding_mask, edge_mask):
|
277 |
+
embedding_feature = self.initial_embedding(node_inputs)
|
278 |
+
embedding_feature = self.encoder(src=embedding_feature, key_padding_mask=node_padding_mask, attn_mask=edge_mask)
|
279 |
+
|
280 |
+
return embedding_feature
|
281 |
+
|
282 |
+
def output_q_values(self, enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask):
|
283 |
+
k_size = edge_inputs.size()[2]
|
284 |
+
current_edge = torch.gather(edge_inputs, 1, current_index.repeat(1, 1, k_size))
|
285 |
+
current_edge = current_edge.permute(0, 2, 1)
|
286 |
+
embedding_dim = enhanced_node_feature.size()[2]
|
287 |
+
|
288 |
+
neigboring_feature = torch.gather(enhanced_node_feature, 1, current_edge.repeat(1, 1, embedding_dim))
|
289 |
+
|
290 |
+
current_node_feature = torch.gather(enhanced_node_feature, 1, current_index.repeat(1, 1, embedding_dim))
|
291 |
+
|
292 |
+
enhanced_current_node_feature, attention_weights = self.decoder(current_node_feature, enhanced_node_feature, node_padding_mask)
|
293 |
+
action_features = torch.cat((enhanced_current_node_feature.repeat(1, k_size, 1), current_node_feature.repeat(1, k_size, 1), neigboring_feature), dim=-1)
|
294 |
+
action_features = self.action_embedding(action_features)
|
295 |
+
q_values = self.q_values_layer(action_features)
|
296 |
+
|
297 |
+
if edge_padding_mask is not None:
|
298 |
+
current_mask = torch.gather(edge_padding_mask, 1, current_index.repeat(1, 1, k_size)).to(
|
299 |
+
enhanced_node_feature.device)
|
300 |
+
else:
|
301 |
+
current_mask = None
|
302 |
+
current_mask[:, :, 0] = 1 # don't stay at current position
|
303 |
+
|
304 |
+
# assert 0 in current_mask # Will cause sim to crash
|
305 |
+
if not 0 in current_mask:
|
306 |
+
current_mask[:,:,0] = 0
|
307 |
+
|
308 |
+
current_mask = current_mask.permute(0, 2, 1)
|
309 |
+
zero = torch.zeros_like(q_values).to(q_values.device)
|
310 |
+
q_values = torch.where(current_mask == 1, zero, q_values)
|
311 |
+
|
312 |
+
return q_values, attention_weights
|
313 |
+
|
314 |
+
def forward(self, node_inputs, edge_inputs, current_index, node_padding_mask=None, edge_padding_mask=None,
|
315 |
+
edge_mask=None):
|
316 |
+
enhanced_node_feature = self.encode_graph(node_inputs, node_padding_mask, edge_mask)
|
317 |
+
q_values, attention_weights = self.output_q_values(enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask)
|
318 |
+
return q_values, attention_weights
|
319 |
+
|
node.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#######################################################################
|
2 |
+
# Name: node.py
|
3 |
+
#
|
4 |
+
# - Contains info per node on graph (edge)
|
5 |
+
# - Contains: Position, Utility, Visitation History
|
6 |
+
#######################################################################
|
7 |
+
|
8 |
+
import sys
|
9 |
+
if sys.modules['TRAINING']:
|
10 |
+
from parameter import *
|
11 |
+
else:
|
12 |
+
from test_parameter import *
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import shapely.geometry
|
16 |
+
|
17 |
+
|
18 |
+
class Node():
|
19 |
+
def __init__(self, coords, frontiers, robot_belief):
|
20 |
+
self.coords = coords
|
21 |
+
self.observable_frontiers = []
|
22 |
+
self.sensor_range = SENSOR_RANGE
|
23 |
+
self.initialize_observable_frontiers(frontiers, robot_belief)
|
24 |
+
self.utility = self.get_node_utility()
|
25 |
+
if self.utility == 0:
|
26 |
+
self.zero_utility_node = True
|
27 |
+
else:
|
28 |
+
self.zero_utility_node = False
|
29 |
+
|
30 |
+
def initialize_observable_frontiers(self, frontiers, robot_belief):
|
31 |
+
dist_list = np.linalg.norm(frontiers - self.coords, axis=-1)
|
32 |
+
frontiers_in_range = frontiers[dist_list < self.sensor_range - 10]
|
33 |
+
for point in frontiers_in_range:
|
34 |
+
collision = self.check_collision(self.coords, point, robot_belief)
|
35 |
+
if not collision:
|
36 |
+
self.observable_frontiers.append(point)
|
37 |
+
|
38 |
+
def get_node_utility(self):
|
39 |
+
return len(self.observable_frontiers)
|
40 |
+
|
41 |
+
def update_observable_frontiers(self, observed_frontiers, new_frontiers, robot_belief):
|
42 |
+
if observed_frontiers != []:
|
43 |
+
observed_index = []
|
44 |
+
for i, point in enumerate(self.observable_frontiers):
|
45 |
+
if point[0] + point[1] * 1j in observed_frontiers[:, 0] + observed_frontiers[:, 1] * 1j:
|
46 |
+
observed_index.append(i)
|
47 |
+
for index in reversed(observed_index):
|
48 |
+
self.observable_frontiers.pop(index)
|
49 |
+
#
|
50 |
+
if new_frontiers != []:
|
51 |
+
dist_list = np.linalg.norm(new_frontiers - self.coords, axis=-1)
|
52 |
+
new_frontiers_in_range = new_frontiers[dist_list < self.sensor_range - 15]
|
53 |
+
for point in new_frontiers_in_range:
|
54 |
+
collision = self.check_collision(self.coords, point, robot_belief)
|
55 |
+
if not collision:
|
56 |
+
self.observable_frontiers.append(point)
|
57 |
+
|
58 |
+
self.utility = self.get_node_utility()
|
59 |
+
if self.utility == 0:
|
60 |
+
self.zero_utility_node = True
|
61 |
+
else:
|
62 |
+
self.zero_utility_node = False
|
63 |
+
|
64 |
+
def set_visited(self):
|
65 |
+
self.observable_frontiers = []
|
66 |
+
self.utility = 0
|
67 |
+
self.zero_utility_node = True
|
68 |
+
|
69 |
+
def check_collision(self, start, end, robot_belief):
|
70 |
+
collision = False
|
71 |
+
line = shapely.geometry.LineString([start, end])
|
72 |
+
|
73 |
+
sortx = np.sort([start[0], end[0]])
|
74 |
+
sorty = np.sort([start[1], end[1]])
|
75 |
+
|
76 |
+
# print(robot_belief.shape)
|
77 |
+
robot_belief = robot_belief[sorty[0]:sorty[1] + 1, sortx[0]:sortx[1] + 1]
|
78 |
+
|
79 |
+
occupied_area_index = np.where(robot_belief == 1)
|
80 |
+
occupied_area_coords = np.asarray(
|
81 |
+
[occupied_area_index[1] + sortx[0], occupied_area_index[0] + sorty[0]]).T
|
82 |
+
unexplored_area_index = np.where(robot_belief == 127)
|
83 |
+
unexplored_area_coords = np.asarray(
|
84 |
+
[unexplored_area_index[1] + sortx[0], unexplored_area_index[0] + sorty[0]]).T
|
85 |
+
unfree_area_coords = np.concatenate((occupied_area_coords, unexplored_area_coords))
|
86 |
+
|
87 |
+
# obstacles = []
|
88 |
+
for i in range(unfree_area_coords.shape[0]):
|
89 |
+
coords = ([(unfree_area_coords[i][0], unfree_area_coords[i][1]),
|
90 |
+
(unfree_area_coords[i][0] + 1, unfree_area_coords[i][1]),
|
91 |
+
(unfree_area_coords[i][0], unfree_area_coords[i][1] + 1),
|
92 |
+
(unfree_area_coords[i][0] + 1, unfree_area_coords[i][1] + 1)])
|
93 |
+
obstacle = shapely.geometry.Polygon(coords)
|
94 |
+
# obstacles.append(obstacle)
|
95 |
+
# if obstacles != []:
|
96 |
+
# all_obstacles = shapely.geometry.MultiPolygon(obstacles)
|
97 |
+
# print(obstacle.is_valid)
|
98 |
+
collision = line.intersects(obstacle)
|
99 |
+
if collision:
|
100 |
+
break
|
101 |
+
|
102 |
+
return collision
|
robot.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#######################################################################
|
2 |
+
# Name: robot.py
|
3 |
+
#
|
4 |
+
# - Stores S(t), A(t), R(t), S(t+1)
|
5 |
+
#######################################################################
|
6 |
+
|
7 |
+
from copy import deepcopy
|
8 |
+
import torch
|
9 |
+
|
10 |
+
class Robot:
|
11 |
+
def __init__(self, robot_id, position, plot=False):
|
12 |
+
self.robot_id = robot_id
|
13 |
+
self.plot = plot
|
14 |
+
self.travel_dist = 0
|
15 |
+
self.robot_position = position
|
16 |
+
self.observations = None
|
17 |
+
self.trajectory_coords = []
|
18 |
+
self.targets_found_on_path = []
|
19 |
+
|
20 |
+
self.episode_buffer = []
|
21 |
+
for i in range(15):
|
22 |
+
self.episode_buffer.append([])
|
23 |
+
|
24 |
+
if self.plot:
|
25 |
+
# initialize the route
|
26 |
+
self.xPoints = [self.robot_position[0]]
|
27 |
+
self.yPoints = [self.robot_position[1]]
|
28 |
+
|
29 |
+
def save_observations(self, observations):
|
30 |
+
node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations
|
31 |
+
self.episode_buffer[0] += deepcopy(node_inputs).to('cpu')
|
32 |
+
self.episode_buffer[1] += deepcopy(edge_inputs).to('cpu')
|
33 |
+
self.episode_buffer[2] += deepcopy(current_index).to('cpu')
|
34 |
+
self.episode_buffer[3] += deepcopy(node_padding_mask).to('cpu')
|
35 |
+
self.episode_buffer[4] += deepcopy(edge_padding_mask).to('cpu')
|
36 |
+
self.episode_buffer[5] += deepcopy(edge_mask).to('cpu')
|
37 |
+
|
38 |
+
def save_action(self, action_index):
|
39 |
+
self.episode_buffer[6] += action_index.unsqueeze(0).unsqueeze(0)
|
40 |
+
|
41 |
+
def save_reward_done(self, reward, done):
|
42 |
+
self.episode_buffer[7] += deepcopy(torch.FloatTensor([[[reward]]])).to('cpu')
|
43 |
+
self.episode_buffer[8] += deepcopy(torch.tensor([[[(int(done))]]])).to('cpu')
|
44 |
+
if self.plot:
|
45 |
+
self.xPoints.append(self.robot_position[0])
|
46 |
+
self.yPoints.append(self.robot_position[1])
|
47 |
+
|
48 |
+
def save_next_observations(self, observations):
|
49 |
+
node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations
|
50 |
+
self.episode_buffer[9] += deepcopy(node_inputs).to('cpu')
|
51 |
+
self.episode_buffer[10] += deepcopy(edge_inputs).to('cpu')
|
52 |
+
self.episode_buffer[11] += deepcopy(current_index).to('cpu')
|
53 |
+
self.episode_buffer[12] += deepcopy(node_padding_mask).to('cpu')
|
54 |
+
self.episode_buffer[13] += deepcopy(edge_padding_mask).to('cpu')
|
55 |
+
self.episode_buffer[14] += deepcopy(edge_mask).to('cpu')
|
56 |
+
|
57 |
+
# NEW: ADDED TO SAVE TRAJECTORY COORDS DURING TTA
|
58 |
+
def save_trajectory_coords(self, robot_position_coords, num_target_found):
|
59 |
+
self.trajectory_coords.append(robot_position_coords)
|
60 |
+
self.targets_found_on_path.append(num_target_found)
|
sensor.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#######################################################################
|
2 |
+
# Name: sensor.py
|
3 |
+
#
|
4 |
+
# - Computes sensor related checks (e.g. collision, utility etc)
|
5 |
+
#######################################################################
|
6 |
+
|
7 |
+
import sys
|
8 |
+
if sys.modules['TRAINING']:
|
9 |
+
from parameter import *
|
10 |
+
else:
|
11 |
+
from test_parameter import *
|
12 |
+
|
13 |
+
import math
|
14 |
+
import numpy as np
|
15 |
+
import copy
|
16 |
+
|
17 |
+
def collision_check(x0, y0, x1, y1, ground_truth, robot_belief):
|
18 |
+
x0 = x0.round()
|
19 |
+
y0 = y0.round()
|
20 |
+
x1 = x1.round()
|
21 |
+
y1 = y1.round()
|
22 |
+
dx, dy = abs(x1 - x0), abs(y1 - y0)
|
23 |
+
x, y = x0, y0
|
24 |
+
error = dx - dy
|
25 |
+
x_inc = 1 if x1 > x0 else -1
|
26 |
+
y_inc = 1 if y1 > y0 else -1
|
27 |
+
dx *= 2
|
28 |
+
dy *= 2
|
29 |
+
|
30 |
+
collision_flag = 0
|
31 |
+
max_collision = 10
|
32 |
+
|
33 |
+
while 0 <= x < ground_truth.shape[1] and 0 <= y < ground_truth.shape[0]:
|
34 |
+
k = ground_truth.item(y, x)
|
35 |
+
if k == 1 and collision_flag < max_collision:
|
36 |
+
collision_flag += 1
|
37 |
+
if collision_flag >= max_collision:
|
38 |
+
break
|
39 |
+
|
40 |
+
if k !=1 and collision_flag > 0:
|
41 |
+
break
|
42 |
+
|
43 |
+
if x == x1 and y == y1:
|
44 |
+
break
|
45 |
+
|
46 |
+
robot_belief.itemset((y, x), k)
|
47 |
+
|
48 |
+
if error > 0:
|
49 |
+
x += x_inc
|
50 |
+
error -= dy
|
51 |
+
else:
|
52 |
+
y += y_inc
|
53 |
+
error += dx
|
54 |
+
|
55 |
+
return robot_belief
|
56 |
+
|
57 |
+
|
58 |
+
def sensor_work(robot_position, sensor_range, robot_belief, ground_truth, sensor_model=SENSOR_MODEL):
|
59 |
+
x0 = robot_position[0]
|
60 |
+
y0 = robot_position[1]
|
61 |
+
rng_x = 0.5 * (ground_truth.shape[1] / NUM_COORDS_WIDTH)
|
62 |
+
rng_y = 0.5 * (ground_truth.shape[0] / NUM_COORDS_HEIGHT)
|
63 |
+
|
64 |
+
if sensor_model == "rectangular": # TODO: add collision check
|
65 |
+
max_x = min(x0 + int(math.ceil(rng_x)), ground_truth.shape[1])
|
66 |
+
min_x = max(x0 - int(math.ceil(rng_x)), 0)
|
67 |
+
max_y = min(y0 + int(math.ceil(rng_y)), ground_truth.shape[0])
|
68 |
+
min_y = max(y0 - int(math.ceil(rng_y)), 0)
|
69 |
+
robot_belief[min_y:max_y, min_x:max_x] = ground_truth[min_y:max_y, min_x:max_x]
|
70 |
+
else:
|
71 |
+
sensor_angle_inc = 0.5 / 180 * np.pi
|
72 |
+
sensor_angle = 0
|
73 |
+
while sensor_angle < 2 * np.pi:
|
74 |
+
x1 = x0 + np.cos(sensor_angle) * sensor_range
|
75 |
+
y1 = y0 + np.sin(sensor_angle) * sensor_range
|
76 |
+
robot_belief = collision_check(x0, y0, x1, y1, ground_truth, robot_belief)
|
77 |
+
sensor_angle += sensor_angle_inc
|
78 |
+
return robot_belief
|
79 |
+
|
80 |
+
|
81 |
+
def unexplored_area_check(x0, y0, x1, y1, current_belief):
|
82 |
+
x0 = x0.round()
|
83 |
+
y0 = y0.round()
|
84 |
+
x1 = x1.round()
|
85 |
+
y1 = y1.round()
|
86 |
+
dx, dy = abs(x1 - x0), abs(y1 - y0)
|
87 |
+
x, y = x0, y0
|
88 |
+
error = dx - dy
|
89 |
+
x_inc = 1 if x1 > x0 else -1
|
90 |
+
y_inc = 1 if y1 > y0 else -1
|
91 |
+
dx *= 2
|
92 |
+
dy *= 2
|
93 |
+
|
94 |
+
while 0 <= x < current_belief.shape[1] and 0 <= y < current_belief.shape[0]:
|
95 |
+
k = current_belief.item(y, x)
|
96 |
+
if x == x1 and y == y1:
|
97 |
+
break
|
98 |
+
|
99 |
+
if k == 1:
|
100 |
+
break
|
101 |
+
|
102 |
+
if k == 127:
|
103 |
+
current_belief.itemset((y, x), 0)
|
104 |
+
break
|
105 |
+
|
106 |
+
if error > 0:
|
107 |
+
x += x_inc
|
108 |
+
error -= dy
|
109 |
+
else:
|
110 |
+
y += y_inc
|
111 |
+
error += dx
|
112 |
+
|
113 |
+
return current_belief
|
114 |
+
|
115 |
+
|
116 |
+
def calculate_utility(waypoint_position, sensor_range, robot_belief):
|
117 |
+
sensor_angle_inc = 5 / 180 * np.pi
|
118 |
+
sensor_angle = 0
|
119 |
+
x0 = waypoint_position[0]
|
120 |
+
y0 = waypoint_position[1]
|
121 |
+
current_belief = copy.deepcopy(robot_belief)
|
122 |
+
while sensor_angle < 2 * np.pi:
|
123 |
+
x1 = x0 + np.cos(sensor_angle) * sensor_range
|
124 |
+
y1 = y0 + np.sin(sensor_angle) * sensor_range
|
125 |
+
current_belief = unexplored_area_check(x0, y0, x1, y1, current_belief)
|
126 |
+
sensor_angle += sensor_angle_inc
|
127 |
+
utility = np.sum(robot_belief == 127) - np.sum(current_belief == 127)
|
128 |
+
return utility
|
test_multi_robot_worker.py
ADDED
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#######################################################################
|
2 |
+
# Name: multi_robot_worker.py
|
3 |
+
#
|
4 |
+
# - Runs robot in environment for N steps
|
5 |
+
# - Collects & Returns S(t), A(t), R(t), S(t+1)
|
6 |
+
# - NOTE: Applicable for multiple robots
|
7 |
+
#######################################################################
|
8 |
+
|
9 |
+
from pathlib import Path
|
10 |
+
from test_parameter import *
|
11 |
+
|
12 |
+
from time import time
|
13 |
+
import imageio
|
14 |
+
import csv
|
15 |
+
import os
|
16 |
+
import copy
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
import matplotlib.pyplot as plt
|
20 |
+
import json
|
21 |
+
from env import Env
|
22 |
+
from model import PolicyNet
|
23 |
+
from robot import Robot
|
24 |
+
from Taxabind.Taxabind.SatBind.watershed_segmentation import WatershedBinomial
|
25 |
+
from Taxabind.Taxabind.SatBind.kmeans_clustering import CombinedSilhouetteInertiaClusterer
|
26 |
+
from Taxabind.Taxabind.SatBind.clip_seg_tta import ClipSegTTA
|
27 |
+
|
28 |
+
np.seterr(invalid='raise', divide='raise')
|
29 |
+
|
30 |
+
|
31 |
+
class TestWorker:
|
32 |
+
def __init__(self, meta_agent_id, n_agent, policy_net, global_step, device='cuda', greedy=False, save_image=False, clip_seg_tta=None):
|
33 |
+
self.device = device
|
34 |
+
self.greedy = greedy
|
35 |
+
self.n_agent = n_agent
|
36 |
+
self.metaAgentID = meta_agent_id
|
37 |
+
self.global_step = global_step
|
38 |
+
self.k_size = K_SIZE
|
39 |
+
self.save_image = save_image
|
40 |
+
self.tta = TAXABIND_TTA
|
41 |
+
self.clip_seg_tta = clip_seg_tta
|
42 |
+
|
43 |
+
self.env = Env(map_index=self.global_step, n_agent=n_agent, k_size=self.k_size, plot=save_image, test=True)
|
44 |
+
self.local_policy_net = policy_net
|
45 |
+
|
46 |
+
self.robot_list = []
|
47 |
+
self.all_robot_positions = []
|
48 |
+
for i in range(self.n_agent):
|
49 |
+
# robot_position = self.env.node_coords[i]
|
50 |
+
robot_position = self.env.start_positions[i]
|
51 |
+
robot = Robot(robot_id=i, position=robot_position, plot=save_image)
|
52 |
+
self.robot_list.append(robot)
|
53 |
+
self.all_robot_positions.append(robot_position)
|
54 |
+
|
55 |
+
self.perf_metrics = dict()
|
56 |
+
self.bad_mask_init = False
|
57 |
+
|
58 |
+
# # TEMP - EXPORT START POSES FOR BASELINES
|
59 |
+
# json_path = "eval_start_positions.json"
|
60 |
+
# sat_to_start_pose_dict = {}
|
61 |
+
# print("len(self.env.map_list): ", len(self.env.map_list))
|
62 |
+
# for i in range(4000):
|
63 |
+
# print("i: ", i)
|
64 |
+
# map_idx = i % len(self.env.map_list)
|
65 |
+
# _, map_start_position = self.env.import_ground_truth(os.path.join(self.env.map_dir, self.env.map_list[map_idx]))
|
66 |
+
# self.clip_seg_tta.reset(sample_idx=i)
|
67 |
+
# sat_path = self.clip_seg_tta.gt_mask_name
|
68 |
+
# sat_to_start_pose_dict[sat_path] = tuple(map(int, map_start_position))
|
69 |
+
# # Save to json
|
70 |
+
# with open(json_path, 'w') as f:
|
71 |
+
# json.dump(sat_to_start_pose_dict, f)
|
72 |
+
# print("len(sat_to_start_pose_dict): ", len(sat_to_start_pose_dict))
|
73 |
+
# exit()
|
74 |
+
|
75 |
+
if self.tta:
|
76 |
+
# NOTE: Moved to test_driver.py for efficiency (avoid repeated init)
|
77 |
+
# self.clip_seg_tta = ClipSegTTA(
|
78 |
+
# img_dir=TAXABIND_IMG_DIR,
|
79 |
+
# imo_dir=TAXABIND_IMO_DIR,
|
80 |
+
# json_path=TAXABIND_INAT_JSON_PATH,
|
81 |
+
# sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH,
|
82 |
+
# patch_size=TAXABIND_PATCH_SIZE,
|
83 |
+
# sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH,
|
84 |
+
# sample_index = TAXABIND_SAMPLE_INDEX, #global_step
|
85 |
+
# blur_kernel = TAXABIND_GAUSSIAN_BLUR_KERNEL,
|
86 |
+
# device=self.device,
|
87 |
+
# sat_to_img_ids_json_is_train_dict=False # for search ds val
|
88 |
+
# # sat_filtered_json_path=TAXABIND_FILTERED_INAT_JSON_PATH,
|
89 |
+
# )
|
90 |
+
if clip_seg_tta is not None:
|
91 |
+
self.clip_seg_tta.reset(sample_idx=self.global_step)
|
92 |
+
# print("Resetting for sample index: ", self.global_step)
|
93 |
+
|
94 |
+
# Override target positions in env
|
95 |
+
self.env.target_positions = [(pose[1], pose[0]) for pose in self.clip_seg_tta.target_positions] # Must transpose to (y, x) format
|
96 |
+
|
97 |
+
# Override GT seg mask for info_gain metric
|
98 |
+
if OVERRIDE_GT_MASK_DIR != "":
|
99 |
+
self.tta_gt_seg_path = os.path.join(OVERRIDE_GT_MASK_DIR, self.clip_seg_tta.gt_mask_name)
|
100 |
+
print("self.clip_seg_tta.gt_mask_name: ", self.clip_seg_tta.gt_mask_name)
|
101 |
+
if os.path.exists(self.tta_gt_seg_path):
|
102 |
+
self.env.gt_segmentation_mask = self.env.import_segmentation_mask(self.tta_gt_seg_path)
|
103 |
+
else:
|
104 |
+
print("\n\n!!!!!! WARNING: GT mask not found at path: ", self.tta_gt_seg_path)
|
105 |
+
|
106 |
+
if not USE_CLIP_PREDS and OVERRIDE_MASK_DIR != "":
|
107 |
+
# mask_name = self.clip_seg_tta.gt_mask_name.split('_')[:-TAX_HIERARCHY_TO_CONDENSE]
|
108 |
+
# mask_name = '_'.join(mask_name) + ".png"
|
109 |
+
score_mask_path = os.path.join(OVERRIDE_MASK_DIR, self.clip_seg_tta.gt_mask_name)
|
110 |
+
print("score_mask_path: ", score_mask_path)
|
111 |
+
if os.path.exists(score_mask_path):
|
112 |
+
self.env.segmentation_mask = self.env.import_segmentation_mask(score_mask_path)
|
113 |
+
self.env.begin(self.env.map_start_position)
|
114 |
+
else:
|
115 |
+
print(f"\n\n{RED}!!!!!! ERROR: Trying to override, but score mask not found at path:{NC} ", score_mask_path)
|
116 |
+
self.bad_mask_init = True
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
# # # # TEMP: Additional targets
|
121 |
+
# self.env.target_positions = [] # Reset
|
122 |
+
# print("self.env.target_positions", self.env.target_positions)
|
123 |
+
# # self.env.target_positions = [self.env.target_positions[2]]
|
124 |
+
# # self.env.target_positions = [self.env.target_positions[0]]
|
125 |
+
# # self.env.target_positions.append((251, 297))
|
126 |
+
# self.env.target_positions.append((40,40))
|
127 |
+
# self.env.target_positions.append((80,40))
|
128 |
+
# self.env.target_positions.append((120,40))
|
129 |
+
# self.env.target_positions.append((160,40))
|
130 |
+
# self.env.target_positions.append((200,40))
|
131 |
+
|
132 |
+
# Save clustered embeds from sat encoder
|
133 |
+
# In thery, we only need to do this once (same satellite map throughout)
|
134 |
+
if USE_CLIP_PREDS:
|
135 |
+
self.kmeans_clusterer = CombinedSilhouetteInertiaClusterer(
|
136 |
+
k_min=1,
|
137 |
+
k_max=8,
|
138 |
+
k_avg_max=4,
|
139 |
+
silhouette_threshold=0.15,
|
140 |
+
relative_threshold=0.15,
|
141 |
+
random_state=0,
|
142 |
+
min_patch_size=5, # smoothing parameter
|
143 |
+
n_smooth_iter=2, # smoothing parameter
|
144 |
+
ignore_label=-1,
|
145 |
+
plot=self.save_image,
|
146 |
+
gifs_dir = gifs_path
|
147 |
+
)
|
148 |
+
# Fit & predict (this will also plot the clusters before & after smoothing)
|
149 |
+
map_shape = (int(np.sqrt(self.clip_seg_tta.patch_embeds.shape[0])), int(np.sqrt(self.clip_seg_tta.patch_embeds.shape[0])))
|
150 |
+
self.kmeans_sat_embeds_clusters = self.kmeans_clusterer.fit_predict(
|
151 |
+
patch_embeds=self.clip_seg_tta.patch_embeds,
|
152 |
+
map_shape=map_shape,
|
153 |
+
)
|
154 |
+
print("Chosen k:", self.kmeans_clusterer.final_k)
|
155 |
+
# print("Smoothed labels shape:", self.kmeans_sat_embeds_clusters.shape)
|
156 |
+
|
157 |
+
if EXECUTE_TTA:
|
158 |
+
print("Will execute TTA...")
|
159 |
+
|
160 |
+
# Define Poisson TTA params
|
161 |
+
self.step_since_tta = 0
|
162 |
+
self.steps_to_first_tgt = None
|
163 |
+
self.steps_to_mid_tgt = None
|
164 |
+
self.steps_to_last_tgt = None
|
165 |
+
|
166 |
+
self.sim_0perc = None
|
167 |
+
self.sim_25perc = None
|
168 |
+
self.sim_50perc = None
|
169 |
+
self.sim_75perc = None
|
170 |
+
self.sim_100perc = None
|
171 |
+
|
172 |
+
|
173 |
+
def run_episode(self, curr_episode):
|
174 |
+
|
175 |
+
# Return all metrics as None if faulty mask init
|
176 |
+
if self.bad_mask_init:
|
177 |
+
self.perf_metrics['tax'] = None
|
178 |
+
self.perf_metrics['tax_first'] = None
|
179 |
+
self.perf_metrics['travel_dist'] = None
|
180 |
+
self.perf_metrics['travel_steps'] = None
|
181 |
+
self.perf_metrics['steps_to_first_tgt'] = None
|
182 |
+
self.perf_metrics['steps_to_mid_tgt'] = None
|
183 |
+
self.perf_metrics['steps_to_last_tgt'] = None
|
184 |
+
self.perf_metrics['explored_rate'] = None
|
185 |
+
self.perf_metrics['targets_found'] = None
|
186 |
+
self.perf_metrics['targets_total'] = None
|
187 |
+
self.perf_metrics['sim_0perc'] = None
|
188 |
+
self.perf_metrics['sim_25perc'] = None
|
189 |
+
self.perf_metrics['sim_50perc'] = None
|
190 |
+
self.perf_metrics['sim_75perc'] = None
|
191 |
+
self.perf_metrics['sim_100perc'] = None
|
192 |
+
self.perf_metrics['kmeans_k'] = None
|
193 |
+
self.perf_metrics['tgts_gt_score'] = None
|
194 |
+
self.perf_metrics['clip_inference_time'] = None
|
195 |
+
self.perf_metrics['tta_time'] = None
|
196 |
+
self.perf_metrics['info_gain'] = None
|
197 |
+
self.perf_metrics['total_info'] = None
|
198 |
+
self.perf_metrics['success_rate'] = None
|
199 |
+
return
|
200 |
+
|
201 |
+
eps_start = time()
|
202 |
+
done = False
|
203 |
+
for robot_id, deciding_robot in enumerate(self.robot_list):
|
204 |
+
deciding_robot.observations = self.get_observations(deciding_robot.robot_position)
|
205 |
+
if self.tta and USE_CLIP_PREDS:
|
206 |
+
self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1)
|
207 |
+
self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1)
|
208 |
+
# print("self.env.segmentation_info_mask.shape", self.env.segmentation_info_mask.shape)
|
209 |
+
|
210 |
+
### Run episode for 128 steps ###
|
211 |
+
for step in range(NUM_EPS_STEPS):
|
212 |
+
|
213 |
+
# print("\n\n\n~~~~~~~~~~~~~~~~~~~~~~ Step: ", step, " ~~~~~~~~~~~~~~~~~~~~~~")
|
214 |
+
|
215 |
+
next_position_list = []
|
216 |
+
dist_list = []
|
217 |
+
travel_dist_list = []
|
218 |
+
dist_array = np.zeros((self.n_agent, 1))
|
219 |
+
for robot_id, deciding_robot in enumerate(self.robot_list):
|
220 |
+
observations = deciding_robot.observations
|
221 |
+
# if self.env.node_coords.shape[0] >= self.k_size:
|
222 |
+
# deciding_robot.save_observations(observations)
|
223 |
+
|
224 |
+
### Forward pass through policy to get next position ###
|
225 |
+
next_position, action_index = self.select_node(observations)
|
226 |
+
# if self.env.node_coords.shape[0] >= self.k_size:
|
227 |
+
# deciding_robot.save_action(action_index)
|
228 |
+
|
229 |
+
dist = np.linalg.norm(next_position - deciding_robot.robot_position)
|
230 |
+
|
231 |
+
### Log results of action (e.g. distance travelled) ###
|
232 |
+
dist_array[robot_id] = dist
|
233 |
+
dist_list.append(dist)
|
234 |
+
travel_dist_list.append(deciding_robot.travel_dist)
|
235 |
+
next_position_list.append(next_position)
|
236 |
+
self.all_robot_positions[robot_id] = next_position
|
237 |
+
|
238 |
+
arriving_sequence = np.argsort(dist_list)
|
239 |
+
next_position_list = np.array(next_position_list)
|
240 |
+
dist_list = np.array(dist_list)
|
241 |
+
travel_dist_list = np.array(travel_dist_list)
|
242 |
+
next_position_list = next_position_list[arriving_sequence]
|
243 |
+
dist_list = dist_list[arriving_sequence]
|
244 |
+
travel_dist_list = travel_dist_list[arriving_sequence]
|
245 |
+
|
246 |
+
### Take Action (Deconflict if 2 agents choose the same target position) ###
|
247 |
+
next_position_list, dist_list = self.solve_conflict(arriving_sequence, next_position_list, dist_list)
|
248 |
+
# dist_travelled = np.linalg.norm(next_position - deciding_robot.robot_position)
|
249 |
+
# deciding_robot.travel_dist += dist_travelled
|
250 |
+
# deciding_robot.robot_position = next_position
|
251 |
+
reward_list, done = self.env.multi_robot_step(next_position_list, dist_list, travel_dist_list)
|
252 |
+
|
253 |
+
### Update observations + rewards from action ###
|
254 |
+
for reward, robot_id in zip(reward_list, arriving_sequence):
|
255 |
+
robot = self.robot_list[robot_id]
|
256 |
+
robot.save_trajectory_coords(self.env.find_index_from_coords(robot.robot_position), self.env.num_new_targets_found)
|
257 |
+
|
258 |
+
# # TTA Update via Poisson Test (with KMeans clustering stats)
|
259 |
+
if self.tta and USE_CLIP_PREDS and EXECUTE_TTA:
|
260 |
+
self.poisson_tta_update(robot, self.global_step, step)
|
261 |
+
|
262 |
+
robot.observations = self.get_observations(robot.robot_position)
|
263 |
+
# if self.env.node_coords.shape[0] >= self.k_size:
|
264 |
+
robot.save_reward_done(reward, done)
|
265 |
+
# robot.save_next_observations(robot.observations)
|
266 |
+
|
267 |
+
# Update metrics
|
268 |
+
# NOTE: For 1 robot for now
|
269 |
+
self.log_metrics(step=step) # robot.targets_found_on_path)
|
270 |
+
|
271 |
+
|
272 |
+
### Save a frame to generate gif of robot trajectories ###
|
273 |
+
if self.save_image:
|
274 |
+
robots_route = []
|
275 |
+
for robot in self.robot_list:
|
276 |
+
robots_route.append([robot.xPoints, robot.yPoints])
|
277 |
+
if not os.path.exists(gifs_path):
|
278 |
+
os.makedirs(gifs_path)
|
279 |
+
sound_id_override = None if self.clip_seg_tta.sound_ids == [] else self.clip_seg_tta.sound_ids[0]
|
280 |
+
if TAXABIND_TTA and USE_CLIP_PREDS:
|
281 |
+
self.env.plot_env(
|
282 |
+
self.global_step,
|
283 |
+
gifs_path,
|
284 |
+
step,
|
285 |
+
max(travel_dist_list),
|
286 |
+
robots_route,
|
287 |
+
img_path_override=self.clip_seg_tta.img_paths[0], # Viz 1st
|
288 |
+
sat_path_override=self.clip_seg_tta.imo_path,
|
289 |
+
msk_name_override=self.clip_seg_tta.species_name,
|
290 |
+
sound_id_override=sound_id_override,
|
291 |
+
colormap_mid_val=np.max(self.clip_seg_tta.heatmap_unnormalized_initial)
|
292 |
+
)
|
293 |
+
else:
|
294 |
+
self.env.plot_env(
|
295 |
+
self.global_step,
|
296 |
+
gifs_path,
|
297 |
+
step,
|
298 |
+
max(travel_dist_list),
|
299 |
+
robots_route,
|
300 |
+
img_path_override=self.clip_seg_tta.img_paths[0], # Viz 1st
|
301 |
+
sat_path_override=self.clip_seg_tta.imo_path,
|
302 |
+
msk_name_override=self.clip_seg_tta.species_name,
|
303 |
+
sound_id_override=sound_id_override,
|
304 |
+
)
|
305 |
+
|
306 |
+
if done:
|
307 |
+
break
|
308 |
+
|
309 |
+
if self.tta:
|
310 |
+
tax = Path(self.clip_seg_tta.gt_mask_name).stem
|
311 |
+
self.perf_metrics['tax'] = " ".join(tax.split("_")[1:])
|
312 |
+
self.perf_metrics['tax_first'] = tax.split("_")[1]
|
313 |
+
else:
|
314 |
+
self.perf_metrics['tax'] = None
|
315 |
+
self.perf_metrics['tax_first'] = None
|
316 |
+
self.perf_metrics['travel_dist'] = max(travel_dist_list)
|
317 |
+
self.perf_metrics['travel_steps'] = step + 1
|
318 |
+
self.perf_metrics['steps_to_first_tgt'] = self.steps_to_first_tgt
|
319 |
+
self.perf_metrics['steps_to_mid_tgt'] = self.steps_to_mid_tgt
|
320 |
+
self.perf_metrics['steps_to_last_tgt'] = self.steps_to_last_tgt
|
321 |
+
self.perf_metrics['explored_rate'] = self.env.explored_rate
|
322 |
+
self.perf_metrics['targets_found'] = self.env.targets_found_rate
|
323 |
+
self.perf_metrics['targets_total'] = len(self.env.target_positions)
|
324 |
+
self.perf_metrics['sim_0perc'] = self.sim_0perc
|
325 |
+
self.perf_metrics['sim_25perc'] = self.sim_25perc
|
326 |
+
self.perf_metrics['sim_50perc'] = self.sim_50perc
|
327 |
+
self.perf_metrics['sim_75perc'] = self.sim_75perc
|
328 |
+
self.perf_metrics['sim_100perc'] = self.sim_100perc
|
329 |
+
if USE_CLIP_PREDS:
|
330 |
+
self.perf_metrics['kmeans_k'] = self.kmeans_clusterer.final_k
|
331 |
+
self.perf_metrics['tgts_gt_score'] = self.clip_seg_tta.tgts_gt_score
|
332 |
+
self.perf_metrics['clip_inference_time'] = self.clip_seg_tta.clip_inference_time
|
333 |
+
self.perf_metrics['tta_time'] = self.clip_seg_tta.tta_time
|
334 |
+
else:
|
335 |
+
self.perf_metrics['kmeans_k'] = None
|
336 |
+
self.perf_metrics['tgts_gt_score'] = None
|
337 |
+
self.perf_metrics['clip_inference_time'] = None
|
338 |
+
self.perf_metrics['tta_time'] = None
|
339 |
+
if OVERRIDE_GT_MASK_DIR != "" and os.path.exists(self.tta_gt_seg_path):
|
340 |
+
self.perf_metrics['info_gain'] = self.env.info_gain
|
341 |
+
self.perf_metrics['total_info'] = self.env.total_info
|
342 |
+
else:
|
343 |
+
self.perf_metrics['info_gain'] = None
|
344 |
+
self.perf_metrics['total_info'] = None
|
345 |
+
if FORCE_LOGGING_DONE_TGTS_FOUND and self.env.targets_found_rate == 1.0:
|
346 |
+
self.perf_metrics['success_rate'] = True
|
347 |
+
else:
|
348 |
+
self.perf_metrics['success_rate'] = done
|
349 |
+
|
350 |
+
# save gif
|
351 |
+
if self.save_image:
|
352 |
+
path = gifs_path
|
353 |
+
self.make_gif(path, curr_episode)
|
354 |
+
|
355 |
+
print(YELLOW, f"[Eps {curr_episode} Completed] Time Taken: {time()-eps_start:.2f}s, Steps: {step+1}", NC)
|
356 |
+
|
357 |
+
def get_observations(self, robot_position):
|
358 |
+
""" Get robot's sensor observation of environment given position """
|
359 |
+
current_node_index = self.env.find_index_from_coords(robot_position)
|
360 |
+
current_index = torch.tensor([current_node_index]).unsqueeze(0).unsqueeze(0).to(self.device) # (1,1,1)
|
361 |
+
|
362 |
+
node_coords = copy.deepcopy(self.env.node_coords)
|
363 |
+
graph = copy.deepcopy(self.env.graph)
|
364 |
+
node_utility = copy.deepcopy(self.env.node_utility)
|
365 |
+
guidepost = copy.deepcopy(self.env.guidepost)
|
366 |
+
# segmentation_info_mask = copy.deepcopy(self.env.segmentation_info_mask)
|
367 |
+
segmentation_info_mask = copy.deepcopy(self.env.filtered_seg_info_mask)
|
368 |
+
|
369 |
+
# ADDED - SEGMENTATION INFORATION MASK
|
370 |
+
n_nodes = node_coords.shape[0]
|
371 |
+
|
372 |
+
node_coords = node_coords / 640
|
373 |
+
node_utility = node_utility / 50
|
374 |
+
|
375 |
+
node_utility_inputs = node_utility.reshape((n_nodes, 1))
|
376 |
+
|
377 |
+
occupied_node = np.zeros((n_nodes, 1))
|
378 |
+
for position in self.all_robot_positions:
|
379 |
+
index = self.env.find_index_from_coords(position)
|
380 |
+
if index == current_index.item():
|
381 |
+
occupied_node[index] = -1
|
382 |
+
else:
|
383 |
+
occupied_node[index] = 1
|
384 |
+
|
385 |
+
# node_inputs = np.concatenate((node_coords, node_utility_inputs, guidepost, occupied_node), axis=1)
|
386 |
+
node_inputs = np.concatenate((node_coords, segmentation_info_mask, guidepost), axis=1)
|
387 |
+
# node_inputs = np.concatenate((node_coords, segmentation_info_mask, guidepost), axis=1)
|
388 |
+
node_inputs = torch.FloatTensor(node_inputs).unsqueeze(0).to(self.device) # (1, node_padding_size+1, 3)
|
389 |
+
|
390 |
+
# assert node_coords.shape[0] < self.node_padding_size
|
391 |
+
# padding = torch.nn.ZeroPad2d((0, 0, 0, self.node_padding_size - node_coords.shape[0]))
|
392 |
+
# node_inputs = padding(node_inputs)
|
393 |
+
|
394 |
+
# node_padding_mask = torch.zeros((1, 1, node_coords.shape[0]), dtype=torch.int64).to(self.device)
|
395 |
+
# node_padding = torch.ones((1, 1, self.node_padding_size - node_coords.shape[0]), dtype=torch.int64).to(
|
396 |
+
# self.device)
|
397 |
+
# node_padding_mask = torch.cat((node_padding_mask, node_padding), dim=-1)
|
398 |
+
node_padding_mask = None
|
399 |
+
|
400 |
+
graph = list(graph.values())
|
401 |
+
edge_inputs = []
|
402 |
+
for node in graph:
|
403 |
+
node_edges = list(map(int, node))
|
404 |
+
edge_inputs.append(node_edges)
|
405 |
+
|
406 |
+
bias_matrix = self.calculate_edge_mask(edge_inputs)
|
407 |
+
edge_mask = torch.from_numpy(bias_matrix).float().unsqueeze(0).to(self.device)
|
408 |
+
|
409 |
+
# assert len(edge_inputs) < self.node_padding_size
|
410 |
+
# padding = torch.nn.ConstantPad2d(
|
411 |
+
# (0, self.node_padding_size - len(edge_inputs), 0, self.node_padding_size - len(edge_inputs)), 1)
|
412 |
+
# edge_mask = padding(edge_mask)
|
413 |
+
# padding2 = torch.nn.ZeroPad2d((0, 0, 0, self.node_padding_size - len(edge_inputs)))
|
414 |
+
|
415 |
+
for edges in edge_inputs:
|
416 |
+
while len(edges) < self.k_size:
|
417 |
+
edges.append(0)
|
418 |
+
|
419 |
+
edge_inputs = torch.tensor(edge_inputs).unsqueeze(0).to(self.device) # (1, node_padding_size+1, k_size)
|
420 |
+
# edge_inputs = padding2(edge_inputs)
|
421 |
+
|
422 |
+
edge_padding_mask = torch.zeros((1, len(edge_inputs), K_SIZE), dtype=torch.int64).to(self.device)
|
423 |
+
one = torch.ones_like(edge_padding_mask, dtype=torch.int64).to(self.device)
|
424 |
+
edge_padding_mask = torch.where(edge_inputs == 0, one, edge_padding_mask)
|
425 |
+
|
426 |
+
observations = node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask
|
427 |
+
return observations
|
428 |
+
|
429 |
+
|
430 |
+
def select_node(self, observations):
|
431 |
+
""" Forward pass through policy to get next position to go to on map """
|
432 |
+
node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations
|
433 |
+
with torch.no_grad():
|
434 |
+
logp_list = self.local_policy_net(node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask)
|
435 |
+
|
436 |
+
if self.greedy:
|
437 |
+
action_index = torch.argmax(logp_list, dim=1).long()
|
438 |
+
else:
|
439 |
+
action_index = torch.multinomial(logp_list.exp(), 1).long().squeeze(1)
|
440 |
+
|
441 |
+
next_node_index = edge_inputs[:, current_index.item(), action_index.item()]
|
442 |
+
|
443 |
+
next_position = self.env.node_coords[next_node_index]
|
444 |
+
|
445 |
+
return next_position, action_index
|
446 |
+
|
447 |
+
def solve_conflict(self, arriving_sequence, next_position_list, dist_list):
|
448 |
+
""" Deconflict if 2 agents choose the same target position """
|
449 |
+
for j, [robot_id, next_position] in enumerate(zip(arriving_sequence, next_position_list)):
|
450 |
+
moving_robot = self.robot_list[robot_id]
|
451 |
+
# if next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]:
|
452 |
+
# dist_to_next_position = np.argsort(np.linalg.norm(self.env.node_coords - next_position, axis=1))
|
453 |
+
# k = 0
|
454 |
+
# while next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]:
|
455 |
+
# k += 1
|
456 |
+
# next_position = self.env.node_coords[dist_to_next_position[k]]
|
457 |
+
|
458 |
+
dist = np.linalg.norm(next_position - moving_robot.robot_position)
|
459 |
+
next_position_list[j] = next_position
|
460 |
+
dist_list[j] = dist
|
461 |
+
moving_robot.travel_dist += dist
|
462 |
+
moving_robot.robot_position = next_position
|
463 |
+
|
464 |
+
return next_position_list, dist_list
|
465 |
+
|
466 |
+
def work(self, currEpisode):
|
467 |
+
'''
|
468 |
+
Interacts with the environment. The agent gets either gradients or experience buffer
|
469 |
+
'''
|
470 |
+
self.run_episode(currEpisode)
|
471 |
+
|
472 |
+
def calculate_edge_mask(self, edge_inputs):
|
473 |
+
size = len(edge_inputs)
|
474 |
+
bias_matrix = np.ones((size, size))
|
475 |
+
for i in range(size):
|
476 |
+
for j in range(size):
|
477 |
+
if j in edge_inputs[i]:
|
478 |
+
bias_matrix[i][j] = 0
|
479 |
+
return bias_matrix
|
480 |
+
|
481 |
+
def make_gif(self, path, n):
|
482 |
+
""" Generate a gif given list of images """
|
483 |
+
with imageio.get_writer('{}/{}_target_rate_{:.2f}.gif'.format(path, n, self.env.targets_found_rate), mode='I',
|
484 |
+
fps=5) as writer:
|
485 |
+
for frame in self.env.frame_files:
|
486 |
+
image = imageio.imread(frame)
|
487 |
+
writer.append_data(image)
|
488 |
+
print('gif complete\n')
|
489 |
+
|
490 |
+
# Remove files
|
491 |
+
for filename in self.env.frame_files[:-1]:
|
492 |
+
os.remove(filename)
|
493 |
+
|
494 |
+
# For watershed segmenter gif during TTA
|
495 |
+
if self.tta:
|
496 |
+
|
497 |
+
# print("self.kmeans_clusterer.kmeans_frame_files", self.kmeans_clusterer.kmeans_frame_files)
|
498 |
+
with imageio.get_writer('{}/{}_kmeans_stats.gif'.format(path, n), mode='I',
|
499 |
+
fps=5) as writer:
|
500 |
+
for frame in self.kmeans_clusterer.kmeans_frame_files:
|
501 |
+
image = imageio.imread(frame)
|
502 |
+
writer.append_data(image)
|
503 |
+
print('Kmeans Clusterer gif complete\n')
|
504 |
+
|
505 |
+
# Remove files
|
506 |
+
for filename in self.kmeans_clusterer.kmeans_frame_files[:-1]:
|
507 |
+
os.remove(filename)
|
508 |
+
|
509 |
+
################################################################################
|
510 |
+
# ADDED
|
511 |
+
################################################################################
|
512 |
+
|
513 |
+
def log_metrics(self, step):
|
514 |
+
# Update tgt found metrics
|
515 |
+
if self.steps_to_first_tgt is None and self.env.num_targets_found == 1:
|
516 |
+
self.steps_to_first_tgt = step + 1
|
517 |
+
if self.steps_to_mid_tgt is None and self.env.num_targets_found == int(len(self.env.target_positions) / 2):
|
518 |
+
self.steps_to_mid_tgt = step + 1
|
519 |
+
if self.steps_to_last_tgt is None and self.env.num_targets_found == len(self.env.target_positions):
|
520 |
+
self.steps_to_last_tgt = step + 1
|
521 |
+
|
522 |
+
# Update sim metrics
|
523 |
+
if OVERRIDE_GT_MASK_DIR != "" and os.path.exists(self.tta_gt_seg_path):
|
524 |
+
side_dim = int(np.sqrt(self.env.segmentation_info_mask.shape[0]))
|
525 |
+
pred_mask = self.env.segmentation_info_mask.squeeze().reshape((side_dim, side_dim)).T
|
526 |
+
gt_mask = self.env.gt_segmentation_info_mask.squeeze().reshape((side_dim, side_dim)).T
|
527 |
+
if step == 0:
|
528 |
+
self.sim_0perc = self.norm_distance(pred_mask, gt_mask)
|
529 |
+
elif step == int(NUM_EPS_STEPS * 0.25):
|
530 |
+
self.sim_25perc = self.norm_distance(pred_mask, gt_mask)
|
531 |
+
elif step == int(NUM_EPS_STEPS * 0.5):
|
532 |
+
self.sim_50perc = self.norm_distance(pred_mask, gt_mask)
|
533 |
+
elif step == int(NUM_EPS_STEPS * 0.75):
|
534 |
+
self.sim_75perc = self.norm_distance(pred_mask, gt_mask)
|
535 |
+
elif step == NUM_EPS_STEPS - 1:
|
536 |
+
self.sim_100perc = self.norm_distance(pred_mask, gt_mask)
|
537 |
+
|
538 |
+
def norm_distance(self, P, Q, norm_type="L2"):
|
539 |
+
|
540 |
+
# Normalize both grids to [0,1]
|
541 |
+
try:
|
542 |
+
if P.max() != P.min():
|
543 |
+
P_norm = (P - P.min()) / (P.max() - P.min())
|
544 |
+
else:
|
545 |
+
P_norm = P
|
546 |
+
if Q.max() != Q.min():
|
547 |
+
Q_norm = (Q - Q.min()) / (Q.max() - Q.min())
|
548 |
+
else:
|
549 |
+
Q_norm = Q
|
550 |
+
except FloatingPointError as e:
|
551 |
+
print(f"{RED}Caught floating point error:{NC} {e}")
|
552 |
+
print("P min/max:", P.min(), P.max())
|
553 |
+
print("Q min/max:", Q.min(), Q.max())
|
554 |
+
print("Q: ", Q)
|
555 |
+
similarity = None
|
556 |
+
return similarity
|
557 |
+
|
558 |
+
if norm_type == "L1":
|
559 |
+
num_cells = P.shape[0] * P.shape[1]
|
560 |
+
# L1 distance: sum of absolute differences
|
561 |
+
l1_dist = np.sum(np.abs(P_norm - Q_norm))
|
562 |
+
# Normalize: maximum L1 distance is num_cells (if every cell differs by 1)
|
563 |
+
similarity = 1 - (l1_dist / num_cells)
|
564 |
+
elif norm_type == "L2":
|
565 |
+
# L2 distance via Root Mean Squared Error (RMSE)
|
566 |
+
rmse = np.sqrt(np.mean((P_norm - Q_norm)**2))
|
567 |
+
# Since both grids are in [0,1], maximum RMSE is 1.
|
568 |
+
similarity = 1 - rmse
|
569 |
+
else:
|
570 |
+
raise ValueError("norm_type must be either 'L1' or 'L2'")
|
571 |
+
|
572 |
+
return similarity
|
573 |
+
|
574 |
+
def transpose_flat_idx(self, idx, H= NUM_COORDS_HEIGHT, W=NUM_COORDS_WIDTH):
|
575 |
+
"""
|
576 |
+
Given a flattened index X in an NxN matrix,
|
577 |
+
return the new index X' after transposing the matrix.
|
578 |
+
"""
|
579 |
+
row = idx // W
|
580 |
+
col = idx % W
|
581 |
+
idx_T = col * H + row
|
582 |
+
return idx_T
|
583 |
+
|
584 |
+
def poisson_tta_update(self, robot, episode, step):
|
585 |
+
|
586 |
+
# TODO: Move into TTA loop to save computation
|
587 |
+
# Generate Kmeans Clusters Stats
|
588 |
+
visited_indices = [self.transpose_flat_idx(idx) for idx in robot.trajectory_coords]
|
589 |
+
region_stats_dict = self.kmeans_clusterer.compute_region_statistics(
|
590 |
+
self.kmeans_sat_embeds_clusters,
|
591 |
+
self.clip_seg_tta.heatmap_unnormalized,
|
592 |
+
visited_indices,
|
593 |
+
episode_num=episode,
|
594 |
+
step_num=step
|
595 |
+
)
|
596 |
+
|
597 |
+
# Prep & execute TTA
|
598 |
+
self.step_since_tta += 1
|
599 |
+
if robot.targets_found_on_path[-1] or self.step_since_tta % STEPS_PER_TTA == 0: # Allow even if no positive at start
|
600 |
+
|
601 |
+
# for _ in range(NUM_TTA_STEPS):
|
602 |
+
|
603 |
+
filt_traj_coords = [self.transpose_flat_idx(idx) for idx in robot.trajectory_coords]
|
604 |
+
filt_targets_found_on_path = robot.targets_found_on_path
|
605 |
+
num_cells = self.clip_seg_tta.heatmap.shape[0] * self.clip_seg_tta.heatmap.shape[1]
|
606 |
+
# per_sample_weight_scale = [min(1.0, (step/num_cells)) for _ in filt_traj_coords]
|
607 |
+
# per_sample_weight_scale = [0.5 * min(1.0, (step/100)) for _ in filt_traj_coords]
|
608 |
+
# per_sample_weight_scale = [1.0 for _ in filt_traj_coords]
|
609 |
+
pos_sample_weight_scale, neg_sample_weight_scale = [], []
|
610 |
+
for i, sample_loc in enumerate(filt_traj_coords):
|
611 |
+
label = self.kmeans_clusterer.get_label_id(sample_loc)
|
612 |
+
num_patches = region_stats_dict[label]['num_patches']
|
613 |
+
patches_visited = region_stats_dict[label]['patches_visited']
|
614 |
+
expectation = region_stats_dict[label]['expectation']
|
615 |
+
|
616 |
+
## BEST so far: exponent like focal loss to wait for more samples before confidently decreasing
|
617 |
+
pos_weight = 4.0 # 2.0
|
618 |
+
# pos_weight = 1.0 + 4.0 * min(1.0, (patches_visited/(num_patches))**GAMMA_EXPONENT) # (1,5)
|
619 |
+
# pos_weight = 1.0 + 4.0 * min(1.0, (patches_visited/(3*expectation))**GAMMA_EXPONENT)
|
620 |
+
# neg_weight = min(1.0, (patches_visited/(3*expectation))**GAMMA_EXPONENT)
|
621 |
+
neg_weight = min(1.0, (patches_visited/(3*num_patches))**GAMMA_EXPONENT) # (0,1)
|
622 |
+
pos_sample_weight_scale.append(pos_weight)
|
623 |
+
neg_sample_weight_scale.append(neg_weight)
|
624 |
+
|
625 |
+
## Prelim throughts (BAD - quickly reduce low probs region even with little samples)
|
626 |
+
# neg_weight = min(1.0, patches_visited/(3*expectation))
|
627 |
+
# local_probs = self.kmeans_clusterer.get_probs(sample_loc, self.clip_seg_tta.heatmap)
|
628 |
+
# neg_weight = min(local_probs/2, patches_visited/(3*expectation)) # 2*expectation (if don't want TTA scheduler - 3x TTA)
|
629 |
+
# neg_weight = min(1.0, patches_visited/num_patches) # 2*expectation (if don't want TTA scheduler - 3x TTA)
|
630 |
+
|
631 |
+
## Hacky, but works better (does not decrase low probs region too fast)
|
632 |
+
# if label == 0:
|
633 |
+
# neg_weight = min(0.5, patches_visited/(3*expectation))
|
634 |
+
# else:
|
635 |
+
# neg_weight = min(0.05, patches_visited/(3*expectation))
|
636 |
+
# squared
|
637 |
+
|
638 |
+
# # # Adaptative LR (as samples increase, increase LR to fit more datapoints - else won't update)
|
639 |
+
adaptive_lr = MIN_LR + (MAX_LR - MIN_LR) * (step / num_cells)
|
640 |
+
# print("!!! adaptive_lr", adaptive_lr)
|
641 |
+
# adaptive_lr = 2e-6
|
642 |
+
|
643 |
+
# NOTE: Not as good as adaptive LR (cos discrete)
|
644 |
+
# # Num TTA teps schedulerq
|
645 |
+
# min_tta_steps = 3
|
646 |
+
# max_tta_steps = 10
|
647 |
+
# num_tta_steps = int((max_tta_steps - min_tta_steps) * (step / num_cells) + min_tta_steps)
|
648 |
+
# print("!!! num_tta_steps", num_tta_steps)
|
649 |
+
|
650 |
+
# TTA Update
|
651 |
+
self.clip_seg_tta.execute_tta(
|
652 |
+
filt_traj_coords,
|
653 |
+
filt_targets_found_on_path,
|
654 |
+
tta_steps=NUM_TTA_STEPS,
|
655 |
+
lr=adaptive_lr,
|
656 |
+
pos_sample_weight=pos_sample_weight_scale,
|
657 |
+
neg_sample_weight=neg_sample_weight_scale,
|
658 |
+
modality=MODALITY,
|
659 |
+
query_variety=QUERY_VARIETY,
|
660 |
+
target_found_idxs=self.env.target_found_idxs,
|
661 |
+
reset_weights=RESET_WEIGHTS
|
662 |
+
)
|
663 |
+
self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1)
|
664 |
+
self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1)
|
665 |
+
self.step_since_tta = 0
|
666 |
+
|
667 |
+
################################################################################
|
668 |
+
|
669 |
+
|
670 |
+
# if def main
|
671 |
+
if __name__ == "__main__":
|
672 |
+
|
673 |
+
# CHANGE ME!
|
674 |
+
currEpisode = 0
|
675 |
+
|
676 |
+
# Prepare the model
|
677 |
+
# device = torch.device('cpu') #if USE_GPU_TRAINING else torch.device('cpu')
|
678 |
+
device = torch.device('cuda') if USE_GPU else torch.device('cpu')
|
679 |
+
policy_net = PolicyNet(INPUT_DIM, EMBEDDING_DIM).to(device)
|
680 |
+
# script_dir = os.path.dirname(os.path.abspath(__file__))
|
681 |
+
script_dir = Path(__file__).resolve().parent
|
682 |
+
print("real_script_dir: ", script_dir)
|
683 |
+
# checkpoint = torch.load(f'{script_dir}/modules/vlm_search/{model_path}/{MODEL_NAME}')
|
684 |
+
checkpoint = torch.load(f'{model_path}/{MODEL_NAME}')
|
685 |
+
policy_net.load_state_dict(checkpoint['policy_model'])
|
686 |
+
print('Model loaded!')
|
687 |
+
# print(next(policy_net.parameters()).device)
|
688 |
+
|
689 |
+
# Init Taxabind here (only need to init once)
|
690 |
+
if TAXABIND_TTA:
|
691 |
+
# self.clip_seg_tta = None
|
692 |
+
clip_seg_tta = ClipSegTTA(
|
693 |
+
img_dir=TAXABIND_IMG_DIR,
|
694 |
+
imo_dir=TAXABIND_IMO_DIR,
|
695 |
+
json_path=TAXABIND_INAT_JSON_PATH,
|
696 |
+
sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH,
|
697 |
+
patch_size=TAXABIND_PATCH_SIZE,
|
698 |
+
sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH,
|
699 |
+
sample_index = 0, # Set using 'reset' in worker
|
700 |
+
blur_kernel = TAXABIND_GAUSSIAN_BLUR_KERNEL,
|
701 |
+
device=device,
|
702 |
+
sat_to_img_ids_json_is_train_dict=False, # for search ds val
|
703 |
+
tax_to_filter_val=QUERY_TAX,
|
704 |
+
load_model=USE_CLIP_PREDS,
|
705 |
+
initial_modality=INITIAL_MODALITY,
|
706 |
+
sound_data_path = TAXABIND_SOUND_DATA_PATH,
|
707 |
+
sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH,
|
708 |
+
# sat_filtered_json_path=TAXABIND_FILTERED_INAT_JSON_PATH,
|
709 |
+
)
|
710 |
+
print("ClipSegTTA Loaded!")
|
711 |
+
else:
|
712 |
+
clip_seg_tta = None
|
713 |
+
|
714 |
+
# Define TestWorker
|
715 |
+
planner = TestWorker(
|
716 |
+
meta_agent_id=0,
|
717 |
+
n_agent=1,
|
718 |
+
policy_net=policy_net,
|
719 |
+
global_step=3,
|
720 |
+
device='cuda',
|
721 |
+
greedy=True,
|
722 |
+
save_image=SAVE_GIFS,
|
723 |
+
clip_seg_tta=clip_seg_tta
|
724 |
+
)
|
725 |
+
planner.run_episode(currEpisode)
|
726 |
+
|
727 |
+
print("planner.perf_metrics: ", planner.perf_metrics)
|
test_parameter.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
############################################################################################
|
2 |
+
# Name: test_parameter.py
|
3 |
+
#
|
4 |
+
# NOTE: Change all your hyper-params here!
|
5 |
+
# Simple How-To Guide:
|
6 |
+
# 1. CLIP TTA: USE_CLIP_PREDS = True, EXECUTE_TTA = True
|
7 |
+
# 2. CLIP (No TTA): USE_CLIP_PREDS = True, EXECUTE_TTA = False
|
8 |
+
# 3. Custom masks (e.g. LLMSeg): USE_CLIP_PREDS = False, EXECUTE_TTA = False
|
9 |
+
############################################################################################
|
10 |
+
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
sys.modules['TRAINING'] = False # False = Inference Testing
|
14 |
+
|
15 |
+
###############################################################
|
16 |
+
OPT_VARS = {}
|
17 |
+
def getenv(var_name, default=None, cast_type=str):
|
18 |
+
try:
|
19 |
+
value = os.environ.get(var_name, None)
|
20 |
+
if value is None:
|
21 |
+
result = default
|
22 |
+
elif cast_type == bool:
|
23 |
+
result = value.lower() in ("true", "1", "yes")
|
24 |
+
else:
|
25 |
+
result = cast_type(value)
|
26 |
+
except (ValueError, TypeError):
|
27 |
+
result = default
|
28 |
+
|
29 |
+
OPT_VARS[var_name] = result # Log the result
|
30 |
+
return result
|
31 |
+
###############################################################
|
32 |
+
|
33 |
+
POLICY = getenv("POLICY", default="RL", cast_type=str)
|
34 |
+
# TAX_HIERARCHY_TO_CONDENSE = 3 # Remove N layers of the taxonomy hierarchy from the back
|
35 |
+
|
36 |
+
NUM_TEST = 800 # Overriden if TAXABIND_TTA is True and performing search ds val
|
37 |
+
NUM_RUN = 1
|
38 |
+
SAVE_GIFS = getenv("SAVE_GIFS", default=True, cast_type=bool) # do you want to save GIFs
|
39 |
+
SAVE_TRAJECTORY = False # do you want to save per-step metrics
|
40 |
+
SAVE_LENGTH = False # do you want to save per-episode metrics
|
41 |
+
VIZ_GRAPH_EDGES = False # do you want to visualize the graph edges
|
42 |
+
# MODEL_NAME = "pure_coverage_no_pose_obs_230325_stage1.pth" # checkpoint.pth
|
43 |
+
# MODEL_NAME = "STAGE2_20k_vlm_search_24x24_290225_NO_TARGET_REWARDS_600steps.pth" # checkpoint.pth
|
44 |
+
# MODEL_NAME = "vlm_search_24x24_230225_NO_TARGET_REWARDS_600steps.pth" # checkpoint.pth
|
45 |
+
# MODEL_NAME = "vlm_search_20x20_200125_256steps_CORRECT_REWARDS.pth" # checkpoint.pth
|
46 |
+
MODEL_NAME = "STAGE1_vlm_search_24x24_040425_no_tgt_rewards_iNAT_DS_16k.pth"
|
47 |
+
|
48 |
+
NUM_EPS_STEPS = getenv("NUM_EPS_STEPS", default=384, cast_type=int)
|
49 |
+
TERMINATE_ON_TGTS_FOUND = False # Whether to terminate episode when all targets found
|
50 |
+
FORCE_LOGGING_DONE_TGTS_FOUND = True # Whether to force csv logging when all targets found
|
51 |
+
FIX_START_POSITION = getenv("FIX_START_POSITION", default=True, cast_type=bool) # Whether to fix the starting position of the robots (middle index)
|
52 |
+
|
53 |
+
## Whether to override initial score mask from CLIP
|
54 |
+
USE_CLIP_PREDS = getenv("USE_CLIP_PREDS", default=True, cast_type=bool) # If false, use custom masks from OVERRIDE_MASK_DIR
|
55 |
+
OVERRIDE_MASK_DIR = getenv("OVERRIDE_MASK_DIR", default="/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_in/out_mask_val_in", cast_type=str)
|
56 |
+
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Chordata_Mammalia_Rodentia"
|
57 |
+
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Chordata_Mammalia_Artiodactyla"
|
58 |
+
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Arthropoda_Arachnida_Araneae"
|
59 |
+
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Plantae_Tracheophyta_Magnoliopsida_Caryophyllales"
|
60 |
+
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_in/out_mask_val_in"
|
61 |
+
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_out/out_mask_val_out"
|
62 |
+
|
63 |
+
# Used to calcultae info_gain metric
|
64 |
+
OVERRIDE_GT_MASK_DIR = getenv("OVERRIDE_GT_MASK_DIR", default="", cast_type=str)
|
65 |
+
# OVERRIDE_GT_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/gt/val_in_4gsnet_score_map"
|
66 |
+
# OVERRIDE_GT_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/gt/val_out_4gsnet_score_map"
|
67 |
+
|
68 |
+
#######################################################################
|
69 |
+
# iNAT TTA
|
70 |
+
#######################################################################
|
71 |
+
|
72 |
+
# Query Params
|
73 |
+
QUERY_TAX = getenv("QUERY_TAX", default="", cast_type=str) # "" = Test all tax
|
74 |
+
# QUERY_TAX = "Animalia Chordata Mammalia Rodentia" # search_val_in
|
75 |
+
# QUERY_TAX = "Animalia Chordata Mammalia Artiodactyla" # search_val_in
|
76 |
+
# QUERY_TAX = "Animalia Arthropoda Arachnida Araneae" # search_val_out
|
77 |
+
# QUERY_TAX = "Plantae Tracheophyta Magnoliopsida Caryophyllales" # search_val_out
|
78 |
+
|
79 |
+
# TTA PARAMS
|
80 |
+
EXECUTE_TTA = getenv("EXECUTE_TTA", default=True, cast_type=bool) # Whether to execute TTA mask updates
|
81 |
+
STEPS_PER_TTA = 20 # no. steps before each TTA series
|
82 |
+
NUM_TTA_STEPS = 1 # no. of TTA steps during each series
|
83 |
+
INITIAL_MODALITY = getenv("INITIAL_MODALITY", default="image", cast_type=str) # "image", "text", "combined"
|
84 |
+
MODALITY = getenv("MODALITY", default="image", cast_type=str) # "image", "text", "combined"
|
85 |
+
QUERY_VARIETY = getenv("QUERY_VARIETY", default=False, cast_type=bool) # "image", "text", "combined"
|
86 |
+
RESET_WEIGHTS = True
|
87 |
+
MIN_LR = 1e-6
|
88 |
+
MAX_LR = 1e-5 # 1e-5
|
89 |
+
GAMMA_EXPONENT = 2 # 2
|
90 |
+
|
91 |
+
# Paths related to taxabind (TRAIN w/ TARGETS)
|
92 |
+
TAXABIND_TTA = True # Whether to init TTA classes - FOR NOW: Always True
|
93 |
+
TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21'
|
94 |
+
TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_train_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px
|
95 |
+
TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_train.json' # no filter needed
|
96 |
+
TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json' # no filter needed
|
97 |
+
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json'
|
98 |
+
TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = getenv("TAXABIND_SAT_TO_IMG_IDS_JSON_PATH", default="/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/search_val_in.json", cast_type=str)
|
99 |
+
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/search_val_in.json"
|
100 |
+
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/search_val_out.json"
|
101 |
+
TAXABIND_PATCH_SIZE=14
|
102 |
+
TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px_search_ds_filtered/pixel_clip_512px_search_ds_100625_CLIP-L-336_FINAL_SPLIT_LARGE_BUGFIX_CLIP_TRAIN_CORRECT_VAL_IN_TAX_FILTER_TGT_ONLY/satbind-epoch=02-val_loss=2.50_BACKUP.ckpt" # "/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px_search_ds_filtered/pixel_clip_512px_search_ds_070425_CLIP-L-336_FINAL_SPLIT_LARGE/satbind-epoch=02-val_loss=2.48-BACKUP.ckpt"
|
103 |
+
TAXABIND_GAUSSIAN_BLUR_KERNEL = (5,5)
|
104 |
+
TAXABIND_SAMPLE_INDEX = 8 # DEBUG (Starting point) 5, 6, 8
|
105 |
+
|
106 |
+
# Sound
|
107 |
+
TAXABIND_SOUND_DATA_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/sound_test'
|
108 |
+
TAXABIND_SOUND_CHECKPOINT_PATH = "/home/user/Taxabind/TaxaBind/SoundBind/checkpoints/BUGFIX_CLIP_TRAIN_CORRECT_without_out_domain_taxs_v4_220625/soundbind-epoch=19-val_loss=3.92_BACKUP.ckpt"
|
109 |
+
|
110 |
+
# # Paths related to taxabind (TRAIN w/ TARGETS)
|
111 |
+
# TAXABIND_TTA = True
|
112 |
+
# TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21'
|
113 |
+
# TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_train_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px
|
114 |
+
# TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_train.json' # no filter needed
|
115 |
+
# TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json' # no filter needed
|
116 |
+
# # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json'
|
117 |
+
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/OLD/taxon_sat_target_search_100x_per_10-20counts.json"
|
118 |
+
# TAXABIND_PATCH_SIZE=14
|
119 |
+
# TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt" # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_160225_NO_DATASET_SHUFFLE/satbind-epoch=02-val_loss=2.26-BACKUP.ckpt
|
120 |
+
# TAXABIND_SAMPLE_INDEX = 99 # (Starting point) 99,141
|
121 |
+
|
122 |
+
# # Paths related to taxabind (VAL)
|
123 |
+
# TAXABIND_TTA = True
|
124 |
+
# TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21'
|
125 |
+
# TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_test_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px
|
126 |
+
# TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_val.json' # no filter needed
|
127 |
+
# TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_filtered_pixel_clip_val.json' # no filter needed
|
128 |
+
# # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json'
|
129 |
+
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/taxon_sat_target_search_100x_per_10-20counts.json"
|
130 |
+
# TAXABIND_PATCH_SIZE=14
|
131 |
+
# TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt" # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_160225_NO_DATASET_SHUFFLE/satbind-epoch=02-val_loss=2.26-BACKUP.ckpt
|
132 |
+
# TAXABIND_SAMPLE_INDEX = 45 # TEMP
|
133 |
+
|
134 |
+
#######################################################################
|
135 |
+
# Pretraining
|
136 |
+
#######################################################################
|
137 |
+
|
138 |
+
# TODO: Get rid of the LISA stuff...
|
139 |
+
# If LISA trained clss
|
140 |
+
GRIDMAP_SET_DIR = "Maps/flair_real_maps/envs_val_trained_clss"
|
141 |
+
MASK_SET_DIR = "Maps/flair_real_maps_lisa_pred/finetuned_LISA_v3_original_losses/flair_lisa_soft_masks_trained_clss_v3" # original_LISA, finetuned_LISA_v3_original_losses
|
142 |
+
TARGETS_SET_DIR = "Maps/flair_real_maps/masks_val_trained_clss" # If empty, then targets assumed to be on MASK_SET_DIR
|
143 |
+
RAW_IMG_PATH_DICT = "Maps/flair_real_maps/flair-ds-paths-filtered-with-scores-val-trained-clss.csv" # flair-ds-paths-filtered-with-scores-train.csv, flair-ds-paths-filtered-with-scores-val-trained-clss.csv, flair-ds-paths-filtered-with-scores-val-out-clss.csv
|
144 |
+
|
145 |
+
# # If LISA out clss
|
146 |
+
# GRIDMAP_SET_DIR = "Maps/flair_real_maps/envs_val_out_clss"
|
147 |
+
# MASK_SET_DIR = "Maps/flair_real_maps_lisa_pred/finetuned_LISA_v3_original_losses/flair_lisa_soft_masks_out_clss_v3" # original_LISA, finetuned_LISA_v3_original_losses
|
148 |
+
# TARGETS_SET_DIR = "Maps/flair_real_maps/masks_val_out_clss" # If empty, then targets assumed to be on MASK_SET_DIR
|
149 |
+
# RAW_IMG_PATH_DICT = "Maps/flair_real_maps/flair-ds-paths-filtered-with-scores-val-out-clss.csv" # flair-ds-paths-filtered-with-scores-train.csv, flair-ds-paths-filtered-with-scores-val-trained-clss.csv, flair-ds-paths-filtered-with-scores-val-out-clss.csv
|
150 |
+
|
151 |
+
#######################################################################
|
152 |
+
|
153 |
+
NUM_ROBOTS = 1
|
154 |
+
NUM_COORDS_WIDTH=24 # How many node coords across width?
|
155 |
+
NUM_COORDS_HEIGHT=24 # How many node coords across height?
|
156 |
+
HIGH_INFO_REWARD_RATIO = 0.75 # Ratio of rewards for moving to uncertain area (high info vs low info)
|
157 |
+
|
158 |
+
SENSOR_RANGE=80 # Only applicable to 'circle' sensor model
|
159 |
+
SENSOR_MODEL="rectangular" # "rectangular", "circle" (NOTE: (no colllision check for rectangular)
|
160 |
+
|
161 |
+
INPUT_DIM = 4
|
162 |
+
EMBEDDING_DIM = 128
|
163 |
+
K_SIZE = 8 # 8
|
164 |
+
|
165 |
+
USE_GPU = True # do you want to use GPUS?
|
166 |
+
NUM_GPU = getenv("NUM_GPU", default=2, cast_type=int) # the number of GPUs
|
167 |
+
NUM_META_AGENT = getenv("NUM_META_AGENT", default=4, cast_type=int) # the number of processes
|
168 |
+
FOLDER_NAME = 'inference'
|
169 |
+
model_path = f'{FOLDER_NAME}/model'
|
170 |
+
gifs_path = f'{FOLDER_NAME}/test_results/gifs'
|
171 |
+
trajectory_path = f'{FOLDER_NAME}/test_results/trajectory'
|
172 |
+
length_path = f'{FOLDER_NAME}/test_results/length'
|
173 |
+
log_path = f'{FOLDER_NAME}/test_results/log'
|
174 |
+
CSV_EXPT_NAME = getenv("CSV_EXPT_NAME", default="data", cast_type=str)
|
175 |
+
# trajectory_path = f'results/trajectory'
|
176 |
+
# length_path = f'results/length'
|
177 |
+
|
178 |
+
# COLORS (for printing)
|
179 |
+
RED='\033[1;31m'
|
180 |
+
GREEN='\033[1;32m'
|
181 |
+
YELLOW='\033[1;93m'
|
182 |
+
NC_BOLD='\033[1m' # Bold, No Color
|
183 |
+
NC='\033[0m' # No Color
|